mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-07 21:41:39 +00:00
fix(quality): TICKET-012 propagate request context instead of context.Background()
- Updated AgentService interface to accept context.Context parameter in all methods - Replaced context.Background() calls with proper ctx parameter in agent.go - Updated AgentGroupService interface to accept context.Context parameter - Replaced context.Background() calls with proper ctx parameter in agent_group.go - Updated handler methods to pass r.Context() to service methods - Context now properly propagates through request lifecycle for timeout/cancellation - Improved request tracing and cancellation behavior
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -12,12 +13,12 @@ import (
|
||||
|
||||
// AgentGroupService defines the service interface for agent group operations.
|
||||
type AgentGroupService interface {
|
||||
ListAgentGroups(page, perPage int) ([]domain.AgentGroup, int64, error)
|
||||
GetAgentGroup(id string) (*domain.AgentGroup, error)
|
||||
CreateAgentGroup(group domain.AgentGroup) (*domain.AgentGroup, error)
|
||||
UpdateAgentGroup(id string, group domain.AgentGroup) (*domain.AgentGroup, error)
|
||||
DeleteAgentGroup(id string) error
|
||||
ListMembers(id string) ([]domain.Agent, int64, error)
|
||||
ListAgentGroups(ctx context.Context, page, perPage int) ([]domain.AgentGroup, int64, error)
|
||||
GetAgentGroup(ctx context.Context, id string) (*domain.AgentGroup, error)
|
||||
CreateAgentGroup(ctx context.Context, group domain.AgentGroup) (*domain.AgentGroup, error)
|
||||
UpdateAgentGroup(ctx context.Context, id string, group domain.AgentGroup) (*domain.AgentGroup, error)
|
||||
DeleteAgentGroup(ctx context.Context, id string) error
|
||||
ListMembers(ctx context.Context, id string) ([]domain.Agent, int64, error)
|
||||
}
|
||||
|
||||
// AgentGroupHandler handles HTTP requests for agent group operations.
|
||||
@@ -54,7 +55,7 @@ func (h AgentGroupHandler) ListAgentGroups(w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
}
|
||||
|
||||
groups, total, err := h.svc.ListAgentGroups(page, perPage)
|
||||
groups, total, err := h.svc.ListAgentGroups(r.Context(), page, perPage)
|
||||
if err != nil {
|
||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list agent groups", requestID)
|
||||
return
|
||||
@@ -86,7 +87,7 @@ func (h AgentGroupHandler) GetAgentGroup(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.svc.GetAgentGroup(id)
|
||||
group, err := h.svc.GetAgentGroup(r.Context(), id)
|
||||
if err != nil {
|
||||
ErrorWithRequestID(w, http.StatusNotFound, "Agent group not found", requestID)
|
||||
return
|
||||
@@ -120,7 +121,7 @@ func (h AgentGroupHandler) CreateAgentGroup(w http.ResponseWriter, r *http.Reque
|
||||
return
|
||||
}
|
||||
|
||||
created, err := h.svc.CreateAgentGroup(group)
|
||||
created, err := h.svc.CreateAgentGroup(r.Context(), group)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "invalid") || strings.Contains(err.Error(), "required") {
|
||||
ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID)
|
||||
@@ -157,7 +158,7 @@ func (h AgentGroupHandler) UpdateAgentGroup(w http.ResponseWriter, r *http.Reque
|
||||
return
|
||||
}
|
||||
|
||||
updated, err := h.svc.UpdateAgentGroup(id, group)
|
||||
updated, err := h.svc.UpdateAgentGroup(r.Context(), id, group)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
ErrorWithRequestID(w, http.StatusNotFound, "Agent group not found", requestID)
|
||||
@@ -186,7 +187,7 @@ func (h AgentGroupHandler) DeleteAgentGroup(w http.ResponseWriter, r *http.Reque
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.DeleteAgentGroup(id); err != nil {
|
||||
if err := h.svc.DeleteAgentGroup(r.Context(), id); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
ErrorWithRequestID(w, http.StatusNotFound, "Agent group not found", requestID)
|
||||
return
|
||||
@@ -217,7 +218,7 @@ func (h AgentGroupHandler) ListAgentGroupMembers(w http.ResponseWriter, r *http.
|
||||
}
|
||||
id := parts[0]
|
||||
|
||||
members, total, err := h.svc.ListMembers(id)
|
||||
members, total, err := h.svc.ListMembers(r.Context(), id)
|
||||
if err != nil {
|
||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list group members", requestID)
|
||||
return
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -12,16 +13,16 @@ import (
|
||||
|
||||
// AgentService defines the service interface for agent operations.
|
||||
type AgentService interface {
|
||||
ListAgents(page, perPage int) ([]domain.Agent, int64, error)
|
||||
GetAgent(id string) (*domain.Agent, error)
|
||||
RegisterAgent(agent domain.Agent) (*domain.Agent, error)
|
||||
Heartbeat(agentID string, metadata *domain.AgentMetadata) error
|
||||
CSRSubmit(agentID string, csrPEM string) (string, error)
|
||||
CSRSubmitForCert(agentID string, certID string, csrPEM string) (string, error)
|
||||
CertificatePickup(agentID, certID string) (string, error)
|
||||
GetWork(agentID string) ([]domain.Job, error)
|
||||
GetWorkWithTargets(agentID string) ([]domain.WorkItem, error)
|
||||
UpdateJobStatus(agentID string, jobID string, status string, errMsg string) error
|
||||
ListAgents(ctx context.Context, page, perPage int) ([]domain.Agent, int64, error)
|
||||
GetAgent(ctx context.Context, id string) (*domain.Agent, error)
|
||||
RegisterAgent(ctx context.Context, agent domain.Agent) (*domain.Agent, error)
|
||||
Heartbeat(ctx context.Context, agentID string, metadata *domain.AgentMetadata) error
|
||||
CSRSubmit(ctx context.Context, agentID string, csrPEM string) (string, error)
|
||||
CSRSubmitForCert(ctx context.Context, agentID string, certID string, csrPEM string) (string, error)
|
||||
CertificatePickup(ctx context.Context, agentID, certID string) (string, error)
|
||||
GetWork(ctx context.Context, agentID string) ([]domain.Job, error)
|
||||
GetWorkWithTargets(ctx context.Context, agentID string) ([]domain.WorkItem, error)
|
||||
UpdateJobStatus(ctx context.Context, agentID string, jobID string, status string, errMsg string) error
|
||||
}
|
||||
|
||||
// AgentHandler handles HTTP requests for agent operations.
|
||||
@@ -58,7 +59,7 @@ func (h AgentHandler) ListAgents(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
agents, total, err := h.svc.ListAgents(page, perPage)
|
||||
agents, total, err := h.svc.ListAgents(r.Context(), page, perPage)
|
||||
if err != nil {
|
||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list agents", requestID)
|
||||
return
|
||||
@@ -92,7 +93,7 @@ func (h AgentHandler) GetAgent(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
id = parts[0]
|
||||
|
||||
agent, err := h.svc.GetAgent(id)
|
||||
agent, err := h.svc.GetAgent(r.Context(), id)
|
||||
if err != nil {
|
||||
ErrorWithRequestID(w, http.StatusNotFound, "Agent not found", requestID)
|
||||
return
|
||||
@@ -131,7 +132,7 @@ func (h AgentHandler) RegisterAgent(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
created, err := h.svc.RegisterAgent(agent)
|
||||
created, err := h.svc.RegisterAgent(r.Context(), agent)
|
||||
if err != nil {
|
||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to register agent", requestID)
|
||||
return
|
||||
@@ -182,7 +183,7 @@ func (h AgentHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.svc.Heartbeat(agentID, metadata); err != nil {
|
||||
if err := h.svc.Heartbeat(r.Context(), agentID, metadata); err != nil {
|
||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to record heartbeat", requestID)
|
||||
return
|
||||
}
|
||||
@@ -234,9 +235,9 @@ func (h AgentHandler) AgentCSRSubmit(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// If certificate_id is provided, sign the CSR for that specific certificate
|
||||
if req.CertificateID != "" {
|
||||
status, err = h.svc.CSRSubmitForCert(agentID, req.CertificateID, req.CSRPEM)
|
||||
status, err = h.svc.CSRSubmitForCert(r.Context(), agentID, req.CertificateID, req.CSRPEM)
|
||||
} else {
|
||||
status, err = h.svc.CSRSubmit(agentID, req.CSRPEM)
|
||||
status, err = h.svc.CSRSubmit(r.Context(), agentID, req.CSRPEM)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -271,7 +272,7 @@ func (h AgentHandler) AgentCertificatePickup(w http.ResponseWriter, r *http.Requ
|
||||
agentID := parts[0]
|
||||
certID := parts[2]
|
||||
|
||||
certPEM, err := h.svc.CertificatePickup(agentID, certID)
|
||||
certPEM, err := h.svc.CertificatePickup(r.Context(), agentID, certID)
|
||||
if err != nil {
|
||||
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found or not ready", requestID)
|
||||
return
|
||||
@@ -303,7 +304,7 @@ func (h AgentHandler) AgentGetWork(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
agentID := parts[0]
|
||||
|
||||
workItems, err := h.svc.GetWorkWithTargets(agentID)
|
||||
workItems, err := h.svc.GetWorkWithTargets(r.Context(), agentID)
|
||||
if err != nil {
|
||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to get pending work", requestID)
|
||||
return
|
||||
@@ -353,7 +354,7 @@ func (h AgentHandler) AgentReportJobStatus(w http.ResponseWriter, r *http.Reques
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.UpdateJobStatus(agentID, jobID, req.Status, req.Error); err != nil {
|
||||
if err := h.svc.UpdateJobStatus(r.Context(), agentID, jobID, req.Status, req.Error); err != nil {
|
||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update job status", requestID)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"log/slog"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/validation"
|
||||
)
|
||||
|
||||
// DNSSolver defines the interface for DNS-01 challenge provisioning.
|
||||
@@ -55,6 +57,16 @@ func (s *ScriptDNSSolver) Present(ctx context.Context, domain, token, keyAuth st
|
||||
return fmt.Errorf("DNS present script not configured")
|
||||
}
|
||||
|
||||
// Validate domain name to prevent injection attacks
|
||||
if err := validation.ValidateDomainName(domain); err != nil {
|
||||
return fmt.Errorf("invalid domain name: %w", err)
|
||||
}
|
||||
|
||||
// Validate ACME token to prevent injection attacks
|
||||
if err := validation.ValidateACMEToken(token); err != nil {
|
||||
return fmt.Errorf("invalid ACME token: %w", err)
|
||||
}
|
||||
|
||||
fqdn := "_acme-challenge." + domain
|
||||
|
||||
s.Logger.Info("creating DNS TXT record via script",
|
||||
@@ -72,6 +84,16 @@ func (s *ScriptDNSSolver) CleanUp(ctx context.Context, domain, token, keyAuth st
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate domain name to prevent injection attacks
|
||||
if err := validation.ValidateDomainName(domain); err != nil {
|
||||
return fmt.Errorf("invalid domain name: %w", err)
|
||||
}
|
||||
|
||||
// Validate ACME token to prevent injection attacks
|
||||
if err := validation.ValidateACMEToken(token); err != nil {
|
||||
return fmt.Errorf("invalid ACME token: %w", err)
|
||||
}
|
||||
|
||||
fqdn := "_acme-challenge." + domain
|
||||
|
||||
s.Logger.Info("removing DNS TXT record via script",
|
||||
@@ -90,6 +112,16 @@ func (s *ScriptDNSSolver) PresentPersist(ctx context.Context, domain, token, rec
|
||||
return fmt.Errorf("DNS present script not configured")
|
||||
}
|
||||
|
||||
// Validate domain name to prevent injection attacks
|
||||
if err := validation.ValidateDomainName(domain); err != nil {
|
||||
return fmt.Errorf("invalid domain name: %w", err)
|
||||
}
|
||||
|
||||
// Validate ACME token to prevent injection attacks
|
||||
if err := validation.ValidateACMEToken(token); err != nil {
|
||||
return fmt.Errorf("invalid ACME token: %w", err)
|
||||
}
|
||||
|
||||
fqdn := "_validation-persist." + domain
|
||||
|
||||
s.Logger.Info("creating persistent DNS TXT record via script",
|
||||
|
||||
@@ -193,3 +193,136 @@ echo "FQDN=$CERTCTL_DNS_FQDN" > ` + outputFile + `
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Security tests for DNS injection prevention
|
||||
|
||||
func TestScriptDNSSolver_Present_RejectInvalidDomain(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
scriptPath := filepath.Join(tmpDir, "present.sh")
|
||||
os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
}{
|
||||
{
|
||||
name: "domain with command injection semicolon",
|
||||
domain: "example.com; rm -rf /",
|
||||
},
|
||||
{
|
||||
name: "domain with backtick injection",
|
||||
domain: "example.com`whoami`",
|
||||
},
|
||||
{
|
||||
name: "domain with command substitution",
|
||||
domain: "example.com$(whoami)",
|
||||
},
|
||||
{
|
||||
name: "domain with pipe injection",
|
||||
domain: "example.com | cat /etc/passwd",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
solver := acmeissuer.NewScriptDNSSolver(scriptPath, "", logger)
|
||||
err := solver.Present(ctx, tt.domain, "test-token", "test-key-auth")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for invalid domain: %s", tt.domain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScriptDNSSolver_Present_RejectInvalidToken(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
scriptPath := filepath.Join(tmpDir, "present.sh")
|
||||
os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{
|
||||
name: "token with command injection",
|
||||
token: "token$(whoami)",
|
||||
},
|
||||
{
|
||||
name: "token with backtick injection",
|
||||
token: "token`id`",
|
||||
},
|
||||
{
|
||||
name: "token with semicolon",
|
||||
token: "token;malicious",
|
||||
},
|
||||
{
|
||||
name: "token with pipe",
|
||||
token: "token|cat",
|
||||
},
|
||||
{
|
||||
name: "token with space",
|
||||
token: "token value",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
solver := acmeissuer.NewScriptDNSSolver(scriptPath, "", logger)
|
||||
err := solver.Present(ctx, "example.com", tt.token, "test-key-auth")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for invalid token: %s", tt.token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScriptDNSSolver_CleanUp_RejectInvalidDomain(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
scriptPath := filepath.Join(tmpDir, "cleanup.sh")
|
||||
os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755)
|
||||
|
||||
solver := acmeissuer.NewScriptDNSSolver("", scriptPath, logger)
|
||||
err := solver.CleanUp(ctx, "example.com; rm -rf /", "test-token", "test-key-auth")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for command injection in domain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestScriptDNSSolver_PresentPersist_RejectInvalidDomain(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
scriptPath := filepath.Join(tmpDir, "present.sh")
|
||||
os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755)
|
||||
|
||||
solver := acmeissuer.NewScriptDNSSolver(scriptPath, "", logger)
|
||||
err := solver.PresentPersist(ctx, "example.com`whoami`", "test-token", "letsencrypt.org; accounturi=https://example.com/acct/1")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for command injection in domain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestScriptDNSSolver_PresentPersist_RejectInvalidToken(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
scriptPath := filepath.Join(tmpDir, "present.sh")
|
||||
os.WriteFile(scriptPath, []byte("#!/bin/sh\nexit 0"), 0755)
|
||||
|
||||
solver := acmeissuer.NewScriptDNSSolver(scriptPath, "", logger)
|
||||
err := solver.PresentPersist(ctx, "example.com", "token$(whoami)", "letsencrypt.org; accounturi=https://example.com/acct/1")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for command injection in token")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,22 +97,28 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag
|
||||
return fmt.Errorf("sign_script is required")
|
||||
}
|
||||
|
||||
// Verify sign_script exists and is executable
|
||||
if _, err := os.Stat(cfg.SignScript); err != nil {
|
||||
// Verify sign_script exists and is a regular file
|
||||
if info, err := os.Stat(cfg.SignScript); err != nil {
|
||||
return fmt.Errorf("sign_script not accessible: %w", err)
|
||||
} else if !info.Mode().IsRegular() {
|
||||
return fmt.Errorf("sign_script must be a regular file, got %s", info.Mode())
|
||||
}
|
||||
|
||||
// Verify revoke_script exists if specified
|
||||
// Verify revoke_script exists and is a regular file if specified
|
||||
if cfg.RevokeScript != "" {
|
||||
if _, err := os.Stat(cfg.RevokeScript); err != nil {
|
||||
if info, err := os.Stat(cfg.RevokeScript); err != nil {
|
||||
return fmt.Errorf("revoke_script not accessible: %w", err)
|
||||
} else if !info.Mode().IsRegular() {
|
||||
return fmt.Errorf("revoke_script must be a regular file, got %s", info.Mode())
|
||||
}
|
||||
}
|
||||
|
||||
// Verify crl_script exists if specified
|
||||
// Verify crl_script exists and is a regular file if specified
|
||||
if cfg.CRLScript != "" {
|
||||
if _, err := os.Stat(cfg.CRLScript); err != nil {
|
||||
if info, err := os.Stat(cfg.CRLScript); err != nil {
|
||||
return fmt.Errorf("crl_script not accessible: %w", err)
|
||||
} else if !info.Mode().IsRegular() {
|
||||
return fmt.Errorf("crl_script must be a regular file, got %s", info.Mode())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -556,3 +556,68 @@ func generateMockCertPEM() string {
|
||||
Bytes: certBytes,
|
||||
}))
|
||||
}
|
||||
|
||||
// Security tests for script path validation
|
||||
|
||||
func TestOpenSSLConnector_ValidateConfig_RejectNonRegularFile(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
// Try to use a directory as a script path
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
config := &openssl.Config{
|
||||
SignScript: tmpDir, // This is a directory, not a regular file
|
||||
}
|
||||
connector := openssl.New(config, logger)
|
||||
|
||||
rawConfig, _ := json.Marshal(config)
|
||||
err := connector.ValidateConfig(ctx, rawConfig)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when sign_script is not a regular file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenSSLConnector_ValidateConfig_ValidateRevokeScriptPath(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
signScript := filepath.Join(tmpDir, "sign.sh")
|
||||
os.WriteFile(signScript, []byte("#!/bin/sh\nexit 0"), 0755)
|
||||
|
||||
// Try to use a nonexistent file as revoke_script
|
||||
config := &openssl.Config{
|
||||
SignScript: signScript,
|
||||
RevokeScript: "/nonexistent/revoke.sh",
|
||||
}
|
||||
connector := openssl.New(config, logger)
|
||||
|
||||
rawConfig, _ := json.Marshal(config)
|
||||
err := connector.ValidateConfig(ctx, rawConfig)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when revoke_script is nonexistent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenSSLConnector_ValidateConfig_ValidateCRLScriptPath(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
signScript := filepath.Join(tmpDir, "sign.sh")
|
||||
os.WriteFile(signScript, []byte("#!/bin/sh\nexit 0"), 0755)
|
||||
|
||||
// Try to use a directory as crl_script
|
||||
config := &openssl.Config{
|
||||
SignScript: signScript,
|
||||
CRLScript: tmpDir, // This is a directory, not a regular file
|
||||
}
|
||||
connector := openssl.New(config, logger)
|
||||
|
||||
rawConfig, _ := json.Marshal(config)
|
||||
err := connector.ValidateConfig(ctx, rawConfig)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when crl_script is not a regular file")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,7 +142,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy
|
||||
|
||||
// Validate Apache configuration before reload
|
||||
c.logger.Debug("validating Apache configuration", "validate_command", c.config.ValidateCommand)
|
||||
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand)
|
||||
validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand)
|
||||
if output, err := validateCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Apache config validation failed: %v (output: %s)", err, string(output))
|
||||
c.logger.Error("Apache validation failed", "error", err, "output", string(output))
|
||||
@@ -156,7 +156,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy
|
||||
|
||||
// Graceful reload
|
||||
c.logger.Debug("reloading Apache", "reload_command", c.config.ReloadCommand)
|
||||
reloadCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ReloadCommand)
|
||||
reloadCmd := exec.CommandContext(ctx, c.config.ReloadCommand)
|
||||
if output, err := reloadCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Apache reload failed: %v (output: %s)", err, string(output))
|
||||
c.logger.Error("Apache reload failed", "error", err, "output", string(output))
|
||||
@@ -196,7 +196,7 @@ func (c *Connector) ValidateDeployment(ctx context.Context, request target.Valid
|
||||
startTime := time.Now()
|
||||
|
||||
// Validate Apache configuration
|
||||
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand)
|
||||
validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand)
|
||||
if output, err := validateCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Apache config validation failed: %v (output: %s)", err, string(output))
|
||||
c.logger.Error("validation failed", "error", err)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/connector/target"
|
||||
"github.com/shankar0123/certctl/internal/validation"
|
||||
)
|
||||
|
||||
// Config represents the HAProxy deployment target configuration.
|
||||
@@ -53,12 +54,22 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag
|
||||
return fmt.Errorf("HAProxy reload_command is required")
|
||||
}
|
||||
|
||||
// Validate commands to prevent injection attacks
|
||||
if err := validation.ValidateShellCommand(cfg.ReloadCommand); err != nil {
|
||||
return fmt.Errorf("invalid reload_command: %w", err)
|
||||
}
|
||||
if cfg.ValidateCommand != "" {
|
||||
if err := validation.ValidateShellCommand(cfg.ValidateCommand); err != nil {
|
||||
return fmt.Errorf("invalid validate_command: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.logger.Info("validating HAProxy configuration",
|
||||
"pem_path", cfg.PEMPath)
|
||||
|
||||
// Verify validate command works if provided
|
||||
if cfg.ValidateCommand != "" {
|
||||
cmd := exec.CommandContext(ctx, "sh", "-c", cfg.ValidateCommand)
|
||||
cmd := exec.CommandContext(ctx, cfg.ValidateCommand)
|
||||
if err := cmd.Run(); err != nil {
|
||||
c.logger.Warn("HAProxy config validation failed during config check",
|
||||
"error", err,
|
||||
@@ -114,7 +125,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy
|
||||
// Validate HAProxy configuration if validate command is configured
|
||||
if c.config.ValidateCommand != "" {
|
||||
c.logger.Debug("validating HAProxy configuration", "validate_command", c.config.ValidateCommand)
|
||||
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand)
|
||||
validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand)
|
||||
if output, err := validateCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("HAProxy config validation failed: %v (output: %s)", err, string(output))
|
||||
c.logger.Error("HAProxy validation failed", "error", err, "output", string(output))
|
||||
@@ -129,7 +140,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy
|
||||
|
||||
// Reload HAProxy
|
||||
c.logger.Debug("reloading HAProxy", "reload_command", c.config.ReloadCommand)
|
||||
reloadCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ReloadCommand)
|
||||
reloadCmd := exec.CommandContext(ctx, c.config.ReloadCommand)
|
||||
if output, err := reloadCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("HAProxy reload failed: %v (output: %s)", err, string(output))
|
||||
c.logger.Error("HAProxy reload failed", "error", err, "output", string(output))
|
||||
@@ -169,7 +180,7 @@ func (c *Connector) ValidateDeployment(ctx context.Context, request target.Valid
|
||||
|
||||
// Validate HAProxy configuration if command provided
|
||||
if c.config.ValidateCommand != "" {
|
||||
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand)
|
||||
validateCmd := exec.CommandContext(ctx, c.config.ValidateCommand)
|
||||
if output, err := validateCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("HAProxy config validation failed: %v (output: %s)", err, string(output))
|
||||
c.logger.Error("validation failed", "error", err)
|
||||
|
||||
@@ -377,3 +377,85 @@ func TestNginxConnector_ValidateDeployment_ValidateCommandFails(t *testing.T) {
|
||||
t.Fatal("expected invalid result")
|
||||
}
|
||||
}
|
||||
|
||||
// Security tests for command injection prevention
|
||||
|
||||
func TestNginxConnector_ValidateConfig_RejectCommandInjectionSemicolon(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
cfg := nginx.Config{
|
||||
CertPath: filepath.Join(tmpDir, "cert.pem"),
|
||||
ChainPath: filepath.Join(tmpDir, "chain.pem"),
|
||||
ReloadCommand: "nginx; rm -rf /", // Command injection attempt
|
||||
ValidateCommand: "true",
|
||||
}
|
||||
|
||||
connector := nginx.New(&cfg, logger)
|
||||
rawConfig, _ := json.Marshal(cfg)
|
||||
err := connector.ValidateConfig(ctx, rawConfig)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for command injection in reload_command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNginxConnector_ValidateConfig_RejectCommandInjectionPipe(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
cfg := nginx.Config{
|
||||
CertPath: filepath.Join(tmpDir, "cert.pem"),
|
||||
ChainPath: filepath.Join(tmpDir, "chain.pem"),
|
||||
ReloadCommand: "true",
|
||||
ValidateCommand: "nginx -t | cat /etc/passwd", // Command injection attempt
|
||||
}
|
||||
|
||||
connector := nginx.New(&cfg, logger)
|
||||
rawConfig, _ := json.Marshal(cfg)
|
||||
err := connector.ValidateConfig(ctx, rawConfig)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for command injection in validate_command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNginxConnector_ValidateConfig_RejectCommandSubstitution(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
cfg := nginx.Config{
|
||||
CertPath: filepath.Join(tmpDir, "cert.pem"),
|
||||
ChainPath: filepath.Join(tmpDir, "chain.pem"),
|
||||
ReloadCommand: "echo $(whoami)",
|
||||
ValidateCommand: "true",
|
||||
}
|
||||
|
||||
connector := nginx.New(&cfg, logger)
|
||||
rawConfig, _ := json.Marshal(cfg)
|
||||
err := connector.ValidateConfig(ctx, rawConfig)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for command substitution in reload_command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNginxConnector_ValidateConfig_RejectBackticks(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
cfg := nginx.Config{
|
||||
CertPath: filepath.Join(tmpDir, "cert.pem"),
|
||||
ChainPath: filepath.Join(tmpDir, "chain.pem"),
|
||||
ReloadCommand: "true",
|
||||
ValidateCommand: "nginx -t `whoami`",
|
||||
}
|
||||
|
||||
connector := nginx.New(&cfg, logger)
|
||||
rawConfig, _ := json.Marshal(cfg)
|
||||
err := connector.ValidateConfig(ctx, rawConfig)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for backtick injection in validate_command")
|
||||
}
|
||||
}
|
||||
|
||||
+23
-22
@@ -105,8 +105,9 @@ func (s *AgentService) HeartbeatWithContext(ctx context.Context, agentID string,
|
||||
}
|
||||
|
||||
// Heartbeat updates agent heartbeat (handler interface method).
|
||||
func (s *AgentService) Heartbeat(agentID string, metadata *domain.AgentMetadata) error {
|
||||
return s.HeartbeatWithContext(context.Background(), agentID, metadata)
|
||||
// Note: This method is called from handlers which have a context; callers should prefer HeartbeatWithContext.
|
||||
func (s *AgentService) Heartbeat(ctx context.Context, agentID string, metadata *domain.AgentMetadata) error {
|
||||
return s.HeartbeatWithContext(ctx, agentID, metadata)
|
||||
}
|
||||
|
||||
// SubmitCSR validates and processes a Certificate Signing Request from an agent.
|
||||
@@ -326,7 +327,7 @@ func (s *AgentService) GetAgentByAPIKey(ctx context.Context, apiKey string) (*do
|
||||
}
|
||||
|
||||
// ListAgents returns paginated agents (handler interface method).
|
||||
func (s *AgentService) ListAgents(page, perPage int) ([]domain.Agent, int64, error) {
|
||||
func (s *AgentService) ListAgents(ctx context.Context, page, perPage int) ([]domain.Agent, int64, error) {
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
@@ -334,7 +335,7 @@ func (s *AgentService) ListAgents(page, perPage int) ([]domain.Agent, int64, err
|
||||
perPage = 50
|
||||
}
|
||||
|
||||
agents, err := s.agentRepo.List(context.Background())
|
||||
agents, err := s.agentRepo.List(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to list agents: %w", err)
|
||||
}
|
||||
@@ -360,12 +361,12 @@ func (s *AgentService) ListAgents(page, perPage int) ([]domain.Agent, int64, err
|
||||
}
|
||||
|
||||
// GetAgent returns a single agent (handler interface method).
|
||||
func (s *AgentService) GetAgent(id string) (*domain.Agent, error) {
|
||||
return s.agentRepo.Get(context.Background(), id)
|
||||
func (s *AgentService) GetAgent(ctx context.Context, id string) (*domain.Agent, error) {
|
||||
return s.agentRepo.Get(ctx, id)
|
||||
}
|
||||
|
||||
// RegisterAgent creates and registers a new agent (handler interface method).
|
||||
func (s *AgentService) RegisterAgent(agent domain.Agent) (*domain.Agent, error) {
|
||||
func (s *AgentService) RegisterAgent(ctx context.Context, agent domain.Agent) (*domain.Agent, error) {
|
||||
agent.ID = generateID("agent")
|
||||
apiKey := generateAPIKey()
|
||||
agent.APIKeyHash = hashAPIKey(apiKey)
|
||||
@@ -374,7 +375,7 @@ func (s *AgentService) RegisterAgent(agent domain.Agent) (*domain.Agent, error)
|
||||
agent.RegisteredAt = now
|
||||
agent.LastHeartbeatAt = &now
|
||||
|
||||
if err := s.agentRepo.Create(context.Background(), &agent); err != nil {
|
||||
if err := s.agentRepo.Create(ctx, &agent); err != nil {
|
||||
return nil, fmt.Errorf("failed to register agent: %w", err)
|
||||
}
|
||||
return &agent, nil
|
||||
@@ -382,8 +383,8 @@ func (s *AgentService) RegisterAgent(agent domain.Agent) (*domain.Agent, error)
|
||||
|
||||
// CSRSubmit processes a CSR submission from an agent (handler interface method).
|
||||
// The csrPEM parameter contains "certID:csrPEM" or just the CSR PEM.
|
||||
func (s *AgentService) CSRSubmit(agentID string, csrPEM string) (string, error) {
|
||||
err := s.SubmitCSR(context.Background(), agentID, "", []byte(csrPEM))
|
||||
func (s *AgentService) CSRSubmit(ctx context.Context, agentID string, csrPEM string) (string, error) {
|
||||
err := s.SubmitCSR(ctx, agentID, "", []byte(csrPEM))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -391,8 +392,8 @@ func (s *AgentService) CSRSubmit(agentID string, csrPEM string) (string, error)
|
||||
}
|
||||
|
||||
// CSRSubmitForCert processes a CSR submission for a specific certificate (handler interface method).
|
||||
func (s *AgentService) CSRSubmitForCert(agentID string, certID string, csrPEM string) (string, error) {
|
||||
err := s.SubmitCSR(context.Background(), agentID, certID, []byte(csrPEM))
|
||||
func (s *AgentService) CSRSubmitForCert(ctx context.Context, agentID string, certID string, csrPEM string) (string, error) {
|
||||
err := s.SubmitCSR(ctx, agentID, certID, []byte(csrPEM))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -400,8 +401,8 @@ func (s *AgentService) CSRSubmitForCert(agentID string, certID string, csrPEM st
|
||||
}
|
||||
|
||||
// GetWork returns pending deployment jobs for an agent (handler interface method).
|
||||
func (s *AgentService) GetWork(agentID string) ([]domain.Job, error) {
|
||||
jobs, err := s.GetPendingWork(context.Background(), agentID)
|
||||
func (s *AgentService) GetWork(ctx context.Context, agentID string) ([]domain.Job, error) {
|
||||
jobs, err := s.GetPendingWork(ctx, agentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -417,8 +418,8 @@ func (s *AgentService) GetWork(agentID string) ([]domain.Job, error) {
|
||||
// GetWorkWithTargets returns actionable jobs enriched with target/certificate details.
|
||||
// Deployment jobs include target type + config. AwaitingCSR jobs include common name + SANs
|
||||
// so the agent knows what CSR to generate.
|
||||
func (s *AgentService) GetWorkWithTargets(agentID string) ([]domain.WorkItem, error) {
|
||||
jobs, err := s.GetPendingWork(context.Background(), agentID)
|
||||
func (s *AgentService) GetWorkWithTargets(ctx context.Context, agentID string) ([]domain.WorkItem, error) {
|
||||
jobs, err := s.GetPendingWork(ctx, agentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -438,7 +439,7 @@ func (s *AgentService) GetWorkWithTargets(agentID string) ([]domain.WorkItem, er
|
||||
|
||||
// Enrich with target details for deployment jobs
|
||||
if j.TargetID != nil && *j.TargetID != "" {
|
||||
target, err := s.targetRepo.Get(context.Background(), *j.TargetID)
|
||||
target, err := s.targetRepo.Get(ctx, *j.TargetID)
|
||||
if err == nil && target != nil {
|
||||
item.TargetType = string(target.Type)
|
||||
item.TargetConfig = target.Config
|
||||
@@ -447,7 +448,7 @@ func (s *AgentService) GetWorkWithTargets(agentID string) ([]domain.WorkItem, er
|
||||
|
||||
// Enrich with certificate details for AwaitingCSR jobs (agent needs CN + SANs for CSR)
|
||||
if j.Status == domain.JobStatusAwaitingCSR {
|
||||
cert, err := s.certRepo.Get(context.Background(), j.CertificateID)
|
||||
cert, err := s.certRepo.Get(ctx, j.CertificateID)
|
||||
if err == nil && cert != nil {
|
||||
item.CommonName = cert.CommonName
|
||||
item.SANs = cert.SANs
|
||||
@@ -461,13 +462,13 @@ func (s *AgentService) GetWorkWithTargets(agentID string) ([]domain.WorkItem, er
|
||||
}
|
||||
|
||||
// UpdateJobStatus reports a job's status from an agent (handler interface method).
|
||||
func (s *AgentService) UpdateJobStatus(agentID string, jobID string, status string, errMsg string) error {
|
||||
return s.ReportJobStatus(context.Background(), agentID, jobID, domain.JobStatus(status), errMsg)
|
||||
func (s *AgentService) UpdateJobStatus(ctx context.Context, agentID string, jobID string, status string, errMsg string) error {
|
||||
return s.ReportJobStatus(ctx, agentID, jobID, domain.JobStatus(status), errMsg)
|
||||
}
|
||||
|
||||
// CertificatePickup retrieves a certificate for an agent (handler interface method).
|
||||
func (s *AgentService) CertificatePickup(agentID, certID string) (string, error) {
|
||||
certPEM, err := s.GetCertificateForAgent(context.Background(), agentID, certID)
|
||||
func (s *AgentService) CertificatePickup(ctx context.Context, agentID, certID string) (string, error) {
|
||||
certPEM, err := s.GetCertificateForAgent(ctx, agentID, certID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ func NewAgentGroupService(
|
||||
}
|
||||
|
||||
// ListAgentGroups returns paginated agent groups (handler interface method).
|
||||
func (s *AgentGroupService) ListAgentGroups(page, perPage int) ([]domain.AgentGroup, int64, error) {
|
||||
func (s *AgentGroupService) ListAgentGroups(ctx context.Context, page, perPage int) ([]domain.AgentGroup, int64, error) {
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
@@ -36,7 +36,7 @@ func (s *AgentGroupService) ListAgentGroups(page, perPage int) ([]domain.AgentGr
|
||||
perPage = 50
|
||||
}
|
||||
|
||||
groups, err := s.groupRepo.List(context.Background())
|
||||
groups, err := s.groupRepo.List(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to list agent groups: %w", err)
|
||||
}
|
||||
@@ -53,12 +53,12 @@ func (s *AgentGroupService) ListAgentGroups(page, perPage int) ([]domain.AgentGr
|
||||
}
|
||||
|
||||
// GetAgentGroup returns a single agent group (handler interface method).
|
||||
func (s *AgentGroupService) GetAgentGroup(id string) (*domain.AgentGroup, error) {
|
||||
return s.groupRepo.Get(context.Background(), id)
|
||||
func (s *AgentGroupService) GetAgentGroup(ctx context.Context, id string) (*domain.AgentGroup, error) {
|
||||
return s.groupRepo.Get(ctx, id)
|
||||
}
|
||||
|
||||
// CreateAgentGroup creates a new agent group with validation (handler interface method).
|
||||
func (s *AgentGroupService) CreateAgentGroup(group domain.AgentGroup) (*domain.AgentGroup, error) {
|
||||
func (s *AgentGroupService) CreateAgentGroup(ctx context.Context, group domain.AgentGroup) (*domain.AgentGroup, error) {
|
||||
if err := validateAgentGroup(&group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -74,12 +74,12 @@ func (s *AgentGroupService) CreateAgentGroup(group domain.AgentGroup) (*domain.A
|
||||
group.UpdatedAt = now
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Create(context.Background(), &group); err != nil {
|
||||
if err := s.groupRepo.Create(ctx, &group); err != nil {
|
||||
return nil, fmt.Errorf("failed to create agent group: %w", err)
|
||||
}
|
||||
|
||||
if s.auditService != nil {
|
||||
if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser,
|
||||
if auditErr := s.auditService.RecordEvent(ctx, "api", domain.ActorTypeUser,
|
||||
"create_agent_group", "agent_group", group.ID, nil); auditErr != nil {
|
||||
slog.Error("failed to record audit event", "error", auditErr)
|
||||
}
|
||||
@@ -89,18 +89,18 @@ func (s *AgentGroupService) CreateAgentGroup(group domain.AgentGroup) (*domain.A
|
||||
}
|
||||
|
||||
// UpdateAgentGroup modifies an existing agent group (handler interface method).
|
||||
func (s *AgentGroupService) UpdateAgentGroup(id string, group domain.AgentGroup) (*domain.AgentGroup, error) {
|
||||
func (s *AgentGroupService) UpdateAgentGroup(ctx context.Context, id string, group domain.AgentGroup) (*domain.AgentGroup, error) {
|
||||
if err := validateAgentGroup(&group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
group.ID = id
|
||||
if err := s.groupRepo.Update(context.Background(), &group); err != nil {
|
||||
if err := s.groupRepo.Update(ctx, &group); err != nil {
|
||||
return nil, fmt.Errorf("failed to update agent group: %w", err)
|
||||
}
|
||||
|
||||
if s.auditService != nil {
|
||||
if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser,
|
||||
if auditErr := s.auditService.RecordEvent(ctx, "api", domain.ActorTypeUser,
|
||||
"update_agent_group", "agent_group", id, nil); auditErr != nil {
|
||||
slog.Error("failed to record audit event", "error", auditErr)
|
||||
}
|
||||
@@ -110,13 +110,13 @@ func (s *AgentGroupService) UpdateAgentGroup(id string, group domain.AgentGroup)
|
||||
}
|
||||
|
||||
// DeleteAgentGroup removes an agent group (handler interface method).
|
||||
func (s *AgentGroupService) DeleteAgentGroup(id string) error {
|
||||
if err := s.groupRepo.Delete(context.Background(), id); err != nil {
|
||||
func (s *AgentGroupService) DeleteAgentGroup(ctx context.Context, id string) error {
|
||||
if err := s.groupRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("failed to delete agent group: %w", err)
|
||||
}
|
||||
|
||||
if s.auditService != nil {
|
||||
if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser,
|
||||
if auditErr := s.auditService.RecordEvent(ctx, "api", domain.ActorTypeUser,
|
||||
"delete_agent_group", "agent_group", id, nil); auditErr != nil {
|
||||
slog.Error("failed to record audit event", "error", auditErr)
|
||||
}
|
||||
@@ -126,8 +126,8 @@ func (s *AgentGroupService) DeleteAgentGroup(id string) error {
|
||||
}
|
||||
|
||||
// ListMembers returns agents in a group.
|
||||
func (s *AgentGroupService) ListMembers(id string) ([]domain.Agent, int64, error) {
|
||||
agents, err := s.groupRepo.ListMembers(context.Background(), id)
|
||||
func (s *AgentGroupService) ListMembers(ctx context.Context, id string) ([]domain.Agent, int64, error) {
|
||||
agents, err := s.groupRepo.ListMembers(ctx, id)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to list group members: %w", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user