fix: harden agent with backoff, panic recovery, and error handling

- Exponential backoff on consecutive poll/heartbeat failures (max 5min)
- Panic recovery wrapper on agent.Run goroutine
- All 9 silent reportJobStatus errors now logged properly
- Key read failures return error and report job failure
- CommonName validation before CSR creation
- KeyDir permissions enforced with os.Chmod after MkdirAll
- splitPEMChain rewritten to use encoding/pem instead of string parsing

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
shankar0123
2026-03-20 01:20:10 -04:00
parent e03a75ed9a
commit 239a1792d2
+90 -35
View File
@@ -18,6 +18,7 @@ import (
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"
@@ -47,8 +48,9 @@ type Agent struct {
client *http.Client
// Configuration
heartbeatInterval time.Duration
pollInterval time.Duration
heartbeatInterval time.Duration
pollInterval time.Duration
consecutiveFailures int
}
// WorkResponse represents the response from the work polling endpoint.
@@ -95,6 +97,11 @@ func (a *Agent) Run(ctx context.Context) error {
return fmt.Errorf("failed to create key directory %s: %w", a.config.KeyDir, err)
}
// Enforce permissions even if directory already exists
if err := os.Chmod(a.config.KeyDir, 0700); err != nil {
a.logger.Warn("failed to enforce key directory permissions", "path", a.config.KeyDir, "error", err)
}
// Create ticker channels for heartbeat and polling
heartbeatTicker := time.NewTicker(a.heartbeatInterval)
defer heartbeatTicker.Stop()
@@ -117,6 +124,16 @@ func (a *Agent) Run(ctx context.Context) error {
a.sendHeartbeat(ctx)
case <-pollTicker.C:
if a.consecutiveFailures > 0 {
backoff := time.Duration(a.consecutiveFailures) * a.pollInterval
if backoff > 5*time.Minute {
backoff = 5 * time.Minute
}
a.logger.Warn("backing off due to consecutive failures",
"failures", a.consecutiveFailures,
"backoff", backoff.String())
time.Sleep(backoff)
}
a.pollForWork(ctx)
}
}
@@ -134,6 +151,7 @@ func (a *Agent) sendHeartbeat(ctx context.Context) {
})
if err != nil {
a.logger.Error("heartbeat failed", "error", err)
a.consecutiveFailures++
return
}
defer resp.Body.Close()
@@ -143,9 +161,11 @@ func (a *Agent) sendHeartbeat(ctx context.Context) {
a.logger.Error("heartbeat rejected",
"status", resp.StatusCode,
"body", string(body))
a.consecutiveFailures++
return
}
a.consecutiveFailures = 0
a.logger.Debug("heartbeat acknowledged")
}
@@ -159,6 +179,7 @@ func (a *Agent) pollForWork(ctx context.Context) {
resp, err := a.makeRequest(ctx, http.MethodGet, path, nil)
if err != nil {
a.logger.Error("work poll failed", "error", err)
a.consecutiveFailures++
return
}
defer resp.Body.Close()
@@ -168,15 +189,19 @@ func (a *Agent) pollForWork(ctx context.Context) {
a.logger.Error("work poll rejected",
"status", resp.StatusCode,
"body", string(body))
a.consecutiveFailures++
return
}
var workResp WorkResponse
if err := json.NewDecoder(resp.Body).Decode(&workResp); err != nil {
a.logger.Error("failed to decode work response", "error", err)
a.consecutiveFailures++
return
}
a.consecutiveFailures = 0
if workResp.Count == 0 {
a.logger.Debug("no pending work")
return
@@ -218,7 +243,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
a.logger.Error("failed to generate private key",
"job_id", job.ID,
"error", err)
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key generation failed: %v", err))
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key generation failed: %v", err)); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
}
return
}
@@ -233,7 +260,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
a.logger.Error("failed to marshal private key",
"job_id", job.ID,
"error", err)
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key marshal failed: %v", err))
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key marshal failed: %v", err)); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
}
return
}
@@ -247,7 +276,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
"job_id", job.ID,
"key_path", keyPath,
"error", err)
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key storage failed: %v", err))
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key storage failed: %v", err)); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
}
return
}
@@ -256,6 +287,15 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
"key_path", keyPath,
"permissions", "0600")
// Validate common name is present
if job.CommonName == "" {
a.logger.Error("empty common name in CSR job", "job_id", job.ID)
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", "empty common name"); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "error", reportErr)
}
return
}
// Step 3: Create CSR with common name and SANs
csrTemplate := &x509.CertificateRequest{
Subject: pkix.Name{
@@ -269,7 +309,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
a.logger.Error("failed to create CSR",
"job_id", job.ID,
"error", err)
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR creation failed: %v", err))
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR creation failed: %v", err)); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
}
return
}
@@ -292,7 +334,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
a.logger.Error("failed to submit CSR",
"job_id", job.ID,
"error", err)
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR submission failed: %v", err))
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR submission failed: %v", err)); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
}
return
}
defer resp.Body.Close()
@@ -303,7 +347,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
"job_id", job.ID,
"status", resp.StatusCode,
"body", string(body))
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR rejected: %s", string(body)))
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR rejected: %s", string(body))); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
}
return
}
@@ -343,7 +389,9 @@ func (a *Agent) executeDeploymentJob(ctx context.Context, job JobItem) {
a.logger.Error("failed to fetch certificate",
"job_id", job.ID,
"error", err)
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("cert fetch failed: %v", err))
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("cert fetch failed: %v", err)); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
}
return
}
@@ -357,12 +405,21 @@ func (a *Agent) executeDeploymentJob(ctx context.Context, job JobItem) {
// Check for locally-stored private key (agent keygen mode)
keyPath := filepath.Join(a.config.KeyDir, job.CertificateID+".key")
var keyPEM string
if keyData, err := os.ReadFile(keyPath); err == nil {
keyPEM = string(keyData)
a.logger.Info("loaded local private key for deployment",
keyData, err := os.ReadFile(keyPath)
if err != nil {
a.logger.Error("failed to read local private key for deployment",
"job_id", job.ID,
"key_path", keyPath)
"key_path", keyPath,
"error", err)
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key read failed: %v", err)); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "error", reportErr)
}
return
}
keyPEM = string(keyData)
a.logger.Info("loaded local private key for deployment",
"job_id", job.ID,
"key_path", keyPath)
// Deploy to the target using the appropriate connector
if job.TargetType != "" {
@@ -372,7 +429,9 @@ func (a *Agent) executeDeploymentJob(ctx context.Context, job JobItem) {
"job_id", job.ID,
"target_type", job.TargetType,
"error", err)
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("connector init failed: %v", err))
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("connector init failed: %v", err)); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
}
return
}
@@ -393,7 +452,9 @@ func (a *Agent) executeDeploymentJob(ctx context.Context, job JobItem) {
"job_id", job.ID,
"target_type", job.TargetType,
"error", err)
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("deployment failed: %v", err))
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("deployment failed: %v", err)); reportErr != nil {
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
}
return
}
@@ -454,29 +515,17 @@ func (a *Agent) createTargetConnector(targetType string, configJSON json.RawMess
// splitPEMChain splits a PEM chain into the first certificate (cert) and the rest (chain).
// The control plane returns the full chain as a single string with PEM blocks concatenated.
func splitPEMChain(pemChain string) (string, string) {
const endCert = "-----END CERTIFICATE-----"
idx := 0
count := 0
for i := 0; i < len(pemChain); i++ {
if i+len(endCert) <= len(pemChain) && pemChain[i:i+len(endCert)] == endCert {
count++
if count == 1 {
idx = i + len(endCert)
break
}
}
}
if idx == 0 || idx >= len(pemChain) {
data := []byte(pemChain)
block, rest := pem.Decode(data)
if block == nil {
return pemChain, ""
}
cert := pemChain[:idx] + "\n"
chain := ""
cert := string(pem.EncodeToMemory(block))
// Skip whitespace between cert and chain
for idx < len(pemChain) && (pemChain[idx] == '\n' || pemChain[idx] == '\r' || pemChain[idx] == ' ') {
idx++
}
if idx < len(pemChain) {
chain = pemChain[idx:]
chain := strings.TrimSpace(string(rest))
if chain == "" {
return cert, ""
}
return cert, chain
}
@@ -626,6 +675,12 @@ func main() {
// Run agent in background
errChan := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
logger.Error("agent panicked", "error", fmt.Sprintf("%v", r))
errChan <- fmt.Errorf("agent panic: %v", r)
}
}()
errChan <- agent.Run(ctx)
}()