Files
certctl/internal/connector/target/ssh/ssh_server_fixture_test.go
T
shankar0123 636de7f6b5 ssh: pre-deploy snapshot + reload-failure rollback
Closes Bundle 6 of the 2026-05-02 deployment-target coverage audit
(see cowork/deployment-target-audit-2026-05-02/RESULTS.md). Pre-fix,
DeployCertificate at ssh.go:201-316 wrote new cert/key/chain via
SFTP then ran the operator's reload command. If reload failed, the
new files stayed on the remote — partial-success state with no
rollback path. docs/deployment-atomicity.md L92 promised "Pre-deploy
SCP backup of remote files"; the code didn't deliver.

This commit:

1. Pre-deploy snapshot. Before any WriteFile, iterate the deploy's
   target paths (cert, key, optional chain). For each path:
   - StatFile to detect existence. errors.Is(err, os.ErrNotExist)
     means first-time deploy (rollback = Remove). Other stat
     errors bail out before any write happens.
   - ReadFile into an in-memory backups map[string][]byte keyed
     by remote path. Original mode captured into a parallel
     modes map for restore fidelity.

2. SSHClient interface evolution — three changes:
   - StatFile(path) (os.FileInfo, error) — was (int64, error).
     FileInfo carries Mode() needed for accurate restore. Existing
     fixture tests updated to call info.Size() instead of the
     bare size value.
   - ReadFile(path) ([]byte, error) — new method; SFTP Open + read
     via io.ReadAll. realSSHClient implements via sftpClient.Open.
   - Remove(path) error — new method; SFTP Remove. Used by the
     rollback path to clean up first-time-deploy partial state.

3. On-reload-failure rollback. Replace the bare error-return at
   L282-295 with restoreFromBackups + retry-reload escalation:
   - For paths in the snapshot map, WriteFile the original bytes
     with the original mode (0600 fallback if mode capture was
     incomplete).
   - For paths that didn't exist pre-deploy, Remove the new file.
   - Re-run the reload command (best-effort second attempt). If
     it succeeds, the target is back to pre-deploy state. If it
     fails, the remote is in pre-deploy file state but the daemon
     may be stuck — surface as wrapped error so the operator
     knows where to look.

4. DeploymentResult.Metadata gains backup_status_{cert,key,chain}
   so operators can see per-path snapshot state on both success
   ("snapshotted" / "no_pre_existing" / "n/a") and failure
   ("restored" / "removed" / "restore_failed" / "remove_failed").
   buildMetadataWithBackup helper centralises the metadata
   shape so success and failure paths emit a consistent set
   of keys.

5. Helper extraction. restoreFromBackups(ctx, paths, backups,
   modes) is a private method on Connector; returns the first
   error + per-key restore status map for clean test seams.

DeploymentResult shape on failure:
- rollback OK + retry-reload OK → Success=false, "reload command
  failed; rolled back to pre-deploy state" (clean recoverable
  failure; remote fully restored, daemon serving original cert).
- rollback OK + retry-reload FAIL → wrapped error noting "rolled
  back files; retry-reload also failed; daemon may need manual
  restart". Metadata flags daemon_state_unknown=true.
- rollback FAIL → operator-actionable wrapped error containing
  BOTH the reload error AND the rollback error; metadata flags
  manual_action_required=true.

Tests added to ssh_test.go (4 new tests, ~330 LOC):
- TestSSH_ReloadFails_FilesRestored — happy rollback path with
  pre-existing remote bytes for cert/key/chain. Asserts every
  path's last WriteFile call contains the captured backup bytes
  verbatim, no Remove calls fired (all paths had snapshots), and
  metadata reports backup_status=restored for each path.
- TestSSH_NoExistingCert_ReloadFails_NewCertRemoved — first-time
  deploy variant. StatFile returns os.ErrNotExist for every path;
  rollback Removes each written file but performs no WriteFile
  during restore (no backup to restore from). Asserts exactly 3
  WriteFile calls (deploy only) and 3 Remove calls (rollback).
