Files
certctl/internal/connector/target/ssh/ssh_test.go
T
shankar0123 636de7f6b5 ssh: pre-deploy snapshot + reload-failure rollback
Closes Bundle 6 of the 2026-05-02 deployment-target coverage audit
(see cowork/deployment-target-audit-2026-05-02/RESULTS.md). Pre-fix,
DeployCertificate at ssh.go:201-316 wrote new cert/key/chain via
SFTP then ran the operator's reload command. If reload failed, the
new files stayed on the remote — partial-success state with no
rollback path. docs/deployment-atomicity.md L92 promised "Pre-deploy
SCP backup of remote files"; the code didn't deliver.

This commit:

1. Pre-deploy snapshot. Before any WriteFile, iterate the deploy's
   target paths (cert, key, optional chain). For each path:
   - StatFile to detect existence. errors.Is(err, os.ErrNotExist)
     means first-time deploy (rollback = Remove). Other stat
     errors bail out before any write happens.
   - ReadFile into an in-memory backups map[string][]byte keyed
     by remote path. Original mode captured into a parallel
     modes map for restore fidelity.

2. SSHClient interface evolution — three changes:
   - StatFile(path) (os.FileInfo, error) — was (int64, error).
     FileInfo carries Mode() needed for accurate restore. Existing
     fixture tests updated to call info.Size() instead of the
     bare size value.
   - ReadFile(path) ([]byte, error) — new method; SFTP Open + read
     via io.ReadAll. realSSHClient implements via sftpClient.Open.
   - Remove(path) error — new method; SFTP Remove. Used by the
     rollback path to clean up first-time-deploy partial state.

3. On-reload-failure rollback. Replace the bare error-return at
   L282-295 with restoreFromBackups + retry-reload escalation:
   - For paths in the snapshot map, WriteFile the original bytes
     with the original mode (0600 fallback if mode capture was
     incomplete).
   - For paths that didn't exist pre-deploy, Remove the new file.
   - Re-run the reload command (best-effort second attempt). If
     it succeeds, the target is back to pre-deploy state. If it
     fails, the remote is in pre-deploy file state but the daemon
     may be stuck — surface as wrapped error so the operator
     knows where to look.

4. DeploymentResult.Metadata gains backup_status_{cert,key,chain}
   so operators can see per-path snapshot state on both success
   ("snapshotted" / "no_pre_existing" / "n/a") and failure
   ("restored" / "removed" / "restore_failed" / "remove_failed").
   buildMetadataWithBackup helper centralises the metadata
   shape so success and failure paths emit a consistent set
   of keys.

5. Helper extraction. restoreFromBackups(ctx, paths, backups,
   modes) is a private method on Connector; returns the first
   error + per-key restore status map for clean test seams.

DeploymentResult shape on failure:
- rollback OK + retry-reload OK → Success=false, "reload command
  failed; rolled back to pre-deploy state" (clean recoverable
  failure; remote fully restored, daemon serving original cert).
- rollback OK + retry-reload FAIL → wrapped error noting "rolled
  back files; retry-reload also failed; daemon may need manual
  restart". Metadata flags daemon_state_unknown=true.
- rollback FAIL → operator-actionable wrapped error containing
  BOTH the reload error AND the rollback error; metadata flags
  manual_action_required=true.

Tests added to ssh_test.go (4 new tests, ~330 LOC):
- TestSSH_ReloadFails_FilesRestored — happy rollback path with
  pre-existing remote bytes for cert/key/chain. Asserts every
  path's last WriteFile call contains the captured backup bytes
  verbatim, no Remove calls fired (all paths had snapshots), and
  metadata reports backup_status=restored for each path.
- TestSSH_NoExistingCert_ReloadFails_NewCertRemoved — first-time
  deploy variant. StatFile returns os.ErrNotExist for every path;
  rollback Removes each written file but performs no WriteFile
  during restore (no backup to restore from). Asserts exactly 3
  WriteFile calls (deploy only) and 3 Remove calls (rollback).
- TestSSH_ReloadFails_RollbackAlsoFails_OperatorActionable —
  uses a writeOrderTrackingMock to fail the SECOND WriteFile to
  the cert path (i.e. the restore call, not the initial deploy).
  Asserts wrapped error contains both the reload error and the
  rollback error, and metadata flags manual_action_required=true.
- TestSSH_ReloadFails_RestoreThenSecondReloadFails — partial-
  recovery escalation. Rollback succeeds but the post-restore
  retry-reload fails. Asserts wrapped error mentions "rolled back
  files; retry-reload also failed" and metadata flags
  daemon_state_unknown=true.

Existing tests preserved by extending mockSSHClient with backward-
compatible per-path response maps (statByPath / readByPath /
writeFileErrByPath / executeErrSequence). Legacy global fields
(statFileSize / statFileErr / writeFileErr / executeErr) still
work when no per-path override matches, so TestValidateConfig_*
and TestDeployCertificate_Success_* don't need changes.

docs/deployment-atomicity.md L92 unchanged from today's text —
Bundle 1 doc-realignment hasn't shipped, so the "Pre-deploy SCP
backup of remote files" line was never softened. Post-Bundle-6
the claim is honest (was aspirational pre-fix).

Verified locally (sandbox lacks staticcheck install due to disk
pressure; CI runs the full lint gate):
- gofmt -l ./internal/connector/target/ssh/  clean
- go vet ./internal/connector/target/ssh/  clean
- go build ./internal/connector/target/ssh/...  clean
- go build ./cmd/agent/...  clean
- go test -race -count=1 ./internal/connector/target/ssh/  green

Audit reference: cowork/deployment-target-audit-2026-05-02/RESULTS.md
Bundle 6.
2026-05-02 17:13:38 +00:00

1417 lines
44 KiB
Go

