mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-07 15:01:32 +00:00
636de7f6b5
Closes Bundle 6 of the 2026-05-02 deployment-target coverage audit
(see cowork/deployment-target-audit-2026-05-02/RESULTS.md). Pre-fix,
DeployCertificate at ssh.go:201-316 wrote new cert/key/chain via
SFTP then ran the operator's reload command. If reload failed, the
new files stayed on the remote — partial-success state with no
rollback path. docs/deployment-atomicity.md L92 promised "Pre-deploy
SCP backup of remote files"; the code didn't deliver.
This commit:
1. Pre-deploy snapshot. Before any WriteFile, iterate the deploy's
target paths (cert, key, optional chain). For each path:
- StatFile to detect existence. errors.Is(err, os.ErrNotExist)
means first-time deploy (rollback = Remove). Other stat
errors bail out before any write happens.
- ReadFile into an in-memory backups map[string][]byte keyed
by remote path. Original mode captured into a parallel
modes map for restore fidelity.
2. SSHClient interface evolution — three changes:
- StatFile(path) (os.FileInfo, error) — was (int64, error).
FileInfo carries Mode() needed for accurate restore. Existing
fixture tests updated to call info.Size() instead of the
bare size value.
- ReadFile(path) ([]byte, error) — new method; SFTP Open + read
via io.ReadAll. realSSHClient implements via sftpClient.Open.
- Remove(path) error — new method; SFTP Remove. Used by the
rollback path to clean up first-time-deploy partial state.
3. On-reload-failure rollback. Replace the bare error-return at
L282-295 with restoreFromBackups + retry-reload escalation:
- For paths in the snapshot map, WriteFile the original bytes
with the original mode (0600 fallback if mode capture was
incomplete).
- For paths that didn't exist pre-deploy, Remove the new file.
- Re-run the reload command (best-effort second attempt). If
it succeeds, the target is back to pre-deploy state. If it
fails, the remote is in pre-deploy file state but the daemon
may be stuck — surface as wrapped error so the operator
knows where to look.
4. DeploymentResult.Metadata gains backup_status_{cert,key,chain}
so operators can see per-path snapshot state on both success
("snapshotted" / "no_pre_existing" / "n/a") and failure
("restored" / "removed" / "restore_failed" / "remove_failed").
buildMetadataWithBackup helper centralises the metadata
shape so success and failure paths emit a consistent set
of keys.
5. Helper extraction. restoreFromBackups(ctx, paths, backups,
modes) is a private method on Connector; returns the first
error + per-key restore status map for clean test seams.
DeploymentResult shape on failure:
- rollback OK + retry-reload OK → Success=false, "reload command
failed; rolled back to pre-deploy state" (clean recoverable
failure; remote fully restored, daemon serving original cert).
- rollback OK + retry-reload FAIL → wrapped error noting "rolled
back files; retry-reload also failed; daemon may need manual
restart". Metadata flags daemon_state_unknown=true.
- rollback FAIL → operator-actionable wrapped error containing
BOTH the reload error AND the rollback error; metadata flags
manual_action_required=true.
Tests added to ssh_test.go (4 new tests, ~330 LOC):
- TestSSH_ReloadFails_FilesRestored — happy rollback path with
pre-existing remote bytes for cert/key/chain. Asserts every
path's last WriteFile call contains the captured backup bytes
verbatim, no Remove calls fired (all paths had snapshots), and
metadata reports backup_status=restored for each path.
- TestSSH_NoExistingCert_ReloadFails_NewCertRemoved — first-time
deploy variant. StatFile returns os.ErrNotExist for every path;
rollback Removes each written file but performs no WriteFile
during restore (no backup to restore from). Asserts exactly 3
WriteFile calls (deploy only) and 3 Remove calls (rollback).
- TestSSH_ReloadFails_RollbackAlsoFails_OperatorActionable —
uses a writeOrderTrackingMock to fail the SECOND WriteFile to
the cert path (i.e. the restore call, not the initial deploy).
Asserts wrapped error contains both the reload error and the
rollback error, and metadata flags manual_action_required=true.
- TestSSH_ReloadFails_RestoreThenSecondReloadFails — partial-
recovery escalation. Rollback succeeds but the post-restore
retry-reload fails. Asserts wrapped error mentions "rolled back
files; retry-reload also failed" and metadata flags
daemon_state_unknown=true.
Existing tests preserved by extending mockSSHClient with backward-
compatible per-path response maps (statByPath / readByPath /
writeFileErrByPath / executeErrSequence). Legacy global fields
(statFileSize / statFileErr / writeFileErr / executeErr) still
work when no per-path override matches, so TestValidateConfig_*
and TestDeployCertificate_Success_* don't need changes.
docs/deployment-atomicity.md L92 unchanged from today's text —
Bundle 1 doc-realignment hasn't shipped, so the "Pre-deploy SCP
backup of remote files" line was never softened. Post-Bundle-6
the claim is honest (was aspirational pre-fix).
Verified locally (sandbox lacks staticcheck install due to disk
pressure; CI runs the full lint gate):
- gofmt -l ./internal/connector/target/ssh/ clean
- go vet ./internal/connector/target/ssh/ clean
- go build ./internal/connector/target/ssh/... clean
- go build ./cmd/agent/... clean
- go test -race -count=1 ./internal/connector/target/ssh/ green
Audit reference: cowork/deployment-target-audit-2026-05-02/RESULTS.md
Bundle 6.
1417 lines
44 KiB
Go
1417 lines
44 KiB
Go
package ssh
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/shankar0123/certctl/internal/connector/target"
|
|
)
|
|
|
|
// stubFileInfo implements os.FileInfo for tests that need to return a
|
|
// FileInfo from the mock SSHClient's StatFile. Bundle 6 of the
|
|
// 2026-05-02 deployment-target audit evolved StatFile's signature from
|
|
// (int64, error) to (os.FileInfo, error) so the pre-deploy snapshot
|
|
// can capture the original mode for accurate rollback restoration.
|
|
type stubFileInfo struct {
|
|
size int64
|
|
mode os.FileMode
|
|
name string
|
|
}
|
|
|
|
func (s *stubFileInfo) Name() string { return s.name }
|
|
func (s *stubFileInfo) Size() int64 { return s.size }
|
|
func (s *stubFileInfo) Mode() os.FileMode { return s.mode }
|
|
func (s *stubFileInfo) ModTime() time.Time { return time.Time{} }
|
|
func (s *stubFileInfo) IsDir() bool { return false }
|
|
func (s *stubFileInfo) Sys() any { return nil }
|
|
|
|
// testLogger returns a slog.Logger for test output.
|
|
func testLogger() *slog.Logger {
|
|
return slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
|
|
}
|
|
|
|
// --- Mock SSH Client ---
|
|
|
|
// mockSSHClient records all calls and returns configurable results.
|
|
//
|
|
// Bundle 6 of the 2026-05-02 deployment-target audit added per-path
|
|
// response maps (statByPath / readByPath / writeFileErrByPath) so the
|
|
// new snapshot/rollback tests can simulate (a) pre-existing remote
|
|
// files for the snapshot to read, (b) per-call WriteFile failures to
|
|
// inject restore-failure paths, and (c) sequenced Execute errors so
|
|
// the reload-then-retry-reload tests can drive both calls
|
|
// independently. The legacy global fields (statFileSize / statFileErr
|
|
// / writeFileErr / executeErr) are still honored when no per-path
|
|
// override matches, so existing tests remain green.
|
|
type mockSSHClient struct {
|
|
connectCalls int
|
|
connectErr error
|
|
writeFileCalls []writeFileCall
|
|
writeFileErr error
|
|
writeFileErrByPath map[string]error // per-path WriteFile error overrides
|
|
executeCalls []string
|
|
executeOutput string
|
|
executeErr error
|
|
executeErrSequence []error // per-call Execute errors; falls back to executeErr after exhaustion
|
|
executeOutSequence []string // per-call Execute outputs; mirrors executeErrSequence
|
|
statFileCalls []string
|
|
statFileSize int64
|
|
statFileErr error
|
|
statByPath map[string]statResponse // per-path StatFile responses
|
|
readByPath map[string][]byte // per-path ReadFile bytes (existence implies success)
|
|
readErrByPath map[string]error // per-path ReadFile error overrides
|
|
removeCalls []string
|
|
removeErr error
|
|
removeErrByPath map[string]error
|
|
closeCalls int
|
|
}
|
|
|
|
type writeFileCall struct {
|
|
Path string
|
|
Data []byte
|
|
Mode os.FileMode
|
|
}
|
|
|
|
type statResponse struct {
|
|
info os.FileInfo
|
|
err error
|
|
}
|
|
|
|
func (m *mockSSHClient) Connect(ctx context.Context) error {
|
|
m.connectCalls++
|
|
return m.connectErr
|
|
}
|
|
|
|
func (m *mockSSHClient) WriteFile(remotePath string, data []byte, mode os.FileMode) error {
|
|
m.writeFileCalls = append(m.writeFileCalls, writeFileCall{Path: remotePath, Data: data, Mode: mode})
|
|
if m.writeFileErrByPath != nil {
|
|
if err, ok := m.writeFileErrByPath[remotePath]; ok {
|
|
return err
|
|
}
|
|
}
|
|
return m.writeFileErr
|
|
}
|
|
|
|
func (m *mockSSHClient) Execute(ctx context.Context, command string) (string, error) {
|
|
idx := len(m.executeCalls)
|
|
m.executeCalls = append(m.executeCalls, command)
|
|
if idx < len(m.executeErrSequence) {
|
|
out := ""
|
|
if idx < len(m.executeOutSequence) {
|
|
out = m.executeOutSequence[idx]
|
|
}
|
|
return out, m.executeErrSequence[idx]
|
|
}
|
|
return m.executeOutput, m.executeErr
|
|
}
|
|
|
|
func (m *mockSSHClient) StatFile(remotePath string) (os.FileInfo, error) {
|
|
m.statFileCalls = append(m.statFileCalls, remotePath)
|
|
if m.statByPath != nil {
|
|
if resp, ok := m.statByPath[remotePath]; ok {
|
|
return resp.info, resp.err
|
|
}
|
|
}
|
|
if m.statFileErr != nil {
|
|
return nil, m.statFileErr
|
|
}
|
|
// Default: synthesise a FileInfo with the legacy size + a sane mode.
|
|
return &stubFileInfo{size: m.statFileSize, mode: 0644, name: remotePath}, nil
|
|
}
|
|
|
|
func (m *mockSSHClient) ReadFile(remotePath string) ([]byte, error) {
|
|
if m.readErrByPath != nil {
|
|
if err, ok := m.readErrByPath[remotePath]; ok {
|
|
return nil, err
|
|
}
|
|
}
|
|
if m.readByPath != nil {
|
|
if data, ok := m.readByPath[remotePath]; ok {
|
|
return data, nil
|
|
}
|
|
}
|
|
// Default: empty bytes, no error. Tests that don't exercise the
|
|
// snapshot path see this fall-through (the read still succeeds so
|
|
// the snapshot phase doesn't block their deploy hot path).
|
|
return []byte{}, nil
|
|
}
|
|
|
|
func (m *mockSSHClient) Remove(remotePath string) error {
|
|
m.removeCalls = append(m.removeCalls, remotePath)
|
|
if m.removeErrByPath != nil {
|
|
if err, ok := m.removeErrByPath[remotePath]; ok {
|
|
return err
|
|
}
|
|
}
|
|
return m.removeErr
|
|
}
|
|
|
|
func (m *mockSSHClient) Close() error {
|
|
m.closeCalls++
|
|
return nil
|
|
}
|
|
|
|
// --- ValidateConfig tests ---
|
|
|
|
func TestValidateConfig_Success_KeyAuth(t *testing.T) {
|
|
// Create a temporary key file
|
|
keyFile := createTempKeyFile(t)
|
|
|
|
cfg := map[string]interface{}{
|
|
"host": "server.example.com",
|
|
"user": "deploy",
|
|
"auth_method": "key",
|
|
"private_key_path": keyFile,
|
|
"cert_path": "/etc/ssl/certs/cert.pem",
|
|
"key_path": "/etc/ssl/private/key.pem",
|
|
}
|
|
|
|
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
|
|
raw, _ := json.Marshal(cfg)
|
|
err := c.ValidateConfig(context.Background(), raw)
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
|
|
if c.config.Port != 22 {
|
|
t.Errorf("expected default port 22, got %d", c.config.Port)
|
|
}
|
|
if c.config.CertMode != "0644" {
|
|
t.Errorf("expected default cert_mode 0644, got %s", c.config.CertMode)
|
|
}
|
|
if c.config.KeyMode != "0600" {
|
|
t.Errorf("expected default key_mode 0600, got %s", c.config.KeyMode)
|
|
}
|
|
if c.config.Timeout != 30 {
|
|
t.Errorf("expected default timeout 30, got %d", c.config.Timeout)
|
|
}
|
|
}
|
|
|
|
func TestValidateConfig_Success_InlineKey(t *testing.T) {
|
|
cfg := map[string]interface{}{
|
|
"host": "10.0.0.5",
|
|
"user": "root",
|
|
"auth_method": "key",
|
|
"private_key": "-----BEGIN OPENSSH PRIVATE KEY-----\nfakekey\n-----END OPENSSH PRIVATE KEY-----",
|
|
"cert_path": "/etc/ssl/cert.pem",
|
|
"key_path": "/etc/ssl/key.pem",
|
|
}
|
|
|
|
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
|
|
raw, _ := json.Marshal(cfg)
|
|
err := c.ValidateConfig(context.Background(), raw)
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestValidateConfig_Success_PasswordAuth(t *testing.T) {
|
|
cfg := map[string]interface{}{
|
|
"host": "server.local",
|
|
"user": "deploy",
|
|
"auth_method": "password",
|
|
"password": "s3cret",
|
|
"cert_path": "/etc/ssl/cert.pem",
|
|
"key_path": "/etc/ssl/key.pem",
|
|
}
|
|
|
|
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
|
|
raw, _ := json.Marshal(cfg)
|
|
err := c.ValidateConfig(context.Background(), raw)
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestValidateConfig_InvalidJSON(t *testing.T) {
|
|
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
|
|
err := c.ValidateConfig(context.Background(), json.RawMessage(`{invalid`))
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid JSON")
|
|
}
|
|
}
|
|
|
|
func TestValidateConfig_MissingHost(t *testing.T) {
|
|
cfg := map[string]interface{}{
|
|
"user": "deploy",
|
|
"cert_path": "/etc/ssl/cert.pem",
|
|
"key_path": "/etc/ssl/key.pem",
|
|
}
|
|
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
|
|
raw, _ := json.Marshal(cfg)
|
|
err := c.ValidateConfig(context.Background(), raw)
|
|
if err == nil {
|
|
t.Fatal("expected error for missing host")
|
|
}
|
|
}
|
|
|
|
func TestValidateConfig_MissingUser(t *testing.T) {
|
|
cfg := map[string]interface{}{
|
|
"host": "server.local",
|
|
"cert_path": "/etc/ssl/cert.pem",
|
|
"key_path": "/etc/ssl/key.pem",
|
|
}
|
|
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
|
|
raw, _ := json.Marshal(cfg)
|
|
err := c.ValidateConfig(context.Background(), raw)
|
|
if err == nil {
|
|
t.Fatal("expected error for missing user")
|
|
}
|
|
}
|
|
|
|
func TestValidateConfig_MissingCertPath(t *testing.T) {
|
|
cfg := map[string]interface{}{
|
|
"host": "server.local",
|
|
"user": "deploy",
|
|
"key_path": "/etc/ssl/key.pem",
|
|
}
|
|
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
|
|
raw, _ := json.Marshal(cfg)
|
|
err := c.ValidateConfig(context.Background(), raw)
|
|
if err == nil {
|
|
t.Fatal("expected error for missing cert_path")
|
|
}
|
|
}
|
|
|
|
func TestValidateConfig_MissingKeyPath(t *testing.T) {
|
|
cfg := map[string]interface{}{
|
|
"host": "server.local",
|
|
"user": "deploy",
|
|
"cert_path": "/etc/ssl/cert.pem",
|
|
}
|
|
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
|
|
raw, _ := json.Marshal(cfg)
|
|
err := c.ValidateConfig(context.Background(), raw)
|
|
if err == nil {
|
|
t.Fatal("expected error for missing key_path")
|
|
}
|
|
}
|
|
|
|
func TestValidateConfig_KeyAuth_MissingKey(t *testing.T) {
|
|
cfg := map[string]interface{}{
|
|
"host": "server.local",
|
|
"user": "deploy",
|
|
"auth_method": "key",
|
|
"cert_path": "/etc/ssl/cert.pem",
|
|
"key_path": "/etc/ssl/key.pem",
|
|
}
|
|
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
|
|
raw, _ := json.Marshal(cfg)
|
|
err := c.ValidateConfig(context.Background(), raw)
|
|
if err == nil {
|
|
t.Fatal("expected error for key auth missing both private_key and private_key_path")
|
|
}
|
|
}
|
|
|
|
func TestValidateConfig_PasswordAuth_MissingPassword(t *testing.T) {
|
|
cfg := map[string]interface{}{
|
|
"host": "server.local",
|
|
"user": "deploy",
|
|
"auth_method": "password",
|
|
"cert_path": "/etc/ssl/cert.pem",
|
|
"key_path": "/etc/ssl/key.pem",
|
|
}
|
|
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
|
|
raw, _ := json.Marshal(cfg)
|
|
err := c.ValidateConfig(context.Background(), raw)
|
|
if err == nil {
|
|
t.Fatal("expected error for password auth missing password")
|
|
}
|
|
}
|
|
|
|
func TestValidateConfig_InvalidHost(t *testing.T) {
|
|
cfg := map[string]interface{}{
|
|
"host": "server;rm -rf /",
|
|
"user": "deploy",
|
|
"cert_path": "/etc/ssl/cert.pem",
|
|
"key_path": "/etc/ssl/key.pem",
|
|
"private_key": "fake",
|
|
}
|
|
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
|
|
raw, _ := json.Marshal(cfg)
|
|
err := c.ValidateConfig(context.Background(), raw)
|
|
if err == nil {
|
|
t.Fatal("expected error for host with shell metacharacters")
|
|
}
|
|
}
|
|
|
|
func TestValidateConfig_InvalidPermissions(t *testing.T) {
|
|
keyFile := createTempKeyFile(t)
|
|
cfg := map[string]interface{}{
|
|
"host": "server.local",
|
|
"user": "deploy",
|
|
"private_key_path": keyFile,
|
|
"cert_path": "/etc/ssl/cert.pem",
|
|
"key_path": "/etc/ssl/key.pem",
|
|
"cert_mode": "999",
|
|
}
|
|
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
|
|
raw, _ := json.Marshal(cfg)
|
|
err := c.ValidateConfig(context.Background(), raw)
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid cert_mode")
|
|
}
|
|
}
|
|
|
|
func TestValidateConfig_ReloadCommandInjection(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
command string
|
|
}{
|
|
{"semicolon", "systemctl reload nginx; rm -rf /"},
|
|
{"pipe", "systemctl reload nginx | cat"},
|
|
{"backtick", "systemctl reload `malicious`"},
|
|
{"command substitution", "systemctl reload $(evil)"},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
keyFile := createTempKeyFile(t)
|
|
cfg := map[string]interface{}{
|
|
"host": "server.local",
|
|
"user": "deploy",
|
|
"private_key_path": keyFile,
|
|
"cert_path": "/etc/ssl/cert.pem",
|
|
"key_path": "/etc/ssl/key.pem",
|
|
"reload_command": tc.command,
|
|
}
|
|
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
|
|
raw, _ := json.Marshal(cfg)
|
|
err := c.ValidateConfig(context.Background(), raw)
|
|
if err == nil {
|
|
t.Fatalf("expected error for reload command injection: %q", tc.command)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestValidateConfig_InvalidAuthMethod(t *testing.T) {
|
|
cfg := map[string]interface{}{
|
|
"host": "server.local",
|
|
"user": "deploy",
|
|
"auth_method": "kerberos",
|
|
"cert_path": "/etc/ssl/cert.pem",
|
|
"key_path": "/etc/ssl/key.pem",
|
|
}
|
|
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
|
|
raw, _ := json.Marshal(cfg)
|
|
err := c.ValidateConfig(context.Background(), raw)
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid auth method")
|
|
}
|
|
}
|
|
|
|
func TestValidateConfig_KeyFileNotFound(t *testing.T) {
|
|
cfg := map[string]interface{}{
|
|
"host": "server.local",
|
|
"user": "deploy",
|
|
"auth_method": "key",
|
|
"private_key_path": "/nonexistent/key.pem",
|
|
"cert_path": "/etc/ssl/cert.pem",
|
|
"key_path": "/etc/ssl/key.pem",
|
|
}
|
|
c := NewWithClient(&Config{}, &mockSSHClient{}, testLogger())
|
|
raw, _ := json.Marshal(cfg)
|
|
err := c.ValidateConfig(context.Background(), raw)
|
|
if err == nil {
|
|
t.Fatal("expected error for nonexistent key file")
|
|
}
|
|
}
|
|
|
|
// --- DeployCertificate tests ---
|
|
|
|
func TestDeployCertificate_Success_NoChainPath(t *testing.T) {
|
|
mock := &mockSSHClient{statFileSize: 1024}
|
|
cfg := &Config{
|
|
Host: "server.local",
|
|
Port: 22,
|
|
CertPath: "/etc/ssl/cert.pem",
|
|
KeyPath: "/etc/ssl/key.pem",
|
|
CertMode: "0644",
|
|
KeyMode: "0600",
|
|
}
|
|
c := NewWithClient(cfg, mock, testLogger())
|
|
|
|
req := target.DeploymentRequest{
|
|
CertPEM: "-----BEGIN CERTIFICATE-----\ncert\n-----END CERTIFICATE-----",
|
|
KeyPEM: "-----BEGIN PRIVATE KEY-----\nkey\n-----END PRIVATE KEY-----",
|
|
ChainPEM: "-----BEGIN CERTIFICATE-----\nchain\n-----END CERTIFICATE-----",
|
|
}
|
|
|
|
result, err := c.DeployCertificate(context.Background(), req)
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
if !result.Success {
|
|
t.Fatalf("expected success, got %s", result.Message)
|
|
}
|
|
|
|
// Should have 2 writes (cert with chain appended, key)
|
|
if len(mock.writeFileCalls) != 2 {
|
|
t.Fatalf("expected 2 write calls, got %d", len(mock.writeFileCalls))
|
|
}
|
|
|
|
// Cert should include chain (fullchain)
|
|
certWrite := mock.writeFileCalls[0]
|
|
if certWrite.Path != "/etc/ssl/cert.pem" {
|
|
t.Errorf("expected cert path /etc/ssl/cert.pem, got %s", certWrite.Path)
|
|
}
|
|
if certWrite.Mode != 0644 {
|
|
t.Errorf("expected cert mode 0644, got %v", certWrite.Mode)
|
|
}
|
|
certContent := string(certWrite.Data)
|
|
if len(certContent) == 0 {
|
|
t.Error("cert data should not be empty")
|
|
}
|
|
|
|
// Key write
|
|
keyWrite := mock.writeFileCalls[1]
|
|
if keyWrite.Path != "/etc/ssl/key.pem" {
|
|
t.Errorf("expected key path /etc/ssl/key.pem, got %s", keyWrite.Path)
|
|
}
|
|
if keyWrite.Mode != 0600 {
|
|
t.Errorf("expected key mode 0600, got %v", keyWrite.Mode)
|
|
}
|
|
|
|
// Metadata
|
|
if result.Metadata["host"] != "server.local" {
|
|
t.Errorf("expected host metadata server.local, got %s", result.Metadata["host"])
|
|
}
|
|
}
|
|
|
|
func TestDeployCertificate_Success_SeparateChain(t *testing.T) {
|
|
mock := &mockSSHClient{}
|
|
cfg := &Config{
|
|
Host: "server.local",
|
|
Port: 22,
|
|
CertPath: "/etc/ssl/cert.pem",
|
|
KeyPath: "/etc/ssl/key.pem",
|
|
ChainPath: "/etc/ssl/chain.pem",
|
|
CertMode: "0644",
|
|
KeyMode: "0600",
|
|
}
|
|
c := NewWithClient(cfg, mock, testLogger())
|
|
|
|
req := target.DeploymentRequest{
|
|
CertPEM: "cert-data",
|
|
KeyPEM: "key-data",
|
|
ChainPEM: "chain-data",
|
|
}
|
|
|
|
result, err := c.DeployCertificate(context.Background(), req)
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
if !result.Success {
|
|
t.Fatalf("expected success, got %s", result.Message)
|
|
}
|
|
|
|
// Should have 3 writes (cert, key, chain)
|
|
if len(mock.writeFileCalls) != 3 {
|
|
t.Fatalf("expected 3 write calls, got %d", len(mock.writeFileCalls))
|
|
}
|
|
|
|
// Chain should be separate
|
|
chainWrite := mock.writeFileCalls[2]
|
|
if chainWrite.Path != "/etc/ssl/chain.pem" {
|
|
t.Errorf("expected chain path /etc/ssl/chain.pem, got %s", chainWrite.Path)
|
|
}
|
|
}
|
|
|
|
func TestDeployCertificate_Success_WithReload(t *testing.T) {
|
|
mock := &mockSSHClient{executeOutput: "ok"}
|
|
cfg := &Config{
|
|
Host: "server.local",
|
|
Port: 22,
|
|
CertPath: "/etc/ssl/cert.pem",
|
|
KeyPath: "/etc/ssl/key.pem",
|
|
CertMode: "0644",
|
|
KeyMode: "0600",
|
|
ReloadCommand: "systemctl reload nginx",
|
|
}
|
|
c := NewWithClient(cfg, mock, testLogger())
|
|
|
|
req := target.DeploymentRequest{
|
|
CertPEM: "cert",
|
|
KeyPEM: "key",
|
|
}
|
|
|
|
result, err := c.DeployCertificate(context.Background(), req)
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
if !result.Success {
|
|
t.Fatalf("expected success, got %s", result.Message)
|
|
}
|
|
|
|
// Should have executed reload command
|
|
if len(mock.executeCalls) != 1 {
|
|
t.Fatalf("expected 1 execute call, got %d", len(mock.executeCalls))
|
|
}
|
|
if mock.executeCalls[0] != "systemctl reload nginx" {
|
|
t.Errorf("expected reload command, got %s", mock.executeCalls[0])
|
|
}
|
|
}
|
|
|
|
func TestDeployCertificate_MissingKeyPEM(t *testing.T) {
|
|
mock := &mockSSHClient{}
|
|
cfg := &Config{
|
|
Host: "server.local",
|
|
Port: 22,
|
|
CertPath: "/etc/ssl/cert.pem",
|
|
KeyPath: "/etc/ssl/key.pem",
|
|
CertMode: "0644",
|
|
KeyMode: "0600",
|
|
}
|
|
c := NewWithClient(cfg, mock, testLogger())
|
|
|
|
req := target.DeploymentRequest{
|
|
CertPEM: "cert",
|
|
KeyPEM: "", // Missing
|
|
}
|
|
|
|
result, err := c.DeployCertificate(context.Background(), req)
|
|
if err == nil {
|
|
t.Fatal("expected error for missing KeyPEM")
|
|
}
|
|
if result.Success {
|
|
t.Fatal("expected failure result")
|
|
}
|
|
}
|
|
|
|
func TestDeployCertificate_ConnectionFailure(t *testing.T) {
|
|
mock := &mockSSHClient{connectErr: fmt.Errorf("connection refused")}
|
|
cfg := &Config{
|
|
Host: "unreachable.local",
|
|
Port: 22,
|
|
CertPath: "/etc/ssl/cert.pem",
|
|
KeyPath: "/etc/ssl/key.pem",
|
|
CertMode: "0644",
|
|
KeyMode: "0600",
|
|
}
|
|
c := NewWithClient(cfg, mock, testLogger())
|
|
|
|
req := target.DeploymentRequest{
|
|
CertPEM: "cert",
|
|
KeyPEM: "key",
|
|
}
|
|
|
|
result, err := c.DeployCertificate(context.Background(), req)
|
|
if err == nil {
|
|
t.Fatal("expected error for connection failure")
|
|
}
|
|
if result.Success {
|
|
t.Fatal("expected failure result")
|
|
}
|
|
}
|
|
|
|
func TestDeployCertificate_WriteFailure(t *testing.T) {
|
|
mock := &mockSSHClient{writeFileErr: fmt.Errorf("permission denied")}
|
|
cfg := &Config{
|
|
Host: "server.local",
|
|
Port: 22,
|
|
CertPath: "/etc/ssl/cert.pem",
|
|
KeyPath: "/etc/ssl/key.pem",
|
|
CertMode: "0644",
|
|
KeyMode: "0600",
|
|
}
|
|
c := NewWithClient(cfg, mock, testLogger())
|
|
|
|
req := target.DeploymentRequest{
|
|
CertPEM: "cert",
|
|
KeyPEM: "key",
|
|
}
|
|
|
|
result, err := c.DeployCertificate(context.Background(), req)
|
|
if err == nil {
|
|
t.Fatal("expected error for write failure")
|
|
}
|
|
if result.Success {
|
|
t.Fatal("expected failure result")
|
|
}
|
|
}
|
|
|
|
func TestDeployCertificate_ReloadFailure(t *testing.T) {
|
|
mock := &mockSSHClient{executeErr: fmt.Errorf("reload failed: exit status 1"), executeOutput: "error"}
|
|
cfg := &Config{
|
|
Host: "server.local",
|
|
Port: 22,
|
|
CertPath: "/etc/ssl/cert.pem",
|
|
KeyPath: "/etc/ssl/key.pem",
|
|
CertMode: "0644",
|
|
KeyMode: "0600",
|
|
ReloadCommand: "systemctl reload nginx",
|
|
}
|
|
c := NewWithClient(cfg, mock, testLogger())
|
|
|
|
req := target.DeploymentRequest{
|
|
CertPEM: "cert",
|
|
KeyPEM: "key",
|
|
}
|
|
|
|
result, err := c.DeployCertificate(context.Background(), req)
|
|
if err == nil {
|
|
t.Fatal("expected error for reload failure")
|
|
}
|
|
if result.Success {
|
|
t.Fatal("expected failure result")
|
|
}
|
|
}
|
|
|
|
// --- Bundle 6: pre-deploy snapshot + reload-failure rollback ---
|
|
//
|
|
// These four tests pin the load-bearing rollback contract added in
|
|
// Bundle 6 of the 2026-05-02 deployment-target audit:
|
|
// - happy rollback path: pre-existing remote bytes restored verbatim;
|
|
// - first-time deploy partial-state cleanup via Remove;
|
|
// - both reload AND rollback fail → operator-actionable wrapped error;
|
|
// - rollback succeeds but the retry-reload after rollback fails →
|
|
// daemon-state-unknown wrapped error.
|
|
|
|
func TestSSH_ReloadFails_FilesRestored(t *testing.T) {
|
|
originalCert := []byte("-----BEGIN CERTIFICATE-----\nORIGINAL_CERT\n-----END CERTIFICATE-----\n")
|
|
originalKey := []byte("-----BEGIN PRIVATE KEY-----\nORIGINAL_KEY\n-----END PRIVATE KEY-----\n")
|
|
originalChain := []byte("-----BEGIN CERTIFICATE-----\nORIGINAL_CHAIN\n-----END CERTIFICATE-----\n")
|
|
|
|
mock := &mockSSHClient{
|
|
// Pre-existing files for all three paths; mode 0644 / 0600 / 0644.
|
|
statByPath: map[string]statResponse{
|
|
"/etc/ssl/cert.pem": {info: &stubFileInfo{size: int64(len(originalCert)), mode: 0644}},
|
|
"/etc/ssl/key.pem": {info: &stubFileInfo{size: int64(len(originalKey)), mode: 0600}},
|
|
"/etc/ssl/chain.pem": {info: &stubFileInfo{size: int64(len(originalChain)), mode: 0644}},
|
|
},
|
|
readByPath: map[string][]byte{
|
|
"/etc/ssl/cert.pem": originalCert,
|
|
"/etc/ssl/key.pem": originalKey,
|
|
"/etc/ssl/chain.pem": originalChain,
|
|
},
|
|
// First Execute (reload) fails; second Execute (retry-reload after
|
|
// restore) succeeds — clean recoverable failure.
|
|
executeErrSequence: []error{fmt.Errorf("reload failed: exit status 1"), nil},
|
|
executeOutSequence: []string{"reload error output", "ok"},
|
|
}
|
|
|
|
cfg := &Config{
|
|
Host: "server.local",
|
|
Port: 22,
|
|
CertPath: "/etc/ssl/cert.pem",
|
|
KeyPath: "/etc/ssl/key.pem",
|
|
ChainPath: "/etc/ssl/chain.pem",
|
|
CertMode: "0644",
|
|
KeyMode: "0600",
|
|
ReloadCommand: "systemctl reload nginx",
|
|
}
|
|
c := NewWithClient(cfg, mock, testLogger())
|
|
|
|
req := target.DeploymentRequest{
|
|
CertPEM: "-----BEGIN CERTIFICATE-----\nNEW_CERT\n-----END CERTIFICATE-----\n",
|
|
KeyPEM: "-----BEGIN PRIVATE KEY-----\nNEW_KEY\n-----END PRIVATE KEY-----\n",
|
|
ChainPEM: "-----BEGIN CERTIFICATE-----\nNEW_CHAIN\n-----END CERTIFICATE-----\n",
|
|
}
|
|
|
|
result, err := c.DeployCertificate(context.Background(), req)
|
|
if err == nil {
|
|
t.Fatal("expected error when reload fails")
|
|
}
|
|
if result.Success {
|
|
t.Fatal("expected failure result")
|
|
}
|
|
|
|
// Error must mention reload failure + rollback success.
|
|
if !containsString(err.Error(), "reload command failed") && !containsString(err.Error(), "reload failed") {
|
|
t.Errorf("expected error to mention reload failure, got: %v", err)
|
|
}
|
|
if !containsString(err.Error(), "rolled back") {
|
|
t.Errorf("expected error to mention 'rolled back', got: %v", err)
|
|
}
|
|
|
|
// Build a path → bytes view of every WriteFile call for the assertions.
|
|
// On the success path the deploy writes new bytes; on the rollback path
|
|
// it writes the originals back. We expect each path to be written at
|
|
// least twice (once with new bytes, once with originals).
|
|
writesByPath := map[string][][]byte{}
|
|
for _, w := range mock.writeFileCalls {
|
|
writesByPath[w.Path] = append(writesByPath[w.Path], w.Data)
|
|
}
|
|
|
|
for _, path := range []string{"/etc/ssl/cert.pem", "/etc/ssl/key.pem", "/etc/ssl/chain.pem"} {
|
|
writes := writesByPath[path]
|
|
if len(writes) < 2 {
|
|
t.Errorf("expected at least 2 WriteFile calls for %s (deploy + restore), got %d", path, len(writes))
|
|
continue
|
|
}
|
|
// Last write to each path is the rollback restore — must equal
|
|
// the pre-existing bytes captured in the snapshot.
|
|
lastWrite := writes[len(writes)-1]
|
|
var want []byte
|
|
switch path {
|
|
case "/etc/ssl/cert.pem":
|
|
want = originalCert
|
|
case "/etc/ssl/key.pem":
|
|
want = originalKey
|
|
case "/etc/ssl/chain.pem":
|
|
want = originalChain
|
|
}
|
|
if string(lastWrite) != string(want) {
|
|
t.Errorf("rollback for %s did not restore original bytes:\n got: %q\n want: %q", path, lastWrite, want)
|
|
}
|
|
}
|
|
|
|
// No Remove calls — every path had a pre-existing snapshot to restore from.
|
|
if len(mock.removeCalls) != 0 {
|
|
t.Errorf("expected 0 Remove calls (all paths had backups), got %d: %v", len(mock.removeCalls), mock.removeCalls)
|
|
}
|
|
|
|
// Both Execute calls (initial reload + retry-reload after rollback)
|
|
// must have run.
|
|
if len(mock.executeCalls) != 2 {
|
|
t.Errorf("expected 2 Execute calls (reload + retry-reload), got %d", len(mock.executeCalls))
|
|
}
|
|
|
|
// Metadata reflects per-path snapshot status.
|
|
if result.Metadata["backup_status_cert"] != "restored" {
|
|
t.Errorf("expected backup_status_cert=restored, got %q", result.Metadata["backup_status_cert"])
|
|
}
|
|
if result.Metadata["backup_status_key"] != "restored" {
|
|
t.Errorf("expected backup_status_key=restored, got %q", result.Metadata["backup_status_key"])
|
|
}
|
|
if result.Metadata["backup_status_chain"] != "restored" {
|
|
t.Errorf("expected backup_status_chain=restored, got %q", result.Metadata["backup_status_chain"])
|
|
}
|
|
if result.Metadata["rolled_back"] != "true" {
|
|
t.Errorf("expected rolled_back=true, got %q", result.Metadata["rolled_back"])
|
|
}
|
|
}
|
|
|
|
func TestSSH_NoExistingCert_ReloadFails_NewCertRemoved(t *testing.T) {
|
|
mock := &mockSSHClient{
|
|
// All three paths report "no such file" — first-time deploy.
|
|
statByPath: map[string]statResponse{
|
|
"/etc/ssl/cert.pem": {err: fmt.Errorf("stat: %w", os.ErrNotExist)},
|
|
"/etc/ssl/key.pem": {err: fmt.Errorf("stat: %w", os.ErrNotExist)},
|
|
"/etc/ssl/chain.pem": {err: fmt.Errorf("stat: %w", os.ErrNotExist)},
|
|
},
|
|
// Reload fails; retry-reload after rollback succeeds.
|
|
executeErrSequence: []error{fmt.Errorf("reload failed"), nil},
|
|
executeOutSequence: []string{"reload error", "ok"},
|
|
}
|
|
|
|
cfg := &Config{
|
|
Host: "server.local",
|
|
Port: 22,
|
|
CertPath: "/etc/ssl/cert.pem",
|
|
KeyPath: "/etc/ssl/key.pem",
|
|
ChainPath: "/etc/ssl/chain.pem",
|
|
CertMode: "0644",
|
|
KeyMode: "0600",
|
|
ReloadCommand: "systemctl reload nginx",
|
|
}
|
|
c := NewWithClient(cfg, mock, testLogger())
|
|
|
|
req := target.DeploymentRequest{
|
|
CertPEM: "-----BEGIN CERTIFICATE-----\nNEW_CERT\n-----END CERTIFICATE-----\n",
|
|
KeyPEM: "-----BEGIN PRIVATE KEY-----\nNEW_KEY\n-----END PRIVATE KEY-----\n",
|
|
ChainPEM: "-----BEGIN CERTIFICATE-----\nNEW_CHAIN\n-----END CERTIFICATE-----\n",
|
|
}
|
|
|
|
result, err := c.DeployCertificate(context.Background(), req)
|
|
if err == nil {
|
|
t.Fatal("expected error when reload fails")
|
|
}
|
|
if result.Success {
|
|
t.Fatal("expected failure result")
|
|
}
|
|
|
|
// Rollback for first-time deploys must call Remove on every written path.
|
|
expectedRemoves := map[string]bool{
|
|
"/etc/ssl/cert.pem": true,
|
|
"/etc/ssl/key.pem": true,
|
|
"/etc/ssl/chain.pem": true,
|
|
}
|
|
if len(mock.removeCalls) != len(expectedRemoves) {
|
|
t.Errorf("expected %d Remove calls, got %d: %v", len(expectedRemoves), len(mock.removeCalls), mock.removeCalls)
|
|
}
|
|
for _, p := range mock.removeCalls {
|
|
if !expectedRemoves[p] {
|
|
t.Errorf("unexpected Remove path: %s", p)
|
|
}
|
|
}
|
|
|
|
// First-time deploy: WriteFile is called only during the initial
|
|
// deploy, never during rollback (no backup to restore from).
|
|
expectedWrites := 3 // cert + key + chain (all configured paths)
|
|
if len(mock.writeFileCalls) != expectedWrites {
|
|
t.Errorf("expected exactly %d WriteFile calls (deploy only, no restore), got %d", expectedWrites, len(mock.writeFileCalls))
|
|
}
|
|
|
|
// Metadata reflects "removed" status for all paths.
|
|
if result.Metadata["backup_status_cert"] != "removed" {
|
|
t.Errorf("expected backup_status_cert=removed, got %q", result.Metadata["backup_status_cert"])
|
|
}
|
|
if result.Metadata["backup_status_key"] != "removed" {
|
|
t.Errorf("expected backup_status_key=removed, got %q", result.Metadata["backup_status_key"])
|
|
}
|
|
if result.Metadata["backup_status_chain"] != "removed" {
|
|
t.Errorf("expected backup_status_chain=removed, got %q", result.Metadata["backup_status_chain"])
|
|
}
|
|
}
|
|
|
|
func TestSSH_ReloadFails_RollbackAlsoFails_OperatorActionable(t *testing.T) {
|
|
originalCert := []byte("ORIGINAL_CERT")
|
|
originalKey := []byte("ORIGINAL_KEY")
|
|
|
|
mock := &mockSSHClient{
|
|
statByPath: map[string]statResponse{
|
|
"/etc/ssl/cert.pem": {info: &stubFileInfo{size: int64(len(originalCert)), mode: 0644}},
|
|
"/etc/ssl/key.pem": {info: &stubFileInfo{size: int64(len(originalKey)), mode: 0600}},
|
|
},
|
|
readByPath: map[string][]byte{
|
|
"/etc/ssl/cert.pem": originalCert,
|
|
"/etc/ssl/key.pem": originalKey,
|
|
},
|
|
// Initial deploy WriteFile calls succeed; rollback's WriteFile to
|
|
// restore the cert FAILS. This injects the operator-actionable
|
|
// case: reload failed AND the restore can't complete.
|
|
writeFileErrByPath: map[string]error{},
|
|
executeErrSequence: []error{fmt.Errorf("reload step failed")},
|
|
executeOutSequence: []string{"reload error"},
|
|
}
|
|
// Track call count so we can fail only the SECOND WriteFile to
|
|
// /etc/ssl/cert.pem (i.e. the restore call, not the initial deploy
|
|
// write). Done via a wrapper because writeFileErrByPath is a flat map.
|
|
wrapped := &writeOrderTrackingMock{base: mock}
|
|
wrapped.failOnNthWriteForPath = map[string]int{
|
|
"/etc/ssl/cert.pem": 2, // 1st = deploy write (succeed); 2nd = restore (fail)
|
|
}
|
|
|
|
cfg := &Config{
|
|
Host: "server.local",
|
|
Port: 22,
|
|
CertPath: "/etc/ssl/cert.pem",
|
|
KeyPath: "/etc/ssl/key.pem",
|
|
CertMode: "0644",
|
|
KeyMode: "0600",
|
|
ReloadCommand: "systemctl reload nginx",
|
|
}
|
|
c := NewWithClient(cfg, wrapped, testLogger())
|
|
|
|
req := target.DeploymentRequest{
|
|
CertPEM: "NEW_CERT",
|
|
KeyPEM: "NEW_KEY",
|
|
}
|
|
|
|
result, err := c.DeployCertificate(context.Background(), req)
|
|
if err == nil {
|
|
t.Fatal("expected error when both reload and rollback fail")
|
|
}
|
|
if result.Success {
|
|
t.Fatal("expected failure result")
|
|
}
|
|
|
|
// Wrapped error must mention BOTH the reload error and the rollback error.
|
|
if !containsString(err.Error(), "reload failed") {
|
|
t.Errorf("expected error to mention reload failure, got: %v", err)
|
|
}
|
|
if !containsString(err.Error(), "rollback also failed") {
|
|
t.Errorf("expected error to mention 'rollback also failed', got: %v", err)
|
|
}
|
|
if !containsString(err.Error(), "manual operator inspection required") {
|
|
t.Errorf("expected error to flag manual inspection, got: %v", err)
|
|
}
|
|
|
|
// Metadata must surface manual_action_required + both error strings.
|
|
if result.Metadata["manual_action_required"] != "true" {
|
|
t.Errorf("expected manual_action_required=true, got %q", result.Metadata["manual_action_required"])
|
|
}
|
|
if result.Metadata["rolled_back"] != "false" {
|
|
t.Errorf("expected rolled_back=false, got %q", result.Metadata["rolled_back"])
|
|
}
|
|
if result.Metadata["rollback_error"] == "" {
|
|
t.Error("expected rollback_error in metadata")
|
|
}
|
|
}
|
|
|
|
func TestSSH_ReloadFails_RestoreThenSecondReloadFails(t *testing.T) {
|
|
originalCert := []byte("ORIGINAL_CERT")
|
|
originalKey := []byte("ORIGINAL_KEY")
|
|
|
|
mock := &mockSSHClient{
|
|
statByPath: map[string]statResponse{
|
|
"/etc/ssl/cert.pem": {info: &stubFileInfo{size: int64(len(originalCert)), mode: 0644}},
|
|
"/etc/ssl/key.pem": {info: &stubFileInfo{size: int64(len(originalKey)), mode: 0600}},
|
|
},
|
|
readByPath: map[string][]byte{
|
|
"/etc/ssl/cert.pem": originalCert,
|
|
"/etc/ssl/key.pem": originalKey,
|
|
},
|
|
// Both Execute calls (initial reload + retry-reload after rollback)
|
|
// fail. The remote files are back to pre-deploy state but the
|
|
// daemon may be in a stuck/partial state — operator needs to
|
|
// know that.
|
|
executeErrSequence: []error{fmt.Errorf("reload step 1 failed"), fmt.Errorf("reload step 2 failed")},
|
|
executeOutSequence: []string{"out1", "out2"},
|
|
}
|
|
|
|
cfg := &Config{
|
|
Host: "server.local",
|
|
Port: 22,
|
|
CertPath: "/etc/ssl/cert.pem",
|
|
KeyPath: "/etc/ssl/key.pem",
|
|
CertMode: "0644",
|
|
KeyMode: "0600",
|
|
ReloadCommand: "systemctl reload nginx",
|
|
}
|
|
c := NewWithClient(cfg, mock, testLogger())
|
|
|
|
req := target.DeploymentRequest{
|
|
CertPEM: "NEW_CERT",
|
|
KeyPEM: "NEW_KEY",
|
|
}
|
|
|
|
result, err := c.DeployCertificate(context.Background(), req)
|
|
if err == nil {
|
|
t.Fatal("expected error when retry-reload after rollback fails")
|
|
}
|
|
if result.Success {
|
|
t.Fatal("expected failure result")
|
|
}
|
|
|
|
// Wrapped error mentions reload failure, rollback success, and
|
|
// retry-reload failure — operator must understand the daemon may
|
|
// not be running the original config even though the files are back.
|
|
if !containsString(err.Error(), "rolled back files") {
|
|
t.Errorf("expected error to mention 'rolled back files', got: %v", err)
|
|
}
|
|
if !containsString(err.Error(), "retry-reload also failed") {
|
|
t.Errorf("expected error to mention retry-reload failure, got: %v", err)
|
|
}
|
|
if !containsString(err.Error(), "daemon may need manual restart") {
|
|
t.Errorf("expected error to flag daemon state, got: %v", err)
|
|
}
|
|
|
|
// Metadata flags daemon_state_unknown + rolled_back=true (files OK).
|
|
if result.Metadata["daemon_state_unknown"] != "true" {
|
|
t.Errorf("expected daemon_state_unknown=true, got %q", result.Metadata["daemon_state_unknown"])
|
|
}
|
|
if result.Metadata["rolled_back"] != "true" {
|
|
t.Errorf("expected rolled_back=true, got %q", result.Metadata["rolled_back"])
|
|
}
|
|
|
|
// Both Execute calls happened; both WriteFile-on-restore calls
|
|
// happened (cert + key restored).
|
|
if len(mock.executeCalls) != 2 {
|
|
t.Errorf("expected 2 Execute calls, got %d", len(mock.executeCalls))
|
|
}
|
|
}
|
|
|
|
// writeOrderTrackingMock wraps mockSSHClient to fail the Nth WriteFile
|
|
// for a given path. Used by TestSSH_ReloadFails_RollbackAlsoFails_-
|
|
// OperatorActionable to fail the restore (2nd write) while letting the
|
|
// initial deploy (1st write) succeed for the same path.
|
|
type writeOrderTrackingMock struct {
|
|
base *mockSSHClient
|
|
writeCountByPath map[string]int
|
|
failOnNthWriteForPath map[string]int
|
|
}
|
|
|
|
func (w *writeOrderTrackingMock) Connect(ctx context.Context) error { return w.base.Connect(ctx) }
|
|
func (w *writeOrderTrackingMock) WriteFile(remotePath string, data []byte, mode os.FileMode) error {
|
|
if w.writeCountByPath == nil {
|
|
w.writeCountByPath = map[string]int{}
|
|
}
|
|
w.writeCountByPath[remotePath]++
|
|
w.base.writeFileCalls = append(w.base.writeFileCalls, writeFileCall{Path: remotePath, Data: data, Mode: mode})
|
|
if n, ok := w.failOnNthWriteForPath[remotePath]; ok {
|
|
if w.writeCountByPath[remotePath] == n {
|
|
return fmt.Errorf("injected write failure on call %d to %s", n, remotePath)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
func (w *writeOrderTrackingMock) Execute(ctx context.Context, cmd string) (string, error) {
|
|
return w.base.Execute(ctx, cmd)
|
|
}
|
|
func (w *writeOrderTrackingMock) StatFile(remotePath string) (os.FileInfo, error) {
|
|
return w.base.StatFile(remotePath)
|
|
}
|
|
func (w *writeOrderTrackingMock) ReadFile(remotePath string) ([]byte, error) {
|
|
return w.base.ReadFile(remotePath)
|
|
}
|
|
func (w *writeOrderTrackingMock) Remove(remotePath string) error { return w.base.Remove(remotePath) }
|
|
func (w *writeOrderTrackingMock) Close() error { return w.base.Close() }
|
|
|
|
// --- ValidateDeployment tests ---
|
|
|
|
func TestValidateDeployment_Success(t *testing.T) {
|
|
mock := &mockSSHClient{statFileSize: 2048}
|
|
cfg := &Config{
|
|
Host: "server.local",
|
|
Port: 22,
|
|
CertPath: "/etc/ssl/cert.pem",
|
|
KeyPath: "/etc/ssl/key.pem",
|
|
CertMode: "0644",
|
|
KeyMode: "0600",
|
|
}
|
|
c := NewWithClient(cfg, mock, testLogger())
|
|
|
|
req := target.ValidationRequest{
|
|
CertificateID: "mc-test",
|
|
Serial: "ABC123",
|
|
}
|
|
|
|
result, err := c.ValidateDeployment(context.Background(), req)
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
if !result.Valid {
|
|
t.Fatalf("expected valid, got %s", result.Message)
|
|
}
|
|
|
|
// Should have stat'd both files
|
|
if len(mock.statFileCalls) != 2 {
|
|
t.Fatalf("expected 2 stat calls, got %d", len(mock.statFileCalls))
|
|
}
|
|
if mock.statFileCalls[0] != "/etc/ssl/cert.pem" {
|
|
t.Errorf("expected cert path, got %s", mock.statFileCalls[0])
|
|
}
|
|
if mock.statFileCalls[1] != "/etc/ssl/key.pem" {
|
|
t.Errorf("expected key path, got %s", mock.statFileCalls[1])
|
|
}
|
|
}
|
|
|
|
func TestValidateDeployment_CertNotFound(t *testing.T) {
|
|
mock := &mockSSHClient{statFileErr: fmt.Errorf("file not found")}
|
|
cfg := &Config{
|
|
Host: "server.local",
|
|
Port: 22,
|
|
CertPath: "/etc/ssl/cert.pem",
|
|
KeyPath: "/etc/ssl/key.pem",
|
|
CertMode: "0644",
|
|
KeyMode: "0600",
|
|
}
|
|
c := NewWithClient(cfg, mock, testLogger())
|
|
|
|
req := target.ValidationRequest{
|
|
CertificateID: "mc-test",
|
|
Serial: "ABC123",
|
|
}
|
|
|
|
result, err := c.ValidateDeployment(context.Background(), req)
|
|
if err == nil {
|
|
t.Fatal("expected error for missing cert")
|
|
}
|
|
if result.Valid {
|
|
t.Fatal("expected invalid result")
|
|
}
|
|
}
|
|
|
|
func TestValidateDeployment_ConnectionFailure(t *testing.T) {
|
|
mock := &mockSSHClient{connectErr: fmt.Errorf("connection refused")}
|
|
cfg := &Config{
|
|
Host: "unreachable.local",
|
|
Port: 22,
|
|
CertPath: "/etc/ssl/cert.pem",
|
|
KeyPath: "/etc/ssl/key.pem",
|
|
CertMode: "0644",
|
|
KeyMode: "0600",
|
|
}
|
|
c := NewWithClient(cfg, mock, testLogger())
|
|
|
|
req := target.ValidationRequest{
|
|
CertificateID: "mc-test",
|
|
Serial: "ABC123",
|
|
}
|
|
|
|
result, err := c.ValidateDeployment(context.Background(), req)
|
|
if err == nil {
|
|
t.Fatal("expected error for connection failure")
|
|
}
|
|
if result.Valid {
|
|
t.Fatal("expected invalid result")
|
|
}
|
|
}
|
|
|
|
// --- Helper tests ---
|
|
|
|
func TestParsePermissions(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected os.FileMode
|
|
wantErr bool
|
|
}{
|
|
{"0644", 0644, false},
|
|
{"0600", 0600, false},
|
|
{"0755", 0755, false},
|
|
{"invalid", 0, true},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.input, func(t *testing.T) {
|
|
mode, err := parsePermissions(tc.input)
|
|
if tc.wantErr && err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !tc.wantErr && err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if !tc.wantErr && mode != tc.expected {
|
|
t.Errorf("expected %v, got %v", tc.expected, mode)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestApplyDefaults(t *testing.T) {
|
|
cfg := &Config{}
|
|
applyDefaults(cfg)
|
|
|
|
if cfg.Port != 22 {
|
|
t.Errorf("expected port 22, got %d", cfg.Port)
|
|
}
|
|
if cfg.AuthMethod != "key" {
|
|
t.Errorf("expected auth_method key, got %s", cfg.AuthMethod)
|
|
}
|
|
if cfg.CertMode != "0644" {
|
|
t.Errorf("expected cert_mode 0644, got %s", cfg.CertMode)
|
|
}
|
|
if cfg.KeyMode != "0600" {
|
|
t.Errorf("expected key_mode 0600, got %s", cfg.KeyMode)
|
|
}
|
|
if cfg.Timeout != 30 {
|
|
t.Errorf("expected timeout 30, got %d", cfg.Timeout)
|
|
}
|
|
}
|
|
|
|
// TestDeployCertificate_FullChainMode tests that when ChainPath is not set but
|
|
// ChainPEM is provided, the chain is appended to the certificate data before writing.
|
|
func TestDeployCertificate_FullChainMode(t *testing.T) {
|
|
keyFile := createTempKeyFile(t)
|
|
|
|
cfg := &Config{
|
|
Host: "example.com",
|
|
Port: 22,
|
|
User: "deploy",
|
|
AuthMethod: "key",
|
|
PrivateKeyPath: keyFile,
|
|
CertPath: "/etc/ssl/certs/cert.pem",
|
|
KeyPath: "/etc/ssl/private/key.pem",
|
|
ChainPath: "", // Not set, so chain should be appended to cert
|
|
CertMode: "0644",
|
|
KeyMode: "0600",
|
|
Timeout: 30,
|
|
}
|
|
|
|
mock := &mockSSHClient{}
|
|
connector := NewWithClient(cfg, mock, testLogger())
|
|
|
|
deployReq := target.DeploymentRequest{
|
|
CertPEM: "-----BEGIN CERTIFICATE-----\nMIIBk...\n-----END CERTIFICATE-----",
|
|
KeyPEM: "-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----",
|
|
ChainPEM: "-----BEGIN CERTIFICATE-----\nMIIBj...\n-----END CERTIFICATE-----",
|
|
}
|
|
|
|
result, err := connector.DeployCertificate(context.Background(), deployReq)
|
|
if err != nil {
|
|
t.Fatalf("deployment failed: %v", err)
|
|
}
|
|
if !result.Success {
|
|
t.Fatalf("deployment result was not successful: %s", result.Message)
|
|
}
|
|
|
|
// Verify that the cert file received contains both cert and chain concatenated
|
|
if len(mock.writeFileCalls) < 2 {
|
|
t.Fatalf("expected at least 2 WriteFile calls, got %d", len(mock.writeFileCalls))
|
|
}
|
|
|
|
certWriteCall := mock.writeFileCalls[0]
|
|
if certWriteCall.Path != "/etc/ssl/certs/cert.pem" {
|
|
t.Errorf("expected cert path /etc/ssl/certs/cert.pem, got %s", certWriteCall.Path)
|
|
}
|
|
|
|
certData := string(certWriteCall.Data)
|
|
if !containsString(certData, "BEGIN CERTIFICATE") || !containsString(certData, "END CERTIFICATE") {
|
|
t.Errorf("cert data should contain combined cert and chain")
|
|
}
|
|
|
|
// Verify chain was not written separately (since ChainPath is empty)
|
|
if len(mock.writeFileCalls) > 2 {
|
|
t.Errorf("expected only 2 WriteFile calls (cert + key), got %d", len(mock.writeFileCalls))
|
|
}
|
|
}
|
|
|
|
// TestDeployCertificate_Permissions tests that the correct file permissions are
|
|
// passed to WriteFile for both certificate and key files.
|
|
func TestDeployCertificate_Permissions(t *testing.T) {
|
|
keyFile := createTempKeyFile(t)
|
|
|
|
cfg := &Config{
|
|
Host: "example.com",
|
|
Port: 22,
|
|
User: "deploy",
|
|
AuthMethod: "key",
|
|
PrivateKeyPath: keyFile,
|
|
CertPath: "/etc/ssl/certs/cert.pem",
|
|
KeyPath: "/etc/ssl/private/key.pem",
|
|
ChainPath: "",
|
|
CertMode: "0644",
|
|
KeyMode: "0600",
|
|
Timeout: 30,
|
|
}
|
|
|
|
mock := &mockSSHClient{}
|
|
connector := NewWithClient(cfg, mock, testLogger())
|
|
|
|
deployReq := target.DeploymentRequest{
|
|
CertPEM: "-----BEGIN CERTIFICATE-----\nMIIBk...\n-----END CERTIFICATE-----",
|
|
KeyPEM: "-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----",
|
|
ChainPEM: "",
|
|
}
|
|
|
|
_, err := connector.DeployCertificate(context.Background(), deployReq)
|
|
if err != nil {
|
|
t.Fatalf("deployment failed: %v", err)
|
|
}
|
|
|
|
if len(mock.writeFileCalls) < 2 {
|
|
t.Fatalf("expected at least 2 WriteFile calls, got %d", len(mock.writeFileCalls))
|
|
}
|
|
|
|
// Check cert file permissions (0644 = rw-r--r--)
|
|
certMode := mock.writeFileCalls[0].Mode
|
|
expectedCertMode := os.FileMode(0644)
|
|
if certMode != expectedCertMode {
|
|
t.Errorf("expected cert mode 0644, got %o", certMode)
|
|
}
|
|
|
|
// Check key file permissions (0600 = rw-------)
|
|
keyMode := mock.writeFileCalls[1].Mode
|
|
expectedKeyMode := os.FileMode(0600)
|
|
if keyMode != expectedKeyMode {
|
|
t.Errorf("expected key mode 0600, got %o", keyMode)
|
|
}
|
|
}
|
|
|
|
// TestValidateDeployment_KeyNotFound tests that ValidateDeployment fails when
|
|
// the key file is not found on the remote server.
|
|
func TestValidateDeployment_KeyNotFound(t *testing.T) {
|
|
keyFile := createTempKeyFile(t)
|
|
|
|
cfg := &Config{
|
|
Host: "example.com",
|
|
Port: 22,
|
|
User: "deploy",
|
|
AuthMethod: "key",
|
|
PrivateKeyPath: keyFile,
|
|
CertPath: "/etc/ssl/certs/cert.pem",
|
|
KeyPath: "/etc/ssl/private/key.pem",
|
|
ChainPath: "",
|
|
CertMode: "0644",
|
|
KeyMode: "0600",
|
|
Timeout: 30,
|
|
}
|
|
|
|
// Create a custom mock that succeeds for cert but fails for key
|
|
mock := &conditionalStatMockSSHClient{
|
|
base: &mockSSHClient{},
|
|
}
|
|
|
|
connector := NewWithClient(cfg, mock, testLogger())
|
|
|
|
valReq := target.ValidationRequest{
|
|
Serial: "11111",
|
|
}
|
|
|
|
result, err := connector.ValidateDeployment(context.Background(), valReq)
|
|
if err == nil {
|
|
t.Error("expected validation to fail when key file is not found")
|
|
}
|
|
if result.Valid {
|
|
t.Error("expected Valid=false when key file is missing")
|
|
}
|
|
if !containsString(result.Message, "key file not found") {
|
|
t.Errorf("expected 'key file not found' in message, got: %s", result.Message)
|
|
}
|
|
}
|
|
|
|
// conditionalStatMockSSHClient wraps mockSSHClient to fail on key path during StatFile.
|
|
type conditionalStatMockSSHClient struct {
|
|
base *mockSSHClient
|
|
callCount int
|
|
}
|
|
|
|
func (m *conditionalStatMockSSHClient) Connect(ctx context.Context) error {
|
|
return m.base.Connect(ctx)
|
|
}
|
|
|
|
func (m *conditionalStatMockSSHClient) WriteFile(remotePath string, data []byte, mode os.FileMode) error {
|
|
return m.base.WriteFile(remotePath, data, mode)
|
|
}
|
|
|
|
func (m *conditionalStatMockSSHClient) Execute(ctx context.Context, command string) (string, error) {
|
|
return m.base.Execute(ctx, command)
|
|
}
|
|
|
|
func (m *conditionalStatMockSSHClient) StatFile(remotePath string) (os.FileInfo, error) {
|
|
m.callCount++
|
|
// First call succeeds (cert), second call fails (key) — wrap
|
|
// os.ErrNotExist so the connector's errors.Is check propagates the
|
|
// "file not found" semantics through the Bundle 6 stat-error
|
|
// handling.
|
|
if m.callCount == 2 {
|
|
return nil, fmt.Errorf("file not found: %w", os.ErrNotExist)
|
|
}
|
|
return &stubFileInfo{size: 1024, mode: 0644}, nil
|
|
}
|
|
|
|
func (m *conditionalStatMockSSHClient) ReadFile(remotePath string) ([]byte, error) {
|
|
return m.base.ReadFile(remotePath)
|
|
}
|
|
|
|
func (m *conditionalStatMockSSHClient) Remove(remotePath string) error {
|
|
return m.base.Remove(remotePath)
|
|
}
|
|
|
|
func (m *conditionalStatMockSSHClient) Close() error {
|
|
return m.base.Close()
|
|
}
|
|
|
|
// --- Helpers ---
|
|
|
|
// createTempKeyFile creates a temporary file that simulates an SSH private key.
|
|
func createTempKeyFile(t *testing.T) string {
|
|
t.Helper()
|
|
dir := t.TempDir()
|
|
keyFile := dir + "/id_rsa"
|
|
if err := os.WriteFile(keyFile, []byte("fake-key-data"), 0600); err != nil {
|
|
t.Fatalf("failed to create temp key file: %v", err)
|
|
}
|
|
return keyFile
|
|
}
|
|
|
|
// containsString is a helper to check if a string contains a substring.
|
|
func containsString(s, substr string) bool {
|
|
return len(s) >= len(substr) && stringIndex(s, substr) != -1
|
|
}
|
|
|
|
// stringIndex returns the index of the first occurrence of substr in s, or -1 if not found.
|
|
func stringIndex(s, substr string) int {
|
|
for i := 0; i <= len(s)-len(substr); i++ {
|
|
match := true
|
|
for j := 0; j < len(substr); j++ {
|
|
if s[i+j] != substr[j] {
|
|
match = false
|
|
break
|
|
}
|
|
}
|
|
if match {
|
|
return i
|
|
}
|
|
}
|
|
return -1
|
|
}
|