mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-07 15:21:35 +00:00
feat: M10 — agent metadata collection, Apache httpd + HAProxy target connectors
Agents now report OS, architecture, IP address, hostname, and version via heartbeat using runtime.GOOS, runtime.GOARCH, and net.Dial. New migration adds columns to agents table. Heartbeat handler, service, and repository updated to accept and persist metadata. GUI shows OS/Arch in agent list and full system info in agent detail page. Apache httpd connector: separate cert/chain/key files, apachectl configtest validation, graceful reload. HAProxy connector: combined PEM file (cert+chain+key), optional config validation, reload. Both wired into agent binary's target connector switch. 14 tests for new connectors. All existing tests updated for new Heartbeat/UpdateHeartbeat signatures. Docs updated across README, architecture, concepts, and connectors guides. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -16,7 +16,7 @@ type MockAgentService struct {
|
||||
ListAgentsFn func(page, perPage int) ([]domain.Agent, int64, error)
|
||||
GetAgentFn func(id string) (*domain.Agent, error)
|
||||
RegisterAgentFn func(agent domain.Agent) (*domain.Agent, error)
|
||||
HeartbeatFn func(agentID string) error
|
||||
HeartbeatFn func(agentID string, metadata *domain.AgentMetadata) error
|
||||
CSRSubmitFn func(agentID string, csrPEM string) (string, error)
|
||||
CSRSubmitForCertFn func(agentID string, certID string, csrPEM string) (string, error)
|
||||
CertificatePickupFn func(agentID, certID string) (string, error)
|
||||
@@ -46,9 +46,9 @@ func (m *MockAgentService) RegisterAgent(agent domain.Agent) (*domain.Agent, err
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockAgentService) Heartbeat(agentID string) error {
|
||||
func (m *MockAgentService) Heartbeat(agentID string, metadata *domain.AgentMetadata) error {
|
||||
if m.HeartbeatFn != nil {
|
||||
return m.HeartbeatFn(agentID)
|
||||
return m.HeartbeatFn(agentID, metadata)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -309,7 +309,7 @@ func TestRegisterAgent_InvalidBody(t *testing.T) {
|
||||
// Test Heartbeat - success case
|
||||
func TestHeartbeat_Success(t *testing.T) {
|
||||
mock := &MockAgentService{
|
||||
HeartbeatFn: func(agentID string) error {
|
||||
HeartbeatFn: func(agentID string, metadata *domain.AgentMetadata) error {
|
||||
if agentID == "a-prod-001" {
|
||||
return nil
|
||||
}
|
||||
@@ -341,7 +341,7 @@ func TestHeartbeat_Success(t *testing.T) {
|
||||
// Test Heartbeat - service error
|
||||
func TestHeartbeat_ServiceError(t *testing.T) {
|
||||
mock := &MockAgentService{
|
||||
HeartbeatFn: func(agentID string) error {
|
||||
HeartbeatFn: func(agentID string, metadata *domain.AgentMetadata) error {
|
||||
return ErrMockServiceFailed
|
||||
},
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ type AgentService interface {
|
||||
ListAgents(page, perPage int) ([]domain.Agent, int64, error)
|
||||
GetAgent(id string) (*domain.Agent, error)
|
||||
RegisterAgent(agent domain.Agent) (*domain.Agent, error)
|
||||
Heartbeat(agentID string) error
|
||||
Heartbeat(agentID string, metadata *domain.AgentMetadata) error
|
||||
CSRSubmit(agentID string, csrPEM string) (string, error)
|
||||
CSRSubmitForCert(agentID string, certID string, csrPEM string) (string, error)
|
||||
CertificatePickup(agentID, certID string) (string, error)
|
||||
@@ -159,7 +159,30 @@ func (h AgentHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
agentID := parts[0]
|
||||
|
||||
if err := h.svc.Heartbeat(agentID); err != nil {
|
||||
// Parse optional metadata from request body
|
||||
var metadata *domain.AgentMetadata
|
||||
if r.Body != nil {
|
||||
var body struct {
|
||||
Version string `json:"version"`
|
||||
Hostname string `json:"hostname"`
|
||||
OS string `json:"os"`
|
||||
Architecture string `json:"architecture"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err == nil {
|
||||
if body.Version != "" || body.Hostname != "" || body.OS != "" || body.Architecture != "" || body.IPAddress != "" {
|
||||
metadata = &domain.AgentMetadata{
|
||||
Version: body.Version,
|
||||
Hostname: body.Hostname,
|
||||
OS: body.OS,
|
||||
Architecture: body.Architecture,
|
||||
IPAddress: body.IPAddress,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.svc.Heartbeat(agentID, metadata); err != nil {
|
||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to record heartbeat", requestID)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -0,0 +1,231 @@
|
||||
package apache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/connector/target"
|
||||
)
|
||||
|
||||
// Config represents the Apache httpd deployment target configuration.
|
||||
// This configuration is used on the agent side to deploy certificates to Apache.
|
||||
type Config struct {
|
||||
CertPath string `json:"cert_path"` // Path where cert will be written (e.g., /etc/apache2/ssl/cert.pem)
|
||||
KeyPath string `json:"key_path"` // Path where private key will be written
|
||||
ChainPath string `json:"chain_path"` // Path where CA chain will be written
|
||||
ReloadCommand string `json:"reload_command"` // Command to reload Apache (e.g., "apachectl graceful" or "systemctl reload apache2")
|
||||
ValidateCommand string `json:"validate_command"` // Command to validate Apache config (e.g., "apachectl configtest")
|
||||
}
|
||||
|
||||
// Connector implements the target.Connector interface for Apache httpd servers.
|
||||
// This connector runs on the AGENT side and handles local certificate deployment.
|
||||
type Connector struct {
|
||||
config *Config
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// New creates a new Apache target connector with the given configuration and logger.
|
||||
func New(config *Config, logger *slog.Logger) *Connector {
|
||||
return &Connector{
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateConfig checks that all required configuration paths and commands are valid.
|
||||
func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessage) error {
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(rawConfig, &cfg); err != nil {
|
||||
return fmt.Errorf("invalid Apache config: %w", err)
|
||||
}
|
||||
|
||||
if cfg.CertPath == "" || cfg.ChainPath == "" {
|
||||
return fmt.Errorf("Apache cert_path and chain_path are required")
|
||||
}
|
||||
|
||||
if cfg.ReloadCommand == "" || cfg.ValidateCommand == "" {
|
||||
return fmt.Errorf("Apache reload_command and validate_command are required")
|
||||
}
|
||||
|
||||
c.logger.Info("validating Apache configuration",
|
||||
"cert_path", cfg.CertPath,
|
||||
"chain_path", cfg.ChainPath)
|
||||
|
||||
// Verify parent directory exists
|
||||
certDir := filepath.Dir(cfg.CertPath)
|
||||
if _, err := os.Stat(certDir); os.IsNotExist(err) {
|
||||
return fmt.Errorf("Apache cert directory does not exist: %s", certDir)
|
||||
}
|
||||
|
||||
// Verify validate command works
|
||||
cmd := exec.CommandContext(ctx, "sh", "-c", cfg.ValidateCommand)
|
||||
if err := cmd.Run(); err != nil {
|
||||
c.logger.Warn("Apache config validation failed during config check",
|
||||
"error", err,
|
||||
"validate_command", cfg.ValidateCommand)
|
||||
// Don't fail; Apache might not be installed yet
|
||||
}
|
||||
|
||||
c.config = &cfg
|
||||
c.logger.Info("Apache configuration validated")
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeployCertificate writes the certificate, key, and chain to configured paths
|
||||
// and reloads Apache to pick up the new certificates.
|
||||
//
|
||||
// Steps:
|
||||
// 1. Write certificate to cert_path with mode 0644
|
||||
// 2. Write private key to key_path with mode 0600 (owner-only read)
|
||||
// 3. Write chain to chain_path with mode 0644
|
||||
// 4. Validate Apache configuration with configtest
|
||||
// 5. Execute graceful reload command
|
||||
func (c *Connector) DeployCertificate(ctx context.Context, request target.DeploymentRequest) (*target.DeploymentResult, error) {
|
||||
c.logger.Info("deploying certificate to Apache httpd",
|
||||
"cert_path", c.config.CertPath,
|
||||
"chain_path", c.config.ChainPath)
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// Write certificate (0644: rw-r--r--)
|
||||
if err := os.WriteFile(c.config.CertPath, []byte(request.CertPEM), 0644); err != nil {
|
||||
errMsg := fmt.Sprintf("failed to write certificate: %v", err)
|
||||
c.logger.Error("certificate deployment failed", "error", err)
|
||||
return &target.DeploymentResult{
|
||||
Success: false,
|
||||
TargetAddress: c.config.CertPath,
|
||||
Message: errMsg,
|
||||
DeployedAt: time.Now(),
|
||||
}, fmt.Errorf("%s", errMsg)
|
||||
}
|
||||
|
||||
// Write private key with secure permissions (0600: rw-------)
|
||||
if c.config.KeyPath != "" && request.KeyPEM != "" {
|
||||
if err := os.WriteFile(c.config.KeyPath, []byte(request.KeyPEM), 0600); err != nil {
|
||||
errMsg := fmt.Sprintf("failed to write private key: %v", err)
|
||||
c.logger.Error("key deployment failed", "error", err)
|
||||
return &target.DeploymentResult{
|
||||
Success: false,
|
||||
TargetAddress: c.config.KeyPath,
|
||||
Message: errMsg,
|
||||
DeployedAt: time.Now(),
|
||||
}, fmt.Errorf("%s", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// Write chain (0644: rw-r--r--)
|
||||
if err := os.WriteFile(c.config.ChainPath, []byte(request.ChainPEM), 0644); err != nil {
|
||||
errMsg := fmt.Sprintf("failed to write chain: %v", err)
|
||||
c.logger.Error("chain deployment failed", "error", err)
|
||||
return &target.DeploymentResult{
|
||||
Success: false,
|
||||
TargetAddress: c.config.ChainPath,
|
||||
Message: errMsg,
|
||||
DeployedAt: time.Now(),
|
||||
}, fmt.Errorf("%s", errMsg)
|
||||
}
|
||||
|
||||
// Validate Apache configuration before reload
|
||||
c.logger.Debug("validating Apache configuration", "validate_command", c.config.ValidateCommand)
|
||||
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand)
|
||||
if output, err := validateCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Apache config validation failed: %v (output: %s)", err, string(output))
|
||||
c.logger.Error("Apache validation failed", "error", err, "output", string(output))
|
||||
return &target.DeploymentResult{
|
||||
Success: false,
|
||||
TargetAddress: c.config.CertPath,
|
||||
Message: errMsg,
|
||||
DeployedAt: time.Now(),
|
||||
}, fmt.Errorf("%s", errMsg)
|
||||
}
|
||||
|
||||
// Graceful reload
|
||||
c.logger.Debug("reloading Apache", "reload_command", c.config.ReloadCommand)
|
||||
reloadCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ReloadCommand)
|
||||
if output, err := reloadCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Apache reload failed: %v (output: %s)", err, string(output))
|
||||
c.logger.Error("Apache reload failed", "error", err, "output", string(output))
|
||||
return &target.DeploymentResult{
|
||||
Success: false,
|
||||
TargetAddress: c.config.CertPath,
|
||||
Message: errMsg,
|
||||
DeployedAt: time.Now(),
|
||||
}, fmt.Errorf("%s", errMsg)
|
||||
}
|
||||
|
||||
deploymentDuration := time.Since(startTime)
|
||||
c.logger.Info("certificate deployed to Apache successfully",
|
||||
"duration", deploymentDuration.String(),
|
||||
"cert_path", c.config.CertPath)
|
||||
|
||||
return &target.DeploymentResult{
|
||||
Success: true,
|
||||
TargetAddress: c.config.CertPath,
|
||||
DeploymentID: fmt.Sprintf("apache-%d", time.Now().Unix()),
|
||||
Message: "Certificate deployed and Apache reloaded successfully",
|
||||
DeployedAt: time.Now(),
|
||||
Metadata: map[string]string{
|
||||
"cert_path": c.config.CertPath,
|
||||
"chain_path": c.config.ChainPath,
|
||||
"duration_ms": fmt.Sprintf("%d", deploymentDuration.Milliseconds()),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateDeployment verifies that the deployed certificate is valid and accessible.
|
||||
func (c *Connector) ValidateDeployment(ctx context.Context, request target.ValidationRequest) (*target.ValidationResult, error) {
|
||||
c.logger.Info("validating Apache deployment",
|
||||
"certificate_id", request.CertificateID,
|
||||
"serial", request.Serial)
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// Validate Apache configuration
|
||||
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand)
|
||||
if output, err := validateCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Apache config validation failed: %v (output: %s)", err, string(output))
|
||||
c.logger.Error("validation failed", "error", err)
|
||||
return &target.ValidationResult{
|
||||
Valid: false,
|
||||
Serial: request.Serial,
|
||||
TargetAddress: c.config.CertPath,
|
||||
Message: errMsg,
|
||||
ValidatedAt: time.Now(),
|
||||
}, fmt.Errorf("%s", errMsg)
|
||||
}
|
||||
|
||||
// Verify certificate file exists and is readable
|
||||
if _, err := os.Stat(c.config.CertPath); os.IsNotExist(err) {
|
||||
errMsg := fmt.Sprintf("certificate file not found: %s", c.config.CertPath)
|
||||
c.logger.Error("validation failed", "error", err)
|
||||
return &target.ValidationResult{
|
||||
Valid: false,
|
||||
Serial: request.Serial,
|
||||
TargetAddress: c.config.CertPath,
|
||||
Message: errMsg,
|
||||
ValidatedAt: time.Now(),
|
||||
}, fmt.Errorf("%s", errMsg)
|
||||
}
|
||||
|
||||
validationDuration := time.Since(startTime)
|
||||
c.logger.Info("Apache deployment validated successfully",
|
||||
"duration", validationDuration.String())
|
||||
|
||||
return &target.ValidationResult{
|
||||
Valid: true,
|
||||
Serial: request.Serial,
|
||||
TargetAddress: c.config.CertPath,
|
||||
Message: "Apache configuration valid and certificate accessible",
|
||||
ValidatedAt: time.Now(),
|
||||
Metadata: map[string]string{
|
||||
"validate_command": c.config.ValidateCommand,
|
||||
"duration_ms": fmt.Sprintf("%d", validationDuration.Milliseconds()),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
package apache_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/connector/target"
|
||||
"github.com/shankar0123/certctl/internal/connector/target/apache"
|
||||
)
|
||||
|
||||
func TestApacheConnector_ValidateConfig(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("valid config", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := apache.Config{
|
||||
CertPath: filepath.Join(tmpDir, "cert.pem"),
|
||||
KeyPath: filepath.Join(tmpDir, "key.pem"),
|
||||
ChainPath: filepath.Join(tmpDir, "chain.pem"),
|
||||
ReloadCommand: "echo reload",
|
||||
ValidateCommand: "echo ok",
|
||||
}
|
||||
|
||||
connector := apache.New(&cfg, logger)
|
||||
rawConfig, _ := json.Marshal(cfg)
|
||||
err := connector.ValidateConfig(ctx, rawConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateConfig failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing cert_path", func(t *testing.T) {
|
||||
cfg := apache.Config{
|
||||
ChainPath: "/tmp/chain.pem",
|
||||
ReloadCommand: "echo reload",
|
||||
ValidateCommand: "echo ok",
|
||||
}
|
||||
|
||||
connector := apache.New(&cfg, logger)
|
||||
rawConfig, _ := json.Marshal(cfg)
|
||||
err := connector.ValidateConfig(ctx, rawConfig)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing cert_path")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing reload_command", func(t *testing.T) {
|
||||
cfg := apache.Config{
|
||||
CertPath: "/tmp/cert.pem",
|
||||
ChainPath: "/tmp/chain.pem",
|
||||
ValidateCommand: "echo ok",
|
||||
}
|
||||
|
||||
connector := apache.New(&cfg, logger)
|
||||
rawConfig, _ := json.Marshal(cfg)
|
||||
err := connector.ValidateConfig(ctx, rawConfig)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing reload_command")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid JSON", func(t *testing.T) {
|
||||
connector := apache.New(&apache.Config{}, logger)
|
||||
err := connector.ValidateConfig(ctx, json.RawMessage(`{invalid}`))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestApacheConnector_DeployCertificate(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("successful deployment", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &apache.Config{
|
||||
CertPath: filepath.Join(tmpDir, "cert.pem"),
|
||||
KeyPath: filepath.Join(tmpDir, "key.pem"),
|
||||
ChainPath: filepath.Join(tmpDir, "chain.pem"),
|
||||
ReloadCommand: "echo reload",
|
||||
ValidateCommand: "echo ok",
|
||||
}
|
||||
|
||||
connector := apache.New(cfg, logger)
|
||||
|
||||
req := target.DeploymentRequest{
|
||||
CertPEM: "-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----",
|
||||
KeyPEM: "-----BEGIN EC PRIVATE KEY-----\ntest\n-----END EC PRIVATE KEY-----",
|
||||
ChainPEM: "-----BEGIN CERTIFICATE-----\nchain\n-----END CERTIFICATE-----",
|
||||
}
|
||||
|
||||
result, err := connector.DeployCertificate(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("DeployCertificate failed: %v", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
t.Fatalf("expected success, got: %s", result.Message)
|
||||
}
|
||||
|
||||
// Verify files were written
|
||||
certData, err := os.ReadFile(cfg.CertPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read cert file: %v", err)
|
||||
}
|
||||
if string(certData) != req.CertPEM {
|
||||
t.Errorf("cert content mismatch")
|
||||
}
|
||||
|
||||
// Verify key has secure permissions
|
||||
info, err := os.Stat(cfg.KeyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to stat key file: %v", err)
|
||||
}
|
||||
if info.Mode().Perm() != 0600 {
|
||||
t.Errorf("expected key permissions 0600, got %v", info.Mode().Perm())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validate command fails", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &apache.Config{
|
||||
CertPath: filepath.Join(tmpDir, "cert.pem"),
|
||||
KeyPath: filepath.Join(tmpDir, "key.pem"),
|
||||
ChainPath: filepath.Join(tmpDir, "chain.pem"),
|
||||
ReloadCommand: "echo reload",
|
||||
ValidateCommand: "false", // always fails
|
||||
}
|
||||
|
||||
connector := apache.New(cfg, logger)
|
||||
|
||||
req := target.DeploymentRequest{
|
||||
CertPEM: "cert",
|
||||
ChainPEM: "chain",
|
||||
}
|
||||
|
||||
result, err := connector.DeployCertificate(ctx, req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when validate command fails")
|
||||
}
|
||||
if result.Success {
|
||||
t.Fatal("expected failure result")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestApacheConnector_ValidateDeployment(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("valid deployment", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
os.WriteFile(certPath, []byte("cert"), 0644)
|
||||
|
||||
cfg := &apache.Config{
|
||||
CertPath: certPath,
|
||||
ValidateCommand: "echo ok",
|
||||
}
|
||||
|
||||
connector := apache.New(cfg, logger)
|
||||
|
||||
result, err := connector.ValidateDeployment(ctx, target.ValidationRequest{
|
||||
CertificateID: "mc-test",
|
||||
Serial: "123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateDeployment failed: %v", err)
|
||||
}
|
||||
if !result.Valid {
|
||||
t.Fatal("expected valid deployment")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing cert file", func(t *testing.T) {
|
||||
cfg := &apache.Config{
|
||||
CertPath: "/nonexistent/cert.pem",
|
||||
ValidateCommand: "echo ok",
|
||||
}
|
||||
|
||||
connector := apache.New(cfg, logger)
|
||||
|
||||
result, err := connector.ValidateDeployment(ctx, target.ValidationRequest{
|
||||
CertificateID: "mc-test",
|
||||
Serial: "123",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing cert file")
|
||||
}
|
||||
if result.Valid {
|
||||
t.Fatal("expected invalid result")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
package haproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/connector/target"
|
||||
)
|
||||
|
||||
// Config represents the HAProxy deployment target configuration.
|
||||
// HAProxy expects a combined PEM file containing the certificate, chain, and private key
|
||||
// concatenated in a single file.
|
||||
type Config struct {
|
||||
PEMPath string `json:"pem_path"` // Path for combined PEM (cert + chain + key)
|
||||
ReloadCommand string `json:"reload_command"` // Command to reload HAProxy (e.g., "systemctl reload haproxy")
|
||||
ValidateCommand string `json:"validate_command"` // Command to validate config (e.g., "haproxy -c -f /etc/haproxy/haproxy.cfg")
|
||||
}
|
||||
|
||||
// Connector implements the target.Connector interface for HAProxy servers.
|
||||
// This connector runs on the AGENT side and handles local certificate deployment.
|
||||
// HAProxy uses a combined PEM file (cert + chain + key) unlike NGINX/Apache which use
|
||||
// separate files.
|
||||
type Connector struct {
|
||||
config *Config
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// New creates a new HAProxy target connector with the given configuration and logger.
|
||||
func New(config *Config, logger *slog.Logger) *Connector {
|
||||
return &Connector{
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateConfig checks that all required configuration paths and commands are valid.
|
||||
func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessage) error {
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(rawConfig, &cfg); err != nil {
|
||||
return fmt.Errorf("invalid HAProxy config: %w", err)
|
||||
}
|
||||
|
||||
if cfg.PEMPath == "" {
|
||||
return fmt.Errorf("HAProxy pem_path is required")
|
||||
}
|
||||
|
||||
if cfg.ReloadCommand == "" {
|
||||
return fmt.Errorf("HAProxy reload_command is required")
|
||||
}
|
||||
|
||||
c.logger.Info("validating HAProxy configuration",
|
||||
"pem_path", cfg.PEMPath)
|
||||
|
||||
// Verify validate command works if provided
|
||||
if cfg.ValidateCommand != "" {
|
||||
cmd := exec.CommandContext(ctx, "sh", "-c", cfg.ValidateCommand)
|
||||
if err := cmd.Run(); err != nil {
|
||||
c.logger.Warn("HAProxy config validation failed during config check",
|
||||
"error", err,
|
||||
"validate_command", cfg.ValidateCommand)
|
||||
// Don't fail; HAProxy might not be installed yet
|
||||
}
|
||||
}
|
||||
|
||||
c.config = &cfg
|
||||
c.logger.Info("HAProxy configuration validated")
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeployCertificate creates a combined PEM file (cert + chain + key) and reloads HAProxy.
|
||||
//
|
||||
// HAProxy requires all TLS material in a single file, concatenated in this order:
|
||||
// 1. Server certificate
|
||||
// 2. Intermediate/chain certificates
|
||||
// 3. Private key
|
||||
//
|
||||
// Steps:
|
||||
// 1. Build combined PEM (cert + chain + key)
|
||||
// 2. Write to pem_path with mode 0600 (contains private key)
|
||||
// 3. Optionally validate HAProxy configuration
|
||||
// 4. Execute reload command
|
||||
func (c *Connector) DeployCertificate(ctx context.Context, request target.DeploymentRequest) (*target.DeploymentResult, error) {
|
||||
c.logger.Info("deploying certificate to HAProxy",
|
||||
"pem_path", c.config.PEMPath)
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// Build combined PEM: cert + chain + key
|
||||
combinedPEM := request.CertPEM + "\n"
|
||||
if request.ChainPEM != "" {
|
||||
combinedPEM += request.ChainPEM + "\n"
|
||||
}
|
||||
if request.KeyPEM != "" {
|
||||
combinedPEM += request.KeyPEM + "\n"
|
||||
}
|
||||
|
||||
// Write combined PEM with secure permissions (0600: contains private key)
|
||||
if err := os.WriteFile(c.config.PEMPath, []byte(combinedPEM), 0600); err != nil {
|
||||
errMsg := fmt.Sprintf("failed to write combined PEM: %v", err)
|
||||
c.logger.Error("PEM deployment failed", "error", err)
|
||||
return &target.DeploymentResult{
|
||||
Success: false,
|
||||
TargetAddress: c.config.PEMPath,
|
||||
Message: errMsg,
|
||||
DeployedAt: time.Now(),
|
||||
}, fmt.Errorf("%s", errMsg)
|
||||
}
|
||||
|
||||
// Validate HAProxy configuration if validate command is configured
|
||||
if c.config.ValidateCommand != "" {
|
||||
c.logger.Debug("validating HAProxy configuration", "validate_command", c.config.ValidateCommand)
|
||||
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand)
|
||||
if output, err := validateCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("HAProxy config validation failed: %v (output: %s)", err, string(output))
|
||||
c.logger.Error("HAProxy validation failed", "error", err, "output", string(output))
|
||||
return &target.DeploymentResult{
|
||||
Success: false,
|
||||
TargetAddress: c.config.PEMPath,
|
||||
Message: errMsg,
|
||||
DeployedAt: time.Now(),
|
||||
}, fmt.Errorf("%s", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// Reload HAProxy
|
||||
c.logger.Debug("reloading HAProxy", "reload_command", c.config.ReloadCommand)
|
||||
reloadCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ReloadCommand)
|
||||
if output, err := reloadCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("HAProxy reload failed: %v (output: %s)", err, string(output))
|
||||
c.logger.Error("HAProxy reload failed", "error", err, "output", string(output))
|
||||
return &target.DeploymentResult{
|
||||
Success: false,
|
||||
TargetAddress: c.config.PEMPath,
|
||||
Message: errMsg,
|
||||
DeployedAt: time.Now(),
|
||||
}, fmt.Errorf("%s", errMsg)
|
||||
}
|
||||
|
||||
deploymentDuration := time.Since(startTime)
|
||||
c.logger.Info("certificate deployed to HAProxy successfully",
|
||||
"duration", deploymentDuration.String(),
|
||||
"pem_path", c.config.PEMPath)
|
||||
|
||||
return &target.DeploymentResult{
|
||||
Success: true,
|
||||
TargetAddress: c.config.PEMPath,
|
||||
DeploymentID: fmt.Sprintf("haproxy-%d", time.Now().Unix()),
|
||||
Message: "Combined PEM deployed and HAProxy reloaded successfully",
|
||||
DeployedAt: time.Now(),
|
||||
Metadata: map[string]string{
|
||||
"pem_path": c.config.PEMPath,
|
||||
"duration_ms": fmt.Sprintf("%d", deploymentDuration.Milliseconds()),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateDeployment verifies that the deployed certificate is valid and accessible.
|
||||
func (c *Connector) ValidateDeployment(ctx context.Context, request target.ValidationRequest) (*target.ValidationResult, error) {
|
||||
c.logger.Info("validating HAProxy deployment",
|
||||
"certificate_id", request.CertificateID,
|
||||
"serial", request.Serial)
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// Validate HAProxy configuration if command provided
|
||||
if c.config.ValidateCommand != "" {
|
||||
validateCmd := exec.CommandContext(ctx, "sh", "-c", c.config.ValidateCommand)
|
||||
if output, err := validateCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("HAProxy config validation failed: %v (output: %s)", err, string(output))
|
||||
c.logger.Error("validation failed", "error", err)
|
||||
return &target.ValidationResult{
|
||||
Valid: false,
|
||||
Serial: request.Serial,
|
||||
TargetAddress: c.config.PEMPath,
|
||||
Message: errMsg,
|
||||
ValidatedAt: time.Now(),
|
||||
}, fmt.Errorf("%s", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify combined PEM file exists and is readable
|
||||
if _, err := os.Stat(c.config.PEMPath); os.IsNotExist(err) {
|
||||
errMsg := fmt.Sprintf("combined PEM file not found: %s", c.config.PEMPath)
|
||||
c.logger.Error("validation failed", "error", err)
|
||||
return &target.ValidationResult{
|
||||
Valid: false,
|
||||
Serial: request.Serial,
|
||||
TargetAddress: c.config.PEMPath,
|
||||
Message: errMsg,
|
||||
ValidatedAt: time.Now(),
|
||||
}, fmt.Errorf("%s", errMsg)
|
||||
}
|
||||
|
||||
validationDuration := time.Since(startTime)
|
||||
c.logger.Info("HAProxy deployment validated successfully",
|
||||
"duration", validationDuration.String())
|
||||
|
||||
return &target.ValidationResult{
|
||||
Valid: true,
|
||||
Serial: request.Serial,
|
||||
TargetAddress: c.config.PEMPath,
|
||||
Message: "HAProxy configuration valid and PEM accessible",
|
||||
ValidatedAt: time.Now(),
|
||||
Metadata: map[string]string{
|
||||
"pem_path": c.config.PEMPath,
|
||||
"duration_ms": fmt.Sprintf("%d", validationDuration.Milliseconds()),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
package haproxy_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/connector/target"
|
||||
"github.com/shankar0123/certctl/internal/connector/target/haproxy"
|
||||
)
|
||||
|
||||
func TestHAProxyConnector_ValidateConfig(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("valid config", func(t *testing.T) {
|
||||
cfg := haproxy.Config{
|
||||
PEMPath: "/tmp/haproxy/cert.pem",
|
||||
ReloadCommand: "echo reload",
|
||||
}
|
||||
|
||||
connector := haproxy.New(&cfg, logger)
|
||||
rawConfig, _ := json.Marshal(cfg)
|
||||
err := connector.ValidateConfig(ctx, rawConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateConfig failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing pem_path", func(t *testing.T) {
|
||||
cfg := haproxy.Config{
|
||||
ReloadCommand: "echo reload",
|
||||
}
|
||||
|
||||
connector := haproxy.New(&cfg, logger)
|
||||
rawConfig, _ := json.Marshal(cfg)
|
||||
err := connector.ValidateConfig(ctx, rawConfig)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing pem_path")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing reload_command", func(t *testing.T) {
|
||||
cfg := haproxy.Config{
|
||||
PEMPath: "/tmp/cert.pem",
|
||||
}
|
||||
|
||||
connector := haproxy.New(&cfg, logger)
|
||||
rawConfig, _ := json.Marshal(cfg)
|
||||
err := connector.ValidateConfig(ctx, rawConfig)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing reload_command")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid JSON", func(t *testing.T) {
|
||||
connector := haproxy.New(&haproxy.Config{}, logger)
|
||||
err := connector.ValidateConfig(ctx, json.RawMessage(`{invalid}`))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHAProxyConnector_DeployCertificate(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("successful deployment with combined PEM", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
pemPath := filepath.Join(tmpDir, "combined.pem")
|
||||
|
||||
cfg := &haproxy.Config{
|
||||
PEMPath: pemPath,
|
||||
ReloadCommand: "echo reload",
|
||||
}
|
||||
|
||||
connector := haproxy.New(cfg, logger)
|
||||
|
||||
certPEM := "-----BEGIN CERTIFICATE-----\ncert\n-----END CERTIFICATE-----"
|
||||
chainPEM := "-----BEGIN CERTIFICATE-----\nchain\n-----END CERTIFICATE-----"
|
||||
keyPEM := "-----BEGIN EC PRIVATE KEY-----\nkey\n-----END EC PRIVATE KEY-----"
|
||||
|
||||
req := target.DeploymentRequest{
|
||||
CertPEM: certPEM,
|
||||
KeyPEM: keyPEM,
|
||||
ChainPEM: chainPEM,
|
||||
}
|
||||
|
||||
result, err := connector.DeployCertificate(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("DeployCertificate failed: %v", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
t.Fatalf("expected success, got: %s", result.Message)
|
||||
}
|
||||
|
||||
// Verify combined PEM was written
|
||||
data, err := os.ReadFile(pemPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read PEM file: %v", err)
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
if !strings.Contains(content, "cert") {
|
||||
t.Error("combined PEM missing certificate")
|
||||
}
|
||||
if !strings.Contains(content, "chain") {
|
||||
t.Error("combined PEM missing chain")
|
||||
}
|
||||
if !strings.Contains(content, "key") {
|
||||
t.Error("combined PEM missing key")
|
||||
}
|
||||
|
||||
// Verify secure permissions (contains private key)
|
||||
info, err := os.Stat(pemPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to stat PEM file: %v", err)
|
||||
}
|
||||
if info.Mode().Perm() != 0600 {
|
||||
t.Errorf("expected PEM permissions 0600, got %v", info.Mode().Perm())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("reload command fails", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
pemPath := filepath.Join(tmpDir, "combined.pem")
|
||||
|
||||
cfg := &haproxy.Config{
|
||||
PEMPath: pemPath,
|
||||
ReloadCommand: "false", // always fails
|
||||
}
|
||||
|
||||
connector := haproxy.New(cfg, logger)
|
||||
|
||||
req := target.DeploymentRequest{
|
||||
CertPEM: "cert",
|
||||
}
|
||||
|
||||
result, err := connector.DeployCertificate(ctx, req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when reload command fails")
|
||||
}
|
||||
if result.Success {
|
||||
t.Fatal("expected failure result")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHAProxyConnector_ValidateDeployment(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("valid deployment", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
pemPath := filepath.Join(tmpDir, "combined.pem")
|
||||
os.WriteFile(pemPath, []byte("combined-pem-content"), 0600)
|
||||
|
||||
cfg := &haproxy.Config{
|
||||
PEMPath: pemPath,
|
||||
ReloadCommand: "echo reload",
|
||||
ValidateCommand: "echo ok",
|
||||
}
|
||||
|
||||
connector := haproxy.New(cfg, logger)
|
||||
|
||||
result, err := connector.ValidateDeployment(ctx, target.ValidationRequest{
|
||||
CertificateID: "mc-test",
|
||||
Serial: "123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateDeployment failed: %v", err)
|
||||
}
|
||||
if !result.Valid {
|
||||
t.Fatal("expected valid deployment")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing PEM file", func(t *testing.T) {
|
||||
cfg := &haproxy.Config{
|
||||
PEMPath: "/nonexistent/combined.pem",
|
||||
ReloadCommand: "echo reload",
|
||||
}
|
||||
|
||||
connector := haproxy.New(cfg, logger)
|
||||
|
||||
result, err := connector.ValidateDeployment(ctx, target.ValidationRequest{
|
||||
CertificateID: "mc-test",
|
||||
Serial: "123",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing PEM file")
|
||||
}
|
||||
if result.Valid {
|
||||
t.Fatal("expected invalid result")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -37,6 +37,19 @@ type Agent struct {
|
||||
LastHeartbeatAt *time.Time `json:"last_heartbeat_at,omitempty"`
|
||||
RegisteredAt time.Time `json:"registered_at"`
|
||||
APIKeyHash string `json:"api_key_hash"`
|
||||
OS string `json:"os"`
|
||||
Architecture string `json:"architecture"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// AgentMetadata contains runtime metadata reported by agents via heartbeat.
|
||||
type AgentMetadata struct {
|
||||
OS string `json:"os"`
|
||||
Architecture string `json:"architecture"`
|
||||
Hostname string `json:"hostname"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// AgentStatus represents the operational status of an agent.
|
||||
@@ -60,7 +73,9 @@ const (
|
||||
type TargetType string
|
||||
|
||||
const (
|
||||
TargetTypeNGINX TargetType = "NGINX"
|
||||
TargetTypeF5 TargetType = "F5"
|
||||
TargetTypeIIS TargetType = "IIS"
|
||||
TargetTypeNGINX TargetType = "NGINX"
|
||||
TargetTypeApache TargetType = "Apache"
|
||||
TargetTypeHAProxy TargetType = "HAProxy"
|
||||
TargetTypeF5 TargetType = "F5"
|
||||
TargetTypeIIS TargetType = "IIS"
|
||||
)
|
||||
|
||||
@@ -684,7 +684,7 @@ func (m *mockAgentRepository) Delete(ctx context.Context, id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAgentRepository) UpdateHeartbeat(ctx context.Context, id string) error {
|
||||
func (m *mockAgentRepository) UpdateHeartbeat(ctx context.Context, id string, metadata *domain.AgentMetadata) error {
|
||||
agent, ok := m.agents[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("agent not found")
|
||||
|
||||
@@ -69,8 +69,8 @@ type AgentRepository interface {
|
||||
Update(ctx context.Context, agent *domain.Agent) error
|
||||
// Delete removes an agent.
|
||||
Delete(ctx context.Context, id string) error
|
||||
// UpdateHeartbeat updates the agent's last heartbeat timestamp.
|
||||
UpdateHeartbeat(ctx context.Context, id string) error
|
||||
// UpdateHeartbeat updates the agent's last heartbeat timestamp and metadata.
|
||||
UpdateHeartbeat(ctx context.Context, id string, metadata *domain.AgentMetadata) error
|
||||
// GetByAPIKey retrieves an agent by hashed API key.
|
||||
GetByAPIKey(ctx context.Context, keyHash string) (*domain.Agent, error)
|
||||
}
|
||||
|
||||
@@ -23,7 +23,8 @@ func NewAgentRepository(db *sql.DB) *AgentRepository {
|
||||
// List returns all agents
|
||||
func (r *AgentRepository) List(ctx context.Context) ([]*domain.Agent, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash
|
||||
SELECT id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash,
|
||||
os, architecture, ip_address, version
|
||||
FROM agents
|
||||
ORDER BY registered_at DESC
|
||||
`)
|
||||
@@ -52,7 +53,8 @@ func (r *AgentRepository) List(ctx context.Context) ([]*domain.Agent, error) {
|
||||
// Get retrieves an agent by ID
|
||||
func (r *AgentRepository) Get(ctx context.Context, id string) (*domain.Agent, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash
|
||||
SELECT id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash,
|
||||
os, architecture, ip_address, version
|
||||
FROM agents
|
||||
WHERE id = $1
|
||||
`, id)
|
||||
@@ -75,11 +77,13 @@ func (r *AgentRepository) Create(ctx context.Context, agent *domain.Agent) error
|
||||
}
|
||||
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO agents (id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
INSERT INTO agents (id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash,
|
||||
os, architecture, ip_address, version)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
RETURNING id
|
||||
`, agent.ID, agent.Name, agent.Hostname, agent.Status, agent.LastHeartbeatAt,
|
||||
agent.RegisteredAt, agent.APIKeyHash).Scan(&agent.ID)
|
||||
agent.RegisteredAt, agent.APIKeyHash,
|
||||
agent.OS, agent.Architecture, agent.IPAddress, agent.Version).Scan(&agent.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create agent: %w", err)
|
||||
@@ -96,9 +100,14 @@ func (r *AgentRepository) Update(ctx context.Context, agent *domain.Agent) error
|
||||
hostname = $2,
|
||||
status = $3,
|
||||
last_heartbeat_at = $4,
|
||||
api_key_hash = $5
|
||||
WHERE id = $6
|
||||
`, agent.Name, agent.Hostname, agent.Status, agent.LastHeartbeatAt, agent.APIKeyHash, agent.ID)
|
||||
api_key_hash = $5,
|
||||
os = $6,
|
||||
architecture = $7,
|
||||
ip_address = $8,
|
||||
version = $9
|
||||
WHERE id = $10
|
||||
`, agent.Name, agent.Hostname, agent.Status, agent.LastHeartbeatAt, agent.APIKeyHash,
|
||||
agent.OS, agent.Architecture, agent.IPAddress, agent.Version, agent.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update agent: %w", err)
|
||||
@@ -136,11 +145,27 @@ func (r *AgentRepository) Delete(ctx context.Context, id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateHeartbeat updates the agent's last heartbeat timestamp
|
||||
func (r *AgentRepository) UpdateHeartbeat(ctx context.Context, id string) error {
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE agents SET last_heartbeat_at = $1 WHERE id = $2
|
||||
`, time.Now(), id)
|
||||
// UpdateHeartbeat updates the agent's last heartbeat timestamp and metadata
|
||||
func (r *AgentRepository) UpdateHeartbeat(ctx context.Context, id string, metadata *domain.AgentMetadata) error {
|
||||
var result sql.Result
|
||||
var err error
|
||||
|
||||
if metadata != nil {
|
||||
result, err = r.db.ExecContext(ctx, `
|
||||
UPDATE agents SET
|
||||
last_heartbeat_at = $1,
|
||||
hostname = CASE WHEN $3 = '' THEN hostname ELSE $3 END,
|
||||
os = CASE WHEN $4 = '' THEN os ELSE $4 END,
|
||||
architecture = CASE WHEN $5 = '' THEN architecture ELSE $5 END,
|
||||
ip_address = CASE WHEN $6 = '' THEN ip_address ELSE $6 END,
|
||||
version = CASE WHEN $7 = '' THEN version ELSE $7 END
|
||||
WHERE id = $2
|
||||
`, time.Now(), id, metadata.Hostname, metadata.OS, metadata.Architecture, metadata.IPAddress, metadata.Version)
|
||||
} else {
|
||||
result, err = r.db.ExecContext(ctx, `
|
||||
UPDATE agents SET last_heartbeat_at = $1 WHERE id = $2
|
||||
`, time.Now(), id)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update heartbeat: %w", err)
|
||||
@@ -161,7 +186,8 @@ func (r *AgentRepository) UpdateHeartbeat(ctx context.Context, id string) error
|
||||
// GetByAPIKey retrieves an agent by hashed API key
|
||||
func (r *AgentRepository) GetByAPIKey(ctx context.Context, keyHash string) (*domain.Agent, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash
|
||||
SELECT id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash,
|
||||
os, architecture, ip_address, version
|
||||
FROM agents
|
||||
WHERE api_key_hash = $1
|
||||
`, keyHash)
|
||||
@@ -183,7 +209,8 @@ func scanAgent(scanner interface {
|
||||
}) (*domain.Agent, error) {
|
||||
var agent domain.Agent
|
||||
err := scanner.Scan(&agent.ID, &agent.Name, &agent.Hostname, &agent.Status,
|
||||
&agent.LastHeartbeatAt, &agent.RegisteredAt, &agent.APIKeyHash)
|
||||
&agent.LastHeartbeatAt, &agent.RegisteredAt, &agent.APIKeyHash,
|
||||
&agent.OS, &agent.Architecture, &agent.IPAddress, &agent.Version)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan agent: %w", err)
|
||||
|
||||
@@ -81,15 +81,15 @@ func (s *AgentService) Register(ctx context.Context, name string, hostname strin
|
||||
return agent, apiKey, nil
|
||||
}
|
||||
|
||||
// HeartbeatWithContext updates an agent's last seen time and status.
|
||||
func (s *AgentService) HeartbeatWithContext(ctx context.Context, agentID string) error {
|
||||
// HeartbeatWithContext updates an agent's last seen time, status, and metadata.
|
||||
func (s *AgentService) HeartbeatWithContext(ctx context.Context, agentID string, metadata *domain.AgentMetadata) error {
|
||||
agent, err := s.agentRepo.Get(ctx, agentID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch agent: %w", err)
|
||||
}
|
||||
|
||||
// Update heartbeat
|
||||
if err := s.agentRepo.UpdateHeartbeat(ctx, agentID); err != nil {
|
||||
// Update heartbeat and metadata
|
||||
if err := s.agentRepo.UpdateHeartbeat(ctx, agentID, metadata); err != nil {
|
||||
return fmt.Errorf("failed to update heartbeat: %w", err)
|
||||
}
|
||||
|
||||
@@ -105,8 +105,8 @@ func (s *AgentService) HeartbeatWithContext(ctx context.Context, agentID string)
|
||||
}
|
||||
|
||||
// Heartbeat updates agent heartbeat (handler interface method).
|
||||
func (s *AgentService) Heartbeat(agentID string) error {
|
||||
return s.HeartbeatWithContext(context.Background(), agentID)
|
||||
func (s *AgentService) Heartbeat(agentID string, metadata *domain.AgentMetadata) error {
|
||||
return s.HeartbeatWithContext(context.Background(), agentID, metadata)
|
||||
}
|
||||
|
||||
// SubmitCSR validates and processes a Certificate Signing Request from an agent.
|
||||
|
||||
@@ -89,7 +89,7 @@ func TestHeartbeat(t *testing.T) {
|
||||
|
||||
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry, nil)
|
||||
|
||||
err := agentService.HeartbeatWithContext(ctx, "agent-001")
|
||||
err := agentService.HeartbeatWithContext(ctx, "agent-001", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Heartbeat failed: %v", err)
|
||||
}
|
||||
@@ -122,7 +122,7 @@ func TestHeartbeat_NotFound(t *testing.T) {
|
||||
|
||||
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry, nil)
|
||||
|
||||
err := agentService.HeartbeatWithContext(ctx, "nonexistent")
|
||||
err := agentService.HeartbeatWithContext(ctx, "nonexistent", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent agent")
|
||||
}
|
||||
|
||||
@@ -477,7 +477,7 @@ func (m *mockAgentRepo) Delete(ctx context.Context, id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAgentRepo) UpdateHeartbeat(ctx context.Context, id string) error {
|
||||
func (m *mockAgentRepo) UpdateHeartbeat(ctx context.Context, id string, metadata *domain.AgentMetadata) error {
|
||||
if m.UpdateHeartbeatErr != nil {
|
||||
return m.UpdateHeartbeatErr
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user