fix: harden agent with backoff, panic recovery, and error handling

- Exponential backoff on consecutive poll/heartbeat failures (max 5min)
- Panic recovery wrapper on agent.Run goroutine
- All 9 silent reportJobStatus errors now logged properly
- Key read failures return error and report job failure
- CommonName validation before CSR creation
- KeyDir permissions enforced with os.Chmod after MkdirAll
- splitPEMChain rewritten to use encoding/pem instead of string parsing

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
shankar0123
2026-03-20 01:20:10 -04:00
parent e03a75ed9a
commit 239a1792d2
+90 -35
View File
@@ -18,6 +18,7 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"strings"
"syscall" "syscall"
"time" "time"
@@ -47,8 +48,9 @@ type Agent struct {
client *http.Client client *http.Client
// Configuration // Configuration
heartbeatInterval time.Duration heartbeatInterval time.Duration
pollInterval time.Duration pollInterval time.Duration
consecutiveFailures int
} }
// WorkResponse represents the response from the work polling endpoint. // WorkResponse represents the response from the work polling endpoint.
@@ -95,6 +97,11 @@ func (a *Agent) Run(ctx context.Context) error {
return fmt.Errorf("failed to create key directory %s: %w", a.config.KeyDir, err) return fmt.Errorf("failed to create key directory %s: %w", a.config.KeyDir, err)
} }
// Enforce permissions even if directory already exists
if err := os.Chmod(a.config.KeyDir, 0700); err != nil {
a.logger.Warn("failed to enforce key directory permissions", "path", a.config.KeyDir, "error", err)
}
// Create ticker channels for heartbeat and polling // Create ticker channels for heartbeat and polling
heartbeatTicker := time.NewTicker(a.heartbeatInterval) heartbeatTicker := time.NewTicker(a.heartbeatInterval)
defer heartbeatTicker.Stop() defer heartbeatTicker.Stop()
@@ -117,6 +124,16 @@ func (a *Agent) Run(ctx context.Context) error {
a.sendHeartbeat(ctx) a.sendHeartbeat(ctx)
case <-pollTicker.C: case <-pollTicker.C:
if a.consecutiveFailures > 0 {
backoff := time.Duration(a.consecutiveFailures) * a.pollInterval
if backoff > 5*time.Minute {
backoff = 5 * time.Minute
}
a.logger.Warn("backing off due to consecutive failures",
"failures", a.consecutiveFailures,
"backoff", backoff.String())
time.Sleep(backoff)
}
a.pollForWork(ctx) a.pollForWork(ctx)
} }
} }
@@ -134,6 +151,7 @@ func (a *Agent) sendHeartbeat(ctx context.Context) {
}) })
if err != nil { if err != nil {
a.logger.Error("heartbeat failed", "error", err) a.logger.Error("heartbeat failed", "error", err)
a.consecutiveFailures++
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
@@ -143,9 +161,11 @@ func (a *Agent) sendHeartbeat(ctx context.Context) {
a.logger.Error("heartbeat rejected", a.logger.Error("heartbeat rejected",
"status", resp.StatusCode, "status", resp.StatusCode,
"body", string(body)) "body", string(body))
a.consecutiveFailures++
return return
} }
a.consecutiveFailures = 0
a.logger.Debug("heartbeat acknowledged") a.logger.Debug("heartbeat acknowledged")
} }
@@ -159,6 +179,7 @@ func (a *Agent) pollForWork(ctx context.Context) {
resp, err := a.makeRequest(ctx, http.MethodGet, path, nil) resp, err := a.makeRequest(ctx, http.MethodGet, path, nil)
if err != nil { if err != nil {
a.logger.Error("work poll failed", "error", err) a.logger.Error("work poll failed", "error", err)
a.consecutiveFailures++
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
@@ -168,15 +189,19 @@ func (a *Agent) pollForWork(ctx context.Context) {
a.logger.Error("work poll rejected", a.logger.Error("work poll rejected",
"status", resp.StatusCode, "status", resp.StatusCode,
"body", string(body)) "body", string(body))
a.consecutiveFailures++
return return
} }
var workResp WorkResponse var workResp WorkResponse
if err := json.NewDecoder(resp.Body).Decode(&workResp); err != nil { if err := json.NewDecoder(resp.Body).Decode(&workResp); err != nil {
a.logger.Error("failed to decode work response", "error", err) a.logger.Error("failed to decode work response", "error", err)
a.consecutiveFailures++
return return
} }
a.consecutiveFailures = 0
if workResp.Count == 0 { if workResp.Count == 0 {
a.logger.Debug("no pending work") a.logger.Debug("no pending work")
return return
@@ -218,7 +243,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
a.logger.Error("failed to generate private key", a.logger.Error("failed to generate private key",
"job_id", job.ID, "job_id", job.ID,
"error", err) "error", err)
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key generation failed: %v", err)) if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key generation failed: %v", err)); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
}
return return
} }
@@ -233,7 +260,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
a.logger.Error("failed to marshal private key", a.logger.Error("failed to marshal private key",
"job_id", job.ID, "job_id", job.ID,
"error", err) "error", err)
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key marshal failed: %v", err)) if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key marshal failed: %v", err)); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
}
return return
} }
@@ -247,7 +276,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
"job_id", job.ID, "job_id", job.ID,
"key_path", keyPath, "key_path", keyPath,
"error", err) "error", err)
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key storage failed: %v", err)) if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key storage failed: %v", err)); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
}
return return
} }
@@ -256,6 +287,15 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
"key_path", keyPath, "key_path", keyPath,
"permissions", "0600") "permissions", "0600")
// Validate common name is present
if job.CommonName == "" {
a.logger.Error("empty common name in CSR job", "job_id", job.ID)
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", "empty common name"); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "error", reportErr)
}
return
}
// Step 3: Create CSR with common name and SANs // Step 3: Create CSR with common name and SANs
csrTemplate := &x509.CertificateRequest{ csrTemplate := &x509.CertificateRequest{
Subject: pkix.Name{ Subject: pkix.Name{
@@ -269,7 +309,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
a.logger.Error("failed to create CSR", a.logger.Error("failed to create CSR",
"job_id", job.ID, "job_id", job.ID,
"error", err) "error", err)
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR creation failed: %v", err)) if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR creation failed: %v", err)); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
}
return return
} }
@@ -292,7 +334,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
a.logger.Error("failed to submit CSR", a.logger.Error("failed to submit CSR",
"job_id", job.ID, "job_id", job.ID,
"error", err) "error", err)
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR submission failed: %v", err)) if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR submission failed: %v", err)); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
}
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
@@ -303,7 +347,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
"job_id", job.ID, "job_id", job.ID,
"status", resp.StatusCode, "status", resp.StatusCode,
"body", string(body)) "body", string(body))
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR rejected: %s", string(body))) if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR rejected: %s", string(body))); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
}
return return
} }
@@ -343,7 +389,9 @@ func (a *Agent) executeDeploymentJob(ctx context.Context, job JobItem) {
a.logger.Error("failed to fetch certificate", a.logger.Error("failed to fetch certificate",
"job_id", job.ID, "job_id", job.ID,
"error", err) "error", err)
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("cert fetch failed: %v", err)) if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("cert fetch failed: %v", err)); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
}
return return
} }
@@ -357,12 +405,21 @@ func (a *Agent) executeDeploymentJob(ctx context.Context, job JobItem) {
// Check for locally-stored private key (agent keygen mode) // Check for locally-stored private key (agent keygen mode)
keyPath := filepath.Join(a.config.KeyDir, job.CertificateID+".key") keyPath := filepath.Join(a.config.KeyDir, job.CertificateID+".key")
var keyPEM string var keyPEM string
if keyData, err := os.ReadFile(keyPath); err == nil { keyData, err := os.ReadFile(keyPath)
keyPEM = string(keyData) if err != nil {
a.logger.Info("loaded local private key for deployment", a.logger.Error("failed to read local private key for deployment",
"job_id", job.ID, "job_id", job.ID,
"key_path", keyPath) "key_path", keyPath,
"error", err)
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key read failed: %v", err)); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "error", reportErr)
}
return
} }
keyPEM = string(keyData)
a.logger.Info("loaded local private key for deployment",
"job_id", job.ID,
"key_path", keyPath)
// Deploy to the target using the appropriate connector // Deploy to the target using the appropriate connector
if job.TargetType != "" { if job.TargetType != "" {
@@ -372,7 +429,9 @@ func (a *Agent) executeDeploymentJob(ctx context.Context, job JobItem) {
"job_id", job.ID, "job_id", job.ID,
"target_type", job.TargetType, "target_type", job.TargetType,
"error", err) "error", err)
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("connector init failed: %v", err)) if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("connector init failed: %v", err)); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
}
return return
} }
@@ -393,7 +452,9 @@ func (a *Agent) executeDeploymentJob(ctx context.Context, job JobItem) {
"job_id", job.ID, "job_id", job.ID,
"target_type", job.TargetType, "target_type", job.TargetType,
"error", err) "error", err)
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("deployment failed: %v", err)) if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("deployment failed: %v", err)); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
}
return return
} }
@@ -454,29 +515,17 @@ func (a *Agent) createTargetConnector(targetType string, configJSON json.RawMess
// splitPEMChain splits a PEM chain into the first certificate (cert) and the rest (chain). // splitPEMChain splits a PEM chain into the first certificate (cert) and the rest (chain).
// The control plane returns the full chain as a single string with PEM blocks concatenated. // The control plane returns the full chain as a single string with PEM blocks concatenated.
func splitPEMChain(pemChain string) (string, string) { func splitPEMChain(pemChain string) (string, string) {
const endCert = "-----END CERTIFICATE-----" data := []byte(pemChain)
idx := 0 block, rest := pem.Decode(data)
count := 0 if block == nil {
for i := 0; i < len(pemChain); i++ {
if i+len(endCert) <= len(pemChain) && pemChain[i:i+len(endCert)] == endCert {
count++
if count == 1 {
idx = i + len(endCert)
break
}
}
}
if idx == 0 || idx >= len(pemChain) {
return pemChain, "" return pemChain, ""
} }
cert := pemChain[:idx] + "\n" cert := string(pem.EncodeToMemory(block))
chain := ""
// Skip whitespace between cert and chain // Skip whitespace between cert and chain
for idx < len(pemChain) && (pemChain[idx] == '\n' || pemChain[idx] == '\r' || pemChain[idx] == ' ') { chain := strings.TrimSpace(string(rest))
idx++ if chain == "" {
} return cert, ""
if idx < len(pemChain) {
chain = pemChain[idx:]
} }
return cert, chain return cert, chain
} }
@@ -626,6 +675,12 @@ func main() {
// Run agent in background // Run agent in background
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
defer func() {
if r := recover(); r != nil {
logger.Error("agent panicked", "error", fmt.Sprintf("%v", r))
errChan <- fmt.Errorf("agent panic: %v", r)
}
}()
errChan <- agent.Run(ctx) errChan <- agent.Run(ctx)
}() }()