- TestSSH_ReloadFails_RollbackAlsoFails_OperatorActionable —
  uses a writeOrderTrackingMock to fail the SECOND WriteFile to
  the cert path (i.e. the restore call, not the initial deploy).
  Asserts wrapped error contains both the reload error and the
  rollback error, and metadata flags manual_action_required=true.
- TestSSH_ReloadFails_RestoreThenSecondReloadFails — partial-
  recovery escalation. Rollback succeeds but the post-restore
  retry-reload fails. Asserts wrapped error mentions "rolled back
  files; retry-reload also failed" and metadata flags
  daemon_state_unknown=true.

Existing tests preserved by extending mockSSHClient with backward-
compatible per-path response maps (statByPath / readByPath /
writeFileErrByPath / executeErrSequence). Legacy global fields
(statFileSize / statFileErr / writeFileErr / executeErr) still
work when no per-path override matches, so TestValidateConfig_*
and TestDeployCertificate_Success_* don't need changes.

docs/deployment-atomicity.md L92 unchanged from today's text —
Bundle 1 doc-realignment hasn't shipped, so the "Pre-deploy SCP
backup of remote files" line was never softened. Post-Bundle-6
the claim is honest (was aspirational pre-fix).

Verified locally (sandbox lacks staticcheck install due to disk
pressure; CI runs the full lint gate):
- gofmt -l ./internal/connector/target/ssh/  clean
- go vet ./internal/connector/target/ssh/  clean
- go build ./internal/connector/target/ssh/...  clean
- go build ./cmd/agent/...  clean
- go test -race -count=1 ./internal/connector/target/ssh/  green

Audit reference: cowork/deployment-target-audit-2026-05-02/RESULTS.md
Bundle 6.
2026-05-02 17:13:38 +00:00

629 lines
18 KiB
Go

