mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-08 02:31:34 +00:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 672e1d991d | |||
| 89b910a8f1 | |||
| 6315ef102a | |||
| 119986fa7e | |||
| 3853b7460c | |||
| e9947dc0fe | |||
| b813660c74 | |||
| 387fb555ac | |||
| f549a7aa79 | |||
| b219e5d68a |
@@ -107,6 +107,16 @@ jobs:
|
||||
tags: |
|
||||
${{ env.REGISTRY }}/shankar0123/certctl-server:${{ steps.version.outputs.VERSION }}
|
||||
${{ env.REGISTRY }}/shankar0123/certctl-server:latest
|
||||
# Proxy propagation (M-4, Issue #9) — forwards runner-level proxy
|
||||
# secrets into the Docker build so self-hosted runners behind
|
||||
# corporate proxies can reach public registries. GitHub-hosted
|
||||
# runners don't need proxies, so the secrets are optional and
|
||||
# resolve to empty strings when unset — byte-identical to the
|
||||
# pre-fix behaviour for the public-runner path.
|
||||
build-args: |
|
||||
HTTP_PROXY=${{ secrets.HTTP_PROXY }}
|
||||
HTTPS_PROXY=${{ secrets.HTTPS_PROXY }}
|
||||
NO_PROXY=${{ secrets.NO_PROXY }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
@@ -119,6 +129,13 @@ jobs:
|
||||
tags: |
|
||||
${{ env.REGISTRY }}/shankar0123/certctl-agent:${{ steps.version.outputs.VERSION }}
|
||||
${{ env.REGISTRY }}/shankar0123/certctl-agent:latest
|
||||
# Proxy propagation (M-4, Issue #9) — see server-image step for
|
||||
# rationale. Empty secrets resolve to empty build args, leaving
|
||||
# the un-proxied code path byte-identical to the pre-fix tree.
|
||||
build-args: |
|
||||
HTTP_PROXY=${{ secrets.HTTP_PROXY }}
|
||||
HTTPS_PROXY=${{ secrets.HTTPS_PROXY }}
|
||||
NO_PROXY=${{ secrets.NO_PROXY }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
|
||||
+27
@@ -3,6 +3,22 @@
|
||||
# Stage 1: Build frontend
|
||||
FROM node:20-alpine AS frontend
|
||||
|
||||
# Proxy propagation (M-4, Issue #9) — defaulted to empty so un-proxied builds
|
||||
# behave identically to the pre-fix tree. When `HTTP_PROXY`/`HTTPS_PROXY`/
|
||||
# `NO_PROXY` are forwarded via `docker build --build-arg` (or compose
|
||||
# `build.args`), they are re-exported as ENV with both upper- and lower-case
|
||||
# names because npm/apk/curl read the lowercase variants while Go, Node, and
|
||||
# most HTTP libraries read the uppercase ones.
|
||||
ARG HTTP_PROXY=
|
||||
ARG HTTPS_PROXY=
|
||||
ARG NO_PROXY=
|
||||
ENV HTTP_PROXY=${HTTP_PROXY} \
|
||||
HTTPS_PROXY=${HTTPS_PROXY} \
|
||||
NO_PROXY=${NO_PROXY} \
|
||||
http_proxy=${HTTP_PROXY} \
|
||||
https_proxy=${HTTPS_PROXY} \
|
||||
no_proxy=${NO_PROXY}
|
||||
|
||||
WORKDIR /app/web
|
||||
|
||||
COPY web/ .
|
||||
@@ -13,6 +29,17 @@ RUN npm ci --include=dev || npm ci --include=dev && \
|
||||
# Stage 2: Build Go binary
|
||||
FROM golang:1.25-alpine AS builder
|
||||
|
||||
# Proxy propagation (M-4, Issue #9) — see Stage 1 rationale.
|
||||
ARG HTTP_PROXY=
|
||||
ARG HTTPS_PROXY=
|
||||
ARG NO_PROXY=
|
||||
ENV HTTP_PROXY=${HTTP_PROXY} \
|
||||
HTTPS_PROXY=${HTTPS_PROXY} \
|
||||
NO_PROXY=${NO_PROXY} \
|
||||
http_proxy=${HTTP_PROXY} \
|
||||
https_proxy=${HTTPS_PROXY} \
|
||||
no_proxy=${NO_PROXY}
|
||||
|
||||
RUN apk add --no-cache git ca-certificates tzdata
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -2,6 +2,22 @@
|
||||
# Stage 1: Build
|
||||
FROM golang:1.25-alpine AS builder
|
||||
|
||||
# Proxy propagation (M-4, Issue #9) — defaulted to empty so un-proxied builds
|
||||
# behave identically to the pre-fix tree. When `HTTP_PROXY`/`HTTPS_PROXY`/
|
||||
# `NO_PROXY` are forwarded via `docker build --build-arg` (or compose
|
||||
# `build.args`), they are re-exported as ENV with both upper- and lower-case
|
||||
# names because apk and curl read the lowercase variants while Go reads the
|
||||
# uppercase ones.
|
||||
ARG HTTP_PROXY=
|
||||
ARG HTTPS_PROXY=
|
||||
ARG NO_PROXY=
|
||||
ENV HTTP_PROXY=${HTTP_PROXY} \
|
||||
HTTPS_PROXY=${HTTPS_PROXY} \
|
||||
NO_PROXY=${NO_PROXY} \
|
||||
http_proxy=${HTTP_PROXY} \
|
||||
https_proxy=${HTTPS_PROXY} \
|
||||
no_proxy=${NO_PROXY}
|
||||
|
||||
RUN apk add --no-cache git ca-certificates
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -320,7 +320,7 @@ Core lifecycle management — Local CA + ACME v2 issuers, NGINX target connector
|
||||
30+ milestones shipping enterprise-grade features for free. Sub-CA mode, ACME DNS-01/DNS-PERSIST-01/EAB/ARI (RFC 9773)/profile selection, step-ca, Vault PKI, DigiCert CertCentral, Sectigo SCM, Google CAS, AWS ACM PCA, Entrust, GlobalSign, EJBCA, OpenSSL/Custom CA issuers. NGINX, Apache, HAProxy, Traefik, Caddy, Envoy, Postfix, Dovecot, IIS (WinRM), F5 BIG-IP, SSH, Windows Certificate Store, Java Keystore, Kubernetes Secrets targets. EST server (RFC 7030) and SCEP server (RFC 8894) enrollment protocols. RFC 5280 revocation with DER CRL + embedded OCSP responder. Certificate profiles, ownership tracking, team assignment, agent groups, interactive approval workflows. Filesystem, network, and cloud secret manager (AWS SM, Azure KV, GCP SM) certificate discovery with triage GUI. Dynamic issuer/target configuration via GUI with AES-256-GCM encrypted storage. First-run onboarding wizard. Post-deployment TLS verification. Certificate export (PEM/PKCS#12). S/MIME support. Prometheus metrics. Scheduled certificate digest emails. Slack, Teams, PagerDuty, OpsGenie, SMTP notifications. MCP server (80 tools), CLI (12 commands), Helm chart. Compliance mapping (SOC 2, PCI-DSS 4.0, NIST SP 800-57). 5 turnkey deployment examples. Agent install script. Migration guides from certbot, acme.sh, and cert-manager. See the [Feature Inventory](docs/features.md) for details.
|
||||
|
||||
### V3: certctl Pro
|
||||
Team access controls and identity provider integration. Role-based access control with profile-gating. Event-driven architecture with real-time operational views. Advanced search, compliance scoring, and HSM/TPM integration.
|
||||
Enterprise capabilities for larger deployments are available in the commercial tier.
|
||||
|
||||
### V4+: Cloud & Scale
|
||||
Kubernetes cert-manager external issuer, cloud infrastructure targets, extended CA support, and platform-scale features.
|
||||
|
||||
+77
-1
@@ -89,7 +89,45 @@ func main() {
|
||||
encryptionKey = crypto.DeriveKey(cfg.Encryption.ConfigEncryptionKey)
|
||||
logger.Info("config encryption enabled (AES-256-GCM)")
|
||||
} else {
|
||||
logger.Warn("CERTCTL_CONFIG_ENCRYPTION_KEY not set — issuer configs stored in plaintext (not recommended for production)")
|
||||
// C-2 fix: fail closed at startup when database-sourced issuer or target
|
||||
// rows exist without a configured encryption key. Previously the server
|
||||
// would emit a one-line warning and silently persist new GUI-created
|
||||
// configs as plaintext (CWE-311). Refuse to start instead: the operator
|
||||
// must either configure CERTCTL_CONFIG_ENCRYPTION_KEY or remove the
|
||||
// vulnerable rows before the control plane can boot.
|
||||
ctx := context.Background()
|
||||
dbIssuers, ierr := issuerRepo.List(ctx)
|
||||
if ierr != nil {
|
||||
logger.Error("startup check: failed to list issuers", "error", ierr)
|
||||
os.Exit(1)
|
||||
}
|
||||
dbTargets, terr := targetRepo.List(ctx)
|
||||
if terr != nil {
|
||||
logger.Error("startup check: failed to list targets", "error", terr)
|
||||
os.Exit(1)
|
||||
}
|
||||
var dbIssuerCount, dbTargetCount int
|
||||
for _, iss := range dbIssuers {
|
||||
if iss != nil && iss.Source == "database" {
|
||||
dbIssuerCount++
|
||||
}
|
||||
}
|
||||
for _, tgt := range dbTargets {
|
||||
if tgt != nil && tgt.Source == "database" {
|
||||
dbTargetCount++
|
||||
}
|
||||
}
|
||||
if dbIssuerCount > 0 || dbTargetCount > 0 {
|
||||
logger.Error(
|
||||
"startup refused: CERTCTL_CONFIG_ENCRYPTION_KEY is not set but database-sourced configs exist "+
|
||||
"(would expose sensitive fields as plaintext, CWE-311). "+
|
||||
"Set the encryption key or remove the affected rows before restarting.",
|
||||
"database_sourced_issuers", dbIssuerCount,
|
||||
"database_sourced_targets", dbTargetCount,
|
||||
)
|
||||
os.Exit(1)
|
||||
}
|
||||
logger.Warn("CERTCTL_CONFIG_ENCRYPTION_KEY not set — env-seeded issuers will be stored in plaintext; GUI-created issuers and targets will be rejected until a key is configured")
|
||||
}
|
||||
|
||||
issuerRegistry := service.NewIssuerRegistry(logger)
|
||||
@@ -445,6 +483,24 @@ func main() {
|
||||
|
||||
// Register SCEP (RFC 8894) handlers if enabled
|
||||
if cfg.SCEP.Enabled {
|
||||
// H-2 fix: fail closed at startup when SCEP is enabled without a
|
||||
// challenge password configured. Previously the service-layer guard
|
||||
// at internal/service/scep.go:72-79 skipped the password check when
|
||||
// s.challengePassword == "", meaning any client that could reach the
|
||||
// /scep endpoint could enroll an arbitrary CSR against the configured
|
||||
// issuer (CWE-306, missing authentication for a critical function).
|
||||
// Refuse to start instead: the operator must set
|
||||
// CERTCTL_SCEP_CHALLENGE_PASSWORD (or disable SCEP) before the control
|
||||
// plane can boot.
|
||||
if err := preflightSCEPChallengePassword(cfg.SCEP.Enabled, cfg.SCEP.ChallengePassword); err != nil {
|
||||
logger.Error(
|
||||
"startup refused: SCEP is enabled but CERTCTL_SCEP_CHALLENGE_PASSWORD is not set "+
|
||||
"(would allow unauthenticated certificate enrollment, CWE-306). "+
|
||||
"Set a non-empty challenge password or disable SCEP before restarting.",
|
||||
"error", err,
|
||||
)
|
||||
os.Exit(1)
|
||||
}
|
||||
issuerConn, ok := issuerRegistry.Get(cfg.SCEP.IssuerID)
|
||||
if !ok {
|
||||
logger.Error("SCEP issuer not found in registry", "issuer_id", cfg.SCEP.IssuerID)
|
||||
@@ -645,3 +701,23 @@ func main() {
|
||||
logger.Info("certctl server stopped")
|
||||
}
|
||||
|
||||
// preflightSCEPChallengePassword enforces the H-2 fix: if SCEP is enabled, a
|
||||
// non-empty challenge password MUST be configured. Returns a non-nil error
|
||||
// otherwise so the caller can refuse to start the control plane (CWE-306,
|
||||
// missing authentication for a critical function).
|
||||
//
|
||||
// This helper is extracted so the check can be unit tested without booting
|
||||
// the full server. The caller (main) is responsible for translating the
|
||||
// returned error into a structured log line and os.Exit(1).
|
||||
func preflightSCEPChallengePassword(enabled bool, challengePassword string) error {
|
||||
if !enabled {
|
||||
return nil
|
||||
}
|
||||
if challengePassword == "" {
|
||||
return fmt.Errorf("SCEP enabled but CERTCTL_SCEP_CHALLENGE_PASSWORD is empty: " +
|
||||
"SCEP enrollment would accept any client (CWE-306); " +
|
||||
"configure a non-empty shared secret or set CERTCTL_SCEP_ENABLED=false")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/api/middleware"
|
||||
@@ -538,3 +539,68 @@ func TestMain_ContextPropagation(t *testing.T) {
|
||||
t.Logf("Context value may not be propagated (status %d), this may be expected", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPreflightSCEPChallengePassword is the H-2 regression guard for the
|
||||
// startup pre-flight check. The helper MUST return a non-nil error whenever
|
||||
// SCEP is enabled with an empty challenge password — that configuration
|
||||
// previously allowed unauthenticated certificate enrollment (CWE-306).
|
||||
// Disabled-SCEP and configured-password cases must pass cleanly.
|
||||
func TestPreflightSCEPChallengePassword(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
enabled bool
|
||||
challengePassword string
|
||||
wantErr bool
|
||||
wantErrSubstring string
|
||||
}{
|
||||
{
|
||||
name: "disabled_empty_password_ok",
|
||||
enabled: false,
|
||||
challengePassword: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "disabled_with_password_ok",
|
||||
enabled: false,
|
||||
challengePassword: "leftover-value",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "enabled_empty_password_rejected",
|
||||
enabled: true,
|
||||
challengePassword: "",
|
||||
wantErr: true,
|
||||
wantErrSubstring: "CERTCTL_SCEP_CHALLENGE_PASSWORD",
|
||||
},
|
||||
{
|
||||
name: "enabled_with_password_ok",
|
||||
enabled: true,
|
||||
challengePassword: "hunter2",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "enabled_single_char_password_ok",
|
||||
enabled: true,
|
||||
challengePassword: "x",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := preflightSCEPChallengePassword(tt.enabled, tt.challengePassword)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
if tt.wantErrSubstring != "" && !strings.Contains(err.Error(), tt.wantErrSubstring) {
|
||||
t.Errorf("expected error to mention %q, got: %v", tt.wantErrSubstring, err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "CWE-306") {
|
||||
t.Errorf("expected error to cite CWE-306 for traceability, got: %v", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
t.Errorf("expected no error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,16 @@ services:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: Dockerfile
|
||||
# Proxy propagation (M-4, Issue #9) — forwards host shell's proxy env
|
||||
# vars into the Docker build so the Node frontend stage and Go module
|
||||
# download can reach the public registries behind corporate proxies.
|
||||
# Defaults to empty; omit the variables from the host environment for
|
||||
# un-proxied builds and the behaviour is byte-identical to the pre-fix
|
||||
# tree.
|
||||
args:
|
||||
HTTP_PROXY: ${HTTP_PROXY:-}
|
||||
HTTPS_PROXY: ${HTTPS_PROXY:-}
|
||||
NO_PROXY: ${NO_PROXY:-}
|
||||
environment:
|
||||
# Verbose logging for development
|
||||
CERTCTL_LOG_LEVEL: debug
|
||||
@@ -29,6 +39,15 @@ services:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: Dockerfile.agent
|
||||
# Proxy propagation (M-4, Issue #9) — forwards host shell's proxy env
|
||||
# vars into the Docker build so the Go module download stage can reach
|
||||
# the public Go module proxy behind corporate proxies. Defaults to
|
||||
# empty; omit the variables from the host environment for un-proxied
|
||||
# builds and the behaviour is byte-identical to the pre-fix tree.
|
||||
args:
|
||||
HTTP_PROXY: ${HTTP_PROXY:-}
|
||||
HTTPS_PROXY: ${HTTPS_PROXY:-}
|
||||
NO_PROXY: ${NO_PROXY:-}
|
||||
environment:
|
||||
CERTCTL_LOG_LEVEL: debug
|
||||
|
||||
|
||||
@@ -150,6 +150,16 @@ services:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: Dockerfile
|
||||
# Proxy propagation (M-4, Issue #9) — forwards host shell's proxy env
|
||||
# vars into the Docker build so the Node frontend stage and Go module
|
||||
# download can reach the public registries behind corporate proxies.
|
||||
# Defaults to empty; omit the variables from the host environment for
|
||||
# un-proxied builds and the behaviour is byte-identical to the pre-fix
|
||||
# tree.
|
||||
args:
|
||||
HTTP_PROXY: ${HTTP_PROXY:-}
|
||||
HTTPS_PROXY: ${HTTPS_PROXY:-}
|
||||
NO_PROXY: ${NO_PROXY:-}
|
||||
container_name: certctl-test-server
|
||||
depends_on:
|
||||
postgres:
|
||||
@@ -266,6 +276,15 @@ services:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: Dockerfile.agent
|
||||
# Proxy propagation (M-4, Issue #9) — forwards host shell's proxy env
|
||||
# vars into the Docker build so the Go module download stage can reach
|
||||
# the public Go module proxy behind corporate proxies. Defaults to
|
||||
# empty; omit the variables from the host environment for un-proxied
|
||||
# builds and the behaviour is byte-identical to the pre-fix tree.
|
||||
args:
|
||||
HTTP_PROXY: ${HTTP_PROXY:-}
|
||||
HTTPS_PROXY: ${HTTPS_PROXY:-}
|
||||
NO_PROXY: ${NO_PROXY:-}
|
||||
container_name: certctl-test-agent
|
||||
depends_on:
|
||||
certctl-server:
|
||||
|
||||
@@ -36,6 +36,16 @@ services:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: Dockerfile
|
||||
# Proxy propagation (M-4, Issue #9) — forwards host shell's proxy env
|
||||
# vars into the Docker build so the Node frontend stage and Go module
|
||||
# download can reach the public registries behind corporate proxies.
|
||||
# Defaults to empty; omit the variables from the host environment for
|
||||
# un-proxied builds and the behaviour is byte-identical to the pre-fix
|
||||
# tree.
|
||||
args:
|
||||
HTTP_PROXY: ${HTTP_PROXY:-}
|
||||
HTTPS_PROXY: ${HTTPS_PROXY:-}
|
||||
NO_PROXY: ${NO_PROXY:-}
|
||||
container_name: certctl-server
|
||||
depends_on:
|
||||
postgres:
|
||||
@@ -75,6 +85,15 @@ services:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: Dockerfile.agent
|
||||
# Proxy propagation (M-4, Issue #9) — forwards host shell's proxy env
|
||||
# vars into the Docker build so the Go module download stage can reach
|
||||
# the public Go module proxy behind corporate proxies. Defaults to
|
||||
# empty; omit the variables from the host environment for un-proxied
|
||||
# builds and the behaviour is byte-identical to the pre-fix tree.
|
||||
args:
|
||||
HTTP_PROXY: ${HTTP_PROXY:-}
|
||||
HTTPS_PROXY: ${HTTPS_PROXY:-}
|
||||
NO_PROXY: ${NO_PROXY:-}
|
||||
container_name: certctl-agent
|
||||
depends_on:
|
||||
certctl-server:
|
||||
|
||||
@@ -465,9 +465,12 @@ GlobalSign Atlas High Volume CA REST API with dual authentication: mTLS for the
|
||||
| `CERTCTL_GLOBALSIGN_API_SECRET` | Yes | — | API secret for request authentication |
|
||||
| `CERTCTL_GLOBALSIGN_CLIENT_CERT_PATH` | Yes | — | Path to mTLS client certificate PEM |
|
||||
| `CERTCTL_GLOBALSIGN_CLIENT_KEY_PATH` | Yes | — | Path to mTLS client private key PEM |
|
||||
| `CERTCTL_GLOBALSIGN_SERVER_CA_PATH` | No | system trust store | PEM bundle used to verify the Atlas API server certificate. Set this for private/lab Atlas deployments whose server TLS chain is not in the host's default trust bundle. |
|
||||
|
||||
**Authentication:** Dual — mTLS client certificate for TLS handshake plus `X-API-Key` and `X-API-Secret` headers on every request.
|
||||
|
||||
**TLS verification:** The connector always verifies the server certificate. When `server_ca_path` is set, the PEM bundle at that path is used as the trust anchor; otherwise the host's system trust store is used. TLS 1.2 is the minimum protocol version.
|
||||
|
||||
**Issuance model:** `POST /v2/certificates` returns a serial number. Certificate PEM is available after validation completes. Typically resolves within seconds for DV. `GetOrderStatus` polls the certificate endpoint.
|
||||
|
||||
**Note:** CRL and OCSP are managed by GlobalSign. certctl records revocations locally and notifies GlobalSign via `PUT /v2/certificates/{serial}/revoke`.
|
||||
|
||||
@@ -116,6 +116,14 @@ type GlobalSignConfig struct {
|
||||
// ClientKeyPath is the path to the mTLS client private key PEM file.
|
||||
// Setting: CERTCTL_GLOBALSIGN_CLIENT_KEY_PATH environment variable.
|
||||
ClientKeyPath string
|
||||
|
||||
// ServerCAPath is the optional path to a PEM file containing the CA
|
||||
// certificate(s) used to verify the GlobalSign Atlas HVCA API server
|
||||
// certificate. If empty, the system trust store is used. Set this
|
||||
// for private/lab Atlas deployments whose server TLS chain is not
|
||||
// present in the host's default trust bundle.
|
||||
// Setting: CERTCTL_GLOBALSIGN_SERVER_CA_PATH environment variable.
|
||||
ServerCAPath string
|
||||
}
|
||||
|
||||
// EJBCAConfig contains EJBCA (Keyfactor) issuer connector configuration.
|
||||
@@ -641,7 +649,12 @@ type SCEPConfig struct {
|
||||
|
||||
// ChallengePassword is the shared secret used to authenticate SCEP enrollment requests.
|
||||
// Clients include this in the PKCS#10 CSR challengePassword attribute.
|
||||
// Required when SCEP is enabled.
|
||||
//
|
||||
// REQUIRED when Enabled is true. If SCEP is enabled and this value is empty,
|
||||
// cmd/server/main.go's preflightSCEPChallengePassword check will refuse to
|
||||
// start the server (H-2, CWE-306): an empty shared secret allowed any client
|
||||
// that could reach /scep to enroll a CSR against the configured issuer. The
|
||||
// service-layer PKCSReq path also rejects this configuration defense-in-depth.
|
||||
ChallengePassword string
|
||||
}
|
||||
|
||||
@@ -882,6 +895,7 @@ func Load() (*Config, error) {
|
||||
APISecret: getEnv("CERTCTL_GLOBALSIGN_API_SECRET", ""),
|
||||
ClientCertPath: getEnv("CERTCTL_GLOBALSIGN_CLIENT_CERT_PATH", ""),
|
||||
ClientKeyPath: getEnv("CERTCTL_GLOBALSIGN_CLIENT_KEY_PATH", ""),
|
||||
ServerCAPath: getEnv("CERTCTL_GLOBALSIGN_SERVER_CA_PATH", ""),
|
||||
},
|
||||
EJBCA: EJBCAConfig{
|
||||
APIUrl: getEnv("CERTCTL_EJBCA_API_URL", ""),
|
||||
|
||||
@@ -34,6 +34,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -64,6 +65,14 @@ type Config struct {
|
||||
// Must match the certificate in ClientCertPath.
|
||||
// Required. Set via CERTCTL_GLOBALSIGN_CLIENT_KEY_PATH environment variable.
|
||||
ClientKeyPath string `json:"client_key_path"`
|
||||
|
||||
// ServerCAPath is the filesystem path to a PEM file containing the CA
|
||||
// certificate(s) used to verify the GlobalSign Atlas HVCA API server certificate.
|
||||
// Optional. If empty, the system trust store is used. This option exists for
|
||||
// private/lab deployments of GlobalSign Atlas that terminate TLS with an
|
||||
// internal CA not present in the host's default trust bundle.
|
||||
// Set via CERTCTL_GLOBALSIGN_SERVER_CA_PATH environment variable.
|
||||
ServerCAPath string `json:"server_ca_path,omitempty"`
|
||||
}
|
||||
|
||||
// Connector implements the issuer.Connector interface for GlobalSign Atlas HVCA.
|
||||
@@ -153,14 +162,12 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag
|
||||
return fmt.Errorf("failed to load GlobalSign client certificate: %w", err)
|
||||
}
|
||||
|
||||
// Create an mTLS client for validation
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
// InsecureSkipVerify=true allows testing against self-signed server certs.
|
||||
// In production, GlobalSign's API uses a proper certificate chain.
|
||||
// This matches the pattern used by other connectors (F5, network scanner, etc.)
|
||||
// that also need to bypass hostname verification for internal/lab environments.
|
||||
InsecureSkipVerify: true,
|
||||
// Build a verifying mTLS TLS config. If ServerCAPath is set, that PEM
|
||||
// bundle is used as the trust anchor for the server certificate;
|
||||
// otherwise the system trust store is used. TLS 1.2 is the minimum.
|
||||
tlsConfig, err := buildServerTLSConfig(&cfg, cert)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build GlobalSign TLS config: %w", err)
|
||||
}
|
||||
|
||||
validationClient := &http.Client{
|
||||
@@ -225,9 +232,9 @@ func (c *Connector) getHTTPClient(ctx context.Context) (*http.Client, error) {
|
||||
return nil, fmt.Errorf("failed to load GlobalSign client certificate: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
InsecureSkipVerify: true,
|
||||
tlsConfig, err := buildServerTLSConfig(c.config, cert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build GlobalSign TLS config: %w", err)
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
@@ -238,6 +245,38 @@ func (c *Connector) getHTTPClient(ctx context.Context) (*http.Client, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildServerTLSConfig returns a TLS configuration for the GlobalSign Atlas
|
||||
// HVCA API client. It always verifies the server certificate. When
|
||||
// cfg.ServerCAPath is set, the PEM bundle at that path is used as the
|
||||
// trust anchor (enables pinning a private/lab CA); otherwise the host's
|
||||
// system trust store is used. TLS 1.2 is the minimum protocol version.
|
||||
//
|
||||
// This helper is the single source of truth for both the ValidateConfig
|
||||
// probe client and the steady-state getHTTPClient production client, so
|
||||
// any future TLS policy change applies uniformly.
|
||||
func buildServerTLSConfig(cfg *Config, clientCert tls.Certificate) (*tls.Config, error) {
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{clientCert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
if cfg.ServerCAPath != "" {
|
||||
caPEM, err := os.ReadFile(cfg.ServerCAPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read server CA bundle at %s: %w", cfg.ServerCAPath, err)
|
||||
}
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(caPEM) {
|
||||
return nil, fmt.Errorf("no valid PEM certificates found in server CA bundle at %s", cfg.ServerCAPath)
|
||||
}
|
||||
|
||||
tlsConfig.RootCAs = pool
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// IssueCertificate submits a certificate order to GlobalSign Atlas HVCA.
|
||||
// Returns the serial number immediately; typically the cert is available within seconds (DV) to minutes (OV).
|
||||
func (c *Connector) IssueCertificate(ctx context.Context, request issuer.IssuanceRequest) (*issuer.IssuanceResult, error) {
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/json"
|
||||
@@ -161,11 +160,7 @@ func TestGlobalSignConnector(t *testing.T) {
|
||||
testCertPEM, _ := generateTestCert(t)
|
||||
testChainPEM, _ := generateTestCert(t)
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
httpClient := &http.Client{}
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/v2/certificates" && r.Method == http.MethodPost {
|
||||
@@ -223,11 +218,7 @@ func TestGlobalSignConnector(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("IssueCertificate_Pending", func(t *testing.T) {
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
httpClient := &http.Client{}
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/v2/certificates" && r.Method == http.MethodPost {
|
||||
@@ -271,11 +262,7 @@ func TestGlobalSignConnector(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("IssueCertificate_Error", func(t *testing.T) {
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
httpClient := &http.Client{}
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/v2/certificates" && r.Method == http.MethodPost {
|
||||
@@ -312,11 +299,7 @@ func TestGlobalSignConnector(t *testing.T) {
|
||||
testCertPEM, _ := generateTestCert(t)
|
||||
testChainPEM, _ := generateTestCert(t)
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
httpClient := &http.Client{}
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasPrefix(r.URL.Path, "/v2/certificates/12345") && r.Method == http.MethodGet {
|
||||
@@ -356,11 +339,7 @@ func TestGlobalSignConnector(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("GetOrderStatus_Pending", func(t *testing.T) {
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
httpClient := &http.Client{}
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasPrefix(r.URL.Path, "/v2/certificates/98765") && r.Method == http.MethodGet {
|
||||
@@ -401,11 +380,7 @@ func TestGlobalSignConnector(t *testing.T) {
|
||||
testCertPEM, _ := generateTestCert(t)
|
||||
testChainPEM, _ := generateTestCert(t)
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
httpClient := &http.Client{}
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/v2/certificates" && r.Method == http.MethodPost {
|
||||
@@ -448,11 +423,7 @@ func TestGlobalSignConnector(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("RevokeCertificate_Success", func(t *testing.T) {
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
httpClient := &http.Client{}
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasPrefix(r.URL.Path, "/v2/certificates/") && strings.HasSuffix(r.URL.Path, "/revoke") && r.Method == http.MethodPut {
|
||||
@@ -492,11 +463,7 @@ func TestGlobalSignConnector(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("RevokeCertificate_Error", func(t *testing.T) {
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
httpClient := &http.Client{}
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasPrefix(r.URL.Path, "/v2/certificates/") && strings.HasSuffix(r.URL.Path, "/revoke") && r.Method == http.MethodPut {
|
||||
@@ -532,11 +499,7 @@ func TestGlobalSignConnector(t *testing.T) {
|
||||
testChainPEM, _ := generateTestCert(t)
|
||||
authHeadersChecked := 0
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
httpClient := &http.Client{}
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check for auth headers on every request
|
||||
@@ -584,6 +547,177 @@ func TestGlobalSignConnector(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestGlobalSign_ServerTLSConfig exercises the server-side TLS verification
|
||||
// policy added by H-5. The connector must always verify the GlobalSign Atlas
|
||||
// HVCA API server certificate: by default against the host's system trust
|
||||
// store, and when ServerCAPath is set, against the pinned PEM bundle at that
|
||||
// path. InsecureSkipVerify is no longer reachable from any production code path.
|
||||
func TestGlobalSign_ServerTLSConfig(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
ctx := context.Background()
|
||||
|
||||
// writeClientMTLS generates a throwaway client cert+key pair and writes them
|
||||
// to disk. ValidateConfig requires valid ClientCertPath / ClientKeyPath files
|
||||
// before it reaches the server-CA validation path under test.
|
||||
writeClientMTLS := func(t *testing.T) (certPath, keyPath string) {
|
||||
t.Helper()
|
||||
certPEM, keyPEM := generateTestCert(t)
|
||||
dir := t.TempDir()
|
||||
certPath = dir + "/client-cert.pem"
|
||||
keyPath = dir + "/client-key.pem"
|
||||
if err := os.WriteFile(certPath, []byte(certPEM), 0600); err != nil {
|
||||
t.Fatalf("failed to write client cert: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, []byte(keyPEM), 0600); err != nil {
|
||||
t.Fatalf("failed to write client key: %v", err)
|
||||
}
|
||||
return certPath, keyPath
|
||||
}
|
||||
|
||||
// certToPEM re-encodes a parsed certificate as a PEM block for trust-store
|
||||
// pinning. httptest.NewTLSServer.Certificate() returns the server's self-
|
||||
// signed cert; pinning that cert trusts exactly that one server.
|
||||
certToPEM := func(t *testing.T, cert *x509.Certificate) string {
|
||||
t.Helper()
|
||||
return string(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
}))
|
||||
}
|
||||
|
||||
t.Run("PinnedCA_TrustsExpectedServer", func(t *testing.T) {
|
||||
// Mock Atlas API served over HTTPS with a self-signed cert. We pin
|
||||
// that cert's PEM as the client's trust anchor; the validation probe
|
||||
// should succeed because the pinned pool contains the server's issuer.
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/v2/certificates" && r.Method == http.MethodGet {
|
||||
if r.Header.Get("ApiKey") == "gs-test-key" && r.Header.Get("ApiSecret") == "gs-test-secret" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"certificates":[]}`))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
caPEM := certToPEM(t, srv.Certificate())
|
||||
caPath := t.TempDir() + "/atlas-ca.pem"
|
||||
if err := os.WriteFile(caPath, []byte(caPEM), 0600); err != nil {
|
||||
t.Fatalf("failed to write pinned CA: %v", err)
|
||||
}
|
||||
|
||||
clientCert, clientKey := writeClientMTLS(t)
|
||||
config := globalsign.Config{
|
||||
APIUrl: srv.URL,
|
||||
APIKey: "gs-test-key",
|
||||
APISecret: "gs-test-secret",
|
||||
ClientCertPath: clientCert,
|
||||
ClientKeyPath: clientKey,
|
||||
ServerCAPath: caPath,
|
||||
}
|
||||
|
||||
connector := globalsign.New(&config, logger)
|
||||
rawConfig, _ := json.Marshal(config)
|
||||
if err := connector.ValidateConfig(ctx, rawConfig); err != nil {
|
||||
t.Fatalf("ValidateConfig with pinned CA should succeed, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PinnedCA_RejectsUntrustedServer", func(t *testing.T) {
|
||||
// Mock server presents its own self-signed cert; we pin an UNRELATED
|
||||
// cert as the trust anchor. The TLS handshake must fail before any
|
||||
// request is sent — this is exactly what H-5 remediates.
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
unrelatedPEM, _ := generateTestCert(t)
|
||||
caPath := t.TempDir() + "/unrelated-ca.pem"
|
||||
if err := os.WriteFile(caPath, []byte(unrelatedPEM), 0600); err != nil {
|
||||
t.Fatalf("failed to write unrelated CA: %v", err)
|
||||
}
|
||||
|
||||
clientCert, clientKey := writeClientMTLS(t)
|
||||
config := globalsign.Config{
|
||||
APIUrl: srv.URL,
|
||||
APIKey: "gs-test-key",
|
||||
APISecret: "gs-test-secret",
|
||||
ClientCertPath: clientCert,
|
||||
ClientKeyPath: clientKey,
|
||||
ServerCAPath: caPath,
|
||||
}
|
||||
|
||||
connector := globalsign.New(&config, logger)
|
||||
rawConfig, _ := json.Marshal(config)
|
||||
err := connector.ValidateConfig(ctx, rawConfig)
|
||||
if err == nil {
|
||||
t.Fatal("ValidateConfig must fail when the server cert is not signed by the pinned CA")
|
||||
}
|
||||
// The failure must originate from TLS verification, not from any other path.
|
||||
if !strings.Contains(err.Error(), "x509") &&
|
||||
!strings.Contains(err.Error(), "certificate") &&
|
||||
!strings.Contains(err.Error(), "unknown authority") {
|
||||
t.Errorf("expected TLS verification error, got: %v", err)
|
||||
}
|
||||
t.Logf("Untrusted server cert correctly rejected: %v", err)
|
||||
})
|
||||
|
||||
t.Run("ServerCAPath_MissingFile", func(t *testing.T) {
|
||||
clientCert, clientKey := writeClientMTLS(t)
|
||||
config := globalsign.Config{
|
||||
APIUrl: "https://example.invalid",
|
||||
APIKey: "gs-test-key",
|
||||
APISecret: "gs-test-secret",
|
||||
ClientCertPath: clientCert,
|
||||
ClientKeyPath: clientKey,
|
||||
ServerCAPath: "/nonexistent/path/to/ca.pem",
|
||||
}
|
||||
|
||||
connector := globalsign.New(&config, logger)
|
||||
rawConfig, _ := json.Marshal(config)
|
||||
err := connector.ValidateConfig(ctx, rawConfig)
|
||||
if err == nil {
|
||||
t.Fatal("ValidateConfig must fail when ServerCAPath points to a missing file")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed to read server CA bundle") {
|
||||
t.Errorf("expected 'failed to read server CA bundle' error, got: %v", err)
|
||||
}
|
||||
t.Logf("Missing server CA file correctly rejected: %v", err)
|
||||
})
|
||||
|
||||
t.Run("ServerCAPath_InvalidPEM", func(t *testing.T) {
|
||||
clientCert, clientKey := writeClientMTLS(t)
|
||||
badCAPath := t.TempDir() + "/garbage.pem"
|
||||
if err := os.WriteFile(badCAPath, []byte("this is not a PEM certificate at all"), 0600); err != nil {
|
||||
t.Fatalf("failed to write garbage file: %v", err)
|
||||
}
|
||||
|
||||
config := globalsign.Config{
|
||||
APIUrl: "https://example.invalid",
|
||||
APIKey: "gs-test-key",
|
||||
APISecret: "gs-test-secret",
|
||||
ClientCertPath: clientCert,
|
||||
ClientKeyPath: clientKey,
|
||||
ServerCAPath: badCAPath,
|
||||
}
|
||||
|
||||
connector := globalsign.New(&config, logger)
|
||||
rawConfig, _ := json.Marshal(config)
|
||||
err := connector.ValidateConfig(ctx, rawConfig)
|
||||
if err == nil {
|
||||
t.Fatal("ValidateConfig must fail when ServerCAPath contains no valid PEM certificates")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no valid PEM certificates") {
|
||||
t.Errorf("expected 'no valid PEM certificates' error, got: %v", err)
|
||||
}
|
||||
t.Logf("Invalid PEM correctly rejected: %v", err)
|
||||
})
|
||||
}
|
||||
|
||||
// generateTestCert generates a self-signed test certificate and returns PEM strings.
|
||||
func generateTestCert(t *testing.T) (certPEM string, keyPEM string) {
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/connector/notifier"
|
||||
"github.com/shankar0123/certctl/internal/validation"
|
||||
)
|
||||
|
||||
// Config represents the email notifier configuration.
|
||||
@@ -123,7 +124,22 @@ func (c *Connector) SendEvent(ctx context.Context, event notifier.Event) error {
|
||||
|
||||
// sendEmail sends an email message using the configured SMTP server.
|
||||
// It handles both TLS and plain authentication modes.
|
||||
//
|
||||
// Header values (From, To, Subject) are validated up-front to reject CR, LF,
|
||||
// and NUL characters. This blocks SMTP header injection (CWE-113) and also
|
||||
// prevents injection into the SMTP envelope commands MAIL FROM and RCPT TO,
|
||||
// since net/smtp does not sanitize those inputs itself.
|
||||
func (c *Connector) sendEmail(ctx context.Context, to, subject, body string) error {
|
||||
if err := validation.ValidateHeaderValue("From", c.config.FromAddress); err != nil {
|
||||
return fmt.Errorf("invalid sender: %w", err)
|
||||
}
|
||||
if err := validation.ValidateHeaderValue("To", to); err != nil {
|
||||
return fmt.Errorf("invalid recipient: %w", err)
|
||||
}
|
||||
if err := validation.ValidateHeaderValue("Subject", subject); err != nil {
|
||||
return fmt.Errorf("invalid subject: %w", err)
|
||||
}
|
||||
|
||||
addr := net.JoinHostPort(c.config.SMTPHost, strconv.Itoa(c.config.SMTPPort))
|
||||
|
||||
// Connect to SMTP server
|
||||
@@ -182,8 +198,13 @@ func (c *Connector) sendEmail(ctx context.Context, to, subject, body string) err
|
||||
}
|
||||
defer wc.Close()
|
||||
|
||||
// Format and write email headers and body
|
||||
message := c.formatEmailMessage(c.config.FromAddress, to, subject, body)
|
||||
// Format and write email headers and body. The format function
|
||||
// re-validates header values as defense-in-depth; the early-return
|
||||
// above should have already caught any injection attempt.
|
||||
message, err := c.formatEmailMessage(c.config.FromAddress, to, subject, body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to format message: %w", err)
|
||||
}
|
||||
if _, err := wc.Write(message); err != nil {
|
||||
return fmt.Errorf("failed to write message: %w", err)
|
||||
}
|
||||
@@ -197,7 +218,22 @@ func (c *Connector) sendEmail(ctx context.Context, to, subject, body string) err
|
||||
|
||||
// sendHTMLEmail sends an HTML email message using the configured SMTP server.
|
||||
// Used by the digest service for rich HTML digest emails.
|
||||
//
|
||||
// Header values (From, To, Subject) are validated up-front to reject CR, LF,
|
||||
// and NUL characters. This blocks SMTP header injection (CWE-113) and also
|
||||
// prevents injection into the SMTP envelope commands MAIL FROM and RCPT TO,
|
||||
// since net/smtp does not sanitize those inputs itself.
|
||||
func (c *Connector) sendHTMLEmail(ctx context.Context, to, subject, htmlBody string) error {
|
||||
if err := validation.ValidateHeaderValue("From", c.config.FromAddress); err != nil {
|
||||
return fmt.Errorf("invalid sender: %w", err)
|
||||
}
|
||||
if err := validation.ValidateHeaderValue("To", to); err != nil {
|
||||
return fmt.Errorf("invalid recipient: %w", err)
|
||||
}
|
||||
if err := validation.ValidateHeaderValue("Subject", subject); err != nil {
|
||||
return fmt.Errorf("invalid subject: %w", err)
|
||||
}
|
||||
|
||||
addr := net.JoinHostPort(c.config.SMTPHost, strconv.Itoa(c.config.SMTPPort))
|
||||
|
||||
var auth smtp.Auth
|
||||
@@ -250,7 +286,12 @@ func (c *Connector) sendHTMLEmail(ctx context.Context, to, subject, htmlBody str
|
||||
}
|
||||
defer wc.Close()
|
||||
|
||||
message := c.formatHTMLEmailMessage(c.config.FromAddress, to, subject, htmlBody)
|
||||
// The format function re-validates header values as defense-in-depth;
|
||||
// the early-return above should have already caught any injection attempt.
|
||||
message, err := c.formatHTMLEmailMessage(c.config.FromAddress, to, subject, htmlBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to format message: %w", err)
|
||||
}
|
||||
if _, err := wc.Write(message); err != nil {
|
||||
return fmt.Errorf("failed to write message: %w", err)
|
||||
}
|
||||
@@ -263,7 +304,20 @@ func (c *Connector) sendHTMLEmail(ctx context.Context, to, subject, htmlBody str
|
||||
}
|
||||
|
||||
// formatEmailMessage formats an email message with standard headers.
|
||||
func (c *Connector) formatEmailMessage(from, to, subject, body string) []byte {
|
||||
// It rejects any header value containing CR, LF, or NUL bytes to prevent
|
||||
// SMTP header injection (CWE-113). See internal/validation.ValidateHeaderValue.
|
||||
// The body is not validated — CR/LF in the body is legitimate content, and
|
||||
// SMTP dot-stuffing / length framing are handled by net/smtp.
|
||||
func (c *Connector) formatEmailMessage(from, to, subject, body string) ([]byte, error) {
|
||||
if err := validation.ValidateHeaderValue("From", from); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validation.ValidateHeaderValue("To", to); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validation.ValidateHeaderValue("Subject", subject); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
message := fmt.Sprintf(
|
||||
"From: %s\r\nTo: %s\r\nSubject: %s\r\nDate: %s\r\nContent-Type: text/plain; charset=utf-8\r\n\r\n%s",
|
||||
from,
|
||||
@@ -272,11 +326,24 @@ func (c *Connector) formatEmailMessage(from, to, subject, body string) []byte {
|
||||
time.Now().Format(time.RFC1123Z),
|
||||
body,
|
||||
)
|
||||
return []byte(message)
|
||||
return []byte(message), nil
|
||||
}
|
||||
|
||||
// formatHTMLEmailMessage formats an HTML email message with MIME headers.
|
||||
func (c *Connector) formatHTMLEmailMessage(from, to, subject, htmlBody string) []byte {
|
||||
// It rejects any header value containing CR, LF, or NUL bytes to prevent
|
||||
// SMTP header injection (CWE-113). See internal/validation.ValidateHeaderValue.
|
||||
// The HTML body is not validated at this layer — CR/LF in HTML content is
|
||||
// legitimate, and SMTP dot-stuffing / length framing are handled by net/smtp.
|
||||
func (c *Connector) formatHTMLEmailMessage(from, to, subject, htmlBody string) ([]byte, error) {
|
||||
if err := validation.ValidateHeaderValue("From", from); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validation.ValidateHeaderValue("To", to); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validation.ValidateHeaderValue("Subject", subject); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
message := fmt.Sprintf(
|
||||
"From: %s\r\nTo: %s\r\nSubject: %s\r\nDate: %s\r\nMIME-Version: 1.0\r\nContent-Type: text/html; charset=utf-8\r\n\r\n%s",
|
||||
from,
|
||||
@@ -285,7 +352,7 @@ func (c *Connector) formatHTMLEmailMessage(from, to, subject, htmlBody string) [
|
||||
time.Now().Format(time.RFC1123Z),
|
||||
htmlBody,
|
||||
)
|
||||
return []byte(message)
|
||||
return []byte(message), nil
|
||||
}
|
||||
|
||||
// formatAlertBody formats an alert notification as email body text.
|
||||
|
||||
@@ -138,7 +138,10 @@ func TestEmail_FormatMessage_RFC822Headers(t *testing.T) {
|
||||
subject := "Test Subject"
|
||||
body := "Test Body"
|
||||
|
||||
message := conn.formatEmailMessage(from, to, subject, body)
|
||||
message, err := conn.formatEmailMessage(from, to, subject, body)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
messageStr := string(message)
|
||||
|
||||
if !strings.Contains(messageStr, "From: "+from) {
|
||||
@@ -177,7 +180,10 @@ func TestEmail_FormatHTMLEmailMessage_Headers(t *testing.T) {
|
||||
subject := "HTML Test"
|
||||
htmlBody := "<html><body><h1>Test</h1></body></html>"
|
||||
|
||||
message := conn.formatHTMLEmailMessage(from, to, subject, htmlBody)
|
||||
message, err := conn.formatHTMLEmailMessage(from, to, subject, htmlBody)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
messageStr := string(message)
|
||||
|
||||
if !strings.Contains(messageStr, "From: "+from) {
|
||||
@@ -200,6 +206,67 @@ func TestEmail_FormatHTMLEmailMessage_Headers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmail_FormatEmailMessage_RejectsCRLFInjection exercises the CRLF
|
||||
// sanitizer (CWE-113). A subject containing "\r\nBcc: ..." must be rejected
|
||||
// rather than silently stripped — authentication-relevant headers are
|
||||
// security-critical and silent mutation masks malicious intent.
|
||||
func TestEmail_FormatEmailMessage_RejectsCRLFInjection(t *testing.T) {
|
||||
cfg := &Config{
|
||||
SMTPHost: "smtp.example.com",
|
||||
SMTPPort: 587,
|
||||
FromAddress: "sender@example.com",
|
||||
}
|
||||
logger := newTestLogger()
|
||||
conn := New(cfg, logger)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
from, to, sub string
|
||||
wantField string
|
||||
}{
|
||||
{"CRLF in Subject", "sender@example.com", "recipient@example.com", "hello\r\nBcc: attacker@example.com", "Subject"},
|
||||
{"LF in To", "sender@example.com", "recipient@example.com\nBcc: x@y", "ok", "To"},
|
||||
{"CR in From", "sender@example.com\rExtra: header", "recipient@example.com", "ok", "From"},
|
||||
{"NUL in Subject", "sender@example.com", "recipient@example.com", "hi\x00there", "Subject"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := conn.formatEmailMessage(tc.from, tc.to, tc.sub, "body")
|
||||
if err == nil {
|
||||
t.Fatal("expected injection error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.wantField) {
|
||||
t.Errorf("expected error to mention field %q, got %q", tc.wantField, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmail_FormatHTMLEmailMessage_RejectsCRLFInjection mirrors the plain-text
|
||||
// test for the HTML codepath used by the digest service.
|
||||
func TestEmail_FormatHTMLEmailMessage_RejectsCRLFInjection(t *testing.T) {
|
||||
cfg := &Config{
|
||||
SMTPHost: "smtp.example.com",
|
||||
SMTPPort: 587,
|
||||
FromAddress: "sender@example.com",
|
||||
}
|
||||
logger := newTestLogger()
|
||||
conn := New(cfg, logger)
|
||||
|
||||
_, err := conn.formatHTMLEmailMessage(
|
||||
"sender@example.com",
|
||||
"recipient@example.com",
|
||||
"digest\r\nBcc: attacker@example.com",
|
||||
"<p>hi</p>",
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected CRLF injection error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "Subject") {
|
||||
t.Errorf("expected error to mention Subject field, got %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmail_FormatAlertBody(t *testing.T) {
|
||||
cfg := &Config{
|
||||
SMTPHost: "smtp.example.com",
|
||||
|
||||
@@ -14,8 +14,15 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/connector/notifier"
|
||||
"github.com/shankar0123/certctl/internal/validation"
|
||||
)
|
||||
|
||||
// webhookClientTimeout bounds every outbound webhook request and its
|
||||
// resolution/dial phase. Kept as a package-level constant so the timeout is
|
||||
// shared by the transport dialer and the http.Client, and so tests can reason
|
||||
// about it without plumbing configuration.
|
||||
const webhookClientTimeout = 30 * time.Second
|
||||
|
||||
// Config represents the webhook notifier configuration.
|
||||
type Config struct {
|
||||
URL string `json:"url"`
|
||||
@@ -25,20 +32,69 @@ type Config struct {
|
||||
|
||||
// Connector implements the notifier.Connector interface for webhook notifications.
|
||||
// It sends alert and event notifications via HTTP POST with optional HMAC signing.
|
||||
//
|
||||
// validateURL is injected so that the production constructor (New) installs the
|
||||
// strict validation.ValidateSafeURL guard while newForTest can install a
|
||||
// permissive validator. This is the only way to keep the production SSRF
|
||||
// defence unconditionally on in real code while still allowing tests to point
|
||||
// at httptest loopback servers. Without this seam, every test using
|
||||
// httptest.NewServer would be blocked by the guard's loopback rejection — that
|
||||
// is the correct behaviour in production but makes legitimate unit tests
|
||||
// impossible to write. The test seam is unexported so no external caller can
|
||||
// use it to disable the guard.
|
||||
type Connector struct {
|
||||
config *Config
|
||||
logger *slog.Logger
|
||||
client *http.Client
|
||||
config *Config
|
||||
logger *slog.Logger
|
||||
client *http.Client
|
||||
validateURL func(string) error
|
||||
}
|
||||
|
||||
// New creates a new webhook notifier with the given configuration and logger.
|
||||
//
|
||||
// The returned connector uses an http.Transport whose DialContext is hardened
|
||||
// by validation.SafeHTTPDialContext. That guard re-resolves the target host
|
||||
// at dial time and refuses any connection whose resolved address lies in a
|
||||
// reserved range (loopback, cloud-metadata link-local, multicast, broadcast,
|
||||
// unspecified, IPv6 link-local/multicast). This is the authoritative SSRF
|
||||
// defence; validation.ValidateSafeURL inside ValidateConfig/postWebhook is a
|
||||
// fast early diagnostic. The two layers together defeat both misconfigured
|
||||
// URLs and DNS-rebinding attacks where a name's resolved address changes
|
||||
// between validation and dial.
|
||||
func New(config *Config, logger *slog.Logger) *Connector {
|
||||
transport := &http.Transport{
|
||||
DialContext: validation.SafeHTTPDialContext(webhookClientTimeout),
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
}
|
||||
return &Connector{
|
||||
config: config,
|
||||
logger: logger,
|
||||
client: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Timeout: webhookClientTimeout,
|
||||
Transport: transport,
|
||||
},
|
||||
validateURL: validation.ValidateSafeURL,
|
||||
}
|
||||
}
|
||||
|
||||
// newForTest is an unexported constructor used exclusively by the webhook
|
||||
// package's own tests. It installs a permissive URL validator and the stdlib
|
||||
// default transport so tests can point the connector at httptest loopback
|
||||
// servers (127.0.0.1), which the production SafeHTTPDialContext guard would
|
||||
// correctly reject. Production callers cannot reach this constructor because
|
||||
// it is unexported; only same-package tests (package webhook) can use it.
|
||||
// The SSRF-rejection tests that verify the guard itself still call New so
|
||||
// they exercise the real, strict validator.
|
||||
func newForTest(config *Config, logger *slog.Logger) *Connector {
|
||||
return &Connector{
|
||||
config: config,
|
||||
logger: logger,
|
||||
client: &http.Client{
|
||||
Timeout: webhookClientTimeout,
|
||||
},
|
||||
validateURL: func(string) error { return nil },
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,6 +110,18 @@ func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessag
|
||||
return fmt.Errorf("webhook url is required")
|
||||
}
|
||||
|
||||
// SSRF guard (CWE-918). Reject reserved-address URLs before issuing any
|
||||
// outbound HTTP — this catches the obvious 127.0.0.1 / ::1 /
|
||||
// 169.254.169.254 / 0.0.0.0 cases at config-ingestion time and produces
|
||||
// a clear operator-facing error. The authoritative, TOCTOU-safe check
|
||||
// still runs at dial time inside SafeHTTPDialContext. Routed through
|
||||
// c.validateURL so newForTest can install a permissive validator for
|
||||
// same-package unit tests; production New always wires
|
||||
// validation.ValidateSafeURL here.
|
||||
if err := c.validateURL(cfg.URL); err != nil {
|
||||
return fmt.Errorf("webhook url rejected: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("validating webhook configuration", "url", cfg.URL)
|
||||
|
||||
// Test webhook connectivity with a HEAD request
|
||||
@@ -150,7 +218,17 @@ func (c *Connector) SendEvent(ctx context.Context, event notifier.Event) error {
|
||||
// postWebhook sends a payload to the webhook URL with proper headers and signing.
|
||||
// If a secret is configured, it signs the payload using HMAC-SHA256 and includes
|
||||
// the signature in the X-Signature header.
|
||||
//
|
||||
// The URL is re-validated here even though ValidateConfig already accepted it:
|
||||
// configuration can be mutated in place, reloaded dynamically, or set directly
|
||||
// by tests that bypass ValidateConfig, so this call is a defence-in-depth
|
||||
// guard that fails closed before any outbound request is built. Authoritative
|
||||
// DNS-rebinding defence still runs at dial time via SafeHTTPDialContext.
|
||||
func (c *Connector) postWebhook(ctx context.Context, payload interface{}) error {
|
||||
if err := c.validateURL(c.config.URL); err != nil {
|
||||
return fmt.Errorf("webhook url rejected: %w", err)
|
||||
}
|
||||
|
||||
// Marshal payload to JSON
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
|
||||
@@ -32,7 +32,7 @@ func TestWebhook_ValidateConfig_ValidURL(t *testing.T) {
|
||||
|
||||
// Create a new logger (or use test logger)
|
||||
logger := newTestLogger()
|
||||
conn := New(cfg, logger)
|
||||
conn := newForTest(cfg, logger)
|
||||
|
||||
err := conn.ValidateConfig(context.Background(), rawConfig)
|
||||
if err != nil {
|
||||
@@ -47,7 +47,7 @@ func TestWebhook_ValidateConfig_MissingURL(t *testing.T) {
|
||||
|
||||
rawConfig, _ := json.Marshal(cfg)
|
||||
logger := newTestLogger()
|
||||
conn := New(cfg, logger)
|
||||
conn := newForTest(cfg, logger)
|
||||
|
||||
err := conn.ValidateConfig(context.Background(), rawConfig)
|
||||
if err == nil {
|
||||
@@ -96,7 +96,7 @@ func TestWebhook_SendAlert_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
logger := newTestLogger()
|
||||
conn := New(cfg, logger)
|
||||
conn := newForTest(cfg, logger)
|
||||
|
||||
alert := notifier.Alert{
|
||||
ID: "alert-123",
|
||||
@@ -160,7 +160,7 @@ func TestWebhook_SendAlert_HMACSignature(t *testing.T) {
|
||||
}
|
||||
|
||||
logger := newTestLogger()
|
||||
conn := New(cfg, logger)
|
||||
conn := newForTest(cfg, logger)
|
||||
|
||||
alert := notifier.Alert{
|
||||
ID: "alert-456",
|
||||
@@ -199,7 +199,7 @@ func TestWebhook_SendAlert_NoSignatureWithoutSecret(t *testing.T) {
|
||||
}
|
||||
|
||||
logger := newTestLogger()
|
||||
conn := New(cfg, logger)
|
||||
conn := newForTest(cfg, logger)
|
||||
|
||||
alert := notifier.Alert{
|
||||
ID: "alert-789",
|
||||
@@ -239,7 +239,7 @@ func TestWebhook_SendAlert_CustomHeaders(t *testing.T) {
|
||||
}
|
||||
|
||||
logger := newTestLogger()
|
||||
conn := New(cfg, logger)
|
||||
conn := newForTest(cfg, logger)
|
||||
|
||||
alert := notifier.Alert{
|
||||
ID: "alert-custom",
|
||||
@@ -276,7 +276,7 @@ func TestWebhook_SendAlert_HTTPError(t *testing.T) {
|
||||
}
|
||||
|
||||
logger := newTestLogger()
|
||||
conn := New(cfg, logger)
|
||||
conn := newForTest(cfg, logger)
|
||||
|
||||
alert := notifier.Alert{
|
||||
ID: "alert-error",
|
||||
@@ -318,7 +318,7 @@ func TestWebhook_SendEvent_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
logger := newTestLogger()
|
||||
conn := New(cfg, logger)
|
||||
conn := newForTest(cfg, logger)
|
||||
|
||||
certID := "mc-api-prod"
|
||||
event := notifier.Event{
|
||||
@@ -367,7 +367,7 @@ func TestWebhook_SendEvent_WithoutCertificateID(t *testing.T) {
|
||||
}
|
||||
|
||||
logger := newTestLogger()
|
||||
conn := New(cfg, logger)
|
||||
conn := newForTest(cfg, logger)
|
||||
|
||||
event := notifier.Event{
|
||||
ID: "event-456",
|
||||
@@ -389,6 +389,130 @@ func TestWebhook_SendEvent_WithoutCertificateID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// The SSRF tests below exercise the CWE-918 guard added alongside H-4. Each
|
||||
// case pairs a reserved-address URL with the call surface that should reject
|
||||
// it. ValidateConfig is the early-fail path; SendAlert/SendEvent reach the
|
||||
// same guard via postWebhook and are the defence-in-depth that still rejects
|
||||
// even when ValidateConfig was bypassed (e.g. dynamic config reload mutating
|
||||
// c.config.URL in place).
|
||||
|
||||
func TestWebhook_ValidateConfig_RejectsReservedURLs(t *testing.T) {
|
||||
// These must all fail at config-ingestion time without ever opening a
|
||||
// socket — the reserved-address filter is the whole point of H-4.
|
||||
cases := []struct {
|
||||
name string
|
||||
url string
|
||||
}{
|
||||
{"loopback v4", "http://127.0.0.1/hook"},
|
||||
{"loopback v4 with port", "http://127.0.0.1:8080/"},
|
||||
{"loopback v6 bracketed", "http://[::1]/hook"},
|
||||
{"AWS metadata", "http://169.254.169.254/latest/meta-data/"},
|
||||
{"generic link-local", "http://169.254.1.2/"},
|
||||
{"unspecified v4", "http://0.0.0.0/"},
|
||||
{"unspecified v6", "http://[::]/"},
|
||||
{"IPv6 link-local", "http://[fe80::1]/"},
|
||||
{"multicast", "https://224.0.0.5/"},
|
||||
{"broadcast", "http://255.255.255.255/"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cfg := &Config{URL: tc.url}
|
||||
rawConfig, _ := json.Marshal(cfg)
|
||||
conn := New(cfg, newTestLogger())
|
||||
|
||||
err := conn.ValidateConfig(context.Background(), rawConfig)
|
||||
if err == nil {
|
||||
t.Fatalf("ValidateConfig(%q) returned nil, want SSRF rejection", tc.url)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "reserved") && !strings.Contains(err.Error(), "rejected") {
|
||||
t.Errorf("expected reserved/rejected error, got %q", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhook_ValidateConfig_RejectsDangerousSchemes(t *testing.T) {
|
||||
// Only http(s) is a legitimate webhook transport. Every other scheme is
|
||||
// an SSRF amplifier (file, gopher, ftp, javascript, data, ldap, dict,
|
||||
// jar) and must be refused at config time.
|
||||
cases := []struct {
|
||||
name string
|
||||
url string
|
||||
}{
|
||||
{"file", "file:///etc/passwd"},
|
||||
{"gopher", "gopher://example.com/_x"},
|
||||
{"ftp", "ftp://example.com/"},
|
||||
{"javascript", "javascript:alert(1)"},
|
||||
{"data", "data:text/plain;base64,SGVsbG8="},
|
||||
{"ldap", "ldap://example.com/"},
|
||||
{"dict", "dict://example.com:2628/d:foo"},
|
||||
{"jar", "jar:http://example.com/foo.jar!/"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cfg := &Config{URL: tc.url}
|
||||
rawConfig, _ := json.Marshal(cfg)
|
||||
conn := New(cfg, newTestLogger())
|
||||
|
||||
err := conn.ValidateConfig(context.Background(), rawConfig)
|
||||
if err == nil {
|
||||
t.Fatalf("ValidateConfig(%q) returned nil, want scheme rejection", tc.url)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "rejected") && !strings.Contains(err.Error(), "scheme") {
|
||||
t.Errorf("expected scheme/rejected error, got %q", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhook_SendAlert_RejectsReservedURLInPostWebhook(t *testing.T) {
|
||||
// Simulate config drift: URL was legitimate at ValidateConfig time but
|
||||
// has since been rewritten to an SSRF target. postWebhook must catch
|
||||
// this on every call without ever hitting the wire.
|
||||
cfg := &Config{URL: "http://169.254.169.254/latest/meta-data/"}
|
||||
conn := New(cfg, newTestLogger())
|
||||
|
||||
alert := notifier.Alert{
|
||||
ID: "alert-ssrf",
|
||||
Type: "test",
|
||||
Severity: "info",
|
||||
Subject: "Test",
|
||||
Message: "Test",
|
||||
Recipient: "ops@example.com",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
err := conn.SendAlert(context.Background(), alert)
|
||||
if err == nil {
|
||||
t.Fatal("SendAlert returned nil, want SSRF rejection from postWebhook")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "reserved") && !strings.Contains(err.Error(), "rejected") {
|
||||
t.Errorf("expected reserved/rejected error, got %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhook_SendEvent_RejectsReservedURLInPostWebhook(t *testing.T) {
|
||||
cfg := &Config{URL: "http://[::1]:9/webhook"}
|
||||
conn := New(cfg, newTestLogger())
|
||||
|
||||
event := notifier.Event{
|
||||
ID: "event-ssrf",
|
||||
Type: "test",
|
||||
Subject: "Test",
|
||||
Body: "Test",
|
||||
Recipient: "ops@example.com",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
err := conn.SendEvent(context.Background(), event)
|
||||
if err == nil {
|
||||
t.Fatal("SendEvent returned nil, want SSRF rejection from postWebhook")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "reserved") && !strings.Contains(err.Error(), "rejected") {
|
||||
t.Errorf("expected reserved/rejected error, got %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to compute HMAC-SHA256 signature
|
||||
func computeHMACSHA256(data []byte, secret string) string {
|
||||
h := hmac.New(sha256.New, []byte(secret))
|
||||
|
||||
@@ -6,12 +6,29 @@ import (
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
// ErrEncryptionKeyRequired is returned by EncryptIfKeySet and DecryptIfKeySet when
|
||||
// the caller provides an empty key but the data on the wire requires protection.
|
||||
//
|
||||
// Historically these helpers silently returned plaintext when no key was configured,
|
||||
// which produced a data-at-rest confidentiality bypass (CWE-311): sensitive fields
|
||||
// in dynamically-configured issuer and target records (source='database') were
|
||||
// persisted to PostgreSQL without any encryption whenever the operator forgot to
|
||||
// set CERTCTL_CONFIG_ENCRYPTION_KEY. Callers could not distinguish the encrypted
|
||||
// and plaintext branches at runtime, so the only visible signal was a warning
|
||||
// line emitted once at startup.
|
||||
//
|
||||
// The fix is to fail closed: EncryptIfKeySet/DecryptIfKeySet now require a key
|
||||
// whenever they are invoked on sensitive material, and the server refuses to
|
||||
// start if any source='database' rows already exist without a configured key.
|
||||
var ErrEncryptionKeyRequired = errors.New("crypto: CERTCTL_CONFIG_ENCRYPTION_KEY is required to encrypt or decrypt sensitive config")
|
||||
|
||||
// Encrypt encrypts plaintext using AES-256-GCM with a random 12-byte nonce prepended to the output.
|
||||
// The key must be exactly 32 bytes (AES-256). Returns [12-byte nonce][ciphertext+tag].
|
||||
func Encrypt(plaintext []byte, key []byte) ([]byte, error) {
|
||||
@@ -81,11 +98,17 @@ func DeriveKey(passphrase string) []byte {
|
||||
return pbkdf2.Key([]byte(passphrase), salt, 100000, 32, sha256.New)
|
||||
}
|
||||
|
||||
// EncryptIfKeySet encrypts plaintext if a key is provided, otherwise returns plaintext unchanged.
|
||||
// This supports the development/demo fallback where encryption isn't configured.
|
||||
// EncryptIfKeySet encrypts plaintext with the supplied 32-byte AES-256 key.
|
||||
//
|
||||
// The second return value is always true when err == nil — the "wasEncrypted"
|
||||
// flag is retained for source-compatibility with callers that previously used it
|
||||
// to log provenance. Callers MUST handle err: passing an empty key now returns
|
||||
// ErrEncryptionKeyRequired rather than silently emitting plaintext. See the
|
||||
// package-level ErrEncryptionKeyRequired documentation for the history behind
|
||||
// this behavior change.
|
||||
func EncryptIfKeySet(plaintext []byte, key []byte) ([]byte, bool, error) {
|
||||
if len(key) == 0 {
|
||||
return plaintext, false, nil
|
||||
return nil, false, ErrEncryptionKeyRequired
|
||||
}
|
||||
encrypted, err := Encrypt(plaintext, key)
|
||||
if err != nil {
|
||||
@@ -94,10 +117,17 @@ func EncryptIfKeySet(plaintext []byte, key []byte) ([]byte, bool, error) {
|
||||
return encrypted, true, nil
|
||||
}
|
||||
|
||||
// DecryptIfKeySet decrypts ciphertext if a key is provided, otherwise returns ciphertext unchanged.
|
||||
// DecryptIfKeySet decrypts ciphertext with the supplied 32-byte AES-256 key.
|
||||
//
|
||||
// Passing an empty key now returns ErrEncryptionKeyRequired. Callers that
|
||||
// legitimately store plaintext (e.g. env-seeded source='env' rows that keep
|
||||
// the raw JSON in the unencrypted `config` column) must branch on the presence
|
||||
// of the ciphertext themselves rather than relying on this helper to silently
|
||||
// pass bytes through. See the package-level ErrEncryptionKeyRequired
|
||||
// documentation for the history behind this behavior change.
|
||||
func DecryptIfKeySet(ciphertext []byte, key []byte) ([]byte, error) {
|
||||
if len(key) == 0 {
|
||||
return ciphertext, nil
|
||||
return nil, ErrEncryptionKeyRequired
|
||||
}
|
||||
return Decrypt(ciphertext, key)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -148,31 +149,140 @@ func TestEncryptIfKeySet_WithKey(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptIfKeySet_NilKey(t *testing.T) {
|
||||
// TestEncryptIfKeySet_EmptyKeyFailsClosed asserts the C-2 regression guard:
|
||||
// EncryptIfKeySet must refuse to silently emit plaintext when no key is configured.
|
||||
// The pre-fix behavior was to return plaintext with wasEncrypted=false, which
|
||||
// produced a data-at-rest confidentiality bypass (CWE-311) for GUI-created
|
||||
// issuer and target configs.
|
||||
func TestEncryptIfKeySet_EmptyKeyFailsClosed(t *testing.T) {
|
||||
plaintext := []byte("config data")
|
||||
|
||||
result, wasEncrypted, err := EncryptIfKeySet(plaintext, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptIfKeySet with nil key failed: %v", err)
|
||||
cases := []struct {
|
||||
name string
|
||||
key []byte
|
||||
}{
|
||||
{"nil_key", nil},
|
||||
{"empty_key", []byte{}},
|
||||
}
|
||||
if wasEncrypted {
|
||||
t.Fatal("expected wasEncrypted=false when key is nil")
|
||||
}
|
||||
if !bytes.Equal(result, plaintext) {
|
||||
t.Fatal("result should be unchanged plaintext when key is nil")
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result, wasEncrypted, err := EncryptIfKeySet(plaintext, tc.key)
|
||||
if err == nil {
|
||||
t.Fatal("expected ErrEncryptionKeyRequired, got nil")
|
||||
}
|
||||
if !errors.Is(err, ErrEncryptionKeyRequired) {
|
||||
t.Fatalf("expected ErrEncryptionKeyRequired, got %v", err)
|
||||
}
|
||||
if wasEncrypted {
|
||||
t.Fatal("wasEncrypted must be false on error")
|
||||
}
|
||||
if result != nil {
|
||||
t.Fatalf("expected nil result on error, got %q", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptIfKeySet_NilKey(t *testing.T) {
|
||||
// TestDecryptIfKeySet_EmptyKeyFailsClosed asserts the matching C-2 regression
|
||||
// guard on the read path: DecryptIfKeySet must refuse to pass ciphertext
|
||||
// through as plaintext when no key is configured.
|
||||
func TestDecryptIfKeySet_EmptyKeyFailsClosed(t *testing.T) {
|
||||
data := []byte("plaintext config data")
|
||||
|
||||
result, err := DecryptIfKeySet(data, nil)
|
||||
cases := []struct {
|
||||
name string
|
||||
key []byte
|
||||
}{
|
||||
{"nil_key", nil},
|
||||
{"empty_key", []byte{}},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result, err := DecryptIfKeySet(data, tc.key)
|
||||
if err == nil {
|
||||
t.Fatal("expected ErrEncryptionKeyRequired, got nil")
|
||||
}
|
||||
if !errors.Is(err, ErrEncryptionKeyRequired) {
|
||||
t.Fatalf("expected ErrEncryptionKeyRequired, got %v", err)
|
||||
}
|
||||
if result != nil {
|
||||
t.Fatalf("expected nil result on error, got %q", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncryptDecryptIfKeySet_RoundTripProducesDifferentCiphertext proves the
|
||||
// "if set" helpers produce real AES-GCM output (not plaintext) and that a full
|
||||
// round-trip through both helpers recovers the original bytes.
|
||||
func TestEncryptDecryptIfKeySet_RoundTripProducesDifferentCiphertext(t *testing.T) {
|
||||
key := DeriveKey("round-trip-key")
|
||||
plaintext := []byte(`{"api_key":"s3cr3t","token":"abc"}`)
|
||||
|
||||
encrypted, wasEncrypted, err := EncryptIfKeySet(plaintext, key)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptIfKeySet with nil key failed: %v", err)
|
||||
t.Fatalf("EncryptIfKeySet failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(result, data) {
|
||||
t.Fatal("result should be unchanged when key is nil")
|
||||
if !wasEncrypted {
|
||||
t.Fatal("wasEncrypted must be true when key is present")
|
||||
}
|
||||
if bytes.Equal(encrypted, plaintext) {
|
||||
t.Fatal("EncryptIfKeySet returned plaintext — would regress C-2")
|
||||
}
|
||||
|
||||
decrypted, err := DecryptIfKeySet(encrypted, key)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptIfKeySet failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(decrypted, plaintext) {
|
||||
t.Fatalf("round-trip mismatch: got %q, want %q", decrypted, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecryptIfKeySet_RejectsTamperedCiphertext confirms the AEAD auth tag
|
||||
// still rejects modified ciphertext when routed through the helper.
|
||||
func TestDecryptIfKeySet_RejectsTamperedCiphertext(t *testing.T) {
|
||||
key := DeriveKey("tamper-test-key")
|
||||
plaintext := []byte("authenticated data")
|
||||
|
||||
encrypted, _, err := EncryptIfKeySet(plaintext, key)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptIfKeySet failed: %v", err)
|
||||
}
|
||||
// Flip a byte inside the GCM body (past the 12-byte nonce) to invalidate the tag.
|
||||
if len(encrypted) <= 13 {
|
||||
t.Fatalf("ciphertext too short to tamper: %d bytes", len(encrypted))
|
||||
}
|
||||
encrypted[13] ^= 0xFF
|
||||
|
||||
if _, err := DecryptIfKeySet(encrypted, key); err == nil {
|
||||
t.Fatal("DecryptIfKeySet accepted tampered ciphertext — AEAD tag check bypassed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncryptIfKeySet_PreservesErrEncryptionKeyRequiredSentinel guards the
|
||||
// stability of the public sentinel error so audit-log detectors and callers
|
||||
// outside this package can rely on errors.Is(err, ErrEncryptionKeyRequired).
|
||||
func TestEncryptIfKeySet_PreservesErrEncryptionKeyRequiredSentinel(t *testing.T) {
|
||||
if ErrEncryptionKeyRequired == nil {
|
||||
t.Fatal("ErrEncryptionKeyRequired sentinel must be non-nil")
|
||||
}
|
||||
if ErrEncryptionKeyRequired.Error() == "" {
|
||||
t.Fatal("ErrEncryptionKeyRequired must carry a non-empty message")
|
||||
}
|
||||
// Wrap it and confirm errors.Is unwraps correctly — real callers wrap with %w.
|
||||
wrapped := wrapSentinel(ErrEncryptionKeyRequired)
|
||||
if !errors.Is(wrapped, ErrEncryptionKeyRequired) {
|
||||
t.Fatal("errors.Is must unwrap ErrEncryptionKeyRequired through %w-wrapped callers")
|
||||
}
|
||||
}
|
||||
|
||||
// wrapSentinel is a tiny helper that mimics how production callers propagate
|
||||
// the sentinel (e.g. fmt.Errorf("failed to encrypt config: %w", err)).
|
||||
func wrapSentinel(err error) error {
|
||||
return errors.Join(errors.New("failed to encrypt config"), err)
|
||||
}
|
||||
|
||||
func TestEncryptProducesDifferentCiphertexts(t *testing.T) {
|
||||
|
||||
@@ -66,7 +66,12 @@ func TestCertificateLifecycle(t *testing.T) {
|
||||
deploymentService := service.NewDeploymentService(jobRepo, targetRepo, agentRepo, certRepo, auditService, notificationService)
|
||||
jobService := service.NewJobService(jobRepo, renewalService, deploymentService, logger)
|
||||
agentService := service.NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry, renewalService)
|
||||
issuerService := service.NewIssuerService(issuerRepo, auditService, issuerRegistry, nil, slog.Default())
|
||||
// 32-byte AES-256 test key — C-2 remediation makes IssuerService fail closed
|
||||
// without a configured CERTCTL_CONFIG_ENCRYPTION_KEY. Happy-path CRUD tests
|
||||
// must supply a real key so the encrypt path runs instead of returning
|
||||
// ErrEncryptionKeyRequired.
|
||||
testEncryptionKey := []byte("0123456789abcdef0123456789abcdef")
|
||||
issuerService := service.NewIssuerService(issuerRepo, auditService, issuerRegistry, testEncryptionKey, slog.Default())
|
||||
|
||||
// Initialize handlers
|
||||
certificateHandler := handler.NewCertificateHandler(certificateService)
|
||||
@@ -677,6 +682,46 @@ func (m *mockJobRepository) ListPendingByAgentID(ctx context.Context, agentID st
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ClaimPendingJobs mirrors the production H-6 semantics: Pending jobs of the given type
|
||||
// (or any type when jobType is empty) flip to Running before being returned. limit <= 0
|
||||
// means unlimited.
|
||||
func (m *mockJobRepository) ClaimPendingJobs(ctx context.Context, jobType domain.JobType, limit int) ([]*domain.Job, error) {
|
||||
var claimed []*domain.Job
|
||||
for _, j := range m.jobs {
|
||||
if j.Status != domain.JobStatusPending {
|
||||
continue
|
||||
}
|
||||
if jobType != "" && j.Type != jobType {
|
||||
continue
|
||||
}
|
||||
j.Status = domain.JobStatusRunning
|
||||
claimed = append(claimed, j)
|
||||
if limit > 0 && len(claimed) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
return claimed, nil
|
||||
}
|
||||
|
||||
// ClaimPendingByAgentID mirrors the production H-6 semantics: Pending deployment rows for
|
||||
// the agent flip to Running; AwaitingCSR rows are returned with state preserved.
|
||||
func (m *mockJobRepository) ClaimPendingByAgentID(ctx context.Context, agentID string) ([]*domain.Job, error) {
|
||||
var result []*domain.Job
|
||||
for _, j := range m.jobs {
|
||||
if j.AgentID == nil || *j.AgentID != agentID {
|
||||
continue
|
||||
}
|
||||
switch {
|
||||
case j.Status == domain.JobStatusPending && j.Type == domain.JobTypeDeployment:
|
||||
j.Status = domain.JobStatusRunning
|
||||
result = append(result, j)
|
||||
case j.Status == domain.JobStatusAwaitingCSR:
|
||||
result = append(result, j)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type mockAuditRepository struct {
|
||||
events []*domain.AuditEvent
|
||||
}
|
||||
@@ -1134,9 +1179,9 @@ func (m *mockRevocationRepository) Create(ctx context.Context, revocation *domai
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockRevocationRepository) GetBySerial(ctx context.Context, serial string) (*domain.CertificateRevocation, error) {
|
||||
func (m *mockRevocationRepository) GetByIssuerAndSerial(ctx context.Context, issuerID, serial string) (*domain.CertificateRevocation, error) {
|
||||
for _, r := range m.revocations {
|
||||
if r.SerialNumber == serial {
|
||||
if r.IssuerID == issuerID && r.SerialNumber == serial {
|
||||
return r, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +58,12 @@ func setupTestServer(t *testing.T) (*httptest.Server, *mockCertificateRepository
|
||||
deploymentService := service.NewDeploymentService(jobRepo, targetRepo, agentRepo, certRepo, auditService, notificationService)
|
||||
jobService := service.NewJobService(jobRepo, renewalService, deploymentService, logger)
|
||||
agentService := service.NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry, renewalService)
|
||||
issuerService := service.NewIssuerService(issuerRepo, auditService, issuerRegistry, nil, logger)
|
||||
// 32-byte AES-256 test key — C-2 remediation makes IssuerService fail closed
|
||||
// without a configured CERTCTL_CONFIG_ENCRYPTION_KEY. Happy-path CRUD tests
|
||||
// must supply a real key so the encrypt path runs instead of returning
|
||||
// ErrEncryptionKeyRequired.
|
||||
testEncryptionKey := []byte("0123456789abcdef0123456789abcdef")
|
||||
issuerService := service.NewIssuerService(issuerRepo, auditService, issuerRegistry, testEncryptionKey, logger)
|
||||
|
||||
certificateHandler := handler.NewCertificateHandler(certificateService)
|
||||
issuerHandler := handler.NewIssuerHandler(issuerService)
|
||||
|
||||
@@ -31,10 +31,15 @@ type CertificateRepository interface {
|
||||
|
||||
// RevocationRepository defines operations for managing certificate revocations.
|
||||
type RevocationRepository interface {
|
||||
// Create records a new certificate revocation.
|
||||
// Create records a new certificate revocation. Uniqueness is scoped to
|
||||
// (issuer_id, serial_number) per RFC 5280 §5.2.3, so duplicate serials
|
||||
// across different issuers are permitted.
|
||||
Create(ctx context.Context, revocation *domain.CertificateRevocation) error
|
||||
// GetBySerial retrieves a revocation by serial number.
|
||||
GetBySerial(ctx context.Context, serial string) (*domain.CertificateRevocation, error)
|
||||
// GetByIssuerAndSerial retrieves a revocation by the (issuer_id, serial_number)
|
||||
// pair. Callers (OCSP, CRL generation) always know the issuer because
|
||||
// protocol endpoints carry it in the request path; RFC 5280 §5.2.3 guarantees
|
||||
// uniqueness only within a single issuer.
|
||||
GetByIssuerAndSerial(ctx context.Context, issuerID, serial string) (*domain.CertificateRevocation, error)
|
||||
// ListAll returns all revocations, ordered by revocation time (for CRL generation).
|
||||
ListAll(ctx context.Context) ([]*domain.CertificateRevocation, error)
|
||||
// ListByCertificate returns all revocations for a certificate.
|
||||
@@ -115,10 +120,20 @@ type JobRepository interface {
|
||||
ListByCertificate(ctx context.Context, certID string) ([]*domain.Job, error)
|
||||
// UpdateStatus updates a job's status and optional error message.
|
||||
UpdateStatus(ctx context.Context, id string, status domain.JobStatus, errMsg string) error
|
||||
// GetPendingJobs returns jobs not yet processed of a specific type.
|
||||
// GetPendingJobs returns jobs not yet processed of a specific type. Prefer ClaimPendingJobs in
|
||||
// production paths where concurrent schedulers may race — see H-6 (CWE-362) remediation.
|
||||
GetPendingJobs(ctx context.Context, jobType domain.JobType) ([]*domain.Job, error)
|
||||
// ListPendingByAgentID returns pending deployment jobs and AwaitingCSR jobs for a specific agent.
|
||||
// Prefer ClaimPendingByAgentID in production paths — see H-6 (CWE-362) remediation.
|
||||
ListPendingByAgentID(ctx context.Context, agentID string) ([]*domain.Job, error)
|
||||
// ClaimPendingJobs atomically claims up to `limit` Pending jobs and transitions them to Running
|
||||
// using SELECT FOR UPDATE SKIP LOCKED inside a transaction. An empty jobType matches any type;
|
||||
// limit <= 0 means no limit. H-6 (CWE-362) race remediation.
|
||||
ClaimPendingJobs(ctx context.Context, jobType domain.JobType, limit int) ([]*domain.Job, error)
|
||||
// ClaimPendingByAgentID atomically claims pending deployment jobs for an agent (flipping them
|
||||
// to Running) and locks AwaitingCSR jobs against concurrent observers (leaving state intact,
|
||||
// since the CSR-submission path drives the next transition). H-6 (CWE-362) race remediation.
|
||||
ClaimPendingByAgentID(ctx context.Context, agentID string) ([]*domain.Job, error)
|
||||
}
|
||||
|
||||
// RenewalPolicyRepository defines operations for managing renewal policies.
|
||||
|
||||
@@ -237,7 +237,14 @@ func (r *JobRepository) UpdateStatus(ctx context.Context, id string, status doma
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPendingJobs returns jobs not yet processed of a specific type
|
||||
// GetPendingJobs returns jobs not yet processed of a specific type.
|
||||
//
|
||||
// The SELECT uses FOR UPDATE SKIP LOCKED so that concurrent scheduler replicas
|
||||
// cannot observe the same rows when invoked inside a transaction; combine with
|
||||
// a subsequent UPDATE to Running for correct dispatch semantics. For the
|
||||
// standard production dispatch path, prefer ClaimPendingJobs which wraps the
|
||||
// lock, read, and state transition in a single transaction and is the
|
||||
// authoritative race-free claim primitive (CWE-362 fix for H-6).
|
||||
func (r *JobRepository) GetPendingJobs(ctx context.Context, jobType domain.JobType) ([]*domain.Job, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, type, certificate_id, target_id, agent_id, status, attempts, max_attempts,
|
||||
@@ -245,6 +252,7 @@ func (r *JobRepository) GetPendingJobs(ctx context.Context, jobType domain.JobTy
|
||||
FROM jobs
|
||||
WHERE type = $1 AND status = $2
|
||||
ORDER BY scheduled_at ASC
|
||||
FOR UPDATE SKIP LOCKED
|
||||
`, jobType, domain.JobStatusPending)
|
||||
|
||||
if err != nil {
|
||||
@@ -268,10 +276,115 @@ func (r *JobRepository) GetPendingJobs(ctx context.Context, jobType domain.JobTy
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
// ListPendingByAgentID returns pending deployment jobs and AwaitingCSR jobs for a specific agent.
|
||||
// Deployment jobs are matched by agent_id directly (set at creation time), with a fallback
|
||||
// for legacy jobs where agent_id is NULL but target_id resolves to the agent via deployment_targets.
|
||||
// AwaitingCSR jobs are matched through certificate → target mappings → agent ownership.
|
||||
// ClaimPendingJobs atomically claims up to `limit` Pending jobs and transitions
|
||||
// them to Running inside a single transaction. The SELECT uses FOR UPDATE SKIP
|
||||
// LOCKED so concurrent scheduler replicas observe disjoint result sets — each
|
||||
// row can be claimed by exactly one caller per tick (CWE-362 fix for H-6).
|
||||
//
|
||||
// Passing an empty jobType claims any type. Passing limit<=0 claims all
|
||||
// available rows. The claimed rows are returned with Status already set to
|
||||
// domain.JobStatusRunning.
|
||||
//
|
||||
// Downstream processors (ProcessRenewalJob, ProcessDeploymentJob) already call
|
||||
// UpdateStatus(Running) unconditionally on entry, so this pre-flip is
|
||||
// idempotent with respect to existing processing logic.
|
||||
func (r *JobRepository) ClaimPendingJobs(ctx context.Context, jobType domain.JobType, limit int) ([]*domain.Job, error) {
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to begin claim transaction: %w", err)
|
||||
}
|
||||
// Rollback is a no-op after Commit — safe deferred cleanup if an error path
|
||||
// triggers an early return before Commit().
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
// Build the SELECT — jobType="" means any type, limit<=0 means unlimited.
|
||||
query := `
|
||||
SELECT id, type, certificate_id, target_id, agent_id, status, attempts, max_attempts,
|
||||
last_error, scheduled_at, started_at, completed_at, created_at
|
||||
FROM jobs
|
||||
WHERE status = $1`
|
||||
args := []interface{}{domain.JobStatusPending}
|
||||
if jobType != "" {
|
||||
query += ` AND type = $2`
|
||||
args = append(args, jobType)
|
||||
}
|
||||
query += `
|
||||
ORDER BY scheduled_at ASC
|
||||
FOR UPDATE SKIP LOCKED`
|
||||
if limit > 0 {
|
||||
query += fmt.Sprintf(` LIMIT %d`, limit)
|
||||
}
|
||||
|
||||
rows, err := tx.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query claimable jobs: %w", err)
|
||||
}
|
||||
|
||||
var jobs []*domain.Job
|
||||
for rows.Next() {
|
||||
job, err := scanJob(rows)
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
rows.Close()
|
||||
return nil, fmt.Errorf("error iterating claimable job rows: %w", err)
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
if len(jobs) == 0 {
|
||||
// No rows to claim — commit the (read-only) tx and return.
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("failed to commit empty claim tx: %w", err)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Flip claimed rows to Running. Build IN clause safely with placeholders.
|
||||
ids := make([]interface{}, len(jobs))
|
||||
placeholders := make([]byte, 0, len(jobs)*5)
|
||||
for i, job := range jobs {
|
||||
ids[i] = job.ID
|
||||
if i > 0 {
|
||||
placeholders = append(placeholders, ',')
|
||||
}
|
||||
placeholders = append(placeholders, fmt.Sprintf("$%d", i+2)...)
|
||||
}
|
||||
updateQuery := fmt.Sprintf(
|
||||
`UPDATE jobs SET status = $1 WHERE id IN (%s)`,
|
||||
string(placeholders),
|
||||
)
|
||||
updateArgs := append([]interface{}{domain.JobStatusRunning}, ids...)
|
||||
if _, err := tx.ExecContext(ctx, updateQuery, updateArgs...); err != nil {
|
||||
return nil, fmt.Errorf("failed to transition claimed jobs to Running: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("failed to commit claim transaction: %w", err)
|
||||
}
|
||||
|
||||
// Reflect the committed state in the returned objects.
|
||||
for _, job := range jobs {
|
||||
job.Status = domain.JobStatusRunning
|
||||
}
|
||||
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
// ListPendingByAgentID returns pending deployment jobs and AwaitingCSR jobs for
|
||||
// a specific agent. Deployment jobs are matched by agent_id directly (set at
|
||||
// creation time), with a fallback for legacy jobs where agent_id is NULL but
|
||||
// target_id resolves to the agent via deployment_targets. AwaitingCSR jobs are
|
||||
// matched through certificate → target mappings → agent ownership.
|
||||
//
|
||||
// The SELECT uses FOR UPDATE SKIP LOCKED so concurrent pollers (e.g. two agent
|
||||
// instances running with the same agent_id) cannot observe the same rows when
|
||||
// this method is invoked inside a transaction. For the production agent work
|
||||
// poll path, prefer ClaimPendingByAgentID which additionally transitions
|
||||
// claimed Pending deployment rows to Running atomically (H-6 CWE-362 fix).
|
||||
func (r *JobRepository) ListPendingByAgentID(ctx context.Context, agentID string) ([]*domain.Job, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, type, certificate_id, target_id, agent_id, status, attempts, max_attempts,
|
||||
@@ -326,6 +439,137 @@ func (r *JobRepository) ListPendingByAgentID(ctx context.Context, agentID string
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
// ClaimPendingByAgentID atomically claims agent work inside a single
|
||||
// transaction. Pending Deployment jobs assigned to the agent (directly via
|
||||
// agent_id, or via legacy target→agent fallback) are transitioned from
|
||||
// Pending to Running. AwaitingCSR Renewal/Issuance jobs linked to the agent
|
||||
// via certificate → target mappings are locked with FOR UPDATE SKIP LOCKED
|
||||
// and returned without a state transition — the flow requires the agent to
|
||||
// submit a CSR to advance state, and pre-flipping AwaitingCSR would violate
|
||||
// the renewal state machine (CWE-362 fix for H-6).
|
||||
//
|
||||
// Claimed rows are invisible to other concurrent claim calls for the lifetime
|
||||
// of the transaction; rows claimed as Running remain invisible after commit
|
||||
// because ListPendingByAgentID's filter is status='Pending'.
|
||||
func (r *JobRepository) ClaimPendingByAgentID(ctx context.Context, agentID string) ([]*domain.Job, error) {
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to begin agent claim transaction: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
// Branch 1 + 2: Pending Deployment jobs (direct agent_id match or legacy
|
||||
// target fallback). These get flipped to Running atomically below.
|
||||
pendingRows, err := tx.QueryContext(ctx, `
|
||||
SELECT id, type, certificate_id, target_id, agent_id, status, attempts, max_attempts,
|
||||
last_error, scheduled_at, started_at, completed_at, created_at
|
||||
FROM jobs
|
||||
WHERE agent_id = $1 AND status = 'Pending' AND type = 'Deployment'
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT j.id, j.type, j.certificate_id, j.target_id, j.agent_id, j.status, j.attempts, j.max_attempts,
|
||||
j.last_error, j.scheduled_at, j.started_at, j.completed_at, j.created_at
|
||||
FROM jobs j
|
||||
INNER JOIN deployment_targets dt ON j.target_id = dt.id
|
||||
WHERE j.agent_id IS NULL AND j.status = 'Pending' AND j.type = 'Deployment'
|
||||
AND dt.agent_id = $1
|
||||
|
||||
ORDER BY created_at ASC
|
||||
FOR UPDATE SKIP LOCKED
|
||||
`, agentID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query pending deployment jobs for agent: %w", err)
|
||||
}
|
||||
|
||||
var pendingJobs []*domain.Job
|
||||
for pendingRows.Next() {
|
||||
job, err := scanJob(pendingRows)
|
||||
if err != nil {
|
||||
pendingRows.Close()
|
||||
return nil, err
|
||||
}
|
||||
pendingJobs = append(pendingJobs, job)
|
||||
}
|
||||
if err := pendingRows.Err(); err != nil {
|
||||
pendingRows.Close()
|
||||
return nil, fmt.Errorf("error iterating pending deployment rows: %w", err)
|
||||
}
|
||||
pendingRows.Close()
|
||||
|
||||
// Branch 3: AwaitingCSR jobs for this agent. Locked with FOR UPDATE SKIP
|
||||
// LOCKED to prevent duplicate delivery to concurrent pollers, but state is
|
||||
// NOT transitioned — the agent advances state via CSR submission.
|
||||
csrRows, err := tx.QueryContext(ctx, `
|
||||
SELECT j.id, j.type, j.certificate_id, j.target_id, j.agent_id, j.status, j.attempts, j.max_attempts,
|
||||
j.last_error, j.scheduled_at, j.started_at, j.completed_at, j.created_at
|
||||
FROM jobs j
|
||||
WHERE j.status = 'AwaitingCSR'
|
||||
AND j.type IN ('Renewal', 'Issuance')
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM certificate_target_mappings ctm
|
||||
INNER JOIN deployment_targets dt ON ctm.target_id = dt.id
|
||||
WHERE ctm.certificate_id = j.certificate_id
|
||||
AND dt.agent_id = $1
|
||||
)
|
||||
ORDER BY j.created_at ASC
|
||||
FOR UPDATE SKIP LOCKED
|
||||
`, agentID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query AwaitingCSR jobs for agent: %w", err)
|
||||
}
|
||||
|
||||
var csrJobs []*domain.Job
|
||||
for csrRows.Next() {
|
||||
job, err := scanJob(csrRows)
|
||||
if err != nil {
|
||||
csrRows.Close()
|
||||
return nil, err
|
||||
}
|
||||
csrJobs = append(csrJobs, job)
|
||||
}
|
||||
if err := csrRows.Err(); err != nil {
|
||||
csrRows.Close()
|
||||
return nil, fmt.Errorf("error iterating AwaitingCSR rows: %w", err)
|
||||
}
|
||||
csrRows.Close()
|
||||
|
||||
// Transition locked Pending deployments to Running before commit.
|
||||
if len(pendingJobs) > 0 {
|
||||
ids := make([]interface{}, len(pendingJobs))
|
||||
placeholders := make([]byte, 0, len(pendingJobs)*5)
|
||||
for i, job := range pendingJobs {
|
||||
ids[i] = job.ID
|
||||
if i > 0 {
|
||||
placeholders = append(placeholders, ',')
|
||||
}
|
||||
placeholders = append(placeholders, fmt.Sprintf("$%d", i+2)...)
|
||||
}
|
||||
updateQuery := fmt.Sprintf(
|
||||
`UPDATE jobs SET status = $1 WHERE id IN (%s)`,
|
||||
string(placeholders),
|
||||
)
|
||||
updateArgs := append([]interface{}{domain.JobStatusRunning}, ids...)
|
||||
if _, err := tx.ExecContext(ctx, updateQuery, updateArgs...); err != nil {
|
||||
return nil, fmt.Errorf("failed to transition claimed deployment jobs to Running: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("failed to commit agent claim transaction: %w", err)
|
||||
}
|
||||
|
||||
// Reflect the committed state in returned Pending deployment jobs; leave
|
||||
// AwaitingCSR jobs untouched.
|
||||
for _, job := range pendingJobs {
|
||||
job.Status = domain.JobStatusRunning
|
||||
}
|
||||
|
||||
// Preserve the legacy ordering: Pending deployments first, AwaitingCSR
|
||||
// second. Callers that want a strict created_at merge can re-sort.
|
||||
return append(pendingJobs, csrJobs...), nil
|
||||
}
|
||||
|
||||
// scanJob scans a job from a row or rows
|
||||
func scanJob(scanner interface {
|
||||
Scan(...interface{}) error
|
||||
|
||||
@@ -7,6 +7,9 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -703,10 +706,10 @@ func TestRevocationRepository_CRUD(t *testing.T) {
|
||||
t.Fatalf("Idempotent create failed: %v", err)
|
||||
}
|
||||
|
||||
// GetBySerial
|
||||
got, err := repo.GetBySerial(ctx, "DEADBEEF01")
|
||||
// GetByIssuerAndSerial — lookups are scoped to (issuer_id, serial) per RFC 5280 §5.2.3.
|
||||
got, err := repo.GetByIssuerAndSerial(ctx, issuerID, "DEADBEEF01")
|
||||
if err != nil {
|
||||
t.Fatalf("GetBySerial failed: %v", err)
|
||||
t.Fatalf("GetByIssuerAndSerial failed: %v", err)
|
||||
}
|
||||
if got.Reason != "keyCompromise" {
|
||||
t.Errorf("Reason = %q, want %q", got.Reason, "keyCompromise")
|
||||
@@ -734,12 +737,116 @@ func TestRevocationRepository_CRUD(t *testing.T) {
|
||||
if err := repo.MarkIssuerNotified(ctx, "rev-test-1"); err != nil {
|
||||
t.Fatalf("MarkIssuerNotified failed: %v", err)
|
||||
}
|
||||
got, _ = repo.GetBySerial(ctx, "DEADBEEF01")
|
||||
got, _ = repo.GetByIssuerAndSerial(ctx, issuerID, "DEADBEEF01")
|
||||
if !got.IssuerNotified {
|
||||
t.Error("expected IssuerNotified=true after marking")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRevocationRepository_CrossIssuerSerialCollision verifies that the same
|
||||
// serial number can coexist under two different issuers — RFC 5280 §5.2.3
|
||||
// defines serial uniqueness only within a single CA, and certctl supports
|
||||
// multi-issuer deployments where serial collisions across issuers are
|
||||
// legitimate (e.g., Local CA serial 0x01 and Vault PKI serial 0x01).
|
||||
//
|
||||
// This test locks in the behavior change from migration 000012: the unique
|
||||
// index is on (issuer_id, serial_number), not on serial_number alone.
|
||||
func TestRevocationRepository_CrossIssuerSerialCollision(t *testing.T) {
|
||||
tdb := getTestDB(t)
|
||||
db := tdb.freshSchema(t)
|
||||
repo := postgres.NewRevocationRepository(db)
|
||||
certRepo := postgres.NewCertificateRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now().Truncate(time.Microsecond)
|
||||
|
||||
// First issuer + cert + revocation with serial "CAFEBABE01".
|
||||
ownerID1, teamID1, issuerID1, policyID1 := insertCertPrereqsRaw(t, db, ctx, "dup-a")
|
||||
cert1 := &domain.ManagedCertificate{
|
||||
ID: "mc-dup-a", Name: "dup-a", CommonName: "a.example.com",
|
||||
SANs: []string{}, OwnerID: ownerID1, TeamID: teamID1,
|
||||
IssuerID: issuerID1, RenewalPolicyID: policyID1,
|
||||
Status: domain.CertificateStatusRevoked,
|
||||
ExpiresAt: now.Add(30 * 24 * time.Hour), Tags: map[string]string{},
|
||||
CreatedAt: now, UpdatedAt: now,
|
||||
}
|
||||
if err := certRepo.Create(ctx, cert1); err != nil {
|
||||
t.Fatalf("Create cert1 failed: %v", err)
|
||||
}
|
||||
if err := repo.Create(ctx, &domain.CertificateRevocation{
|
||||
ID: "rev-dup-a", CertificateID: "mc-dup-a", SerialNumber: "CAFEBABE01",
|
||||
Reason: "keyCompromise", RevokedBy: "admin", RevokedAt: now,
|
||||
IssuerID: issuerID1, CreatedAt: now,
|
||||
}); err != nil {
|
||||
t.Fatalf("Create revocation under issuer1 failed: %v", err)
|
||||
}
|
||||
|
||||
// Second issuer + cert + revocation with the SAME serial "CAFEBABE01".
|
||||
// Under the pre-000012 global-unique index this would silently drop via
|
||||
// ON CONFLICT DO NOTHING. Under the new (issuer_id, serial_number) scope
|
||||
// it must succeed.
|
||||
ownerID2, teamID2, issuerID2, policyID2 := insertCertPrereqsRaw(t, db, ctx, "dup-b")
|
||||
cert2 := &domain.ManagedCertificate{
|
||||
ID: "mc-dup-b", Name: "dup-b", CommonName: "b.example.com",
|
||||
SANs: []string{}, OwnerID: ownerID2, TeamID: teamID2,
|
||||
IssuerID: issuerID2, RenewalPolicyID: policyID2,
|
||||
Status: domain.CertificateStatusRevoked,
|
||||
ExpiresAt: now.Add(30 * 24 * time.Hour), Tags: map[string]string{},
|
||||
CreatedAt: now, UpdatedAt: now,
|
||||
}
|
||||
if err := certRepo.Create(ctx, cert2); err != nil {
|
||||
t.Fatalf("Create cert2 failed: %v", err)
|
||||
}
|
||||
if err := repo.Create(ctx, &domain.CertificateRevocation{
|
||||
ID: "rev-dup-b", CertificateID: "mc-dup-b", SerialNumber: "CAFEBABE01",
|
||||
Reason: "superseded", RevokedBy: "admin", RevokedAt: now,
|
||||
IssuerID: issuerID2, CreatedAt: now,
|
||||
}); err != nil {
|
||||
t.Fatalf("Create revocation under issuer2 failed (cross-issuer duplicate serial must be allowed): %v", err)
|
||||
}
|
||||
|
||||
// Both revocations must be retrievable under their respective issuers.
|
||||
revA, err := repo.GetByIssuerAndSerial(ctx, issuerID1, "CAFEBABE01")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByIssuerAndSerial(issuer1) failed: %v", err)
|
||||
}
|
||||
if revA.ID != "rev-dup-a" || revA.Reason != "keyCompromise" {
|
||||
t.Errorf("issuer1 lookup returned wrong row: id=%q reason=%q", revA.ID, revA.Reason)
|
||||
}
|
||||
|
||||
revB, err := repo.GetByIssuerAndSerial(ctx, issuerID2, "CAFEBABE01")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByIssuerAndSerial(issuer2) failed: %v", err)
|
||||
}
|
||||
if revB.ID != "rev-dup-b" || revB.Reason != "superseded" {
|
||||
t.Errorf("issuer2 lookup returned wrong row: id=%q reason=%q", revB.ID, revB.Reason)
|
||||
}
|
||||
|
||||
// ListAll should see both revocations.
|
||||
all, err := repo.ListAll(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ListAll failed: %v", err)
|
||||
}
|
||||
if len(all) != 2 {
|
||||
t.Errorf("len(all) = %d, want 2 (cross-issuer duplicate serials)", len(all))
|
||||
}
|
||||
|
||||
// Same-issuer idempotency guard still works (ON CONFLICT DO NOTHING on
|
||||
// (issuer_id, serial_number) — re-inserting the same (issuer, serial)
|
||||
// pair must not error and must not duplicate the row).
|
||||
if err := repo.Create(ctx, &domain.CertificateRevocation{
|
||||
ID: "rev-dup-a-repeat", CertificateID: "mc-dup-a", SerialNumber: "CAFEBABE01",
|
||||
Reason: "superseded", RevokedBy: "admin", RevokedAt: now,
|
||||
IssuerID: issuerID1, CreatedAt: now,
|
||||
}); err != nil {
|
||||
t.Fatalf("Idempotent create under same issuer failed: %v", err)
|
||||
}
|
||||
all, _ = repo.ListAll(ctx)
|
||||
if len(all) != 2 {
|
||||
t.Errorf("len(all) after idempotent re-insert = %d, want 2", len(all))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Team Repository Tests
|
||||
// ============================================================
|
||||
@@ -1578,3 +1685,334 @@ func TestEmptyResultSets(t *testing.T) {
|
||||
t.Errorf("expected empty agent groups, got %d", len(groups))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// H-6 (CWE-362) Claim-Based Concurrency Tests
|
||||
//
|
||||
// These tests exercise the `SELECT ... FOR UPDATE SKIP LOCKED` worker-queue pattern
|
||||
// introduced to remediate the H-6 race condition. They validate two invariants:
|
||||
//
|
||||
// 1. Disjoint claim: under concurrent callers, no Pending row is returned to more
|
||||
// than one worker (i.e. each claim is exclusive).
|
||||
// 2. State transition: claimed rows are atomically flipped to Running inside the
|
||||
// same transaction that locked them, so a subsequent query must see the row in
|
||||
// the Running state and no other worker can observe it as Pending again.
|
||||
//
|
||||
// Skipped automatically in `-short` mode (CI) since they require a real PostgreSQL
|
||||
// instance and take ~1s under contention.
|
||||
// ============================================================
|
||||
|
||||
// seedPendingJobs creates n Pending renewal jobs against a single prerequisite
|
||||
// certificate and returns the generated job IDs.
|
||||
func seedPendingJobs(t *testing.T, ctx context.Context, db *sql.DB, certID string, n int) []string {
|
||||
t.Helper()
|
||||
certRepo := postgres.NewCertificateRepository(db)
|
||||
jobRepo := postgres.NewJobRepository(db)
|
||||
|
||||
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, certID)
|
||||
|
||||
now := time.Now().Truncate(time.Microsecond)
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-" + certID, Name: certID, CommonName: certID + ".example.com",
|
||||
SANs: []string{}, OwnerID: ownerID, TeamID: teamID,
|
||||
IssuerID: issuerID, RenewalPolicyID: policyID,
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: now.Add(30 * 24 * time.Hour), Tags: map[string]string{},
|
||||
CreatedAt: now, UpdatedAt: now,
|
||||
}
|
||||
if err := certRepo.Create(ctx, cert); err != nil {
|
||||
t.Fatalf("seedPendingJobs: create cert failed: %v", err)
|
||||
}
|
||||
|
||||
ids := make([]string, 0, n)
|
||||
for i := 0; i < n; i++ {
|
||||
job := &domain.Job{
|
||||
ID: fmt.Sprintf("job-%s-%03d", certID, i),
|
||||
Type: domain.JobTypeRenewal,
|
||||
CertificateID: "mc-" + certID,
|
||||
Status: domain.JobStatusPending,
|
||||
Attempts: 0,
|
||||
MaxAttempts: 3,
|
||||
ScheduledAt: now,
|
||||
CreatedAt: now,
|
||||
}
|
||||
if err := jobRepo.Create(ctx, job); err != nil {
|
||||
t.Fatalf("seedPendingJobs: create job %d failed: %v", i, err)
|
||||
}
|
||||
ids = append(ids, job.ID)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// TestJobRepository_ClaimPendingJobs_FlipsToRunning validates the basic claim
|
||||
// semantics: a single call transitions Pending rows to Running atomically, and
|
||||
// the rows returned to the caller reflect the post-update state.
|
||||
func TestJobRepository_ClaimPendingJobs_FlipsToRunning(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test requires PostgreSQL")
|
||||
}
|
||||
tdb := getTestDB(t)
|
||||
db := tdb.freshSchema(t)
|
||||
jobRepo := postgres.NewJobRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
seeded := seedPendingJobs(t, ctx, db, "claimflip", 5)
|
||||
|
||||
claimed, err := jobRepo.ClaimPendingJobs(ctx, domain.JobTypeRenewal, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("ClaimPendingJobs failed: %v", err)
|
||||
}
|
||||
if len(claimed) != len(seeded) {
|
||||
t.Fatalf("len(claimed) = %d, want %d", len(claimed), len(seeded))
|
||||
}
|
||||
|
||||
// In-memory return values must reflect the transitioned state.
|
||||
for _, j := range claimed {
|
||||
if j.Status != domain.JobStatusRunning {
|
||||
t.Errorf("claimed job %s Status = %q, want %q", j.ID, j.Status, domain.JobStatusRunning)
|
||||
}
|
||||
}
|
||||
|
||||
// Persisted rows must also be Running — a fresh Get must not see Pending.
|
||||
for _, id := range seeded {
|
||||
got, err := jobRepo.Get(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("Get(%s) failed: %v", id, err)
|
||||
}
|
||||
if got.Status != domain.JobStatusRunning {
|
||||
t.Errorf("persisted job %s Status = %q, want %q", id, got.Status, domain.JobStatusRunning)
|
||||
}
|
||||
}
|
||||
|
||||
// A subsequent claim must return zero rows — nothing is Pending anymore.
|
||||
residual, err := jobRepo.ClaimPendingJobs(ctx, domain.JobTypeRenewal, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("residual ClaimPendingJobs failed: %v", err)
|
||||
}
|
||||
if len(residual) != 0 {
|
||||
t.Errorf("residual claims = %d, want 0 (all should be Running now)", len(residual))
|
||||
}
|
||||
}
|
||||
|
||||
// TestJobRepository_ClaimPendingJobs_ConcurrentDisjoint validates the core H-6
|
||||
// invariant: under concurrent access, no row is handed to more than one worker.
|
||||
//
|
||||
// The test seeds M Pending jobs, fans out N goroutines each of which loops
|
||||
// calling ClaimPendingJobs with limit=1, and finally asserts the union of all
|
||||
// claimed IDs is exactly M with zero duplicates. Workers that transiently
|
||||
// observe zero rows (because peers are holding the only remaining rows) re-check
|
||||
// an atomic progress counter before exiting, so transient SKIP-LOCKED zeros do
|
||||
// not cause premature termination.
|
||||
func TestJobRepository_ClaimPendingJobs_ConcurrentDisjoint(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test requires PostgreSQL")
|
||||
}
|
||||
tdb := getTestDB(t)
|
||||
db := tdb.freshSchema(t)
|
||||
jobRepo := postgres.NewJobRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
const M = 40 // seeded Pending jobs
|
||||
const N = 8 // concurrent workers
|
||||
seeded := seedPendingJobs(t, ctx, db, "concurrent", M)
|
||||
seededSet := make(map[string]bool, M)
|
||||
for _, id := range seeded {
|
||||
seededSet[id] = true
|
||||
}
|
||||
|
||||
var (
|
||||
totalClaimed int64
|
||||
allClaims []string
|
||||
mu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
)
|
||||
|
||||
for w := 0; w < N; w++ {
|
||||
wg.Add(1)
|
||||
go func(worker int) {
|
||||
defer wg.Done()
|
||||
emptyStreak := 0
|
||||
for iter := 0; iter < M*4; iter++ { // generous ceiling to prevent hangs
|
||||
claimed, err := jobRepo.ClaimPendingJobs(ctx, domain.JobTypeRenewal, 1)
|
||||
if err != nil {
|
||||
t.Errorf("worker %d ClaimPendingJobs failed: %v", worker, err)
|
||||
return
|
||||
}
|
||||
if len(claimed) == 0 {
|
||||
// Transient zero (peer holds lock) vs. terminal zero (all claimed).
|
||||
// Bail only once the shared counter proves work is done, but guard
|
||||
// with a streak so we don't spin forever under starvation.
|
||||
if atomic.LoadInt64(&totalClaimed) >= int64(M) {
|
||||
return
|
||||
}
|
||||
emptyStreak++
|
||||
if emptyStreak >= 20 {
|
||||
return
|
||||
}
|
||||
time.Sleep(500 * time.Microsecond)
|
||||
continue
|
||||
}
|
||||
emptyStreak = 0
|
||||
mu.Lock()
|
||||
for _, j := range claimed {
|
||||
if j.Status != domain.JobStatusRunning {
|
||||
t.Errorf("worker %d got job %s in Status=%q (want Running) — claim did not flip state", worker, j.ID, j.Status)
|
||||
}
|
||||
allClaims = append(allClaims, j.ID)
|
||||
}
|
||||
mu.Unlock()
|
||||
atomic.AddInt64(&totalClaimed, int64(len(claimed)))
|
||||
}
|
||||
}(w)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Invariant 1: no duplicate claims across the worker pool.
|
||||
seen := make(map[string]int, len(allClaims))
|
||||
for _, id := range allClaims {
|
||||
seen[id]++
|
||||
}
|
||||
for id, count := range seen {
|
||||
if count > 1 {
|
||||
t.Errorf("job %s claimed %d times — SKIP LOCKED invariant violated", id, count)
|
||||
}
|
||||
}
|
||||
|
||||
// Invariant 2: every seeded job appears in the claim set exactly once.
|
||||
if len(seen) != M {
|
||||
t.Errorf("distinct claimed IDs = %d, want %d (all seeded jobs must be claimed)", len(seen), M)
|
||||
}
|
||||
for id := range seededSet {
|
||||
if seen[id] == 0 {
|
||||
t.Errorf("seeded job %s was never claimed by any worker", id)
|
||||
}
|
||||
}
|
||||
|
||||
// Invariant 3: persisted state reflects the transition — every seeded row
|
||||
// is now Running; none is Pending.
|
||||
for id := range seededSet {
|
||||
got, err := jobRepo.Get(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("Get(%s) failed: %v", id, err)
|
||||
}
|
||||
if got.Status != domain.JobStatusRunning {
|
||||
t.Errorf("job %s Status = %q, want %q", id, got.Status, domain.JobStatusRunning)
|
||||
}
|
||||
}
|
||||
|
||||
// Final progress counter must match the total number of seeded jobs.
|
||||
if got := atomic.LoadInt64(&totalClaimed); got != int64(M) {
|
||||
t.Errorf("totalClaimed = %d, want %d", got, M)
|
||||
}
|
||||
}
|
||||
|
||||
// TestJobRepository_ClaimPendingByAgentID_TransitionsDeployments validates the
|
||||
// agent-scoped claim variant: Pending deployment rows for a given agent flip to
|
||||
// Running; AwaitingCSR rows are returned but their state is preserved (the CSR
|
||||
// submission path drives their next transition).
|
||||
func TestJobRepository_ClaimPendingByAgentID_TransitionsDeployments(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test requires PostgreSQL")
|
||||
}
|
||||
tdb := getTestDB(t)
|
||||
db := tdb.freshSchema(t)
|
||||
jobRepo := postgres.NewJobRepository(db)
|
||||
agentRepo := postgres.NewAgentRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "agentclaim")
|
||||
|
||||
now := time.Now().Truncate(time.Microsecond)
|
||||
cert := &domain.ManagedCertificate{
|
||||
ID: "mc-agentclaim", Name: "agentclaim", CommonName: "agentclaim.example.com",
|
||||
SANs: []string{}, OwnerID: ownerID, TeamID: teamID,
|
||||
IssuerID: issuerID, RenewalPolicyID: policyID,
|
||||
Status: domain.CertificateStatusActive,
|
||||
ExpiresAt: now.Add(30 * 24 * time.Hour), Tags: map[string]string{},
|
||||
CreatedAt: now, UpdatedAt: now,
|
||||
}
|
||||
if err := postgres.NewCertificateRepository(db).Create(ctx, cert); err != nil {
|
||||
t.Fatalf("create cert failed: %v", err)
|
||||
}
|
||||
|
||||
agent := &domain.Agent{
|
||||
ID: "a-claim",
|
||||
Name: "claim-agent",
|
||||
Hostname: "claim-agent-host",
|
||||
Status: domain.AgentStatusOnline,
|
||||
RegisteredAt: now,
|
||||
APIKeyHash: "hash-claim",
|
||||
}
|
||||
if err := agentRepo.Create(ctx, agent); err != nil {
|
||||
t.Fatalf("create agent failed: %v", err)
|
||||
}
|
||||
|
||||
agentID := agent.ID
|
||||
mkJob := func(id string, typ domain.JobType, status domain.JobStatus) *domain.Job {
|
||||
return &domain.Job{
|
||||
ID: id, Type: typ, CertificateID: cert.ID,
|
||||
AgentID: &agentID,
|
||||
Status: status,
|
||||
Attempts: 0,
|
||||
MaxAttempts: 3,
|
||||
ScheduledAt: now,
|
||||
CreatedAt: now,
|
||||
}
|
||||
}
|
||||
jobs := []*domain.Job{
|
||||
mkJob("job-agentclaim-dep-1", domain.JobTypeDeployment, domain.JobStatusPending),
|
||||
mkJob("job-agentclaim-dep-2", domain.JobTypeDeployment, domain.JobStatusPending),
|
||||
mkJob("job-agentclaim-csr-1", domain.JobTypeRenewal, domain.JobStatusAwaitingCSR),
|
||||
// A Pending Renewal (not Deployment) must NOT be returned by the per-agent claim.
|
||||
mkJob("job-agentclaim-ren-pending", domain.JobTypeRenewal, domain.JobStatusPending),
|
||||
}
|
||||
for _, j := range jobs {
|
||||
if err := jobRepo.Create(ctx, j); err != nil {
|
||||
t.Fatalf("create job %s failed: %v", j.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
claimed, err := jobRepo.ClaimPendingByAgentID(ctx, agentID)
|
||||
if err != nil {
|
||||
t.Fatalf("ClaimPendingByAgentID failed: %v", err)
|
||||
}
|
||||
// Expect exactly the 2 deployments + 1 AwaitingCSR.
|
||||
if len(claimed) != 3 {
|
||||
t.Fatalf("len(claimed) = %d, want 3 (2 deployments + 1 AwaitingCSR)", len(claimed))
|
||||
}
|
||||
|
||||
statusByID := map[string]domain.JobStatus{}
|
||||
for _, j := range claimed {
|
||||
statusByID[j.ID] = j.Status
|
||||
}
|
||||
// Both deployments must be Running in the returned slice (in-memory reflection).
|
||||
for _, id := range []string{"job-agentclaim-dep-1", "job-agentclaim-dep-2"} {
|
||||
if statusByID[id] != domain.JobStatusRunning {
|
||||
t.Errorf("returned deployment %s Status = %q, want Running", id, statusByID[id])
|
||||
}
|
||||
}
|
||||
// AwaitingCSR must remain AwaitingCSR.
|
||||
if statusByID["job-agentclaim-csr-1"] != domain.JobStatusAwaitingCSR {
|
||||
t.Errorf("returned AwaitingCSR Status = %q, want AwaitingCSR", statusByID["job-agentclaim-csr-1"])
|
||||
}
|
||||
// The unrelated Pending Renewal must not be returned.
|
||||
if _, ok := statusByID["job-agentclaim-ren-pending"]; ok {
|
||||
t.Errorf("Pending Renewal job was returned by ClaimPendingByAgentID — scope violation")
|
||||
}
|
||||
|
||||
// Persisted state: deployments Running, AwaitingCSR unchanged, Pending Renewal still Pending.
|
||||
for id, want := range map[string]domain.JobStatus{
|
||||
"job-agentclaim-dep-1": domain.JobStatusRunning,
|
||||
"job-agentclaim-dep-2": domain.JobStatusRunning,
|
||||
"job-agentclaim-csr-1": domain.JobStatusAwaitingCSR,
|
||||
"job-agentclaim-ren-pending": domain.JobStatusPending,
|
||||
} {
|
||||
got, err := jobRepo.Get(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("Get(%s) failed: %v", id, err)
|
||||
}
|
||||
if got.Status != want {
|
||||
t.Errorf("persisted %s Status = %q, want %q", id, got.Status, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,13 +19,18 @@ func NewRevocationRepository(db *sql.DB) *RevocationRepository {
|
||||
}
|
||||
|
||||
// Create records a new certificate revocation.
|
||||
//
|
||||
// Uniqueness is scoped to (issuer_id, serial_number) per RFC 5280 §5.2.3.
|
||||
// Serial numbers are only unique within an issuer, so certctl supports
|
||||
// collisions across different issuer connectors. The composite ON CONFLICT
|
||||
// target matches migration 000012's unique index.
|
||||
func (r *RevocationRepository) Create(ctx context.Context, revocation *domain.CertificateRevocation) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
INSERT INTO certificate_revocations (
|
||||
id, certificate_id, serial_number, reason, revoked_by, revoked_at,
|
||||
issuer_id, issuer_notified, created_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
ON CONFLICT (serial_number) DO NOTHING
|
||||
ON CONFLICT (issuer_id, serial_number) DO NOTHING
|
||||
`, revocation.ID, revocation.CertificateID, revocation.SerialNumber,
|
||||
revocation.Reason, revocation.RevokedBy, revocation.RevokedAt,
|
||||
revocation.IssuerID, revocation.IssuerNotified, revocation.CreatedAt)
|
||||
@@ -37,20 +42,24 @@ func (r *RevocationRepository) Create(ctx context.Context, revocation *domain.Ce
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetBySerial retrieves a revocation by serial number.
|
||||
func (r *RevocationRepository) GetBySerial(ctx context.Context, serial string) (*domain.CertificateRevocation, error) {
|
||||
// GetByIssuerAndSerial retrieves a revocation by the (issuer_id, serial) pair.
|
||||
//
|
||||
// Per RFC 5280 §5.2.3, serial numbers are unique only within a single issuer.
|
||||
// Callers (OCSP handlers, CRL generation) always know the issuer because the
|
||||
// OCSP URL carries it as a path parameter and CRLs are generated per-issuer.
|
||||
func (r *RevocationRepository) GetByIssuerAndSerial(ctx context.Context, issuerID, serial string) (*domain.CertificateRevocation, error) {
|
||||
var rev domain.CertificateRevocation
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, certificate_id, serial_number, reason, revoked_by, revoked_at,
|
||||
issuer_id, issuer_notified, created_at
|
||||
FROM certificate_revocations
|
||||
WHERE serial_number = $1
|
||||
`, serial).Scan(&rev.ID, &rev.CertificateID, &rev.SerialNumber,
|
||||
WHERE issuer_id = $1 AND serial_number = $2
|
||||
`, issuerID, serial).Scan(&rev.ID, &rev.CertificateID, &rev.SerialNumber,
|
||||
&rev.Reason, &rev.RevokedBy, &rev.RevokedAt,
|
||||
&rev.IssuerID, &rev.IssuerNotified, &rev.CreatedAt)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get revocation by serial: %w", err)
|
||||
return nil, fmt.Errorf("failed to get revocation by issuer and serial: %w", err)
|
||||
}
|
||||
|
||||
return &rev, nil
|
||||
|
||||
+27
-12
@@ -2,11 +2,12 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
@@ -57,8 +58,11 @@ func (s *AgentService) Register(ctx context.Context, name string, hostname strin
|
||||
return nil, "", fmt.Errorf("agent name and hostname are required")
|
||||
}
|
||||
|
||||
// Generate API key
|
||||
apiKey := generateAPIKey()
|
||||
// Generate API key. crypto/rand failure is non-recoverable — propagate immediately.
|
||||
apiKey, err := generateAPIKey()
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to generate agent api key: %w", err)
|
||||
}
|
||||
apiKeyHash := hashAPIKey(apiKey)
|
||||
|
||||
now := time.Now()
|
||||
@@ -280,8 +284,13 @@ func (s *AgentService) GetPendingWork(ctx context.Context, agentID string) ([]*d
|
||||
return nil, fmt.Errorf("failed to fetch agent: %w", err)
|
||||
}
|
||||
|
||||
// Return only jobs assigned to this agent (via agent_id or target→agent relationship)
|
||||
return s.jobRepo.ListPendingByAgentID(ctx, agentID)
|
||||
// Atomically claim jobs assigned to this agent. H-6 (CWE-362) remediation:
|
||||
// ClaimPendingByAgentID uses SELECT ... FOR UPDATE SKIP LOCKED so concurrent poll
|
||||
// requests (duplicate agents, retry storms, or a lagging long-poll) never observe
|
||||
// the same Pending deployment row. Pending deployments are flipped to Running inside
|
||||
// the claim transaction; AwaitingCSR jobs keep their state since CSR submission is
|
||||
// the state-machine trigger for their next transition.
|
||||
return s.jobRepo.ClaimPendingByAgentID(ctx, agentID)
|
||||
}
|
||||
|
||||
// ReportJobStatus updates a job's status based on agent feedback.
|
||||
@@ -380,7 +389,10 @@ func (s *AgentService) GetAgent(ctx context.Context, id string) (*domain.Agent,
|
||||
// RegisterAgent creates and registers a new agent (handler interface method).
|
||||
func (s *AgentService) RegisterAgent(ctx context.Context, agent domain.Agent) (*domain.Agent, error) {
|
||||
agent.ID = generateID("agent")
|
||||
apiKey := generateAPIKey()
|
||||
apiKey, err := generateAPIKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate agent api key: %w", err)
|
||||
}
|
||||
agent.APIKeyHash = hashAPIKey(apiKey)
|
||||
agent.Status = domain.AgentStatusOnline
|
||||
now := time.Now()
|
||||
@@ -487,14 +499,17 @@ func (s *AgentService) CertificatePickup(ctx context.Context, agentID, certID st
|
||||
return string(certPEM), nil
|
||||
}
|
||||
|
||||
// generateAPIKey creates a random API key for an agent.
|
||||
func generateAPIKey() string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
// generateAPIKey creates a cryptographically secure random API key for an agent.
|
||||
// It fills a 32-byte buffer from crypto/rand (256 bits of entropy) and encodes it with
|
||||
// base64.RawURLEncoding, yielding a 43-character URL-safe, unpadded ASCII string.
|
||||
// The plaintext key is shown to the caller exactly once; only its SHA-256 hash is stored.
|
||||
// Fixes C-1 (CWE-338: previously used math/rand, which is not cryptographically secure).
|
||||
func generateAPIKey() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
for i := range b {
|
||||
b[i] = charset[rand.Intn(len(charset))]
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("generate agent api key: %w", err)
|
||||
}
|
||||
return string(b)
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// hashAPIKey hashes an API key using SHA256.
|
||||
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -594,3 +595,44 @@ func TestListAgents(t *testing.T) {
|
||||
t.Errorf("expected total 2, got %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateAPIKey_Properties is the core regression test for C-1 (CWE-338).
|
||||
// It verifies that generateAPIKey produces cryptographically random,
|
||||
// unpadded base64url-encoded, 32-byte (256-bit) keys that never collide
|
||||
// across consecutive calls. Exact length and alphabet are verified against
|
||||
// base64.RawURLEncoding so any silent change to entropy or encoding fails
|
||||
// fast.
|
||||
//
|
||||
// Note on the error branch: since Go 1.24 (issue #66821) crypto/rand.Read
|
||||
// treats entropy-source failures as fatal — the process is terminated
|
||||
// rather than returning an error. The defensive `if err != nil` branch
|
||||
// in generateAPIKey is therefore unreachable from tests on modern Go.
|
||||
// It is kept to preserve the documented (string, error) contract and
|
||||
// to remain correct on older Go toolchains or future changes.
|
||||
func TestGenerateAPIKey_Properties(t *testing.T) {
|
||||
seen := make(map[string]struct{}, 64)
|
||||
for i := 0; i < 64; i++ {
|
||||
k, err := generateAPIKey()
|
||||
if err != nil {
|
||||
t.Fatalf("generateAPIKey failed: %v", err)
|
||||
}
|
||||
if k == "" {
|
||||
t.Fatal("expected non-empty API key")
|
||||
}
|
||||
// base64.RawURLEncoding of 32 bytes yields exactly 43 chars.
|
||||
if got, want := len(k), 43; got != want {
|
||||
t.Fatalf("expected key length %d, got %d (%q)", want, got, k)
|
||||
}
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(k)
|
||||
if err != nil {
|
||||
t.Fatalf("key %q not valid base64url: %v", k, err)
|
||||
}
|
||||
if len(decoded) != 32 {
|
||||
t.Fatalf("expected 32 decoded bytes (256 bits entropy), got %d", len(decoded))
|
||||
}
|
||||
if _, dup := seen[k]; dup {
|
||||
t.Fatalf("collision detected after %d calls; weak PRNG?", i+1)
|
||||
}
|
||||
seen[k] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,8 +117,10 @@ func (s *CAOperationsSvc) GetOCSPResponse(issuerID string, serialHex string) ([]
|
||||
// Short-lived cert exemption: if the cert's profile has TTL < 1 hour,
|
||||
// always return "good" — expiry is sufficient revocation for short-lived certs.
|
||||
if s.profileRepo != nil && s.certRepo != nil {
|
||||
// Look up cert by serial through revocation table
|
||||
rev, _ := s.revocationRepo.GetBySerial(context.Background(), serialHex)
|
||||
// Look up cert by (issuer_id, serial) — per RFC 5280 §5.2.3, serial numbers
|
||||
// are unique only within a single issuer. The OCSP URL path carries issuer_id,
|
||||
// so we scope the lookup to avoid cross-issuer collisions.
|
||||
rev, _ := s.revocationRepo.GetByIssuerAndSerial(context.Background(), issuerID, serialHex)
|
||||
if rev != nil {
|
||||
cert, err := s.certRepo.Get(context.Background(), rev.CertificateID)
|
||||
if err == nil && cert.CertificateProfileID != "" {
|
||||
@@ -135,8 +137,8 @@ func (s *CAOperationsSvc) GetOCSPResponse(issuerID string, serialHex string) ([]
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this serial is revoked
|
||||
rev, err := s.revocationRepo.GetBySerial(context.Background(), serialHex)
|
||||
// Check if this (issuer_id, serial) is revoked — RFC 5280 §5.2.3 scoping.
|
||||
rev, err := s.revocationRepo.GetByIssuerAndSerial(context.Background(), issuerID, serialHex)
|
||||
if err != nil {
|
||||
// Not revoked — return "good" status
|
||||
return issuerConn.SignOCSPResponse(context.Background(), OCSPSignRequest{
|
||||
|
||||
+28
-12
@@ -327,8 +327,20 @@ func (s *IssuerService) SeedFromEnvVars(ctx context.Context, cfg *config.Config)
|
||||
seeds := s.buildEnvVarSeeds(cfg)
|
||||
seeded := 0
|
||||
for _, seed := range seeds {
|
||||
// Encrypt the config if key is set
|
||||
if len(seed.Config) > 0 {
|
||||
// Encrypt the config only when an encryption key is configured.
|
||||
//
|
||||
// Env-seeded issuers carry Source="env" and are reconstructable on every
|
||||
// boot from process environment, so persisting their config in plaintext
|
||||
// adds no new exposure: the same bytes already live in the operator's
|
||||
// deployment manifest. When no key is configured we therefore leave
|
||||
// EncryptedConfig nil and keep the raw JSON in the `config` column —
|
||||
// IssuerRegistry.Rebuild falls through to `cfg.Config` when there is no
|
||||
// ciphertext to decrypt, so registry load still works.
|
||||
//
|
||||
// Database-sourced rows (Source="database") never reach this branch:
|
||||
// they are created through the GUI/API write paths, which require the
|
||||
// encryption key and fail closed via crypto.ErrEncryptionKeyRequired.
|
||||
if len(seed.Config) > 0 && len(s.encryptionKey) > 0 {
|
||||
encrypted, _, encErr := crypto.EncryptIfKeySet([]byte(seed.Config), s.encryptionKey)
|
||||
if encErr != nil {
|
||||
s.logger.Error("failed to encrypt seed config", "id", seed.ID, "error", encErr)
|
||||
@@ -565,17 +577,21 @@ func (s *IssuerService) buildEnvVarSeeds(cfg *config.Config) []*domain.Issuer {
|
||||
|
||||
// Conditional: GlobalSign — only seed if API URL and API key are set
|
||||
if cfg.GlobalSign.APIUrl != "" && cfg.GlobalSign.APIKey != "" {
|
||||
globalSignConfig := map[string]interface{}{
|
||||
"api_url": cfg.GlobalSign.APIUrl,
|
||||
"api_key": cfg.GlobalSign.APIKey,
|
||||
"api_secret": cfg.GlobalSign.APISecret,
|
||||
"client_cert_path": cfg.GlobalSign.ClientCertPath,
|
||||
"client_key_path": cfg.GlobalSign.ClientKeyPath,
|
||||
}
|
||||
if cfg.GlobalSign.ServerCAPath != "" {
|
||||
globalSignConfig["server_ca_path"] = cfg.GlobalSign.ServerCAPath
|
||||
}
|
||||
seeds = append(seeds, &domain.Issuer{
|
||||
ID: "iss-globalsign",
|
||||
Name: "GlobalSign Atlas",
|
||||
Type: domain.IssuerTypeGlobalSign,
|
||||
Config: mustJSON(map[string]interface{}{
|
||||
"api_url": cfg.GlobalSign.APIUrl,
|
||||
"api_key": cfg.GlobalSign.APIKey,
|
||||
"api_secret": cfg.GlobalSign.APISecret,
|
||||
"client_cert_path": cfg.GlobalSign.ClientCertPath,
|
||||
"client_key_path": cfg.GlobalSign.ClientKeyPath,
|
||||
}),
|
||||
ID: "iss-globalsign",
|
||||
Name: "GlobalSign Atlas",
|
||||
Type: domain.IssuerTypeGlobalSign,
|
||||
Config: mustJSON(globalSignConfig),
|
||||
Enabled: true,
|
||||
Source: "env",
|
||||
CreatedAt: now,
|
||||
|
||||
@@ -217,7 +217,7 @@ func TestIssuerService_Create(t *testing.T) {
|
||||
auditService := NewAuditService(auditRepo)
|
||||
|
||||
registry := NewIssuerRegistry(slog.Default())
|
||||
service := NewIssuerService(repo, auditService, registry, nil, slog.Default())
|
||||
service := NewIssuerService(repo, auditService, registry, testEncryptionKey, slog.Default())
|
||||
|
||||
config := map[string]interface{}{"endpoint": "https://acme.example.com/v2/new-account"}
|
||||
configJSON, _ := json.Marshal(config)
|
||||
@@ -342,7 +342,7 @@ func TestIssuerService_Update(t *testing.T) {
|
||||
auditService := NewAuditService(auditRepo)
|
||||
|
||||
registry := NewIssuerRegistry(slog.Default())
|
||||
service := NewIssuerService(repo, auditService, registry, nil, slog.Default())
|
||||
service := NewIssuerService(repo, auditService, registry, testEncryptionKey, slog.Default())
|
||||
|
||||
config := map[string]interface{}{"endpoint": "https://acme.example.com"}
|
||||
configJSON, _ := json.Marshal(config)
|
||||
@@ -568,7 +568,7 @@ func TestIssuerService_CreateIssuer_HandlerInterface(t *testing.T) {
|
||||
auditService := NewAuditService(auditRepo)
|
||||
|
||||
registry := NewIssuerRegistry(slog.Default())
|
||||
service := NewIssuerService(repo, auditService, registry, nil, slog.Default())
|
||||
service := NewIssuerService(repo, auditService, registry, testEncryptionKey, slog.Default())
|
||||
|
||||
config := map[string]interface{}{"url": "https://example.com"}
|
||||
configJSON, _ := json.Marshal(config)
|
||||
@@ -680,7 +680,7 @@ func TestIssuerService_Create_LowercaseType(t *testing.T) {
|
||||
auditService := NewAuditService(auditRepo)
|
||||
|
||||
registry := NewIssuerRegistry(slog.Default())
|
||||
service := NewIssuerService(repo, auditService, registry, nil, slog.Default())
|
||||
service := NewIssuerService(repo, auditService, registry, testEncryptionKey, slog.Default())
|
||||
|
||||
config := map[string]interface{}{"endpoint": "https://acme.example.com"}
|
||||
configJSON, _ := json.Marshal(config)
|
||||
@@ -710,7 +710,7 @@ func TestIssuerService_CreateIssuer_LowercaseType(t *testing.T) {
|
||||
auditService := NewAuditService(auditRepo)
|
||||
|
||||
registry := NewIssuerRegistry(slog.Default())
|
||||
service := NewIssuerService(repo, auditService, registry, nil, slog.Default())
|
||||
service := NewIssuerService(repo, auditService, registry, testEncryptionKey, slog.Default())
|
||||
|
||||
config := map[string]interface{}{"url": "https://example.com"}
|
||||
configJSON, _ := json.Marshal(config)
|
||||
@@ -752,7 +752,7 @@ func TestIssuerService_Create_M49Types(t *testing.T) {
|
||||
auditService := NewAuditService(auditRepo)
|
||||
|
||||
registry := NewIssuerRegistry(slog.Default())
|
||||
service := NewIssuerService(repo, auditService, registry, nil, slog.Default())
|
||||
service := NewIssuerService(repo, auditService, registry, testEncryptionKey, slog.Default())
|
||||
|
||||
config := map[string]interface{}{"api_url": "https://example.com"}
|
||||
configJSON, _ := json.Marshal(config)
|
||||
|
||||
+10
-3
@@ -35,11 +35,18 @@ func NewJobService(
|
||||
|
||||
// ProcessPendingJobs fetches and processes all pending jobs.
|
||||
// It routes jobs to the appropriate service based on job type and handles errors gracefully.
|
||||
//
|
||||
// Concurrency (H-6 CWE-362): jobs are claimed via ClaimPendingJobs which uses
|
||||
// SELECT ... FOR UPDATE SKIP LOCKED and flips Pending → Running atomically. Concurrent
|
||||
// scheduler replicas in HA deployments will therefore never observe the same Pending row,
|
||||
// and the subsequent UpdateStatus(Running) calls inside the downstream service methods are
|
||||
// idempotent against the pre-flipped state.
|
||||
func (s *JobService) ProcessPendingJobs(ctx context.Context) error {
|
||||
// Fetch pending jobs
|
||||
pendingJobs, err := s.jobRepo.ListByStatus(ctx, domain.JobStatusPending)
|
||||
// Claim pending jobs atomically (H-6 remediation: was ListByStatus which had no row lock).
|
||||
// Empty jobType matches all types; zero limit means unlimited (preserves prior semantics).
|
||||
pendingJobs, err := s.jobRepo.ClaimPendingJobs(ctx, "", 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list pending jobs: %w", err)
|
||||
return fmt.Errorf("failed to claim pending jobs: %w", err)
|
||||
}
|
||||
|
||||
if len(pendingJobs) == 0 {
|
||||
|
||||
@@ -119,7 +119,10 @@ func TestESTService_MaxTTL_ForwardedToIssuer(t *testing.T) {
|
||||
|
||||
func TestSCEPService_CryptoValidation_RejectsWeakKey(t *testing.T) {
|
||||
mockIssuer := &mockIssuerConnector{}
|
||||
svc := NewSCEPService("iss-local", mockIssuer, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "")
|
||||
// H-2: SCEPService now requires a configured challenge password. Pass a
|
||||
// matching client password so this test exercises the crypto-policy path
|
||||
// rather than being short-circuited by the challenge-password guard.
|
||||
svc := NewSCEPService("iss-local", mockIssuer, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "secret123")
|
||||
|
||||
// Profile requiring ECDSA P-384 minimum
|
||||
profileRepo := newM11cProfileRepo()
|
||||
@@ -136,7 +139,7 @@ func TestSCEPService_CryptoValidation_RejectsWeakKey(t *testing.T) {
|
||||
// P-256 CSR should be rejected
|
||||
csrPEM := generateCSRPEM(t, "device.example.com", nil)
|
||||
|
||||
_, err := svc.PKCSReq(context.Background(), csrPEM, "", "txn-001")
|
||||
_, err := svc.PKCSReq(context.Background(), csrPEM, "secret123", "txn-001")
|
||||
if err == nil {
|
||||
t.Fatal("expected rejection for ECDSA P-256 against P-384 minimum")
|
||||
}
|
||||
@@ -152,7 +155,8 @@ func TestSCEPService_CryptoValidation_AcceptsStrongKey(t *testing.T) {
|
||||
mockIssuer := &mockIssuerConnector{}
|
||||
auditRepo := newMockAuditRepository()
|
||||
auditSvc := NewAuditService(auditRepo)
|
||||
svc := NewSCEPService("iss-local", mockIssuer, auditSvc, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "")
|
||||
// H-2: happy path exercises the authenticated branch.
|
||||
svc := NewSCEPService("iss-local", mockIssuer, auditSvc, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "secret123")
|
||||
|
||||
profileRepo := newM11cProfileRepo()
|
||||
profileRepo.profiles["prof-standard"] = &domain.CertificateProfile{
|
||||
@@ -167,7 +171,7 @@ func TestSCEPService_CryptoValidation_AcceptsStrongKey(t *testing.T) {
|
||||
|
||||
csrPEM := generateCSRPEM(t, "device-ok.example.com", nil)
|
||||
|
||||
result, err := svc.PKCSReq(context.Background(), csrPEM, "", "txn-002")
|
||||
result, err := svc.PKCSReq(context.Background(), csrPEM, "secret123", "txn-002")
|
||||
if err != nil {
|
||||
t.Fatalf("expected success: %v", err)
|
||||
}
|
||||
@@ -179,7 +183,8 @@ func TestSCEPService_CryptoValidation_AcceptsStrongKey(t *testing.T) {
|
||||
func TestSCEPService_MaxTTL_ForwardedToIssuer(t *testing.T) {
|
||||
capturingMock := &capturingIssuerConnector{}
|
||||
|
||||
svc := NewSCEPService("iss-local", capturingMock, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "")
|
||||
// H-2: challenge password required for enrollment.
|
||||
svc := NewSCEPService("iss-local", capturingMock, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "secret123")
|
||||
|
||||
profileRepo := newM11cProfileRepo()
|
||||
profileRepo.profiles["prof-device"] = &domain.CertificateProfile{
|
||||
@@ -192,7 +197,7 @@ func TestSCEPService_MaxTTL_ForwardedToIssuer(t *testing.T) {
|
||||
|
||||
csrPEM := generateCSRPEM(t, "mdm-device.example.com", nil)
|
||||
|
||||
_, err := svc.PKCSReq(context.Background(), csrPEM, "", "txn-003")
|
||||
_, err := svc.PKCSReq(context.Background(), csrPEM, "secret123", "txn-003")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -341,12 +346,13 @@ func TestESTService_NoProfileRepo_PassesThrough(t *testing.T) {
|
||||
|
||||
func TestSCEPService_NoProfileRepo_PassesThrough(t *testing.T) {
|
||||
mockIssuer := &mockIssuerConnector{}
|
||||
svc := NewSCEPService("iss-local", mockIssuer, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "")
|
||||
// H-2: challenge password required for enrollment.
|
||||
svc := NewSCEPService("iss-local", mockIssuer, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "secret123")
|
||||
svc.SetProfileID("nonexistent-profile")
|
||||
|
||||
csrPEM := generateCSRPEM(t, "no-profile-scep.example.com", nil)
|
||||
|
||||
result, err := svc.PKCSReq(context.Background(), csrPEM, "", "txn-004")
|
||||
result, err := svc.PKCSReq(context.Background(), csrPEM, "secret123", "txn-004")
|
||||
if err != nil {
|
||||
t.Fatalf("expected success when no profile repo set: %v", err)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/repository"
|
||||
"github.com/shankar0123/certctl/internal/tlsprobe"
|
||||
"github.com/shankar0123/certctl/internal/validation"
|
||||
)
|
||||
|
||||
// SentinelAgentID is the agent ID used for network-discovered certificates.
|
||||
@@ -318,51 +319,27 @@ func (s *NetworkScanService) expandEndpoints(cidrs []string, ports []int64) []st
|
||||
return endpoints
|
||||
}
|
||||
|
||||
// isReservedCIDR checks if an IP address falls within reserved ranges that should not be scanned.
|
||||
// Filters out loopback, link-local (including cloud metadata), and multicast ranges.
|
||||
// Does NOT filter RFC 1918 ranges since certctl is self-hosted and internal networks are a primary use case.
|
||||
func isReservedIP(ip net.IP) bool {
|
||||
// Loopback: 127.0.0.0/8
|
||||
if ip.IsLoopback() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Link-local: 169.254.0.0/16 (includes cloud metadata 169.254.169.254)
|
||||
if linkLocal := net.ParseIP("169.254.0.0"); linkLocal != nil {
|
||||
if _, linkLocalNet, _ := net.ParseCIDR("169.254.0.0/16"); linkLocalNet != nil {
|
||||
if linkLocalNet.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Multicast: 224.0.0.0/4
|
||||
if multicast := net.ParseIP("224.0.0.0"); multicast != nil {
|
||||
if _, multicastNet, _ := net.ParseCIDR("224.0.0.0/4"); multicastNet != nil {
|
||||
if multicastNet.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast: 255.255.255.255
|
||||
if ip.String() == "255.255.255.255" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
// The reserved-IP filter used by expandCIDR previously lived here as an
|
||||
// unexported isReservedIP helper. It has been moved to
|
||||
// internal/validation.IsReservedIP so the webhook notifier can share a single
|
||||
// authoritative implementation (H-4, CWE-918). The behaviour is
|
||||
// byte-identical with the previous helper — RFC 1918 is intentionally NOT
|
||||
// filtered, matching certctl's self-hosted design. If you change the
|
||||
// validation package's IsReservedIP, you are changing the network-scanner's
|
||||
// behaviour; audit both code paths together.
|
||||
|
||||
// expandCIDR expands a CIDR notation or single IP into a list of IPs.
|
||||
// Limits expansion to /20 (4096 IPs) to prevent accidental huge scans.
|
||||
// Filters out reserved IP ranges to prevent SSRF attacks.
|
||||
// Filters out reserved IP ranges (via validation.IsReservedIP) to prevent
|
||||
// SSRF amplification via network-scan targets pointed at cloud metadata or
|
||||
// loopback.
|
||||
func expandCIDR(cidr string) []string {
|
||||
// Try as CIDR first
|
||||
ip, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
// Try as single IP
|
||||
if singleIP := net.ParseIP(cidr); singleIP != nil {
|
||||
if isReservedIP(singleIP) {
|
||||
if validation.IsReservedIP(singleIP) {
|
||||
return nil
|
||||
}
|
||||
return []string{singleIP.String()}
|
||||
@@ -380,7 +357,7 @@ func expandCIDR(cidr string) []string {
|
||||
var ips []string
|
||||
for ip := ip.Mask(ipNet.Mask); ipNet.Contains(ip); incrementIP(ip) {
|
||||
// Skip reserved IPs
|
||||
if isReservedIP(ip) {
|
||||
if validation.IsReservedIP(ip) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/validation"
|
||||
)
|
||||
|
||||
// mockNetworkScanRepo for testing
|
||||
@@ -248,9 +249,9 @@ func TestIsReservedIP_Loopback(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
result := isReservedIP(net.ParseIP(tt.ip))
|
||||
result := validation.IsReservedIP(net.ParseIP(tt.ip))
|
||||
if result != tt.expected {
|
||||
t.Errorf("isReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected)
|
||||
t.Errorf("validation.IsReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -269,9 +270,9 @@ func TestIsReservedIP_LinkLocal(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
result := isReservedIP(net.ParseIP(tt.ip))
|
||||
result := validation.IsReservedIP(net.ParseIP(tt.ip))
|
||||
if result != tt.expected {
|
||||
t.Errorf("isReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected)
|
||||
t.Errorf("validation.IsReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -289,18 +290,18 @@ func TestIsReservedIP_Multicast(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
result := isReservedIP(net.ParseIP(tt.ip))
|
||||
result := validation.IsReservedIP(net.ParseIP(tt.ip))
|
||||
if result != tt.expected {
|
||||
t.Errorf("isReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected)
|
||||
t.Errorf("validation.IsReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsReservedIP_Broadcast(t *testing.T) {
|
||||
result := isReservedIP(net.ParseIP("255.255.255.255"))
|
||||
result := validation.IsReservedIP(net.ParseIP("255.255.255.255"))
|
||||
if !result {
|
||||
t.Errorf("isReservedIP(255.255.255.255) = %v, expected true", result)
|
||||
t.Errorf("validation.IsReservedIP(255.255.255.255) = %v, expected true", result)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -320,9 +321,9 @@ func TestIsReservedIP_AllowsPrivateRanges(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
result := isReservedIP(net.ParseIP(tt.ip))
|
||||
result := validation.IsReservedIP(net.ParseIP(tt.ip))
|
||||
if result != tt.expected {
|
||||
t.Errorf("isReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected)
|
||||
t.Errorf("validation.IsReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -340,9 +341,9 @@ func TestIsReservedIP_AllowsPublic(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
result := isReservedIP(net.ParseIP(tt.ip))
|
||||
result := validation.IsReservedIP(net.ParseIP(tt.ip))
|
||||
if result != tt.expected {
|
||||
t.Errorf("isReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected)
|
||||
t.Errorf("validation.IsReservedIP(%s) = %v, expected %v", tt.ip, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
@@ -68,14 +69,34 @@ func (s *SCEPService) GetCACert(ctx context.Context) (string, error) {
|
||||
// PKCSReq processes a SCEP enrollment request.
|
||||
// RFC 8894 Section 3.3.1: PKCSReq contains a PKCS#10 CSR for certificate enrollment.
|
||||
// The CSR PEM and challenge password are extracted by the handler from the PKCS#7 envelope.
|
||||
//
|
||||
// H-2 fix (CWE-306): the previous implementation skipped the shared-secret
|
||||
// check entirely when s.challengePassword was empty, meaning any unauthenticated
|
||||
// client that could reach /scep could enroll a CSR against the configured
|
||||
// issuer. Reject that configuration defense-in-depth even though main() already
|
||||
// refuses to start in the same state (see preflightSCEPChallengePassword). The
|
||||
// non-empty branch now uses crypto/subtle.ConstantTimeCompare to avoid leaking
|
||||
// the shared secret through a response-time side channel.
|
||||
func (s *SCEPService) PKCSReq(ctx context.Context, csrPEM string, challengePassword string, transactionID string) (*domain.SCEPEnrollResult, error) {
|
||||
// Validate challenge password
|
||||
if s.challengePassword != "" {
|
||||
if challengePassword != s.challengePassword {
|
||||
s.logger.Warn("SCEP enrollment rejected: invalid challenge password",
|
||||
"transaction_id", transactionID)
|
||||
return nil, fmt.Errorf("invalid challenge password")
|
||||
}
|
||||
// Defense-in-depth: refuse any enrollment when no shared secret is
|
||||
// configured. The server-level pre-flight check in cmd/server/main.go
|
||||
// normally prevents the service from being constructed in this state, but
|
||||
// this branch also protects future call sites (tests, library reuse, a
|
||||
// future REST-over-HTTPS wrapper) from silently accepting unauthenticated
|
||||
// CSRs.
|
||||
if s.challengePassword == "" {
|
||||
s.logger.Warn("SCEP enrollment rejected: server has no challenge password configured",
|
||||
"transaction_id", transactionID)
|
||||
return nil, fmt.Errorf("SCEP challenge password not configured on server")
|
||||
}
|
||||
// Constant-time compare avoids leaking the configured secret through
|
||||
// response-time variance. ConstantTimeCompare returns 1 only when both
|
||||
// slices have equal length AND equal content; a mismatched-length input
|
||||
// still takes the same path as a content mismatch.
|
||||
if subtle.ConstantTimeCompare([]byte(challengePassword), []byte(s.challengePassword)) != 1 {
|
||||
s.logger.Warn("SCEP enrollment rejected: invalid challenge password",
|
||||
"transaction_id", transactionID)
|
||||
return nil, fmt.Errorf("invalid challenge password")
|
||||
}
|
||||
|
||||
return s.processEnrollment(ctx, csrPEM, transactionID, "scep_pkcsreq")
|
||||
|
||||
@@ -58,11 +58,13 @@ func TestSCEPService_PKCSReq_Success(t *testing.T) {
|
||||
mockIssuer := &mockIssuerConnector{}
|
||||
auditRepo := newMockAuditRepository()
|
||||
auditSvc := NewAuditService(auditRepo)
|
||||
svc := NewSCEPService("iss-local", mockIssuer, auditSvc, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "")
|
||||
// H-2: SCEPService now requires a configured challenge password; the happy
|
||||
// path exercises a matching client-submitted password.
|
||||
svc := NewSCEPService("iss-local", mockIssuer, auditSvc, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "secret123")
|
||||
|
||||
csrPEM := generateCSRPEM(t, "device.example.com", []string{"device.example.com"})
|
||||
|
||||
result, err := svc.PKCSReq(context.Background(), csrPEM, "", "txn-001")
|
||||
result, err := svc.PKCSReq(context.Background(), csrPEM, "secret123", "txn-001")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -81,9 +83,9 @@ func TestSCEPService_PKCSReq_Success(t *testing.T) {
|
||||
|
||||
func TestSCEPService_PKCSReq_InvalidCSR(t *testing.T) {
|
||||
mockIssuer := &mockIssuerConnector{}
|
||||
svc := NewSCEPService("iss-local", mockIssuer, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "")
|
||||
svc := NewSCEPService("iss-local", mockIssuer, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "secret123")
|
||||
|
||||
_, err := svc.PKCSReq(context.Background(), "not-valid-pem", "", "txn-002")
|
||||
_, err := svc.PKCSReq(context.Background(), "not-valid-pem", "secret123", "txn-002")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid CSR")
|
||||
}
|
||||
@@ -91,11 +93,11 @@ func TestSCEPService_PKCSReq_InvalidCSR(t *testing.T) {
|
||||
|
||||
func TestSCEPService_PKCSReq_MissingCN(t *testing.T) {
|
||||
mockIssuer := &mockIssuerConnector{}
|
||||
svc := NewSCEPService("iss-local", mockIssuer, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "")
|
||||
svc := NewSCEPService("iss-local", mockIssuer, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "secret123")
|
||||
|
||||
csrPEM := generateCSRPEM(t, "", []string{"test.example.com"})
|
||||
|
||||
_, err := svc.PKCSReq(context.Background(), csrPEM, "", "txn-003")
|
||||
_, err := svc.PKCSReq(context.Background(), csrPEM, "secret123", "txn-003")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing CN")
|
||||
}
|
||||
@@ -106,11 +108,11 @@ func TestSCEPService_PKCSReq_MissingCN(t *testing.T) {
|
||||
|
||||
func TestSCEPService_PKCSReq_IssuerError(t *testing.T) {
|
||||
mockIssuer := &mockIssuerConnector{Err: errors.New("issuance failed")}
|
||||
svc := NewSCEPService("iss-local", mockIssuer, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "")
|
||||
svc := NewSCEPService("iss-local", mockIssuer, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "secret123")
|
||||
|
||||
csrPEM := generateCSRPEM(t, "test.example.com", nil)
|
||||
|
||||
_, err := svc.PKCSReq(context.Background(), csrPEM, "", "txn-004")
|
||||
_, err := svc.PKCSReq(context.Background(), csrPEM, "secret123", "txn-004")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
@@ -151,19 +153,49 @@ func TestSCEPService_PKCSReq_ChallengePassword_Invalid(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCEPService_PKCSReq_ChallengePassword_NotRequired(t *testing.T) {
|
||||
// When server has no challenge password configured, any value should be accepted
|
||||
// TestSCEPService_PKCSReq_ChallengePassword_EmptyServerConfigRejected is the
|
||||
// H-2 regression guard. Before the fix (internal/service/scep.go:72-79 skipped
|
||||
// the password check when s.challengePassword was empty), an unconfigured
|
||||
// server accepted any enrollment (CWE-306). The service now rejects PKCSReq
|
||||
// defense-in-depth even if main()'s pre-flight is somehow bypassed.
|
||||
func TestSCEPService_PKCSReq_ChallengePassword_EmptyServerConfigRejected(t *testing.T) {
|
||||
mockIssuer := &mockIssuerConnector{}
|
||||
svc := NewSCEPService("iss-local", mockIssuer, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "")
|
||||
|
||||
csrPEM := generateCSRPEM(t, "device.example.com", nil)
|
||||
|
||||
result, err := svc.PKCSReq(context.Background(), csrPEM, "any-value", "txn-007")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
// Any client-submitted password (including empty) must be rejected when
|
||||
// the server has no shared secret configured.
|
||||
for _, clientPassword := range []string{"", "any-value", "guess"} {
|
||||
_, err := svc.PKCSReq(context.Background(), csrPEM, clientPassword, "txn-empty")
|
||||
if err == nil {
|
||||
t.Fatalf("expected rejection when server challenge password is empty (client=%q)", clientPassword)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not configured") {
|
||||
t.Errorf("expected 'not configured' in error, got: %v", err)
|
||||
}
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
// TestSCEPService_PKCSReq_ChallengePassword_ConstantTimeLengthIndependence
|
||||
// guards against regression from crypto/subtle.ConstantTimeCompare to a
|
||||
// short-circuiting byte compare. ConstantTimeCompare returns 0 whenever the
|
||||
// two slices differ in length OR content, so a same-prefix-but-longer input
|
||||
// must be rejected the same way as a completely different string.
|
||||
func TestSCEPService_PKCSReq_ChallengePassword_ConstantTimeLengthIndependence(t *testing.T) {
|
||||
mockIssuer := &mockIssuerConnector{}
|
||||
svc := NewSCEPService("iss-local", mockIssuer, nil, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "secret123")
|
||||
|
||||
csrPEM := generateCSRPEM(t, "device.example.com", nil)
|
||||
|
||||
for _, bad := range []string{"secret", "secret12", "secret1234", "SECRET123", "wrong"} {
|
||||
_, err := svc.PKCSReq(context.Background(), csrPEM, bad, "txn-ct")
|
||||
if err == nil {
|
||||
t.Fatalf("expected rejection for bad password %q", bad)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid challenge password") {
|
||||
t.Errorf("expected 'invalid challenge password' for %q, got: %v", bad, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,12 +203,12 @@ func TestSCEPService_PKCSReq_WithProfile(t *testing.T) {
|
||||
mockIssuer := &mockIssuerConnector{}
|
||||
auditRepo := newMockAuditRepository()
|
||||
auditSvc := NewAuditService(auditRepo)
|
||||
svc := NewSCEPService("iss-local", mockIssuer, auditSvc, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "")
|
||||
svc := NewSCEPService("iss-local", mockIssuer, auditSvc, slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})), "secret123")
|
||||
svc.SetProfileID("profile-mdm-device")
|
||||
|
||||
csrPEM := generateCSRPEM(t, "device.example.com", nil)
|
||||
|
||||
result, err := svc.PKCSReq(context.Background(), csrPEM, "", "txn-008")
|
||||
result, err := svc.PKCSReq(context.Background(), csrPEM, "secret123", "txn-008")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ func newTestTargetService() (*TargetService, *mockTargetRepo, *mockAuditRepo, *m
|
||||
auditSvc := NewAuditService(auditRepo)
|
||||
agentRepo := &mockAgentRepo{Agents: make(map[string]*domain.Agent), HeartbeatUpdates: make(map[string]time.Time)}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
return NewTargetService(targetRepo, auditSvc, agentRepo, nil, logger), targetRepo, auditRepo, agentRepo
|
||||
return NewTargetService(targetRepo, auditSvc, agentRepo, testEncryptionKey, logger), targetRepo, auditRepo, agentRepo
|
||||
}
|
||||
|
||||
func TestTargetService_List_Success(t *testing.T) {
|
||||
|
||||
@@ -12,6 +12,13 @@ import (
|
||||
|
||||
var errNotFound = errors.New("not found")
|
||||
|
||||
// testEncryptionKey is a deterministic 32-byte AES-256 key for unit tests that
|
||||
// exercise IssuerService/TargetService write paths. After the C-2 remediation
|
||||
// these services fail closed when no key is configured, so happy-path tests
|
||||
// must supply a real key. Using a constant keeps wire-format assertions stable
|
||||
// across runs and avoids flaky PBKDF2 timing.
|
||||
var testEncryptionKey = []byte("0123456789abcdef0123456789abcdef") // 32 bytes
|
||||
|
||||
// mockCertRepo is a test implementation of CertificateRepository
|
||||
type mockCertRepo struct {
|
||||
Certs map[string]*domain.ManagedCertificate
|
||||
@@ -271,6 +278,56 @@ func (m *mockJobRepo) ListPendingByAgentID(ctx context.Context, agentID string)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ClaimPendingJobs simulates the H-6 atomic claim semantics: matching rows are transitioned
|
||||
// Pending → Running before being returned. The in-memory mock has no concurrency primitives
|
||||
// beyond the existing mutex, which is sufficient for single-goroutine service tests.
|
||||
func (m *mockJobRepo) ClaimPendingJobs(ctx context.Context, jobType domain.JobType, limit int) ([]*domain.Job, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.ListErr != nil {
|
||||
return nil, m.ListErr
|
||||
}
|
||||
var claimed []*domain.Job
|
||||
for _, j := range m.Jobs {
|
||||
if j.Status != domain.JobStatusPending {
|
||||
continue
|
||||
}
|
||||
if jobType != "" && j.Type != jobType {
|
||||
continue
|
||||
}
|
||||
j.Status = domain.JobStatusRunning
|
||||
claimed = append(claimed, j)
|
||||
if limit > 0 && len(claimed) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
return claimed, nil
|
||||
}
|
||||
|
||||
// ClaimPendingByAgentID simulates the H-6 per-agent claim: Pending deployment rows scoped
|
||||
// to the agent flip to Running; AwaitingCSR rows are returned but keep their state.
|
||||
func (m *mockJobRepo) ClaimPendingByAgentID(ctx context.Context, agentID string) ([]*domain.Job, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.ListErr != nil {
|
||||
return nil, m.ListErr
|
||||
}
|
||||
var result []*domain.Job
|
||||
for _, j := range m.Jobs {
|
||||
if j.AgentID == nil || *j.AgentID != agentID {
|
||||
continue
|
||||
}
|
||||
switch {
|
||||
case j.Status == domain.JobStatusPending && j.Type == domain.JobTypeDeployment:
|
||||
j.Status = domain.JobStatusRunning
|
||||
result = append(result, j)
|
||||
case j.Status == domain.JobStatusAwaitingCSR:
|
||||
result = append(result, j)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockJobRepo) AddJob(job *domain.Job) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -922,9 +979,9 @@ func (m *mockRevocationRepo) Create(ctx context.Context, revocation *domain.Cert
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockRevocationRepo) GetBySerial(ctx context.Context, serial string) (*domain.CertificateRevocation, error) {
|
||||
func (m *mockRevocationRepo) GetByIssuerAndSerial(ctx context.Context, issuerID, serial string) (*domain.CertificateRevocation, error) {
|
||||
for _, r := range m.Revocations {
|
||||
if r.SerialNumber == serial {
|
||||
if r.IssuerID == issuerID && r.SerialNumber == serial {
|
||||
return r, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,6 +69,14 @@ func (m *mockVerificationJobRepo) ListPendingByAgentID(ctx context.Context, agen
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockVerificationJobRepo) ClaimPendingJobs(ctx context.Context, jobType domain.JobType, limit int) ([]*domain.Job, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockVerificationJobRepo) ClaimPendingByAgentID(ctx context.Context, agentID string) ([]*domain.Job, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// newVerificationTestService creates a VerificationService wired with test doubles.
|
||||
func newVerificationTestService(jobs map[string]*domain.Job, jobRepoErr error) (*VerificationService, *mockVerificationJobRepo, *mockAuditRepo) {
|
||||
jobRepo := &mockVerificationJobRepo{jobs: jobs, err: jobRepoErr}
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ValidateHeaderValue rejects any value that contains characters capable of
|
||||
// breaking out of a header line and injecting additional headers or body
|
||||
// content. It guards against CRLF injection (CWE-113) in RFC 5322 message
|
||||
// headers (SMTP, IMAP, etc.) and RFC 7230 HTTP headers alike.
|
||||
//
|
||||
// Disallowed characters:
|
||||
// - Carriage return ("\r")
|
||||
// - Line feed ("\n")
|
||||
// - NUL ("\x00")
|
||||
//
|
||||
// The field name is included in the returned error solely for operator
|
||||
// diagnostics; the offending value is not echoed back, so untrusted input
|
||||
// does not leak into logs that render this error.
|
||||
//
|
||||
// Callers should invoke this on any string that will be interpolated into a
|
||||
// header (From, To, Subject, Reply-To, custom X-* headers, etc.) before the
|
||||
// headers are serialized. Values containing CR/LF/NUL MUST be rejected
|
||||
// outright; silent stripping is inappropriate for authentication-relevant
|
||||
// headers because it can mask malicious intent while still altering the
|
||||
// message.
|
||||
func ValidateHeaderValue(field, value string) error {
|
||||
if field == "" {
|
||||
field = "header"
|
||||
}
|
||||
if strings.ContainsAny(value, "\r\n\x00") {
|
||||
return fmt.Errorf("%s contains disallowed control character (CR, LF, or NUL)", field)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateHeaderValue_AcceptsSafeInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
field string
|
||||
value string
|
||||
}{
|
||||
{"plain ASCII", "Subject", "Renewal reminder"},
|
||||
{"empty string", "Reply-To", ""},
|
||||
{"utf-8 multibyte", "Subject", "résumé — 日本語"},
|
||||
{"tabs and spaces permitted", "Subject", "a\tb c"},
|
||||
{"typical email address", "From", "alerts@example.com"},
|
||||
{"long Subject within limits", "Subject", strings.Repeat("x", 998)},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if err := ValidateHeaderValue(tc.field, tc.value); err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateHeaderValue_RejectsControlCharacters(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
field string
|
||||
value string
|
||||
}{
|
||||
{"injected CRLF + header", "Subject", "hello\r\nBcc: attacker@example.com"},
|
||||
{"lone LF", "From", "alice@example.com\nBcc: x@y"},
|
||||
{"lone CR", "Subject", "hello\rworld"},
|
||||
{"NUL byte", "To", "bob@example.com\x00extra"},
|
||||
{"CRLFCRLF body injection", "Subject", "ping\r\n\r\nMalicious body"},
|
||||
{"CR at end", "Subject", "trailing\r"},
|
||||
{"LF at start", "Subject", "\nleading"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidateHeaderValue(tc.field, tc.value)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error rejecting control characters, got nil")
|
||||
}
|
||||
// Error must mention the field so operators can pinpoint the offender.
|
||||
if !strings.Contains(err.Error(), tc.field) {
|
||||
t.Errorf("expected error to mention field %q, got %q", tc.field, err.Error())
|
||||
}
|
||||
// Error must NOT leak the raw value back into logs.
|
||||
if strings.Contains(err.Error(), tc.value) {
|
||||
t.Errorf("error leaks raw value; expected redaction: %q", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateHeaderValue_DefaultFieldName(t *testing.T) {
|
||||
err := ValidateHeaderValue("", "bad\r\nvalue")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for CRLF input, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "header") {
|
||||
t.Errorf("expected default field name 'header' in error, got %q", err.Error())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,212 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IsReservedIP reports whether the given IP falls inside a range that
|
||||
// outbound HTTP egress (and the network-scanner CIDR expander) MUST treat
|
||||
// as unreachable: loopback, link-local (including cloud-provider metadata
|
||||
// endpoints at 169.254.169.254), multicast, and broadcast.
|
||||
//
|
||||
// RFC 1918 ranges (10/8, 172.16/12, 192.168/16) are intentionally NOT
|
||||
// treated as reserved. certctl is designed to manage certificates inside
|
||||
// private networks and filtering private address space would break the
|
||||
// primary use case. The threat model here is outbound HTTP to
|
||||
// cloud-metadata or localhost services, not general network reachability.
|
||||
//
|
||||
// This function is byte-identical in behaviour to the previous unexported
|
||||
// copy in internal/service/network_scan.go. It is exported here so both
|
||||
// the network scanner and the webhook notifier share a single
|
||||
// authoritative implementation. Broader IPv6 coverage and unspecified-
|
||||
// address handling live in SafeHTTPDialContext, where stricter policy is
|
||||
// appropriate for outbound HTTP egress.
|
||||
func IsReservedIP(ip net.IP) bool {
|
||||
// Loopback: 127.0.0.0/8 (and ::1 via IsLoopback).
|
||||
if ip.IsLoopback() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Link-local: 169.254.0.0/16 (includes cloud metadata 169.254.169.254).
|
||||
if linkLocal := net.ParseIP("169.254.0.0"); linkLocal != nil {
|
||||
if _, linkLocalNet, _ := net.ParseCIDR("169.254.0.0/16"); linkLocalNet != nil {
|
||||
if linkLocalNet.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Multicast: 224.0.0.0/4.
|
||||
if multicast := net.ParseIP("224.0.0.0"); multicast != nil {
|
||||
if _, multicastNet, _ := net.ParseCIDR("224.0.0.0/4"); multicastNet != nil {
|
||||
if multicastNet.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast: 255.255.255.255.
|
||||
if ip.String() == "255.255.255.255" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isReservedIPForDial applies IsReservedIP plus additional ranges that are
|
||||
// meaningful for outbound HTTP egress but were not part of the original
|
||||
// network-scanner filter: the unspecified address (0.0.0.0 / ::) and IPv6
|
||||
// link-local / multicast ranges. Kept private so IsReservedIP stays
|
||||
// byte-identical with the previous scanner behaviour.
|
||||
func isReservedIPForDial(ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return true
|
||||
}
|
||||
if IsReservedIP(ip) {
|
||||
return true
|
||||
}
|
||||
if ip.IsUnspecified() {
|
||||
return true
|
||||
}
|
||||
// IPv6 link-local fe80::/10.
|
||||
if _, n, err := net.ParseCIDR("fe80::/10"); err == nil && n.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
// IPv6 multicast ff00::/8.
|
||||
if _, n, err := net.ParseCIDR("ff00::/8"); err == nil && n.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidateSafeURL parses rawURL and rejects anything that would let an
|
||||
// attacker aim an outbound HTTP client at a SSRF-sensitive destination
|
||||
// (CWE-918). Guards enforced:
|
||||
//
|
||||
// 1. The scheme must be http or https. Schemes like file://, gopher://,
|
||||
// ftp://, data:, javascript:, ldap://, and dict:// are rejected outright;
|
||||
// webhook delivery only speaks HTTP(S).
|
||||
// 2. A hostname must be present. Empty-host URLs like "http:///foo" are
|
||||
// rejected to prevent ambiguous defaulting.
|
||||
// 3. If the host is a literal IP address, the IP must not be reserved
|
||||
// (see isReservedIPForDial). This stops the obvious 127.0.0.1 / ::1 /
|
||||
// 169.254.169.254 / 0.0.0.0 attacks at config time.
|
||||
// 4. If the host is a DNS name and resolution succeeds, every resolved
|
||||
// A/AAAA record must be non-reserved. A single reserved result is
|
||||
// enough to reject. Resolution failure is tolerated (offline CI
|
||||
// environments, short-lived test servers) — the authoritative
|
||||
// enforcement runs at dial time anyway.
|
||||
//
|
||||
// The DNS resolution check here is a best-effort early diagnostic. The
|
||||
// authoritative, TOCTOU-safe enforcement is SafeHTTPDialContext, which
|
||||
// re-checks after resolution at dial time and defeats DNS rebinding.
|
||||
// Callers that need SSRF-safe HTTP egress should use BOTH
|
||||
// ValidateSafeURL (at config ingestion) AND SafeHTTPDialContext
|
||||
// (installed on http.Transport).
|
||||
func ValidateSafeURL(rawURL string) error {
|
||||
if rawURL == "" {
|
||||
return fmt.Errorf("url is required")
|
||||
}
|
||||
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid url: %w", err)
|
||||
}
|
||||
|
||||
scheme := strings.ToLower(u.Scheme)
|
||||
if scheme != "http" && scheme != "https" {
|
||||
return fmt.Errorf("url scheme %q is not allowed; only http and https are permitted", u.Scheme)
|
||||
}
|
||||
|
||||
host := u.Hostname()
|
||||
if host == "" {
|
||||
return fmt.Errorf("url must include a host")
|
||||
}
|
||||
|
||||
// Literal IP? Reject if reserved (strict policy for outbound egress).
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isReservedIPForDial(ip) {
|
||||
return fmt.Errorf("url host resolves to a reserved address and cannot be used")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DNS name. Resolve and reject if any answer is reserved.
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
// Resolution failure is not itself a SSRF signal; let the dial-time
|
||||
// DialContext handle the final decision. This keeps the validator
|
||||
// tolerant of offline validation environments (CI, tests) while
|
||||
// still blocking clearly-bad literal-IP URLs above.
|
||||
return nil
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if isReservedIPForDial(ip) {
|
||||
return fmt.Errorf("url host resolves to a reserved address and cannot be used")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SafeHTTPDialContext returns a DialContext function suitable for
|
||||
// installing on an http.Transport. Every dial attempt resolves the host
|
||||
// again and rejects any connection whose resolved IP lies inside a
|
||||
// reserved range. This is the authoritative SSRF / DNS-rebinding guard:
|
||||
// even if ValidateSafeURL was bypassed, or if DNS changed between
|
||||
// validation and dial, the outbound connection will fail closed.
|
||||
//
|
||||
// The timeout argument bounds both the resolution and the underlying TCP
|
||||
// dial. Pass 0 to use a sensible default (10s).
|
||||
func SafeHTTPDialContext(timeout time.Duration) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if timeout <= 0 {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
dialer := &net.Dialer{
|
||||
Timeout: timeout,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid dial address %q: %w", addr, err)
|
||||
}
|
||||
|
||||
// If the host is already a literal IP, check it directly.
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isReservedIPForDial(ip) {
|
||||
return nil, fmt.Errorf("refusing to dial reserved address %s", ip.String())
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
// Resolve and reject any answer that lands in a reserved range.
|
||||
// We then dial an explicit resolved IP so a racing DNS change
|
||||
// cannot substitute a different (and possibly reserved) answer
|
||||
// between our check and the actual TCP dial.
|
||||
resCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
ips, err := (&net.Resolver{}).LookupIP(resCtx, "ip", host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve %s: %w", host, err)
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("no addresses found for %s", host)
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if isReservedIPForDial(ip) {
|
||||
return nil, fmt.Errorf("refusing to dial %s: resolves to reserved address %s", host, ip.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Dial the first non-reserved resolved IP directly, pinning the
|
||||
// target so later DNS changes cannot redirect us.
|
||||
pinned := net.JoinHostPort(ips[0].String(), port)
|
||||
return dialer.DialContext(ctx, network, pinned)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,230 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIsReservedIP_ByteIdenticalWithNetworkScannerBehavior(t *testing.T) {
|
||||
// These expectations MUST NOT drift from the original unexported
|
||||
// isReservedIP in internal/service/network_scan.go. Any deviation here
|
||||
// is a behaviour change in the network scanner and requires a separate,
|
||||
// deliberate migration.
|
||||
cases := []struct {
|
||||
name string
|
||||
ip string
|
||||
reserved bool
|
||||
}{
|
||||
{"loopback v4", "127.0.0.1", true},
|
||||
{"loopback v4 range upper", "127.255.255.254", true},
|
||||
{"loopback v6", "::1", true},
|
||||
{"AWS metadata", "169.254.169.254", true},
|
||||
{"link-local range edge", "169.254.0.0", true},
|
||||
{"multicast 224", "224.0.0.1", true},
|
||||
{"multicast upper", "239.255.255.255", true},
|
||||
{"broadcast", "255.255.255.255", true},
|
||||
// The original network-scanner filter does NOT include unspecified
|
||||
// or IPv6 link-local, so these must remain non-reserved at this
|
||||
// layer. Stricter outbound-dial policy lives in SafeHTTPDialContext.
|
||||
{"unspecified v4", "0.0.0.0", false},
|
||||
{"IPv6 link-local", "fe80::1", false},
|
||||
{"IPv6 multicast", "ff00::1", false},
|
||||
// RFC 1918 is intentionally allowed (self-hosted design).
|
||||
{"RFC 1918 10/8", "10.0.0.1", false},
|
||||
{"RFC 1918 172.16/12", "172.16.0.1", false},
|
||||
{"RFC 1918 192.168/16", "192.168.1.1", false},
|
||||
// Ordinary public addresses pass.
|
||||
{"public v4", "8.8.8.8", false},
|
||||
{"public v6", "2606:4700:4700::1111", false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ip := net.ParseIP(tc.ip)
|
||||
if ip == nil {
|
||||
t.Fatalf("test setup: failed to parse %q", tc.ip)
|
||||
}
|
||||
if got := IsReservedIP(ip); got != tc.reserved {
|
||||
t.Errorf("IsReservedIP(%s)=%v, want %v", tc.ip, got, tc.reserved)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSafeURL_AcceptsSafePublicURLs(t *testing.T) {
|
||||
cases := []string{
|
||||
"https://example.com/webhook",
|
||||
"http://example.com/hook",
|
||||
"https://example.com:8443/hook",
|
||||
"https://webhook.site/abc-123",
|
||||
"http://10.0.0.5/internal", // RFC 1918 allowed
|
||||
"http://192.168.1.10:8080/webhook", // RFC 1918 allowed
|
||||
"http://172.16.5.1/intranet", // RFC 1918 allowed
|
||||
}
|
||||
for _, raw := range cases {
|
||||
t.Run(raw, func(t *testing.T) {
|
||||
if err := ValidateSafeURL(raw); err != nil {
|
||||
t.Errorf("ValidateSafeURL(%q) unexpectedly failed: %v", raw, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSafeURL_RejectsReservedLiteralIPs(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
url string
|
||||
}{
|
||||
{"loopback v4", "http://127.0.0.1/x"},
|
||||
{"loopback v4 with port", "http://127.0.0.1:8080/"},
|
||||
{"loopback v6 bracketed", "http://[::1]/x"},
|
||||
{"AWS metadata endpoint", "http://169.254.169.254/latest/meta-data/"},
|
||||
{"link-local IP", "http://169.254.1.2/"},
|
||||
{"broadcast", "http://255.255.255.255/"},
|
||||
{"multicast", "https://224.0.0.5/"},
|
||||
{"unspecified v4", "http://0.0.0.0/"},
|
||||
{"unspecified v6", "http://[::]/"},
|
||||
{"IPv6 link-local", "http://[fe80::1]/"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidateSafeURL(tc.url)
|
||||
if err == nil {
|
||||
t.Fatalf("ValidateSafeURL(%q) returned nil, want error", tc.url)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "reserved") {
|
||||
t.Errorf("error should mention 'reserved' for operator diagnostics, got %q", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSafeURL_RejectsDangerousSchemes(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
url string
|
||||
}{
|
||||
{"file scheme", "file:///etc/passwd"},
|
||||
{"gopher scheme", "gopher://example.com/"},
|
||||
{"ftp scheme", "ftp://example.com/"},
|
||||
{"javascript scheme", "javascript:alert(1)"},
|
||||
{"data scheme", "data:text/plain;base64,SGVsbG8="},
|
||||
{"ldap scheme", "ldap://example.com/"},
|
||||
{"dict scheme", "dict://example.com:2628/d:foo"},
|
||||
{"jar scheme", "jar:http://example.com/foo.jar!/"},
|
||||
{"empty scheme", "example.com/hook"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidateSafeURL(tc.url)
|
||||
if err == nil {
|
||||
t.Fatalf("ValidateSafeURL(%q) returned nil, want error", tc.url)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "scheme") && !strings.Contains(err.Error(), "host") {
|
||||
t.Errorf("error should mention scheme or host, got %q", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSafeURL_RejectsMissingHost(t *testing.T) {
|
||||
cases := []string{
|
||||
"http:///foo",
|
||||
"https://",
|
||||
}
|
||||
for _, raw := range cases {
|
||||
t.Run(raw, func(t *testing.T) {
|
||||
err := ValidateSafeURL(raw)
|
||||
if err == nil {
|
||||
t.Fatalf("ValidateSafeURL(%q) returned nil, want error", raw)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSafeURL_RejectsEmpty(t *testing.T) {
|
||||
if err := ValidateSafeURL(""); err == nil {
|
||||
t.Fatal("ValidateSafeURL(\"\") returned nil, want error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSafeURL_RejectsMalformed(t *testing.T) {
|
||||
// url.Parse is famously lax; we lean on the scheme/host checks to catch
|
||||
// malformed inputs that produce empty schemes or hosts.
|
||||
cases := []string{
|
||||
"://missing-scheme",
|
||||
"http//missing-colon.example.com",
|
||||
}
|
||||
for _, raw := range cases {
|
||||
t.Run(raw, func(t *testing.T) {
|
||||
err := ValidateSafeURL(raw)
|
||||
if err == nil {
|
||||
t.Fatalf("ValidateSafeURL(%q) returned nil, want error", raw)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeHTTPDialContext_RejectsLiteralReservedAddress(t *testing.T) {
|
||||
dial := SafeHTTPDialContext(2 * time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cases := []string{
|
||||
"127.0.0.1:9",
|
||||
"169.254.169.254:80",
|
||||
"[::1]:22",
|
||||
"0.0.0.0:80",
|
||||
}
|
||||
for _, addr := range cases {
|
||||
t.Run(addr, func(t *testing.T) {
|
||||
conn, err := dial(ctx, "tcp", addr)
|
||||
if err == nil {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("dial(%q) returned nil err, want reserved-address rejection", addr)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "reserved") {
|
||||
t.Errorf("expected reserved-address rejection, got %q", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeHTTPDialContext_RejectsHostResolvingToReservedAddress(t *testing.T) {
|
||||
// The stdlib resolver treats "localhost" as 127.0.0.1 / ::1 on every
|
||||
// platform we care about; this exercises the post-resolution check and
|
||||
// demonstrates that DNS-rebinding attacks (where a name points at a
|
||||
// reserved IP) are rejected at dial time rather than at validation time.
|
||||
dial := SafeHTTPDialContext(2 * time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := dial(ctx, "tcp", "localhost:9")
|
||||
if err == nil {
|
||||
_ = conn.Close()
|
||||
t.Fatal("dial(localhost:9) returned nil err, want reserved-address rejection")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "reserved") {
|
||||
t.Errorf("expected reserved-address rejection for localhost, got %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeHTTPDialContext_InvalidAddress(t *testing.T) {
|
||||
dial := SafeHTTPDialContext(1 * time.Second)
|
||||
_, err := dial(context.Background(), "tcp", "no-port")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid dial address, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeHTTPDialContext_DefaultTimeoutWhenZero(t *testing.T) {
|
||||
// Not directly observable, but we at least exercise the branch to
|
||||
// prevent a nil-ptr regression if the timeout default is dropped.
|
||||
dial := SafeHTTPDialContext(0)
|
||||
_, err := dial(context.Background(), "tcp", "127.0.0.1:1")
|
||||
if err == nil {
|
||||
t.Fatal("expected reserved-address rejection")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
-- Rollback Migration 000012: Restore global-serial uniqueness.
|
||||
--
|
||||
-- Reverts to the pre-000012 behavior: uniqueness on `serial_number` alone.
|
||||
-- Operators must ensure no duplicate serial_numbers exist across different
|
||||
-- issuers before rolling back, otherwise the unique-index creation will fail.
|
||||
|
||||
DROP INDEX IF EXISTS idx_certificate_revocations_serial_lookup;
|
||||
|
||||
DROP INDEX IF EXISTS idx_certificate_revocations_issuer_serial;
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_certificate_revocations_serial
|
||||
ON certificate_revocations(serial_number);
|
||||
@@ -0,0 +1,31 @@
|
||||
-- Migration 000012: Scope Revocation Uniqueness to (issuer_id, serial_number)
|
||||
--
|
||||
-- RFC 5280 §5.2.3 defines certificate serial number uniqueness per issuing CA.
|
||||
-- The prior global-unique index on `certificate_revocations.serial_number` was
|
||||
-- too strict: certctl supports multiple issuer connectors (Local CA, Vault,
|
||||
-- DigiCert, Sectigo, Google CAS, AWS ACM PCA, step-ca, Entrust, GlobalSign,
|
||||
-- EJBCA, ACME, OpenSSL), and different CAs legitimately issue distinct certs
|
||||
-- that share a serial-number value. Under the old index, recording a
|
||||
-- revocation for such a collision silently dropped via ON CONFLICT DO NOTHING.
|
||||
--
|
||||
-- This migration scopes uniqueness to the (issuer_id, serial_number) pair,
|
||||
-- which matches RFC 5280 and the revocation-recording call site's intent
|
||||
-- (see RevocationSvc.RevokeCertificateWithActor, which already populates
|
||||
-- IssuerID at Create time).
|
||||
--
|
||||
-- Duplicate detection: if any row pairs exist with identical (issuer_id,
|
||||
-- serial_number), the unique-index creation will fail — this is intentional.
|
||||
-- Operators must resolve duplicates manually before re-running the migration.
|
||||
|
||||
-- Drop the overly broad global-serial unique index.
|
||||
DROP INDEX IF EXISTS idx_certificate_revocations_serial;
|
||||
|
||||
-- Recreate uniqueness scoped to (issuer_id, serial_number) per RFC 5280 §5.2.3.
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_certificate_revocations_issuer_serial
|
||||
ON certificate_revocations(issuer_id, serial_number);
|
||||
|
||||
-- Preserve fast serial-only lookup for OCSP/CRL paths that search within a
|
||||
-- known issuer scope. Non-unique — uniqueness is enforced by the composite
|
||||
-- index above.
|
||||
CREATE INDEX IF NOT EXISTS idx_certificate_revocations_serial_lookup
|
||||
ON certificate_revocations(serial_number);
|
||||
@@ -195,6 +195,7 @@ export const issuerTypes: IssuerTypeConfig[] = [
|
||||
{ key: 'api_secret', label: 'API Secret', placeholder: 'GlobalSign API secret', required: true, type: 'password', sensitive: true },
|
||||
{ key: 'client_cert_path', label: 'Client Certificate Path', placeholder: '/path/to/client.crt', required: true },
|
||||
{ key: 'client_key_path', label: 'Client Key Path', placeholder: '/path/to/client.key', required: true, sensitive: true },
|
||||
{ key: 'server_ca_path', label: 'Server CA Path (optional)', placeholder: '/path/to/atlas-ca.pem', required: false },
|
||||
],
|
||||
},
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user