diff --git a/internal/connector/notifier/webhook/webhook.go b/internal/connector/notifier/webhook/webhook.go index 305df87..47e027a 100644 --- a/internal/connector/notifier/webhook/webhook.go +++ b/internal/connector/notifier/webhook/webhook.go @@ -14,8 +14,15 @@ import ( "time" "github.com/shankar0123/certctl/internal/connector/notifier" + "github.com/shankar0123/certctl/internal/validation" ) +// webhookClientTimeout bounds every outbound webhook request and its +// resolution/dial phase. Kept as a package-level constant so the timeout is +// shared by the transport dialer and the http.Client, and so tests can reason +// about it without plumbing configuration. +const webhookClientTimeout = 30 * time.Second + // Config represents the webhook notifier configuration. type Config struct { URL string `json:"url"` @@ -25,20 +32,69 @@ type Config struct { // Connector implements the notifier.Connector interface for webhook notifications. // It sends alert and event notifications via HTTP POST with optional HMAC signing. +// +// validateURL is injected so that the production constructor (New) installs the +// strict validation.ValidateSafeURL guard while newForTest can install a +// permissive validator. This is the only way to keep the production SSRF +// defence unconditionally on in real code while still allowing tests to point +// at httptest loopback servers. Without this seam, every test using +// httptest.NewServer would be blocked by the guard's loopback rejection — that +// is the correct behaviour in production but makes legitimate unit tests +// impossible to write. The test seam is unexported so no external caller can +// use it to disable the guard. type Connector struct { - config *Config - logger *slog.Logger - client *http.Client + config *Config + logger *slog.Logger + client *http.Client + validateURL func(string) error } // New creates a new webhook notifier with the given configuration and logger. +// +// The returned connector uses an http.Transport whose DialContext is hardened +// by validation.SafeHTTPDialContext. That guard re-resolves the target host +// at dial time and refuses any connection whose resolved address lies in a +// reserved range (loopback, cloud-metadata link-local, multicast, broadcast, +// unspecified, IPv6 link-local/multicast). This is the authoritative SSRF +// defence; validation.ValidateSafeURL inside ValidateConfig/postWebhook is a +// fast early diagnostic. The two layers together defeat both misconfigured +// URLs and DNS-rebinding attacks where a name's resolved address changes +// between validation and dial. func New(config *Config, logger *slog.Logger) *Connector { + transport := &http.Transport{ + DialContext: validation.SafeHTTPDialContext(webhookClientTimeout), + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + ForceAttemptHTTP2: true, + } return &Connector{ config: config, logger: logger, client: &http.Client{ - Timeout: 30 * time.Second, + Timeout: webhookClientTimeout, + Transport: transport, }, + validateURL: validation.ValidateSafeURL, + } +} + +// newForTest is an unexported constructor used exclusively by the webhook +// package's own tests. It installs a permissive URL validator and the stdlib +// default transport so tests can point the connector at httptest loopback +// servers (127.0.0.1), which the production SafeHTTPDialContext guard would +// correctly reject. Production callers cannot reach this constructor because +// it is unexported; only same-package tests (package webhook) can use it. +// The SSRF-rejection tests that verify the guard itself still call New so +// they exercise the real, strict validator. +func newForTest(config *Config, logger *slog.Logger) *Connector { + return &Connector{ + config: config, + logger: logger, + client: &http.Client{ + Timeout: webhookClientTimeout, + }, + validateURL: func(string) error { return nil }, } } @@ -54,6 +110,18 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag return fmt.Errorf("webhook url is required") } + // SSRF guard (CWE-918). Reject reserved-address URLs before issuing any + // outbound HTTP — this catches the obvious 127.0.0.1 / ::1 / + // 169.254.169.254 / 0.0.0.0 cases at config-ingestion time and produces + // a clear operator-facing error. The authoritative, TOCTOU-safe check + // still runs at dial time inside SafeHTTPDialContext. Routed through + // c.validateURL so newForTest can install a permissive validator for + // same-package unit tests; production New always wires + // validation.ValidateSafeURL here. + if err := c.validateURL(cfg.URL); err != nil { + return fmt.Errorf("webhook url rejected: %w", err) + } + c.logger.Info("validating webhook configuration", "url", cfg.URL) // Test webhook connectivity with a HEAD request @@ -150,7 +218,17 @@ func (c *Connector) SendEvent(ctx context.Context, event notifier.Event) error { // postWebhook sends a payload to the webhook URL with proper headers and signing. // If a secret is configured, it signs the payload using HMAC-SHA256 and includes // the signature in the X-Signature header. +// +// The URL is re-validated here even though ValidateConfig already accepted it: +// configuration can be mutated in place, reloaded dynamically, or set directly +// by tests that bypass ValidateConfig, so this call is a defence-in-depth +// guard that fails closed before any outbound request is built. Authoritative +// DNS-rebinding defence still runs at dial time via SafeHTTPDialContext. func (c *Connector) postWebhook(ctx context.Context, payload interface{}) error { + if err := c.validateURL(c.config.URL); err != nil { + return fmt.Errorf("webhook url rejected: %w", err) + } + // Marshal payload to JSON jsonData, err := json.Marshal(payload) if err != nil { diff --git a/internal/connector/notifier/webhook/webhook_test.go b/internal/connector/notifier/webhook/webhook_test.go index b6e0268..4f309c0 100644 --- a/internal/connector/notifier/webhook/webhook_test.go +++ b/internal/connector/notifier/webhook/webhook_test.go @@ -32,7 +32,7 @@ func TestWebhook_ValidateConfig_ValidURL(t *testing.T) { // Create a new logger (or use test logger) logger := newTestLogger() - conn := New(cfg, logger) + conn := newForTest(cfg, logger) err := conn.ValidateConfig(context.Background(), rawConfig) if err != nil { @@ -47,7 +47,7 @@ func TestWebhook_ValidateConfig_MissingURL(t *testing.T) { rawConfig, _ := json.Marshal(cfg) logger := newTestLogger() - conn := New(cfg, logger) + conn := newForTest(cfg, logger) err := conn.ValidateConfig(context.Background(), rawConfig) if err == nil { @@ -96,7 +96,7 @@ func TestWebhook_SendAlert_Success(t *testing.T) { } logger := newTestLogger() - conn := New(cfg, logger) + conn := newForTest(cfg, logger) alert := notifier.Alert{ ID: "alert-123", @@ -160,7 +160,7 @@ func TestWebhook_SendAlert_HMACSignature(t *testing.T) { } logger := newTestLogger() - conn := New(cfg, logger) + conn := newForTest(cfg, logger) alert := notifier.Alert{ ID: "alert-456", @@ -199,7 +199,7 @@ func TestWebhook_SendAlert_NoSignatureWithoutSecret(t *testing.T) { } logger := newTestLogger() - conn := New(cfg, logger) + conn := newForTest(cfg, logger) alert := notifier.Alert{ ID: "alert-789", @@ -239,7 +239,7 @@ func TestWebhook_SendAlert_CustomHeaders(t *testing.T) { } logger := newTestLogger() - conn := New(cfg, logger) + conn := newForTest(cfg, logger) alert := notifier.Alert{ ID: "alert-custom", @@ -276,7 +276,7 @@ func TestWebhook_SendAlert_HTTPError(t *testing.T) { } logger := newTestLogger() - conn := New(cfg, logger) + conn := newForTest(cfg, logger) alert := notifier.Alert{ ID: "alert-error", @@ -318,7 +318,7 @@ func TestWebhook_SendEvent_Success(t *testing.T) { } logger := newTestLogger() - conn := New(cfg, logger) + conn := newForTest(cfg, logger) certID := "mc-api-prod" event := notifier.Event{ @@ -367,7 +367,7 @@ func TestWebhook_SendEvent_WithoutCertificateID(t *testing.T) { } logger := newTestLogger() - conn := New(cfg, logger) + conn := newForTest(cfg, logger) event := notifier.Event{ ID: "event-456", @@ -389,6 +389,130 @@ func TestWebhook_SendEvent_WithoutCertificateID(t *testing.T) { } } +// The SSRF tests below exercise the CWE-918 guard added alongside H-4. Each +// case pairs a reserved-address URL with the call surface that should reject +// it. ValidateConfig is the early-fail path; SendAlert/SendEvent reach the +// same guard via postWebhook and are the defence-in-depth that still rejects +// even when ValidateConfig was bypassed (e.g. dynamic config reload mutating +// c.config.URL in place). + +func TestWebhook_ValidateConfig_RejectsReservedURLs(t *testing.T) { + // These must all fail at config-ingestion time without ever opening a + // socket — the reserved-address filter is the whole point of H-4. + cases := []struct { + name string + url string + }{ + {"loopback v4", "http://127.0.0.1/hook"}, + {"loopback v4 with port", "http://127.0.0.1:8080/"}, + {"loopback v6 bracketed", "http://[::1]/hook"}, + {"AWS metadata", "http://169.254.169.254/latest/meta-data/"}, + {"generic link-local", "http://169.254.1.2/"}, + {"unspecified v4", "http://0.0.0.0/"}, + {"unspecified v6", "http://[::]/"}, + {"IPv6 link-local", "http://[fe80::1]/"}, + {"multicast", "https://224.0.0.5/"}, + {"broadcast", "http://255.255.255.255/"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + cfg := &Config{URL: tc.url} + rawConfig, _ := json.Marshal(cfg) + conn := New(cfg, newTestLogger()) + + err := conn.ValidateConfig(context.Background(), rawConfig) + if err == nil { + t.Fatalf("ValidateConfig(%q) returned nil, want SSRF rejection", tc.url) + } + if !strings.Contains(err.Error(), "reserved") && !strings.Contains(err.Error(), "rejected") { + t.Errorf("expected reserved/rejected error, got %q", err.Error()) + } + }) + } +} + +func TestWebhook_ValidateConfig_RejectsDangerousSchemes(t *testing.T) { + // Only http(s) is a legitimate webhook transport. Every other scheme is + // an SSRF amplifier (file, gopher, ftp, javascript, data, ldap, dict, + // jar) and must be refused at config time. + cases := []struct { + name string + url string + }{ + {"file", "file:///etc/passwd"}, + {"gopher", "gopher://example.com/_x"}, + {"ftp", "ftp://example.com/"}, + {"javascript", "javascript:alert(1)"}, + {"data", "data:text/plain;base64,SGVsbG8="}, + {"ldap", "ldap://example.com/"}, + {"dict", "dict://example.com:2628/d:foo"}, + {"jar", "jar:http://example.com/foo.jar!/"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + cfg := &Config{URL: tc.url} + rawConfig, _ := json.Marshal(cfg) + conn := New(cfg, newTestLogger()) + + err := conn.ValidateConfig(context.Background(), rawConfig) + if err == nil { + t.Fatalf("ValidateConfig(%q) returned nil, want scheme rejection", tc.url) + } + if !strings.Contains(err.Error(), "rejected") && !strings.Contains(err.Error(), "scheme") { + t.Errorf("expected scheme/rejected error, got %q", err.Error()) + } + }) + } +} + +func TestWebhook_SendAlert_RejectsReservedURLInPostWebhook(t *testing.T) { + // Simulate config drift: URL was legitimate at ValidateConfig time but + // has since been rewritten to an SSRF target. postWebhook must catch + // this on every call without ever hitting the wire. + cfg := &Config{URL: "http://169.254.169.254/latest/meta-data/"} + conn := New(cfg, newTestLogger()) + + alert := notifier.Alert{ + ID: "alert-ssrf", + Type: "test", + Severity: "info", + Subject: "Test", + Message: "Test", + Recipient: "ops@example.com", + CreatedAt: time.Now(), + } + + err := conn.SendAlert(context.Background(), alert) + if err == nil { + t.Fatal("SendAlert returned nil, want SSRF rejection from postWebhook") + } + if !strings.Contains(err.Error(), "reserved") && !strings.Contains(err.Error(), "rejected") { + t.Errorf("expected reserved/rejected error, got %q", err.Error()) + } +} + +func TestWebhook_SendEvent_RejectsReservedURLInPostWebhook(t *testing.T) { + cfg := &Config{URL: "http://[::1]:9/webhook"} + conn := New(cfg, newTestLogger()) + + event := notifier.Event{ + ID: "event-ssrf", + Type: "test", + Subject: "Test", + Body: "Test", + Recipient: "ops@example.com", + CreatedAt: time.Now(), + } + + err := conn.SendEvent(context.Background(), event) + if err == nil { + t.Fatal("SendEvent returned nil, want SSRF rejection from postWebhook") + } + if !strings.Contains(err.Error(), "reserved") && !strings.Contains(err.Error(), "rejected") { + t.Errorf("expected reserved/rejected error, got %q", err.Error()) + } +} + // Helper function to compute HMAC-SHA256 signature func computeHMACSHA256(data []byte, secret string) string { h := hmac.New(sha256.New, []byte(secret)) diff --git a/internal/service/network_scan.go b/internal/service/network_scan.go index f239161..3464eed 100644 --- a/internal/service/network_scan.go +++ b/internal/service/network_scan.go @@ -14,6 +14,7 @@ import ( "github.com/shankar0123/certctl/internal/domain" "github.com/shankar0123/certctl/internal/repository" "github.com/shankar0123/certctl/internal/tlsprobe" + "github.com/shankar0123/certctl/internal/validation" ) // SentinelAgentID is the agent ID used for network-discovered certificates. @@ -318,51 +319,27 @@ func (s *NetworkScanService) expandEndpoints(cidrs []string, ports []int64) []st return endpoints } -// isReservedCIDR checks if an IP address falls within reserved ranges that should not be scanned. -// Filters out loopback, link-local (including cloud metadata), and multicast ranges. -// Does NOT filter RFC 1918 ranges since certctl is self-hosted and internal networks are a primary use case. -func isReservedIP(ip net.IP) bool { - // Loopback: 127.0.0.0/8 - if ip.IsLoopback() { - return true - } - - // Link-local: 169.254.0.0/16 (includes cloud metadata 169.254.169.254) - if linkLocal := net.ParseIP("169.254.0.0"); linkLocal != nil { - if _, linkLocalNet, _ := net.ParseCIDR("169.254.0.0/16"); linkLocalNet != nil { - if linkLocalNet.Contains(ip) { - return true - } - } - } - - // Multicast: 224.0.0.0/4 - if multicast := net.ParseIP("224.0.0.0"); multicast != nil { - if _, multicastNet, _ := net.ParseCIDR("224.0.0.0/4"); multicastNet != nil { - if multicastNet.Contains(ip) { - return true - } - } - } - - // Broadcast: 255.255.255.255 - if ip.String() == "255.255.255.255" { - return true - } - - return false -} +// The reserved-IP filter used by expandCIDR previously lived here as an +// unexported isReservedIP helper. It has been moved to +// internal/validation.IsReservedIP so the webhook notifier can share a single +// authoritative implementation (H-4, CWE-918). The behaviour is +// byte-identical with the previous helper — RFC 1918 is intentionally NOT +// filtered, matching certctl's self-hosted design. If you change the +// validation package's IsReservedIP, you are changing the network-scanner's +// behaviour; audit both code paths together. // expandCIDR expands a CIDR notation or single IP into a list of IPs. // Limits expansion to /20 (4096 IPs) to prevent accidental huge scans. -// Filters out reserved IP ranges to prevent SSRF attacks. +// Filters out reserved IP ranges (via validation.IsReservedIP) to prevent +// SSRF amplification via network-scan targets pointed at cloud metadata or +// loopback. func expandCIDR(cidr string) []string { // Try as CIDR first ip, ipNet, err := net.ParseCIDR(cidr) if err != nil { // Try as single IP if singleIP := net.ParseIP(cidr); singleIP != nil { - if isReservedIP(singleIP) { + if validation.IsReservedIP(singleIP) { return nil } return []string{singleIP.String()} @@ -380,7 +357,7 @@ func expandCIDR(cidr string) []string { var ips []string for ip := ip.Mask(ipNet.Mask); ipNet.Contains(ip); incrementIP(ip) { // Skip reserved IPs - if isReservedIP(ip) { + if validation.IsReservedIP(ip) { continue } diff --git a/internal/service/network_scan_test.go b/internal/service/network_scan_test.go index cb1b304..3096fda 100644 --- a/internal/service/network_scan_test.go +++ b/internal/service/network_scan_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/shankar0123/certctl/internal/domain" + "github.com/shankar0123/certctl/internal/validation" ) // mockNetworkScanRepo for testing @@ -248,9 +249,9 @@ func TestIsReservedIP_Loopback(t *testing.T) { for _, tt := range tests { t.Run(tt.ip, func(t *testing.T) { - result := isReservedIP(net.ParseIP(tt.ip)) + result := validation.IsReservedIP(net.ParseIP(tt.ip)) if result != tt.expected { - t.Errorf("isReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected) + t.Errorf("validation.IsReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected) } }) } @@ -269,9 +270,9 @@ func TestIsReservedIP_LinkLocal(t *testing.T) { for _, tt := range tests { t.Run(tt.ip, func(t *testing.T) { - result := isReservedIP(net.ParseIP(tt.ip)) + result := validation.IsReservedIP(net.ParseIP(tt.ip)) if result != tt.expected { - t.Errorf("isReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected) + t.Errorf("validation.IsReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected) } }) } @@ -289,18 +290,18 @@ func TestIsReservedIP_Multicast(t *testing.T) { for _, tt := range tests { t.Run(tt.ip, func(t *testing.T) { - result := isReservedIP(net.ParseIP(tt.ip)) + result := validation.IsReservedIP(net.ParseIP(tt.ip)) if result != tt.expected { - t.Errorf("isReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected) + t.Errorf("validation.IsReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected) } }) } } func TestIsReservedIP_Broadcast(t *testing.T) { - result := isReservedIP(net.ParseIP("255.255.255.255")) + result := validation.IsReservedIP(net.ParseIP("255.255.255.255")) if !result { - t.Errorf("isReservedIP(255.255.255.255) = %v, expected true", result) + t.Errorf("validation.IsReservedIP(255.255.255.255) = %v, expected true", result) } } @@ -320,9 +321,9 @@ func TestIsReservedIP_AllowsPrivateRanges(t *testing.T) { for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - result := isReservedIP(net.ParseIP(tt.ip)) + result := validation.IsReservedIP(net.ParseIP(tt.ip)) if result != tt.expected { - t.Errorf("isReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected) + t.Errorf("validation.IsReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected) } }) } @@ -340,9 +341,9 @@ func TestIsReservedIP_AllowsPublic(t *testing.T) { for _, tt := range tests { t.Run(tt.ip, func(t *testing.T) { - result := isReservedIP(net.ParseIP(tt.ip)) + result := validation.IsReservedIP(net.ParseIP(tt.ip)) if result != tt.expected { - t.Errorf("isReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected) + t.Errorf("validation.IsReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected) } }) } diff --git a/internal/validation/ssrf.go b/internal/validation/ssrf.go new file mode 100644 index 0000000..6b8d34e --- /dev/null +++ b/internal/validation/ssrf.go @@ -0,0 +1,212 @@ +package validation + +import ( + "context" + "fmt" + "net" + "net/url" + "strings" + "time" +) + +// IsReservedIP reports whether the given IP falls inside a range that +// outbound HTTP egress (and the network-scanner CIDR expander) MUST treat +// as unreachable: loopback, link-local (including cloud-provider metadata +// endpoints at 169.254.169.254), multicast, and broadcast. +// +// RFC 1918 ranges (10/8, 172.16/12, 192.168/16) are intentionally NOT +// treated as reserved. certctl is designed to manage certificates inside +// private networks and filtering private address space would break the +// primary use case. The threat model here is outbound HTTP to +// cloud-metadata or localhost services, not general network reachability. +// +// This function is byte-identical in behaviour to the previous unexported +// copy in internal/service/network_scan.go. It is exported here so both +// the network scanner and the webhook notifier share a single +// authoritative implementation. Broader IPv6 coverage and unspecified- +// address handling live in SafeHTTPDialContext, where stricter policy is +// appropriate for outbound HTTP egress. +func IsReservedIP(ip net.IP) bool { + // Loopback: 127.0.0.0/8 (and ::1 via IsLoopback). + if ip.IsLoopback() { + return true + } + + // Link-local: 169.254.0.0/16 (includes cloud metadata 169.254.169.254). + if linkLocal := net.ParseIP("169.254.0.0"); linkLocal != nil { + if _, linkLocalNet, _ := net.ParseCIDR("169.254.0.0/16"); linkLocalNet != nil { + if linkLocalNet.Contains(ip) { + return true + } + } + } + + // Multicast: 224.0.0.0/4. + if multicast := net.ParseIP("224.0.0.0"); multicast != nil { + if _, multicastNet, _ := net.ParseCIDR("224.0.0.0/4"); multicastNet != nil { + if multicastNet.Contains(ip) { + return true + } + } + } + + // Broadcast: 255.255.255.255. + if ip.String() == "255.255.255.255" { + return true + } + + return false +} + +// isReservedIPForDial applies IsReservedIP plus additional ranges that are +// meaningful for outbound HTTP egress but were not part of the original +// network-scanner filter: the unspecified address (0.0.0.0 / ::) and IPv6 +// link-local / multicast ranges. Kept private so IsReservedIP stays +// byte-identical with the previous scanner behaviour. +func isReservedIPForDial(ip net.IP) bool { + if ip == nil { + return true + } + if IsReservedIP(ip) { + return true + } + if ip.IsUnspecified() { + return true + } + // IPv6 link-local fe80::/10. + if _, n, err := net.ParseCIDR("fe80::/10"); err == nil && n.Contains(ip) { + return true + } + // IPv6 multicast ff00::/8. + if _, n, err := net.ParseCIDR("ff00::/8"); err == nil && n.Contains(ip) { + return true + } + return false +} + +// ValidateSafeURL parses rawURL and rejects anything that would let an +// attacker aim an outbound HTTP client at a SSRF-sensitive destination +// (CWE-918). Guards enforced: +// +// 1. The scheme must be http or https. Schemes like file://, gopher://, +// ftp://, data:, javascript:, ldap://, and dict:// are rejected outright; +// webhook delivery only speaks HTTP(S). +// 2. A hostname must be present. Empty-host URLs like "http:///foo" are +// rejected to prevent ambiguous defaulting. +// 3. If the host is a literal IP address, the IP must not be reserved +// (see isReservedIPForDial). This stops the obvious 127.0.0.1 / ::1 / +// 169.254.169.254 / 0.0.0.0 attacks at config time. +// 4. If the host is a DNS name and resolution succeeds, every resolved +// A/AAAA record must be non-reserved. A single reserved result is +// enough to reject. Resolution failure is tolerated (offline CI +// environments, short-lived test servers) — the authoritative +// enforcement runs at dial time anyway. +// +// The DNS resolution check here is a best-effort early diagnostic. The +// authoritative, TOCTOU-safe enforcement is SafeHTTPDialContext, which +// re-checks after resolution at dial time and defeats DNS rebinding. +// Callers that need SSRF-safe HTTP egress should use BOTH +// ValidateSafeURL (at config ingestion) AND SafeHTTPDialContext +// (installed on http.Transport). +func ValidateSafeURL(rawURL string) error { + if rawURL == "" { + return fmt.Errorf("url is required") + } + + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid url: %w", err) + } + + scheme := strings.ToLower(u.Scheme) + if scheme != "http" && scheme != "https" { + return fmt.Errorf("url scheme %q is not allowed; only http and https are permitted", u.Scheme) + } + + host := u.Hostname() + if host == "" { + return fmt.Errorf("url must include a host") + } + + // Literal IP? Reject if reserved (strict policy for outbound egress). + if ip := net.ParseIP(host); ip != nil { + if isReservedIPForDial(ip) { + return fmt.Errorf("url host resolves to a reserved address and cannot be used") + } + return nil + } + + // DNS name. Resolve and reject if any answer is reserved. + ips, err := net.LookupIP(host) + if err != nil { + // Resolution failure is not itself a SSRF signal; let the dial-time + // DialContext handle the final decision. This keeps the validator + // tolerant of offline validation environments (CI, tests) while + // still blocking clearly-bad literal-IP URLs above. + return nil + } + for _, ip := range ips { + if isReservedIPForDial(ip) { + return fmt.Errorf("url host resolves to a reserved address and cannot be used") + } + } + + return nil +} + +// SafeHTTPDialContext returns a DialContext function suitable for +// installing on an http.Transport. Every dial attempt resolves the host +// again and rejects any connection whose resolved IP lies inside a +// reserved range. This is the authoritative SSRF / DNS-rebinding guard: +// even if ValidateSafeURL was bypassed, or if DNS changed between +// validation and dial, the outbound connection will fail closed. +// +// The timeout argument bounds both the resolution and the underlying TCP +// dial. Pass 0 to use a sensible default (10s). +func SafeHTTPDialContext(timeout time.Duration) func(ctx context.Context, network, addr string) (net.Conn, error) { + if timeout <= 0 { + timeout = 10 * time.Second + } + dialer := &net.Dialer{ + Timeout: timeout, + KeepAlive: 30 * time.Second, + } + return func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("invalid dial address %q: %w", addr, err) + } + + // If the host is already a literal IP, check it directly. + if ip := net.ParseIP(host); ip != nil { + if isReservedIPForDial(ip) { + return nil, fmt.Errorf("refusing to dial reserved address %s", ip.String()) + } + return dialer.DialContext(ctx, network, addr) + } + + // Resolve and reject any answer that lands in a reserved range. + // We then dial an explicit resolved IP so a racing DNS change + // cannot substitute a different (and possibly reserved) answer + // between our check and the actual TCP dial. + resCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + ips, err := (&net.Resolver{}).LookupIP(resCtx, "ip", host) + if err != nil { + return nil, fmt.Errorf("resolve %s: %w", host, err) + } + if len(ips) == 0 { + return nil, fmt.Errorf("no addresses found for %s", host) + } + for _, ip := range ips { + if isReservedIPForDial(ip) { + return nil, fmt.Errorf("refusing to dial %s: resolves to reserved address %s", host, ip.String()) + } + } + + // Dial the first non-reserved resolved IP directly, pinning the + // target so later DNS changes cannot redirect us. + pinned := net.JoinHostPort(ips[0].String(), port) + return dialer.DialContext(ctx, network, pinned) + } +} diff --git a/internal/validation/ssrf_test.go b/internal/validation/ssrf_test.go new file mode 100644 index 0000000..a51d455 --- /dev/null +++ b/internal/validation/ssrf_test.go @@ -0,0 +1,230 @@ +package validation + +import ( + "context" + "net" + "strings" + "testing" + "time" +) + +func TestIsReservedIP_ByteIdenticalWithNetworkScannerBehavior(t *testing.T) { + // These expectations MUST NOT drift from the original unexported + // isReservedIP in internal/service/network_scan.go. Any deviation here + // is a behaviour change in the network scanner and requires a separate, + // deliberate migration. + cases := []struct { + name string + ip string + reserved bool + }{ + {"loopback v4", "127.0.0.1", true}, + {"loopback v4 range upper", "127.255.255.254", true}, + {"loopback v6", "::1", true}, + {"AWS metadata", "169.254.169.254", true}, + {"link-local range edge", "169.254.0.0", true}, + {"multicast 224", "224.0.0.1", true}, + {"multicast upper", "239.255.255.255", true}, + {"broadcast", "255.255.255.255", true}, + // The original network-scanner filter does NOT include unspecified + // or IPv6 link-local, so these must remain non-reserved at this + // layer. Stricter outbound-dial policy lives in SafeHTTPDialContext. + {"unspecified v4", "0.0.0.0", false}, + {"IPv6 link-local", "fe80::1", false}, + {"IPv6 multicast", "ff00::1", false}, + // RFC 1918 is intentionally allowed (self-hosted design). + {"RFC 1918 10/8", "10.0.0.1", false}, + {"RFC 1918 172.16/12", "172.16.0.1", false}, + {"RFC 1918 192.168/16", "192.168.1.1", false}, + // Ordinary public addresses pass. + {"public v4", "8.8.8.8", false}, + {"public v6", "2606:4700:4700::1111", false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ip := net.ParseIP(tc.ip) + if ip == nil { + t.Fatalf("test setup: failed to parse %q", tc.ip) + } + if got := IsReservedIP(ip); got != tc.reserved { + t.Errorf("IsReservedIP(%s)=%v, want %v", tc.ip, got, tc.reserved) + } + }) + } +} + +func TestValidateSafeURL_AcceptsSafePublicURLs(t *testing.T) { + cases := []string{ + "https://example.com/webhook", + "http://example.com/hook", + "https://example.com:8443/hook", + "https://webhook.site/abc-123", + "http://10.0.0.5/internal", // RFC 1918 allowed + "http://192.168.1.10:8080/webhook", // RFC 1918 allowed + "http://172.16.5.1/intranet", // RFC 1918 allowed + } + for _, raw := range cases { + t.Run(raw, func(t *testing.T) { + if err := ValidateSafeURL(raw); err != nil { + t.Errorf("ValidateSafeURL(%q) unexpectedly failed: %v", raw, err) + } + }) + } +} + +func TestValidateSafeURL_RejectsReservedLiteralIPs(t *testing.T) { + cases := []struct { + name string + url string + }{ + {"loopback v4", "http://127.0.0.1/x"}, + {"loopback v4 with port", "http://127.0.0.1:8080/"}, + {"loopback v6 bracketed", "http://[::1]/x"}, + {"AWS metadata endpoint", "http://169.254.169.254/latest/meta-data/"}, + {"link-local IP", "http://169.254.1.2/"}, + {"broadcast", "http://255.255.255.255/"}, + {"multicast", "https://224.0.0.5/"}, + {"unspecified v4", "http://0.0.0.0/"}, + {"unspecified v6", "http://[::]/"}, + {"IPv6 link-local", "http://[fe80::1]/"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateSafeURL(tc.url) + if err == nil { + t.Fatalf("ValidateSafeURL(%q) returned nil, want error", tc.url) + } + if !strings.Contains(err.Error(), "reserved") { + t.Errorf("error should mention 'reserved' for operator diagnostics, got %q", err.Error()) + } + }) + } +} + +func TestValidateSafeURL_RejectsDangerousSchemes(t *testing.T) { + cases := []struct { + name string + url string + }{ + {"file scheme", "file:///etc/passwd"}, + {"gopher scheme", "gopher://example.com/"}, + {"ftp scheme", "ftp://example.com/"}, + {"javascript scheme", "javascript:alert(1)"}, + {"data scheme", "data:text/plain;base64,SGVsbG8="}, + {"ldap scheme", "ldap://example.com/"}, + {"dict scheme", "dict://example.com:2628/d:foo"}, + {"jar scheme", "jar:http://example.com/foo.jar!/"}, + {"empty scheme", "example.com/hook"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateSafeURL(tc.url) + if err == nil { + t.Fatalf("ValidateSafeURL(%q) returned nil, want error", tc.url) + } + if !strings.Contains(err.Error(), "scheme") && !strings.Contains(err.Error(), "host") { + t.Errorf("error should mention scheme or host, got %q", err.Error()) + } + }) + } +} + +func TestValidateSafeURL_RejectsMissingHost(t *testing.T) { + cases := []string{ + "http:///foo", + "https://", + } + for _, raw := range cases { + t.Run(raw, func(t *testing.T) { + err := ValidateSafeURL(raw) + if err == nil { + t.Fatalf("ValidateSafeURL(%q) returned nil, want error", raw) + } + }) + } +} + +func TestValidateSafeURL_RejectsEmpty(t *testing.T) { + if err := ValidateSafeURL(""); err == nil { + t.Fatal("ValidateSafeURL(\"\") returned nil, want error") + } +} + +func TestValidateSafeURL_RejectsMalformed(t *testing.T) { + // url.Parse is famously lax; we lean on the scheme/host checks to catch + // malformed inputs that produce empty schemes or hosts. + cases := []string{ + "://missing-scheme", + "http//missing-colon.example.com", + } + for _, raw := range cases { + t.Run(raw, func(t *testing.T) { + err := ValidateSafeURL(raw) + if err == nil { + t.Fatalf("ValidateSafeURL(%q) returned nil, want error", raw) + } + }) + } +} + +func TestSafeHTTPDialContext_RejectsLiteralReservedAddress(t *testing.T) { + dial := SafeHTTPDialContext(2 * time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + cases := []string{ + "127.0.0.1:9", + "169.254.169.254:80", + "[::1]:22", + "0.0.0.0:80", + } + for _, addr := range cases { + t.Run(addr, func(t *testing.T) { + conn, err := dial(ctx, "tcp", addr) + if err == nil { + _ = conn.Close() + t.Fatalf("dial(%q) returned nil err, want reserved-address rejection", addr) + } + if !strings.Contains(err.Error(), "reserved") { + t.Errorf("expected reserved-address rejection, got %q", err.Error()) + } + }) + } +} + +func TestSafeHTTPDialContext_RejectsHostResolvingToReservedAddress(t *testing.T) { + // The stdlib resolver treats "localhost" as 127.0.0.1 / ::1 on every + // platform we care about; this exercises the post-resolution check and + // demonstrates that DNS-rebinding attacks (where a name points at a + // reserved IP) are rejected at dial time rather than at validation time. + dial := SafeHTTPDialContext(2 * time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + conn, err := dial(ctx, "tcp", "localhost:9") + if err == nil { + _ = conn.Close() + t.Fatal("dial(localhost:9) returned nil err, want reserved-address rejection") + } + if !strings.Contains(err.Error(), "reserved") { + t.Errorf("expected reserved-address rejection for localhost, got %q", err.Error()) + } +} + +func TestSafeHTTPDialContext_InvalidAddress(t *testing.T) { + dial := SafeHTTPDialContext(1 * time.Second) + _, err := dial(context.Background(), "tcp", "no-port") + if err == nil { + t.Fatal("expected error for invalid dial address, got nil") + } +} + +func TestSafeHTTPDialContext_DefaultTimeoutWhenZero(t *testing.T) { + // Not directly observable, but we at least exercise the branch to + // prevent a nil-ptr regression if the timeout default is dropped. + dial := SafeHTTPDialContext(0) + _, err := dial(context.Background(), "tcp", "127.0.0.1:1") + if err == nil { + t.Fatal("expected reserved-address rejection") + } +}