mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-09 23:09:01 +00:00
fix: security audit remediation (AUDIT-001, 003, 004, 005, 006, 018)
- AUDIT-001: Validate OpenSSL revoke inputs (hex-only serials, RFC 5280 reasons) - AUDIT-003: Enforce /20 CIDR size cap at API level (create + update) - AUDIT-004: Support comma-separated CERTCTL_AUTH_SECRET for zero-downtime key rotation - AUDIT-005: Add ReadHeaderTimeout (5s) to prevent Slowloris - AUDIT-006: Document audit trail query parameter exclusion rationale - AUDIT-018: Add immediate-run-on-start to short-lived expiry scheduler loop Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -78,7 +78,12 @@ func NewAuditLog(recorder AuditRecorder, cfg AuditConfig) func(http.Handler) htt
|
||||
|
||||
latency := time.Since(start).Milliseconds()
|
||||
|
||||
// Record audit event asynchronously (best-effort, don't block response)
|
||||
// Record audit event asynchronously (best-effort, don't block response).
|
||||
// SECURITY: We intentionally use r.URL.Path (not r.URL.String() or r.RequestURI)
|
||||
// to prevent query parameters from being recorded in the immutable audit trail.
|
||||
// Query strings may contain cursor tokens, API keys passed as params, or other
|
||||
// sensitive filter values. Since the audit trail is append-only with no deletion
|
||||
// capability, any sensitive data recorded would persist permanently.
|
||||
go func() {
|
||||
if err := recorder.RecordAPICall(
|
||||
context.Background(),
|
||||
|
||||
@@ -328,6 +328,46 @@ func TestAuditLog_CapturesLatency(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuditLog_ExcludesQueryParamsFromPath(t *testing.T) {
|
||||
recorder := newWaitableAuditRecorder()
|
||||
mw := NewAuditLog(recorder, AuditConfig{})
|
||||
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Send a request with sensitive query parameters
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?api_key=secret123&cursor=abc&status=active", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if !recorder.Wait(1 * time.Second) {
|
||||
t.Fatal("timeout waiting for audit record")
|
||||
}
|
||||
|
||||
calls := recorder.getCalls()
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 audit call, got %d", len(calls))
|
||||
}
|
||||
|
||||
// Path should contain ONLY the path, no query parameters
|
||||
if calls[0].Path != "/api/v1/certificates" {
|
||||
t.Errorf("expected path /api/v1/certificates (no query params), got %s", calls[0].Path)
|
||||
}
|
||||
if strings.Contains(calls[0].Path, "api_key") {
|
||||
t.Error("audit path contains 'api_key' — query parameters leaked into audit trail")
|
||||
}
|
||||
if strings.Contains(calls[0].Path, "secret123") {
|
||||
t.Error("audit path contains sensitive value 'secret123' — query parameters leaked into audit trail")
|
||||
}
|
||||
if strings.Contains(calls[0].Path, "cursor") {
|
||||
t.Error("audit path contains 'cursor' — query parameters leaked into audit trail")
|
||||
}
|
||||
if strings.Contains(calls[0].Path, "?") {
|
||||
t.Error("audit path contains '?' — query string leaked into audit trail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuditServiceAdapter_TranslatesCallToEvent(t *testing.T) {
|
||||
var capturedActor, capturedActorType, capturedAction, capturedResourceType, capturedResourceID string
|
||||
var capturedDetails map[string]interface{}
|
||||
|
||||
@@ -0,0 +1,189 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewAuth_MultiKeyAcceptsBothKeys(t *testing.T) {
|
||||
cfg := AuthConfig{
|
||||
Type: "api-key",
|
||||
Secret: "key-one,key-two",
|
||||
}
|
||||
|
||||
mw := NewAuth(cfg)
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// First key should work
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||
req1.Header.Set("Authorization", "Bearer key-one")
|
||||
rr1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr1, req1)
|
||||
if rr1.Code != http.StatusOK {
|
||||
t.Errorf("expected 200 for first key, got %d", rr1.Code)
|
||||
}
|
||||
|
||||
// Second key should work
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||
req2.Header.Set("Authorization", "Bearer key-two")
|
||||
rr2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr2, req2)
|
||||
if rr2.Code != http.StatusOK {
|
||||
t.Errorf("expected 200 for second key, got %d", rr2.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuth_MultiKeyRejectsInvalidKey(t *testing.T) {
|
||||
cfg := AuthConfig{
|
||||
Type: "api-key",
|
||||
Secret: "key-one,key-two",
|
||||
}
|
||||
|
||||
mw := NewAuth(cfg)
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Invalid key should be rejected
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||
req.Header.Set("Authorization", "Bearer wrong-key")
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401 for invalid key, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuth_MultiKeyWithSpaces(t *testing.T) {
|
||||
// Keys with leading/trailing spaces should be trimmed
|
||||
cfg := AuthConfig{
|
||||
Type: "api-key",
|
||||
Secret: " key-one , key-two ",
|
||||
}
|
||||
|
||||
mw := NewAuth(cfg)
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||
req.Header.Set("Authorization", "Bearer key-one")
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected 200 for trimmed key, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuth_SingleKeyStillWorks(t *testing.T) {
|
||||
cfg := AuthConfig{
|
||||
Type: "api-key",
|
||||
Secret: "my-single-key",
|
||||
}
|
||||
|
||||
mw := NewAuth(cfg)
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||
req.Header.Set("Authorization", "Bearer my-single-key")
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected 200 for single key, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuth_NoneMode(t *testing.T) {
|
||||
cfg := AuthConfig{
|
||||
Type: "none",
|
||||
Secret: "",
|
||||
}
|
||||
|
||||
mw := NewAuth(cfg)
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// No auth header needed in none mode
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected 200 in none mode, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuth_MissingAuthHeader(t *testing.T) {
|
||||
cfg := AuthConfig{
|
||||
Type: "api-key",
|
||||
Secret: "test-key",
|
||||
}
|
||||
|
||||
mw := NewAuth(cfg)
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401 for missing auth, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuth_InvalidBearerFormat(t *testing.T) {
|
||||
cfg := AuthConfig{
|
||||
Type: "api-key",
|
||||
Secret: "test-key",
|
||||
}
|
||||
|
||||
mw := NewAuth(cfg)
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||
req.Header.Set("Authorization", "Basic dGVzdDp0ZXN0")
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401 for non-Bearer auth, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuth_RemovedKeyIsRejected(t *testing.T) {
|
||||
// Simulate key rotation: only key-two is configured (key-one was removed)
|
||||
cfg := AuthConfig{
|
||||
Type: "api-key",
|
||||
Secret: "key-two",
|
||||
}
|
||||
|
||||
mw := NewAuth(cfg)
|
||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Old key should be rejected
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||
req.Header.Set("Authorization", "Bearer key-one")
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401 for removed key, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// New key should work
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||
req2.Header.Set("Authorization", "Bearer key-two")
|
||||
rr2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr2, req2)
|
||||
if rr2.Code != http.StatusOK {
|
||||
t.Errorf("expected 200 for current key, got %d", rr2.Code)
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -100,12 +101,17 @@ func HashAPIKey(key string) string {
|
||||
// AuthConfig holds configuration for the Auth middleware.
|
||||
type AuthConfig struct {
|
||||
Type string // "api-key", "jwt", "none"
|
||||
Secret string // The raw API key (server compares against this)
|
||||
Secret string // The raw API key or comma-separated list of valid API keys
|
||||
}
|
||||
|
||||
// NewAuth creates an authentication middleware based on config.
|
||||
// When Type is "none", all requests pass through (demo/development mode).
|
||||
// When Type is "api-key", requests must include a valid Bearer token.
|
||||
// The Secret field supports a comma-separated list of valid API keys for
|
||||
// zero-downtime key rotation. Rotation workflow:
|
||||
// 1. Add new key to comma-separated list, restart server
|
||||
// 2. Update all agents/clients to use new key
|
||||
// 3. Remove old key from list, restart server
|
||||
func NewAuth(cfg AuthConfig) func(http.Handler) http.Handler {
|
||||
if cfg.Type == "none" {
|
||||
return func(next http.Handler) http.Handler {
|
||||
@@ -113,8 +119,21 @@ func NewAuth(cfg AuthConfig) func(http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// Pre-compute hash of the expected key for constant-time comparison
|
||||
expectedHash := HashAPIKey(cfg.Secret)
|
||||
// Pre-compute hashes of all valid keys for constant-time comparison.
|
||||
// Supports comma-separated list for zero-downtime key rotation.
|
||||
keys := strings.Split(cfg.Secret, ",")
|
||||
var expectedHashes []string
|
||||
for _, k := range keys {
|
||||
k = strings.TrimSpace(k)
|
||||
if k != "" {
|
||||
expectedHashes = append(expectedHashes, HashAPIKey(k))
|
||||
}
|
||||
}
|
||||
|
||||
// Warn if only one key is configured in production mode
|
||||
if len(expectedHashes) == 1 {
|
||||
slog.Warn("only one API key configured — consider adding a rotation key via comma-separated CERTCTL_AUTH_SECRET for zero-downtime rotation")
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -136,8 +155,16 @@ func NewAuth(cfg AuthConfig) func(http.Handler) http.Handler {
|
||||
token := authHeader[7:]
|
||||
tokenHash := HashAPIKey(token)
|
||||
|
||||
// Constant-time comparison to prevent timing attacks
|
||||
if subtle.ConstantTimeCompare([]byte(tokenHash), []byte(expectedHash)) != 1 {
|
||||
// Check against all valid keys using constant-time comparison
|
||||
authorized := false
|
||||
for _, expectedHash := range expectedHashes {
|
||||
if subtle.ConstantTimeCompare([]byte(tokenHash), []byte(expectedHash)) == 1 {
|
||||
authorized = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !authorized {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
http.Error(w, `{"error":"Invalid API key"}`, http.StatusUnauthorized)
|
||||
return
|
||||
|
||||
@@ -32,9 +32,12 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/connector/issuer"
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/validation"
|
||||
)
|
||||
|
||||
// Config represents the OpenSSL/Custom CA issuer connector configuration.
|
||||
@@ -258,6 +261,36 @@ func (c *Connector) RenewCertificate(ctx context.Context, request issuer.Renewal
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// hexSerialRegex validates that a serial number contains only hexadecimal characters.
|
||||
// Certificate serial numbers are integers represented in hex (RFC 5280).
|
||||
var hexSerialRegex = regexp.MustCompile(`^[0-9a-fA-F]+$`)
|
||||
|
||||
// validateSerial validates a certificate serial number for safe use in shell commands.
|
||||
// Serial numbers must be non-empty, hex-only strings with no shell metacharacters.
|
||||
func validateSerial(serial string) error {
|
||||
if serial == "" {
|
||||
return fmt.Errorf("serial number cannot be empty")
|
||||
}
|
||||
if !hexSerialRegex.MatchString(serial) {
|
||||
return fmt.Errorf("serial number %q contains non-hex characters (expected ^[0-9a-fA-F]+$)", serial)
|
||||
}
|
||||
if err := validation.ValidateShellCommand(serial); err != nil {
|
||||
return fmt.Errorf("serial number failed shell safety validation: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateRevocationReason validates a revocation reason against RFC 5280 reason codes.
|
||||
func validateRevocationReason(reason string) error {
|
||||
if !domain.IsValidRevocationReason(reason) {
|
||||
return fmt.Errorf("invalid revocation reason %q (must be a valid RFC 5280 reason code)", reason)
|
||||
}
|
||||
if err := validation.ValidateShellCommand(reason); err != nil {
|
||||
return fmt.Errorf("revocation reason failed shell safety validation: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeCertificate revokes a certificate by calling the revoke script if configured.
|
||||
func (c *Connector) RevokeCertificate(ctx context.Context, request issuer.RevocationRequest) error {
|
||||
if c.config.RevokeScript == "" {
|
||||
@@ -270,6 +303,14 @@ func (c *Connector) RevokeCertificate(ctx context.Context, request issuer.Revoca
|
||||
reason = *request.Reason
|
||||
}
|
||||
|
||||
// Validate serial number (hex-only) and reason code (RFC 5280) before shell execution
|
||||
if err := validateSerial(request.Serial); err != nil {
|
||||
return fmt.Errorf("revocation input validation failed: %w", err)
|
||||
}
|
||||
if err := validateRevocationReason(reason); err != nil {
|
||||
return fmt.Errorf("revocation input validation failed: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("revoking certificate via revoke script",
|
||||
"serial", request.Serial,
|
||||
"reason", reason)
|
||||
|
||||
@@ -289,7 +289,7 @@ func TestOpenSSLConnector(t *testing.T) {
|
||||
}
|
||||
|
||||
revokeReq := issuer.RevocationRequest{
|
||||
Serial: "test-serial-12345",
|
||||
Serial: "ABCDEF1234567890",
|
||||
}
|
||||
|
||||
// Should return nil (no-op) when revoke script not configured
|
||||
@@ -324,8 +324,10 @@ func TestOpenSSLConnector(t *testing.T) {
|
||||
t.Fatalf("ValidateConfig failed: %v", err)
|
||||
}
|
||||
|
||||
reason := "keyCompromise"
|
||||
revokeReq := issuer.RevocationRequest{
|
||||
Serial: "test-serial-12345",
|
||||
Serial: "ABCDEF1234567890",
|
||||
Reason: &reason,
|
||||
}
|
||||
|
||||
err := connector.RevokeCertificate(ctx, revokeReq)
|
||||
@@ -334,6 +336,139 @@ func TestOpenSSLConnector(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Test 15: RevokeCertificate rejects injection payloads in serial number
|
||||
t.Run("RevokeCertificate_InjectionSerial", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
signScript := filepath.Join(tmpDir, "sign.sh")
|
||||
if err := os.WriteFile(signScript, []byte("#!/bin/sh\nexit 0"), 0755); err != nil {
|
||||
t.Fatalf("Failed to create sign script: %v", err)
|
||||
}
|
||||
revokeScript := filepath.Join(tmpDir, "revoke.sh")
|
||||
if err := os.WriteFile(revokeScript, []byte("#!/bin/sh\nexit 0"), 0755); err != nil {
|
||||
t.Fatalf("Failed to create revoke script: %v", err)
|
||||
}
|
||||
|
||||
config := &openssl.Config{
|
||||
SignScript: signScript,
|
||||
RevokeScript: revokeScript,
|
||||
}
|
||||
connector := openssl.New(config, logger)
|
||||
rawConfig, _ := json.Marshal(config)
|
||||
if err := connector.ValidateConfig(ctx, rawConfig); err != nil {
|
||||
t.Fatalf("ValidateConfig failed: %v", err)
|
||||
}
|
||||
|
||||
injectionPayloads := []string{
|
||||
"1234;rm -rf /",
|
||||
"1234|cat /etc/passwd",
|
||||
"1234&whoami",
|
||||
"$(id)",
|
||||
"`id`",
|
||||
"1234\nid",
|
||||
"../../../etc/passwd",
|
||||
"test-serial-12345", // hyphens not allowed (not hex)
|
||||
}
|
||||
|
||||
for _, payload := range injectionPayloads {
|
||||
t.Run(payload, func(t *testing.T) {
|
||||
req := issuer.RevocationRequest{Serial: payload}
|
||||
err := connector.RevokeCertificate(ctx, req)
|
||||
if err == nil {
|
||||
t.Errorf("Expected injection payload %q to be rejected, but it was accepted", payload)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Test 16: RevokeCertificate rejects invalid reason codes
|
||||
t.Run("RevokeCertificate_InvalidReason", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
signScript := filepath.Join(tmpDir, "sign.sh")
|
||||
if err := os.WriteFile(signScript, []byte("#!/bin/sh\nexit 0"), 0755); err != nil {
|
||||
t.Fatalf("Failed to create sign script: %v", err)
|
||||
}
|
||||
revokeScript := filepath.Join(tmpDir, "revoke.sh")
|
||||
if err := os.WriteFile(revokeScript, []byte("#!/bin/sh\nexit 0"), 0755); err != nil {
|
||||
t.Fatalf("Failed to create revoke script: %v", err)
|
||||
}
|
||||
|
||||
config := &openssl.Config{
|
||||
SignScript: signScript,
|
||||
RevokeScript: revokeScript,
|
||||
}
|
||||
connector := openssl.New(config, logger)
|
||||
rawConfig, _ := json.Marshal(config)
|
||||
if err := connector.ValidateConfig(ctx, rawConfig); err != nil {
|
||||
t.Fatalf("ValidateConfig failed: %v", err)
|
||||
}
|
||||
|
||||
invalidReasons := []string{
|
||||
"notARealReason",
|
||||
"keyCompromise;rm -rf /",
|
||||
"$(whoami)",
|
||||
"`id`",
|
||||
}
|
||||
|
||||
for _, reason := range invalidReasons {
|
||||
t.Run(reason, func(t *testing.T) {
|
||||
r := reason
|
||||
req := issuer.RevocationRequest{
|
||||
Serial: "ABCDEF1234567890",
|
||||
Reason: &r,
|
||||
}
|
||||
err := connector.RevokeCertificate(ctx, req)
|
||||
if err == nil {
|
||||
t.Errorf("Expected invalid reason %q to be rejected, but it was accepted", reason)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Test 17: RevokeCertificate accepts all valid RFC 5280 reason codes
|
||||
t.Run("RevokeCertificate_ValidReasons", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
signScript := filepath.Join(tmpDir, "sign.sh")
|
||||
if err := os.WriteFile(signScript, []byte("#!/bin/sh\nexit 0"), 0755); err != nil {
|
||||
t.Fatalf("Failed to create sign script: %v", err)
|
||||
}
|
||||
revokeScript := filepath.Join(tmpDir, "revoke.sh")
|
||||
if err := os.WriteFile(revokeScript, []byte("#!/bin/sh\nexit 0"), 0755); err != nil {
|
||||
t.Fatalf("Failed to create revoke script: %v", err)
|
||||
}
|
||||
|
||||
config := &openssl.Config{
|
||||
SignScript: signScript,
|
||||
RevokeScript: revokeScript,
|
||||
}
|
||||
connector := openssl.New(config, logger)
|
||||
rawConfig, _ := json.Marshal(config)
|
||||
if err := connector.ValidateConfig(ctx, rawConfig); err != nil {
|
||||
t.Fatalf("ValidateConfig failed: %v", err)
|
||||
}
|
||||
|
||||
validReasons := []string{
|
||||
"unspecified", "keyCompromise", "caCompromise", "affiliationChanged",
|
||||
"superseded", "cessationOfOperation", "certificateHold", "privilegeWithdrawn",
|
||||
}
|
||||
|
||||
for _, reason := range validReasons {
|
||||
t.Run(reason, func(t *testing.T) {
|
||||
r := reason
|
||||
req := issuer.RevocationRequest{
|
||||
Serial: "ABCDEF1234567890",
|
||||
Reason: &r,
|
||||
}
|
||||
err := connector.RevokeCertificate(ctx, req)
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid reason %q to be accepted, got error: %v", reason, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Test 10: GetOrderStatus always returns "completed"
|
||||
t.Run("GetOrderStatus", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
@@ -356,6 +356,15 @@ func (s *Scheduler) shortLivedExpiryCheckLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(s.shortLivedExpiryCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run immediately on start (with idempotency guard)
|
||||
s.shortLivedExpiryCheckRunning.Store(true)
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
defer s.shortLivedExpiryCheckRunning.Store(false)
|
||||
s.runShortLivedExpiryCheck(ctx)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
||||
@@ -58,6 +58,36 @@ func (s *NetworkScanService) GetTarget(ctx context.Context, id string) (*domain.
|
||||
return s.networkScanRepo.Get(ctx, id)
|
||||
}
|
||||
|
||||
// maxCIDRHostBits is the maximum number of host bits allowed in a CIDR range.
|
||||
// A /20 network has 12 host bits = 4096 IPs max. This prevents operators from
|
||||
// accidentally creating scan targets that would exhaust server resources.
|
||||
const maxCIDRHostBits = 12
|
||||
|
||||
// validateCIDRs validates a list of CIDRs for syntax correctness and size limits.
|
||||
// Each CIDR must be a valid CIDR notation or plain IP address, and no single CIDR
|
||||
// may be larger than /20 (4096 IPs). This validation runs at API request time so
|
||||
// operators get an immediate 400 error instead of a silent truncation at scan time.
|
||||
func validateCIDRs(cidrs []string) error {
|
||||
for _, cidr := range cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
// Try parsing as plain IP (single host)
|
||||
if ip := net.ParseIP(cidr); ip == nil {
|
||||
return fmt.Errorf("invalid CIDR or IP: %s", cidr)
|
||||
}
|
||||
continue // Single IPs are always valid size
|
||||
}
|
||||
// Enforce /20 size cap at API level
|
||||
ones, bits := ipNet.Mask.Size()
|
||||
hostBits := bits - ones
|
||||
if hostBits > maxCIDRHostBits {
|
||||
return fmt.Errorf("CIDR %s is too large (/%d has %d host bits, max /%d with %d host bits = 4096 IPs)",
|
||||
cidr, ones, hostBits, bits-maxCIDRHostBits, maxCIDRHostBits)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateTarget creates a new network scan target.
|
||||
func (s *NetworkScanService) CreateTarget(ctx context.Context, target *domain.NetworkScanTarget) (*domain.NetworkScanTarget, error) {
|
||||
if target.Name == "" {
|
||||
@@ -66,14 +96,9 @@ func (s *NetworkScanService) CreateTarget(ctx context.Context, target *domain.Ne
|
||||
if len(target.CIDRs) == 0 {
|
||||
return nil, fmt.Errorf("at least one CIDR is required")
|
||||
}
|
||||
// Validate CIDRs
|
||||
for _, cidr := range target.CIDRs {
|
||||
if _, _, err := net.ParseCIDR(cidr); err != nil {
|
||||
// Try parsing as plain IP
|
||||
if ip := net.ParseIP(cidr); ip == nil {
|
||||
return nil, fmt.Errorf("invalid CIDR or IP: %s", cidr)
|
||||
}
|
||||
}
|
||||
// Validate CIDRs (syntax + /20 size cap)
|
||||
if err := validateCIDRs(target.CIDRs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(target.Ports) == 0 {
|
||||
target.Ports = []int64{443}
|
||||
@@ -115,13 +140,9 @@ func (s *NetworkScanService) UpdateTarget(ctx context.Context, id string, target
|
||||
existing.Name = target.Name
|
||||
}
|
||||
if len(target.CIDRs) > 0 {
|
||||
// Validate new CIDRs
|
||||
for _, cidr := range target.CIDRs {
|
||||
if _, _, err := net.ParseCIDR(cidr); err != nil {
|
||||
if ip := net.ParseIP(cidr); ip == nil {
|
||||
return nil, fmt.Errorf("invalid CIDR or IP: %s", cidr)
|
||||
}
|
||||
}
|
||||
// Validate new CIDRs (syntax + /20 size cap)
|
||||
if err := validateCIDRs(target.CIDRs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
existing.CIDRs = target.CIDRs
|
||||
}
|
||||
|
||||
@@ -391,6 +391,92 @@ func TestExpandCIDR_AllowsPrivateRanges(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// AUDIT-003: CIDR size validation at API level
|
||||
|
||||
func TestValidateCIDRs_AcceptsValidSizes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cidrs []string
|
||||
}{
|
||||
{"single IP", []string{"192.168.1.1"}},
|
||||
{"/24 network", []string{"10.0.0.0/24"}},
|
||||
{"/20 network (max)", []string{"10.0.0.0/20"}},
|
||||
{"/30 tiny network", []string{"10.0.0.0/30"}},
|
||||
{"multiple valid", []string{"10.0.0.0/24", "192.168.1.0/24"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateCIDRs(tt.cidrs)
|
||||
if err != nil {
|
||||
t.Errorf("expected valid CIDRs to be accepted, got error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCIDRs_RejectsOversized(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cidrs []string
|
||||
}{
|
||||
{"/19 too large", []string{"10.0.0.0/19"}},
|
||||
{"/16 way too large", []string{"10.0.0.0/16"}},
|
||||
{"/8 massive", []string{"10.0.0.0/8"}},
|
||||
{"/0 everything", []string{"0.0.0.0/0"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateCIDRs(tt.cidrs)
|
||||
if err == nil {
|
||||
t.Errorf("expected oversized CIDR %v to be rejected", tt.cidrs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCIDRs_RejectsInvalid(t *testing.T) {
|
||||
err := validateCIDRs([]string{"not-a-cidr"})
|
||||
if err == nil {
|
||||
t.Error("expected invalid CIDR to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkScanService_CreateTarget_RejectsOversizedCIDR(t *testing.T) {
|
||||
repo := &mockNetworkScanRepo{}
|
||||
auditRepo := newMockAuditRepository()
|
||||
auditService := NewAuditService(auditRepo)
|
||||
svc := NewNetworkScanService(repo, nil, auditService, nil)
|
||||
|
||||
_, err := svc.CreateTarget(context.Background(), &domain.NetworkScanTarget{
|
||||
Name: "Test",
|
||||
CIDRs: []string{"10.0.0.0/8"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected CreateTarget to reject /8 CIDR")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkScanService_UpdateTarget_RejectsOversizedCIDR(t *testing.T) {
|
||||
repo := &mockNetworkScanRepo{
|
||||
targets: []*domain.NetworkScanTarget{
|
||||
{ID: "nst-1", Name: "Original", CIDRs: []string{"10.0.0.0/24"}, Enabled: true},
|
||||
},
|
||||
}
|
||||
auditRepo := newMockAuditRepository()
|
||||
auditService := NewAuditService(auditRepo)
|
||||
svc := NewNetworkScanService(repo, nil, auditService, nil)
|
||||
|
||||
// Try to update from /24 to /8 — should be rejected
|
||||
_, err := svc.UpdateTarget(context.Background(), "nst-1", &domain.NetworkScanTarget{
|
||||
CIDRs: []string{"10.0.0.0/8"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected UpdateTarget to reject /8 CIDR update (bypass attempt)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandCIDR_SingleLoopbackIP(t *testing.T) {
|
||||
ips := expandCIDR("127.0.0.1")
|
||||
if len(ips) != 0 {
|
||||
|
||||
Reference in New Issue
Block a user