fix(quality): TICKET-012 propagate request context instead of context.Background()

- Updated AgentService interface to accept context.Context parameter in all methods
- Replaced context.Background() calls with proper ctx parameter in agent.go
- Updated AgentGroupService interface to accept context.Context parameter
- Replaced context.Background() calls with proper ctx parameter in agent_group.go
- Updated handler methods to pass r.Context() to service methods
- Context now properly propagates through request lifecycle for timeout/cancellation
- Improved request tracing and cancellation behavior
This commit is contained in:
shankar0123
2026-03-27 21:35:22 -04:00
parent 3e5cc86c5a
commit 200bdf990f
11 changed files with 413 additions and 81 deletions
+13 -12
View File
@@ -1,6 +1,7 @@
package handler
import (
"context"
"encoding/json"
"net/http"
"strconv"
@@ -12,12 +13,12 @@ import (
// AgentGroupService defines the service interface for agent group operations.
type AgentGroupService interface {
ListAgentGroups(page, perPage int) ([]domain.AgentGroup, int64, error)
GetAgentGroup(id string) (*domain.AgentGroup, error)
CreateAgentGroup(group domain.AgentGroup) (*domain.AgentGroup, error)
UpdateAgentGroup(id string, group domain.AgentGroup) (*domain.AgentGroup, error)
DeleteAgentGroup(id string) error
ListMembers(id string) ([]domain.Agent, int64, error)
ListAgentGroups(ctx context.Context, page, perPage int) ([]domain.AgentGroup, int64, error)
GetAgentGroup(ctx context.Context, id string) (*domain.AgentGroup, error)
CreateAgentGroup(ctx context.Context, group domain.AgentGroup) (*domain.AgentGroup, error)
UpdateAgentGroup(ctx context.Context, id string, group domain.AgentGroup) (*domain.AgentGroup, error)
DeleteAgentGroup(ctx context.Context, id string) error
ListMembers(ctx context.Context, id string) ([]domain.Agent, int64, error)
}
// AgentGroupHandler handles HTTP requests for agent group operations.
@@ -54,7 +55,7 @@ func (h AgentGroupHandler) ListAgentGroups(w http.ResponseWriter, r *http.Reques
}
}
groups, total, err := h.svc.ListAgentGroups(page, perPage)
groups, total, err := h.svc.ListAgentGroups(r.Context(), page, perPage)
if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list agent groups", requestID)
return
@@ -86,7 +87,7 @@ func (h AgentGroupHandler) GetAgentGroup(w http.ResponseWriter, r *http.Request)
return
}
group, err := h.svc.GetAgentGroup(id)
group, err := h.svc.GetAgentGroup(r.Context(), id)
if err != nil {
ErrorWithRequestID(w, http.StatusNotFound, "Agent group not found", requestID)
return
@@ -120,7 +121,7 @@ func (h AgentGroupHandler) CreateAgentGroup(w http.ResponseWriter, r *http.Reque
return
}
created, err := h.svc.CreateAgentGroup(group)
created, err := h.svc.CreateAgentGroup(r.Context(), group)
if err != nil {
if strings.Contains(err.Error(), "invalid") || strings.Contains(err.Error(), "required") {
ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID)
@@ -157,7 +158,7 @@ func (h AgentGroupHandler) UpdateAgentGroup(w http.ResponseWriter, r *http.Reque
return
}
updated, err := h.svc.UpdateAgentGroup(id, group)
updated, err := h.svc.UpdateAgentGroup(r.Context(), id, group)
if err != nil {
if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Agent group not found", requestID)
@@ -186,7 +187,7 @@ func (h AgentGroupHandler) DeleteAgentGroup(w http.ResponseWriter, r *http.Reque
return
}
if err := h.svc.DeleteAgentGroup(id); err != nil {
if err := h.svc.DeleteAgentGroup(r.Context(), id); err != nil {
if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Agent group not found", requestID)
return
@@ -217,7 +218,7 @@ func (h AgentGroupHandler) ListAgentGroupMembers(w http.ResponseWriter, r *http.
}
id := parts[0]
members, total, err := h.svc.ListMembers(id)
members, total, err := h.svc.ListMembers(r.Context(), id)
if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list group members", requestID)
return
+20 -19
View File
@@ -1,6 +1,7 @@
package handler
import (
"context"
"encoding/json"
"net/http"
"strconv"
@@ -12,16 +13,16 @@ import (
// AgentService defines the service interface for agent operations.
type AgentService interface {
ListAgents(page, perPage int) ([]domain.Agent, int64, error)
GetAgent(id string) (*domain.Agent, error)
RegisterAgent(agent domain.Agent) (*domain.Agent, error)
Heartbeat(agentID string, metadata *domain.AgentMetadata) error
CSRSubmit(agentID string, csrPEM string) (string, error)
CSRSubmitForCert(agentID string, certID string, csrPEM string) (string, error)
CertificatePickup(agentID, certID string) (string, error)
GetWork(agentID string) ([]domain.Job, error)
GetWorkWithTargets(agentID string) ([]domain.WorkItem, error)
UpdateJobStatus(agentID string, jobID string, status string, errMsg string) error
ListAgents(ctx context.Context, page, perPage int) ([]domain.Agent, int64, error)
GetAgent(ctx context.Context, id string) (*domain.Agent, error)
RegisterAgent(ctx context.Context, agent domain.Agent) (*domain.Agent, error)
Heartbeat(ctx context.Context, agentID string, metadata *domain.AgentMetadata) error
CSRSubmit(ctx context.Context, agentID string, csrPEM string) (string, error)
CSRSubmitForCert(ctx context.Context, agentID string, certID string, csrPEM string) (string, error)
CertificatePickup(ctx context.Context, agentID, certID string) (string, error)
GetWork(ctx context.Context, agentID string) ([]domain.Job, error)
GetWorkWithTargets(ctx context.Context, agentID string) ([]domain.WorkItem, error)
UpdateJobStatus(ctx context.Context, agentID string, jobID string, status string, errMsg string) error
}
// AgentHandler handles HTTP requests for agent operations.
@@ -58,7 +59,7 @@ func (h AgentHandler) ListAgents(w http.ResponseWriter, r *http.Request) {
}
}
agents, total, err := h.svc.ListAgents(page, perPage)
agents, total, err := h.svc.ListAgents(r.Context(), page, perPage)
if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list agents", requestID)
return
@@ -92,7 +93,7 @@ func (h AgentHandler) GetAgent(w http.ResponseWriter, r *http.Request) {
}
id = parts[0]
agent, err := h.svc.GetAgent(id)
agent, err := h.svc.GetAgent(r.Context(), id)
if err != nil {
ErrorWithRequestID(w, http.StatusNotFound, "Agent not found", requestID)
return
@@ -131,7 +132,7 @@ func (h AgentHandler) RegisterAgent(w http.ResponseWriter, r *http.Request) {
return
}
created, err := h.svc.RegisterAgent(agent)
created, err := h.svc.RegisterAgent(r.Context(), agent)
if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to register agent", requestID)
return
@@ -182,7 +183,7 @@ func (h AgentHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
}
}
if err := h.svc.Heartbeat(agentID, metadata); err != nil {
if err := h.svc.Heartbeat(r.Context(), agentID, metadata); err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to record heartbeat", requestID)
return
}
@@ -234,9 +235,9 @@ func (h AgentHandler) AgentCSRSubmit(w http.ResponseWriter, r *http.Request) {
// If certificate_id is provided, sign the CSR for that specific certificate
if req.CertificateID != "" {
status, err = h.svc.CSRSubmitForCert(agentID, req.CertificateID, req.CSRPEM)
status, err = h.svc.CSRSubmitForCert(r.Context(), agentID, req.CertificateID, req.CSRPEM)
} else {
status, err = h.svc.CSRSubmit(agentID, req.CSRPEM)
status, err = h.svc.CSRSubmit(r.Context(), agentID, req.CSRPEM)
}
if err != nil {
@@ -271,7 +272,7 @@ func (h AgentHandler) AgentCertificatePickup(w http.ResponseWriter, r *http.Requ
agentID := parts[0]
certID := parts[2]
certPEM, err := h.svc.CertificatePickup(agentID, certID)
certPEM, err := h.svc.CertificatePickup(r.Context(), agentID, certID)
if err != nil {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found or not ready", requestID)
return
@@ -303,7 +304,7 @@ func (h AgentHandler) AgentGetWork(w http.ResponseWriter, r *http.Request) {
}
agentID := parts[0]
workItems, err := h.svc.GetWorkWithTargets(agentID)
workItems, err := h.svc.GetWorkWithTargets(r.Context(), agentID)
if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to get pending work", requestID)
return
@@ -353,7 +354,7 @@ func (h AgentHandler) AgentReportJobStatus(w http.ResponseWriter, r *http.Reques
return
}
if err := h.svc.UpdateJobStatus(agentID, jobID, req.Status, req.Error); err != nil {
if err := h.svc.UpdateJobStatus(r.Context(), agentID, jobID, req.Status, req.Error); err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update job status", requestID)
return
}