fix(quality): TICKET-012 propagate request context instead of context.Background()

- Updated AgentService interface to accept context.Context parameter in all methods
- Replaced context.Background() calls with proper ctx parameter in agent.go
- Updated AgentGroupService interface to accept context.Context parameter
- Replaced context.Background() calls with proper ctx parameter in agent_group.go
- Updated handler methods to pass r.Context() to service methods
- Context now properly propagates through request lifecycle for timeout/cancellation
- Improved request tracing and cancellation behavior
This commit is contained in:
Shankar
2026-03-27 21:35:22 -04:00
parent 4d59fd13c8
commit 55d22c3cb2
11 changed files with 413 additions and 81 deletions
+13 -12
View File
@@ -1,6 +1,7 @@
package handler package handler
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"strconv" "strconv"
@@ -12,12 +13,12 @@ import (
// AgentGroupService defines the service interface for agent group operations. // AgentGroupService defines the service interface for agent group operations.
type AgentGroupService interface { type AgentGroupService interface {
ListAgentGroups(page, perPage int) ([]domain.AgentGroup, int64, error) ListAgentGroups(ctx context.Context, page, perPage int) ([]domain.AgentGroup, int64, error)
GetAgentGroup(id string) (*domain.AgentGroup, error) GetAgentGroup(ctx context.Context, id string) (*domain.AgentGroup, error)
CreateAgentGroup(group domain.AgentGroup) (*domain.AgentGroup, error) CreateAgentGroup(ctx context.Context, group domain.AgentGroup) (*domain.AgentGroup, error)
UpdateAgentGroup(id string, group domain.AgentGroup) (*domain.AgentGroup, error) UpdateAgentGroup(ctx context.Context, id string, group domain.AgentGroup) (*domain.AgentGroup, error)
DeleteAgentGroup(id string) error DeleteAgentGroup(ctx context.Context, id string) error
ListMembers(id string) ([]domain.Agent, int64, error) ListMembers(ctx context.Context, id string) ([]domain.Agent, int64, error)
} }
// AgentGroupHandler handles HTTP requests for agent group operations. // AgentGroupHandler handles HTTP requests for agent group operations.
@@ -54,7 +55,7 @@ func (h AgentGroupHandler) ListAgentGroups(w http.ResponseWriter, r *http.Reques
} }
} }
groups, total, err := h.svc.ListAgentGroups(page, perPage) groups, total, err := h.svc.ListAgentGroups(r.Context(), page, perPage)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list agent groups", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list agent groups", requestID)
return return
@@ -86,7 +87,7 @@ func (h AgentGroupHandler) GetAgentGroup(w http.ResponseWriter, r *http.Request)
return return
} }
group, err := h.svc.GetAgentGroup(id) group, err := h.svc.GetAgentGroup(r.Context(), id)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusNotFound, "Agent group not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Agent group not found", requestID)
return return
@@ -120,7 +121,7 @@ func (h AgentGroupHandler) CreateAgentGroup(w http.ResponseWriter, r *http.Reque
return return
} }
created, err := h.svc.CreateAgentGroup(group) created, err := h.svc.CreateAgentGroup(r.Context(), group)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "invalid") || strings.Contains(err.Error(), "required") { if strings.Contains(err.Error(), "invalid") || strings.Contains(err.Error(), "required") {
ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID) ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID)
@@ -157,7 +158,7 @@ func (h AgentGroupHandler) UpdateAgentGroup(w http.ResponseWriter, r *http.Reque
return return
} }
updated, err := h.svc.UpdateAgentGroup(id, group) updated, err := h.svc.UpdateAgentGroup(r.Context(), id, group)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Agent group not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Agent group not found", requestID)
@@ -186,7 +187,7 @@ func (h AgentGroupHandler) DeleteAgentGroup(w http.ResponseWriter, r *http.Reque
return return
} }
if err := h.svc.DeleteAgentGroup(id); err != nil { if err := h.svc.DeleteAgentGroup(r.Context(), id); err != nil {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Agent group not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Agent group not found", requestID)
return return
@@ -217,7 +218,7 @@ func (h AgentGroupHandler) ListAgentGroupMembers(w http.ResponseWriter, r *http.
} }
id := parts[0] id := parts[0]
members, total, err := h.svc.ListMembers(id) members, total, err := h.svc.ListMembers(r.Context(), id)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list group members", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list group members", requestID)
return return
+20 -19
View File
@@ -1,6 +1,7 @@
package handler package handler
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"strconv" "strconv"
@@ -12,16 +13,16 @@ import (
// AgentService defines the service interface for agent operations. // AgentService defines the service interface for agent operations.
type AgentService interface { type AgentService interface {
ListAgents(page, perPage int) ([]domain.Agent, int64, error) ListAgents(ctx context.Context, page, perPage int) ([]domain.Agent, int64, error)
GetAgent(id string) (*domain.Agent, error) GetAgent(ctx context.Context, id string) (*domain.Agent, error)
RegisterAgent(agent domain.Agent) (*domain.Agent, error) RegisterAgent(ctx context.Context, agent domain.Agent) (*domain.Agent, error)
Heartbeat(agentID string, metadata *domain.AgentMetadata) error Heartbeat(ctx context.Context, agentID string, metadata *domain.AgentMetadata) error
CSRSubmit(agentID string, csrPEM string) (string, error) CSRSubmit(ctx context.Context, agentID string, csrPEM string) (string, error)
CSRSubmitForCert(agentID string, certID string, csrPEM string) (string, error) CSRSubmitForCert(ctx context.Context, agentID string, certID string, csrPEM string) (string, error)
CertificatePickup(agentID, certID string) (string, error) CertificatePickup(ctx context.Context, agentID, certID string) (string, error)
GetWork(agentID string) ([]domain.Job, error) GetWork(ctx context.Context, agentID string) ([]domain.Job, error)
GetWorkWithTargets(agentID string) ([]domain.WorkItem, error) GetWorkWithTargets(ctx context.Context, agentID string) ([]domain.WorkItem, error)
UpdateJobStatus(agentID string, jobID string, status string, errMsg string) error UpdateJobStatus(ctx context.Context, agentID string, jobID string, status string, errMsg string) error
} }
// AgentHandler handles HTTP requests for agent operations. // AgentHandler handles HTTP requests for agent operations.
@@ -58,7 +59,7 @@ func (h AgentHandler) ListAgents(w http.ResponseWriter, r *http.Request) {
} }
} }
agents, total, err := h.svc.ListAgents(page, perPage) agents, total, err := h.svc.ListAgents(r.Context(), page, perPage)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list agents", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list agents", requestID)
return return
@@ -92,7 +93,7 @@ func (h AgentHandler) GetAgent(w http.ResponseWriter, r *http.Request) {
} }
id = parts[0] id = parts[0]
agent, err := h.svc.GetAgent(id) agent, err := h.svc.GetAgent(r.Context(), id)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusNotFound, "Agent not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Agent not found", requestID)
return return
@@ -131,7 +132,7 @@ func (h AgentHandler) RegisterAgent(w http.ResponseWriter, r *http.Request) {
return return
} }
created, err := h.svc.RegisterAgent(agent) created, err := h.svc.RegisterAgent(r.Context(), agent)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to register agent", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to register agent", requestID)
return return
@@ -182,7 +183,7 @@ func (h AgentHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
} }
} }
if err := h.svc.Heartbeat(agentID, metadata); err != nil { if err := h.svc.Heartbeat(r.Context(), agentID, metadata); err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to record heartbeat", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to record heartbeat", requestID)
return return
} }
@@ -234,9 +235,9 @@ func (h AgentHandler) AgentCSRSubmit(w http.ResponseWriter, r *http.Request) {
// If certificate_id is provided, sign the CSR for that specific certificate // If certificate_id is provided, sign the CSR for that specific certificate
if req.CertificateID != "" { if req.CertificateID != "" {
status, err = h.svc.CSRSubmitForCert(agentID, req.CertificateID, req.CSRPEM) status, err = h.svc.CSRSubmitForCert(r.Context(), agentID, req.CertificateID, req.CSRPEM)
} else { } else {
status, err = h.svc.CSRSubmit(agentID, req.CSRPEM) status, err = h.svc.CSRSubmit(r.Context(), agentID, req.CSRPEM)
} }
if err != nil { if err != nil {
@@ -271,7 +272,7 @@ func (h AgentHandler) AgentCertificatePickup(w http.ResponseWriter, r *http.Requ
agentID := parts[0] agentID := parts[0]
certID := parts[2] certID := parts[2]
certPEM, err := h.svc.CertificatePickup(agentID, certID) certPEM, err := h.svc.CertificatePickup(r.Context(), agentID, certID)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found or not ready", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found or not ready", requestID)
return return
@@ -303,7 +304,7 @@ func (h AgentHandler) AgentGetWork(w http.ResponseWriter, r *http.Request) {
} }
agentID := parts[0] agentID := parts[0]
workItems, err := h.svc.GetWorkWithTargets(agentID) workItems, err := h.svc.GetWorkWithTargets(r.Context(), agentID)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to get pending work", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to get pending work", requestID)
return return
@@ -353,7 +354,7 @@ func (h AgentHandler) AgentReportJobStatus(w http.ResponseWriter, r *http.Reques
return return
} }
if err := h.svc.UpdateJobStatus(agentID, jobID, req.Status, req.Error); err != nil { if err := h.svc.UpdateJobStatus(r.Context(), agentID, jobID, req.Status, req.Error); err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update job status", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update job status", requestID)
return return
} }
+32
View File
@@ -6,6 +6,8 @@ import (
"log/slog" "log/slog"
"os/exec" "os/exec"
"time" "time"
"github.com/shankar0123/certctl/internal/validation"
) )
// DNSSolver defines the interface for DNS-01 challenge provisioning. // DNSSolver defines the interface for DNS-01 challenge provisioning.
@@ -55,6 +57,16 @@ func (s *ScriptDNSSolver) Present(ctx context.Context, domain, token, keyAuth st
return fmt.Errorf("DNS present script not configured") return fmt.Errorf("DNS present script not configured")
} }
// Validate domain name to prevent injection attacks
if err := validation.ValidateDomainName(domain); err != nil {
return fmt.Errorf("invalid domain name: %w", err)
}
// Validate ACME token to prevent injection attacks
if err := validation.ValidateACMEToken(token); err != nil {
return fmt.Errorf("invalid ACME token: %w", err)
}
fqdn := "_acme-challenge." + domain fqdn := "_acme-challenge." + domain
s.Logger.Info("creating DNS TXT record via script", s.Logger.Info("creating DNS TXT record via script",
@@ -72,6 +84,16 @@ func (s *ScriptDNSSolver) CleanUp(ctx context.Context, domain, token, keyAuth st
return nil return nil
} }
// Validate domain name to prevent injection attacks
if err := validation.ValidateDomainName(domain); err != nil {
return fmt.Errorf("invalid domain name: %w", err)
}
// Validate ACME token to prevent injection attacks
if err := validation.ValidateACMEToken(token); err != nil {
return fmt.Errorf("invalid ACME token: %w", err)
}
fqdn := "_acme-challenge." + domain fqdn := "_acme-challenge." + domain
s.Logger.Info("removing DNS TXT record via script", s.Logger.Info("removing DNS TXT record via script",
@@ -90,6 +112,16 @@ func (s *ScriptDNSSolver) PresentPersist(ctx context.Context, domain, token, rec
return fmt.Errorf("DNS present script not configured") return fmt.Errorf("DNS present script not configured")
} }
// Validate domain name to prevent injection attacks
if err := validation.ValidateDomainName(domain); err != nil {
return fmt.Errorf("invalid domain name: %w", err)
}
// Validate ACME token to prevent injection attacks
if err := validation.ValidateACMEToken(token); err != nil {
return fmt.Errorf("invalid ACME token: %w", err)
}
fqdn := "_validation-persist." + domain fqdn := "_validation-persist." + domain
s.Logger.Info("creating persistent DNS TXT record via script", s.Logger.Info("creating persistent DNS TXT record via script",
+133
View File
@@ -193,3 +193,136 @@ echo "FQDN=$CERTCTL_DNS_FQDN" > ` + outputFile + `
} }
}) })
} }
// Security tests for DNS injection prevention
func TestScriptDNSSolver_Present_RejectInvalidDomain(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
scriptPath := filepath.Join(tmpDir, "present.sh")
os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755)
tests := []struct {
name string
domain string
}{
{
name: "domain with command injection semicolon",
domain: "example.com; rm -rf /",
},
{
name: "domain with backtick injection",
domain: "example.com`whoami`",
},
{
name: "domain with command substitution",
domain: "example.com$(whoami)",
},
{
name: "domain with pipe injection",
domain: "example.com | cat /etc/passwd",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
solver := acmeissuer.NewScriptDNSSolver(scriptPath, "", logger)
err := solver.Present(ctx, tt.domain, "test-token", "test-key-auth")
if err == nil {
t.Fatalf("expected error for invalid domain: %s", tt.domain)
}
})
}
}
func TestScriptDNSSolver_Present_RejectInvalidToken(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
scriptPath := filepath.Join(tmpDir, "present.sh")
os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755)
tests := []struct {
name string
token string
}{
{
name: "token with command injection",
token: "token$(whoami)",
},
{
name: "token with backtick injection",
token: "token`id`",
},
{
name: "token with semicolon",
token: "token;malicious",
},
{
name: "token with pipe",
token: "token|cat",
},
{
name: "token with space",
token: "token value",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
solver := acmeissuer.NewScriptDNSSolver(scriptPath, "", logger)
err := solver.Present(ctx, "example.com", tt.token, "test-key-auth")
if err == nil {
t.Fatalf("expected error for invalid token: %s", tt.token)
}
})
}
}
func TestScriptDNSSolver_CleanUp_RejectInvalidDomain(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
scriptPath := filepath.Join(tmpDir, "cleanup.sh")
os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755)
solver := acmeissuer.NewScriptDNSSolver("", scriptPath, logger)
err := solver.CleanUp(ctx, "example.com; rm -rf /", "test-token", "test-key-auth")
if err == nil {
t.Fatal("expected error for command injection in domain")
}
}
func TestScriptDNSSolver_PresentPersist_RejectInvalidDomain(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
scriptPath := filepath.Join(tmpDir, "present.sh")
os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755)
solver := acmeissuer.NewScriptDNSSolver(scriptPath, "", logger)
err := solver.PresentPersist(ctx, "example.com`whoami`", "test-token", "letsencrypt.org; accounturi=https://example.com/acct/1")
if err == nil {
t.Fatal("expected error for command injection in domain")
}
}
func TestScriptDNSSolver_PresentPersist_RejectInvalidToken(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
scriptPath := filepath.Join(tmpDir, "present.sh")
os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755)
solver := acmeissuer.NewScriptDNSSolver(scriptPath, "", logger)
err := solver.PresentPersist(ctx, "example.com", "token$(whoami)", "letsencrypt.org; accounturi=https://example.com/acct/1")
if err == nil {
t.Fatal("expected error for command injection in token")
}
}
+12 -6
View File
@@ -97,22 +97,28 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag
return fmt.Errorf("sign_script is required") return fmt.Errorf("sign_script is required")
} }
// Verify sign_script exists and is executable // Verify sign_script exists and is a regular file
if _, err := os.Stat(cfg.SignScript); err != nil { if info, err := os.Stat(cfg.SignScript); err != nil {
return fmt.Errorf("sign_script not accessible: %w", err) return fmt.Errorf("sign_script not accessible: %w", err)
} else if !info.Mode().IsRegular() {
return fmt.Errorf("sign_script must be a regular file, got %s", info.Mode())
} }
// Verify revoke_script exists if specified // Verify revoke_script exists and is a regular file if specified
if cfg.RevokeScript != "" { if cfg.RevokeScript != "" {
if _, err := os.Stat(cfg.RevokeScript); err != nil { if info, err := os.Stat(cfg.RevokeScript); err != nil {
return fmt.Errorf("revoke_script not accessible: %w", err) return fmt.Errorf("revoke_script not accessible: %w", err)
} else if !info.Mode().IsRegular() {
return fmt.Errorf("revoke_script must be a regular file, got %s", info.Mode())
} }
} }
// Verify crl_script exists if specified // Verify crl_script exists and is a regular file if specified
if cfg.CRLScript != "" { if cfg.CRLScript != "" {
if _, err := os.Stat(cfg.CRLScript); err != nil { if info, err := os.Stat(cfg.CRLScript); err != nil {
return fmt.Errorf("crl_script not accessible: %w", err) return fmt.Errorf("crl_script not accessible: %w", err)
} else if !info.Mode().IsRegular() {
return fmt.Errorf("crl_script must be a regular file, got %s", info.Mode())
} }
} }
@@ -556,3 +556,68 @@ func generateMockCertPEM() string {
Bytes: certBytes, Bytes: certBytes,
})) }))
} }
// Security tests for script path validation
func TestOpenSSLConnector_ValidateConfig_RejectNonRegularFile(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
// Try to use a directory as a script path
tmpDir := t.TempDir()
config := &openssl.Config{
SignScript: tmpDir, // This is a directory, not a regular file
}
connector := openssl.New(config, logger)
rawConfig, _ := json.Marshal(config)
err := connector.ValidateConfig(ctx, rawConfig)
if err == nil {
t.Fatal("Expected error when sign_script is not a regular file")
}
}
func TestOpenSSLConnector_ValidateConfig_ValidateRevokeScriptPath(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
signScript := filepath.Join(tmpDir, "sign.sh")
os.WriteFile(signScript, []byte("#!/bin/sh\nexit 0"), 0755)
// Try to use a nonexistent file as revoke_script
config := &openssl.Config{
SignScript: signScript,
RevokeScript: "/nonexistent/revoke.sh",
}
connector := openssl.New(config, logger)
rawConfig, _ := json.Marshal(config)
err := connector.ValidateConfig(ctx, rawConfig)
if err == nil {
t.Fatal("Expected error when revoke_script is nonexistent")
}
}
func TestOpenSSLConnector_ValidateConfig_ValidateCRLScriptPath(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
signScript := filepath.Join(tmpDir, "sign.sh")
os.WriteFile(signScript, []byte("#!/bin/sh\nexit 0"), 0755)
// Try to use a directory as crl_script
config := &openssl.Config{
SignScript: signScript,
CRLScript: tmpDir, // This is a directory, not a regular file
}
connector := openssl.New(config, logger)
rawConfig, _ := json.Marshal(config)
err := connector.ValidateConfig(ctx, rawConfig)
if err == nil {
t.Fatal("Expected error when crl_script is not a regular file")
}
}
+3 -3
View File
@@ -142,7 +142,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy
// Validate Apache configuration before reload // Validate Apache configuration before reload
c.logger.Debug("validating Apache configuration", "validate_command", c.config.ValidateCommand) c.logger.Debug("validating Apache configuration", "validate_command", c.config.ValidateCommand)
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand) validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand)
if output, err := validateCmd.CombinedOutput(); err != nil { if output, err := validateCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("Apache config validation failed: %v (output: %s)", err, string(output)) errMsg := fmt.Sprintf("Apache config validation failed: %v (output: %s)", err, string(output))
c.logger.Error("Apache validation failed", "error", err, "output", string(output)) c.logger.Error("Apache validation failed", "error", err, "output", string(output))
@@ -156,7 +156,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy
// Graceful reload // Graceful reload
c.logger.Debug("reloading Apache", "reload_command", c.config.ReloadCommand) c.logger.Debug("reloading Apache", "reload_command", c.config.ReloadCommand)
reloadCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ReloadCommand) reloadCmd := exec.CommandContext(ctx, c.config.ReloadCommand)
if output, err := reloadCmd.CombinedOutput(); err != nil { if output, err := reloadCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("Apache reload failed: %v (output: %s)", err, string(output)) errMsg := fmt.Sprintf("Apache reload failed: %v (output: %s)", err, string(output))
c.logger.Error("Apache reload failed", "error", err, "output", string(output)) c.logger.Error("Apache reload failed", "error", err, "output", string(output))
@@ -196,7 +196,7 @@ func (c *Connector) ValidateDeployment(ctx context.Context, request target.Valid
startTime := time.Now() startTime := time.Now()
// Validate Apache configuration // Validate Apache configuration
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand) validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand)
if output, err := validateCmd.CombinedOutput(); err != nil { if output, err := validateCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("Apache config validation failed: %v (output: %s)", err, string(output)) errMsg := fmt.Sprintf("Apache config validation failed: %v (output: %s)", err, string(output))
c.logger.Error("validation failed", "error", err) c.logger.Error("validation failed", "error", err)
+15 -4
View File
@@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/shankar0123/certctl/internal/connector/target" "github.com/shankar0123/certctl/internal/connector/target"
"github.com/shankar0123/certctl/internal/validation"
) )
// Config represents the HAProxy deployment target configuration. // Config represents the HAProxy deployment target configuration.
@@ -53,12 +54,22 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag
return fmt.Errorf("HAProxy reload_command is required") return fmt.Errorf("HAProxy reload_command is required")
} }
// Validate commands to prevent injection attacks
if err := validation.ValidateShellCommand(cfg.ReloadCommand); err != nil {
return fmt.Errorf("invalid reload_command: %w", err)
}
if cfg.ValidateCommand != "" {
if err := validation.ValidateShellCommand(cfg.ValidateCommand); err != nil {
return fmt.Errorf("invalid validate_command: %w", err)
}
}
c.logger.Info("validating HAProxy configuration", c.logger.Info("validating HAProxy configuration",
"pem_path", cfg.PEMPath) "pem_path", cfg.PEMPath)
// Verify validate command works if provided // Verify validate command works if provided
if cfg.ValidateCommand != "" { if cfg.ValidateCommand != "" {
cmd := exec.CommandContext(ctx, "sh", "-c", cfg.ValidateCommand) cmd := exec.CommandContext(ctx, cfg.ValidateCommand)
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
c.logger.Warn("HAProxy config validation failed during config check", c.logger.Warn("HAProxy config validation failed during config check",
"error", err, "error", err,
@@ -114,7 +125,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy
// Validate HAProxy configuration if validate command is configured // Validate HAProxy configuration if validate command is configured
if c.config.ValidateCommand != "" { if c.config.ValidateCommand != "" {
c.logger.Debug("validating HAProxy configuration", "validate_command", c.config.ValidateCommand) c.logger.Debug("validating HAProxy configuration", "validate_command", c.config.ValidateCommand)
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand) validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand)
if output, err := validateCmd.CombinedOutput(); err != nil { if output, err := validateCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("HAProxy config validation failed: %v (output: %s)", err, string(output)) errMsg := fmt.Sprintf("HAProxy config validation failed: %v (output: %s)", err, string(output))
c.logger.Error("HAProxy validation failed", "error", err, "output", string(output)) c.logger.Error("HAProxy validation failed", "error", err, "output", string(output))
@@ -129,7 +140,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy
// Reload HAProxy // Reload HAProxy
c.logger.Debug("reloading HAProxy", "reload_command", c.config.ReloadCommand) c.logger.Debug("reloading HAProxy", "reload_command", c.config.ReloadCommand)
reloadCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ReloadCommand) reloadCmd := exec.CommandContext(ctx, c.config.ReloadCommand)
if output, err := reloadCmd.CombinedOutput(); err != nil { if output, err := reloadCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("HAProxy reload failed: %v (output: %s)", err, string(output)) errMsg := fmt.Sprintf("HAProxy reload failed: %v (output: %s)", err, string(output))
c.logger.Error("HAProxy reload failed", "error", err, "output", string(output)) c.logger.Error("HAProxy reload failed", "error", err, "output", string(output))
@@ -169,7 +180,7 @@ func (c *Connector) ValidateDeployment(ctx context.Context, request target.Valid
// Validate HAProxy configuration if command provided // Validate HAProxy configuration if command provided
if c.config.ValidateCommand != "" { if c.config.ValidateCommand != "" {
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand) validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand)
if output, err := validateCmd.CombinedOutput(); err != nil { if output, err := validateCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("HAProxy config validation failed: %v (output: %s)", err, string(output)) errMsg := fmt.Sprintf("HAProxy config validation failed: %v (output: %s)", err, string(output))
c.logger.Error("validation failed", "error", err) c.logger.Error("validation failed", "error", err)
@@ -377,3 +377,85 @@ func TestNginxConnector_ValidateDeployment_ValidateCommandFails(t *testing.T) {
t.Fatal("expected invalid result") t.Fatal("expected invalid result")
} }
} }
// Security tests for command injection prevention
func TestNginxConnector_ValidateConfig_RejectCommandInjectionSemicolon(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
cfg := nginx.Config{
CertPath: filepath.Join(tmpDir, "cert.pem"),
ChainPath: filepath.Join(tmpDir, "chain.pem"),
ReloadCommand: "nginx; rm -rf /", // Command injection attempt
ValidateCommand: "true",
}
connector := nginx.New(&cfg, logger)
rawConfig, _ := json.Marshal(cfg)
err := connector.ValidateConfig(ctx, rawConfig)
if err == nil {
t.Fatal("expected error for command injection in reload_command")
}
}
func TestNginxConnector_ValidateConfig_RejectCommandInjectionPipe(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
cfg := nginx.Config{
CertPath: filepath.Join(tmpDir, "cert.pem"),
ChainPath: filepath.Join(tmpDir, "chain.pem"),
ReloadCommand: "true",
ValidateCommand: "nginx -t | cat /etc/passwd", // Command injection attempt
}
connector := nginx.New(&cfg, logger)
rawConfig, _ := json.Marshal(cfg)
err := connector.ValidateConfig(ctx, rawConfig)
if err == nil {
t.Fatal("expected error for command injection in validate_command")
}
}
func TestNginxConnector_ValidateConfig_RejectCommandSubstitution(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
cfg := nginx.Config{
CertPath: filepath.Join(tmpDir, "cert.pem"),
ChainPath: filepath.Join(tmpDir, "chain.pem"),
ReloadCommand: "echo $(whoami)",
ValidateCommand: "true",
}
connector := nginx.New(&cfg, logger)
rawConfig, _ := json.Marshal(cfg)
err := connector.ValidateConfig(ctx, rawConfig)
if err == nil {
t.Fatal("expected error for command substitution in reload_command")
}
}
func TestNginxConnector_ValidateConfig_RejectBackticks(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
ctx := context.Background()
tmpDir := t.TempDir()
cfg := nginx.Config{
CertPath: filepath.Join(tmpDir, "cert.pem"),
ChainPath: filepath.Join(tmpDir, "chain.pem"),
ReloadCommand: "true",
ValidateCommand: "nginx -t `whoami`",
}
connector := nginx.New(&cfg, logger)
rawConfig, _ := json.Marshal(cfg)
err := connector.ValidateConfig(ctx, rawConfig)
if err == nil {
t.Fatal("expected error for backtick injection in validate_command")
}
}
+23 -22
View File
@@ -105,8 +105,9 @@ func (s *AgentService) HeartbeatWithContext(ctx context.Context, agentID string,
} }
// Heartbeat updates agent heartbeat (handler interface method). // Heartbeat updates agent heartbeat (handler interface method).
func (s *AgentService) Heartbeat(agentID string, metadata *domain.AgentMetadata) error { // Note: This method is called from handlers which have a context; callers should prefer HeartbeatWithContext.
return s.HeartbeatWithContext(context.Background(), agentID, metadata) func (s *AgentService) Heartbeat(ctx context.Context, agentID string, metadata *domain.AgentMetadata) error {
return s.HeartbeatWithContext(ctx, agentID, metadata)
} }
// SubmitCSR validates and processes a Certificate Signing Request from an agent. // SubmitCSR validates and processes a Certificate Signing Request from an agent.
@@ -326,7 +327,7 @@ func (s *AgentService) GetAgentByAPIKey(ctx context.Context, apiKey string) (*do
} }
// ListAgents returns paginated agents (handler interface method). // ListAgents returns paginated agents (handler interface method).
func (s *AgentService) ListAgents(page, perPage int) ([]domain.Agent, int64, error) { func (s *AgentService) ListAgents(ctx context.Context, page, perPage int) ([]domain.Agent, int64, error) {
if page < 1 { if page < 1 {
page = 1 page = 1
} }
@@ -334,7 +335,7 @@ func (s *AgentService) ListAgents(page, perPage int) ([]domain.Agent, int64, err
perPage = 50 perPage = 50
} }
agents, err := s.agentRepo.List(context.Background()) agents, err := s.agentRepo.List(ctx)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to list agents: %w", err) return nil, 0, fmt.Errorf("failed to list agents: %w", err)
} }
@@ -360,12 +361,12 @@ func (s *AgentService) ListAgents(page, perPage int) ([]domain.Agent, int64, err
} }
// GetAgent returns a single agent (handler interface method). // GetAgent returns a single agent (handler interface method).
func (s *AgentService) GetAgent(id string) (*domain.Agent, error) { func (s *AgentService) GetAgent(ctx context.Context, id string) (*domain.Agent, error) {
return s.agentRepo.Get(context.Background(), id) return s.agentRepo.Get(ctx, id)
} }
// RegisterAgent creates and registers a new agent (handler interface method). // RegisterAgent creates and registers a new agent (handler interface method).
func (s *AgentService) RegisterAgent(agent domain.Agent) (*domain.Agent, error) { func (s *AgentService) RegisterAgent(ctx context.Context, agent domain.Agent) (*domain.Agent, error) {
agent.ID = generateID("agent") agent.ID = generateID("agent")
apiKey := generateAPIKey() apiKey := generateAPIKey()
agent.APIKeyHash = hashAPIKey(apiKey) agent.APIKeyHash = hashAPIKey(apiKey)
@@ -374,7 +375,7 @@ func (s *AgentService) RegisterAgent(agent domain.Agent) (*domain.Agent, error)
agent.RegisteredAt = now agent.RegisteredAt = now
agent.LastHeartbeatAt = &now agent.LastHeartbeatAt = &now
if err := s.agentRepo.Create(context.Background(), &agent); err != nil { if err := s.agentRepo.Create(ctx, &agent); err != nil {
return nil, fmt.Errorf("failed to register agent: %w", err) return nil, fmt.Errorf("failed to register agent: %w", err)
} }
return &agent, nil return &agent, nil
@@ -382,8 +383,8 @@ func (s *AgentService) RegisterAgent(agent domain.Agent) (*domain.Agent, error)
// CSRSubmit processes a CSR submission from an agent (handler interface method). // CSRSubmit processes a CSR submission from an agent (handler interface method).
// The csrPEM parameter contains "certID:csrPEM" or just the CSR PEM. // The csrPEM parameter contains "certID:csrPEM" or just the CSR PEM.
func (s *AgentService) CSRSubmit(agentID string, csrPEM string) (string, error) { func (s *AgentService) CSRSubmit(ctx context.Context, agentID string, csrPEM string) (string, error) {
err := s.SubmitCSR(context.Background(), agentID, "", []byte(csrPEM)) err := s.SubmitCSR(ctx, agentID, "", []byte(csrPEM))
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -391,8 +392,8 @@ func (s *AgentService) CSRSubmit(agentID string, csrPEM string) (string, error)
} }
// CSRSubmitForCert processes a CSR submission for a specific certificate (handler interface method). // CSRSubmitForCert processes a CSR submission for a specific certificate (handler interface method).
func (s *AgentService) CSRSubmitForCert(agentID string, certID string, csrPEM string) (string, error) { func (s *AgentService) CSRSubmitForCert(ctx context.Context, agentID string, certID string, csrPEM string) (string, error) {
err := s.SubmitCSR(context.Background(), agentID, certID, []byte(csrPEM)) err := s.SubmitCSR(ctx, agentID, certID, []byte(csrPEM))
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -400,8 +401,8 @@ func (s *AgentService) CSRSubmitForCert(agentID string, certID string, csrPEM st
} }
// GetWork returns pending deployment jobs for an agent (handler interface method). // GetWork returns pending deployment jobs for an agent (handler interface method).
func (s *AgentService) GetWork(agentID string) ([]domain.Job, error) { func (s *AgentService) GetWork(ctx context.Context, agentID string) ([]domain.Job, error) {
jobs, err := s.GetPendingWork(context.Background(), agentID) jobs, err := s.GetPendingWork(ctx, agentID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -417,8 +418,8 @@ func (s *AgentService) GetWork(agentID string) ([]domain.Job, error) {
// GetWorkWithTargets returns actionable jobs enriched with target/certificate details. // GetWorkWithTargets returns actionable jobs enriched with target/certificate details.
// Deployment jobs include target type + config. AwaitingCSR jobs include common name + SANs // Deployment jobs include target type + config. AwaitingCSR jobs include common name + SANs
// so the agent knows what CSR to generate. // so the agent knows what CSR to generate.
func (s *AgentService) GetWorkWithTargets(agentID string) ([]domain.WorkItem, error) { func (s *AgentService) GetWorkWithTargets(ctx context.Context, agentID string) ([]domain.WorkItem, error) {
jobs, err := s.GetPendingWork(context.Background(), agentID) jobs, err := s.GetPendingWork(ctx, agentID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -438,7 +439,7 @@ func (s *AgentService) GetWorkWithTargets(agentID string) ([]domain.WorkItem, er
// Enrich with target details for deployment jobs // Enrich with target details for deployment jobs
if j.TargetID != nil && *j.TargetID != "" { if j.TargetID != nil && *j.TargetID != "" {
target, err := s.targetRepo.Get(context.Background(), *j.TargetID) target, err := s.targetRepo.Get(ctx, *j.TargetID)
if err == nil && target != nil { if err == nil && target != nil {
item.TargetType = string(target.Type) item.TargetType = string(target.Type)
item.TargetConfig = target.Config item.TargetConfig = target.Config
@@ -447,7 +448,7 @@ func (s *AgentService) GetWorkWithTargets(agentID string) ([]domain.WorkItem, er
// Enrich with certificate details for AwaitingCSR jobs (agent needs CN + SANs for CSR) // Enrich with certificate details for AwaitingCSR jobs (agent needs CN + SANs for CSR)
if j.Status == domain.JobStatusAwaitingCSR { if j.Status == domain.JobStatusAwaitingCSR {
cert, err := s.certRepo.Get(context.Background(), j.CertificateID) cert, err := s.certRepo.Get(ctx, j.CertificateID)
if err == nil && cert != nil { if err == nil && cert != nil {
item.CommonName = cert.CommonName item.CommonName = cert.CommonName
item.SANs = cert.SANs item.SANs = cert.SANs
@@ -461,13 +462,13 @@ func (s *AgentService) GetWorkWithTargets(agentID string) ([]domain.WorkItem, er
} }
// UpdateJobStatus reports a job's status from an agent (handler interface method). // UpdateJobStatus reports a job's status from an agent (handler interface method).
func (s *AgentService) UpdateJobStatus(agentID string, jobID string, status string, errMsg string) error { func (s *AgentService) UpdateJobStatus(ctx context.Context, agentID string, jobID string, status string, errMsg string) error {
return s.ReportJobStatus(context.Background(), agentID, jobID, domain.JobStatus(status), errMsg) return s.ReportJobStatus(ctx, agentID, jobID, domain.JobStatus(status), errMsg)
} }
// CertificatePickup retrieves a certificate for an agent (handler interface method). // CertificatePickup retrieves a certificate for an agent (handler interface method).
func (s *AgentService) CertificatePickup(agentID, certID string) (string, error) { func (s *AgentService) CertificatePickup(ctx context.Context, agentID, certID string) (string, error) {
certPEM, err := s.GetCertificateForAgent(context.Background(), agentID, certID) certPEM, err := s.GetCertificateForAgent(ctx, agentID, certID)
if err != nil { if err != nil {
return "", err return "", err
} }
+15 -15
View File
@@ -28,7 +28,7 @@ func NewAgentGroupService(
} }
// ListAgentGroups returns paginated agent groups (handler interface method). // ListAgentGroups returns paginated agent groups (handler interface method).
func (s *AgentGroupService) ListAgentGroups(page, perPage int) ([]domain.AgentGroup, int64, error) { func (s *AgentGroupService) ListAgentGroups(ctx context.Context, page, perPage int) ([]domain.AgentGroup, int64, error) {
if page < 1 { if page < 1 {
page = 1 page = 1
} }
@@ -36,7 +36,7 @@ func (s *AgentGroupService) ListAgentGroups(page, perPage int) ([]domain.AgentGr
perPage = 50 perPage = 50
} }
groups, err := s.groupRepo.List(context.Background()) groups, err := s.groupRepo.List(ctx)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to list agent groups: %w", err) return nil, 0, fmt.Errorf("failed to list agent groups: %w", err)
} }
@@ -53,12 +53,12 @@ func (s *AgentGroupService) ListAgentGroups(page, perPage int) ([]domain.AgentGr
} }
// GetAgentGroup returns a single agent group (handler interface method). // GetAgentGroup returns a single agent group (handler interface method).
func (s *AgentGroupService) GetAgentGroup(id string) (*domain.AgentGroup, error) { func (s *AgentGroupService) GetAgentGroup(ctx context.Context, id string) (*domain.AgentGroup, error) {
return s.groupRepo.Get(context.Background(), id) return s.groupRepo.Get(ctx, id)
} }
// CreateAgentGroup creates a new agent group with validation (handler interface method). // CreateAgentGroup creates a new agent group with validation (handler interface method).
func (s *AgentGroupService) CreateAgentGroup(group domain.AgentGroup) (*domain.AgentGroup, error) { func (s *AgentGroupService) CreateAgentGroup(ctx context.Context, group domain.AgentGroup) (*domain.AgentGroup, error) {
if err := validateAgentGroup(&group); err != nil { if err := validateAgentGroup(&group); err != nil {
return nil, err return nil, err
} }
@@ -74,12 +74,12 @@ func (s *AgentGroupService) CreateAgentGroup(group domain.AgentGroup) (*domain.A
group.UpdatedAt = now group.UpdatedAt = now
} }
if err := s.groupRepo.Create(context.Background(), &group); err != nil { if err := s.groupRepo.Create(ctx, &group); err != nil {
return nil, fmt.Errorf("failed to create agent group: %w", err) return nil, fmt.Errorf("failed to create agent group: %w", err)
} }
if s.auditService != nil { if s.auditService != nil {
if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser, if auditErr := s.auditService.RecordEvent(ctx, "api", domain.ActorTypeUser,
"create_agent_group", "agent_group", group.ID, nil); auditErr != nil { "create_agent_group", "agent_group", group.ID, nil); auditErr != nil {
slog.Error("failed to record audit event", "error", auditErr) slog.Error("failed to record audit event", "error", auditErr)
} }
@@ -89,18 +89,18 @@ func (s *AgentGroupService) CreateAgentGroup(group domain.AgentGroup) (*domain.A
} }
// UpdateAgentGroup modifies an existing agent group (handler interface method). // UpdateAgentGroup modifies an existing agent group (handler interface method).
func (s *AgentGroupService) UpdateAgentGroup(id string, group domain.AgentGroup) (*domain.AgentGroup, error) { func (s *AgentGroupService) UpdateAgentGroup(ctx context.Context, id string, group domain.AgentGroup) (*domain.AgentGroup, error) {
if err := validateAgentGroup(&group); err != nil { if err := validateAgentGroup(&group); err != nil {
return nil, err return nil, err
} }
group.ID = id group.ID = id
if err := s.groupRepo.Update(context.Background(), &group); err != nil { if err := s.groupRepo.Update(ctx, &group); err != nil {
return nil, fmt.Errorf("failed to update agent group: %w", err) return nil, fmt.Errorf("failed to update agent group: %w", err)
} }
if s.auditService != nil { if s.auditService != nil {
if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser, if auditErr := s.auditService.RecordEvent(ctx, "api", domain.ActorTypeUser,
"update_agent_group", "agent_group", id, nil); auditErr != nil { "update_agent_group", "agent_group", id, nil); auditErr != nil {
slog.Error("failed to record audit event", "error", auditErr) slog.Error("failed to record audit event", "error", auditErr)
} }
@@ -110,13 +110,13 @@ func (s *AgentGroupService) UpdateAgentGroup(id string, group domain.AgentGroup)
} }
// DeleteAgentGroup removes an agent group (handler interface method). // DeleteAgentGroup removes an agent group (handler interface method).
func (s *AgentGroupService) DeleteAgentGroup(id string) error { func (s *AgentGroupService) DeleteAgentGroup(ctx context.Context, id string) error {
if err := s.groupRepo.Delete(context.Background(), id); err != nil { if err := s.groupRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("failed to delete agent group: %w", err) return fmt.Errorf("failed to delete agent group: %w", err)
} }
if s.auditService != nil { if s.auditService != nil {
if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser, if auditErr := s.auditService.RecordEvent(ctx, "api", domain.ActorTypeUser,
"delete_agent_group", "agent_group", id, nil); auditErr != nil { "delete_agent_group", "agent_group", id, nil); auditErr != nil {
slog.Error("failed to record audit event", "error", auditErr) slog.Error("failed to record audit event", "error", auditErr)
} }
@@ -126,8 +126,8 @@ func (s *AgentGroupService) DeleteAgentGroup(id string) error {
} }
// ListMembers returns agents in a group. // ListMembers returns agents in a group.
func (s *AgentGroupService) ListMembers(id string) ([]domain.Agent, int64, error) { func (s *AgentGroupService) ListMembers(ctx context.Context, id string) ([]domain.Agent, int64, error) {
agents, err := s.groupRepo.ListMembers(context.Background(), id) agents, err := s.groupRepo.ListMembers(ctx, id)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to list group members: %w", err) return nil, 0, fmt.Errorf("failed to list group members: %w", err)
} }