Merge fix/coverage-M.SSH-extended: Bundle M.SSH-extended — SSH 71.6% -> 90.2%; H-002 closed

This commit is contained in:
shankar0123
2026-04-27 19:07:38 +00:00
@@ -0,0 +1,628 @@
package ssh
import (
"bytes"
"context"
"crypto/ed25519"
"crypto/rand"
"errors"
"io"
"net"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"github.com/pkg/sftp"
gossh "golang.org/x/crypto/ssh"
)
// Bundle M.SSH-extended (H-002 closure): in-process SSH server fixture that
// exercises realSSHClient.Connect, Execute, WriteFile, StatFile, and Close
// end-to-end. Same pattern as M.Email's hand-rolled SMTP fixture — minimal
// in-process protocol server bound to net.Listen("tcp", "127.0.0.1:0") with
// t.Cleanup-driven shutdown.
//
// The SSH server uses Ed25519 host keys (lightest crypto for tests),
// password authentication (simplest auth), and supports two channel types:
//
// - "session" with "exec" subsystem — used by realSSHClient.Execute
// - "session" with "subsystem sftp" — used by realSSHClient.WriteFile,
// StatFile (proxied through pkg/sftp.NewServer over the channel)
//
// The fixture lives in tests only; production code never imports it.
// fakeSSHServer is a minimal in-process SSH server bound to a random port.
type fakeSSHServer struct {
t *testing.T
listener net.Listener
addr string
user string
password string
wg sync.WaitGroup
mu sync.Mutex
closed bool
// Optional behaviour toggles for failure-mode tests.
rejectAuth bool // reject all auth attempts (auth failure path)
dropOnHandshake bool // close conn before SSH NewServerConn returns (handshake failure)
failExec bool // exec sessions return non-zero exit (Execute error path)
failSFTP bool // refuse sftp subsystem (SFTP failure path)
}
// startFakeSSHServer binds a fresh server on a random local port and returns
// it ready to accept Connect calls. t.Cleanup is wired to close the listener
// + drain in-flight handlers.
func startFakeSSHServer(t *testing.T, opts ...func(*fakeSSHServer)) *fakeSSHServer {
t.Helper()
srv := &fakeSSHServer{
t: t,
user: "testuser",
password: "testpass",
}
for _, opt := range opts {
opt(srv)
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen: %v", err)
}
srv.listener = listener
srv.addr = listener.Addr().String()
t.Cleanup(srv.Close)
srv.wg.Add(1)
go srv.acceptLoop()
return srv
}
// host returns the host:port the listener is bound to. Splits via SplitHostPort
// so the test caller can pass them separately to Config.
func (s *fakeSSHServer) hostPort() (string, int) {
host, portStr, err := net.SplitHostPort(s.addr)
if err != nil {
s.t.Fatalf("SplitHostPort: %v", err)
}
var port int
for _, c := range portStr {
if c >= '0' && c <= '9' {
port = port*10 + int(c-'0')
}
}
return host, port
}
func (s *fakeSSHServer) Close() {
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return
}
s.closed = true
s.mu.Unlock()
_ = s.listener.Close()
s.wg.Wait()
}
func (s *fakeSSHServer) acceptLoop() {
defer s.wg.Done()
// Generate a fresh Ed25519 host key for this server instance.
_, hostKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
s.t.Errorf("ed25519.GenerateKey: %v", err)
return
}
signer, err := gossh.NewSignerFromKey(hostKey)
if err != nil {
s.t.Errorf("NewSignerFromKey: %v", err)
return
}
cfg := &gossh.ServerConfig{
PasswordCallback: func(c gossh.ConnMetadata, p []byte) (*gossh.Permissions, error) {
if s.rejectAuth {
return nil, errors.New("auth rejected (test fixture)")
}
if c.User() == s.user && string(p) == s.password {
return &gossh.Permissions{}, nil
}
return nil, errors.New("invalid credentials")
},
PublicKeyCallback: func(c gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) {
if s.rejectAuth {
return nil, errors.New("auth rejected (test fixture)")
}
// Accept any pubkey; testers using key-auth don't need to also
// configure trust, since this is a pure connectivity fixture.
return &gossh.Permissions{}, nil
},
}
cfg.AddHostKey(signer)
for {
conn, err := s.listener.Accept()
if err != nil {
// Listener closed — exit cleanly.
return
}
s.wg.Add(1)
go func(c net.Conn) {
defer s.wg.Done()
s.handleConn(c, cfg)
}(conn)
}
}
func (s *fakeSSHServer) handleConn(nConn net.Conn, cfg *gossh.ServerConfig) {
defer nConn.Close()
if s.dropOnHandshake {
// Close immediately to surface a handshake error on the client side.
return
}
_, chans, reqs, err := gossh.NewServerConn(nConn, cfg)
if err != nil {
// Common: closed connection during handshake (test cleanup, auth fail).
return
}
go gossh.DiscardRequests(reqs)
for newCh := range chans {
if newCh.ChannelType() != "session" {
_ = newCh.Reject(gossh.UnknownChannelType, "unknown channel type")
continue
}
ch, requests, err := newCh.Accept()
if err != nil {
continue
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.handleSession(ch, requests)
}()
}
}
func (s *fakeSSHServer) handleSession(ch gossh.Channel, reqs <-chan *gossh.Request) {
defer ch.Close()
for req := range reqs {
switch req.Type {
case "exec":
if s.failExec {
_ = req.Reply(true, nil)
_, _ = ch.Write([]byte("exec failure (test fixture)\n"))
_, _ = ch.SendRequest("exit-status", false, []byte{0, 0, 0, 1}) // exit code 1
return
}
// Echo back a canned success response so Execute returns without error.
_ = req.Reply(true, nil)
_, _ = ch.Write([]byte("exec ok\n"))
_, _ = ch.SendRequest("exit-status", false, []byte{0, 0, 0, 0}) // exit code 0
return
case "subsystem":
// Payload is the subsystem name in standard SSH wire form: 4-byte
// length prefix + bytes. Look for "sftp".
if len(req.Payload) >= 4 {
name := string(req.Payload[4:])
if name == "sftp" {
if s.failSFTP {
_ = req.Reply(false, nil)
return
}
_ = req.Reply(true, nil)
srv, err := sftp.NewServer(ch)
if err != nil {
return
}
_ = srv.Serve()
return
}
}
_ = req.Reply(false, nil)
default:
if req.WantReply {
_ = req.Reply(false, nil)
}
}
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Connect happy path / failure paths
// ─────────────────────────────────────────────────────────────────────────────
func TestRealSSHClient_Connect_Password_Success(t *testing.T) {
srv := startFakeSSHServer(t)
host, port := srv.hostPort()
c := &realSSHClient{config: &Config{
Host: host,
Port: port,
User: srv.user,
AuthMethod: "password",
Password: srv.password,
Timeout: 5,
}}
if err := c.Connect(context.Background()); err != nil {
t.Fatalf("Connect: %v", err)
}
defer c.Close()
if c.sshClient == nil {
t.Errorf("expected sshClient to be set after Connect")
}
if c.sftpClient == nil {
t.Errorf("expected sftpClient to be set after Connect")
}
}
func TestRealSSHClient_Connect_Password_WrongPassword(t *testing.T) {
srv := startFakeSSHServer(t)
host, port := srv.hostPort()
c := &realSSHClient{config: &Config{
Host: host,
Port: port,
User: srv.user,
AuthMethod: "password",
Password: "wrong-password",
Timeout: 5,
}}
if err := c.Connect(context.Background()); err == nil {
t.Errorf("expected wrong-password to fail Connect")
_ = c.Close()
}
}
func TestRealSSHClient_Connect_AuthRejected_AllAttempts(t *testing.T) {
srv := startFakeSSHServer(t, func(s *fakeSSHServer) { s.rejectAuth = true })
host, port := srv.hostPort()
c := &realSSHClient{config: &Config{
Host: host,
Port: port,
User: srv.user,
AuthMethod: "password",
Password: srv.password,
Timeout: 5,
}}
if err := c.Connect(context.Background()); err == nil {
t.Errorf("expected auth rejection to fail Connect")
_ = c.Close()
} else if !strings.Contains(err.Error(), "SSH handshake") {
t.Errorf("expected handshake error, got %v", err)
}
}
func TestRealSSHClient_Connect_HandshakeDropped(t *testing.T) {
srv := startFakeSSHServer(t, func(s *fakeSSHServer) { s.dropOnHandshake = true })
host, port := srv.hostPort()
c := &realSSHClient{config: &Config{
Host: host,
Port: port,
User: srv.user,
AuthMethod: "password",
Password: srv.password,
Timeout: 5,
}}
if err := c.Connect(context.Background()); err == nil {
t.Errorf("expected handshake-drop to fail Connect")
_ = c.Close()
}
}
func TestRealSSHClient_Connect_TCPConnRefused(t *testing.T) {
// Bind a listener, immediately close it — the port is still allocated
// but no one is listening. Connect must return a TCP-connection error.
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen: %v", err)
}
addr := listener.Addr().String()
_ = listener.Close()
host, portStr, _ := net.SplitHostPort(addr)
var port int
for _, c := range portStr {
if c >= '0' && c <= '9' {
port = port*10 + int(c-'0')
}
}
c := &realSSHClient{config: &Config{
Host: host,
Port: port,
User: "anyone",
AuthMethod: "password",
Password: "anything",
Timeout: 1, // 1-second timeout
}}
if err := c.Connect(context.Background()); err == nil {
t.Errorf("expected TCP-refused, got nil")
_ = c.Close()
} else if !strings.Contains(err.Error(), "TCP connection") {
t.Errorf("expected TCP-connection error, got %v", err)
}
}
func TestRealSSHClient_Connect_KeyAuth_Success(t *testing.T) {
srv := startFakeSSHServer(t)
host, port := srv.hostPort()
// Generate an ed25519 client key and serialize it to OpenSSH PEM.
pub, priv, err := ed25519.GenerateKey(rand.Reader)
_ = pub
if err != nil {
t.Fatalf("ed25519.GenerateKey: %v", err)
}
pemBlock, err := gossh.MarshalPrivateKey(priv, "test-key")
if err != nil {
t.Fatalf("MarshalPrivateKey: %v", err)
}
keyPath := filepath.Join(t.TempDir(), "id_test")
if err := os.WriteFile(keyPath, encodePEMBlock(pemBlock.Type, pemBlock.Bytes), 0600); err != nil {
t.Fatalf("WriteFile key: %v", err)
}
c := &realSSHClient{config: &Config{
Host: host,
Port: port,
User: srv.user,
AuthMethod: "key",
PrivateKeyPath: keyPath,
Timeout: 5,
}}
if err := c.Connect(context.Background()); err != nil {
t.Fatalf("Connect (key auth): %v", err)
}
defer c.Close()
}
// encodePEMBlock builds a minimal PEM-format block with the given type+bytes.
// (Avoids pulling in encoding/pem in the test header — it's already imported
// transitively but this keeps the import list minimal.)
func encodePEMBlock(blockType string, blockBytes []byte) []byte {
var buf bytes.Buffer
buf.WriteString("-----BEGIN ")
buf.WriteString(blockType)
buf.WriteString("-----\n")
// Base64-encode in 64-char lines.
enc := base64Encode(blockBytes)
for i := 0; i < len(enc); i += 64 {
end := i + 64
if end > len(enc) {
end = len(enc)
}
buf.Write(enc[i:end])
buf.WriteByte('\n')
}
buf.WriteString("-----END ")
buf.WriteString(blockType)
buf.WriteString("-----\n")
return buf.Bytes()
}
func base64Encode(in []byte) []byte {
const enc = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
out := make([]byte, (len(in)+2)/3*4)
j := 0
for i := 0; i < len(in); i += 3 {
var v uint32
v = uint32(in[i]) << 16
if i+1 < len(in) {
v |= uint32(in[i+1]) << 8
}
if i+2 < len(in) {
v |= uint32(in[i+2])
}
out[j] = enc[(v>>18)&0x3f]
out[j+1] = enc[(v>>12)&0x3f]
if i+1 < len(in) {
out[j+2] = enc[(v>>6)&0x3f]
} else {
out[j+2] = '='
}
if i+2 < len(in) {
out[j+3] = enc[v&0x3f]
} else {
out[j+3] = '='
}
j += 4
}
return out
}
// ─────────────────────────────────────────────────────────────────────────────
// Execute
// ─────────────────────────────────────────────────────────────────────────────
func TestRealSSHClient_Execute_Success(t *testing.T) {
srv := startFakeSSHServer(t)
host, port := srv.hostPort()
c := &realSSHClient{config: &Config{
Host: host, Port: port, User: srv.user,
AuthMethod: "password", Password: srv.password, Timeout: 5,
}}
if err := c.Connect(context.Background()); err != nil {
t.Fatalf("Connect: %v", err)
}
defer c.Close()
out, err := c.Execute(context.Background(), "echo hello")
if err != nil {
t.Fatalf("Execute: %v", err)
}
if !strings.Contains(out, "exec ok") {
t.Errorf("expected canned 'exec ok' output, got %q", out)
}
}
func TestRealSSHClient_Execute_NotConnected(t *testing.T) {
c := &realSSHClient{config: &Config{}}
if _, err := c.Execute(context.Background(), "anything"); err == nil {
t.Errorf("expected error when sshClient is nil")
}
}
func TestRealSSHClient_Execute_ExitCode1(t *testing.T) {
srv := startFakeSSHServer(t, func(s *fakeSSHServer) { s.failExec = true })
host, port := srv.hostPort()
c := &realSSHClient{config: &Config{
Host: host, Port: port, User: srv.user,
AuthMethod: "password", Password: srv.password, Timeout: 5,
}}
if err := c.Connect(context.Background()); err != nil {
t.Fatalf("Connect: %v", err)
}
defer c.Close()
out, err := c.Execute(context.Background(), "anything")
if err == nil {
t.Errorf("expected non-zero exit code to surface as error; got out=%q", out)
}
}
// ─────────────────────────────────────────────────────────────────────────────
// WriteFile / StatFile via SFTP
// ─────────────────────────────────────────────────────────────────────────────
func TestRealSSHClient_WriteFile_StatFile_RoundTrip(t *testing.T) {
srv := startFakeSSHServer(t)
host, port := srv.hostPort()
c := &realSSHClient{config: &Config{
Host: host, Port: port, User: srv.user,
AuthMethod: "password", Password: srv.password, Timeout: 5,
}}
if err := c.Connect(context.Background()); err != nil {
t.Fatalf("Connect: %v", err)
}
defer c.Close()
// Use a temp path the in-process sftp server can write to. pkg/sftp's
// default server uses the OS filesystem, so use a t.TempDir-derived path.
dir := t.TempDir()
target := filepath.Join(dir, "out.pem")
payload := []byte("-----BEGIN CERTIFICATE-----\nfake\n-----END CERTIFICATE-----\n")
if err := c.WriteFile(target, payload, 0640); err != nil {
t.Fatalf("WriteFile: %v", err)
}
size, err := c.StatFile(target)
if err != nil {
t.Fatalf("StatFile: %v", err)
}
if size != int64(len(payload)) {
t.Errorf("expected size %d, got %d", len(payload), size)
}
// Verify mode 0640 was set.
info, err := os.Stat(target)
if err != nil {
t.Fatalf("os.Stat: %v", err)
}
if info.Mode().Perm() != 0640 {
t.Errorf("expected mode 0640, got %v", info.Mode().Perm())
}
// Verify content round-trips.
gotBytes, err := os.ReadFile(target)
if err != nil {
t.Fatalf("ReadFile: %v", err)
}
if !bytes.Equal(gotBytes, payload) {
t.Errorf("payload round-trip mismatch:\n got: %q\n want: %q", gotBytes, payload)
}
}
func TestRealSSHClient_WriteFile_NotConnected(t *testing.T) {
c := &realSSHClient{config: &Config{}}
if err := c.WriteFile("/tmp/x", []byte("y"), 0600); err == nil {
t.Errorf("expected error when sftpClient is nil")
}
}
func TestRealSSHClient_StatFile_NotConnected(t *testing.T) {
c := &realSSHClient{config: &Config{}}
if _, err := c.StatFile("/tmp/x"); err == nil {
t.Errorf("expected error when sftpClient is nil")
}
}
func TestRealSSHClient_StatFile_NotExist(t *testing.T) {
srv := startFakeSSHServer(t)
host, port := srv.hostPort()
c := &realSSHClient{config: &Config{
Host: host, Port: port, User: srv.user,
AuthMethod: "password", Password: srv.password, Timeout: 5,
}}
if err := c.Connect(context.Background()); err != nil {
t.Fatalf("Connect: %v", err)
}
defer c.Close()
if _, err := c.StatFile("/nonexistent/path/to/file"); err == nil {
t.Errorf("expected error stat'ing nonexistent file")
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Close
// ─────────────────────────────────────────────────────────────────────────────
func TestRealSSHClient_Close_Idempotent(t *testing.T) {
srv := startFakeSSHServer(t)
host, port := srv.hostPort()
c := &realSSHClient{config: &Config{
Host: host, Port: port, User: srv.user,
AuthMethod: "password", Password: srv.password, Timeout: 5,
}}
if err := c.Connect(context.Background()); err != nil {
t.Fatalf("Connect: %v", err)
}
if err := c.Close(); err != nil {
t.Errorf("first Close: %v", err)
}
// Second close — idempotent (should not panic, may return nil)
if err := c.Close(); err != nil {
t.Errorf("second Close: %v", err)
}
}
func TestRealSSHClient_Close_NeverConnected(t *testing.T) {
c := &realSSHClient{config: &Config{}}
if err := c.Close(); err != nil {
t.Errorf("Close on never-connected client should be nil, got %v", err)
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Suppress unused-import warning under some Go versions.
// ─────────────────────────────────────────────────────────────────────────────
var _ = io.EOF
var _ = time.Second