package ssh
import (
"bytes"
"context"
"crypto/ed25519"
"crypto/rand"
"errors"
"io"
"net"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"github.com/pkg/sftp"
gossh "golang.org/x/crypto/ssh"
)
// Bundle M.SSH-extended (H-002 closure): in-process SSH server fixture that
// exercises realSSHClient.Connect, Execute, WriteFile, StatFile, and Close
// end-to-end. Same pattern as M.Email's hand-rolled SMTP fixture — minimal
// in-process protocol server bound to net.Listen("tcp", "127.0.0.1:0") with
// t.Cleanup-driven shutdown.
//
// The SSH server uses Ed25519 host keys (lightest crypto for tests),
// password authentication (simplest auth), and supports two channel types:
//
// - "session" with "exec" subsystem — used by realSSHClient.Execute
// - "session" with "subsystem sftp" — used by realSSHClient.WriteFile,
// StatFile (proxied through pkg/sftp.NewServer over the channel)
//
// The fixture lives in tests only; production code never imports it.
// fakeSSHServer is a minimal in-process SSH server bound to a random port.
type fakeSSHServer struct {
t *testing.T
listener net.Listener
addr string
user string
password string
wg sync.WaitGroup
mu sync.Mutex
closed bool
// Optional behaviour toggles for failure-mode tests.
rejectAuth bool // reject all auth attempts (auth failure path)
dropOnHandshake bool // close conn before SSH NewServerConn returns (handshake failure)
failExec bool // exec sessions return non-zero exit (Execute error path)
failSFTP bool // refuse sftp subsystem (SFTP failure path)
}
// startFakeSSHServer binds a fresh server on a random local port and returns
// it ready to accept Connect calls. t.Cleanup is wired to close the listener
// + drain in-flight handlers.
func startFakeSSHServer(t *testing.T, opts ...func(*fakeSSHServer)) *fakeSSHServer {
t.Helper()
srv := &fakeSSHServer{
t: t,
user: "testuser",
password: "testpass",
}
for _, opt := range opts {
opt(srv)
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen: %v", err)
}
srv.listener = listener
srv.addr = listener.Addr().String()
t.Cleanup(srv.Close)
srv.wg.Add(1)
go srv.acceptLoop()
return srv
}
// host returns the host:port the listener is bound to. Splits via SplitHostPort
// so the test caller can pass them separately to Config.
func (s *fakeSSHServer) hostPort() (string, int) {
host, portStr, err := net.SplitHostPort(s.addr)
if err != nil {
s.t.Fatalf("SplitHostPort: %v", err)
}
var port int
for _, c := range portStr {
if c >= '0' && c <= '9' {
port = port*10 + int(c-'0')
}
}
return host, port
}
func (s *fakeSSHServer) Close() {
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return
}
s.closed = true
s.mu.Unlock()
_ = s.listener.Close()
s.wg.Wait()
}
func (s *fakeSSHServer) acceptLoop() {
defer s.wg.Done()
// Generate a fresh Ed25519 host key for this server instance.
_, hostKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
s.t.Errorf("ed25519.GenerateKey: %v", err)
return
}
signer, err := gossh.NewSignerFromKey(hostKey)
if err != nil {
s.t.Errorf("NewSignerFromKey: %v", err)
return
}
cfg := &gossh.ServerConfig{
PasswordCallback: func(c gossh.ConnMetadata, p []byte) (*gossh.Permissions, error) {
if s.rejectAuth {
return nil, errors.New("auth rejected (test fixture)")
}
if c.User() == s.user && string(p) == s.password {
return &gossh.Permissions{}, nil
}
return nil, errors.New("invalid credentials")
},
PublicKeyCallback: func(c gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) {
if s.rejectAuth {
return nil, errors.New("auth rejected (test fixture)")
}
// Accept any pubkey; testers using key-auth don't need to also
// configure trust, since this is a pure connectivity fixture.
return &gossh.Permissions{}, nil
},
}
cfg.AddHostKey(signer)
for {
conn, err := s.listener.Accept()
if err != nil {
// Listener closed — exit cleanly.
return
}
s.wg.Add(1)
go func(c net.Conn) {
defer s.wg.Done()
s.handleConn(c, cfg)
}(conn)
}
}
func (s *fakeSSHServer) handleConn(nConn net.Conn, cfg *gossh.ServerConfig) {
defer nConn.Close()
if s.dropOnHandshake {
// Close immediately to surface a handshake error on the client side.
return
}
_, chans, reqs, err := gossh.NewServerConn(nConn, cfg)
if err != nil {
// Common: closed connection during handshake (test cleanup, auth fail).
return
}
go gossh.DiscardRequests(reqs)
for newCh := range chans {
if newCh.ChannelType() != "session" {
_ = newCh.Reject(gossh.UnknownChannelType, "unknown channel type")
continue
}
ch, requests, err := newCh.Accept()
if err != nil {
continue
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.handleSession(ch, requests)
}()
}
}
func (s *fakeSSHServer) handleSession(ch gossh.Channel, reqs <-chan *gossh.Request) {
defer ch.Close()
for req := range reqs {
switch req.Type {
case "exec":
if s.failExec {
_ = req.Reply(true, nil)
_, _ = ch.Write([]byte("exec failure (test fixture)\n"))
_, _ = ch.SendRequest("exit-status", false, []byte{0, 0, 0, 1}) // exit code 1
return
}
// Echo back a canned success response so Execute returns without error.
_ = req.Reply(true, nil)
_, _ = ch.Write([]byte("exec ok\n"))
_, _ = ch.SendRequest("exit-status", false, []byte{0, 0, 0, 0}) // exit code 0
return
case "subsystem":
// Payload is the subsystem name in standard SSH wire form: 4-byte
// length prefix + bytes. Look for "sftp".
if len(req.Payload) >= 4 {
name := string(req.Payload[4:])
if name == "sftp" {
if s.failSFTP {
_ = req.Reply(false, nil)
return
}
_ = req.Reply(true, nil)
srv, err := sftp.NewServer(ch)
if err != nil {
return
}
_ = srv.Serve()
return
}
}
_ = req.Reply(false, nil)
default:
if req.WantReply {
_ = req.Reply(false, nil)
}
}
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Connect happy path / failure paths
// ─────────────────────────────────────────────────────────────────────────────
func TestRealSSHClient_Connect_Password_Success(t *testing.T) {
srv := startFakeSSHServer(t)
host, port := srv.hostPort()
c := &realSSHClient{config: &Config{
Host: host,
Port: port,
User: srv.user,
AuthMethod: "password",
Password: srv.password,
Timeout: 5,
}}
if err := c.Connect(context.Background()); err != nil {
t.Fatalf("Connect: %v", err)
}
defer c.Close()
if c.sshClient == nil {
t.Errorf("expected sshClient to be set after Connect")
}
if c.sftpClient == nil {
t.Errorf("expected sftpClient to be set after Connect")
}
}
func TestRealSSHClient_Connect_Password_WrongPassword(t *testing.T) {
srv := startFakeSSHServer(t)
host, port := srv.hostPort()
c := &realSSHClient{config: &Config{
Host: host,
Port: port,
User: srv.user,
AuthMethod: "password",
Password: "wrong-password",
Timeout: 5,
}}
if err := c.Connect(context.Background()); err == nil {
t.Errorf("expected wrong-password to fail Connect")
_ = c.Close()
}
}
func TestRealSSHClient_Connect_AuthRejected_AllAttempts(t *testing.T) {
srv := startFakeSSHServer(t, func(s *fakeSSHServer) { s.rejectAuth = true })
host, port := srv.hostPort()
c := &realSSHClient{config: &Config{
Host: host,
Port: port,
User: srv.user,
AuthMethod: "password",
Password: srv.password,
Timeout: 5,
}}
if err := c.Connect(context.Background()); err == nil {
t.Errorf("expected auth rejection to fail Connect")
_ = c.Close()
} else if !strings.Contains(err.Error(), "SSH handshake") {
t.Errorf("expected handshake error, got %v", err)
}
}
func TestRealSSHClient_Connect_HandshakeDropped(t *testing.T) {
srv := startFakeSSHServer(t, func(s *fakeSSHServer) { s.dropOnHandshake = true })
host, port := srv.hostPort()
c := &realSSHClient{config: &Config{
Host: host,
Port: port,
User: srv.user,
AuthMethod: "password",
Password: srv.password,
Timeout: 5,
}}
if err := c.Connect(context.Background()); err == nil {
t.Errorf("expected handshake-drop to fail Connect")
_ = c.Close()
}
}
func TestRealSSHClient_Connect_TCPConnRefused(t *testing.T) {
// Bind a listener, immediately close it — the port is still allocated
// but no one is listening. Connect must return a TCP-connection error.
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen: %v", err)
}
addr := listener.Addr().String()
_ = listener.Close()
host, portStr, _ := net.SplitHostPort(addr)
var port int
for _, c := range portStr {
if c >= '0' && c <= '9' {
port = port*10 + int(c-'0')
}
}
c := &realSSHClient{config: &Config{
Host: host,
Port: port,
User: "anyone",
AuthMethod: "password",
Password: "anything",
Timeout: 1, // 1-second timeout
}}
if err := c.Connect(context.Background()); err == nil {
t.Errorf("expected TCP-refused, got nil")
_ = c.Close()
} else if !strings.Contains(err.Error(), "TCP connection") {
t.Errorf("expected TCP-connection error, got %v", err)
}
}
func TestRealSSHClient_Connect_KeyAuth_Success(t *testing.T) {
srv := startFakeSSHServer(t)
host, port := srv.hostPort()
// Generate an ed25519 client key and serialize it to OpenSSH PEM.
pub, priv, err := ed25519.GenerateKey(rand.Reader)
_ = pub
if err != nil {
t.Fatalf("ed25519.GenerateKey: %v", err)
}
pemBlock, err := gossh.MarshalPrivateKey(priv, "test-key")
if err != nil {
t.Fatalf("MarshalPrivateKey: %v", err)
}
keyPath := filepath.Join(t.TempDir(), "id_test")
if err := os.WriteFile(keyPath, encodePEMBlock(pemBlock.Type, pemBlock.Bytes), 0600); err != nil {
t.Fatalf("WriteFile key: %v", err)
}
c := &realSSHClient{config: &Config{
Host: host,
Port: port,
User: srv.user,
AuthMethod: "key",
PrivateKeyPath: keyPath,
Timeout: 5,
}}
if err := c.Connect(context.Background()); err != nil {
t.Fatalf("Connect (key auth): %v", err)
}
defer c.Close()
}
// encodePEMBlock builds a minimal PEM-format block with the given type+bytes.
// (Avoids pulling in encoding/pem in the test header — it's already imported
// transitively but this keeps the import list minimal.)
func encodePEMBlock(blockType string, blockBytes []byte) []byte {
var buf bytes.Buffer
buf.WriteString("-----BEGIN ")
buf.WriteString(blockType)
buf.WriteString("-----\n")
// Base64-encode in 64-char lines.
enc := base64Encode(blockBytes)
for i := 0; i < len(enc); i += 64 {
end := i + 64
if end > len(enc) {
end = len(enc)
}
buf.Write(enc[i:end])
buf.WriteByte('\n')
}
buf.WriteString("-----END ")
buf.WriteString(blockType)
buf.WriteString("-----\n")
return buf.Bytes()
}
func base64Encode(in []byte) []byte {
const enc = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
out := make([]byte, (len(in)+2)/3*4)
j := 0
for i := 0; i < len(in); i += 3 {
var v uint32
v = uint32(in[i]) << 16
if i+1 < len(in) {
v |= uint32(in[i+1]) << 8
}
if i+2 < len(in) {
v |= uint32(in[i+2])
}
out[j] = enc[(v>>18)&0x3f]
out[j+1] = enc[(v>>12)&0x3f]
if i+1 < len(in) {
out[j+2] = enc[(v>>6)&0x3f]
} else {
out[j+2] = '='
}
if i+2 < len(in) {
out[j+3] = enc[v&0x3f]
} else {
out[j+3] = '='
}
j += 4
}
return out
}
// ─────────────────────────────────────────────────────────────────────────────
// Execute
// ─────────────────────────────────────────────────────────────────────────────
func TestRealSSHClient_Execute_Success(t *testing.T) {
srv := startFakeSSHServer(t)
host, port := srv.hostPort()
c := &realSSHClient{config: &Config{
Host: host, Port: port, User: srv.user,
AuthMethod: "password", Password: srv.password, Timeout: 5,
}}
if err := c.Connect(context.Background()); err != nil {
t.Fatalf("Connect: %v", err)
}
defer c.Close()
out, err := c.Execute(context.Background(), "echo hello")
if err != nil {
t.Fatalf("Execute: %v", err)
}
if !strings.Contains(out, "exec ok") {
t.Errorf("expected canned 'exec ok' output, got %q", out)
}
}
func TestRealSSHClient_Execute_NotConnected(t *testing.T) {
c := &realSSHClient{config: &Config{}}
if _, err := c.Execute(context.Background(), "anything"); err == nil {
t.Errorf("expected error when sshClient is nil")
}
}
func TestRealSSHClient_Execute_ExitCode1(t *testing.T) {
srv := startFakeSSHServer(t, func(s *fakeSSHServer) { s.failExec = true })
host, port := srv.hostPort()
c := &realSSHClient{config: &Config{
Host: host, Port: port, User: srv.user,
AuthMethod: "password", Password: srv.password, Timeout: 5,
}}
if err := c.Connect(context.Background()); err != nil {
t.Fatalf("Connect: %v", err)
}
defer c.Close()
out, err := c.Execute(context.Background(), "anything")
if err == nil {
t.Errorf("expected non-zero exit code to surface as error; got out=%q", out)
}
}
// ─────────────────────────────────────────────────────────────────────────────
// WriteFile / StatFile via SFTP
// ─────────────────────────────────────────────────────────────────────────────
func TestRealSSHClient_WriteFile_StatFile_RoundTrip(t *testing.T) {
srv := startFakeSSHServer(t)
host, port := srv.hostPort()
c := &realSSHClient{config: &Config{
Host: host, Port: port, User: srv.user,
AuthMethod: "password", Password: srv.password, Timeout: 5,
}}
if err := c.Connect(context.Background()); err != nil {
t.Fatalf("Connect: %v", err)
}
defer c.Close()
// Use a temp path the in-process sftp server can write to. pkg/sftp's
// default server uses the OS filesystem, so use a t.TempDir-derived path.
dir := t.TempDir()
target := filepath.Join(dir, "out.pem")
payload := []byte("-----BEGIN CERTIFICATE-----\nfake\n-----END CERTIFICATE-----\n")
if err := c.WriteFile(target, payload, 0640); err != nil {
t.Fatalf("WriteFile: %v", err)
}
info, err := c.StatFile(target)
if err != nil {
t.Fatalf("StatFile: %v", err)
}
if info.Size() != int64(len(payload)) {
t.Errorf("expected size %d, got %d", len(payload), info.Size())
}
// Verify mode 0640 was set.
osInfo, err := os.Stat(target)
if err != nil {
t.Fatalf("os.Stat: %v", err)
}
if osInfo.Mode().Perm() != 0640 {
t.Errorf("expected mode 0640, got %v", osInfo.Mode().Perm())
}
// Verify content round-trips.
gotBytes, err := os.ReadFile(target)
if err != nil {
t.Fatalf("ReadFile: %v", err)
}
if !bytes.Equal(gotBytes, payload) {
t.Errorf("payload round-trip mismatch:\n got: %q\n want: %q", gotBytes, payload)
}
}
func TestRealSSHClient_WriteFile_NotConnected(t *testing.T) {
c := &realSSHClient{config: &Config{}}
if err := c.WriteFile("/tmp/x", []byte("y"), 0600); err == nil {
t.Errorf("expected error when sftpClient is nil")
}
}
func TestRealSSHClient_StatFile_NotConnected(t *testing.T) {
c := &realSSHClient{config: &Config{}}
if _, err := c.StatFile("/tmp/x"); err == nil {
t.Errorf("expected error when sftpClient is nil")
}
}
func TestRealSSHClient_StatFile_NotExist(t *testing.T) {
srv := startFakeSSHServer(t)
host, port := srv.hostPort()
c := &realSSHClient{config: &Config{
Host: host, Port: port, User: srv.user,
AuthMethod: "password", Password: srv.password, Timeout: 5,
}}
if err := c.Connect(context.Background()); err != nil {
t.Fatalf("Connect: %v", err)
}
defer c.Close()
if _, err := c.StatFile("/nonexistent/path/to/file"); err == nil {
t.Errorf("expected error stat'ing nonexistent file")
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Close
// ─────────────────────────────────────────────────────────────────────────────
func TestRealSSHClient_Close_Idempotent(t *testing.T) {
srv := startFakeSSHServer(t)
host, port := srv.hostPort()
c := &realSSHClient{config: &Config{
Host: host, Port: port, User: srv.user,
AuthMethod: "password", Password: srv.password, Timeout: 5,
}}
if err := c.Connect(context.Background()); err != nil {
t.Fatalf("Connect: %v", err)
}
if err := c.Close(); err != nil {
t.Errorf("first Close: %v", err)
}
// Second close — idempotent (should not panic, may return nil)
if err := c.Close(); err != nil {
t.Errorf("second Close: %v", err)
}
}
func TestRealSSHClient_Close_NeverConnected(t *testing.T) {
c := &realSSHClient{config: &Config{}}
if err := c.Close(); err != nil {
t.Errorf("Close on never-connected client should be nil, got %v", err)
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Suppress unused-import warning under some Go versions.
// ─────────────────────────────────────────────────────────────────────────────
var _ = io.EOF
var _ = time.Second