Files
certctl/internal/mcp/tools_test.go
T
Shankar 2c5383da9f fix(bundle-3): MCP Trust-Boundary Fencing — 5 audit findings closed
Closes Audit-2026-04-25 H-002, H-003, M-003, M-004, M-005 (all CWE-1039
LLM Prompt Injection at the MCP↔consumer trust boundary, TB-7).

Strategy: wrapper-layer fencing. All 87 MCP tools route their success
path through textResult and their failure path through errorResult. By
fencing at those two wrappers we cover every existing tool AND every
future tool with a single change — no per-tool wiring required.

What changed
- internal/mcp/fence.go (new) — FenceUntrusted helper with strategy
  doc + per-finding rationale. Both fenceMCPResponse and fenceMCPError
  use it internally.
- internal/mcp/tools.go — textResult wraps response body via
  fenceMCPResponse; errorResult wraps error string via fenceMCPError.
- internal/mcp/tools_test.go — TestTextResult / TestErrorResult updated
  to assert fenced shape (start marker + end marker + inner body).
- internal/mcp/injection_regression_test.go (new) — 5 regression test
  functions, one per audit finding, each replays 5 classic LLM
  injection payloads (instruction_override, system_role_spoofing,
  delimiter_break_attempt, markdown_link_phishing, data_exfil_via_url)
  and asserts the planted payload appears VERBATIM (preservation,
  operator visibility) INSIDE the fence boundaries.
- internal/mcp/fence_guardrail_test.go (new) — CI guardrail that walks
  every non-test .go file in the mcp package and fails if it finds a
  bare gomcp.CallToolResult literal outside tools.go. Prevents future
  tools from silently bypassing the fence.

Delimiter-forgery defense
The naive constant fence (--- UNTRUSTED MCP_RESPONSE END ---) is
forgeable: an attacker who controls a field value can plant the literal
end marker and "break out" of the fence. Defense: every fence call
generates a 6-byte crypto/rand nonce, hex-encoded, and embeds it in
BOTH the START and END markers. An attacker would need to predict the
nonce (2^48 search per fence) to forge a matching END inside the
payload. The delimiter_break_attempt regression test exercises this.

Per-finding mapping
- H-002 Cert Subject DN injection (CSR submitter controlled) →
  TestMCP_PromptInjection_H002_CertSubjectDN
- H-003 Discovered cert metadata injection (cert owner controlled) →
  TestMCP_PromptInjection_H003_DiscoveredCertMetadata
- M-003 Agent heartbeat injection (agent self-reports hostname/OS/IP)
  → TestMCP_PromptInjection_M003_AgentHeartbeat
- M-004 Upstream CA error injection (CA controls error string) →
  TestMCP_PromptInjection_M004_UpstreamCAError
- M-005 Audit details + notification body injection (downstream actors
  control these) → TestMCP_PromptInjection_M005_AuditDetailsAndNotifications

Verification gates
- go vet ./...                                 → clean
- go build ./...                               → clean
- go test -short -count=1 ./...                → all packages pass
- go test -count=1 ./internal/mcp/...          → all packages pass
- npx tsc --noEmit (web)                       → clean
- npx vitest run (web)                         → 337 passed
- python3 yaml.safe_load(api/openapi.yaml)     → 89 paths, 56 schemas

Threat-model placement: TB-7 (MCP↔LLM consumer). certctl owns the
boundary; consumer-side prompt engineering is recommended but not
relied upon. Defense-in-depth: per-call nonce closes the
delimiter-forgery edge case that constant fences would have left
exposed.

Bundle 3 of the 2026-04-25 comprehensive audit (88 findings).
2026-04-25 22:44:33 +00:00

436 lines
12 KiB
Go

