diff --git a/internal/connector/target/ssh/ssh.go b/internal/connector/target/ssh/ssh.go index 6239fdf..9c23671 100644 --- a/internal/connector/target/ssh/ssh.go +++ b/internal/connector/target/ssh/ssh.go @@ -7,7 +7,9 @@ package ssh import ( "context" "encoding/json" + "errors" "fmt" + "io" "log/slog" "net" "os" @@ -52,8 +54,22 @@ type SSHClient interface { WriteFile(remotePath string, data []byte, mode os.FileMode) error // Execute runs a command on the remote server and returns combined output. Execute(ctx context.Context, command string) (string, error) - // StatFile checks if a remote file exists and returns its size. - StatFile(remotePath string) (int64, error) + // StatFile returns os.FileInfo for a remote file. The Mode field is + // load-bearing for the Bundle 6 pre-deploy snapshot — restoring an + // original file requires the original mode for fidelity. Callers + // detect "file does not exist" via errors.Is(err, os.ErrNotExist). + StatFile(remotePath string) (os.FileInfo, error) + // ReadFile reads the entire contents of a remote file. Used by the + // Bundle 6 pre-deploy snapshot to capture original bytes for + // reload-failure rollback. Callers should StatFile first to bound + // the read size. + ReadFile(remotePath string) ([]byte, error) + // Remove deletes a remote file. Used by the Bundle 6 rollback path + // to clean up first-time-deploy partial state — when reload fails + // and the path didn't exist pre-deploy, the new bytes must come + // off the remote so the daemon doesn't pick them up on a later + // manual restart. + Remove(remotePath string) error // Close closes the SSH connection. Close() error } @@ -192,12 +208,20 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag // DeployCertificate deploys a certificate to the remote server via SSH/SFTP. // // Steps: -// 1. Connect to remote host via SSH -// 2. Write certificate (+ chain if chain_path not set) to cert_path -// 3. Write private key to key_path with restricted permissions -// 4. If chain_path is set and chain provided, write chain separately -// 5. If reload_command is set, execute it via SSH -// 6. Close connection +// 1. Connect to remote host via SSH. +// 2. Pre-deploy snapshot (Bundle 6, 2026-05-02 audit): for each path the +// deploy will write to (cert, key, optional chain), capture original +// bytes + mode into in-memory backup buffers. StatFile errors with +// os.ErrNotExist mean the path doesn't exist (rollback = remove); +// other stat errors bail out before any write happens. +// 3. Write certificate (+ chain appended if chain_path not set) to cert_path. +// 4. Write private key to key_path with restricted permissions. +// 5. If chain_path is set and chain provided, write chain separately. +// 6. If reload_command is set, execute it via SSH. +// 7. On reload failure, restore each backed-up file (or Remove if no +// pre-existing) and re-run reload as a best-effort retry. The remote +// ends up in pre-deploy state if the rollback succeeds. +// 8. Close connection. func (c *Connector) DeployCertificate(ctx context.Context, request target.DeploymentRequest) (*target.DeploymentResult, error) { c.logger.Info("deploying certificate via SSH", "host", c.config.Host, @@ -220,6 +244,18 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy } defer c.client.Close() + // Validate we have a private key (required for the deploy to proceed) + if request.KeyPEM == "" { + errMsg := "SSH deployment requires private key (KeyPEM)" + c.logger.Error("missing private key") + return &target.DeploymentResult{ + Success: false, + TargetAddress: fmt.Sprintf("%s:%d", c.config.Host, c.config.Port), + Message: errMsg, + DeployedAt: time.Now(), + }, fmt.Errorf("%s", errMsg) + } + // Parse file permissions certMode, _ := parsePermissions(c.config.CertMode) keyMode, _ := parsePermissions(c.config.KeyMode) @@ -230,6 +266,79 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy certData += "\n" + request.ChainPEM } + // Bundle 6: determine the paths the deploy will write to. Chain is + // written separately only when ChainPath is configured AND ChainPEM + // is non-empty (otherwise the chain is appended to the cert above). + chainSeparate := c.config.ChainPath != "" && request.ChainPEM != "" + + type pathSpec struct { + key string // metadata key suffix: "cert" / "key" / "chain" + path string + } + writePaths := []pathSpec{ + {"cert", c.config.CertPath}, + {"key", c.config.KeyPath}, + } + if chainSeparate { + writePaths = append(writePaths, pathSpec{"chain", c.config.ChainPath}) + } + + // Bundle 6: pre-deploy snapshot. For each path the deploy will touch, + // StatFile to detect existence; if present, ReadFile into an in-memory + // backup buffer keyed by remote path. Original mode captured for + // fidelity on restore. Empty backup map entry = first-time deploy + // (rollback for that path = Remove). + backups := make(map[string][]byte) + modes := make(map[string]os.FileMode) + backupStatus := map[string]string{ + "cert": "no_pre_existing", + "key": "no_pre_existing", + "chain": "n/a", + } + if chainSeparate { + backupStatus["chain"] = "no_pre_existing" + } + + for _, p := range writePaths { + info, err := c.client.StatFile(p.path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + // First-time deploy for this path. Rollback = Remove. + continue + } + // Real stat error — bail out before writing anything. + errMsg := fmt.Sprintf("pre-deploy stat failed for %s: %v", p.path, err) + c.logger.Error("pre-deploy stat failed", "error", err, "path", p.path) + return &target.DeploymentResult{ + Success: false, + TargetAddress: fmt.Sprintf("%s:%d", c.config.Host, c.config.Port), + Message: errMsg, + DeployedAt: time.Now(), + }, fmt.Errorf("%s", errMsg) + } + data, err := c.client.ReadFile(p.path) + if err != nil { + // File exists per stat but read failed — outage signal. Bail. + errMsg := fmt.Sprintf("pre-deploy backup read failed for %s: %v", p.path, err) + c.logger.Error("pre-deploy backup read failed", "error", err, "path", p.path) + return &target.DeploymentResult{ + Success: false, + TargetAddress: fmt.Sprintf("%s:%d", c.config.Host, c.config.Port), + Message: errMsg, + DeployedAt: time.Now(), + }, fmt.Errorf("%s", errMsg) + } + backups[p.path] = data + if info != nil { + modes[p.path] = info.Mode().Perm() + } + backupStatus[p.key] = "snapshotted" + c.logger.Debug("pre-deploy snapshot captured", + "path", p.path, + "size_bytes", len(data), + "mode", modes[p.path]) + } + // Write certificate if err := c.client.WriteFile(c.config.CertPath, []byte(certData), certMode); err != nil { errMsg := fmt.Sprintf("failed to write certificate: %v", err) @@ -242,17 +351,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy }, fmt.Errorf("%s", errMsg) } - // Write private key (must have KeyPEM) - if request.KeyPEM == "" { - errMsg := "SSH deployment requires private key (KeyPEM)" - c.logger.Error("missing private key") - return &target.DeploymentResult{ - Success: false, - TargetAddress: fmt.Sprintf("%s:%d", c.config.Host, c.config.Port), - Message: errMsg, - DeployedAt: time.Now(), - }, fmt.Errorf("%s", errMsg) - } + // Write private key if err := c.client.WriteFile(c.config.KeyPath, []byte(request.KeyPEM), keyMode); err != nil { errMsg := fmt.Sprintf("failed to write private key: %v", err) c.logger.Error("key write failed", "error", err, "path", c.config.KeyPath) @@ -265,7 +364,7 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy } // Write chain separately if chain_path configured - if c.config.ChainPath != "" && request.ChainPEM != "" { + if chainSeparate { if err := c.client.WriteFile(c.config.ChainPath, []byte(request.ChainPEM), certMode); err != nil { errMsg := fmt.Sprintf("failed to write chain: %v", err) c.logger.Error("chain write failed", "error", err, "path", c.config.ChainPath) @@ -283,13 +382,88 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy c.logger.Debug("executing reload command", "command", c.config.ReloadCommand) output, err := c.client.Execute(ctx, c.config.ReloadCommand) if err != nil { - errMsg := fmt.Sprintf("reload command failed: %v (output: %s)", err, output) - c.logger.Error("reload command failed", "error", err, "output", output) + // Bundle 6: reload failed. Walk the writePaths list and either + // restore from the in-memory backup (file existed pre-deploy) + // or Remove (first-time deploy partial state). Re-run reload + // as a best-effort retry once restore completes — if THAT + // succeeds the target is fully back to pre-deploy state. + c.logger.Error("reload command failed; attempting rollback", + "error", err, + "output", output, + "reload_command", c.config.ReloadCommand) + var paths []string + for _, p := range writePaths { + paths = append(paths, p.path) + } + rollbackErr, restoreStatuses := c.restoreFromBackups(ctx, paths, backups, modes) + // Merge per-key restore status into backupStatus so operators + // see whether the rollback ran cleanly per file. restoreFromBackups + // returns statuses keyed by metadata key (cert/key/chain), not + // by remote path. + for _, p := range writePaths { + if s, ok := restoreStatuses[p.key]; ok { + backupStatus[p.key] = s + } + } + + if rollbackErr != nil { + // Both reload AND rollback failed — operator-actionable. + combined := fmt.Errorf("reload failed (%w); rollback also failed (%v); manual operator inspection required", err, rollbackErr) + c.logger.Error("SSH rollback also failed", + "reload_error", err, + "rollback_error", rollbackErr) + return &target.DeploymentResult{ + Success: false, + TargetAddress: fmt.Sprintf("%s:%d", c.config.Host, c.config.Port), + Message: combined.Error(), + DeployedAt: time.Now(), + Metadata: buildMetadataWithBackup(c.config, startTime, backupStatus, map[string]string{ + "reload_error": output, + "rollback_error": rollbackErr.Error(), + "rolled_back": "false", + "manual_action_required": "true", + }), + }, combined + } + + // Rollback succeeded. Best-effort retry-reload — if it works, + // the daemon is serving the original cert again. If it fails, + // remote files are pre-deploy but daemon may be in a stuck + // state; surface as wrapped error so the operator knows to + // investigate the daemon, not the files. + retryOutput, retryErr := c.client.Execute(ctx, c.config.ReloadCommand) + if retryErr != nil { + wrapped := fmt.Errorf("reload failed (%w); rolled back files; retry-reload also failed (%v) — daemon may need manual restart", err, retryErr) + c.logger.Error("SSH retry-reload after rollback failed", + "reload_error", err, + "retry_reload_error", retryErr, + "retry_output", retryOutput) + return &target.DeploymentResult{ + Success: false, + TargetAddress: fmt.Sprintf("%s:%d", c.config.Host, c.config.Port), + Message: wrapped.Error(), + DeployedAt: time.Now(), + Metadata: buildMetadataWithBackup(c.config, startTime, backupStatus, map[string]string{ + "reload_error": output, + "retry_reload_error": retryOutput, + "rolled_back": "true", + "daemon_state_unknown": "true", + }), + }, wrapped + } + + // Clean recoverable failure: files restored, daemon reloaded + // to pre-deploy state. + errMsg := fmt.Sprintf("reload command failed; rolled back to pre-deploy state: %v (output: %s)", err, output) return &target.DeploymentResult{ Success: false, TargetAddress: fmt.Sprintf("%s:%d", c.config.Host, c.config.Port), Message: errMsg, DeployedAt: time.Now(), + Metadata: buildMetadataWithBackup(c.config, startTime, backupStatus, map[string]string{ + "reload_error": output, + "rolled_back": "true", + }), }, fmt.Errorf("%s", errMsg) } } @@ -306,15 +480,95 @@ func (c *Connector) DeployCertificate(ctx context.Context, request target.Deploy DeploymentID: fmt.Sprintf("ssh-%s-%d", c.config.Host, time.Now().Unix()), Message: fmt.Sprintf("Certificate deployed via SSH to %s", c.config.Host), DeployedAt: time.Now(), - Metadata: map[string]string{ - "host": c.config.Host, - "cert_path": c.config.CertPath, - "key_path": c.config.KeyPath, - "duration_ms": fmt.Sprintf("%d", deploymentDuration.Milliseconds()), - }, + Metadata: buildMetadataWithBackup(c.config, startTime, backupStatus, nil), }, nil } +// restoreFromBackups walks the configured deploy paths and either restores +// each path from the in-memory backup (when the file existed pre-deploy) or +// Removes the new bytes (first-time-deploy partial state). Returns the +// first error encountered — caller surfaces the wrapped error to the +// operator. The per-path status map is always populated so callers can +// emit accurate Metadata. +// +// Bundle 6 of the 2026-05-02 deployment-target audit. +func (c *Connector) restoreFromBackups(ctx context.Context, paths []string, backups map[string][]byte, modes map[string]os.FileMode) (error, map[string]string) { + statuses := make(map[string]string, len(paths)) + pathToKey := map[string]string{ + c.config.CertPath: "cert", + c.config.KeyPath: "key", + c.config.ChainPath: "chain", + } + + var firstErr error + for _, path := range paths { + key := pathToKey[path] + if data, ok := backups[path]; ok { + // File existed pre-deploy — restore from backup with the + // original mode (default 0600 if mode capture failed). + mode := modes[path] + if mode == 0 { + mode = 0600 + } + if err := c.client.WriteFile(path, data, mode); err != nil { + wrapped := fmt.Errorf("restore %s: %w", path, err) + if firstErr == nil { + firstErr = wrapped + } + if key != "" { + statuses[key] = "restore_failed" + } + c.logger.Error("rollback restore failed", "error", err, "path", path) + continue + } + if key != "" { + statuses[key] = "restored" + } + c.logger.Info("rollback restored file from backup", "path", path, "size_bytes", len(data)) + } else { + // First-time deploy for this path — Remove the new bytes. + if err := c.client.Remove(path); err != nil { + wrapped := fmt.Errorf("remove %s: %w", path, err) + if firstErr == nil { + firstErr = wrapped + } + if key != "" { + statuses[key] = "remove_failed" + } + c.logger.Error("rollback remove failed", "error", err, "path", path) + continue + } + if key != "" { + statuses[key] = "removed" + } + c.logger.Info("rollback removed first-time-deploy file", "path", path) + } + } + return firstErr, statuses +} + +// buildMetadataWithBackup assembles the per-deploy Metadata map with the +// standard host / cert_path / key_path / duration_ms fields plus the +// per-path backup_status_{cert,key,chain} fields populated from the +// snapshot phase. Extra k/v pairs (e.g. error context) are merged on top. +// +// Bundle 6 of the 2026-05-02 deployment-target audit. +func buildMetadataWithBackup(cfg *Config, startTime time.Time, backupStatus map[string]string, extra map[string]string) map[string]string { + md := map[string]string{ + "host": cfg.Host, + "cert_path": cfg.CertPath, + "key_path": cfg.KeyPath, + "duration_ms": fmt.Sprintf("%d", time.Since(startTime).Milliseconds()), + "backup_status_cert": backupStatus["cert"], + "backup_status_key": backupStatus["key"], + "backup_status_chain": backupStatus["chain"], + } + for k, v := range extra { + md[k] = v + } + return md +} + // ValidateDeployment verifies that the deployed certificate files exist on the remote server. func (c *Connector) ValidateDeployment(ctx context.Context, request target.ValidationRequest) (*target.ValidationResult, error) { c.logger.Info("validating SSH deployment", @@ -532,18 +786,59 @@ func (c *realSSHClient) Execute(ctx context.Context, command string) (string, er return string(output), err } -// StatFile checks if a remote file exists and returns its size. -func (c *realSSHClient) StatFile(remotePath string) (int64, error) { +// StatFile returns os.FileInfo for a remote file via SFTP. Bundle 6 evolved +// the signature from int64 (size only) to os.FileInfo so the pre-deploy +// snapshot can capture the original mode for accurate rollback restoration. +// Errors from SFTP wrapping a non-existent-file syscall preserve the +// os.ErrNotExist sentinel through the %w wrap, so callers can use +// errors.Is(err, os.ErrNotExist) to distinguish "file doesn't exist" from +// real stat errors. +func (c *realSSHClient) StatFile(remotePath string) (os.FileInfo, error) { if c.sftpClient == nil { - return 0, fmt.Errorf("SFTP client not connected") + return nil, fmt.Errorf("SFTP client not connected") } info, err := c.sftpClient.Stat(remotePath) if err != nil { - return 0, fmt.Errorf("failed to stat remote file %s: %w", remotePath, err) + return nil, fmt.Errorf("failed to stat remote file %s: %w", remotePath, err) } - return info.Size(), nil + return info, nil +} + +// ReadFile reads the entire contents of a remote file via SFTP. Used by +// Bundle 6's pre-deploy snapshot to capture original bytes for the +// reload-failure rollback path. Callers cap the read size by inspecting +// StatFile first. +func (c *realSSHClient) ReadFile(remotePath string) ([]byte, error) { + if c.sftpClient == nil { + return nil, fmt.Errorf("SFTP client not connected") + } + + f, err := c.sftpClient.Open(remotePath) + if err != nil { + return nil, fmt.Errorf("sftp open %s: %w", remotePath, err) + } + defer f.Close() + + data, err := io.ReadAll(f) + if err != nil { + return nil, fmt.Errorf("sftp read %s: %w", remotePath, err) + } + return data, nil +} + +// Remove deletes a remote file via SFTP. Used by Bundle 6's rollback path +// to clean up first-time-deploy partial state — when reload fails and the +// path didn't exist pre-deploy, the new bytes must come off the remote. +func (c *realSSHClient) Remove(remotePath string) error { + if c.sftpClient == nil { + return fmt.Errorf("SFTP client not connected") + } + if err := c.sftpClient.Remove(remotePath); err != nil { + return fmt.Errorf("sftp remove %s: %w", remotePath, err) + } + return nil } // Close closes the SFTP and SSH connections. diff --git a/internal/connector/target/ssh/ssh_server_fixture_test.go b/internal/connector/target/ssh/ssh_server_fixture_test.go index e8e1bf0..b9a0e10 100644 --- a/internal/connector/target/ssh/ssh_server_fixture_test.go +++ b/internal/connector/target/ssh/ssh_server_fixture_test.go @@ -532,21 +532,21 @@ func TestRealSSHClient_WriteFile_StatFile_RoundTrip(t *testing.T) { t.Fatalf("WriteFile: %v", err) } - size, err := c.StatFile(target) + info, err := c.StatFile(target) if err != nil { t.Fatalf("StatFile: %v", err) } - if size != int64(len(payload)) { - t.Errorf("expected size %d, got %d", len(payload), size) + if info.Size() != int64(len(payload)) { + t.Errorf("expected size %d, got %d", len(payload), info.Size()) } // Verify mode 0640 was set. - info, err := os.Stat(target) + osInfo, err := os.Stat(target) if err != nil { t.Fatalf("os.Stat: %v", err) } - if info.Mode().Perm() != 0640 { - t.Errorf("expected mode 0640, got %v", info.Mode().Perm()) + if osInfo.Mode().Perm() != 0640 { + t.Errorf("expected mode 0640, got %v", osInfo.Mode().Perm()) } // Verify content round-trips. diff --git a/internal/connector/target/ssh/ssh_test.go b/internal/connector/target/ssh/ssh_test.go index ae3ea4a..c8c77cc 100644 --- a/internal/connector/target/ssh/ssh_test.go +++ b/internal/connector/target/ssh/ssh_test.go @@ -7,10 +7,29 @@ import ( "log/slog" "os" "testing" + "time" "github.com/shankar0123/certctl/internal/connector/target" ) +// stubFileInfo implements os.FileInfo for tests that need to return a +// FileInfo from the mock SSHClient's StatFile. Bundle 6 of the +// 2026-05-02 deployment-target audit evolved StatFile's signature from +// (int64, error) to (os.FileInfo, error) so the pre-deploy snapshot +// can capture the original mode for accurate rollback restoration. +type stubFileInfo struct { + size int64 + mode os.FileMode + name string +} + +func (s *stubFileInfo) Name() string { return s.name } +func (s *stubFileInfo) Size() int64 { return s.size } +func (s *stubFileInfo) Mode() os.FileMode { return s.mode } +func (s *stubFileInfo) ModTime() time.Time { return time.Time{} } +func (s *stubFileInfo) IsDir() bool { return false } +func (s *stubFileInfo) Sys() any { return nil } + // testLogger returns a slog.Logger for test output. func testLogger() *slog.Logger { return slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn})) @@ -19,18 +38,37 @@ func testLogger() *slog.Logger { // --- Mock SSH Client --- // mockSSHClient records all calls and returns configurable results. +// +// Bundle 6 of the 2026-05-02 deployment-target audit added per-path +// response maps (statByPath / readByPath / writeFileErrByPath) so the +// new snapshot/rollback tests can simulate (a) pre-existing remote +// files for the snapshot to read, (b) per-call WriteFile failures to +// inject restore-failure paths, and (c) sequenced Execute errors so +// the reload-then-retry-reload tests can drive both calls +// independently. The legacy global fields (statFileSize / statFileErr +// / writeFileErr / executeErr) are still honored when no per-path +// override matches, so existing tests remain green. type mockSSHClient struct { - connectCalls int - connectErr error - writeFileCalls []writeFileCall - writeFileErr error - executeCalls []string - executeOutput string - executeErr error - statFileCalls []string - statFileSize int64 - statFileErr error - closeCalls int + connectCalls int + connectErr error + writeFileCalls []writeFileCall + writeFileErr error + writeFileErrByPath map[string]error // per-path WriteFile error overrides + executeCalls []string + executeOutput string + executeErr error + executeErrSequence []error // per-call Execute errors; falls back to executeErr after exhaustion + executeOutSequence []string // per-call Execute outputs; mirrors executeErrSequence + statFileCalls []string + statFileSize int64 + statFileErr error + statByPath map[string]statResponse // per-path StatFile responses + readByPath map[string][]byte // per-path ReadFile bytes (existence implies success) + readErrByPath map[string]error // per-path ReadFile error overrides + removeCalls []string + removeErr error + removeErrByPath map[string]error + closeCalls int } type writeFileCall struct { @@ -39,6 +77,11 @@ type writeFileCall struct { Mode os.FileMode } +type statResponse struct { + info os.FileInfo + err error +} + func (m *mockSSHClient) Connect(ctx context.Context) error { m.connectCalls++ return m.connectErr @@ -46,17 +89,66 @@ func (m *mockSSHClient) Connect(ctx context.Context) error { func (m *mockSSHClient) WriteFile(remotePath string, data []byte, mode os.FileMode) error { m.writeFileCalls = append(m.writeFileCalls, writeFileCall{Path: remotePath, Data: data, Mode: mode}) + if m.writeFileErrByPath != nil { + if err, ok := m.writeFileErrByPath[remotePath]; ok { + return err + } + } return m.writeFileErr } func (m *mockSSHClient) Execute(ctx context.Context, command string) (string, error) { + idx := len(m.executeCalls) m.executeCalls = append(m.executeCalls, command) + if idx < len(m.executeErrSequence) { + out := "" + if idx < len(m.executeOutSequence) { + out = m.executeOutSequence[idx] + } + return out, m.executeErrSequence[idx] + } return m.executeOutput, m.executeErr } -func (m *mockSSHClient) StatFile(remotePath string) (int64, error) { +func (m *mockSSHClient) StatFile(remotePath string) (os.FileInfo, error) { m.statFileCalls = append(m.statFileCalls, remotePath) - return m.statFileSize, m.statFileErr + if m.statByPath != nil { + if resp, ok := m.statByPath[remotePath]; ok { + return resp.info, resp.err + } + } + if m.statFileErr != nil { + return nil, m.statFileErr + } + // Default: synthesise a FileInfo with the legacy size + a sane mode. + return &stubFileInfo{size: m.statFileSize, mode: 0644, name: remotePath}, nil +} + +func (m *mockSSHClient) ReadFile(remotePath string) ([]byte, error) { + if m.readErrByPath != nil { + if err, ok := m.readErrByPath[remotePath]; ok { + return nil, err + } + } + if m.readByPath != nil { + if data, ok := m.readByPath[remotePath]; ok { + return data, nil + } + } + // Default: empty bytes, no error. Tests that don't exercise the + // snapshot path see this fall-through (the read still succeeds so + // the snapshot phase doesn't block their deploy hot path). + return []byte{}, nil +} + +func (m *mockSSHClient) Remove(remotePath string) error { + m.removeCalls = append(m.removeCalls, remotePath) + if m.removeErrByPath != nil { + if err, ok := m.removeErrByPath[remotePath]; ok { + return err + } + } + return m.removeErr } func (m *mockSSHClient) Close() error { @@ -571,6 +663,388 @@ func TestDeployCertificate_ReloadFailure(t *testing.T) { } } +// --- Bundle 6: pre-deploy snapshot + reload-failure rollback --- +// +// These four tests pin the load-bearing rollback contract added in +// Bundle 6 of the 2026-05-02 deployment-target audit: +// - happy rollback path: pre-existing remote bytes restored verbatim; +// - first-time deploy partial-state cleanup via Remove; +// - both reload AND rollback fail → operator-actionable wrapped error; +// - rollback succeeds but the retry-reload after rollback fails → +// daemon-state-unknown wrapped error. + +func TestSSH_ReloadFails_FilesRestored(t *testing.T) { + originalCert := []byte("-----BEGIN CERTIFICATE-----\nORIGINAL_CERT\n-----END CERTIFICATE-----\n") + originalKey := []byte("-----BEGIN PRIVATE KEY-----\nORIGINAL_KEY\n-----END PRIVATE KEY-----\n") + originalChain := []byte("-----BEGIN CERTIFICATE-----\nORIGINAL_CHAIN\n-----END CERTIFICATE-----\n") + + mock := &mockSSHClient{ + // Pre-existing files for all three paths; mode 0644 / 0600 / 0644. + statByPath: map[string]statResponse{ + "/etc/ssl/cert.pem": {info: &stubFileInfo{size: int64(len(originalCert)), mode: 0644}}, + "/etc/ssl/key.pem": {info: &stubFileInfo{size: int64(len(originalKey)), mode: 0600}}, + "/etc/ssl/chain.pem": {info: &stubFileInfo{size: int64(len(originalChain)), mode: 0644}}, + }, + readByPath: map[string][]byte{ + "/etc/ssl/cert.pem": originalCert, + "/etc/ssl/key.pem": originalKey, + "/etc/ssl/chain.pem": originalChain, + }, + // First Execute (reload) fails; second Execute (retry-reload after + // restore) succeeds — clean recoverable failure. + executeErrSequence: []error{fmt.Errorf("reload failed: exit status 1"), nil}, + executeOutSequence: []string{"reload error output", "ok"}, + } + + cfg := &Config{ + Host: "server.local", + Port: 22, + CertPath: "/etc/ssl/cert.pem", + KeyPath: "/etc/ssl/key.pem", + ChainPath: "/etc/ssl/chain.pem", + CertMode: "0644", + KeyMode: "0600", + ReloadCommand: "systemctl reload nginx", + } + c := NewWithClient(cfg, mock, testLogger()) + + req := target.DeploymentRequest{ + CertPEM: "-----BEGIN CERTIFICATE-----\nNEW_CERT\n-----END CERTIFICATE-----\n", + KeyPEM: "-----BEGIN PRIVATE KEY-----\nNEW_KEY\n-----END PRIVATE KEY-----\n", + ChainPEM: "-----BEGIN CERTIFICATE-----\nNEW_CHAIN\n-----END CERTIFICATE-----\n", + } + + result, err := c.DeployCertificate(context.Background(), req) + if err == nil { + t.Fatal("expected error when reload fails") + } + if result.Success { + t.Fatal("expected failure result") + } + + // Error must mention reload failure + rollback success. + if !containsString(err.Error(), "reload command failed") && !containsString(err.Error(), "reload failed") { + t.Errorf("expected error to mention reload failure, got: %v", err) + } + if !containsString(err.Error(), "rolled back") { + t.Errorf("expected error to mention 'rolled back', got: %v", err) + } + + // Build a path → bytes view of every WriteFile call for the assertions. + // On the success path the deploy writes new bytes; on the rollback path + // it writes the originals back. We expect each path to be written at + // least twice (once with new bytes, once with originals). + writesByPath := map[string][][]byte{} + for _, w := range mock.writeFileCalls { + writesByPath[w.Path] = append(writesByPath[w.Path], w.Data) + } + + for _, path := range []string{"/etc/ssl/cert.pem", "/etc/ssl/key.pem", "/etc/ssl/chain.pem"} { + writes := writesByPath[path] + if len(writes) < 2 { + t.Errorf("expected at least 2 WriteFile calls for %s (deploy + restore), got %d", path, len(writes)) + continue + } + // Last write to each path is the rollback restore — must equal + // the pre-existing bytes captured in the snapshot. + lastWrite := writes[len(writes)-1] + var want []byte + switch path { + case "/etc/ssl/cert.pem": + want = originalCert + case "/etc/ssl/key.pem": + want = originalKey + case "/etc/ssl/chain.pem": + want = originalChain + } + if string(lastWrite) != string(want) { + t.Errorf("rollback for %s did not restore original bytes:\n got: %q\n want: %q", path, lastWrite, want) + } + } + + // No Remove calls — every path had a pre-existing snapshot to restore from. + if len(mock.removeCalls) != 0 { + t.Errorf("expected 0 Remove calls (all paths had backups), got %d: %v", len(mock.removeCalls), mock.removeCalls) + } + + // Both Execute calls (initial reload + retry-reload after rollback) + // must have run. + if len(mock.executeCalls) != 2 { + t.Errorf("expected 2 Execute calls (reload + retry-reload), got %d", len(mock.executeCalls)) + } + + // Metadata reflects per-path snapshot status. + if result.Metadata["backup_status_cert"] != "restored" { + t.Errorf("expected backup_status_cert=restored, got %q", result.Metadata["backup_status_cert"]) + } + if result.Metadata["backup_status_key"] != "restored" { + t.Errorf("expected backup_status_key=restored, got %q", result.Metadata["backup_status_key"]) + } + if result.Metadata["backup_status_chain"] != "restored" { + t.Errorf("expected backup_status_chain=restored, got %q", result.Metadata["backup_status_chain"]) + } + if result.Metadata["rolled_back"] != "true" { + t.Errorf("expected rolled_back=true, got %q", result.Metadata["rolled_back"]) + } +} + +func TestSSH_NoExistingCert_ReloadFails_NewCertRemoved(t *testing.T) { + mock := &mockSSHClient{ + // All three paths report "no such file" — first-time deploy. + statByPath: map[string]statResponse{ + "/etc/ssl/cert.pem": {err: fmt.Errorf("stat: %w", os.ErrNotExist)}, + "/etc/ssl/key.pem": {err: fmt.Errorf("stat: %w", os.ErrNotExist)}, + "/etc/ssl/chain.pem": {err: fmt.Errorf("stat: %w", os.ErrNotExist)}, + }, + // Reload fails; retry-reload after rollback succeeds. + executeErrSequence: []error{fmt.Errorf("reload failed"), nil}, + executeOutSequence: []string{"reload error", "ok"}, + } + + cfg := &Config{ + Host: "server.local", + Port: 22, + CertPath: "/etc/ssl/cert.pem", + KeyPath: "/etc/ssl/key.pem", + ChainPath: "/etc/ssl/chain.pem", + CertMode: "0644", + KeyMode: "0600", + ReloadCommand: "systemctl reload nginx", + } + c := NewWithClient(cfg, mock, testLogger()) + + req := target.DeploymentRequest{ + CertPEM: "-----BEGIN CERTIFICATE-----\nNEW_CERT\n-----END CERTIFICATE-----\n", + KeyPEM: "-----BEGIN PRIVATE KEY-----\nNEW_KEY\n-----END PRIVATE KEY-----\n", + ChainPEM: "-----BEGIN CERTIFICATE-----\nNEW_CHAIN\n-----END CERTIFICATE-----\n", + } + + result, err := c.DeployCertificate(context.Background(), req) + if err == nil { + t.Fatal("expected error when reload fails") + } + if result.Success { + t.Fatal("expected failure result") + } + + // Rollback for first-time deploys must call Remove on every written path. + expectedRemoves := map[string]bool{ + "/etc/ssl/cert.pem": true, + "/etc/ssl/key.pem": true, + "/etc/ssl/chain.pem": true, + } + if len(mock.removeCalls) != len(expectedRemoves) { + t.Errorf("expected %d Remove calls, got %d: %v", len(expectedRemoves), len(mock.removeCalls), mock.removeCalls) + } + for _, p := range mock.removeCalls { + if !expectedRemoves[p] { + t.Errorf("unexpected Remove path: %s", p) + } + } + + // First-time deploy: WriteFile is called only during the initial + // deploy, never during rollback (no backup to restore from). + expectedWrites := 3 // cert + key + chain (all configured paths) + if len(mock.writeFileCalls) != expectedWrites { + t.Errorf("expected exactly %d WriteFile calls (deploy only, no restore), got %d", expectedWrites, len(mock.writeFileCalls)) + } + + // Metadata reflects "removed" status for all paths. + if result.Metadata["backup_status_cert"] != "removed" { + t.Errorf("expected backup_status_cert=removed, got %q", result.Metadata["backup_status_cert"]) + } + if result.Metadata["backup_status_key"] != "removed" { + t.Errorf("expected backup_status_key=removed, got %q", result.Metadata["backup_status_key"]) + } + if result.Metadata["backup_status_chain"] != "removed" { + t.Errorf("expected backup_status_chain=removed, got %q", result.Metadata["backup_status_chain"]) + } +} + +func TestSSH_ReloadFails_RollbackAlsoFails_OperatorActionable(t *testing.T) { + originalCert := []byte("ORIGINAL_CERT") + originalKey := []byte("ORIGINAL_KEY") + + mock := &mockSSHClient{ + statByPath: map[string]statResponse{ + "/etc/ssl/cert.pem": {info: &stubFileInfo{size: int64(len(originalCert)), mode: 0644}}, + "/etc/ssl/key.pem": {info: &stubFileInfo{size: int64(len(originalKey)), mode: 0600}}, + }, + readByPath: map[string][]byte{ + "/etc/ssl/cert.pem": originalCert, + "/etc/ssl/key.pem": originalKey, + }, + // Initial deploy WriteFile calls succeed; rollback's WriteFile to + // restore the cert FAILS. This injects the operator-actionable + // case: reload failed AND the restore can't complete. + writeFileErrByPath: map[string]error{}, + executeErrSequence: []error{fmt.Errorf("reload step failed")}, + executeOutSequence: []string{"reload error"}, + } + // Track call count so we can fail only the SECOND WriteFile to + // /etc/ssl/cert.pem (i.e. the restore call, not the initial deploy + // write). Done via a wrapper because writeFileErrByPath is a flat map. + wrapped := &writeOrderTrackingMock{base: mock} + wrapped.failOnNthWriteForPath = map[string]int{ + "/etc/ssl/cert.pem": 2, // 1st = deploy write (succeed); 2nd = restore (fail) + } + + cfg := &Config{ + Host: "server.local", + Port: 22, + CertPath: "/etc/ssl/cert.pem", + KeyPath: "/etc/ssl/key.pem", + CertMode: "0644", + KeyMode: "0600", + ReloadCommand: "systemctl reload nginx", + } + c := NewWithClient(cfg, wrapped, testLogger()) + + req := target.DeploymentRequest{ + CertPEM: "NEW_CERT", + KeyPEM: "NEW_KEY", + } + + result, err := c.DeployCertificate(context.Background(), req) + if err == nil { + t.Fatal("expected error when both reload and rollback fail") + } + if result.Success { + t.Fatal("expected failure result") + } + + // Wrapped error must mention BOTH the reload error and the rollback error. + if !containsString(err.Error(), "reload failed") { + t.Errorf("expected error to mention reload failure, got: %v", err) + } + if !containsString(err.Error(), "rollback also failed") { + t.Errorf("expected error to mention 'rollback also failed', got: %v", err) + } + if !containsString(err.Error(), "manual operator inspection required") { + t.Errorf("expected error to flag manual inspection, got: %v", err) + } + + // Metadata must surface manual_action_required + both error strings. + if result.Metadata["manual_action_required"] != "true" { + t.Errorf("expected manual_action_required=true, got %q", result.Metadata["manual_action_required"]) + } + if result.Metadata["rolled_back"] != "false" { + t.Errorf("expected rolled_back=false, got %q", result.Metadata["rolled_back"]) + } + if result.Metadata["rollback_error"] == "" { + t.Error("expected rollback_error in metadata") + } +} + +func TestSSH_ReloadFails_RestoreThenSecondReloadFails(t *testing.T) { + originalCert := []byte("ORIGINAL_CERT") + originalKey := []byte("ORIGINAL_KEY") + + mock := &mockSSHClient{ + statByPath: map[string]statResponse{ + "/etc/ssl/cert.pem": {info: &stubFileInfo{size: int64(len(originalCert)), mode: 0644}}, + "/etc/ssl/key.pem": {info: &stubFileInfo{size: int64(len(originalKey)), mode: 0600}}, + }, + readByPath: map[string][]byte{ + "/etc/ssl/cert.pem": originalCert, + "/etc/ssl/key.pem": originalKey, + }, + // Both Execute calls (initial reload + retry-reload after rollback) + // fail. The remote files are back to pre-deploy state but the + // daemon may be in a stuck/partial state — operator needs to + // know that. + executeErrSequence: []error{fmt.Errorf("reload step 1 failed"), fmt.Errorf("reload step 2 failed")}, + executeOutSequence: []string{"out1", "out2"}, + } + + cfg := &Config{ + Host: "server.local", + Port: 22, + CertPath: "/etc/ssl/cert.pem", + KeyPath: "/etc/ssl/key.pem", + CertMode: "0644", + KeyMode: "0600", + ReloadCommand: "systemctl reload nginx", + } + c := NewWithClient(cfg, mock, testLogger()) + + req := target.DeploymentRequest{ + CertPEM: "NEW_CERT", + KeyPEM: "NEW_KEY", + } + + result, err := c.DeployCertificate(context.Background(), req) + if err == nil { + t.Fatal("expected error when retry-reload after rollback fails") + } + if result.Success { + t.Fatal("expected failure result") + } + + // Wrapped error mentions reload failure, rollback success, and + // retry-reload failure — operator must understand the daemon may + // not be running the original config even though the files are back. + if !containsString(err.Error(), "rolled back files") { + t.Errorf("expected error to mention 'rolled back files', got: %v", err) + } + if !containsString(err.Error(), "retry-reload also failed") { + t.Errorf("expected error to mention retry-reload failure, got: %v", err) + } + if !containsString(err.Error(), "daemon may need manual restart") { + t.Errorf("expected error to flag daemon state, got: %v", err) + } + + // Metadata flags daemon_state_unknown + rolled_back=true (files OK). + if result.Metadata["daemon_state_unknown"] != "true" { + t.Errorf("expected daemon_state_unknown=true, got %q", result.Metadata["daemon_state_unknown"]) + } + if result.Metadata["rolled_back"] != "true" { + t.Errorf("expected rolled_back=true, got %q", result.Metadata["rolled_back"]) + } + + // Both Execute calls happened; both WriteFile-on-restore calls + // happened (cert + key restored). + if len(mock.executeCalls) != 2 { + t.Errorf("expected 2 Execute calls, got %d", len(mock.executeCalls)) + } +} + +// writeOrderTrackingMock wraps mockSSHClient to fail the Nth WriteFile +// for a given path. Used by TestSSH_ReloadFails_RollbackAlsoFails_- +// OperatorActionable to fail the restore (2nd write) while letting the +// initial deploy (1st write) succeed for the same path. +type writeOrderTrackingMock struct { + base *mockSSHClient + writeCountByPath map[string]int + failOnNthWriteForPath map[string]int +} + +func (w *writeOrderTrackingMock) Connect(ctx context.Context) error { return w.base.Connect(ctx) } +func (w *writeOrderTrackingMock) WriteFile(remotePath string, data []byte, mode os.FileMode) error { + if w.writeCountByPath == nil { + w.writeCountByPath = map[string]int{} + } + w.writeCountByPath[remotePath]++ + w.base.writeFileCalls = append(w.base.writeFileCalls, writeFileCall{Path: remotePath, Data: data, Mode: mode}) + if n, ok := w.failOnNthWriteForPath[remotePath]; ok { + if w.writeCountByPath[remotePath] == n { + return fmt.Errorf("injected write failure on call %d to %s", n, remotePath) + } + } + return nil +} +func (w *writeOrderTrackingMock) Execute(ctx context.Context, cmd string) (string, error) { + return w.base.Execute(ctx, cmd) +} +func (w *writeOrderTrackingMock) StatFile(remotePath string) (os.FileInfo, error) { + return w.base.StatFile(remotePath) +} +func (w *writeOrderTrackingMock) ReadFile(remotePath string) ([]byte, error) { + return w.base.ReadFile(remotePath) +} +func (w *writeOrderTrackingMock) Remove(remotePath string) error { return w.base.Remove(remotePath) } +func (w *writeOrderTrackingMock) Close() error { return w.base.Close() } + // --- ValidateDeployment tests --- func TestValidateDeployment_Success(t *testing.T) { @@ -882,13 +1356,24 @@ func (m *conditionalStatMockSSHClient) Execute(ctx context.Context, command stri return m.base.Execute(ctx, command) } -func (m *conditionalStatMockSSHClient) StatFile(remotePath string) (int64, error) { +func (m *conditionalStatMockSSHClient) StatFile(remotePath string) (os.FileInfo, error) { m.callCount++ - // First call succeeds (cert), second call fails (key) + // First call succeeds (cert), second call fails (key) — wrap + // os.ErrNotExist so the connector's errors.Is check propagates the + // "file not found" semantics through the Bundle 6 stat-error + // handling. if m.callCount == 2 { - return 0, fmt.Errorf("file not found") + return nil, fmt.Errorf("file not found: %w", os.ErrNotExist) } - return 1024, nil + return &stubFileInfo{size: 1024, mode: 0644}, nil +} + +func (m *conditionalStatMockSSHClient) ReadFile(remotePath string) ([]byte, error) { + return m.base.ReadFile(remotePath) +} + +func (m *conditionalStatMockSSHClient) Remove(remotePath string) error { + return m.base.Remove(remotePath) } func (m *conditionalStatMockSSHClient) Close() error { diff --git a/internal/connector/target/ssh/validate_only_test.go b/internal/connector/target/ssh/validate_only_test.go index 5526e13..b068b24 100644 --- a/internal/connector/target/ssh/validate_only_test.go +++ b/internal/connector/target/ssh/validate_only_test.go @@ -20,7 +20,9 @@ func (s *stubSSHClient) Connect(_ context.Context) error { r func (s *stubSSHClient) Close() error { return nil } func (s *stubSSHClient) WriteFile(_ string, _ []byte, _ os.FileMode) error { return nil } func (s *stubSSHClient) Execute(_ context.Context, _ string) (string, error) { return "", nil } -func (s *stubSSHClient) StatFile(_ string) (int64, error) { return 0, nil } +func (s *stubSSHClient) StatFile(_ string) (os.FileInfo, error) { return nil, os.ErrNotExist } +func (s *stubSSHClient) ReadFile(_ string) ([]byte, error) { return nil, os.ErrNotExist } +func (s *stubSSHClient) Remove(_ string) error { return nil } func TestSSH_ValidateOnly_Connect_Succeeds(t *testing.T) { c := NewWithClient(&Config{Host: "h", User: "u"}, &stubSSHClient{}, nil)