diff --git a/internal/api/handler/agent_groups.go b/internal/api/handler/agent_groups.go index 6eef3a6..e97cda9 100644 --- a/internal/api/handler/agent_groups.go +++ b/internal/api/handler/agent_groups.go @@ -1,6 +1,7 @@ package handler import ( + "context" "encoding/json" "net/http" "strconv" @@ -12,12 +13,12 @@ import ( // AgentGroupService defines the service interface for agent group operations. type AgentGroupService interface { - ListAgentGroups(page, perPage int) ([]domain.AgentGroup, int64, error) - GetAgentGroup(id string) (*domain.AgentGroup, error) - CreateAgentGroup(group domain.AgentGroup) (*domain.AgentGroup, error) - UpdateAgentGroup(id string, group domain.AgentGroup) (*domain.AgentGroup, error) - DeleteAgentGroup(id string) error - ListMembers(id string) ([]domain.Agent, int64, error) + ListAgentGroups(ctx context.Context, page, perPage int) ([]domain.AgentGroup, int64, error) + GetAgentGroup(ctx context.Context, id string) (*domain.AgentGroup, error) + CreateAgentGroup(ctx context.Context, group domain.AgentGroup) (*domain.AgentGroup, error) + UpdateAgentGroup(ctx context.Context, id string, group domain.AgentGroup) (*domain.AgentGroup, error) + DeleteAgentGroup(ctx context.Context, id string) error + ListMembers(ctx context.Context, id string) ([]domain.Agent, int64, error) } // AgentGroupHandler handles HTTP requests for agent group operations. @@ -54,7 +55,7 @@ func (h AgentGroupHandler) ListAgentGroups(w http.ResponseWriter, r *http.Reques } } - groups, total, err := h.svc.ListAgentGroups(page, perPage) + groups, total, err := h.svc.ListAgentGroups(r.Context(), page, perPage) if err != nil { ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list agent groups", requestID) return @@ -86,7 +87,7 @@ func (h AgentGroupHandler) GetAgentGroup(w http.ResponseWriter, r *http.Request) return } - group, err := h.svc.GetAgentGroup(id) + group, err := h.svc.GetAgentGroup(r.Context(), id) if err != nil { ErrorWithRequestID(w, http.StatusNotFound, "Agent group not found", requestID) return @@ -120,7 +121,7 @@ func (h AgentGroupHandler) CreateAgentGroup(w http.ResponseWriter, r *http.Reque return } - created, err := h.svc.CreateAgentGroup(group) + created, err := h.svc.CreateAgentGroup(r.Context(), group) if err != nil { if strings.Contains(err.Error(), "invalid") || strings.Contains(err.Error(), "required") { ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID) @@ -157,7 +158,7 @@ func (h AgentGroupHandler) UpdateAgentGroup(w http.ResponseWriter, r *http.Reque return } - updated, err := h.svc.UpdateAgentGroup(id, group) + updated, err := h.svc.UpdateAgentGroup(r.Context(), id, group) if err != nil { if strings.Contains(err.Error(), "not found") { ErrorWithRequestID(w, http.StatusNotFound, "Agent group not found", requestID) @@ -186,7 +187,7 @@ func (h AgentGroupHandler) DeleteAgentGroup(w http.ResponseWriter, r *http.Reque return } - if err := h.svc.DeleteAgentGroup(id); err != nil { + if err := h.svc.DeleteAgentGroup(r.Context(), id); err != nil { if strings.Contains(err.Error(), "not found") { ErrorWithRequestID(w, http.StatusNotFound, "Agent group not found", requestID) return @@ -217,7 +218,7 @@ func (h AgentGroupHandler) ListAgentGroupMembers(w http.ResponseWriter, r *http. } id := parts[0] - members, total, err := h.svc.ListMembers(id) + members, total, err := h.svc.ListMembers(r.Context(), id) if err != nil { ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list group members", requestID) return diff --git a/internal/api/handler/agents.go b/internal/api/handler/agents.go index 85ad542..2e5b42e 100644 --- a/internal/api/handler/agents.go +++ b/internal/api/handler/agents.go @@ -1,6 +1,7 @@ package handler import ( + "context" "encoding/json" "net/http" "strconv" @@ -12,16 +13,16 @@ import ( // AgentService defines the service interface for agent operations. type AgentService interface { - ListAgents(page, perPage int) ([]domain.Agent, int64, error) - GetAgent(id string) (*domain.Agent, error) - RegisterAgent(agent domain.Agent) (*domain.Agent, error) - Heartbeat(agentID string, metadata *domain.AgentMetadata) error - CSRSubmit(agentID string, csrPEM string) (string, error) - CSRSubmitForCert(agentID string, certID string, csrPEM string) (string, error) - CertificatePickup(agentID, certID string) (string, error) - GetWork(agentID string) ([]domain.Job, error) - GetWorkWithTargets(agentID string) ([]domain.WorkItem, error) - UpdateJobStatus(agentID string, jobID string, status string, errMsg string) error + ListAgents(ctx context.Context, page, perPage int) ([]domain.Agent, int64, error) + GetAgent(ctx context.Context, id string) (*domain.Agent, error) + RegisterAgent(ctx context.Context, agent domain.Agent) (*domain.Agent, error) + Heartbeat(ctx context.Context, agentID string, metadata *domain.AgentMetadata) error + CSRSubmit(ctx context.Context, agentID string, csrPEM string) (string, error) + CSRSubmitForCert(ctx context.Context, agentID string, certID string, csrPEM string) (string, error) + CertificatePickup(ctx context.Context, agentID, certID string) (string, error) + GetWork(ctx context.Context, agentID string) ([]domain.Job, error) + GetWorkWithTargets(ctx context.Context, agentID string) ([]domain.WorkItem, error) + UpdateJobStatus(ctx context.Context, agentID string, jobID string, status string, errMsg string) error } // AgentHandler handles HTTP requests for agent operations. @@ -58,7 +59,7 @@ func (h AgentHandler) ListAgents(w http.ResponseWriter, r *http.Request) { } } - agents, total, err := h.svc.ListAgents(page, perPage) + agents, total, err := h.svc.ListAgents(r.Context(), page, perPage) if err != nil { ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list agents", requestID) return @@ -92,7 +93,7 @@ func (h AgentHandler) GetAgent(w http.ResponseWriter, r *http.Request) { } id = parts[0] - agent, err := h.svc.GetAgent(id) + agent, err := h.svc.GetAgent(r.Context(), id) if err != nil { ErrorWithRequestID(w, http.StatusNotFound, "Agent not found", requestID) return @@ -131,7 +132,7 @@ func (h AgentHandler) RegisterAgent(w http.ResponseWriter, r *http.Request) { return } - created, err := h.svc.RegisterAgent(agent) + created, err := h.svc.RegisterAgent(r.Context(), agent) if err != nil { ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to register agent", requestID) return @@ -182,7 +183,7 @@ func (h AgentHandler) Heartbeat(w http.ResponseWriter, r *http.Request) { } } - if err := h.svc.Heartbeat(agentID, metadata); err != nil { + if err := h.svc.Heartbeat(r.Context(), agentID, metadata); err != nil { ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to record heartbeat", requestID) return } @@ -234,9 +235,9 @@ func (h AgentHandler) AgentCSRSubmit(w http.ResponseWriter, r *http.Request) { // If certificate_id is provided, sign the CSR for that specific certificate if req.CertificateID != "" { - status, err = h.svc.CSRSubmitForCert(agentID, req.CertificateID, req.CSRPEM) + status, err = h.svc.CSRSubmitForCert(r.Context(), agentID, req.CertificateID, req.CSRPEM) } else { - status, err = h.svc.CSRSubmit(agentID, req.CSRPEM) + status, err = h.svc.CSRSubmit(r.Context(), agentID, req.CSRPEM) } if err != nil { @@ -271,7 +272,7 @@ func (h AgentHandler) AgentCertificatePickup(w http.ResponseWriter, r *http.Requ agentID := parts[0] certID := parts[2] - certPEM, err := h.svc.CertificatePickup(agentID, certID) + certPEM, err := h.svc.CertificatePickup(r.Context(), agentID, certID) if err != nil { ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found or not ready", requestID) return @@ -303,7 +304,7 @@ func (h AgentHandler) AgentGetWork(w http.ResponseWriter, r *http.Request) { } agentID := parts[0] - workItems, err := h.svc.GetWorkWithTargets(agentID) + workItems, err := h.svc.GetWorkWithTargets(r.Context(), agentID) if err != nil { ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to get pending work", requestID) return @@ -353,7 +354,7 @@ func (h AgentHandler) AgentReportJobStatus(w http.ResponseWriter, r *http.Reques return } - if err := h.svc.UpdateJobStatus(agentID, jobID, req.Status, req.Error); err != nil { + if err := h.svc.UpdateJobStatus(r.Context(), agentID, jobID, req.Status, req.Error); err != nil { ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update job status", requestID) return } diff --git a/internal/connector/issuer/acme/dns.go b/internal/connector/issuer/acme/dns.go index 1b1cd44..92cb9a3 100644 --- a/internal/connector/issuer/acme/dns.go +++ b/internal/connector/issuer/acme/dns.go @@ -6,6 +6,8 @@ import ( "log/slog" "os/exec" "time" + + "github.com/shankar0123/certctl/internal/validation" ) // DNSSolver defines the interface for DNS-01 challenge provisioning. @@ -55,6 +57,16 @@ func (s *ScriptDNSSolver) Present(ctx context.Context, domain, token, keyAuth st return fmt.Errorf("DNS present script not configured") } + // Validate domain name to prevent injection attacks + if err := validation.ValidateDomainName(domain); err != nil { + return fmt.Errorf("invalid domain name: %w", err) + } + + // Validate ACME token to prevent injection attacks + if err := validation.ValidateACMEToken(token); err != nil { + return fmt.Errorf("invalid ACME token: %w", err) + } + fqdn := "_acme-challenge." + domain s.Logger.Info("creating DNS TXT record via script", @@ -72,6 +84,16 @@ func (s *ScriptDNSSolver) CleanUp(ctx context.Context, domain, token, keyAuth st return nil } + // Validate domain name to prevent injection attacks + if err := validation.ValidateDomainName(domain); err != nil { + return fmt.Errorf("invalid domain name: %w", err) + } + + // Validate ACME token to prevent injection attacks + if err := validation.ValidateACMEToken(token); err != nil { + return fmt.Errorf("invalid ACME token: %w", err) + } + fqdn := "_acme-challenge." + domain s.Logger.Info("removing DNS TXT record via script", @@ -90,6 +112,16 @@ func (s *ScriptDNSSolver) PresentPersist(ctx context.Context, domain, token, rec return fmt.Errorf("DNS present script not configured") } + // Validate domain name to prevent injection attacks + if err := validation.ValidateDomainName(domain); err != nil { + return fmt.Errorf("invalid domain name: %w", err) + } + + // Validate ACME token to prevent injection attacks + if err := validation.ValidateACMEToken(token); err != nil { + return fmt.Errorf("invalid ACME token: %w", err) + } + fqdn := "_validation-persist." + domain s.Logger.Info("creating persistent DNS TXT record via script", diff --git a/internal/connector/issuer/acme/dns_test.go b/internal/connector/issuer/acme/dns_test.go index d11c051..e99119d 100644 --- a/internal/connector/issuer/acme/dns_test.go +++ b/internal/connector/issuer/acme/dns_test.go @@ -193,3 +193,136 @@ echo "FQDN=$CERTCTL_DNS_FQDN" > ` + outputFile + ` } }) } + +// Security tests for DNS injection prevention + +func TestScriptDNSSolver_Present_RejectInvalidDomain(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + ctx := context.Background() + + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "present.sh") + os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755) + + tests := []struct { + name string + domain string + }{ + { + name: "domain with command injection semicolon", + domain: "example.com; rm -rf /", + }, + { + name: "domain with backtick injection", + domain: "example.com`whoami`", + }, + { + name: "domain with command substitution", + domain: "example.com$(whoami)", + }, + { + name: "domain with pipe injection", + domain: "example.com | cat /etc/passwd", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + solver := acmeissuer.NewScriptDNSSolver(scriptPath, "", logger) + err := solver.Present(ctx, tt.domain, "test-token", "test-key-auth") + if err == nil { + t.Fatalf("expected error for invalid domain: %s", tt.domain) + } + }) + } +} + +func TestScriptDNSSolver_Present_RejectInvalidToken(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + ctx := context.Background() + + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "present.sh") + os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755) + + tests := []struct { + name string + token string + }{ + { + name: "token with command injection", + token: "token$(whoami)", + }, + { + name: "token with backtick injection", + token: "token`id`", + }, + { + name: "token with semicolon", + token: "token;malicious", + }, + { + name: "token with pipe", + token: "token|cat", + }, + { + name: "token with space", + token: "token value", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + solver := acmeissuer.NewScriptDNSSolver(scriptPath, "", logger) + err := solver.Present(ctx, "example.com", tt.token, "test-key-auth") + if err == nil { + t.Fatalf("expected error for invalid token: %s", tt.token) + } + }) + } +} + +func TestScriptDNSSolver_CleanUp_RejectInvalidDomain(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + ctx := context.Background() + + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "cleanup.sh") + os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755) + + solver := acmeissuer.NewScriptDNSSolver("", scriptPath, logger) + err := solver.CleanUp(ctx, "example.com; rm -rf /", "test-token", "test-key-auth") + if err == nil { + t.Fatal("expected error for command injection in domain") + } +} + +func TestScriptDNSSolver_PresentPersist_RejectInvalidDomain(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + ctx := context.Background() + + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "present.sh") + os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755) + + solver := acmeissuer.NewScriptDNSSolver(scriptPath, "", logger) + err := solver.PresentPersist(ctx, "example.com`whoami`", "test-token", "letsencrypt.org; accounturi=https://example.com/acct/1") + if err == nil { + t.Fatal("expected error for command injection in domain") + } +} + +func TestScriptDNSSolver_PresentPersist_RejectInvalidToken(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + ctx := context.Background() + + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "present.sh") + os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755) + + solver := acmeissuer.NewScriptDNSSolver(scriptPath, "", logger) + err := solver.PresentPersist(ctx, "example.com", "token$(whoami)", "letsencrypt.org; accounturi=https://example.com/acct/1") + if err == nil { + t.Fatal("expected error for command injection in token") + } +} diff --git a/internal/connector/issuer/openssl/openssl.go b/internal/connector/issuer/openssl/openssl.go index e97f203..2388d41 100644 --- a/internal/connector/issuer/openssl/openssl.go +++ b/internal/connector/issuer/openssl/openssl.go @@ -97,22 +97,28 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag return fmt.Errorf("sign_script is required") } - // Verify sign_script exists and is executable - if _, err := os.Stat(cfg.SignScript); err != nil { + // Verify sign_script exists and is a regular file + if info, err := os.Stat(cfg.SignScript); err != nil { return fmt.Errorf("sign_script not accessible: %w", err) + } else if !info.Mode().IsRegular() { + return fmt.Errorf("sign_script must be a regular file, got %s", info.Mode()) } - // Verify revoke_script exists if specified + // Verify revoke_script exists and is a regular file if specified if cfg.RevokeScript != "" { - if _, err := os.Stat(cfg.RevokeScript); err != nil { + if info, err := os.Stat(cfg.RevokeScript); err != nil { return fmt.Errorf("revoke_script not accessible: %w", err) + } else if !info.Mode().IsRegular() { + return fmt.Errorf("revoke_script must be a regular file, got %s", info.Mode()) } } - // Verify crl_script exists if specified + // Verify crl_script exists and is a regular file if specified if cfg.CRLScript != "" { - if _, err := os.Stat(cfg.CRLScript); err != nil { + if info, err := os.Stat(cfg.CRLScript); err != nil { return fmt.Errorf("crl_script not accessible: %w", err) + } else if !info.Mode().IsRegular() { + return fmt.Errorf("crl_script must be a regular file, got %s", info.Mode()) } } diff --git a/internal/connector/issuer/openssl/openssl_test.go b/internal/connector/issuer/openssl/openssl_test.go index 955caca..3d8ded8 100644 --- a/internal/connector/issuer/openssl/openssl_test.go +++ b/internal/connector/issuer/openssl/openssl_test.go @@ -556,3 +556,68 @@ func generateMockCertPEM() string { Bytes: certBytes, })) } + +// Security tests for script path validation + +func TestOpenSSLConnector_ValidateConfig_RejectNonRegularFile(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + ctx := context.Background() + + // Try to use a directory as a script path + tmpDir := t.TempDir() + + config := &openssl.Config{ + SignScript: tmpDir, // This is a directory, not a regular file + } + connector := openssl.New(config, logger) + + rawConfig, _ := json.Marshal(config) + err := connector.ValidateConfig(ctx, rawConfig) + if err == nil { + t.Fatal("Expected error when sign_script is not a regular file") + } +} + +func TestOpenSSLConnector_ValidateConfig_ValidateRevokeScriptPath(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + ctx := context.Background() + + tmpDir := t.TempDir() + signScript := filepath.Join(tmpDir, "sign.sh") + os.WriteFile(signScript, []byte("#!/bin/sh\nexit 0"), 0755) + + // Try to use a nonexistent file as revoke_script + config := &openssl.Config{ + SignScript: signScript, + RevokeScript: "/nonexistent/revoke.sh", + } + connector := openssl.New(config, logger) + + rawConfig, _ := json.Marshal(config) + err := connector.ValidateConfig(ctx, rawConfig) + if err == nil { + t.Fatal("Expected error when revoke_script is nonexistent") + } +} + +func TestOpenSSLConnector_ValidateConfig_ValidateCRLScriptPath(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + ctx := context.Background() + + tmpDir := t.TempDir() + signScript := filepath.Join(tmpDir, "sign.sh") + os.WriteFile(signScript, []byte("#!/bin/sh\nexit 0"), 0755) + + // Try to use a directory as crl_script + config := &openssl.Config{ + SignScript: signScript, + CRLScript: tmpDir, // This is a directory, not a regular file + } + connector := openssl.New(config, logger) + + rawConfig, _ := json.Marshal(config) + err := connector.ValidateConfig(ctx, rawConfig) + if err == nil { + t.Fatal("Expected error when crl_script is not a regular file") + } +} diff --git a/internal/connector/target/apache/apache.go b/internal/connector/target/apache/apache.go index 6def055..52ab74f 100644 --- a/internal/connector/target/apache/apache.go +++ b/internal/connector/target/apache/apache.go @@ -142,7 +142,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy // Validate Apache configuration before reload c.logger.Debug("validating Apache configuration", "validate_command", c.config.ValidateCommand) - validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand) + validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand) if output, err := validateCmd.CombinedOutput(); err != nil { errMsg := fmt.Sprintf("Apache config validation failed: %v (output: %s)", err, string(output)) c.logger.Error("Apache validation failed", "error", err, "output", string(output)) @@ -156,7 +156,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy // Graceful reload c.logger.Debug("reloading Apache", "reload_command", c.config.ReloadCommand) - reloadCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ReloadCommand) + reloadCmd := exec.CommandContext(ctx, c.config.ReloadCommand) if output, err := reloadCmd.CombinedOutput(); err != nil { errMsg := fmt.Sprintf("Apache reload failed: %v (output: %s)", err, string(output)) c.logger.Error("Apache reload failed", "error", err, "output", string(output)) @@ -196,7 +196,7 @@ func (c *Connector) ValidateDeployment(ctx context.Context, request target.Valid startTime := time.Now() // Validate Apache configuration - validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand) + validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand) if output, err := validateCmd.CombinedOutput(); err != nil { errMsg := fmt.Sprintf("Apache config validation failed: %v (output: %s)", err, string(output)) c.logger.Error("validation failed", "error", err) diff --git a/internal/connector/target/haproxy/haproxy.go b/internal/connector/target/haproxy/haproxy.go index 2d4dba2..85ce13a 100644 --- a/internal/connector/target/haproxy/haproxy.go +++ b/internal/connector/target/haproxy/haproxy.go @@ -10,6 +10,7 @@ import ( "time" "github.com/shankar0123/certctl/internal/connector/target" + "github.com/shankar0123/certctl/internal/validation" ) // Config represents the HAProxy deployment target configuration. @@ -53,12 +54,22 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag return fmt.Errorf("HAProxy reload_command is required") } + // Validate commands to prevent injection attacks + if err := validation.ValidateShellCommand(cfg.ReloadCommand); err != nil { + return fmt.Errorf("invalid reload_command: %w", err) + } + if cfg.ValidateCommand != "" { + if err := validation.ValidateShellCommand(cfg.ValidateCommand); err != nil { + return fmt.Errorf("invalid validate_command: %w", err) + } + } + c.logger.Info("validating HAProxy configuration", "pem_path", cfg.PEMPath) // Verify validate command works if provided if cfg.ValidateCommand != "" { - cmd := exec.CommandContext(ctx, "sh", "-c", cfg.ValidateCommand) + cmd := exec.CommandContext(ctx, cfg.ValidateCommand) if err := cmd.Run(); err != nil { c.logger.Warn("HAProxy config validation failed during config check", "error", err, @@ -114,7 +125,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy // Validate HAProxy configuration if validate command is configured if c.config.ValidateCommand != "" { c.logger.Debug("validating HAProxy configuration", "validate_command", c.config.ValidateCommand) - validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand) + validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand) if output, err := validateCmd.CombinedOutput(); err != nil { errMsg := fmt.Sprintf("HAProxy config validation failed: %v (output: %s)", err, string(output)) c.logger.Error("HAProxy validation failed", "error", err, "output", string(output)) @@ -129,7 +140,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy // Reload HAProxy c.logger.Debug("reloading HAProxy", "reload_command", c.config.ReloadCommand) - reloadCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ReloadCommand) + reloadCmd := exec.CommandContext(ctx, c.config.ReloadCommand) if output, err := reloadCmd.CombinedOutput(); err != nil { errMsg := fmt.Sprintf("HAProxy reload failed: %v (output: %s)", err, string(output)) c.logger.Error("HAProxy reload failed", "error", err, "output", string(output)) @@ -169,7 +180,7 @@ func (c *Connector) ValidateDeployment(ctx context.Context, request target.Valid // Validate HAProxy configuration if command provided if c.config.ValidateCommand != "" { - validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand) + validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand) if output, err := validateCmd.CombinedOutput(); err != nil { errMsg := fmt.Sprintf("HAProxy config validation failed: %v (output: %s)", err, string(output)) c.logger.Error("validation failed", "error", err) diff --git a/internal/connector/target/nginx/nginx_test.go b/internal/connector/target/nginx/nginx_test.go index dd1dd76..f76f7ef 100644 --- a/internal/connector/target/nginx/nginx_test.go +++ b/internal/connector/target/nginx/nginx_test.go @@ -377,3 +377,85 @@ func TestNginxConnector_ValidateDeployment_ValidateCommandFails(t *testing.T) { t.Fatal("expected invalid result") } } + +// Security tests for command injection prevention + +func TestNginxConnector_ValidateConfig_RejectCommandInjectionSemicolon(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + ctx := context.Background() + + tmpDir := t.TempDir() + cfg := nginx.Config{ + CertPath: filepath.Join(tmpDir, "cert.pem"), + ChainPath: filepath.Join(tmpDir, "chain.pem"), + ReloadCommand: "nginx; rm -rf /", // Command injection attempt + ValidateCommand: "true", + } + + connector := nginx.New(&cfg, logger) + rawConfig, _ := json.Marshal(cfg) + err := connector.ValidateConfig(ctx, rawConfig) + if err == nil { + t.Fatal("expected error for command injection in reload_command") + } +} + +func TestNginxConnector_ValidateConfig_RejectCommandInjectionPipe(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + ctx := context.Background() + + tmpDir := t.TempDir() + cfg := nginx.Config{ + CertPath: filepath.Join(tmpDir, "cert.pem"), + ChainPath: filepath.Join(tmpDir, "chain.pem"), + ReloadCommand: "true", + ValidateCommand: "nginx -t | cat /etc/passwd", // Command injection attempt + } + + connector := nginx.New(&cfg, logger) + rawConfig, _ := json.Marshal(cfg) + err := connector.ValidateConfig(ctx, rawConfig) + if err == nil { + t.Fatal("expected error for command injection in validate_command") + } +} + +func TestNginxConnector_ValidateConfig_RejectCommandSubstitution(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + ctx := context.Background() + + tmpDir := t.TempDir() + cfg := nginx.Config{ + CertPath: filepath.Join(tmpDir, "cert.pem"), + ChainPath: filepath.Join(tmpDir, "chain.pem"), + ReloadCommand: "echo $(whoami)", + ValidateCommand: "true", + } + + connector := nginx.New(&cfg, logger) + rawConfig, _ := json.Marshal(cfg) + err := connector.ValidateConfig(ctx, rawConfig) + if err == nil { + t.Fatal("expected error for command substitution in reload_command") + } +} + +func TestNginxConnector_ValidateConfig_RejectBackticks(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + ctx := context.Background() + + tmpDir := t.TempDir() + cfg := nginx.Config{ + CertPath: filepath.Join(tmpDir, "cert.pem"), + ChainPath: filepath.Join(tmpDir, "chain.pem"), + ReloadCommand: "true", + ValidateCommand: "nginx -t `whoami`", + } + + connector := nginx.New(&cfg, logger) + rawConfig, _ := json.Marshal(cfg) + err := connector.ValidateConfig(ctx, rawConfig) + if err == nil { + t.Fatal("expected error for backtick injection in validate_command") + } +} diff --git a/internal/service/agent.go b/internal/service/agent.go index 93628f4..285937c 100644 --- a/internal/service/agent.go +++ b/internal/service/agent.go @@ -105,8 +105,9 @@ func (s *AgentService) HeartbeatWithContext(ctx context.Context, agentID string, } // Heartbeat updates agent heartbeat (handler interface method). -func (s *AgentService) Heartbeat(agentID string, metadata *domain.AgentMetadata) error { - return s.HeartbeatWithContext(context.Background(), agentID, metadata) +// Note: This method is called from handlers which have a context; callers should prefer HeartbeatWithContext. +func (s *AgentService) Heartbeat(ctx context.Context, agentID string, metadata *domain.AgentMetadata) error { + return s.HeartbeatWithContext(ctx, agentID, metadata) } // SubmitCSR validates and processes a Certificate Signing Request from an agent. @@ -326,7 +327,7 @@ func (s *AgentService) GetAgentByAPIKey(ctx context.Context, apiKey string) (*do } // ListAgents returns paginated agents (handler interface method). -func (s *AgentService) ListAgents(page, perPage int) ([]domain.Agent, int64, error) { +func (s *AgentService) ListAgents(ctx context.Context, page, perPage int) ([]domain.Agent, int64, error) { if page < 1 { page = 1 } @@ -334,7 +335,7 @@ func (s *AgentService) ListAgents(page, perPage int) ([]domain.Agent, int64, err perPage = 50 } - agents, err := s.agentRepo.List(context.Background()) + agents, err := s.agentRepo.List(ctx) if err != nil { return nil, 0, fmt.Errorf("failed to list agents: %w", err) } @@ -360,12 +361,12 @@ func (s *AgentService) ListAgents(page, perPage int) ([]domain.Agent, int64, err } // GetAgent returns a single agent (handler interface method). -func (s *AgentService) GetAgent(id string) (*domain.Agent, error) { - return s.agentRepo.Get(context.Background(), id) +func (s *AgentService) GetAgent(ctx context.Context, id string) (*domain.Agent, error) { + return s.agentRepo.Get(ctx, id) } // RegisterAgent creates and registers a new agent (handler interface method). -func (s *AgentService) RegisterAgent(agent domain.Agent) (*domain.Agent, error) { +func (s *AgentService) RegisterAgent(ctx context.Context, agent domain.Agent) (*domain.Agent, error) { agent.ID = generateID("agent") apiKey := generateAPIKey() agent.APIKeyHash = hashAPIKey(apiKey) @@ -374,7 +375,7 @@ func (s *AgentService) RegisterAgent(agent domain.Agent) (*domain.Agent, error) agent.RegisteredAt = now agent.LastHeartbeatAt = &now - if err := s.agentRepo.Create(context.Background(), &agent); err != nil { + if err := s.agentRepo.Create(ctx, &agent); err != nil { return nil, fmt.Errorf("failed to register agent: %w", err) } return &agent, nil @@ -382,8 +383,8 @@ func (s *AgentService) RegisterAgent(agent domain.Agent) (*domain.Agent, error) // CSRSubmit processes a CSR submission from an agent (handler interface method). // The csrPEM parameter contains "certID:csrPEM" or just the CSR PEM. -func (s *AgentService) CSRSubmit(agentID string, csrPEM string) (string, error) { - err := s.SubmitCSR(context.Background(), agentID, "", []byte(csrPEM)) +func (s *AgentService) CSRSubmit(ctx context.Context, agentID string, csrPEM string) (string, error) { + err := s.SubmitCSR(ctx, agentID, "", []byte(csrPEM)) if err != nil { return "", err } @@ -391,8 +392,8 @@ func (s *AgentService) CSRSubmit(agentID string, csrPEM string) (string, error) } // CSRSubmitForCert processes a CSR submission for a specific certificate (handler interface method). -func (s *AgentService) CSRSubmitForCert(agentID string, certID string, csrPEM string) (string, error) { - err := s.SubmitCSR(context.Background(), agentID, certID, []byte(csrPEM)) +func (s *AgentService) CSRSubmitForCert(ctx context.Context, agentID string, certID string, csrPEM string) (string, error) { + err := s.SubmitCSR(ctx, agentID, certID, []byte(csrPEM)) if err != nil { return "", err } @@ -400,8 +401,8 @@ func (s *AgentService) CSRSubmitForCert(agentID string, certID string, csrPEM st } // GetWork returns pending deployment jobs for an agent (handler interface method). -func (s *AgentService) GetWork(agentID string) ([]domain.Job, error) { - jobs, err := s.GetPendingWork(context.Background(), agentID) +func (s *AgentService) GetWork(ctx context.Context, agentID string) ([]domain.Job, error) { + jobs, err := s.GetPendingWork(ctx, agentID) if err != nil { return nil, err } @@ -417,8 +418,8 @@ func (s *AgentService) GetWork(agentID string) ([]domain.Job, error) { // GetWorkWithTargets returns actionable jobs enriched with target/certificate details. // Deployment jobs include target type + config. AwaitingCSR jobs include common name + SANs // so the agent knows what CSR to generate. -func (s *AgentService) GetWorkWithTargets(agentID string) ([]domain.WorkItem, error) { - jobs, err := s.GetPendingWork(context.Background(), agentID) +func (s *AgentService) GetWorkWithTargets(ctx context.Context, agentID string) ([]domain.WorkItem, error) { + jobs, err := s.GetPendingWork(ctx, agentID) if err != nil { return nil, err } @@ -438,7 +439,7 @@ func (s *AgentService) GetWorkWithTargets(agentID string) ([]domain.WorkItem, er // Enrich with target details for deployment jobs if j.TargetID != nil && *j.TargetID != "" { - target, err := s.targetRepo.Get(context.Background(), *j.TargetID) + target, err := s.targetRepo.Get(ctx, *j.TargetID) if err == nil && target != nil { item.TargetType = string(target.Type) item.TargetConfig = target.Config @@ -447,7 +448,7 @@ func (s *AgentService) GetWorkWithTargets(agentID string) ([]domain.WorkItem, er // Enrich with certificate details for AwaitingCSR jobs (agent needs CN + SANs for CSR) if j.Status == domain.JobStatusAwaitingCSR { - cert, err := s.certRepo.Get(context.Background(), j.CertificateID) + cert, err := s.certRepo.Get(ctx, j.CertificateID) if err == nil && cert != nil { item.CommonName = cert.CommonName item.SANs = cert.SANs @@ -461,13 +462,13 @@ func (s *AgentService) GetWorkWithTargets(agentID string) ([]domain.WorkItem, er } // UpdateJobStatus reports a job's status from an agent (handler interface method). -func (s *AgentService) UpdateJobStatus(agentID string, jobID string, status string, errMsg string) error { - return s.ReportJobStatus(context.Background(), agentID, jobID, domain.JobStatus(status), errMsg) +func (s *AgentService) UpdateJobStatus(ctx context.Context, agentID string, jobID string, status string, errMsg string) error { + return s.ReportJobStatus(ctx, agentID, jobID, domain.JobStatus(status), errMsg) } // CertificatePickup retrieves a certificate for an agent (handler interface method). -func (s *AgentService) CertificatePickup(agentID, certID string) (string, error) { - certPEM, err := s.GetCertificateForAgent(context.Background(), agentID, certID) +func (s *AgentService) CertificatePickup(ctx context.Context, agentID, certID string) (string, error) { + certPEM, err := s.GetCertificateForAgent(ctx, agentID, certID) if err != nil { return "", err } diff --git a/internal/service/agent_group.go b/internal/service/agent_group.go index 41932a6..5e9f148 100644 --- a/internal/service/agent_group.go +++ b/internal/service/agent_group.go @@ -28,7 +28,7 @@ func NewAgentGroupService( } // ListAgentGroups returns paginated agent groups (handler interface method). -func (s *AgentGroupService) ListAgentGroups(page, perPage int) ([]domain.AgentGroup, int64, error) { +func (s *AgentGroupService) ListAgentGroups(ctx context.Context, page, perPage int) ([]domain.AgentGroup, int64, error) { if page < 1 { page = 1 } @@ -36,7 +36,7 @@ func (s *AgentGroupService) ListAgentGroups(page, perPage int) ([]domain.AgentGr perPage = 50 } - groups, err := s.groupRepo.List(context.Background()) + groups, err := s.groupRepo.List(ctx) if err != nil { return nil, 0, fmt.Errorf("failed to list agent groups: %w", err) } @@ -53,12 +53,12 @@ func (s *AgentGroupService) ListAgentGroups(page, perPage int) ([]domain.AgentGr } // GetAgentGroup returns a single agent group (handler interface method). -func (s *AgentGroupService) GetAgentGroup(id string) (*domain.AgentGroup, error) { - return s.groupRepo.Get(context.Background(), id) +func (s *AgentGroupService) GetAgentGroup(ctx context.Context, id string) (*domain.AgentGroup, error) { + return s.groupRepo.Get(ctx, id) } // CreateAgentGroup creates a new agent group with validation (handler interface method). -func (s *AgentGroupService) CreateAgentGroup(group domain.AgentGroup) (*domain.AgentGroup, error) { +func (s *AgentGroupService) CreateAgentGroup(ctx context.Context, group domain.AgentGroup) (*domain.AgentGroup, error) { if err := validateAgentGroup(&group); err != nil { return nil, err } @@ -74,12 +74,12 @@ func (s *AgentGroupService) CreateAgentGroup(group domain.AgentGroup) (*domain.A group.UpdatedAt = now } - if err := s.groupRepo.Create(context.Background(), &group); err != nil { + if err := s.groupRepo.Create(ctx, &group); err != nil { return nil, fmt.Errorf("failed to create agent group: %w", err) } if s.auditService != nil { - if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser, + if auditErr := s.auditService.RecordEvent(ctx, "api", domain.ActorTypeUser, "create_agent_group", "agent_group", group.ID, nil); auditErr != nil { slog.Error("failed to record audit event", "error", auditErr) } @@ -89,18 +89,18 @@ func (s *AgentGroupService) CreateAgentGroup(group domain.AgentGroup) (*domain.A } // UpdateAgentGroup modifies an existing agent group (handler interface method). -func (s *AgentGroupService) UpdateAgentGroup(id string, group domain.AgentGroup) (*domain.AgentGroup, error) { +func (s *AgentGroupService) UpdateAgentGroup(ctx context.Context, id string, group domain.AgentGroup) (*domain.AgentGroup, error) { if err := validateAgentGroup(&group); err != nil { return nil, err } group.ID = id - if err := s.groupRepo.Update(context.Background(), &group); err != nil { + if err := s.groupRepo.Update(ctx, &group); err != nil { return nil, fmt.Errorf("failed to update agent group: %w", err) } if s.auditService != nil { - if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser, + if auditErr := s.auditService.RecordEvent(ctx, "api", domain.ActorTypeUser, "update_agent_group", "agent_group", id, nil); auditErr != nil { slog.Error("failed to record audit event", "error", auditErr) } @@ -110,13 +110,13 @@ func (s *AgentGroupService) UpdateAgentGroup(id string, group domain.AgentGroup) } // DeleteAgentGroup removes an agent group (handler interface method). -func (s *AgentGroupService) DeleteAgentGroup(id string) error { - if err := s.groupRepo.Delete(context.Background(), id); err != nil { +func (s *AgentGroupService) DeleteAgentGroup(ctx context.Context, id string) error { + if err := s.groupRepo.Delete(ctx, id); err != nil { return fmt.Errorf("failed to delete agent group: %w", err) } if s.auditService != nil { - if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser, + if auditErr := s.auditService.RecordEvent(ctx, "api", domain.ActorTypeUser, "delete_agent_group", "agent_group", id, nil); auditErr != nil { slog.Error("failed to record audit event", "error", auditErr) } @@ -126,8 +126,8 @@ func (s *AgentGroupService) DeleteAgentGroup(id string) error { } // ListMembers returns agents in a group. -func (s *AgentGroupService) ListMembers(id string) ([]domain.Agent, int64, error) { - agents, err := s.groupRepo.ListMembers(context.Background(), id) +func (s *AgentGroupService) ListMembers(ctx context.Context, id string) ([]domain.Agent, int64, error) { + agents, err := s.groupRepo.ListMembers(ctx, id) if err != nil { return nil, 0, fmt.Errorf("failed to list group members: %w", err) }