package ssh
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"os"
"testing"
"time"
"github.com/shankar0123/certctl/internal/connector/target"
)
// stubFileInfo implements os.FileInfo for tests that need to return a
// FileInfo from the mock SSHClient's StatFile. Bundle 6 of the
// 2026-05-02 deployment-target audit evolved StatFile's signature from
// (int64, error) to (os.FileInfo, error) so the pre-deploy snapshot
// can capture the original mode for accurate rollback restoration.
type stubFileInfo struct {
size int64
mode os.FileMode
name string
}
func (s *stubFileInfo) Name() string { return s.name }
func (s *stubFileInfo) Size() int64 { return s.size }
func (s *stubFileInfo) Mode() os.FileMode { return s.mode }
func (s *stubFileInfo) ModTime() time.Time { return time.Time{} }
func (s *stubFileInfo) IsDir() bool { return false }
func (s *stubFileInfo) Sys() any { return nil }
// testLogger returns a slog.Logger for test output.
func testLogger() *slog.Logger {
return slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
}
// --- Mock SSH Client ---
// mockSSHClient records all calls and returns configurable results.
//
// Bundle 6 of the 2026-05-02 deployment-target audit added per-path
// response maps (statByPath / readByPath / writeFileErrByPath) so the
// new snapshot/rollback tests can simulate (a) pre-existing remote
// files for the snapshot to read, (b) per-call WriteFile failures to
// inject restore-failure paths, and (c) sequenced Execute errors so
// the reload-then-retry-reload tests can drive both calls
// independently. The legacy global fields (statFileSize / statFileErr
// / writeFileErr / executeErr) are still honored when no per-path
// override matches, so existing tests remain green.
type mockSSHClient struct {
connectCalls int
connectErr error
writeFileCalls []writeFileCall
writeFileErr error
writeFileErrByPath map[string]error // per-path WriteFile error overrides
executeCalls []string
executeOutput string
executeErr error
executeErrSequence []error // per-call Execute errors; falls back to executeErr after exhaustion
executeOutSequence []string // per-call Execute outputs; mirrors executeErrSequence
statFileCalls []string
statFileSize int64
statFileErr error
statByPath map[string]statResponse // per-path StatFile responses
readByPath map[string][]byte // per-path ReadFile bytes (existence implies success)
readErrByPath map[string]error // per-path ReadFile error overrides
removeCalls []string
removeErr error
removeErrByPath map[string]error
closeCalls int
}
type writeFileCall struct {
Path string
Data []byte
Mode os.FileMode
}
type statResponse struct {
info os.FileInfo
err error
}
func (m *mockSSHClient) Connect(ctx context.Context) error {
m.connectCalls++
return m.connectErr
}
func (m *mockSSHClient) WriteFile(remotePath string, data []byte, mode os.FileMode) error {
m.writeFileCalls = append(m.writeFileCalls, writeFileCall{Path: remotePath, Data: data, Mode: mode})
if m.writeFileErrByPath != nil {
if err, ok := m.writeFileErrByPath[remotePath]; ok {
return err
}
}
return m.writeFileErr
}
func (m *mockSSHClient) Execute(ctx context.Context, command string) (string, error) {
idx := len(m.executeCalls)
m.executeCalls = append(m.executeCalls, command)
if idx < len(m.executeErrSequence) {
out := ""
if idx < len(m.executeOutSequence) {
out = m.executeOutSequence[idx]
}
return out, m.executeErrSequence[idx]
}
return m.executeOutput, m.executeErr
}
func (m *mockSSHClient) StatFile(remotePath string) (os.FileInfo, error) {
m.statFileCalls = append(m.statFileCalls, remotePath)
if m.statByPath != nil {
if resp, ok := m.statByPath[remotePath]; ok {
return resp.info, resp.err
}
}
if m.statFileErr != nil {
return nil, m.statFileErr
}
// Default: synthesise a FileInfo with the legacy size + a sane mode.
return &stubFileInfo{size: m.statFileSize, mode: 0644, name: remotePath}, nil
}
func (m *mockSSHClient) ReadFile(remotePath string) ([]byte, error) {
if m.readErrByPath != nil {
if err, ok := m.readErrByPath[remotePath]; ok {
return nil, err
}
}
if m.readByPath != nil {
if data, ok := m.readByPath[remotePath]; ok {
return data, nil
}
}
// Default: empty bytes, no error. Tests that don't exercise the
// snapshot path see this fall-through (the read still succeeds so
// the snapshot phase doesn't block their deploy hot path).
return []byte{}, nil
}
func (m *mockSSHClient) Remove(remotePath string) error {
m.removeCalls = append(m.removeCalls, remotePath)
if m.removeErrByPath != nil {
if err, ok := m.removeErrByPath[remotePath]; ok {
return err
}
}
return m.removeErr
}
func (m *mockSSHClient) Close() error {
m.closeCalls++
return nil
}
// --- ValidateConfig tests ---
func TestValidateConfig_Success_KeyAuth(t *testing.T) {
// Create a temporary key file
keyFile := createTempKeyFile(t)
cfg := map[string]interface{}{
"host": "server.example.com",
"user": "deploy",
"auth_method": "key",
"private_key_path": keyFile,
"cert_path": "/etc/ssl/certs/cert.pem",
"key_path": "/etc/ssl/private/key.pem",
}
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
raw, _ := json.Marshal(cfg)
err := c.ValidateConfig(context.Background(), raw)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if c.config.Port != 22 {
t.Errorf("expected default port 22, got %d", c.config.Port)
}
if c.config.CertMode != "0644" {
t.Errorf("expected default cert_mode 0644, got %s", c.config.CertMode)
}
if c.config.KeyMode != "0600" {
t.Errorf("expected default key_mode 0600, got %s", c.config.KeyMode)
}
if c.config.Timeout != 30 {
t.Errorf("expected default timeout 30, got %d", c.config.Timeout)
}
}
func TestValidateConfig_Success_InlineKey(t *testing.T) {
cfg := map[string]interface{}{
"host": "10.0.0.5",
"user": "root",
"auth_method": "key",
"private_key": "-----BEGIN OPENSSH PRIVATE KEY-----\nfakekey\n-----END OPENSSH PRIVATE KEY-----",
"cert_path": "/etc/ssl/cert.pem",
"key_path": "/etc/ssl/key.pem",
}
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
raw, _ := json.Marshal(cfg)
err := c.ValidateConfig(context.Background(), raw)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
}
func TestValidateConfig_Success_PasswordAuth(t *testing.T) {
cfg := map[string]interface{}{
"host": "server.local",
"user": "deploy",
"auth_method": "password",
"password": "s3cret",
"cert_path": "/etc/ssl/cert.pem",
"key_path": "/etc/ssl/key.pem",
}
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
raw, _ := json.Marshal(cfg)
err := c.ValidateConfig(context.Background(), raw)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
}
func TestValidateConfig_InvalidJSON(t *testing.T) {
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
err := c.ValidateConfig(context.Background(), json.RawMessage(`{invalid`))
if err == nil {
t.Fatal("expected error for invalid JSON")
}
}
func TestValidateConfig_MissingHost(t *testing.T) {
cfg := map[string]interface{}{
"user": "deploy",
"cert_path": "/etc/ssl/cert.pem",
"key_path": "/etc/ssl/key.pem",
}
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
raw, _ := json.Marshal(cfg)
err := c.ValidateConfig(context.Background(), raw)
if err == nil {
t.Fatal("expected error for missing host")
}
}
func TestValidateConfig_MissingUser(t *testing.T) {
cfg := map[string]interface{}{
"host": "server.local",
"cert_path": "/etc/ssl/cert.pem",
"key_path": "/etc/ssl/key.pem",
}
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
raw, _ := json.Marshal(cfg)
err := c.ValidateConfig(context.Background(), raw)
if err == nil {
t.Fatal("expected error for missing user")
}
}
func TestValidateConfig_MissingCertPath(t *testing.T) {
cfg := map[string]interface{}{
"host": "server.local",
"user": "deploy",
"key_path": "/etc/ssl/key.pem",
}
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
raw, _ := json.Marshal(cfg)
err := c.ValidateConfig(context.Background(), raw)
if err == nil {
t.Fatal("expected error for missing cert_path")
}
}
func TestValidateConfig_MissingKeyPath(t *testing.T) {
cfg := map[string]interface{}{
"host": "server.local",
"user": "deploy",
"cert_path": "/etc/ssl/cert.pem",
}
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
raw, _ := json.Marshal(cfg)
err := c.ValidateConfig(context.Background(), raw)
if err == nil {
t.Fatal("expected error for missing key_path")
}
}
func TestValidateConfig_KeyAuth_MissingKey(t *testing.T) {
cfg := map[string]interface{}{
"host": "server.local",
"user": "deploy",
"auth_method": "key",
"cert_path": "/etc/ssl/cert.pem",
"key_path": "/etc/ssl/key.pem",
}
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
raw, _ := json.Marshal(cfg)
err := c.ValidateConfig(context.Background(), raw)
if err == nil {
t.Fatal("expected error for key auth missing both private_key and private_key_path")
}
}
func TestValidateConfig_PasswordAuth_MissingPassword(t *testing.T) {
cfg := map[string]interface{}{
"host": "server.local",
"user": "deploy",
"auth_method": "password",
"cert_path": "/etc/ssl/cert.pem",
"key_path": "/etc/ssl/key.pem",
}
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
raw, _ := json.Marshal(cfg)
err := c.ValidateConfig(context.Background(), raw)
if err == nil {
t.Fatal("expected error for password auth missing password")
}
}
func TestValidateConfig_InvalidHost(t *testing.T) {
cfg := map[string]interface{}{
"host": "server;rm -rf /",
"user": "deploy",
"cert_path": "/etc/ssl/cert.pem",
"key_path": "/etc/ssl/key.pem",
"private_key": "fake",
}
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
raw, _ := json.Marshal(cfg)
err := c.ValidateConfig(context.Background(), raw)
if err == nil {
t.Fatal("expected error for host with shell metacharacters")
}
}
func TestValidateConfig_InvalidPermissions(t *testing.T) {
keyFile := createTempKeyFile(t)
cfg := map[string]interface{}{
"host": "server.local",
"user": "deploy",
"private_key_path": keyFile,
"cert_path": "/etc/ssl/cert.pem",
"key_path": "/etc/ssl/key.pem",
"cert_mode": "999",
}
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
raw, _ := json.Marshal(cfg)
err := c.ValidateConfig(context.Background(), raw)
if err == nil {
t.Fatal("expected error for invalid cert_mode")
}
}
func TestValidateConfig_ReloadCommandInjection(t *testing.T) {
tests := []struct {
name string
command string
}{
{"semicolon", "systemctl reload nginx; rm -rf /"},
{"pipe", "systemctl reload nginx | cat"},
{"backtick", "systemctl reload `malicious`"},
{"command substitution", "systemctl reload $(evil)"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
keyFile := createTempKeyFile(t)
cfg := map[string]interface{}{
"host": "server.local",
"user": "deploy",
"private_key_path": keyFile,
"cert_path": "/etc/ssl/cert.pem",
"key_path": "/etc/ssl/key.pem",
"reload_command": tc.command,
}
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
raw, _ := json.Marshal(cfg)
err := c.ValidateConfig(context.Background(), raw)
if err == nil {
t.Fatalf("expected error for reload command injection: %q", tc.command)
}
})
}
}
func TestValidateConfig_InvalidAuthMethod(t *testing.T) {
cfg := map[string]interface{}{
"host": "server.local",
"user": "deploy",
"auth_method": "kerberos",
"cert_path": "/etc/ssl/cert.pem",
"key_path": "/etc/ssl/key.pem",
}
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
raw, _ := json.Marshal(cfg)
err := c.ValidateConfig(context.Background(), raw)
if err == nil {
t.Fatal("expected error for invalid auth method")
}
}
func TestValidateConfig_KeyFileNotFound(t *testing.T) {
cfg := map[string]interface{}{
"host": "server.local",
"user": "deploy",
"auth_method": "key",
"private_key_path": "/nonexistent/key.pem",
"cert_path": "/etc/ssl/cert.pem",
"key_path": "/etc/ssl/key.pem",
}
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
raw, _ := json.Marshal(cfg)
err := c.ValidateConfig(context.Background(), raw)
if err == nil {
t.Fatal("expected error for nonexistent key file")
}
}
// --- DeployCertificate tests ---
func TestDeployCertificate_Success_NoChainPath(t *testing.T) {
mock := &mockSSHClient{statFileSize: 1024}
cfg := &Config{
Host: "server.local",
Port: 22,
CertPath: "/etc/ssl/cert.pem",
KeyPath: "/etc/ssl/key.pem",
CertMode: "0644",
KeyMode: "0600",
}
c := NewWithClient(cfg, mock, testLogger())
req := target.DeploymentRequest{
CertPEM: "-----BEGIN CERTIFICATE-----\ncert\n-----END CERTIFICATE-----",
KeyPEM: "-----BEGIN PRIVATE KEY-----\nkey\n-----END PRIVATE KEY-----",
ChainPEM: "-----BEGIN CERTIFICATE-----\nchain\n-----END CERTIFICATE-----",
}
result, err := c.DeployCertificate(context.Background(), req)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if !result.Success {
t.Fatalf("expected success, got %s", result.Message)
}
// Should have 2 writes (cert with chain appended, key)
if len(mock.writeFileCalls) != 2 {
t.Fatalf("expected 2 write calls, got %d", len(mock.writeFileCalls))
}
// Cert should include chain (fullchain)
certWrite := mock.writeFileCalls[0]
if certWrite.Path != "/etc/ssl/cert.pem" {
t.Errorf("expected cert path /etc/ssl/cert.pem, got %s", certWrite.Path)
}
if certWrite.Mode != 0644 {
t.Errorf("expected cert mode 0644, got %v", certWrite.Mode)
}
certContent := string(certWrite.Data)
if len(certContent) == 0 {
t.Error("cert data should not be empty")
}
// Key write
keyWrite := mock.writeFileCalls[1]
if keyWrite.Path != "/etc/ssl/key.pem" {
t.Errorf("expected key path /etc/ssl/key.pem, got %s", keyWrite.Path)
}
if keyWrite.Mode != 0600 {
t.Errorf("expected key mode 0600, got %v", keyWrite.Mode)
}
// Metadata
if result.Metadata["host"] != "server.local" {
t.Errorf("expected host metadata server.local, got %s", result.Metadata["host"])
}
}
func TestDeployCertificate_Success_SeparateChain(t *testing.T) {
mock := &mockSSHClient{}
cfg := &Config{
Host: "server.local",
Port: 22,
CertPath: "/etc/ssl/cert.pem",
KeyPath: "/etc/ssl/key.pem",
ChainPath: "/etc/ssl/chain.pem",
CertMode: "0644",
KeyMode: "0600",
}
c := NewWithClient(cfg, mock, testLogger())
req := target.DeploymentRequest{
CertPEM: "cert-data",
KeyPEM: "key-data",
ChainPEM: "chain-data",
}
result, err := c.DeployCertificate(context.Background(), req)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if !result.Success {
t.Fatalf("expected success, got %s", result.Message)
}
// Should have 3 writes (cert, key, chain)
if len(mock.writeFileCalls) != 3 {
t.Fatalf("expected 3 write calls, got %d", len(mock.writeFileCalls))
}
// Chain should be separate
chainWrite := mock.writeFileCalls[2]
if chainWrite.Path != "/etc/ssl/chain.pem" {
t.Errorf("expected chain path /etc/ssl/chain.pem, got %s", chainWrite.Path)
}
}
func TestDeployCertificate_Success_WithReload(t *testing.T) {
mock := &mockSSHClient{executeOutput: "ok"}
cfg := &Config{
Host: "server.local",
Port: 22,
CertPath: "/etc/ssl/cert.pem",
KeyPath: "/etc/ssl/key.pem",
CertMode: "0644",
KeyMode: "0600",
ReloadCommand: "systemctl reload nginx",
}
c := NewWithClient(cfg, mock, testLogger())
req := target.DeploymentRequest{
CertPEM: "cert",
KeyPEM: "key",
}
result, err := c.DeployCertificate(context.Background(), req)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if !result.Success {
t.Fatalf("expected success, got %s", result.Message)
}
// Should have executed reload command
if len(mock.executeCalls) != 1 {
t.Fatalf("expected 1 execute call, got %d", len(mock.executeCalls))
}
if mock.executeCalls[0] != "systemctl reload nginx" {
t.Errorf("expected reload command, got %s", mock.executeCalls[0])
}
}
func TestDeployCertificate_MissingKeyPEM(t *testing.T) {
mock := &mockSSHClient{}
cfg := &Config{
Host: "server.local",
Port: 22,
CertPath: "/etc/ssl/cert.pem",
KeyPath: "/etc/ssl/key.pem",
CertMode: "0644",
KeyMode: "0600",
}
c := NewWithClient(cfg, mock, testLogger())
req := target.DeploymentRequest{
CertPEM: "cert",
KeyPEM: "", // Missing
}
result, err := c.DeployCertificate(context.Background(), req)
if err == nil {
t.Fatal("expected error for missing KeyPEM")
}
if result.Success {
t.Fatal("expected failure result")
}
}
func TestDeployCertificate_ConnectionFailure(t *testing.T) {
mock := &mockSSHClient{connectErr: fmt.Errorf("connection refused")}
cfg := &Config{
Host: "unreachable.local",
Port: 22,
CertPath: "/etc/ssl/cert.pem",
KeyPath: "/etc/ssl/key.pem",
CertMode: "0644",
KeyMode: "0600",
}
c := NewWithClient(cfg, mock, testLogger())
req := target.DeploymentRequest{
CertPEM: "cert",
KeyPEM: "key",
}
result, err := c.DeployCertificate(context.Background(), req)
if err == nil {
t.Fatal("expected error for connection failure")
}
if result.Success {
t.Fatal("expected failure result")
}
}
func TestDeployCertificate_WriteFailure(t *testing.T) {
mock := &mockSSHClient{writeFileErr: fmt.Errorf("permission denied")}
cfg := &Config{
Host: "server.local",
Port: 22,
CertPath: "/etc/ssl/cert.pem",
KeyPath: "/etc/ssl/key.pem",
CertMode: "0644",
KeyMode: "0600",
}
c := NewWithClient(cfg, mock, testLogger())
req := target.DeploymentRequest{
CertPEM: "cert",
KeyPEM: "key",
}
result, err := c.DeployCertificate(context.Background(), req)
if err == nil {
t.Fatal("expected error for write failure")
}
if result.Success {
t.Fatal("expected failure result")
}
}
func TestDeployCertificate_ReloadFailure(t *testing.T) {
mock := &mockSSHClient{executeErr: fmt.Errorf("reload failed: exit status 1"), executeOutput: "error"}
cfg := &Config{
Host: "server.local",
Port: 22,
CertPath: "/etc/ssl/cert.pem",
KeyPath: "/etc/ssl/key.pem",
CertMode: "0644",
KeyMode: "0600",
ReloadCommand: "systemctl reload nginx",
}
c := NewWithClient(cfg, mock, testLogger())
req := target.DeploymentRequest{
CertPEM: "cert",
KeyPEM: "key",
}
result, err := c.DeployCertificate(context.Background(), req)
if err == nil {
t.Fatal("expected error for reload failure")
}
if result.Success {
t.Fatal("expected failure result")
}
}
// --- Bundle 6: pre-deploy snapshot + reload-failure rollback ---
//
// These four tests pin the load-bearing rollback contract added in
// Bundle 6 of the 2026-05-02 deployment-target audit:
// - happy rollback path: pre-existing remote bytes restored verbatim;
// - first-time deploy partial-state cleanup via Remove;
// - both reload AND rollback fail → operator-actionable wrapped error;
// - rollback succeeds but the retry-reload after rollback fails →
// daemon-state-unknown wrapped error.
func TestSSH_ReloadFails_FilesRestored(t *testing.T) {
originalCert := []byte("-----BEGIN CERTIFICATE-----\nORIGINAL_CERT\n-----END CERTIFICATE-----\n")
originalKey := []byte("-----BEGIN PRIVATE KEY-----\nORIGINAL_KEY\n-----END PRIVATE KEY-----\n")
originalChain := []byte("-----BEGIN CERTIFICATE-----\nORIGINAL_CHAIN\n-----END CERTIFICATE-----\n")
mock := &mockSSHClient{
// Pre-existing files for all three paths; mode 0644 / 0600 / 0644.
statByPath: map[string]statResponse{
"/etc/ssl/cert.pem": {info: &stubFileInfo{size: int64(len(originalCert)), mode: 0644}},
"/etc/ssl/key.pem": {info: &stubFileInfo{size: int64(len(originalKey)), mode: 0600}},
"/etc/ssl/chain.pem": {info: &stubFileInfo{size: int64(len(originalChain)), mode: 0644}},
},
readByPath: map[string][]byte{
"/etc/ssl/cert.pem": originalCert,
"/etc/ssl/key.pem": originalKey,
"/etc/ssl/chain.pem": originalChain,
},
// First Execute (reload) fails; second Execute (retry-reload after
// restore) succeeds — clean recoverable failure.
executeErrSequence: []error{fmt.Errorf("reload failed: exit status 1"), nil},
executeOutSequence: []string{"reload error output", "ok"},
}
cfg := &Config{
Host: "server.local",
Port: 22,
CertPath: "/etc/ssl/cert.pem",
KeyPath: "/etc/ssl/key.pem",
ChainPath: "/etc/ssl/chain.pem",
CertMode: "0644",
KeyMode: "0600",
ReloadCommand: "systemctl reload nginx",
}
c := NewWithClient(cfg, mock, testLogger())
req := target.DeploymentRequest{
CertPEM: "-----BEGIN CERTIFICATE-----\nNEW_CERT\n-----END CERTIFICATE-----\n",
KeyPEM: "-----BEGIN PRIVATE KEY-----\nNEW_KEY\n-----END PRIVATE KEY-----\n",
ChainPEM: "-----BEGIN CERTIFICATE-----\nNEW_CHAIN\n-----END CERTIFICATE-----\n",
}
result, err := c.DeployCertificate(context.Background(), req)
if err == nil {
t.Fatal("expected error when reload fails")
}
if result.Success {
t.Fatal("expected failure result")
}
// Error must mention reload failure + rollback success.
if !containsString(err.Error(), "reload command failed") && !containsString(err.Error(), "reload failed") {
t.Errorf("expected error to mention reload failure, got: %v", err)
}
if !containsString(err.Error(), "rolled back") {
t.Errorf("expected error to mention 'rolled back', got: %v", err)
}
// Build a path → bytes view of every WriteFile call for the assertions.
// On the success path the deploy writes new bytes; on the rollback path
// it writes the originals back. We expect each path to be written at
// least twice (once with new bytes, once with originals).
writesByPath := map[string][][]byte{}
for _, w := range mock.writeFileCalls {
writesByPath[w.Path] = append(writesByPath[w.Path], w.Data)
}
for _, path := range []string{"/etc/ssl/cert.pem", "/etc/ssl/key.pem", "/etc/ssl/chain.pem"} {
writes := writesByPath[path]
if len(writes) < 2 {
t.Errorf("expected at least 2 WriteFile calls for %s (deploy + restore), got %d", path, len(writes))
continue
}
// Last write to each path is the rollback restore — must equal
// the pre-existing bytes captured in the snapshot.
lastWrite := writes[len(writes)-1]
var want []byte
switch path {
case "/etc/ssl/cert.pem":
want = originalCert
case "/etc/ssl/key.pem":
want = originalKey
case "/etc/ssl/chain.pem":
want = originalChain
}
if string(lastWrite) != string(want) {
t.Errorf("rollback for %s did not restore original bytes:\n got: %q\n want: %q", path, lastWrite, want)
}
}
// No Remove calls — every path had a pre-existing snapshot to restore from.
if len(mock.removeCalls) != 0 {
t.Errorf("expected 0 Remove calls (all paths had backups), got %d: %v", len(mock.removeCalls), mock.removeCalls)
}
// Both Execute calls (initial reload + retry-reload after rollback)
// must have run.
if len(mock.executeCalls) != 2 {
t.Errorf("expected 2 Execute calls (reload + retry-reload), got %d", len(mock.executeCalls))
}
// Metadata reflects per-path snapshot status.
if result.Metadata["backup_status_cert"] != "restored" {
t.Errorf("expected backup_status_cert=restored, got %q", result.Metadata["backup_status_cert"])
}
if result.Metadata["backup_status_key"] != "restored" {
t.Errorf("expected backup_status_key=restored, got %q", result.Metadata["backup_status_key"])
}
if result.Metadata["backup_status_chain"] != "restored" {
t.Errorf("expected backup_status_chain=restored, got %q", result.Metadata["backup_status_chain"])
}
if result.Metadata["rolled_back"] != "true" {
t.Errorf("expected rolled_back=true, got %q", result.Metadata["rolled_back"])
}
}
func TestSSH_NoExistingCert_ReloadFails_NewCertRemoved(t *testing.T) {
mock := &mockSSHClient{
// All three paths report "no such file" — first-time deploy.
statByPath: map[string]statResponse{
"/etc/ssl/cert.pem": {err: fmt.Errorf("stat: %w", os.ErrNotExist)},
"/etc/ssl/key.pem": {err: fmt.Errorf("stat: %w", os.ErrNotExist)},
"/etc/ssl/chain.pem": {err: fmt.Errorf("stat: %w", os.ErrNotExist)},
},
// Reload fails; retry-reload after rollback succeeds.
executeErrSequence: []error{fmt.Errorf("reload failed"), nil},
executeOutSequence: []string{"reload error", "ok"},
}
cfg := &Config{
Host: "server.local",
Port: 22,
CertPath: "/etc/ssl/cert.pem",
KeyPath: "/etc/ssl/key.pem",
ChainPath: "/etc/ssl/chain.pem",
CertMode: "0644",
KeyMode: "0600",
ReloadCommand: "systemctl reload nginx",
}
c := NewWithClient(cfg, mock, testLogger())
req := target.DeploymentRequest{
CertPEM: "-----BEGIN CERTIFICATE-----\nNEW_CERT\n-----END CERTIFICATE-----\n",
KeyPEM: "-----BEGIN PRIVATE KEY-----\nNEW_KEY\n-----END PRIVATE KEY-----\n",
ChainPEM: "-----BEGIN CERTIFICATE-----\nNEW_CHAIN\n-----END CERTIFICATE-----\n",
}
result, err := c.DeployCertificate(context.Background(), req)
if err == nil {
t.Fatal("expected error when reload fails")
}
if result.Success {
t.Fatal("expected failure result")
}
// Rollback for first-time deploys must call Remove on every written path.
expectedRemoves := map[string]bool{
"/etc/ssl/cert.pem": true,
"/etc/ssl/key.pem": true,
"/etc/ssl/chain.pem": true,
}
if len(mock.removeCalls) != len(expectedRemoves) {
t.Errorf("expected %d Remove calls, got %d: %v", len(expectedRemoves), len(mock.removeCalls), mock.removeCalls)
}
for _, p := range mock.removeCalls {
if !expectedRemoves[p] {
t.Errorf("unexpected Remove path: %s", p)
}
}
// First-time deploy: WriteFile is called only during the initial
// deploy, never during rollback (no backup to restore from).
expectedWrites := 3 // cert + key + chain (all configured paths)
if len(mock.writeFileCalls) != expectedWrites {
t.Errorf("expected exactly %d WriteFile calls (deploy only, no restore), got %d", expectedWrites, len(mock.writeFileCalls))
}
// Metadata reflects "removed" status for all paths.
if result.Metadata["backup_status_cert"] != "removed" {
t.Errorf("expected backup_status_cert=removed, got %q", result.Metadata["backup_status_cert"])
}
if result.Metadata["backup_status_key"] != "removed" {
t.Errorf("expected backup_status_key=removed, got %q", result.Metadata["backup_status_key"])
}
if result.Metadata["backup_status_chain"] != "removed" {
t.Errorf("expected backup_status_chain=removed, got %q", result.Metadata["backup_status_chain"])
}
}
func TestSSH_ReloadFails_RollbackAlsoFails_OperatorActionable(t *testing.T) {
originalCert := []byte("ORIGINAL_CERT")
originalKey := []byte("ORIGINAL_KEY")
mock := &mockSSHClient{
statByPath: map[string]statResponse{
"/etc/ssl/cert.pem": {info: &stubFileInfo{size: int64(len(originalCert)), mode: 0644}},
"/etc/ssl/key.pem": {info: &stubFileInfo{size: int64(len(originalKey)), mode: 0600}},
},
readByPath: map[string][]byte{
"/etc/ssl/cert.pem": originalCert,
"/etc/ssl/key.pem": originalKey,
},
// Initial deploy WriteFile calls succeed; rollback's WriteFile to
// restore the cert FAILS. This injects the operator-actionable
// case: reload failed AND the restore can't complete.
writeFileErrByPath: map[string]error{},
executeErrSequence: []error{fmt.Errorf("reload step failed")},
executeOutSequence: []string{"reload error"},
}
// Track call count so we can fail only the SECOND WriteFile to
// /etc/ssl/cert.pem (i.e. the restore call, not the initial deploy
// write). Done via a wrapper because writeFileErrByPath is a flat map.
wrapped := &writeOrderTrackingMock{base: mock}
wrapped.failOnNthWriteForPath = map[string]int{
"/etc/ssl/cert.pem": 2, // 1st = deploy write (succeed); 2nd = restore (fail)
}
cfg := &Config{
Host: "server.local",
Port: 22,
CertPath: "/etc/ssl/cert.pem",
KeyPath: "/etc/ssl/key.pem",
CertMode: "0644",
KeyMode: "0600",
ReloadCommand: "systemctl reload nginx",
}
c := NewWithClient(cfg, wrapped, testLogger())
req := target.DeploymentRequest{
CertPEM: "NEW_CERT",
KeyPEM: "NEW_KEY",
}
result, err := c.DeployCertificate(context.Background(), req)
if err == nil {
t.Fatal("expected error when both reload and rollback fail")
}
if result.Success {
t.Fatal("expected failure result")
}
// Wrapped error must mention BOTH the reload error and the rollback error.
if !containsString(err.Error(), "reload failed") {
t.Errorf("expected error to mention reload failure, got: %v", err)
}
if !containsString(err.Error(), "rollback also failed") {
t.Errorf("expected error to mention 'rollback also failed', got: %v", err)
}
if !containsString(err.Error(), "manual operator inspection required") {
t.Errorf("expected error to flag manual inspection, got: %v", err)
}
// Metadata must surface manual_action_required + both error strings.
if result.Metadata["manual_action_required"] != "true" {
t.Errorf("expected manual_action_required=true, got %q", result.Metadata["manual_action_required"])
}
if result.Metadata["rolled_back"] != "false" {
t.Errorf("expected rolled_back=false, got %q", result.Metadata["rolled_back"])
}
if result.Metadata["rollback_error"] == "" {
t.Error("expected rollback_error in metadata")
}
}
func TestSSH_ReloadFails_RestoreThenSecondReloadFails(t *testing.T) {
originalCert := []byte("ORIGINAL_CERT")
originalKey := []byte("ORIGINAL_KEY")
mock := &mockSSHClient{
statByPath: map[string]statResponse{
"/etc/ssl/cert.pem": {info: &stubFileInfo{size: int64(len(originalCert)), mode: 0644}},
"/etc/ssl/key.pem": {info: &stubFileInfo{size: int64(len(originalKey)), mode: 0600}},
},
readByPath: map[string][]byte{
"/etc/ssl/cert.pem": originalCert,
"/etc/ssl/key.pem": originalKey,
},
// Both Execute calls (initial reload + retry-reload after rollback)
// fail. The remote files are back to pre-deploy state but the
// daemon may be in a stuck/partial state — operator needs to
// know that.
executeErrSequence: []error{fmt.Errorf("reload step 1 failed"), fmt.Errorf("reload step 2 failed")},
executeOutSequence: []string{"out1", "out2"},
}
cfg := &Config{
Host: "server.local",
Port: 22,
CertPath: "/etc/ssl/cert.pem",
KeyPath: "/etc/ssl/key.pem",
CertMode: "0644",
KeyMode: "0600",
ReloadCommand: "systemctl reload nginx",
}
c := NewWithClient(cfg, mock, testLogger())
req := target.DeploymentRequest{
CertPEM: "NEW_CERT",
KeyPEM: "NEW_KEY",
}
result, err := c.DeployCertificate(context.Background(), req)
if err == nil {
t.Fatal("expected error when retry-reload after rollback fails")
}
if result.Success {
t.Fatal("expected failure result")
}
// Wrapped error mentions reload failure, rollback success, and
// retry-reload failure — operator must understand the daemon may
// not be running the original config even though the files are back.
if !containsString(err.Error(), "rolled back files") {
t.Errorf("expected error to mention 'rolled back files', got: %v", err)
}
if !containsString(err.Error(), "retry-reload also failed") {
t.Errorf("expected error to mention retry-reload failure, got: %v", err)
}
if !containsString(err.Error(), "daemon may need manual restart") {
t.Errorf("expected error to flag daemon state, got: %v", err)
}
// Metadata flags daemon_state_unknown + rolled_back=true (files OK).
if result.Metadata["daemon_state_unknown"] != "true" {
t.Errorf("expected daemon_state_unknown=true, got %q", result.Metadata["daemon_state_unknown"])
}
if result.Metadata["rolled_back"] != "true" {
t.Errorf("expected rolled_back=true, got %q", result.Metadata["rolled_back"])
}
// Both Execute calls happened; both WriteFile-on-restore calls
// happened (cert + key restored).
if len(mock.executeCalls) != 2 {
t.Errorf("expected 2 Execute calls, got %d", len(mock.executeCalls))
}
}
// writeOrderTrackingMock wraps mockSSHClient to fail the Nth WriteFile
// for a given path. Used by TestSSH_ReloadFails_RollbackAlsoFails_-
// OperatorActionable to fail the restore (2nd write) while letting the
// initial deploy (1st write) succeed for the same path.
type writeOrderTrackingMock struct {
base *mockSSHClient
writeCountByPath map[string]int
failOnNthWriteForPath map[string]int
}
func (w *writeOrderTrackingMock) Connect(ctx context.Context) error { return w.base.Connect(ctx) }
func (w *writeOrderTrackingMock) WriteFile(remotePath string, data []byte, mode os.FileMode) error {
if w.writeCountByPath == nil {
w.writeCountByPath = map[string]int{}
}
w.writeCountByPath[remotePath]++
w.base.writeFileCalls = append(w.base.writeFileCalls, writeFileCall{Path: remotePath, Data: data, Mode: mode})
if n, ok := w.failOnNthWriteForPath[remotePath]; ok {
if w.writeCountByPath[remotePath] == n {
return fmt.Errorf("injected write failure on call %d to %s", n, remotePath)
}
}
return nil
}
func (w *writeOrderTrackingMock) Execute(ctx context.Context, cmd string) (string, error) {
return w.base.Execute(ctx, cmd)
}
func (w *writeOrderTrackingMock) StatFile(remotePath string) (os.FileInfo, error) {
return w.base.StatFile(remotePath)
}
func (w *writeOrderTrackingMock) ReadFile(remotePath string) ([]byte, error) {
return w.base.ReadFile(remotePath)
}
func (w *writeOrderTrackingMock) Remove(remotePath string) error { return w.base.Remove(remotePath) }
func (w *writeOrderTrackingMock) Close() error { return w.base.Close() }
// --- ValidateDeployment tests ---
func TestValidateDeployment_Success(t *testing.T) {
mock := &mockSSHClient{statFileSize: 2048}
cfg := &Config{
Host: "server.local",
Port: 22,
CertPath: "/etc/ssl/cert.pem",
KeyPath: "/etc/ssl/key.pem",
CertMode: "0644",
KeyMode: "0600",
}
c := NewWithClient(cfg, mock, testLogger())
req := target.ValidationRequest{
CertificateID: "mc-test",
Serial: "ABC123",
}
result, err := c.ValidateDeployment(context.Background(), req)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if !result.Valid {
t.Fatalf("expected valid, got %s", result.Message)
}
// Should have stat'd both files
if len(mock.statFileCalls) != 2 {
t.Fatalf("expected 2 stat calls, got %d", len(mock.statFileCalls))
}
if mock.statFileCalls[0] != "/etc/ssl/cert.pem" {
t.Errorf("expected cert path, got %s", mock.statFileCalls[0])
}
if mock.statFileCalls[1] != "/etc/ssl/key.pem" {
t.Errorf("expected key path, got %s", mock.statFileCalls[1])
}
}
func TestValidateDeployment_CertNotFound(t *testing.T) {
mock := &mockSSHClient{statFileErr: fmt.Errorf("file not found")}
cfg := &Config{
Host: "server.local",
Port: 22,
CertPath: "/etc/ssl/cert.pem",
KeyPath: "/etc/ssl/key.pem",
CertMode: "0644",
KeyMode: "0600",
}
c := NewWithClient(cfg, mock, testLogger())
req := target.ValidationRequest{
CertificateID: "mc-test",
Serial: "ABC123",
}
result, err := c.ValidateDeployment(context.Background(), req)
if err == nil {
t.Fatal("expected error for missing cert")
}
if result.Valid {
t.Fatal("expected invalid result")
}
}
func TestValidateDeployment_ConnectionFailure(t *testing.T) {
mock := &mockSSHClient{connectErr: fmt.Errorf("connection refused")}
cfg := &Config{
Host: "unreachable.local",
Port: 22,
CertPath: "/etc/ssl/cert.pem",
KeyPath: "/etc/ssl/key.pem",
CertMode: "0644",
KeyMode: "0600",
}
c := NewWithClient(cfg, mock, testLogger())
req := target.ValidationRequest{
CertificateID: "mc-test",
Serial: "ABC123",
}
result, err := c.ValidateDeployment(context.Background(), req)
if err == nil {
t.Fatal("expected error for connection failure")
}
if result.Valid {
t.Fatal("expected invalid result")
}
}
// --- Helper tests ---
func TestParsePermissions(t *testing.T) {
tests := []struct {
input string
expected os.FileMode
wantErr bool
}{
{"0644", 0644, false},
{"0600", 0600, false},
{"0755", 0755, false},
{"invalid", 0, true},
}
for _, tc := range tests {
t.Run(tc.input, func(t *testing.T) {
mode, err := parsePermissions(tc.input)
if tc.wantErr && err == nil {
t.Fatal("expected error")
}
if !tc.wantErr && err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !tc.wantErr && mode != tc.expected {
t.Errorf("expected %v, got %v", tc.expected, mode)
}
})
}
}
func TestApplyDefaults(t *testing.T) {
cfg := &Config{}
applyDefaults(cfg)
if cfg.Port != 22 {
t.Errorf("expected port 22, got %d", cfg.Port)
}
if cfg.AuthMethod != "key" {
t.Errorf("expected auth_method key, got %s", cfg.AuthMethod)
}
if cfg.CertMode != "0644" {
t.Errorf("expected cert_mode 0644, got %s", cfg.CertMode)
}
if cfg.KeyMode != "0600" {
t.Errorf("expected key_mode 0600, got %s", cfg.KeyMode)
}
if cfg.Timeout != 30 {
t.Errorf("expected timeout 30, got %d", cfg.Timeout)
}
}
// TestDeployCertificate_FullChainMode tests that when ChainPath is not set but
// ChainPEM is provided, the chain is appended to the certificate data before writing.
func TestDeployCertificate_FullChainMode(t *testing.T) {
keyFile := createTempKeyFile(t)
cfg := &Config{
Host: "example.com",
Port: 22,
User: "deploy",
AuthMethod: "key",
PrivateKeyPath: keyFile,
CertPath: "/etc/ssl/certs/cert.pem",
KeyPath: "/etc/ssl/private/key.pem",
ChainPath: "", // Not set, so chain should be appended to cert
CertMode: "0644",
KeyMode: "0600",
Timeout: 30,
}
mock := &mockSSHClient{}
connector := NewWithClient(cfg, mock, testLogger())
deployReq := target.DeploymentRequest{
CertPEM: "-----BEGIN CERTIFICATE-----\nMIIBk...\n-----END CERTIFICATE-----",
KeyPEM: "-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----",
ChainPEM: "-----BEGIN CERTIFICATE-----\nMIIBj...\n-----END CERTIFICATE-----",
}
result, err := connector.DeployCertificate(context.Background(), deployReq)
if err != nil {
t.Fatalf("deployment failed: %v", err)
}
if !result.Success {
t.Fatalf("deployment result was not successful: %s", result.Message)
}
// Verify that the cert file received contains both cert and chain concatenated
if len(mock.writeFileCalls) < 2 {
t.Fatalf("expected at least 2 WriteFile calls, got %d", len(mock.writeFileCalls))
}
certWriteCall := mock.writeFileCalls[0]
if certWriteCall.Path != "/etc/ssl/certs/cert.pem" {
t.Errorf("expected cert path /etc/ssl/certs/cert.pem, got %s", certWriteCall.Path)
}
certData := string(certWriteCall.Data)
if !containsString(certData, "BEGIN CERTIFICATE") || !containsString(certData, "END CERTIFICATE") {
t.Errorf("cert data should contain combined cert and chain")
}
// Verify chain was not written separately (since ChainPath is empty)
if len(mock.writeFileCalls) > 2 {
t.Errorf("expected only 2 WriteFile calls (cert + key), got %d", len(mock.writeFileCalls))
}
}
// TestDeployCertificate_Permissions tests that the correct file permissions are
// passed to WriteFile for both certificate and key files.
func TestDeployCertificate_Permissions(t *testing.T) {
keyFile := createTempKeyFile(t)
cfg := &Config{
Host: "example.com",
Port: 22,
User: "deploy",
AuthMethod: "key",
PrivateKeyPath: keyFile,
CertPath: "/etc/ssl/certs/cert.pem",
KeyPath: "/etc/ssl/private/key.pem",
ChainPath: "",
CertMode: "0644",
KeyMode: "0600",
Timeout: 30,
}
mock := &mockSSHClient{}
connector := NewWithClient(cfg, mock, testLogger())
deployReq := target.DeploymentRequest{
CertPEM: "-----BEGIN CERTIFICATE-----\nMIIBk...\n-----END CERTIFICATE-----",
KeyPEM: "-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----",
ChainPEM: "",
}
_, err := connector.DeployCertificate(context.Background(), deployReq)
if err != nil {
t.Fatalf("deployment failed: %v", err)
}
if len(mock.writeFileCalls) < 2 {
t.Fatalf("expected at least 2 WriteFile calls, got %d", len(mock.writeFileCalls))
}
// Check cert file permissions (0644 = rw-r--r--)
certMode := mock.writeFileCalls[0].Mode
expectedCertMode := os.FileMode(0644)
if certMode != expectedCertMode {
t.Errorf("expected cert mode 0644, got %o", certMode)
}
// Check key file permissions (0600 = rw-------)
keyMode := mock.writeFileCalls[1].Mode
expectedKeyMode := os.FileMode(0600)
if keyMode != expectedKeyMode {
t.Errorf("expected key mode 0600, got %o", keyMode)
}
}
// TestValidateDeployment_KeyNotFound tests that ValidateDeployment fails when
// the key file is not found on the remote server.
func TestValidateDeployment_KeyNotFound(t *testing.T) {
keyFile := createTempKeyFile(t)
cfg := &Config{
Host: "example.com",
Port: 22,
User: "deploy",
AuthMethod: "key",
PrivateKeyPath: keyFile,
CertPath: "/etc/ssl/certs/cert.pem",
KeyPath: "/etc/ssl/private/key.pem",
ChainPath: "",
CertMode: "0644",
KeyMode: "0600",
Timeout: 30,
}
// Create a custom mock that succeeds for cert but fails for key
mock := &conditionalStatMockSSHClient{
base: &mockSSHClient{},
}
connector := NewWithClient(cfg, mock, testLogger())
valReq := target.ValidationRequest{
Serial: "11111",
}
result, err := connector.ValidateDeployment(context.Background(), valReq)
if err == nil {
t.Error("expected validation to fail when key file is not found")
}
if result.Valid {
t.Error("expected Valid=false when key file is missing")
}
if !containsString(result.Message, "key file not found") {
t.Errorf("expected 'key file not found' in message, got: %s", result.Message)
}
}
// conditionalStatMockSSHClient wraps mockSSHClient to fail on key path during StatFile.
type conditionalStatMockSSHClient struct {
base *mockSSHClient
callCount int
}
func (m *conditionalStatMockSSHClient) Connect(ctx context.Context) error {
return m.base.Connect(ctx)
}
func (m *conditionalStatMockSSHClient) WriteFile(remotePath string, data []byte, mode os.FileMode) error {
return m.base.WriteFile(remotePath, data, mode)
}
func (m *conditionalStatMockSSHClient) Execute(ctx context.Context, command string) (string, error) {
return m.base.Execute(ctx, command)
}
func (m *conditionalStatMockSSHClient) StatFile(remotePath string) (os.FileInfo, error) {
m.callCount++
// First call succeeds (cert), second call fails (key) — wrap
// os.ErrNotExist so the connector's errors.Is check propagates the
// "file not found" semantics through the Bundle 6 stat-error
// handling.
if m.callCount == 2 {
return nil, fmt.Errorf("file not found: %w", os.ErrNotExist)
}
return &stubFileInfo{size: 1024, mode: 0644}, nil
}
func (m *conditionalStatMockSSHClient) ReadFile(remotePath string) ([]byte, error) {
return m.base.ReadFile(remotePath)
}
func (m *conditionalStatMockSSHClient) Remove(remotePath string) error {
return m.base.Remove(remotePath)
}
func (m *conditionalStatMockSSHClient) Close() error {
return m.base.Close()
}
// --- Helpers ---
// createTempKeyFile creates a temporary file that simulates an SSH private key.
func createTempKeyFile(t *testing.T) string {
t.Helper()
dir := t.TempDir()
keyFile := dir + "/id_rsa"
if err := os.WriteFile(keyFile, []byte("fake-key-data"), 0600); err != nil {
t.Fatalf("failed to create temp key file: %v", err)
}
return keyFile
}
// containsString is a helper to check if a string contains a substring.
func containsString(s, substr string) bool {
return len(s) >= len(substr) && stringIndex(s, substr) != -1
}
// stringIndex returns the index of the first occurrence of substr in s, or -1 if not found.
func stringIndex(s, substr string) int {
for i := 0; i <= len(s)-len(substr); i++ {
match := true
for j := 0; j < len(substr); j++ {
if s[i+j] != substr[j] {
match = false
break
}
}
if match {
return i
}
}
return -1
}