package mcp
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
gomcp "github.com/modelcontextprotocol/go-sdk/mcp"
)
// requestLog captures HTTP requests made by MCP tool handlers.
type requestLog struct {
mu sync.Mutex
requests []capturedRequest
}
type capturedRequest struct {
Method string
Path string
Query string
Body string
}
func (rl *requestLog) add(r capturedRequest) {
rl.mu.Lock()
defer rl.mu.Unlock()
rl.requests = append(rl.requests, r)
}
func (rl *requestLog) last() capturedRequest {
rl.mu.Lock()
defer rl.mu.Unlock()
if len(rl.requests) == 0 {
return capturedRequest{}
}
return rl.requests[len(rl.requests)-1]
}
// mockCertctlAPI returns a test server that records all requests and returns
// canned JSON responses based on the path.
func mockCertctlAPI(log *requestLog) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body := ""
if r.Body != nil {
buf := make([]byte, 4096)
n, _ := r.Body.Read(buf)
body = string(buf[:n])
}
log.add(capturedRequest{
Method: r.Method,
Path: r.URL.Path,
Query: r.URL.RawQuery,
Body: body,
})
w.Header().Set("Content-Type", "application/json")
switch {
case r.Method == "DELETE":
w.WriteHeader(http.StatusNoContent)
case strings.HasSuffix(r.URL.Path, "/renew") || strings.HasSuffix(r.URL.Path, "/deploy"):
w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(map[string]string{"status": "accepted", "job_id": "job-001"})
case r.Method == "POST":
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]string{"id": "new-resource"})
default:
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{
"data": []interface{}{map[string]string{"id": "test-1"}},
"total": 1,
})
}
}))
}
func TestRegisterTools_ToolCount(t *testing.T) {
server := gomcp.NewServer(&gomcp.Implementation{
Name: "certctl-test",
Version: "test",
}, nil)
log := &requestLog{}
api := mockCertctlAPI(log)
defer api.Close()
client, _ := NewClient(api.URL, "test-key", "", false)
RegisterTools(server, client)
// The server should have tools registered — we can verify by listing them
// Since the SDK doesn't expose a tool count method, we verify through the
// request capabilities
t.Log("RegisterTools completed without panic")
}
func TestPaginationQuery(t *testing.T) {
tests := []struct {
name string
page int
perPage int
wantLen int
}{
{"both set", 2, 50, 2},
{"page only", 3, 0, 1},
{"per_page only", 0, 100, 1},
{"neither set", 0, 0, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := paginationQuery(tt.page, tt.perPage)
if len(q) != tt.wantLen {
t.Errorf("expected %d query params, got %d", tt.wantLen, len(q))
}
if tt.page > 0 {
if q.Get("page") != string(rune('0'+tt.page)) && q.Get("page") == "" {
t.Errorf("expected page param to be set")
}
}
})
}
}
func TestTextResult(t *testing.T) {
// Bundle-3: textResult wraps the response body in untrusted-data fences.
// The fence labels the data as MCP_RESPONSE so LLM consumers can be
// instructed to interpret the inner JSON as opaque content rather than
// instructions. See internal/mcp/fence.go for the strategy doc.
data := json.RawMessage(`{"id":"mc-test","status":"Active"}`)
result, metadata, err := textResult(data)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if metadata != nil {
t.Errorf("expected nil metadata, got %v", metadata)
}
if result == nil {
t.Fatal("expected non-nil result")
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content item, got %d", len(result.Content))
}
tc, ok := result.Content[0].(*gomcp.TextContent)
if !ok {
t.Fatal("expected TextContent type")
}
if !strings.Contains(tc.Text, "--- UNTRUSTED MCP_RESPONSE START") {
t.Errorf("missing start fence in text content: %s", tc.Text)
}
if !strings.Contains(tc.Text, "--- UNTRUSTED MCP_RESPONSE END") {
t.Errorf("missing end fence in text content: %s", tc.Text)
}
if !strings.Contains(tc.Text, `{"id":"mc-test","status":"Active"}`) {
t.Errorf("inner body missing from fenced content: %s", tc.Text)
}
}
func TestErrorResult(t *testing.T) {
// Bundle-3: errorResult wraps the error message in untrusted-data fences.
// Upstream-CA error strings are attacker-controllable (M-004), so the
// fence prevents an injected "ignore previous instructions" payload in
// a CA error from steering the LLM consumer.
result, _, err := errorResult(http.ErrServerClosed)
if result != nil {
t.Errorf("expected nil result, got %v", result)
}
if err == nil {
t.Fatal("expected non-nil error")
}
if !strings.Contains(err.Error(), "--- UNTRUSTED MCP_ERROR START") {
t.Errorf("missing start fence in error: %s", err.Error())
}
if !strings.Contains(err.Error(), "--- UNTRUSTED MCP_ERROR END") {
t.Errorf("missing end fence in error: %s", err.Error())
}
if !strings.Contains(err.Error(), http.ErrServerClosed.Error()) {
t.Errorf("inner error missing from fenced content: %s", err.Error())
}
}
// TestToolEndToEnd_ListCertificates verifies the full flow:
// MCP tool handler → HTTP client → mock API → response formatting
func TestToolEndToEnd_ListCertificates(t *testing.T) {
log := &requestLog{}
api := mockCertctlAPI(log)
defer api.Close()
client, _ := NewClient(api.URL, "test-key", "", false)
// Manually call the handler logic that would be registered as a tool
q := paginationQuery(1, 50)
q.Set("status", "Active")
data, err := client.Get("/api/v1/certificates", q)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req := log.last()
if req.Method != "GET" {
t.Errorf("expected GET, got %s", req.Method)
}
if req.Path != "/api/v1/certificates" {
t.Errorf("expected path /api/v1/certificates, got %s", req.Path)
}
if !strings.Contains(req.Query, "status=Active") {
t.Errorf("expected status=Active in query, got %s", req.Query)
}
if !strings.Contains(req.Query, "page=1") {
t.Errorf("expected page=1 in query, got %s", req.Query)
}
result, _, err := textResult(data)
if err != nil {
t.Fatalf("unexpected error formatting result: %v", err)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content item, got %d", len(result.Content))
}
}
func TestToolEndToEnd_CreateCertificate(t *testing.T) {
log := &requestLog{}
api := mockCertctlAPI(log)
defer api.Close()
client, _ := NewClient(api.URL, "test-key", "", false)
input := CreateCertificateInput{
Name: "API Production",
CommonName: "api.example.com",
IssuerID: "iss-local",
OwnerID: "o-alice",
TeamID: "team-platform",
}
data, err := client.Post("/api/v1/certificates", input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req := log.last()
if req.Method != "POST" {
t.Errorf("expected POST, got %s", req.Method)
}
if req.Path != "/api/v1/certificates" {
t.Errorf("expected path /api/v1/certificates, got %s", req.Path)
}
if !strings.Contains(req.Body, "api.example.com") {
t.Errorf("expected common_name in body, got %s", req.Body)
}
var result map[string]string
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if result["id"] != "new-resource" {
t.Errorf("expected id=new-resource, got %s", result["id"])
}
}
func TestToolEndToEnd_TriggerRenewal(t *testing.T) {
log := &requestLog{}
api := mockCertctlAPI(log)
defer api.Close()
client, _ := NewClient(api.URL, "test-key", "", false)
data, err := client.Post("/api/v1/certificates/mc-api-prod/renew", nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req := log.last()
if req.Method != "POST" {
t.Errorf("expected POST, got %s", req.Method)
}
if req.Path != "/api/v1/certificates/mc-api-prod/renew" {
t.Errorf("expected path /api/v1/certificates/mc-api-prod/renew, got %s", req.Path)
}
var result map[string]string
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if result["job_id"] != "job-001" {
t.Errorf("expected job_id=job-001, got %s", result["job_id"])
}
}
func TestToolEndToEnd_DeleteTarget(t *testing.T) {
log := &requestLog{}
api := mockCertctlAPI(log)
defer api.Close()
client, _ := NewClient(api.URL, "test-key", "", false)
data, err := client.Delete("/api/v1/targets/t-platform")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req := log.last()
if req.Method != "DELETE" {
t.Errorf("expected DELETE, got %s", req.Method)
}
if req.Path != "/api/v1/targets/t-platform" {
t.Errorf("expected path /api/v1/targets/t-platform, got %s", req.Path)
}
var result map[string]string
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if result["status"] != "deleted" {
t.Errorf("expected status=deleted, got %s", result["status"])
}
}
func TestToolEndToEnd_RevokeCertificate(t *testing.T) {
log := &requestLog{}
api := mockCertctlAPI(log)
defer api.Close()
client, _ := NewClient(api.URL, "test-key", "", false)
input := RevokeCertificateInput{
ID: "mc-api-prod",
Reason: "keyCompromise",
}
_, err := client.Post("/api/v1/certificates/"+input.ID+"/revoke", map[string]string{"reason": input.Reason})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req := log.last()
if req.Method != "POST" {
t.Errorf("expected POST, got %s", req.Method)
}
if req.Path != "/api/v1/certificates/mc-api-prod/revoke" {
t.Errorf("expected path /api/v1/certificates/mc-api-prod/revoke, got %s", req.Path)
}
if !strings.Contains(req.Body, "keyCompromise") {
t.Errorf("expected reason in body, got %s", req.Body)
}
}
func TestToolEndToEnd_AgentHeartbeat(t *testing.T) {
log := &requestLog{}
api := mockCertctlAPI(log)
defer api.Close()
client, _ := NewClient(api.URL, "test-key", "", false)
_, err := client.Post("/api/v1/agents/agent-001/heartbeat", map[string]string{
"os": "linux",
"architecture": "amd64",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req := log.last()
if req.Path != "/api/v1/agents/agent-001/heartbeat" {
t.Errorf("expected path /api/v1/agents/agent-001/heartbeat, got %s", req.Path)
}
}
func TestToolEndToEnd_ListWithFilters(t *testing.T) {
log := &requestLog{}
api := mockCertctlAPI(log)
defer api.Close()
client, _ := NewClient(api.URL, "test-key", "", false)
q := paginationQuery(1, 25)
q.Set("status", "Pending")
q.Set("type", "Renewal")
_, err := client.Get("/api/v1/jobs", q)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req := log.last()
if req.Path != "/api/v1/jobs" {
t.Errorf("expected path /api/v1/jobs, got %s", req.Path)
}
if !strings.Contains(req.Query, "status=Pending") {
t.Errorf("expected status filter in query, got %s", req.Query)
}
if !strings.Contains(req.Query, "type=Renewal") {
t.Errorf("expected type filter in query, got %s", req.Query)
}
}
func TestToolEndToEnd_GetRawBinary(t *testing.T) {
derData := []byte{0x30, 0x82, 0x01, 0x22}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/pkix-crl")
w.WriteHeader(http.StatusOK)
w.Write(derData)
}))
defer server.Close()
client, _ := NewClient(server.URL, "test-key", "", false)
data, ct, err := client.GetRaw("/.well-known/pki/crl/iss-local")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ct != "application/pkix-crl" {
t.Errorf("expected content-type application/pkix-crl, got %s", ct)
}
if len(data) != 4 {
t.Errorf("expected 4 bytes, got %d", len(data))
}
}
func TestToolEndToEnd_ErrorPropagation(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"error":"forbidden"}`))
}))
defer server.Close()
client, _ := NewClient(server.URL, "test-key", "", false)
_, err := client.Get("/api/v1/certificates", nil)
if err == nil {
t.Fatal("expected error for 403 response")
}
result, _, toolErr := errorResult(err)
if result != nil {
t.Errorf("expected nil result from errorResult")
}
if toolErr == nil {
t.Fatal("expected non-nil error from errorResult")
}
}