mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-07 13:41:30 +00:00
fix: harden agent with backoff, panic recovery, and error handling
- Exponential backoff on consecutive poll/heartbeat failures (max 5min) - Panic recovery wrapper on agent.Run goroutine - All 9 silent reportJobStatus errors now logged properly - Key read failures return error and report job failure - CommonName validation before CSR creation - KeyDir permissions enforced with os.Chmod after MkdirAll - splitPEMChain rewritten to use encoding/pem instead of string parsing Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+90
-35
@@ -18,6 +18,7 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -47,8 +48,9 @@ type Agent struct {
|
||||
client *http.Client
|
||||
|
||||
// Configuration
|
||||
heartbeatInterval time.Duration
|
||||
pollInterval time.Duration
|
||||
heartbeatInterval time.Duration
|
||||
pollInterval time.Duration
|
||||
consecutiveFailures int
|
||||
}
|
||||
|
||||
// WorkResponse represents the response from the work polling endpoint.
|
||||
@@ -95,6 +97,11 @@ func (a *Agent) Run(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to create key directory %s: %w", a.config.KeyDir, err)
|
||||
}
|
||||
|
||||
// Enforce permissions even if directory already exists
|
||||
if err := os.Chmod(a.config.KeyDir, 0700); err != nil {
|
||||
a.logger.Warn("failed to enforce key directory permissions", "path", a.config.KeyDir, "error", err)
|
||||
}
|
||||
|
||||
// Create ticker channels for heartbeat and polling
|
||||
heartbeatTicker := time.NewTicker(a.heartbeatInterval)
|
||||
defer heartbeatTicker.Stop()
|
||||
@@ -117,6 +124,16 @@ func (a *Agent) Run(ctx context.Context) error {
|
||||
a.sendHeartbeat(ctx)
|
||||
|
||||
case <-pollTicker.C:
|
||||
if a.consecutiveFailures > 0 {
|
||||
backoff := time.Duration(a.consecutiveFailures) * a.pollInterval
|
||||
if backoff > 5*time.Minute {
|
||||
backoff = 5 * time.Minute
|
||||
}
|
||||
a.logger.Warn("backing off due to consecutive failures",
|
||||
"failures", a.consecutiveFailures,
|
||||
"backoff", backoff.String())
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
a.pollForWork(ctx)
|
||||
}
|
||||
}
|
||||
@@ -134,6 +151,7 @@ func (a *Agent) sendHeartbeat(ctx context.Context) {
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Error("heartbeat failed", "error", err)
|
||||
a.consecutiveFailures++
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
@@ -143,9 +161,11 @@ func (a *Agent) sendHeartbeat(ctx context.Context) {
|
||||
a.logger.Error("heartbeat rejected",
|
||||
"status", resp.StatusCode,
|
||||
"body", string(body))
|
||||
a.consecutiveFailures++
|
||||
return
|
||||
}
|
||||
|
||||
a.consecutiveFailures = 0
|
||||
a.logger.Debug("heartbeat acknowledged")
|
||||
}
|
||||
|
||||
@@ -159,6 +179,7 @@ func (a *Agent) pollForWork(ctx context.Context) {
|
||||
resp, err := a.makeRequest(ctx, http.MethodGet, path, nil)
|
||||
if err != nil {
|
||||
a.logger.Error("work poll failed", "error", err)
|
||||
a.consecutiveFailures++
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
@@ -168,15 +189,19 @@ func (a *Agent) pollForWork(ctx context.Context) {
|
||||
a.logger.Error("work poll rejected",
|
||||
"status", resp.StatusCode,
|
||||
"body", string(body))
|
||||
a.consecutiveFailures++
|
||||
return
|
||||
}
|
||||
|
||||
var workResp WorkResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&workResp); err != nil {
|
||||
a.logger.Error("failed to decode work response", "error", err)
|
||||
a.consecutiveFailures++
|
||||
return
|
||||
}
|
||||
|
||||
a.consecutiveFailures = 0
|
||||
|
||||
if workResp.Count == 0 {
|
||||
a.logger.Debug("no pending work")
|
||||
return
|
||||
@@ -218,7 +243,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
|
||||
a.logger.Error("failed to generate private key",
|
||||
"job_id", job.ID,
|
||||
"error", err)
|
||||
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key generation failed: %v", err))
|
||||
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key generation failed: %v", err)); reportErr != nil {
|
||||
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -233,7 +260,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
|
||||
a.logger.Error("failed to marshal private key",
|
||||
"job_id", job.ID,
|
||||
"error", err)
|
||||
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key marshal failed: %v", err))
|
||||
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key marshal failed: %v", err)); reportErr != nil {
|
||||
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -247,7 +276,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
|
||||
"job_id", job.ID,
|
||||
"key_path", keyPath,
|
||||
"error", err)
|
||||
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key storage failed: %v", err))
|
||||
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key storage failed: %v", err)); reportErr != nil {
|
||||
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -256,6 +287,15 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
|
||||
"key_path", keyPath,
|
||||
"permissions", "0600")
|
||||
|
||||
// Validate common name is present
|
||||
if job.CommonName == "" {
|
||||
a.logger.Error("empty common name in CSR job", "job_id", job.ID)
|
||||
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", "empty common name"); reportErr != nil {
|
||||
a.logger.Error("failed to report job status to server", "job_id", job.ID, "error", reportErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Step 3: Create CSR with common name and SANs
|
||||
csrTemplate := &x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
@@ -269,7 +309,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
|
||||
a.logger.Error("failed to create CSR",
|
||||
"job_id", job.ID,
|
||||
"error", err)
|
||||
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR creation failed: %v", err))
|
||||
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR creation failed: %v", err)); reportErr != nil {
|
||||
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -292,7 +334,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
|
||||
a.logger.Error("failed to submit CSR",
|
||||
"job_id", job.ID,
|
||||
"error", err)
|
||||
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR submission failed: %v", err))
|
||||
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR submission failed: %v", err)); reportErr != nil {
|
||||
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
@@ -303,7 +347,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) {
|
||||
"job_id", job.ID,
|
||||
"status", resp.StatusCode,
|
||||
"body", string(body))
|
||||
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR rejected: %s", string(body)))
|
||||
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR rejected: %s", string(body))); reportErr != nil {
|
||||
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -343,7 +389,9 @@ func (a *Agent) executeDeploymentJob(ctx context.Context, job JobItem) {
|
||||
a.logger.Error("failed to fetch certificate",
|
||||
"job_id", job.ID,
|
||||
"error", err)
|
||||
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("cert fetch failed: %v", err))
|
||||
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("cert fetch failed: %v", err)); reportErr != nil {
|
||||
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -357,12 +405,21 @@ func (a *Agent) executeDeploymentJob(ctx context.Context, job JobItem) {
|
||||
// Check for locally-stored private key (agent keygen mode)
|
||||
keyPath := filepath.Join(a.config.KeyDir, job.CertificateID+".key")
|
||||
var keyPEM string
|
||||
if keyData, err := os.ReadFile(keyPath); err == nil {
|
||||
keyPEM = string(keyData)
|
||||
a.logger.Info("loaded local private key for deployment",
|
||||
keyData, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
a.logger.Error("failed to read local private key for deployment",
|
||||
"job_id", job.ID,
|
||||
"key_path", keyPath)
|
||||
"key_path", keyPath,
|
||||
"error", err)
|
||||
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key read failed: %v", err)); reportErr != nil {
|
||||
a.logger.Error("failed to report job status to server", "job_id", job.ID, "error", reportErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
keyPEM = string(keyData)
|
||||
a.logger.Info("loaded local private key for deployment",
|
||||
"job_id", job.ID,
|
||||
"key_path", keyPath)
|
||||
|
||||
// Deploy to the target using the appropriate connector
|
||||
if job.TargetType != "" {
|
||||
@@ -372,7 +429,9 @@ func (a *Agent) executeDeploymentJob(ctx context.Context, job JobItem) {
|
||||
"job_id", job.ID,
|
||||
"target_type", job.TargetType,
|
||||
"error", err)
|
||||
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("connector init failed: %v", err))
|
||||
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("connector init failed: %v", err)); reportErr != nil {
|
||||
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -393,7 +452,9 @@ func (a *Agent) executeDeploymentJob(ctx context.Context, job JobItem) {
|
||||
"job_id", job.ID,
|
||||
"target_type", job.TargetType,
|
||||
"error", err)
|
||||
_ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("deployment failed: %v", err))
|
||||
if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("deployment failed: %v", err)); reportErr != nil {
|
||||
a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -454,29 +515,17 @@ func (a *Agent) createTargetConnector(targetType string, configJSON json.RawMess
|
||||
// splitPEMChain splits a PEM chain into the first certificate (cert) and the rest (chain).
|
||||
// The control plane returns the full chain as a single string with PEM blocks concatenated.
|
||||
func splitPEMChain(pemChain string) (string, string) {
|
||||
const endCert = "-----END CERTIFICATE-----"
|
||||
idx := 0
|
||||
count := 0
|
||||
for i := 0; i < len(pemChain); i++ {
|
||||
if i+len(endCert) <= len(pemChain) && pemChain[i:i+len(endCert)] == endCert {
|
||||
count++
|
||||
if count == 1 {
|
||||
idx = i + len(endCert)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if idx == 0 || idx >= len(pemChain) {
|
||||
data := []byte(pemChain)
|
||||
block, rest := pem.Decode(data)
|
||||
if block == nil {
|
||||
return pemChain, ""
|
||||
}
|
||||
cert := pemChain[:idx] + "\n"
|
||||
chain := ""
|
||||
cert := string(pem.EncodeToMemory(block))
|
||||
|
||||
// Skip whitespace between cert and chain
|
||||
for idx < len(pemChain) && (pemChain[idx] == '\n' || pemChain[idx] == '\r' || pemChain[idx] == ' ') {
|
||||
idx++
|
||||
}
|
||||
if idx < len(pemChain) {
|
||||
chain = pemChain[idx:]
|
||||
chain := strings.TrimSpace(string(rest))
|
||||
if chain == "" {
|
||||
return cert, ""
|
||||
}
|
||||
return cert, chain
|
||||
}
|
||||
@@ -626,6 +675,12 @@ func main() {
|
||||
// Run agent in background
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Error("agent panicked", "error", fmt.Sprintf("%v", r))
|
||||
errChan <- fmt.Errorf("agent panic: %v", r)
|
||||
}
|
||||
}()
|
||||
errChan <- agent.Run(ctx)
|
||||
}()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user