fix(quality): TICKET-012 propagate request context instead of context.Background()

- Updated AgentService interface to accept context.Context parameter in all methods
- Replaced context.Background() calls with proper ctx parameter in agent.go
- Updated AgentGroupService interface to accept context.Context parameter
- Replaced context.Background() calls with proper ctx parameter in agent_group.go
- Updated handler methods to pass r.Context() to service methods
- Context now properly propagates through request lifecycle for timeout/cancellation
- Improved request tracing and cancellation behavior
This commit is contained in:
shankar0123
2026-03-27 21:35:22 -04:00
parent 3e5cc86c5a
commit 200bdf990f
11 changed files with 413 additions and 81 deletions
+23 -22
View File
@@ -105,8 +105,9 @@ func (s *AgentService) HeartbeatWithContext(ctx context.Context, agentID string,
}
// Heartbeat updates agent heartbeat (handler interface method).
func (s *AgentService) Heartbeat(agentID string, metadata *domain.AgentMetadata) error {
return s.HeartbeatWithContext(context.Background(), agentID, metadata)
// Note: This method is called from handlers which have a context; callers should prefer HeartbeatWithContext.
func (s *AgentService) Heartbeat(ctx context.Context, agentID string, metadata *domain.AgentMetadata) error {
return s.HeartbeatWithContext(ctx, agentID, metadata)
}
// SubmitCSR validates and processes a Certificate Signing Request from an agent.
@@ -326,7 +327,7 @@ func (s *AgentService) GetAgentByAPIKey(ctx context.Context, apiKey string) (*do
}
// ListAgents returns paginated agents (handler interface method).
func (s *AgentService) ListAgents(page, perPage int) ([]domain.Agent, int64, error) {
func (s *AgentService) ListAgents(ctx context.Context, page, perPage int) ([]domain.Agent, int64, error) {
if page < 1 {
page = 1
}
@@ -334,7 +335,7 @@ func (s *AgentService) ListAgents(page, perPage int) ([]domain.Agent, int64, err
perPage = 50
}
agents, err := s.agentRepo.List(context.Background())
agents, err := s.agentRepo.List(ctx)
if err != nil {
return nil, 0, fmt.Errorf("failed to list agents: %w", err)
}
@@ -360,12 +361,12 @@ func (s *AgentService) ListAgents(page, perPage int) ([]domain.Agent, int64, err
}
// GetAgent returns a single agent (handler interface method).
func (s *AgentService) GetAgent(id string) (*domain.Agent, error) {
return s.agentRepo.Get(context.Background(), id)
func (s *AgentService) GetAgent(ctx context.Context, id string) (*domain.Agent, error) {
return s.agentRepo.Get(ctx, id)
}
// RegisterAgent creates and registers a new agent (handler interface method).
func (s *AgentService) RegisterAgent(agent domain.Agent) (*domain.Agent, error) {
func (s *AgentService) RegisterAgent(ctx context.Context, agent domain.Agent) (*domain.Agent, error) {
agent.ID = generateID("agent")
apiKey := generateAPIKey()
agent.APIKeyHash = hashAPIKey(apiKey)
@@ -374,7 +375,7 @@ func (s *AgentService) RegisterAgent(agent domain.Agent) (*domain.Agent, error)
agent.RegisteredAt = now
agent.LastHeartbeatAt = &now
if err := s.agentRepo.Create(context.Background(), &agent); err != nil {
if err := s.agentRepo.Create(ctx, &agent); err != nil {
return nil, fmt.Errorf("failed to register agent: %w", err)
}
return &agent, nil
@@ -382,8 +383,8 @@ func (s *AgentService) RegisterAgent(agent domain.Agent) (*domain.Agent, error)
// CSRSubmit processes a CSR submission from an agent (handler interface method).
// The csrPEM parameter contains "certID:csrPEM" or just the CSR PEM.
func (s *AgentService) CSRSubmit(agentID string, csrPEM string) (string, error) {
err := s.SubmitCSR(context.Background(), agentID, "", []byte(csrPEM))
func (s *AgentService) CSRSubmit(ctx context.Context, agentID string, csrPEM string) (string, error) {
err := s.SubmitCSR(ctx, agentID, "", []byte(csrPEM))
if err != nil {
return "", err
}
@@ -391,8 +392,8 @@ func (s *AgentService) CSRSubmit(agentID string, csrPEM string) (string, error)
}
// CSRSubmitForCert processes a CSR submission for a specific certificate (handler interface method).
func (s *AgentService) CSRSubmitForCert(agentID string, certID string, csrPEM string) (string, error) {
err := s.SubmitCSR(context.Background(), agentID, certID, []byte(csrPEM))
func (s *AgentService) CSRSubmitForCert(ctx context.Context, agentID string, certID string, csrPEM string) (string, error) {
err := s.SubmitCSR(ctx, agentID, certID, []byte(csrPEM))
if err != nil {
return "", err
}
@@ -400,8 +401,8 @@ func (s *AgentService) CSRSubmitForCert(agentID string, certID string, csrPEM st
}
// GetWork returns pending deployment jobs for an agent (handler interface method).
func (s *AgentService) GetWork(agentID string) ([]domain.Job, error) {
jobs, err := s.GetPendingWork(context.Background(), agentID)
func (s *AgentService) GetWork(ctx context.Context, agentID string) ([]domain.Job, error) {
jobs, err := s.GetPendingWork(ctx, agentID)
if err != nil {
return nil, err
}
@@ -417,8 +418,8 @@ func (s *AgentService) GetWork(agentID string) ([]domain.Job, error) {
// GetWorkWithTargets returns actionable jobs enriched with target/certificate details.
// Deployment jobs include target type + config. AwaitingCSR jobs include common name + SANs
// so the agent knows what CSR to generate.
func (s *AgentService) GetWorkWithTargets(agentID string) ([]domain.WorkItem, error) {
jobs, err := s.GetPendingWork(context.Background(), agentID)
func (s *AgentService) GetWorkWithTargets(ctx context.Context, agentID string) ([]domain.WorkItem, error) {
jobs, err := s.GetPendingWork(ctx, agentID)
if err != nil {
return nil, err
}
@@ -438,7 +439,7 @@ func (s *AgentService) GetWorkWithTargets(agentID string) ([]domain.WorkItem, er
// Enrich with target details for deployment jobs
if j.TargetID != nil && *j.TargetID != "" {
target, err := s.targetRepo.Get(context.Background(), *j.TargetID)
target, err := s.targetRepo.Get(ctx, *j.TargetID)
if err == nil && target != nil {
item.TargetType = string(target.Type)
item.TargetConfig = target.Config
@@ -447,7 +448,7 @@ func (s *AgentService) GetWorkWithTargets(agentID string) ([]domain.WorkItem, er
// Enrich with certificate details for AwaitingCSR jobs (agent needs CN + SANs for CSR)
if j.Status == domain.JobStatusAwaitingCSR {
cert, err := s.certRepo.Get(context.Background(), j.CertificateID)
cert, err := s.certRepo.Get(ctx, j.CertificateID)
if err == nil && cert != nil {
item.CommonName = cert.CommonName
item.SANs = cert.SANs
@@ -461,13 +462,13 @@ func (s *AgentService) GetWorkWithTargets(agentID string) ([]domain.WorkItem, er
}
// UpdateJobStatus reports a job's status from an agent (handler interface method).
func (s *AgentService) UpdateJobStatus(agentID string, jobID string, status string, errMsg string) error {
return s.ReportJobStatus(context.Background(), agentID, jobID, domain.JobStatus(status), errMsg)
func (s *AgentService) UpdateJobStatus(ctx context.Context, agentID string, jobID string, status string, errMsg string) error {
return s.ReportJobStatus(ctx, agentID, jobID, domain.JobStatus(status), errMsg)
}
// CertificatePickup retrieves a certificate for an agent (handler interface method).
func (s *AgentService) CertificatePickup(agentID, certID string) (string, error) {
certPEM, err := s.GetCertificateForAgent(context.Background(), agentID, certID)
func (s *AgentService) CertificatePickup(ctx context.Context, agentID, certID string) (string, error) {
certPEM, err := s.GetCertificateForAgent(ctx, agentID, certID)
if err != nil {
return "", err
}
+15 -15
View File
@@ -28,7 +28,7 @@ func NewAgentGroupService(
}
// ListAgentGroups returns paginated agent groups (handler interface method).
func (s *AgentGroupService) ListAgentGroups(page, perPage int) ([]domain.AgentGroup, int64, error) {
func (s *AgentGroupService) ListAgentGroups(ctx context.Context, page, perPage int) ([]domain.AgentGroup, int64, error) {
if page < 1 {
page = 1
}
@@ -36,7 +36,7 @@ func (s *AgentGroupService) ListAgentGroups(page, perPage int) ([]domain.AgentGr
perPage = 50
}
groups, err := s.groupRepo.List(context.Background())
groups, err := s.groupRepo.List(ctx)
if err != nil {
return nil, 0, fmt.Errorf("failed to list agent groups: %w", err)
}
@@ -53,12 +53,12 @@ func (s *AgentGroupService) ListAgentGroups(page, perPage int) ([]domain.AgentGr
}
// GetAgentGroup returns a single agent group (handler interface method).
func (s *AgentGroupService) GetAgentGroup(id string) (*domain.AgentGroup, error) {
return s.groupRepo.Get(context.Background(), id)
func (s *AgentGroupService) GetAgentGroup(ctx context.Context, id string) (*domain.AgentGroup, error) {
return s.groupRepo.Get(ctx, id)
}
// CreateAgentGroup creates a new agent group with validation (handler interface method).
func (s *AgentGroupService) CreateAgentGroup(group domain.AgentGroup) (*domain.AgentGroup, error) {
func (s *AgentGroupService) CreateAgentGroup(ctx context.Context, group domain.AgentGroup) (*domain.AgentGroup, error) {
if err := validateAgentGroup(&group); err != nil {
return nil, err
}
@@ -74,12 +74,12 @@ func (s *AgentGroupService) CreateAgentGroup(group domain.AgentGroup) (*domain.A
group.UpdatedAt = now
}
if err := s.groupRepo.Create(context.Background(), &group); err != nil {
if err := s.groupRepo.Create(ctx, &group); err != nil {
return nil, fmt.Errorf("failed to create agent group: %w", err)
}
if s.auditService != nil {
if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser,
if auditErr := s.auditService.RecordEvent(ctx, "api", domain.ActorTypeUser,
"create_agent_group", "agent_group", group.ID, nil); auditErr != nil {
slog.Error("failed to record audit event", "error", auditErr)
}
@@ -89,18 +89,18 @@ func (s *AgentGroupService) CreateAgentGroup(group domain.AgentGroup) (*domain.A
}
// UpdateAgentGroup modifies an existing agent group (handler interface method).
func (s *AgentGroupService) UpdateAgentGroup(id string, group domain.AgentGroup) (*domain.AgentGroup, error) {
func (s *AgentGroupService) UpdateAgentGroup(ctx context.Context, id string, group domain.AgentGroup) (*domain.AgentGroup, error) {
if err := validateAgentGroup(&group); err != nil {
return nil, err
}
group.ID = id
if err := s.groupRepo.Update(context.Background(), &group); err != nil {
if err := s.groupRepo.Update(ctx, &group); err != nil {
return nil, fmt.Errorf("failed to update agent group: %w", err)
}
if s.auditService != nil {
if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser,
if auditErr := s.auditService.RecordEvent(ctx, "api", domain.ActorTypeUser,
"update_agent_group", "agent_group", id, nil); auditErr != nil {
slog.Error("failed to record audit event", "error", auditErr)
}
@@ -110,13 +110,13 @@ func (s *AgentGroupService) UpdateAgentGroup(id string, group domain.AgentGroup)
}
// DeleteAgentGroup removes an agent group (handler interface method).
func (s *AgentGroupService) DeleteAgentGroup(id string) error {
if err := s.groupRepo.Delete(context.Background(), id); err != nil {
func (s *AgentGroupService) DeleteAgentGroup(ctx context.Context, id string) error {
if err := s.groupRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("failed to delete agent group: %w", err)
}
if s.auditService != nil {
if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser,
if auditErr := s.auditService.RecordEvent(ctx, "api", domain.ActorTypeUser,
"delete_agent_group", "agent_group", id, nil); auditErr != nil {
slog.Error("failed to record audit event", "error", auditErr)
}
@@ -126,8 +126,8 @@ func (s *AgentGroupService) DeleteAgentGroup(id string) error {
}
// ListMembers returns agents in a group.
func (s *AgentGroupService) ListMembers(id string) ([]domain.Agent, int64, error) {
agents, err := s.groupRepo.ListMembers(context.Background(), id)
func (s *AgentGroupService) ListMembers(ctx context.Context, id string) ([]domain.Agent, int64, error) {
agents, err := s.groupRepo.ListMembers(ctx, id)
if err != nil {
return nil, 0, fmt.Errorf("failed to list group members: %w", err)
}