diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 8383b4a..3578570 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -18,6 +18,7 @@ import ( "os" "os/signal" "path/filepath" + "strings" "syscall" "time" @@ -47,8 +48,9 @@ type Agent struct { client *http.Client // Configuration - heartbeatInterval time.Duration - pollInterval time.Duration + heartbeatInterval time.Duration + pollInterval time.Duration + consecutiveFailures int } // WorkResponse represents the response from the work polling endpoint. @@ -95,6 +97,11 @@ func (a *Agent) Run(ctx context.Context) error { return fmt.Errorf("failed to create key directory %s: %w", a.config.KeyDir, err) } + // Enforce permissions even if directory already exists + if err := os.Chmod(a.config.KeyDir, 0700); err != nil { + a.logger.Warn("failed to enforce key directory permissions", "path", a.config.KeyDir, "error", err) + } + // Create ticker channels for heartbeat and polling heartbeatTicker := time.NewTicker(a.heartbeatInterval) defer heartbeatTicker.Stop() @@ -117,6 +124,16 @@ func (a *Agent) Run(ctx context.Context) error { a.sendHeartbeat(ctx) case <-pollTicker.C: + if a.consecutiveFailures > 0 { + backoff := time.Duration(a.consecutiveFailures) * a.pollInterval + if backoff > 5*time.Minute { + backoff = 5 * time.Minute + } + a.logger.Warn("backing off due to consecutive failures", + "failures", a.consecutiveFailures, + "backoff", backoff.String()) + time.Sleep(backoff) + } a.pollForWork(ctx) } } @@ -134,6 +151,7 @@ func (a *Agent) sendHeartbeat(ctx context.Context) { }) if err != nil { a.logger.Error("heartbeat failed", "error", err) + a.consecutiveFailures++ return } defer resp.Body.Close() @@ -143,9 +161,11 @@ func (a *Agent) sendHeartbeat(ctx context.Context) { a.logger.Error("heartbeat rejected", "status", resp.StatusCode, "body", string(body)) + a.consecutiveFailures++ return } + a.consecutiveFailures = 0 a.logger.Debug("heartbeat acknowledged") } @@ -159,6 +179,7 @@ func (a *Agent) pollForWork(ctx context.Context) { resp, err := a.makeRequest(ctx, http.MethodGet, path, nil) if err != nil { a.logger.Error("work poll failed", "error", err) + a.consecutiveFailures++ return } defer resp.Body.Close() @@ -168,15 +189,19 @@ func (a *Agent) pollForWork(ctx context.Context) { a.logger.Error("work poll rejected", "status", resp.StatusCode, "body", string(body)) + a.consecutiveFailures++ return } var workResp WorkResponse if err := json.NewDecoder(resp.Body).Decode(&workResp); err != nil { a.logger.Error("failed to decode work response", "error", err) + a.consecutiveFailures++ return } + a.consecutiveFailures = 0 + if workResp.Count == 0 { a.logger.Debug("no pending work") return @@ -218,7 +243,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) { a.logger.Error("failed to generate private key", "job_id", job.ID, "error", err) - _ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key generation failed: %v", err)) + if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key generation failed: %v", err)); reportErr != nil { + a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr) + } return } @@ -233,7 +260,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) { a.logger.Error("failed to marshal private key", "job_id", job.ID, "error", err) - _ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key marshal failed: %v", err)) + if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key marshal failed: %v", err)); reportErr != nil { + a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr) + } return } @@ -247,7 +276,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) { "job_id", job.ID, "key_path", keyPath, "error", err) - _ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key storage failed: %v", err)) + if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key storage failed: %v", err)); reportErr != nil { + a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr) + } return } @@ -256,6 +287,15 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) { "key_path", keyPath, "permissions", "0600") + // Validate common name is present + if job.CommonName == "" { + a.logger.Error("empty common name in CSR job", "job_id", job.ID) + if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", "empty common name"); reportErr != nil { + a.logger.Error("failed to report job status to server", "job_id", job.ID, "error", reportErr) + } + return + } + // Step 3: Create CSR with common name and SANs csrTemplate := &x509.CertificateRequest{ Subject: pkix.Name{ @@ -269,7 +309,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) { a.logger.Error("failed to create CSR", "job_id", job.ID, "error", err) - _ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR creation failed: %v", err)) + if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR creation failed: %v", err)); reportErr != nil { + a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr) + } return } @@ -292,7 +334,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) { a.logger.Error("failed to submit CSR", "job_id", job.ID, "error", err) - _ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR submission failed: %v", err)) + if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR submission failed: %v", err)); reportErr != nil { + a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr) + } return } defer resp.Body.Close() @@ -303,7 +347,9 @@ func (a *Agent) executeCSRJob(ctx context.Context, job JobItem) { "job_id", job.ID, "status", resp.StatusCode, "body", string(body)) - _ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR rejected: %s", string(body))) + if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("CSR rejected: %s", string(body))); reportErr != nil { + a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr) + } return } @@ -343,7 +389,9 @@ func (a *Agent) executeDeploymentJob(ctx context.Context, job JobItem) { a.logger.Error("failed to fetch certificate", "job_id", job.ID, "error", err) - _ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("cert fetch failed: %v", err)) + if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("cert fetch failed: %v", err)); reportErr != nil { + a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr) + } return } @@ -357,12 +405,21 @@ func (a *Agent) executeDeploymentJob(ctx context.Context, job JobItem) { // Check for locally-stored private key (agent keygen mode) keyPath := filepath.Join(a.config.KeyDir, job.CertificateID+".key") var keyPEM string - if keyData, err := os.ReadFile(keyPath); err == nil { - keyPEM = string(keyData) - a.logger.Info("loaded local private key for deployment", + keyData, err := os.ReadFile(keyPath) + if err != nil { + a.logger.Error("failed to read local private key for deployment", "job_id", job.ID, - "key_path", keyPath) + "key_path", keyPath, + "error", err) + if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("key read failed: %v", err)); reportErr != nil { + a.logger.Error("failed to report job status to server", "job_id", job.ID, "error", reportErr) + } + return } + keyPEM = string(keyData) + a.logger.Info("loaded local private key for deployment", + "job_id", job.ID, + "key_path", keyPath) // Deploy to the target using the appropriate connector if job.TargetType != "" { @@ -372,7 +429,9 @@ func (a *Agent) executeDeploymentJob(ctx context.Context, job JobItem) { "job_id", job.ID, "target_type", job.TargetType, "error", err) - _ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("connector init failed: %v", err)) + if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("connector init failed: %v", err)); reportErr != nil { + a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr) + } return } @@ -393,7 +452,9 @@ func (a *Agent) executeDeploymentJob(ctx context.Context, job JobItem) { "job_id", job.ID, "target_type", job.TargetType, "error", err) - _ = a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("deployment failed: %v", err)) + if reportErr := a.reportJobStatus(ctx, job.ID, "Failed", fmt.Sprintf("deployment failed: %v", err)); reportErr != nil { + a.logger.Error("failed to report job status to server", "job_id", job.ID, "status", "Failed", "error", reportErr) + } return } @@ -454,29 +515,17 @@ func (a *Agent) createTargetConnector(targetType string, configJSON json.RawMess // splitPEMChain splits a PEM chain into the first certificate (cert) and the rest (chain). // The control plane returns the full chain as a single string with PEM blocks concatenated. func splitPEMChain(pemChain string) (string, string) { - const endCert = "-----END CERTIFICATE-----" - idx := 0 - count := 0 - for i := 0; i < len(pemChain); i++ { - if i+len(endCert) <= len(pemChain) && pemChain[i:i+len(endCert)] == endCert { - count++ - if count == 1 { - idx = i + len(endCert) - break - } - } - } - if idx == 0 || idx >= len(pemChain) { + data := []byte(pemChain) + block, rest := pem.Decode(data) + if block == nil { return pemChain, "" } - cert := pemChain[:idx] + "\n" - chain := "" + cert := string(pem.EncodeToMemory(block)) + // Skip whitespace between cert and chain - for idx < len(pemChain) && (pemChain[idx] == '\n' || pemChain[idx] == '\r' || pemChain[idx] == ' ') { - idx++ - } - if idx < len(pemChain) { - chain = pemChain[idx:] + chain := strings.TrimSpace(string(rest)) + if chain == "" { + return cert, "" } return cert, chain } @@ -626,6 +675,12 @@ func main() { // Run agent in background errChan := make(chan error, 1) go func() { + defer func() { + if r := recover(); r != nil { + logger.Error("agent panicked", "error", fmt.Sprintf("%v", r)) + errChan <- fmt.Errorf("agent panic: %v", r) + } + }() errChan <- agent.Run(ctx) }()