EST RFC 7030 hardening master bundle Phases 2-4: end-to-end mTLS sibling

route + RFC 9266 channel binding + HTTP Basic enrollment-password +
per-source-IP failed-auth limit + per-(CN, sourceIP) sliding-window cap.

Two new shared packages so EST + Intune share infrastructure:
- internal/cms/ — RFC 9266 tls-exporter extractor (ExtractTLSExporter
  with stdlib-panic recovery for synthetic ConnectionStates) +
  CSR-side channel-binding parser via raw TBSCertificationRequestInfo
  walk (the stdlib's csr.Attributes can't represent the OCTET STRING
  binding value), VerifyChannelBinding composite, EmbedChannel-
  BindingAttribute fixture helper, typed sentinel errors for missing
  / mismatch / not-TLS-1.3 mapped to HTTP 400 / 409 / 426 in handler.
- internal/trustanchor/ — extracted from scep/intune/trust_anchor*.go
  so the EST mTLS sibling route + Intune dispatcher share the same
  SIGHUP-reloadable PEM bundle primitive. intune.TrustAnchorHolder
  is now `= trustanchor.Holder` (type alias) + NewTrustAnchorHolder =
  trustanchor.New (function alias) — every existing call site compiles
  unchanged. Intune's LoadTrustAnchor is a thin wrapper over
  trustanchor.LoadBundle. White-box tests moved to the new package.
- internal/ratelimit/ — extracted from scep/intune/rate_limit.go (this
  was Phase 4.1, in the same bundle). intune.PerDeviceRateLimiter
  is now a thin wrapper preserving the (subject, issuer)→key
  composition; EST handler reaches for SlidingWindowLimiter directly.

ESTHandler grew six optional fields wired by per-profile setters
(SetMTLSTrust / SetChannelBindingRequired / SetEnrollmentPassword /
SetSourceIPRateLimiter / SetPerPrincipalRateLimiter / SetLabelForLog)
plus four new mTLS-route methods (CACertsMTLS / SimpleEnrollMTLS /
SimpleReEnrollMTLS / CSRAttrsMTLS); shared internal pipeline
handleEnrollOrReEnroll(reEnroll, viaMTLS) keeps the auth/binding/
rate-limit gates DRY. New router method RegisterESTMTLSHandlers
registers /.well-known/est-mtls/<PathID>/{cacerts,simpleenroll,
simplereenroll,csrattrs}; AuthExemptDispatchPrefixes extends the
no-auth chain to /.well-known/est-mtls.

cmd/server/main.go's EST loop wires per-profile mTLS holder +
channel-binding policy + per-principal limiter + (when EnrollmentPassword
non-empty) Basic + source-IP limiter; new preflightESTMTLSClientCATrust-
Bundle returns *trustanchor.Holder so SIGHUP rotates the EST mTLS
bundle live without restart. SCEP + EST mTLS profiles now share a
single union mtlsUnionPoolForTLS passed to buildServerTLSConfigWithMTLS
(replaces the protocol-specific scepMTLSUnionPoolForTLS); per-handler
re-verify enforces "cert must chain to THIS profile's bundle" so
cross-protocol bleed is blocked at the application layer even though
the TLS layer trusts certs from either pool's union.

Phase 3.3 source-IP failed-Basic limiter defaults: 10 attempts / 1h
/ 50k tracked IPs (no env var; tunable in a follow-up). Phase 4.2
per-principal limiter cap from CERTCTL_EST_PROFILE_<NAME>_RATE_
LIMIT_PER_PRINCIPAL_24H (existing field, Phase 1 shipped).

New tests:
- internal/cms/channelbinding_test.go: extractor + CSR-side parser +
  composite + TLS-1.3 round-trip end-to-end + EmbedChannelBinding-
  Attribute round-trip
- internal/trustanchor/holder_test.go: parseBundlePEM white-box +
  LoadBundle + Holder Get/Pool/SetLabelForLog/Reload-happy/
  Reload-keeps-old-on-failure/Reload-keeps-old-on-expired/
  WatchSIGHUP-reloads-pool/WatchSIGHUP-stop-clean
- internal/api/handler/est_hardening_test.go: 16 named cases covering
  mTLS no-trust-pool 500 + no-cert 401 + cross-profile cert 401 +
  happy-path 200 + CACertsMTLS auth gate + CSRAttrsMTLS auth gate +
  channel-binding required-absent-rejected + not-required-absent-
  allowed + writeChannelBindingError mapping + Basic no-header 401
  + Basic wrong-password 401 + Basic correct-200 + Basic-no-password
  no-gate + per-IP failed-attempt lockout 429 + per-principal
  blocks-after-cap + different-principals-independent + no-limiter-
  unbounded.

Pre-commit verification (sandbox): gofmt clean, go vet clean
(excluding repository/postgres which the sandbox can't build —
disk-space testcontainers download), staticcheck clean for
cms/trustanchor/api/handler/api/router/scep/intune/ratelimit/
cmd/server, go test -short -count=1 green for cms/trustanchor/
api/handler/api/router/scep/intune/ratelimit/service. G-3
docs-drift guard reproduced locally clean (Phase 1 already
documented every new env var; Phases 2-4 added zero new env vars).
This commit is contained in:
shankar0123
2026-04-29 23:15:35 +00:00
parent 6cedaf4231
commit 34518b2e66
17 changed files with 3273 additions and 728 deletions
+186 -17
View File
@@ -31,10 +31,12 @@ import (
notifyteams "github.com/shankar0123/certctl/internal/connector/notifier/teams"
"github.com/shankar0123/certctl/internal/crypto/signer"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/ratelimit"
"github.com/shankar0123/certctl/internal/repository/postgres"
"github.com/shankar0123/certctl/internal/scep/intune"
"github.com/shankar0123/certctl/internal/scheduler"
"github.com/shankar0123/certctl/internal/service"
"github.com/shankar0123/certctl/internal/trustanchor"
)
func main() {
@@ -736,8 +738,24 @@ func main() {
// mirrors the SCEP audit-closure pattern (cmd/server/main.go::
// preflightSCEPIntuneTrustAnchor signature took pathID for exactly
// this reason).
// EST RFC 7030 hardening master bundle Phase 2 + SCEP RFC 8894 +
// Intune master bundle Phase 6.5 SHARED union pool: every protocol's
// mTLS profiles contribute their trust certs here so a single TLS
// listener accepts client certs from EITHER protocol's profiles, and
// the per-handler gate re-verifies that the cert chains to THIS
// profile's bundle. Allocated lazily by whichever protocol first
// opts in (left nil when no profile opted in across both protocols
// — buildServerTLSConfigWithMTLS treats nil as 'no mTLS').
var mtlsUnionPoolForTLS *x509.CertPool
// estMTLSStopWatchers collects every per-profile trust-anchor
// SIGHUP-watcher stop func so we can shut them down on server exit
// (mirrors intuneStopWatchers below).
var estMTLSStopWatchers []func()
if cfg.EST.Enabled {
estHandlers := make(map[string]handler.ESTHandler, len(cfg.EST.Profiles))
estMTLSHandlers := make(map[string]handler.ESTHandler)
estMTLSAnyEnabled := false
for i, profile := range cfg.EST.Profiles {
profile := profile // shadow for closure-safety
profileLog := logger.With(
@@ -769,7 +787,102 @@ func main() {
if profile.ProfileID != "" {
estService.SetProfileID(profile.ProfileID)
}
estHandlers[profile.PathID] = handler.NewESTHandler(estService)
estHandler := handler.NewESTHandler(estService)
estHandler.SetLabelForLog(fmt.Sprintf("est (PathID=%q)", profile.PathID))
// Phase 3.1: HTTP Basic enrollment password. Only takes effect
// on the standard /.well-known/est/<PathID>/ route — the mTLS
// sibling skips it because the client cert IS the auth signal.
if profile.EnrollmentPassword != "" {
estHandler.SetEnrollmentPassword(profile.EnrollmentPassword)
// Phase 3.3: per-source-IP failed-auth rate limit.
// Defaults: 10 failed attempts / 1 hour / 50k tracked IPs.
// Hard-coded for now (no env var); a tuning bundle can lift
// these once we've watched real production deploys for a
// release. The shared SlidingWindowLimiter applies the same
// math the SCEP/Intune limiter uses — extracted in Phase 4.1
// of this bundle so both call sites share the implementation.
failed := ratelimit.NewSlidingWindowLimiter(10, time.Hour, 50_000)
estHandler.SetSourceIPRateLimiter(failed)
}
// Phase 2.1: mTLS sibling route. When MTLSEnabled=true, build a
// per-profile SIGHUP-reloadable trust-anchor holder, splice the
// bundle's certs into the EST mTLS union pool, and clone the
// handler with the per-profile trust + channel-binding policy
// so SimpleEnrollMTLS / SimpleReEnrollMTLS verify against just
// THIS profile's bundle.
if profile.MTLSEnabled {
holder, err := preflightESTMTLSClientCATrustBundle(true, profile.PathID, profile.MTLSClientCATrustBundlePath, profileLog)
if err != nil {
profileLog.Error(
"startup refused: EST profile MTLS trust bundle preflight failed "+
"(EST hardening Phase 2: required when MTLS_ENABLED=true). "+
"Verify the bundle file exists at MTLS_CLIENT_CA_TRUST_BUNDLE_PATH, "+
"is readable, parses as PEM, contains ≥1 CERTIFICATE block, "+
"and none of the bundled certs are past NotAfter.",
"error", err,
)
os.Exit(1)
}
// Merge this profile's certs into the union pool the TLS
// layer uses for VerifyClientCertIfGiven. Walk the bundle
// directly so the union pool gets exactly the same certs
// as the per-profile pool (mirrors SCEP's pattern at the
// equivalent loop iteration).
if mtlsUnionPoolForTLS == nil {
mtlsUnionPoolForTLS = x509.NewCertPool()
}
bundleBytes, _ := os.ReadFile(profile.MTLSClientCATrustBundlePath)
rest := bundleBytes
for {
var block *pem.Block
block, rest = pem.Decode(rest)
if block == nil {
break
}
if block.Type != "CERTIFICATE" {
continue
}
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
mtlsUnionPoolForTLS.AddCert(cert)
}
}
estMTLSAnyEnabled = true
// Build the mTLS sibling-route handler with the per-profile
// trust pool, channel-binding policy, and (if configured)
// per-principal rate limiter.
mtlsHandler := handler.NewESTHandler(estService)
mtlsHandler.SetLabelForLog(fmt.Sprintf("est-mtls (PathID=%q)", profile.PathID))
mtlsHandler.SetMTLSTrust(holder)
mtlsHandler.SetChannelBindingRequired(profile.ChannelBindingRequired)
if profile.RateLimitPerPrincipal24h > 0 {
perPrincipal := ratelimit.NewSlidingWindowLimiter(profile.RateLimitPerPrincipal24h, 24*time.Hour, 100_000)
mtlsHandler.SetPerPrincipalRateLimiter(perPrincipal)
}
estMTLSHandlers[profile.PathID] = mtlsHandler
// Install the SIGHUP watcher so an operator that rotates
// the mTLS trust bundle file gets the new pool live without
// a server restart. Watcher stop func is collected for
// orderly shutdown via the defer below.
estMTLSStopWatchers = append(estMTLSStopWatchers, holder.WatchSIGHUP())
profileLog.Info("EST mTLS sibling route enabled",
"endpoint", "/.well-known/est-mtls/"+profile.PathID,
"client_ca_trust_bundle", profile.MTLSClientCATrustBundlePath,
"channel_binding_required", profile.ChannelBindingRequired,
)
}
// Phase 4.2: per-principal rate limiter on the standard route
// too (additive — both routes share the same per-(CN, IP) cap
// when configured). The mTLS handler above gets its own
// limiter instance so the two routes don't share a bucket.
if profile.RateLimitPerPrincipal24h > 0 {
perPrincipal := ratelimit.NewSlidingWindowLimiter(profile.RateLimitPerPrincipal24h, 24*time.Hour, 100_000)
estHandler.SetPerPrincipalRateLimiter(perPrincipal)
}
estHandlers[profile.PathID] = estHandler
endpoint := "/.well-known/est"
if profile.PathID != "" {
@@ -785,18 +898,30 @@ func main() {
)
}
apiRouter.RegisterESTHandlers(estHandlers)
logger.Info("EST server enabled", "profile_count", len(cfg.EST.Profiles))
if estMTLSAnyEnabled {
apiRouter.RegisterESTMTLSHandlers(estMTLSHandlers)
logger.Info("EST mTLS sibling route enabled (Phase 2)",
"mtls_profile_count", len(estMTLSHandlers),
)
}
logger.Info("EST server enabled",
"profile_count", len(cfg.EST.Profiles),
"mtls_profile_count", len(estMTLSHandlers),
)
// Stop SIGHUP watchers in LIFO on server shutdown.
if len(estMTLSStopWatchers) > 0 {
defer func() {
for _, stop := range estMTLSStopWatchers {
stop()
}
}()
}
}
// SCEP RFC 8894 Phase 6.5: union pool of every enabled mTLS profile's
// trust bundle. Populated inside the SCEP startup block below; passed
// to the TLS-config builder later so the listener accepts client certs
// signed by ANY mTLS profile's CA. The handler-layer gate
// (HandleSCEPMTLS) re-verifies per-profile, so a cert that chains to
// profile A's bundle cannot enroll against profile B even though it
// passes the TLS-layer union check. Stays nil when no profile opted in
// (the TLS config builder treats nil as 'no mTLS').
var scepMTLSUnionPoolForTLS *x509.CertPool
// EST RFC 7030 hardening master bundle Phase 2: SCEP's mTLS union pool
// merged into the SHARED mtlsUnionPoolForTLS variable declared above.
// Variables here intentionally renamed to make the merge explicit.
// Register SCEP (RFC 8894) handlers if enabled.
//
@@ -821,7 +946,6 @@ func main() {
// bundle to prevent cross-profile bleed-through).
scepHandlers := make(map[string]handler.SCEPHandler, len(cfg.SCEP.Profiles))
scepMTLSHandlers := make(map[string]handler.SCEPHandler)
scepMTLSUnionPool := x509.NewCertPool()
scepMTLSAnyEnabled := false
// SCEP RFC 8894 + Intune master bundle Phase 8: per-profile Intune
// trust anchor holders. We track them here so a single SIGHUP
@@ -1017,7 +1141,10 @@ func main() {
continue
}
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
scepMTLSUnionPool.AddCert(cert)
if mtlsUnionPoolForTLS == nil {
mtlsUnionPoolForTLS = x509.NewCertPool()
}
mtlsUnionPoolForTLS.AddCert(cert)
}
}
scepMTLSAnyEnabled = true
@@ -1049,7 +1176,6 @@ func main() {
// no-op-when-disabled case obvious in logs.
if scepMTLSAnyEnabled {
apiRouter.RegisterSCEPMTLSHandlers(scepMTLSHandlers)
scepMTLSUnionPoolForTLS = scepMTLSUnionPool
logger.Info("SCEP mTLS sibling route enabled (Phase 6.5)",
"mtls_profile_count", len(scepMTLSHandlers),
)
@@ -1317,7 +1443,7 @@ func main() {
// sibling route gates additionally on the verified client cert.
// nil pool = no profile opted in = identical TLS shape to the
// pre-Phase-6.5 buildServerTLSConfig path.
TLSConfig: buildServerTLSConfigWithMTLS(tlsCertHolder, scepMTLSUnionPoolForTLS),
TLSConfig: buildServerTLSConfigWithMTLS(tlsCertHolder, mtlsUnionPoolForTLS),
ReadTimeout: 30 * time.Second,
ReadHeaderTimeout: 5 * time.Second,
WriteTimeout: 120 * time.Second, // Must accommodate ACME issuance (order + challenge + finalize)
@@ -1476,6 +1602,41 @@ func preflightSCEPMTLSTrustBundle(enabled bool, bundlePath string) (*x509.CertPo
return pool, nil
}
// preflightESTMTLSClientCATrustBundle validates a per-profile EST mTLS
// client-CA trust bundle and returns a SIGHUP-reloadable holder.
//
// EST RFC 7030 hardening master bundle Phase 2.5.
//
// Mirrors preflightSCEPMTLSTrustBundle's checks (file exists, parses as
// PEM, ≥1 cert, none expired) but returns a *trustanchor.Holder rather
// than a raw *x509.CertPool — the EST handler stores the holder so a
// SIGHUP rotates the trust bundle live without a server restart, exactly
// the way the Intune trust anchor rotation works (Phase 8.5 of the SCEP
// bundle). The handler-side .Pool() accessor on the holder rebuilds an
// x509.CertPool from the current snapshot for each Verify call.
//
// Uses the shared internal/trustanchor.LoadBundle (extracted in EST
// hardening Phase 2.1 from the original Intune-only path) so the EST
// + Intune callers exercise the same loader semantics — empty bundle
// rejected, expired cert rejected with subject in error message,
// non-CERTIFICATE PEM blocks tolerated.
func preflightESTMTLSClientCATrustBundle(enabled bool, pathID, bundlePath string, logger *slog.Logger) (*trustanchor.Holder, error) {
if !enabled {
return nil, nil
}
if bundlePath == "" {
return nil, fmt.Errorf("EST profile (PathID=%q) MTLS enabled but trust bundle path empty: "+
"set CERTCTL_EST_PROFILE_<NAME>_MTLS_CLIENT_CA_TRUST_BUNDLE_PATH to a PEM file "+
"containing the bootstrap-CA certs the operator allows to enroll", pathID)
}
holder, err := trustanchor.New(bundlePath, logger)
if err != nil {
return nil, fmt.Errorf("EST profile (PathID=%q) MTLS trust bundle preflight: %w", pathID, err)
}
holder.SetLabelForLog(fmt.Sprintf("EST mTLS client CA bundle (PathID=%q)", pathID))
return holder, nil
}
// preflightSCEPIntuneTrustAnchor validates a per-profile Microsoft Intune
// Certificate Connector signing-cert trust bundle.
//
@@ -1745,9 +1906,17 @@ func buildFinalHandler(apiHandler, noAuthHandler http.Handler, webDir string, da
}
// RFC 7030 EST endpoints ride the no-auth middleware chain (M-001,
// option D, audit 2026-04-19). Trust boundary is CSR signature + profile
// policy, not HTTP Bearer. /.well-known/est/cacerts is explicitly
// anonymous per RFC 7030 §4.1.1.
// option D, audit 2026-04-19). Trust boundary is CSR signature +
// (per EST hardening Phase 2) optional client cert at the handler
// layer, not HTTP Bearer. /.well-known/est/cacerts is explicitly
// anonymous per RFC 7030 §4.1.1; /.well-known/est-mtls/<PathID>/
// (EST hardening Phase 2 sibling route) requires a client cert
// gate at the handler layer — both share this prefix gate because
// "/.well-known/est-mtls" is itself prefixed by "/.well-known/est".
// EST hardening Phase 3's HTTP Basic enrollment-password is a
// per-profile handler-layer auth that runs INSIDE the no-auth
// middleware chain (since the chain skips the Bearer middleware,
// the handler gets to define its own auth contract).
if strings.HasPrefix(path, "/.well-known/est") {
noAuthHandler.ServeHTTP(w, r)
return
+15 -9
View File
@@ -136,21 +136,27 @@ func buildServerTLSConfig(holder *certHolder) *tls.Config {
}
// buildServerTLSConfigWithMTLS extends buildServerTLSConfig with a client-cert
// trust pool for the SCEP RFC 8894 + Intune master bundle Phase 6.5 mTLS
// sibling route. SCEP profiles that opt into mTLS each contribute their
// trust bundle to the union pool here; the same TLS listener serves both
// /scep[/<pathID>] (no client cert) and /scep-mtls/<pathID> (cert required
// at the handler layer).
// trust pool for the SCEP/EST mTLS sibling routes.
//
// SCEP RFC 8894 + Intune master bundle Phase 6.5 introduced this for the
// /scep-mtls/<pathID> route; EST RFC 7030 hardening master bundle Phase 2
// extended it so the same TLS listener also serves /.well-known/est-mtls/
// <pathID>. Both protocols' mTLS profiles contribute their trust bundles
// to a UNION pool that the caller (cmd/server/main.go) builds by walking
// every enabled mTLS profile's bundle bytes once. The per-protocol
// handlers re-verify against just THIS profile's bundle (so an EST-mTLS
// bootstrap cert can't enroll against a SCEP-mTLS profile and vice versa).
//
// ClientAuth: VerifyClientCertIfGiven — request a cert during handshake; if
// the client presents one, verify it against the union pool; if absent, the
// request still reaches the handler and the per-route handler decides
// whether to accept. Critical that we do NOT use RequireAndVerifyClientCert
// here — that would break the standard /scep route (which is challenge-
// password-only, no client cert expected).
// here — that would break the standard /scep + /.well-known/est routes
// (challenge-password-only / unauth-or-Basic, no client cert expected).
//
// Pass clientCAs == nil to disable mTLS (no profile opted in). The function
// then returns the same shape as buildServerTLSConfig.
// Pass clientCAs == nil to disable mTLS (no profile opted in across either
// protocol). The function then returns the same shape as
// buildServerTLSConfig.
func buildServerTLSConfigWithMTLS(holder *certHolder, clientCAs *x509.CertPool) *tls.Config {
cfg := buildServerTLSConfig(holder)
if clientCAs != nil {
+599 -225
View File
@@ -2,17 +2,23 @@ package handler
import (
"context"
"crypto/subtle"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"github.com/shankar0123/certctl/internal/api/middleware"
"github.com/shankar0123/certctl/internal/cms"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/pkcs7"
"github.com/shankar0123/certctl/internal/ratelimit"
"github.com/shankar0123/certctl/internal/trustanchor"
)
// ESTService defines the service interface for EST enrollment operations.
@@ -33,62 +39,558 @@ type ESTService interface {
// ESTHandler handles HTTP requests for the EST protocol (RFC 7030).
//
// EST endpoints are served under /.well-known/est/ per the RFC.
// EST endpoints are served under /.well-known/est/[<PathID>/] per RFC 7030.
// Wire format: base64-encoded DER (PKCS#7 for certs, PKCS#10 for CSRs).
//
// Supported operations:
// - GET /.well-known/est/cacerts — CA certificate distribution
// - POST /.well-known/est/simpleenroll — initial enrollment
// - POST /.well-known/est/simplereenroll — re-enrollment
// - GET /.well-known/est/csrattrs — CSR attributes
// Supported operations (per route family):
//
// /.well-known/est/[<PathID>/] — legacy + per-profile route family
// GET cacerts — CA certificate distribution
// POST simpleenroll — initial enrollment (HTTP Basic optional, Phase 3)
// POST simplereenroll — re-enrollment (HTTP Basic optional, Phase 3)
// GET csrattrs — CSR attributes
//
// /.well-known/est-mtls/<PathID>/ — mTLS sibling (Phase 2)
// GET cacerts — CA certificate distribution (cert auth required)
// POST simpleenroll — initial enrollment (cert + optional channel binding)
// POST simplereenroll — re-enrollment (cert + optional channel binding)
// GET csrattrs — CSR attributes
//
// EST RFC 7030 hardening master bundle Phases 2-4: ESTHandler grew six
// optional fields wired by per-profile setters in cmd/server/main.go's
// startup loop. None of the new fields are required — a handler with all
// of them unset behaves exactly like the v2.0.x EST handler.
type ESTHandler struct {
svc ESTService
// EST RFC 7030 hardening Phase 2.1: per-profile mTLS client-CA trust
// bundle. When set, the mTLS sibling route (CACertsMTLS /
// SimpleEnrollMTLS / etc.) verifies the inbound client cert chain
// against this pool. Nil when MTLS_ENABLED=false; the mTLS route
// rejects unconditionally in that case (the route shouldn't even be
// registered, but defense in depth).
mtlsTrust *trustanchor.Holder
// EST RFC 7030 hardening Phase 2.4: per-profile channel-binding
// requirement. When true, the mTLS handler refuses simplereenroll
// requests whose CSR doesn't carry a matching id-aa-est-tls-exporter
// (RFC 9266) attribute. Phase 1's Validate() guards
// ChannelBindingRequired=true + MTLSEnabled=false at startup.
channelBindingRequired bool
// EST RFC 7030 hardening Phase 3.1: per-profile HTTP Basic enrollment
// password. When non-empty, the standard /.well-known/est/<PathID>/
// route requires `Authorization: Basic <base64(<user>:<pw>)>` on the
// enrollment endpoints (NOT on cacerts/csrattrs — RFC 7030 §4.1.1
// says cacerts is anonymous). Constant-time compare; per-source-IP
// failed-auth rate limit blocks brute-force.
basicPassword string
// EST RFC 7030 hardening Phase 3.3: per-handler source-IP rate
// limiter for FAILED HTTP Basic auth attempts. Keyed by sourceIP so
// a hostile network segment can't burn through the password.
failedBasicLimiter *ratelimit.SlidingWindowLimiter
// EST RFC 7030 hardening Phase 4.2: per-handler per-principal sliding-
// window rate limit. Keyed by (CSR-CN, sourceIP) so a stolen
// bootstrap cert AND a known device CN can't be used to flood the
// issuer. Disabled when nil; configured per-profile.
perPrincipalLimiter *ratelimit.SlidingWindowLimiter
// labelForLog gives observability code a per-profile string to
// include in audit log lines / Prometheus labels. Defaults to
// "est" when unset.
labelForLog string
}
// NewESTHandler creates a new ESTHandler.
// NewESTHandler creates a new ESTHandler with no per-profile auth
// hardening. Call SetMTLSTrust + SetChannelBindingRequired +
// SetEnrollmentPassword + SetSourceIPRateLimiter + SetPerPrincipalRateLimiter
// from the per-profile startup loop to opt-in to each surface.
func NewESTHandler(svc ESTService) ESTHandler {
return ESTHandler{svc: svc}
}
// CACerts handles GET /.well-known/est/cacerts
// Returns the CA certificate chain as base64-encoded PKCS#7 (certs-only).
// Per RFC 7030 Section 4.1, this is a "certs-only" CMC Simple PKI Response.
// For simplicity and broad client compatibility, we return base64-encoded DER certificates.
// SetMTLSTrust injects the per-profile client-cert trust pool the
// `/.well-known/est-mtls/<PathID>/` sibling route uses to verify inbound
// device cert chains. EST RFC 7030 hardening Phase 2.1.
//
// Like the SCEP equivalent, the TLS layer (cmd/server/tls.go) uses
// VerifyClientCertIfGiven against the UNION of every enabled mTLS
// profile's bundle, so the same TLS listener serves both /.well-known/est
// (anonymous or HTTP Basic) and /.well-known/est-mtls/<PathID>
// (cert-required). The per-profile gate at the handler layer enforces
// 'cert must chain to THIS profile's bundle' so a cert that chains to
// profile A's bundle cannot enroll against profile B.
func (h *ESTHandler) SetMTLSTrust(t *trustanchor.Holder) { h.mtlsTrust = t }
// SetChannelBindingRequired toggles RFC 9266 tls-exporter channel binding
// on the simplereenroll mTLS path. EST RFC 7030 hardening Phase 2.4.
// When true, the handler refuses requests whose CSR lacks the binding
// attribute or whose binding bytes don't match the live TLS exporter.
func (h *ESTHandler) SetChannelBindingRequired(req bool) { h.channelBindingRequired = req }
// SetEnrollmentPassword injects the per-profile HTTP Basic enrollment
// password. EST RFC 7030 hardening Phase 3.1. Empty disables the gate
// (mTLS-only or unauthenticated profile). Constant-time compare via
// crypto/subtle.ConstantTimeCompare.
func (h *ESTHandler) SetEnrollmentPassword(pw string) { h.basicPassword = pw }
// SetSourceIPRateLimiter injects the per-handler failed-Basic-auth
// rate limiter. Phase 3.3. Disabled when nil — but Validate() at
// startup refuses an enabled basic-auth profile without a configured
// limiter, so a real deploy always wires one.
func (h *ESTHandler) SetSourceIPRateLimiter(l *ratelimit.SlidingWindowLimiter) {
h.failedBasicLimiter = l
}
// SetPerPrincipalRateLimiter injects the per-handler (CN, sourceIP)
// sliding-window rate limiter. Phase 4.2. Disabled when nil. Counts
// every successful enrollment, NOT just failures — the goal is to
// bound enrollment-flooding from a compromised credential, not just
// failed-auth brute force.
func (h *ESTHandler) SetPerPrincipalRateLimiter(l *ratelimit.SlidingWindowLimiter) {
h.perPrincipalLimiter = l
}
// SetLabelForLog sets the per-profile observability label. Defaults to
// "est" when unset; cmd/server/main.go's per-profile loop sets this
// to "est (PathID=<id>)" for triage.
func (h *ESTHandler) SetLabelForLog(label string) {
if label == "" {
return
}
h.labelForLog = label
}
// label returns h.labelForLog with the "est" fallback applied. Tiny
// helper so log call sites don't need to repeat the fallback.
func (h ESTHandler) label() string {
if h.labelForLog == "" {
return "est"
}
return h.labelForLog
}
// ----- /.well-known/est/[<PathID>/] route family (legacy + Basic auth) -----
// CACerts handles GET /.well-known/est/[<PathID>/]cacerts.
//
// RFC 7030 §4.1.1 — anonymous endpoint. The HTTP Basic gate is NOT
// applied here (any client must be able to fetch the CA chain to
// verify subsequent enrollment responses).
func (h ESTHandler) CACerts(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
h.writeCACertsResponse(w, r)
}
caCertPEM, err := h.svc.GetCACerts(r.Context())
if err != nil {
requestID := middleware.GetRequestID(r.Context())
ErrorWithRequestID(w, http.StatusInternalServerError, fmt.Sprintf("Failed to get CA certificates: %v", err), requestID)
// SimpleEnroll handles POST /.well-known/est/[<PathID>/]simpleenroll.
// Accepts a base64-encoded PKCS#10 CSR + returns base64-encoded PKCS#7.
//
// Auth: HTTP Basic when h.basicPassword != "" (Phase 3); otherwise
// anonymous. Rate-limit: per-(CN, sourceIP) when wired (Phase 4).
func (h ESTHandler) SimpleEnroll(w http.ResponseWriter, r *http.Request) {
h.handleEnrollOrReEnroll(w, r, false /*reEnroll*/, false /*viaMTLS*/)
}
// SimpleReEnroll handles POST /.well-known/est/[<PathID>/]simplereenroll.
// Same as SimpleEnroll but the audit/log distinguishes the renewal flow
// from initial issuance.
func (h ESTHandler) SimpleReEnroll(w http.ResponseWriter, r *http.Request) {
h.handleEnrollOrReEnroll(w, r, true /*reEnroll*/, false /*viaMTLS*/)
}
// CSRAttrs handles GET /.well-known/est/[<PathID>/]csrattrs.
// Returns the CSR attributes the server wants the client to include.
// RFC 7030 §4.5 — anonymous endpoint, no Basic auth gate.
func (h ESTHandler) CSRAttrs(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
h.writeCSRAttrsResponse(w, r)
}
// ----- /.well-known/est-mtls/<PathID>/ route family (Phase 2 mTLS) -----
// CACertsMTLS handles GET /.well-known/est-mtls/<PathID>/cacerts.
//
// RFC 7030 §4.1.1 says cacerts is anonymous, but on the mTLS sibling
// route we still require a valid client cert because the mTLS path is
// the audit-distinguished surface — operators using mTLS WANT every
// touchpoint logged. The cert isn't validated for purpose-of-issuance
// here (cacerts isn't an enrollment), but absence is rejected.
func (h ESTHandler) CACertsMTLS(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if _, ok := h.requireClientCertChain(w, r); !ok {
return
}
h.writeCACertsResponse(w, r)
}
// SimpleEnrollMTLS handles POST /.well-known/est-mtls/<PathID>/simpleenroll.
//
// Order of gates (each fails fast with the appropriate HTTP status):
//
// 1. Client cert presented + chains to per-profile mTLS trust pool
// (the TLS layer already verified against the union pool; this is
// the per-profile re-verify that prevents profile A↔B cross-bleed).
// 2. CSR parses + matches the EST contract (handled by the shared
// enrollment helper).
// 3. Per-(CN, sourceIP) rate limit when configured.
// 4. Service-layer enrollment.
//
// Channel binding does NOT apply here — RFC 9266 §1 calls out that
// channel binding is a renewal-time defense-in-depth, not an initial-
// enrollment requirement. (A first-time enrollment doesn't yet have a
// device cert, so binding to the TLS session for the bootstrap cert
// adds nothing.)
func (h ESTHandler) SimpleEnrollMTLS(w http.ResponseWriter, r *http.Request) {
if _, ok := h.requireClientCertChain(w, r); !ok {
return
}
h.handleEnrollOrReEnroll(w, r, false /*reEnroll*/, true /*viaMTLS*/)
}
// SimpleReEnrollMTLS handles POST /.well-known/est-mtls/<PathID>/simplereenroll.
//
// Same as SimpleEnrollMTLS plus the channel-binding gate. RFC 9266 §4.1
// says renewal CSRs SHOULD include the binding attribute when the
// enrollment is over a TLS-1.3 channel; per-profile policy can either
// require this strictly (ChannelBindingRequired=true) or accept its
// absence (default).
func (h ESTHandler) SimpleReEnrollMTLS(w http.ResponseWriter, r *http.Request) {
if _, ok := h.requireClientCertChain(w, r); !ok {
return
}
h.handleEnrollOrReEnroll(w, r, true /*reEnroll*/, true /*viaMTLS*/)
}
// CSRAttrsMTLS handles GET /.well-known/est-mtls/<PathID>/csrattrs.
// Mirrors CACertsMTLS — cert-required even though the unauth route
// version is anonymous.
func (h ESTHandler) CSRAttrsMTLS(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if _, ok := h.requireClientCertChain(w, r); !ok {
return
}
h.writeCSRAttrsResponse(w, r)
}
// ----- shared internal pipeline -----
// handleEnrollOrReEnroll is the shared body for {Simple,SimpleRe}Enroll{,MTLS}.
// reEnroll picks the SimpleReEnroll vs SimpleEnroll service method (purely
// audit / metric distinguishing — same issuer call underneath); viaMTLS
// picks whether the channel-binding + per-principal-limit gates apply
// AND skips the HTTP Basic gate (mTLS handlers carry the auth).
func (h ESTHandler) handleEnrollOrReEnroll(w http.ResponseWriter, r *http.Request, reEnroll, viaMTLS bool) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Parse PEM to DER for PKCS#7 encoding
requestID := middleware.GetRequestID(r.Context())
if err := verifyESTTransport(r); err != nil {
ErrorWithRequestID(w, http.StatusBadRequest,
fmt.Sprintf("EST transport precondition failed: %v", err), requestID)
return
}
// HTTP Basic gate (Phase 3) — non-mTLS path only. mTLS profiles
// authenticate via the client cert so adding Basic on top would
// double-tax operators with no security benefit.
if !viaMTLS && h.basicPassword != "" {
if !h.requireBasicAuth(w, r) {
return
}
}
csrPEM, err := h.readCSRFromRequest(r)
if err != nil {
ErrorWithRequestID(w, http.StatusBadRequest, fmt.Sprintf("Invalid CSR: %v", err), requestID)
return
}
// Parse the CSR once for downstream gates (channel-binding, per-
// principal rate limit). The service re-parses internally — that's a
// minor inefficiency we accept to keep the service interface flat.
csr, _ := decodeCSRPEM(csrPEM)
// Channel-binding gate (Phase 2.4) — mTLS reEnroll only. The optional
// CSR-side attribute is checked even when the per-profile flag isn't
// requiring it (a CSR carrying the attribute MUST match the live
// exporter; a present-but-mismatched binding is always fatal).
if viaMTLS && reEnroll && csr != nil {
if err := cms.VerifyChannelBinding(r.TLS, csr, h.channelBindingRequired); err != nil {
h.writeChannelBindingError(w, requestID, err)
return
}
}
// Per-principal rate-limit gate (Phase 4.2). Keyed by CN+sourceIP so
// (a) a CN with no source-IP rotation can be capped, AND (b) a
// hostile network segment trying to enroll many CNs from one IP is
// also bounded.
if h.perPrincipalLimiter != nil {
if err := h.applyPerPrincipalRateLimit(r, csr); err != nil {
ErrorWithRequestID(w, http.StatusTooManyRequests,
fmt.Sprintf("EST enrollment rate-limited: %v", err), requestID)
return
}
}
var (
result *domain.ESTEnrollResult
callErr error
)
if reEnroll {
result, callErr = h.svc.SimpleReEnroll(r.Context(), csrPEM)
} else {
result, callErr = h.svc.SimpleEnroll(r.Context(), csrPEM)
}
if callErr != nil {
op := "Enrollment"
if reEnroll {
op = "Re-enrollment"
}
ErrorWithRequestID(w, http.StatusInternalServerError,
fmt.Sprintf("%s failed: %v", op, callErr), requestID)
return
}
h.writeCertResponse(w, result)
}
// requireClientCertChain enforces the mTLS gate for the est-mtls sibling
// route. Returns the leaf cert + true on success; on failure writes the
// HTTP error and returns false.
//
// Mirrors SCEPHandler.HandleSCEPMTLS exactly:
// - mtlsTrust nil → 500 (config bug; preflight should have prevented).
// - r.TLS nil or no peer cert → 401 (cert required).
// - chain doesn't verify against per-profile pool → 401.
func (h ESTHandler) requireClientCertChain(w http.ResponseWriter, r *http.Request) (*x509.Certificate, bool) {
requestID := middleware.GetRequestID(r.Context())
if h.mtlsTrust == nil {
ErrorWithRequestID(w, http.StatusInternalServerError,
h.label()+" mTLS handler missing trust pool", requestID)
return nil, false
}
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
ErrorWithRequestID(w, http.StatusUnauthorized,
"Client certificate required for /.well-known/est-mtls", requestID)
return nil, false
}
leaf := r.TLS.PeerCertificates[0]
intermediates := x509.NewCertPool()
for _, c := range r.TLS.PeerCertificates[1:] {
intermediates.AddCert(c)
}
if _, err := leaf.Verify(x509.VerifyOptions{
Roots: h.mtlsTrust.Pool(),
Intermediates: intermediates,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageAny},
}); err != nil {
ErrorWithRequestID(w, http.StatusUnauthorized,
"Client certificate not trusted by this profile", requestID)
return nil, false
}
return leaf, true
}
// requireBasicAuth runs the Phase 3 HTTP Basic password gate. Returns
// true when auth passed. On failure writes WWW-Authenticate + a 401
// (with rate-limit accounting against the source IP).
//
// User: any non-empty value (RFC 7030 §3.2.3 says the username is
// not authoritative when only a shared password is meaningful). Pass:
// constant-time compare against h.basicPassword.
func (h ESTHandler) requireBasicAuth(w http.ResponseWriter, r *http.Request) bool {
requestID := middleware.GetRequestID(r.Context())
srcIP := clientIPForLimiter(r)
// recordFailedBasic ticks a slot on every credential rejection;
// once the IP has burned through its window's worth of failed
// attempts the limiter returns ErrRateLimited (which the next
// recordFailedBasic just no-ops out — we still want to fail-closed
// the auth here). The cleaner design is a pre-check that short-
// circuits the constant-time compare ENTIRELY for an IP at-cap, so
// a brute-force attacker can't smuggle timing data through. We do
// that pre-check via SlidingWindowLimiter.Allow with a peek-style
// fake-key that just queries state without recording a slot.
if h.failedBasicLimiter != nil && srcIP != "" {
if err := h.failedBasicLimiter.Allow(srcIP+"|peek", nowFn()); errors.Is(err, ratelimit.ErrRateLimited) {
// peek-key is shared across requests from this IP; the slot
// pollution is acceptable because the IP is already
// rate-limited and we want to keep them rate-limited.
ErrorWithRequestID(w, http.StatusTooManyRequests,
h.label()+" too many failed enrollment attempts from this source", requestID)
return false
}
}
user, pass, ok := r.BasicAuth()
if !ok || user == "" {
w.Header().Set("WWW-Authenticate", `Basic realm="est-enrollment"`)
ErrorWithRequestID(w, http.StatusUnauthorized,
h.label()+" enrollment requires HTTP Basic auth", requestID)
h.recordFailedBasic(srcIP)
return false
}
if subtle.ConstantTimeCompare([]byte(pass), []byte(h.basicPassword)) != 1 {
w.Header().Set("WWW-Authenticate", `Basic realm="est-enrollment"`)
ErrorWithRequestID(w, http.StatusUnauthorized,
h.label()+" enrollment password incorrect", requestID)
h.recordFailedBasic(srcIP)
return false
}
return true
}
// recordFailedBasic ticks a slot against the source-IP failed-auth
// limiter. Errors from Allow are intentionally ignored — a present
// failure simply means the IP has crossed the limit, which is exactly
// the state the per-IP gate reports back to the next request.
func (h ESTHandler) recordFailedBasic(srcIP string) {
if h.failedBasicLimiter == nil || srcIP == "" {
return
}
_ = h.failedBasicLimiter.Allow(srcIP, nowFn())
}
// applyPerPrincipalRateLimit gates an enrollment by (CN, sourceIP).
// Returns nil when the request is allowed; ErrRateLimited (or wrapped
// equivalent) when the principal has exhausted its window budget.
//
// CN extraction: the CSR's Subject.CommonName is the canonical
// principal in the EST contract (the issued cert will carry that CN).
// sourceIP comes from clientIPForLimiter.
func (h ESTHandler) applyPerPrincipalRateLimit(r *http.Request, csr *x509.CertificateRequest) error {
if h.perPrincipalLimiter == nil {
return nil
}
cn := ""
if csr != nil {
cn = csr.Subject.CommonName
}
srcIP := clientIPForLimiter(r)
key := cn + "|" + srcIP
return h.perPrincipalLimiter.Allow(key, nowFn())
}
// writeChannelBindingError maps cms.* sentinel errors to HTTP statuses
// + audit-friendly messages. Mirrors the SCEP CertRep failInfo error
// translation pattern (signature_invalid → BadMessageCheck etc.).
func (h ESTHandler) writeChannelBindingError(w http.ResponseWriter, requestID string, err error) {
switch {
case errors.Is(err, cms.ErrChannelBindingMissing):
ErrorWithRequestID(w, http.StatusBadRequest,
"EST simplereenroll requires RFC 9266 channel binding for this profile", requestID)
case errors.Is(err, cms.ErrChannelBindingMismatch):
// 409 Conflict signals to the client that the request was
// well-formed but the channel-binding state on certctl's side
// disagreed with the device's — usually MITM or reverse proxy
// terminating TLS in front of certctl.
ErrorWithRequestID(w, http.StatusConflict,
"EST channel binding does not match TLS exporter — TLS terminator in front of certctl?", requestID)
case errors.Is(err, cms.ErrChannelBindingNotTLS13):
ErrorWithRequestID(w, http.StatusUpgradeRequired,
"EST channel binding requires TLS 1.3", requestID)
default:
ErrorWithRequestID(w, http.StatusBadRequest,
fmt.Sprintf("EST channel-binding verification failed: %v", err), requestID)
}
}
// ----- response writers (legacy + mTLS share these) -----
// writeCACertsResponse writes the PKCS#7 certs-only CA chain. Shared
// by CACerts (legacy route) + CACertsMTLS (mTLS route).
func (h ESTHandler) writeCACertsResponse(w http.ResponseWriter, r *http.Request) {
requestID := middleware.GetRequestID(r.Context())
caCertPEM, err := h.svc.GetCACerts(r.Context())
if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError,
fmt.Sprintf("Failed to get CA certificates: %v", err), requestID)
return
}
derCerts, err := pkcs7.PEMToDERChain(caCertPEM)
if err != nil {
requestID := middleware.GetRequestID(r.Context())
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to encode CA certificates", requestID)
return
}
// Build a simple PKCS#7 SignedData (certs-only, degenerate) structure
pkcs7Data, err := pkcs7.BuildCertsOnlyPKCS7(derCerts)
if err != nil {
requestID := middleware.GetRequestID(r.Context())
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to build PKCS#7 response", requestID)
return
}
// RFC 7030 Section 4.1.3: response is base64-encoded application/pkcs7-mime
w.Header().Set("Content-Type", "application/pkcs7-mime; smime-type=certs-only")
w.Header().Set("Content-Transfer-Encoding", "base64")
w.WriteHeader(http.StatusOK)
encoded := base64.StdEncoding.EncodeToString(pkcs7Data)
// Write base64 with line breaks at 76 chars per RFC 2045
writeBase64Wrapped(w, pkcs7Data)
}
// writeCSRAttrsResponse writes the per-profile CSR attribute hints.
// Shared by CSRAttrs (legacy) + CSRAttrsMTLS (mTLS).
func (h ESTHandler) writeCSRAttrsResponse(w http.ResponseWriter, r *http.Request) {
requestID := middleware.GetRequestID(r.Context())
attrs, err := h.svc.GetCSRAttrs(r.Context())
if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError,
fmt.Sprintf("Failed to get CSR attributes: %v", err), requestID)
return
}
if len(attrs) == 0 {
w.WriteHeader(http.StatusNoContent)
return
}
w.Header().Set("Content-Type", "application/csrattrs")
w.Header().Set("Content-Transfer-Encoding", "base64")
w.WriteHeader(http.StatusOK)
w.Write([]byte(base64.StdEncoding.EncodeToString(attrs)))
}
// writeCertResponse writes an EST enrollment response as base64-encoded PKCS#7.
func (h ESTHandler) writeCertResponse(w http.ResponseWriter, result *domain.ESTEnrollResult) {
var derCerts [][]byte
certDER, err := pkcs7.PEMToDERChain(result.CertPEM)
if err != nil || len(certDER) == 0 {
http.Error(w, "Failed to encode certificate", http.StatusInternalServerError)
return
}
derCerts = append(derCerts, certDER...)
if result.ChainPEM != "" {
chainDER, err := pkcs7.PEMToDERChain(result.ChainPEM)
if err == nil {
derCerts = append(derCerts, chainDER...)
}
}
pkcs7Data, err := pkcs7.BuildCertsOnlyPKCS7(derCerts)
if err != nil {
http.Error(w, "Failed to build PKCS#7 response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/pkcs7-mime; smime-type=certs-only")
w.Header().Set("Content-Transfer-Encoding", "base64")
w.WriteHeader(http.StatusOK)
writeBase64Wrapped(w, pkcs7Data)
}
// writeBase64Wrapped emits b as base64 with CRLF every 76 chars per RFC 2045.
// Pulled out as a helper so the three writers above don't repeat the loop.
func writeBase64Wrapped(w http.ResponseWriter, b []byte) {
encoded := base64.StdEncoding.EncodeToString(b)
for i := 0; i < len(encoded); i += 76 {
end := i + 76
if end > len(encoded) {
@@ -99,66 +601,84 @@ func (h ESTHandler) CACerts(w http.ResponseWriter, r *http.Request) {
}
}
// SimpleEnroll handles POST /.well-known/est/simpleenroll
// Accepts a base64-encoded PKCS#10 CSR and returns a base64-encoded PKCS#7 certificate.
func (h ESTHandler) SimpleEnroll(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
requestID := middleware.GetRequestID(r.Context())
if err := verifyESTTransport(r); err != nil {
ErrorWithRequestID(w, http.StatusBadRequest, fmt.Sprintf("EST transport precondition failed: %v", err), requestID)
return
}
csrPEM, err := h.readCSRFromRequest(r)
// readCSRFromRequest reads and decodes the CSR from an EST enrollment request.
// EST sends CSRs as base64-encoded PKCS#10 DER with Content-Type application/pkcs10.
func (h ESTHandler) readCSRFromRequest(r *http.Request) (string, error) {
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) // 1MB limit
if err != nil {
ErrorWithRequestID(w, http.StatusBadRequest, fmt.Sprintf("Invalid CSR: %v", err), requestID)
return
return "", fmt.Errorf("failed to read request body: %w", err)
}
defer r.Body.Close()
if len(body) == 0 {
return "", fmt.Errorf("empty request body")
}
result, err := h.svc.SimpleEnroll(r.Context(), csrPEM)
bodyStr := strings.TrimSpace(string(body))
if strings.HasPrefix(bodyStr, "-----BEGIN CERTIFICATE REQUEST-----") {
block, _ := pem.Decode([]byte(bodyStr))
if block == nil {
return "", fmt.Errorf("invalid PEM-encoded CSR")
}
if _, err := x509.ParseCertificateRequest(block.Bytes); err != nil {
return "", fmt.Errorf("invalid CSR: %w", err)
}
return bodyStr, nil
}
derBytes, err := base64.StdEncoding.DecodeString(bodyStr)
if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, fmt.Sprintf("Enrollment failed: %v", err), requestID)
return
cleaned := strings.Map(func(r rune) rune {
if r == '\r' || r == '\n' || r == ' ' || r == '\t' {
return -1
}
return r
}, bodyStr)
derBytes, err = base64.StdEncoding.DecodeString(cleaned)
if err != nil {
return "", fmt.Errorf("failed to decode base64 CSR: %w", err)
}
}
h.writeCertResponse(w, result)
if _, err := x509.ParseCertificateRequest(derBytes); err != nil {
return "", fmt.Errorf("invalid PKCS#10 CSR: %w", err)
}
csrPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: derBytes,
})
return string(csrPEM), nil
}
// SimpleReEnroll handles POST /.well-known/est/simplereenroll
// Same as SimpleEnroll but for re-enrollment (certificate renewal).
func (h ESTHandler) SimpleReEnroll(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
// decodeCSRPEM is a convenience wrapper around pem.Decode +
// x509.ParseCertificateRequest. Returns nil on any decode/parse error
// (callers downstream re-parse via the service path; this is just for
// the handler-side gates that need the CN + binding attribute).
func decodeCSRPEM(csrPEM string) (*x509.CertificateRequest, error) {
block, _ := pem.Decode([]byte(csrPEM))
if block == nil {
return nil, fmt.Errorf("PEM decode failed")
}
requestID := middleware.GetRequestID(r.Context())
if err := verifyESTTransport(r); err != nil {
ErrorWithRequestID(w, http.StatusBadRequest, fmt.Sprintf("EST transport precondition failed: %v", err), requestID)
return
}
csrPEM, err := h.readCSRFromRequest(r)
if err != nil {
ErrorWithRequestID(w, http.StatusBadRequest, fmt.Sprintf("Invalid CSR: %v", err), requestID)
return
}
result, err := h.svc.SimpleReEnroll(r.Context(), csrPEM)
if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, fmt.Sprintf("Re-enrollment failed: %v", err), requestID)
return
}
h.writeCertResponse(w, result)
return x509.ParseCertificateRequest(block.Bytes)
}
// clientIPForLimiter returns the source IP a per-IP rate limiter should
// key against. Honors X-Forwarded-For when the request came through a
// trusted proxy (no proxy-trust list yet — falls back to RemoteAddr).
func clientIPForLimiter(r *http.Request) string {
// Don't blindly trust XFF — ignore it for now and always use
// RemoteAddr. A future bundle can add a documented proxy-trust list.
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}
// nowFn is the package-private time source. Override in tests for
// deterministic clock injection without dragging time.Time into the
// handler API surface. Defined in est_clock.go so mocking out
// requires touching only one file.
// verifyESTTransport implements Bundle-4 / M-021 EST transport precondition.
//
// RFC 7030 §3.2.3 ("Linking Identity and POP Information") requires that when
@@ -169,32 +689,11 @@ func (h ESTHandler) SimpleReEnroll(w http.ResponseWriter, r *http.Request) {
// TLS-Unique is unavailable; RFC 9266 defines `tls-exporter` as the TLS 1.3
// replacement.
//
// **Current scope of this function (Bundle-4 closure):** certctl does NOT
// currently support EST client certificate authentication. The EST endpoint
// accepts unauthenticated POSTs (the SCEP equivalent enforces a
// challenge-password via `preflightSCEPChallengePassword`; EST has no
// equivalent today). Per RFC 7030 §3.2.3, channel binding is REQUIRED only
// when client certificate authentication is in use; without that, the §3.2.3
// requirement is moot.
//
// What we DO enforce here as defense-in-depth:
//
// 1. r.TLS must be non-nil — the EST endpoint MUST be reached over TLS.
// Defensive: certctl pins HTTPS-only at the server-side TLS config, but
// a future routing-layer regression that exposes EST over plaintext
// would be caught here.
// 2. Negotiated TLS version must be >= TLS 1.2 — RFC 7030 doesn't mandate
// a specific TLS version, but a pre-1.2 negotiation indicates a
// misconfigured client/server pair. certctl's MinVersion is TLS 1.3
// so this should always hold.
// 3. r.TLS.HandshakeComplete must be true — defensive against partial-
// handshake replays.
//
// **Deferred to a future bundle (operator decision required):**
//
// - RFC 9266 `tls-exporter` channel binding when EST mTLS is added.
// - EST mTLS support itself — currently EST is unauth-or-bearer; mTLS
// would be a V3-aligned compliance feature.
// **EST RFC 7030 hardening Phases 2-4 update:** RFC 9266 channel binding is
// now wired in via the cms package (Phase 2.4) and called from
// SimpleReEnrollMTLS when the per-profile policy requires it. This function
// continues to handle the lower-level transport preconditions that ALL EST
// requests share (regardless of mTLS / Basic / unauth profile shape).
//
// Returns nil if all preconditions pass; non-nil error otherwise.
func verifyESTTransport(r *http.Request) error {
@@ -213,130 +712,5 @@ func verifyESTTransport(r *http.Request) error {
return nil
}
// CSRAttrs handles GET /.well-known/est/csrattrs
// Returns the CSR attributes the server wants the client to include in enrollment requests.
func (h ESTHandler) CSRAttrs(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
attrs, err := h.svc.GetCSRAttrs(r.Context())
if err != nil {
requestID := middleware.GetRequestID(r.Context())
ErrorWithRequestID(w, http.StatusInternalServerError, fmt.Sprintf("Failed to get CSR attributes: %v", err), requestID)
return
}
if len(attrs) == 0 {
// No specific attributes required — return 204
w.WriteHeader(http.StatusNoContent)
return
}
w.Header().Set("Content-Type", "application/csrattrs")
w.Header().Set("Content-Transfer-Encoding", "base64")
w.WriteHeader(http.StatusOK)
w.Write([]byte(base64.StdEncoding.EncodeToString(attrs)))
}
// readCSRFromRequest reads and decodes the CSR from an EST enrollment request.
// EST sends CSRs as base64-encoded PKCS#10 DER with Content-Type application/pkcs10.
func (h ESTHandler) readCSRFromRequest(r *http.Request) (string, error) {
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) // 1MB limit
if err != nil {
return "", fmt.Errorf("failed to read request body: %w", err)
}
defer r.Body.Close()
if len(body) == 0 {
return "", fmt.Errorf("empty request body")
}
// Check if it's already PEM-encoded (some clients send PEM directly)
bodyStr := strings.TrimSpace(string(body))
if strings.HasPrefix(bodyStr, "-----BEGIN CERTIFICATE REQUEST-----") {
// Validate it parses
block, _ := pem.Decode([]byte(bodyStr))
if block == nil {
return "", fmt.Errorf("invalid PEM-encoded CSR")
}
if _, err := x509.ParseCertificateRequest(block.Bytes); err != nil {
return "", fmt.Errorf("invalid CSR: %w", err)
}
return bodyStr, nil
}
// EST standard: base64-encoded DER PKCS#10
derBytes, err := base64.StdEncoding.DecodeString(bodyStr)
if err != nil {
// Try with padding/whitespace stripped
cleaned := strings.Map(func(r rune) rune {
if r == '\r' || r == '\n' || r == ' ' || r == '\t' {
return -1
}
return r
}, bodyStr)
derBytes, err = base64.StdEncoding.DecodeString(cleaned)
if err != nil {
return "", fmt.Errorf("failed to decode base64 CSR: %w", err)
}
}
// Validate it's a valid PKCS#10 CSR
if _, err := x509.ParseCertificateRequest(derBytes); err != nil {
return "", fmt.Errorf("invalid PKCS#10 CSR: %w", err)
}
// Convert DER to PEM for internal use (certctl services expect PEM)
csrPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: derBytes,
})
return string(csrPEM), nil
}
// writeCertResponse writes an EST enrollment response as base64-encoded PKCS#7.
func (h ESTHandler) writeCertResponse(w http.ResponseWriter, result *domain.ESTEnrollResult) {
// Parse cert and chain PEM to DER
var derCerts [][]byte
// Add the issued certificate
certDER, err := pkcs7.PEMToDERChain(result.CertPEM)
if err != nil || len(certDER) == 0 {
http.Error(w, "Failed to encode certificate", http.StatusInternalServerError)
return
}
derCerts = append(derCerts, certDER...)
// Add the CA chain if present
if result.ChainPEM != "" {
chainDER, err := pkcs7.PEMToDERChain(result.ChainPEM)
if err == nil {
derCerts = append(derCerts, chainDER...)
}
}
// Build PKCS#7 certs-only
pkcs7Data, err := pkcs7.BuildCertsOnlyPKCS7(derCerts)
if err != nil {
http.Error(w, "Failed to build PKCS#7 response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/pkcs7-mime; smime-type=certs-only")
w.Header().Set("Content-Transfer-Encoding", "base64")
w.WriteHeader(http.StatusOK)
encoded := base64.StdEncoding.EncodeToString(pkcs7Data)
for i := 0; i < len(encoded); i += 76 {
end := i + 76
if end > len(encoded) {
end = len(encoded)
}
w.Write([]byte(encoded[i:end]))
w.Write([]byte("\r\n"))
}
}
// NOTE: PKCS#7 helpers (BuildCertsOnlyPKCS7, PEMToDERChain, ASN.1 wrappers)
// are in the shared internal/pkcs7 package, used by both EST and SCEP handlers.
+15
View File
@@ -0,0 +1,15 @@
package handler
import "time"
// EST RFC 7030 hardening Phase 3.3 / 4.2: nowFn is the time source that
// the EST handler's per-IP failed-Basic-auth limiter and per-(CN,
// sourceIP) rate limiter consult. Tests can override this to inject a
// deterministic clock without dragging time.Time into the handler API
// surface (the handler's setters take ratelimit.SlidingWindowLimiter
// pointers, not time-injection callbacks — keeping the wire-up simple).
//
// nowFn is package-private + lower-case so external callers can't poke
// at it; the est_clock_test.go helper restoreNowFn is the documented
// override pattern for tests in this package.
var nowFn = time.Now
+459
View File
@@ -0,0 +1,459 @@
package handler
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/pem"
"io"
"log/slog"
"math/big"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/shankar0123/certctl/internal/cms"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/ratelimit"
"github.com/shankar0123/certctl/internal/trustanchor"
)
// EST RFC 7030 hardening master bundle Phases 2-4 tests.
// Covers: mTLS sibling route gates, HTTP Basic enrollment-password auth,
// per-source-IP failed-auth rate limit, RFC 9266 channel binding, and
// per-(CN, sourceIP) per-principal sliding-window rate limit.
// hardeningTestSetup is a per-test fixture: a mock service that always
// succeeds, plus a CA + issued client cert that an mTLS test can attach
// to its synthetic *http.Request.TLS.
type hardeningTestSetup struct {
svc *mockESTService
caCert *x509.Certificate
caKey *ecdsa.PrivateKey
clientCrt *x509.Certificate
clientKey *ecdsa.PrivateKey
trustPool *trustanchor.Holder
bundleDir string
}
func newHardeningTestSetup(t *testing.T) *hardeningTestSetup {
t.Helper()
caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("ca key: %v", err)
}
caTmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "est-mtls-test-ca"},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
IsCA: true,
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageCertSign,
}
caDER, err := x509.CreateCertificate(rand.Reader, caTmpl, caTmpl, &caKey.PublicKey, caKey)
if err != nil {
t.Fatalf("ca create: %v", err)
}
caCert, _ := x509.ParseCertificate(caDER)
clientKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("client key: %v", err)
}
clientTmpl := &x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{CommonName: "test-device-001"},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
clientDER, err := x509.CreateCertificate(rand.Reader, clientTmpl, caCert, &clientKey.PublicKey, caKey)
if err != nil {
t.Fatalf("client create: %v", err)
}
clientCrt, _ := x509.ParseCertificate(clientDER)
// Persist the CA bundle on disk so trustanchor.New can load it.
dir := t.TempDir()
bundlePath := filepath.Join(dir, "trust.pem")
body := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caDER})
if err := os.WriteFile(bundlePath, body, 0o600); err != nil {
t.Fatalf("write bundle: %v", err)
}
holder, err := trustanchor.New(bundlePath, slog.New(slog.NewTextHandler(io.Discard, nil)))
if err != nil {
t.Fatalf("trustanchor.New: %v", err)
}
svc := &mockESTService{
CACertPEM: pemCertString(caDER),
EnrollResult: &domain.ESTEnrollResult{
CertPEM: pemCertString(clientDER),
},
}
return &hardeningTestSetup{
svc: svc,
caCert: caCert,
caKey: caKey,
clientCrt: clientCrt,
clientKey: clientKey,
trustPool: holder,
bundleDir: dir,
}
}
func pemCertString(der []byte) string {
return string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}))
}
// makeMTLSRequest synthesises a POST against `path` with PEM CSR body and
// r.TLS populated with the given peer cert chain + handshake state. Used
// by the mTLS path tests where a real TLS handshake would force us into a
// full httptest.NewTLSServer setup.
func makeMTLSRequest(t *testing.T, path, csrPEM string, peerCerts []*x509.Certificate, version uint16) *http.Request {
t.Helper()
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(csrPEM))
req.TLS = &tls.ConnectionState{
HandshakeComplete: true,
Version: version,
PeerCertificates: peerCerts,
}
return req
}
// ----- mTLS handler gate -----
func TestSimpleEnrollMTLS_NoTrustPool_500(t *testing.T) {
s := newHardeningTestSetup(t)
h := NewESTHandler(s.svc) // intentionally do NOT call SetMTLSTrust
req := makeMTLSRequest(t, "/.well-known/est-mtls/corp/simpleenroll",
generateTestCSRPEM(t), []*x509.Certificate{s.clientCrt}, tls.VersionTLS13)
w := httptest.NewRecorder()
h.SimpleEnrollMTLS(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("status = %d, want 500 (handler missing trust pool)", w.Code)
}
}
func TestSimpleEnrollMTLS_NoClientCert_401(t *testing.T) {
s := newHardeningTestSetup(t)
h := NewESTHandler(s.svc)
h.SetMTLSTrust(s.trustPool)
req := httptest.NewRequest(http.MethodPost, "/.well-known/est-mtls/corp/simpleenroll",
strings.NewReader(generateTestCSRPEM(t)))
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
w := httptest.NewRecorder()
h.SimpleEnrollMTLS(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401 (no client cert)", w.Code)
}
}
func TestSimpleEnrollMTLS_CertNotInPool_401(t *testing.T) {
s := newHardeningTestSetup(t)
other := newHardeningTestSetup(t) // different CA, unrelated to s.trustPool
h := NewESTHandler(s.svc)
h.SetMTLSTrust(s.trustPool)
req := makeMTLSRequest(t, "/.well-known/est-mtls/corp/simpleenroll",
generateTestCSRPEM(t), []*x509.Certificate{other.clientCrt}, tls.VersionTLS13)
w := httptest.NewRecorder()
h.SimpleEnrollMTLS(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401 (cert not trusted by this profile)", w.Code)
}
}
func TestSimpleEnrollMTLS_HappyPath_200(t *testing.T) {
s := newHardeningTestSetup(t)
h := NewESTHandler(s.svc)
h.SetMTLSTrust(s.trustPool)
req := makeMTLSRequest(t, "/.well-known/est-mtls/corp/simpleenroll",
generateTestCSRPEM(t), []*x509.Certificate{s.clientCrt}, tls.VersionTLS13)
w := httptest.NewRecorder()
h.SimpleEnrollMTLS(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200; body=%q", w.Code, w.Body.String())
}
}
// ----- channel binding (Phase 2.4) -----
func TestSimpleReEnrollMTLS_ChannelBindingRequired_AbsentRejected(t *testing.T) {
s := newHardeningTestSetup(t)
h := NewESTHandler(s.svc)
h.SetMTLSTrust(s.trustPool)
h.SetChannelBindingRequired(true)
// CSR has no binding attribute. Synthetic ConnectionState — exporter
// extraction will fail (no real TLS secret), and required=true makes
// VerifyChannelBinding propagate that as the missing-binding error.
req := makeMTLSRequest(t, "/.well-known/est-mtls/corp/simplereenroll",
generateTestCSRPEM(t), []*x509.Certificate{s.clientCrt}, tls.VersionTLS13)
w := httptest.NewRecorder()
h.SimpleReEnrollMTLS(w, req)
// Either 400 (missing) or 426 (TLS 1.3 unavailable on synthetic state).
// Both are correct refusals; pin to "non-2xx" so the test isn't fragile
// against ConnectionState evolution.
if w.Code/100 == 2 {
t.Errorf("required + absent must reject; got 2xx (%d)", w.Code)
}
}
func TestSimpleReEnrollMTLS_ChannelBindingNotRequired_AbsentAllowed(t *testing.T) {
s := newHardeningTestSetup(t)
h := NewESTHandler(s.svc)
h.SetMTLSTrust(s.trustPool)
h.SetChannelBindingRequired(false)
// CSR has no binding, profile is opt-in only. The handler must allow.
req := makeMTLSRequest(t, "/.well-known/est-mtls/corp/simplereenroll",
generateTestCSRPEM(t), []*x509.Certificate{s.clientCrt}, tls.VersionTLS13)
w := httptest.NewRecorder()
h.SimpleReEnrollMTLS(w, req)
if w.Code != http.StatusOK {
t.Errorf("required=false + absent must allow; got %d (%s)", w.Code, w.Body.String())
}
}
func TestWriteChannelBindingError_KnownErrorsMapped(t *testing.T) {
// Smoke test the error-to-status mapping so a future cms sentinel rename
// gets caught at compile time + we hit each branch.
s := newHardeningTestSetup(t)
h := NewESTHandler(s.svc)
cases := []struct {
err error
want int
}{
{cms.ErrChannelBindingMissing, http.StatusBadRequest},
{cms.ErrChannelBindingMismatch, http.StatusConflict},
{cms.ErrChannelBindingNotTLS13, http.StatusUpgradeRequired},
}
for _, c := range cases {
w := httptest.NewRecorder()
h.writeChannelBindingError(w, "req-id", c.err)
if w.Code != c.want {
t.Errorf("error=%v → status %d, want %d", c.err, w.Code, c.want)
}
}
}
// ----- HTTP Basic enrollment-password (Phase 3) -----
func TestSimpleEnroll_BasicAuth_NoHeader_401(t *testing.T) {
s := newHardeningTestSetup(t)
h := NewESTHandler(s.svc)
h.SetEnrollmentPassword("super-secret")
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
strings.NewReader(generateTestCSRPEM(t)))
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
w := httptest.NewRecorder()
h.SimpleEnroll(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401 (Basic required, header absent)", w.Code)
}
if got := w.Header().Get("WWW-Authenticate"); !strings.Contains(got, "Basic") {
t.Errorf("WWW-Authenticate = %q, want to contain 'Basic'", got)
}
}
func TestSimpleEnroll_BasicAuth_WrongPassword_401(t *testing.T) {
s := newHardeningTestSetup(t)
h := NewESTHandler(s.svc)
h.SetEnrollmentPassword("super-secret")
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
strings.NewReader(generateTestCSRPEM(t)))
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
req.SetBasicAuth("device", "wrong-password")
w := httptest.NewRecorder()
h.SimpleEnroll(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401 (wrong password)", w.Code)
}
}
func TestSimpleEnroll_BasicAuth_CorrectPassword_200(t *testing.T) {
s := newHardeningTestSetup(t)
h := NewESTHandler(s.svc)
h.SetEnrollmentPassword("super-secret")
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
strings.NewReader(generateTestCSRPEM(t)))
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
req.SetBasicAuth("device", "super-secret")
w := httptest.NewRecorder()
h.SimpleEnroll(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200 (correct password); body=%q", w.Code, w.Body.String())
}
}
func TestSimpleEnroll_BasicAuth_NoPassword_NoGate(t *testing.T) {
// When the per-profile enrollment password is empty, the Basic gate is
// off and the handler reverts to the v2.0.x anonymous behavior.
s := newHardeningTestSetup(t)
h := NewESTHandler(s.svc) // SetEnrollmentPassword not called
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
strings.NewReader(generateTestCSRPEM(t)))
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
w := httptest.NewRecorder()
h.SimpleEnroll(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200 (no Basic gate)", w.Code)
}
}
// ----- source-IP failed-auth rate limit (Phase 3.3) -----
func TestSimpleEnroll_BasicAuth_FailedAttemptLimitedAfterThreshold(t *testing.T) {
s := newHardeningTestSetup(t)
h := NewESTHandler(s.svc)
h.SetEnrollmentPassword("super-secret")
// Cap of 2 failed attempts before the IP gets locked. Each failed
// attempt records a slot; the 3rd request should be 429.
limiter := ratelimit.NewSlidingWindowLimiter(2, time.Hour, 10)
h.SetSourceIPRateLimiter(limiter)
for i := 0; i < 2; i++ {
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
strings.NewReader(generateTestCSRPEM(t)))
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
req.RemoteAddr = "10.0.0.42:12345"
req.SetBasicAuth("device", "WRONG")
w := httptest.NewRecorder()
h.SimpleEnroll(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("attempt %d: want 401, got %d", i, w.Code)
}
}
// The 3rd attempt — even with a correct password — must be rate limited.
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
strings.NewReader(generateTestCSRPEM(t)))
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
req.RemoteAddr = "10.0.0.42:12345"
req.SetBasicAuth("device", "super-secret")
w := httptest.NewRecorder()
h.SimpleEnroll(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("post-lockout status = %d, want 429 (correct password should still be locked out)", w.Code)
}
}
// ----- per-principal sliding-window rate limit (Phase 4.2) -----
func TestSimpleEnroll_PerPrincipalLimit_BlocksAfterCap(t *testing.T) {
s := newHardeningTestSetup(t)
h := NewESTHandler(s.svc)
limiter := ratelimit.NewSlidingWindowLimiter(2, 24*time.Hour, 100)
h.SetPerPrincipalRateLimiter(limiter)
// First 2 enrollments from same (CN, IP) — pass.
csrPEM := generateTestCSRPEM(t)
for i := 0; i < 2; i++ {
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
strings.NewReader(csrPEM))
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
req.RemoteAddr = "10.0.0.7:5555"
w := httptest.NewRecorder()
h.SimpleEnroll(w, req)
if w.Code != http.StatusOK {
t.Fatalf("attempt %d: want 200, got %d", i, w.Code)
}
}
// Third enrollment from same (CN, IP) — limited.
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
strings.NewReader(csrPEM))
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
req.RemoteAddr = "10.0.0.7:5555"
w := httptest.NewRecorder()
h.SimpleEnroll(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("3rd same-principal enrollment status = %d, want 429", w.Code)
}
}
func TestSimpleEnroll_PerPrincipalLimit_DifferentPrincipalsIndependent(t *testing.T) {
s := newHardeningTestSetup(t)
h := NewESTHandler(s.svc)
limiter := ratelimit.NewSlidingWindowLimiter(1, 24*time.Hour, 100)
h.SetPerPrincipalRateLimiter(limiter)
csrPEM1 := generateTestCSRPEM(t)
csrPEM2 := generateTestCSRPEM(t) // different key + (default) different CN
req1 := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll", strings.NewReader(csrPEM1))
req1.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
req1.RemoteAddr = "10.0.0.10:1111"
w1 := httptest.NewRecorder()
h.SimpleEnroll(w1, req1)
if w1.Code != http.StatusOK {
t.Fatalf("principal 1 first call: want 200, got %d", w1.Code)
}
// Same CN as csrPEM1 but different IP — independent bucket.
req2 := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll", strings.NewReader(csrPEM2))
req2.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
req2.RemoteAddr = "10.0.0.20:2222"
w2 := httptest.NewRecorder()
h.SimpleEnroll(w2, req2)
if w2.Code != http.StatusOK {
t.Errorf("principal 2 first call: want 200, got %d", w2.Code)
}
}
// ----- per-handler smoke test for the un-rolled mTLS variants -----
func TestCACertsMTLS_RequiresClientCert(t *testing.T) {
s := newHardeningTestSetup(t)
h := NewESTHandler(s.svc)
h.SetMTLSTrust(s.trustPool)
req := httptest.NewRequest(http.MethodGet, "/.well-known/est-mtls/corp/cacerts", nil)
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
w := httptest.NewRecorder()
h.CACertsMTLS(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("CACertsMTLS no-cert status = %d, want 401", w.Code)
}
}
func TestCSRAttrsMTLS_RequiresClientCert(t *testing.T) {
s := newHardeningTestSetup(t)
h := NewESTHandler(s.svc)
h.SetMTLSTrust(s.trustPool)
req := httptest.NewRequest(http.MethodGet, "/.well-known/est-mtls/corp/csrattrs", nil)
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
w := httptest.NewRecorder()
h.CSRAttrsMTLS(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("CSRAttrsMTLS no-cert status = %d, want 401", w.Code)
}
}
// ----- ensure the per-principal limit fires only when configured -----
func TestSimpleEnroll_NoPerPrincipalLimiter_AllUnbounded(t *testing.T) {
s := newHardeningTestSetup(t)
h := NewESTHandler(s.svc) // SetPerPrincipalRateLimiter not called
csrPEM := generateTestCSRPEM(t)
for i := 0; i < 50; i++ {
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
strings.NewReader(csrPEM))
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
w := httptest.NewRecorder()
h.SimpleEnroll(w, req)
if w.Code != http.StatusOK {
t.Fatalf("attempt %d: want 200, got %d", i, w.Code)
}
}
}
// silenceUnused keeps the "declared and not used" linter happy when we add
// helpers that future tests may invoke (asn1, atomic).
var _ = asn1.RawValue{}
var _ atomic.Int32
+43 -4
View File
@@ -81,10 +81,11 @@ var AuthExemptRouterRoutes = []string{
// TestDispatch_AuthExemptPrefixes regression test in cmd/server/main_test.go
// pins this slice to buildFinalHandler's actual dispatch logic.
var AuthExemptDispatchPrefixes = []string{
"/.well-known/pki", // RFC 5280 CRL + RFC 6960 OCSP — relying-party-unauth
"/.well-known/est", // RFC 7030 EST — auth via mTLS or CSR-embedded creds
"/scep", // RFC 8894 SCEP — auth via challengePassword in CSR
"/scep-mtls", // SCEP + mTLS sibling route (Phase 6.5) — auth is client cert + challengePassword
"/.well-known/pki", // RFC 5280 CRL + RFC 6960 OCSP — relying-party-unauth
"/.well-known/est", // RFC 7030 EST — auth via mTLS or CSR-embedded creds
"/.well-known/est-mtls", // EST + mTLS sibling route (EST hardening Phase 2) — auth is client cert
"/scep", // RFC 8894 SCEP — auth via challengePassword in CSR
"/scep-mtls", // SCEP + mTLS sibling route (Phase 6.5) — auth is client cert + challengePassword
}
// HandlerRegistry groups all API handler dependencies for router registration.
@@ -445,6 +446,44 @@ func (r *Router) RegisterESTHandlers(handlers map[string]handler.ESTHandler) {
}
}
// RegisterESTMTLSHandlers sets up the sibling `/.well-known/est-mtls/<PathID>/`
// routes for EST profiles that opted into mTLS via
// `CERTCTL_EST_PROFILE_<NAME>_MTLS_ENABLED=true`.
//
// EST RFC 7030 hardening master bundle Phase 2.2 + 2.3: enterprise
// procurement teams routinely reject 'shared password authentication' as
// a checkbox-fail regardless of how strong the password is. This sibling
// route adds client-cert auth at the handler layer AND keeps the (Phase 3)
// HTTP Basic enrollment-password as a defense-in-depth fallback for the
// non-mTLS profile. Devices present a bootstrap cert from a trusted CA,
// then EST-enroll for their long-lived cert. Mirrors the SCEP mTLS
// sibling pattern at RegisterSCEPMTLSHandlers below (commit 6b0d9e from
// the SCEP Phase 6.5 work).
//
// Path conventions: every mTLS profile gets a non-empty PathID, so the
// sibling routes are always /.well-known/est-mtls/<pathID>/. There is no
// "empty PathID = legacy /.well-known/est-mtls" case — mTLS is opt-in
// per profile, the legacy /.well-known/est root is always non-mTLS to
// preserve backward compat with existing deploys.
//
// Each handler in the map MUST have had SetMTLSTrust called so the
// per-profile cert verification has a trust anchor. cmd/server/main.go's
// per-profile EST loop wires this in the same loop iteration that
// registers the handler.
func (r *Router) RegisterESTMTLSHandlers(handlers map[string]handler.ESTHandler) {
for pathID, h := range handlers {
if pathID == "" {
continue // mTLS sibling route requires per-profile PathID
}
hCopy := h // h is captured by value — see RegisterESTHandlers above
prefix := "/.well-known/est-mtls/" + pathID
r.Register("GET "+prefix+"/cacerts", http.HandlerFunc(hCopy.CACertsMTLS))
r.Register("POST "+prefix+"/simpleenroll", http.HandlerFunc(hCopy.SimpleEnrollMTLS))
r.Register("POST "+prefix+"/simplereenroll", http.HandlerFunc(hCopy.SimpleReEnrollMTLS))
r.Register("GET "+prefix+"/csrattrs", http.HandlerFunc(hCopy.CSRAttrsMTLS))
}
}
// RegisterSCEPHandlers sets up SCEP (RFC 8894) routes.
// SCEP uses a single endpoint per profile with operation-based dispatch via
// query parameters. Authentication is via the challengePassword attribute in
+369
View File
@@ -0,0 +1,369 @@
// Package cms implements the small subset of CMS / RFC 7030 / RFC 9266
// helpers that the EST handler needs at request-time: extracting the
// RFC 9266 tls-exporter from a *tls.ConnectionState, and pulling the
// matching value back out of an EST CSR's CMC unsignedAttribute when the
// device proved channel binding.
//
// Why a separate package (vs adding to internal/api/handler/est.go):
//
// 1. internal/api/handler depends on internal/pkcs7 already; if the EST
// mTLS hardening also pulled CMC parsing into handler we'd grow the
// handler-side dep graph by another asn1 surface that has nothing
// specific to HTTP.
//
// 2. Channel-binding extraction is testable in isolation — the unit
// tests construct a *tls.ConnectionState with raw exporter bytes and
// a *x509.CertificateRequest with the CMC unsignedAttribute already
// filled in. No HTTP plumbing required to verify the contract.
//
// 3. Future EST extensions (RFC 7030 §3.5 fullCMC, RFC 9148 EST-coaps)
// are likely to land here too — keep them out of net/http land.
//
// EST RFC 7030 hardening master bundle Phase 2.4.
package cms
import (
"bytes"
"crypto/subtle"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"errors"
"fmt"
)
// ----- RFC 9266 §3 — TLS exporter extraction -----
// TLSExporterLabel is the EXPORTER label registered by RFC 9266 §3.1
// for use as a TLS-1.3 channel binding. Constant rather than string-typed
// so a typo here is a compile error rather than a silent failure mode.
const TLSExporterLabel = "EXPORTER-Channel-Binding"
// TLSExporterLength is the 32-byte exporter length pinned by RFC 9266 §3.1
// (matches the SHA-256 output size; clients and servers MUST agree on the
// length to make the comparison meaningful).
const TLSExporterLength = 32
// ErrChannelBindingMissing is returned when the EST mTLS handler requires
// channel binding (per-profile ChannelBindingRequired=true) but the device's
// CSR has no id-aa-est-tls-exporter unsignedAttribute or the attribute is
// the wrong shape.
var ErrChannelBindingMissing = errors.New("cms: channel binding required but absent or malformed in CSR")
// ErrChannelBindingMismatch is returned when the device's CSR carried a
// channel-binding attribute but its bytes do not match the TLS-1.3 exporter
// extracted from the live connection. This is the signal of an MITM that
// terminates TLS in front of certctl: the device computed exporter X
// against the attacker, certctl sees exporter Y against itself, X≠Y.
var ErrChannelBindingMismatch = errors.New("cms: channel binding in CSR does not match TLS exporter")
// ErrChannelBindingNotTLS13 is returned when the connection is older than
// TLS 1.3 and the per-profile config still requires channel binding.
// RFC 9266's tls-exporter is a TLS-1.3 binding; pre-1.3 connections would
// need RFC 5929 tls-unique, which we deliberately don't support
// (certctl pins TLS-1.3 server-side).
var ErrChannelBindingNotTLS13 = errors.New("cms: tls-exporter channel binding requires TLS 1.3")
// ExtractTLSExporter pulls the 32-byte RFC 9266 channel-binding value from
// the TLS connection state. The connection must be TLS 1.3 + handshake-
// complete; anything else returns a typed error so the caller can map to
// HTTP 400 / 412 cleanly.
//
// Stateless on purpose: callers handle storage + comparison.
//
// Robustness note: stdlib's ConnectionState.ExportKeyingMaterial nil-derefs
// when the underlying secret-derivation closure is unset (i.e. the state
// was hand-constructed by a test fixture rather than produced by a real
// TLS handshake). The recover() below converts that panic into the same
// typed error a missing-binding state would surface, so synthetic test
// states + production TLS-1.3 connections share a single failure mode.
func ExtractTLSExporter(state *tls.ConnectionState) (out []byte, err error) {
if state == nil {
return nil, fmt.Errorf("%w: nil ConnectionState", ErrChannelBindingMissing)
}
if !state.HandshakeComplete {
return nil, fmt.Errorf("%w: handshake incomplete", ErrChannelBindingMissing)
}
// tls.VersionTLS13 == 0x0304. We use the literal so this package doesn't
// have to import "crypto/tls" twice (once for tls.VersionTLS13, once for
// the *tls.ConnectionState type — Go allows it but it's noisy).
if state.Version < 0x0304 {
return nil, fmt.Errorf("%w: negotiated 0x%04x", ErrChannelBindingNotTLS13, state.Version)
}
defer func() {
if r := recover(); r != nil {
out = nil
err = fmt.Errorf("%w: ExportKeyingMaterial unavailable on this connection state (panic=%v)", ErrChannelBindingMissing, r)
}
}()
out, err = state.ExportKeyingMaterial(TLSExporterLabel, nil, TLSExporterLength)
if err != nil {
return nil, fmt.Errorf("cms: ExportKeyingMaterial: %w", err)
}
if len(out) != TLSExporterLength {
return nil, fmt.Errorf("cms: exporter returned %d bytes, want %d", len(out), TLSExporterLength)
}
return out, nil
}
// ----- RFC 7030 §3.5 / RFC 9266 §4.1 — CSR-side channel binding -----
// OIDChannelBindingTLSExporter is the id-aa-est-tls-exporter OID from
// RFC 9266 §4.1 (registered under id-aa = 1.2.840.113549.1.9.16.2 with
// arc 56 by RFC 9266). Devices that signed channel binding into their
// CSR add a CMC unsignedAttribute with this OID + an OCTET STRING value.
//
// Note: the EST RFC 7030 §3.5 historical OID for tls-unique is
// id-aa-cmc-binding (1.2.840.113549.1.9.16.2.43). RFC 9266 §4.1 added
// arc 56 for tls-exporter. We accept BOTH OIDs on the read path so a
// device using a slightly older library that still emits the §3.5 OID
// continues to work — the value bytes are still the 32-byte exporter
// (the OID identifies the binding scheme, not the underlying wire
// format).
var (
OIDChannelBindingTLSExporter = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 16, 2, 56}
OIDCMCEnrollmentBinding = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 16, 2, 43}
)
// ExtractCSRChannelBinding looks for the RFC 9266 channel-binding
// attribute (or the legacy RFC 7030 §3.5 binding attribute) in the CSR's
// raw attributes block. Returns the raw 32-byte exporter value if
// present.
//
// Why we walk csr.RawTBSCertificateRequest manually instead of using
// csr.Attributes:
//
// - csr.Attributes is typed as []pkix.AttributeTypeAndValueSET, where
// the inner Value is [][]pkix.AttributeTypeAndValue. That shape only
// fits attributes whose AttributeValue is itself a SEQUENCE { OID,
// ANY } (e.g. the requestedExtensions attribute). RFC 9266's
// TLSExporterValue is `OCTET STRING` — a primitive, not a SEQUENCE
// — so the stdlib parse path either drops the attribute silently or
// fails the whole CSR parse depending on encoding.
//
// - The PKCS#10 challengePassword path in scep.go works by accident
// because PrintableString happens to round-trip through the
// stdlib's interface{}-typed AttributeTypeAndValue.Value. OCTET
// STRING does not — it's not in the small list of primitive types
// the stdlib's reflect-based unmarshaller handles for `any`.
//
// - Walking the raw TBS is ~30 lines of asn1.Unmarshal calls and
// gives us a stable contract independent of stdlib quirks.
//
// Returns (value, true, nil) on success; (nil, false, nil) when the
// attribute is absent (caller decides whether absence is acceptable per
// the per-profile ChannelBindingRequired flag); (nil, false, err) on
// malformed attribute (always fatal — a present-but-wrong attribute
// signals an attacker rewriting the binding into garbage).
func ExtractCSRChannelBinding(csr *x509.CertificateRequest) ([]byte, bool, error) {
if csr == nil {
return nil, false, fmt.Errorf("cms: nil CSR")
}
if len(csr.RawTBSCertificateRequest) == 0 {
// Stdlib fills RawTBSCertificateRequest on every parse path, so an
// empty value here means the caller hand-crafted the struct. Tests
// can do that — but real handler-side calls always have raw bytes.
return nil, false, nil
}
return walkCSRAttributesForBinding(csr.RawTBSCertificateRequest)
}
// walkCSRAttributesForBinding parses just enough of TBSCertificationRequestInfo
// to reach the [0] IMPLICIT Attributes field, then iterates each Attribute
// looking for the channel-binding OID. The body is intentionally low-level
// so we can keep the asn1 footprint contained to this one helper.
//
// TBSCertificationRequestInfo per RFC 2986 §4.1:
//
// TBSCertificationRequestInfo ::= SEQUENCE {
// version INTEGER (0),
// subject Name,
// subjectPKInfo SubjectPublicKeyInfo,
// attributes [0] IMPLICIT Attributes (SET OF Attribute)
// }
func walkCSRAttributesForBinding(tbs []byte) ([]byte, bool, error) {
// 1. Crack the outer SEQUENCE wrapper.
var inner asn1.RawValue
if rest, err := asn1.Unmarshal(tbs, &inner); err != nil {
return nil, false, fmt.Errorf("cms: TBS outer parse: %w", err)
} else if len(rest) > 0 {
return nil, false, fmt.Errorf("cms: TBS trailing bytes: %d", len(rest))
}
if inner.Tag != asn1.TagSequence {
return nil, false, fmt.Errorf("cms: TBS outer tag %d not SEQUENCE", inner.Tag)
}
rest := inner.Bytes
// 2. Skip version (INTEGER), subject (SEQUENCE = Name), subjectPKInfo
// (SEQUENCE). asn1.Unmarshal into asn1.RawValue advances the cursor
// without parsing the body — perfect for skipping fields we don't care
// about.
for i, label := range []string{"version", "subject", "subjectPKInfo"} {
var rv asn1.RawValue
next, err := asn1.Unmarshal(rest, &rv)
if err != nil {
return nil, false, fmt.Errorf("cms: skip TBS field %d (%s): %w", i, label, err)
}
rest = next
}
// 3. Attributes is [0] IMPLICIT — the on-wire tag is 0xA0 with class
// CONTEXT-SPECIFIC. asn1.Unmarshal into a RawValue accepts arbitrary
// tags; we then walk its Bytes as a SET OF Attribute.
var attrsField asn1.RawValue
if _, err := asn1.Unmarshal(rest, &attrsField); err != nil {
// No attributes block at all — RFC 2986 says [0] is OPTIONAL when
// empty (encoders typically omit the field rather than emit an
// empty SET). Treat as "no binding present", not as an error.
return nil, false, nil
}
if attrsField.Class != asn1.ClassContextSpecific || attrsField.Tag != 0 {
// Some non-attribute-shaped trailing field: not what we expected
// but not strictly a corruption signal — skip silently.
return nil, false, nil
}
// 4. Walk each Attribute in the SET. Each Attribute is
// SEQUENCE { OID, SET OF ANY }.
attrBytes := attrsField.Bytes
for len(attrBytes) > 0 {
var oneAttr asn1.RawValue
next, err := asn1.Unmarshal(attrBytes, &oneAttr)
if err != nil {
return nil, false, fmt.Errorf("cms: walk attributes: %w", err)
}
attrBytes = next
if oneAttr.Tag != asn1.TagSequence {
continue
}
// Inner: OID, then SET.
var oid asn1.ObjectIdentifier
afterOID, err := asn1.Unmarshal(oneAttr.Bytes, &oid)
if err != nil {
continue
}
if !oid.Equal(OIDChannelBindingTLSExporter) && !oid.Equal(OIDCMCEnrollmentBinding) {
continue
}
// Now afterOID is the SET wrapper. Crack it and pull the OCTET
// STRING out of the SET's first element.
var setWrap asn1.RawValue
if _, err := asn1.Unmarshal(afterOID, &setWrap); err != nil {
return nil, false, fmt.Errorf("cms: binding SET parse: %w (%w)", err, ErrChannelBindingMissing)
}
if setWrap.Tag != asn1.TagSet {
return nil, false, fmt.Errorf("cms: binding outer tag %d not SET (%w)", setWrap.Tag, ErrChannelBindingMissing)
}
var octet asn1.RawValue
if _, err := asn1.Unmarshal(setWrap.Bytes, &octet); err != nil {
return nil, false, fmt.Errorf("cms: binding inner parse: %w (%w)", err, ErrChannelBindingMissing)
}
if octet.Tag != asn1.TagOctetString {
return nil, false, fmt.Errorf("cms: binding inner tag %d not OCTET STRING (%w)", octet.Tag, ErrChannelBindingMissing)
}
if len(octet.Bytes) != TLSExporterLength {
return nil, false, fmt.Errorf("cms: binding length %d, want %d (%w)",
len(octet.Bytes), TLSExporterLength, ErrChannelBindingMissing)
}
return octet.Bytes, true, nil
}
return nil, false, nil
}
// VerifyChannelBinding is the convenience composite the EST mTLS handler
// calls per request: extract the exporter from the live TLS connection,
// pull the matching value from the CSR, compare in constant time.
//
// Returns:
// - nil when the binding is present + matches.
// - ErrChannelBindingMissing when the CSR has no binding attribute.
// - ErrChannelBindingMismatch when both sides have a value but they
// differ (the MITM signal).
// - Any error from the exporter extraction (TLS state is wrong, etc).
//
// The required flag controls absence-handling: when required=false a
// missing attribute returns nil (channel binding is optional for this
// profile); when required=true a missing attribute returns
// ErrChannelBindingMissing.
func VerifyChannelBinding(state *tls.ConnectionState, csr *x509.CertificateRequest, required bool) error {
live, err := ExtractTLSExporter(state)
if err != nil {
// If the profile doesn't require channel binding AND the only
// problem is "no TLS 1.3 / no handshake", we still let the request
// through — the binding is opt-in per profile. But if the CSR
// itself carries a binding attribute, the device clearly INTENDED
// to bind, so a TLS state mismatch is a genuine error.
if !required {
if _, present, _ := ExtractCSRChannelBinding(csr); !present {
return nil
}
}
return err
}
csrBinding, present, err := ExtractCSRChannelBinding(csr)
if err != nil {
return err
}
if !present {
if required {
return ErrChannelBindingMissing
}
return nil
}
if subtle.ConstantTimeCompare(live, csrBinding) != 1 {
return ErrChannelBindingMismatch
}
// Sanity: the comparison should be identical bytes for matching cases.
// The bytes.Equal call is dead code under correct subtle.Compare result;
// it's here only to make the contract obvious to readers and to pin the
// symmetry test that asserts ExtractCSRChannelBinding is byte-equivalent
// to ExtractTLSExporter when the device behaved correctly.
if !bytes.Equal(live, csrBinding) {
return ErrChannelBindingMismatch
}
return nil
}
// EmbedChannelBindingAttribute is the test helper inverse of
// ExtractCSRChannelBinding: given an exporter value, returns the DER
// bytes of the Attribute (SEQUENCE { OID, SET { OCTET STRING } }) that
// the caller can splice into the [0] IMPLICIT Attributes field of
// TBSCertificationRequestInfo. Used by the EST channel-binding tests
// AND by any external caller that wants to forge a CSR with a known
// binding for fixture generation.
func EmbedChannelBindingAttribute(exporter []byte) ([]byte, error) {
if len(exporter) != TLSExporterLength {
return nil, fmt.Errorf("cms: exporter length %d, want %d", len(exporter), TLSExporterLength)
}
octet, err := asn1.Marshal(exporter) // marshal []byte as OCTET STRING
if err != nil {
return nil, fmt.Errorf("cms: marshal exporter octet: %w", err)
}
// Wrap in SET OF.
setBody := octet
setEnvelope, err := asn1.Marshal(asn1.RawValue{
Class: asn1.ClassUniversal,
Tag: asn1.TagSet,
IsCompound: true,
Bytes: setBody,
})
if err != nil {
return nil, fmt.Errorf("cms: marshal SET: %w", err)
}
oid, err := asn1.Marshal(OIDChannelBindingTLSExporter)
if err != nil {
return nil, fmt.Errorf("cms: marshal OID: %w", err)
}
// Wrap as SEQUENCE { OID, SET }.
seqBody := append(append([]byte{}, oid...), setEnvelope...)
seqEnvelope, err := asn1.Marshal(asn1.RawValue{
Class: asn1.ClassUniversal,
Tag: asn1.TagSequence,
IsCompound: true,
Bytes: seqBody,
})
if err != nil {
return nil, fmt.Errorf("cms: marshal SEQUENCE: %w", err)
}
return seqEnvelope, nil
}
+394
View File
@@ -0,0 +1,394 @@
package cms
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"errors"
"math/big"
"net"
"testing"
)
// EST RFC 7030 hardening master bundle Phase 2.4 tests.
// ----- ExtractTLSExporter -----
func TestExtractTLSExporter_NilState(t *testing.T) {
if _, err := ExtractTLSExporter(nil); !errors.Is(err, ErrChannelBindingMissing) {
t.Errorf("nil state should return ErrChannelBindingMissing, got %v", err)
}
}
func TestExtractTLSExporter_HandshakeNotComplete(t *testing.T) {
state := &tls.ConnectionState{HandshakeComplete: false, Version: 0x0304}
if _, err := ExtractTLSExporter(state); !errors.Is(err, ErrChannelBindingMissing) {
t.Errorf("incomplete handshake should return ErrChannelBindingMissing, got %v", err)
}
}
func TestExtractTLSExporter_PreTLS13Rejected(t *testing.T) {
state := &tls.ConnectionState{HandshakeComplete: true, Version: 0x0303} // TLS 1.2
if _, err := ExtractTLSExporter(state); !errors.Is(err, ErrChannelBindingNotTLS13) {
t.Errorf("TLS 1.2 should return ErrChannelBindingNotTLS13, got %v", err)
}
}
// TestExtractTLSExporter_TLS13EndToEnd is the only test that builds a full
// real TLS-1.3 session — the exporter is computed on the connection's secret
// state, so we can't fake the ConnectionState. We spin up a localhost TCP
// listener, do a handshake, and then call ExportKeyingMaterial directly to
// pin the contract. This is a small round-trip but we're not testing TLS
// itself — just that ExtractTLSExporter pulls a 32-byte value from a real
// 1.3 state.
func TestExtractTLSExporter_TLS13EndToEnd(t *testing.T) {
cert, key := freshSelfSignedTLSCert(t)
tlsCert := tls.Certificate{Certificate: [][]byte{cert.Raw}, PrivateKey: key}
cfg := &tls.Config{
Certificates: []tls.Certificate{tlsCert},
MinVersion: tls.VersionTLS13,
MaxVersion: tls.VersionTLS13,
}
clientCfg := &tls.Config{
InsecureSkipVerify: true, //nolint:gosec // hermetic test cert; not for production use
MinVersion: tls.VersionTLS13,
MaxVersion: tls.VersionTLS13,
}
ln, err := tls.Listen("tcp", "127.0.0.1:0", cfg)
if err != nil {
t.Fatalf("tls.Listen: %v", err)
}
defer ln.Close()
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
// Finish the handshake on the server side.
_ = conn.(*tls.Conn).HandshakeContext(context.Background())
// Hold the connection open until the client side completes its read.
buf := make([]byte, 1)
_, _ = conn.Read(buf)
}()
conn, err := tls.Dial("tcp", ln.Addr().String(), clientCfg)
if err != nil {
t.Fatalf("tls.Dial: %v", err)
}
defer conn.Close()
if err := conn.HandshakeContext(context.Background()); err != nil {
t.Fatalf("client handshake: %v", err)
}
state := conn.ConnectionState()
out, err := ExtractTLSExporter(&state)
if err != nil {
t.Fatalf("ExtractTLSExporter: %v", err)
}
if len(out) != TLSExporterLength {
t.Errorf("len(out) = %d, want %d", len(out), TLSExporterLength)
}
}
// ----- ExtractCSRChannelBinding -----
func TestExtractCSRChannelBinding_NilCSR(t *testing.T) {
if _, _, err := ExtractCSRChannelBinding(nil); err == nil {
t.Fatal("nil CSR should error")
}
}
func TestExtractCSRChannelBinding_AbsentReturnsFalse(t *testing.T) {
csr := freshCSRNoBinding(t)
val, present, err := ExtractCSRChannelBinding(csr)
if err != nil {
t.Fatalf("ExtractCSRChannelBinding: %v", err)
}
if present {
t.Errorf("present=true on a CSR without the binding attribute (val=%x)", val)
}
}
func TestExtractCSRChannelBinding_PresentReturnsExporter(t *testing.T) {
exporter := repeatByte(0x42, TLSExporterLength)
csr := freshCSRWithBinding(t, exporter, OIDChannelBindingTLSExporter)
val, present, err := ExtractCSRChannelBinding(csr)
if err != nil {
t.Fatalf("ExtractCSRChannelBinding: %v", err)
}
if !present {
t.Fatal("present=false on a CSR that carries the binding")
}
if !bytesEq(val, exporter) {
t.Errorf("exporter = %x, want %x", val, exporter)
}
}
func TestExtractCSRChannelBinding_LegacyOIDAccepted(t *testing.T) {
exporter := repeatByte(0xAA, TLSExporterLength)
csr := freshCSRWithBinding(t, exporter, OIDCMCEnrollmentBinding)
val, present, err := ExtractCSRChannelBinding(csr)
if err != nil {
t.Fatalf("legacy-OID path failed: %v", err)
}
if !present || !bytesEq(val, exporter) {
t.Errorf("legacy-OID extraction: got present=%v val=%x, want present=true val=%x", present, val, exporter)
}
}
func TestExtractCSRChannelBinding_WrongLengthRejected(t *testing.T) {
short := repeatByte(0x55, 16) // half the required length
csr := freshCSRWithBinding(t, short, OIDChannelBindingTLSExporter)
_, _, err := ExtractCSRChannelBinding(csr)
if !errors.Is(err, ErrChannelBindingMissing) {
t.Errorf("wrong-length binding should wrap ErrChannelBindingMissing, got %v", err)
}
}
// ----- VerifyChannelBinding (composite) -----
func TestVerifyChannelBinding_NotRequired_NoBinding_Passes(t *testing.T) {
csr := freshCSRNoBinding(t)
if err := VerifyChannelBinding(nil, csr, false); err != nil {
t.Errorf("required=false + no binding should pass; got %v", err)
}
}
func TestVerifyChannelBinding_Required_NilState_Errors(t *testing.T) {
csr := freshCSRNoBinding(t)
if err := VerifyChannelBinding(nil, csr, true); err == nil {
t.Fatal("required=true + nil state must error")
}
}
// NOTE: a synthetic *tls.ConnectionState{HandshakeComplete:true, Version:0x0304}
// would seem like the obvious VerifyChannelBinding(required=true) negative-case
// fixture, but stdlib's ExportKeyingMaterial nil-derefs when the underlying
// secret state is unset (see crypto/tls/common.go:330). The
// "no live exporter available" branch is genuinely only reachable via a real
// connection (TestExtractTLSExporter_TLS13EndToEnd above), so we don't try to
// fake it here. The TestVerifyChannelBinding_NotRequired_NoBinding_Passes +
// TestVerifyChannelBinding_Required_NilState_Errors tests cover the policy
// branches; production code paths only ever pass r.TLS from a live request.
// ----- helpers -----
// freshCSRNoBinding returns a CSR with no extra attributes.
func freshCSRNoBinding(t *testing.T) *x509.CertificateRequest {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("ecdsa.GenerateKey: %v", err)
}
tmpl := &x509.CertificateRequest{Subject: pkix.Name{CommonName: "no-binding-test"}}
der, err := x509.CreateCertificateRequest(rand.Reader, tmpl, key)
if err != nil {
t.Fatalf("CreateCertificateRequest: %v", err)
}
csr, err := x509.ParseCertificateRequest(der)
if err != nil {
t.Fatalf("ParseCertificateRequest: %v", err)
}
return csr
}
// freshCSRWithBinding builds a CSR whose TBS carries the channel-binding
// attribute. The stdlib's CreateCertificateRequest doesn't support arbitrary
// attributes (only ExtraExtensions), so we hand-craft the TBS by parsing
// what stdlib produced + splicing our attribute into the [0] IMPLICIT
// Attributes block + re-signing.
func freshCSRWithBinding(t *testing.T, exporter []byte, oid asn1.ObjectIdentifier) *x509.CertificateRequest {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("ecdsa.GenerateKey: %v", err)
}
// 1. Get a baseline CSR with no attributes — we steal its TBS shape.
tmpl := &x509.CertificateRequest{Subject: pkix.Name{CommonName: "binding-test"}}
derBaseline, err := x509.CreateCertificateRequest(rand.Reader, tmpl, key)
if err != nil {
t.Fatalf("CreateCertificateRequest: %v", err)
}
baseline, err := x509.ParseCertificateRequest(derBaseline)
if err != nil {
t.Fatalf("ParseCertificateRequest: %v", err)
}
// 2. Build the channel-binding attribute (SEQUENCE { OID, SET { OCTET STRING }}).
octet, err := asn1.Marshal(exporter)
if err != nil {
t.Fatalf("marshal octet: %v", err)
}
setEnv, err := asn1.Marshal(asn1.RawValue{Class: asn1.ClassUniversal, Tag: asn1.TagSet, IsCompound: true, Bytes: octet})
if err != nil {
t.Fatalf("marshal set: %v", err)
}
oidBytes, err := asn1.Marshal(oid)
if err != nil {
t.Fatalf("marshal oid: %v", err)
}
attrSeq, err := asn1.Marshal(asn1.RawValue{
Class: asn1.ClassUniversal,
Tag: asn1.TagSequence,
IsCompound: true,
Bytes: append(append([]byte{}, oidBytes...), setEnv...),
})
if err != nil {
t.Fatalf("marshal attribute SEQUENCE: %v", err)
}
// 3. Splice attribute into a [0] IMPLICIT Attributes block and rebuild
// the TBS by hand. The TBS structure is:
// SEQUENCE { version INTEGER, subject Name, subjectPKInfo SubjectPublicKeyInfo,
// attributes [0] IMPLICIT SET OF Attribute }
// We re-extract the first three fields from the baseline TBS and
// re-marshal with our attribute appended.
var outer asn1.RawValue
if _, err := asn1.Unmarshal(baseline.RawTBSCertificateRequest, &outer); err != nil {
t.Fatalf("baseline TBS unmarshal: %v", err)
}
rest := outer.Bytes
var version, subject, spki asn1.RawValue
for _, target := range []*asn1.RawValue{&version, &subject, &spki} {
next, err := asn1.Unmarshal(rest, target)
if err != nil {
t.Fatalf("baseline TBS skip: %v", err)
}
rest = next
}
versionDER, _ := asn1.Marshal(version)
subjectDER, _ := asn1.Marshal(subject)
spkiDER, _ := asn1.Marshal(spki)
// Build the [0] IMPLICIT Attributes wrapper.
attrsField, err := asn1.Marshal(asn1.RawValue{
Class: asn1.ClassContextSpecific,
Tag: 0,
IsCompound: true,
Bytes: attrSeq,
})
if err != nil {
t.Fatalf("marshal attrs field: %v", err)
}
tbsBody := append(append(append(append([]byte{}, versionDER...), subjectDER...), spkiDER...), attrsField...)
newTBS, err := asn1.Marshal(asn1.RawValue{
Class: asn1.ClassUniversal,
Tag: asn1.TagSequence,
IsCompound: true,
Bytes: tbsBody,
})
if err != nil {
t.Fatalf("re-marshal TBS: %v", err)
}
// 4. Parse the new TBS — we don't need to re-sign for these tests
// (ExtractCSRChannelBinding doesn't verify the signature; it walks
// RawTBSCertificateRequest only).
csr := &x509.CertificateRequest{
RawTBSCertificateRequest: newTBS,
Subject: baseline.Subject,
PublicKey: baseline.PublicKey,
}
return csr
}
// freshSelfSignedTLSCert produces a tls.Certificate-compatible cert+key for
// the TLS-1.3 round-trip test.
func freshSelfSignedTLSCert(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("ecdsa.GenerateKey: %v", err)
}
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "tls-test"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
DNSNames: []string{"localhost"},
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
if err != nil {
t.Fatalf("CreateCertificate: %v", err)
}
cert, err := x509.ParseCertificate(der)
if err != nil {
t.Fatalf("ParseCertificate: %v", err)
}
return cert, key
}
// repeatByte returns a slice of length n filled with b. Used for fixture
// exporter values where we need a deterministic test pattern.
func repeatByte(b byte, n int) []byte {
out := make([]byte, n)
for i := range out {
out[i] = b
}
return out
}
func bytesEq(a, b []byte) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
// EmbedChannelBindingAttribute round-trip — pins the spec contract that
// what we marshal can be parsed back by ExtractCSRChannelBinding without
// going through the freshCSRWithBinding splice helper.
func TestEmbedChannelBindingAttribute_RoundTrip(t *testing.T) {
exporter := repeatByte(0x77, TLSExporterLength)
attrDER, err := EmbedChannelBindingAttribute(exporter)
if err != nil {
t.Fatalf("EmbedChannelBindingAttribute: %v", err)
}
// Wrap the single attribute in a [0] IMPLICIT SET OF Attribute block
// and a TBS-lookalike SEQUENCE so we can feed it through the same path
// the parser uses — the parser doesn't care that version+subject+spki
// are absent because it walks structurally.
attrsField, err := asn1.Marshal(asn1.RawValue{
Class: asn1.ClassContextSpecific,
Tag: 0,
IsCompound: true,
Bytes: attrDER,
})
if err != nil {
t.Fatalf("marshal attrs field: %v", err)
}
// Synthetic TBS with three placeholder asn1.RawValue fields then attrsField.
placeholder, _ := asn1.Marshal(asn1.RawValue{Class: asn1.ClassUniversal, Tag: asn1.TagInteger, Bytes: []byte{0x00}})
body := append(append(append(append([]byte{}, placeholder...), placeholder...), placeholder...), attrsField...)
tbs, err := asn1.Marshal(asn1.RawValue{
Class: asn1.ClassUniversal,
Tag: asn1.TagSequence,
IsCompound: true,
Bytes: body,
})
if err != nil {
t.Fatalf("marshal TBS: %v", err)
}
got, present, err := walkCSRAttributesForBinding(tbs)
if err != nil {
t.Fatalf("walkCSRAttributesForBinding: %v", err)
}
if !present {
t.Fatal("present=false on round-trip")
}
if !bytesEq(got, exporter) {
t.Errorf("round-trip mismatch: got %x, want %x", got, exporter)
}
}
+188
View File
@@ -0,0 +1,188 @@
// Package ratelimit provides shared rate-limit primitives used by
// authenticated-but-shared-credential code paths (SCEP/Intune
// per-device challenge enrollment, EST per-principal CSR enrollment,
// EST HTTP-Basic source-IP failed-auth limiter) where the threat
// model is "single legitimate identity could mint enrollments
// faster than any human/fleet workflow would."
//
// Origin: this package was extracted from
// internal/scep/intune/rate_limit.go in the EST RFC 7030 hardening
// master bundle Phase 4.1 — EST is the third caller after the
// Intune dispatcher (per-device-GUID cap on enrollment) and the EST
// per-principal cap (Phase 4.2). The original Intune-package type +
// constructor + ErrRateLimited sentinel are preserved as type
// aliases at internal/scep/intune/rate_limit.go so existing call
// sites compile unchanged. New callers SHOULD use this package
// directly.
//
// Algorithm: sliding window log. Each key maps to a bucket holding
// timestamps within the configured window. On Allow, the bucket
// prunes timestamps older than (now - window) and either appends +
// returns nil, or rejects + returns ErrRateLimited when the
// post-prune count is already at the cap. Exact (no token-leak
// rounding); O(N_per_key) per-call but N is bounded by the cap, so
// effectively O(1).
//
// Concurrency: safe for concurrent Allow calls. Internal map guarded
// by sync.Mutex; per-key slices mutated only while the mutex is
// held.
//
// Memory: bounded by the per-instance map cap (default 100,000 keys;
// configurable). At-cap eviction drops the oldest entry by newest
// timestamp — small janitor pass; rarely fires in practice because
// the prune-on-Allow path keeps most buckets short-lived.
package ratelimit
import (
"errors"
"sync"
"time"
)
// ErrRateLimited is returned by SlidingWindowLimiter.Allow when the
// bucket for the given key is already at the cap. Callers can
// errors.Is against this sentinel; the underlying message is stable
// across the package's lifetime so test assertions can match on it.
var ErrRateLimited = errors.New("ratelimit: per-key cap exceeded for the configured window")
// SlidingWindowLimiter is the sliding-window-log rate limiter.
//
// Construct via NewSlidingWindowLimiter. The zero value is NOT
// usable — the buckets map needs initialisation.
type SlidingWindowLimiter struct {
mu sync.Mutex
buckets map[string][]time.Time // key → sliding window of timestamps
maxN int // max enrollments per window
window time.Duration // window length (default 24h)
cap int // max keys before LRU eviction kicks in
disabled bool // maxN <= 0 → all Allow calls return nil
}
// NewSlidingWindowLimiter returns a limiter with the given per-key
// cap + window. maxN <= 0 disables the limiter (all Allow calls
// return nil); this is operator opt-out for the rare case where the
// per-key cap is undesirable (test harnesses, sketchpad deploys).
//
// Window defaults to 24h when zero. Map cap defaults to 100,000 when
// zero (matches the SCEP/Intune replay cache cap).
func NewSlidingWindowLimiter(maxN int, window time.Duration, mapCap int) *SlidingWindowLimiter {
if window <= 0 {
window = 24 * time.Hour
}
if mapCap <= 0 {
mapCap = 100_000
}
return &SlidingWindowLimiter{
buckets: make(map[string][]time.Time),
maxN: maxN,
window: window,
cap: mapCap,
disabled: maxN <= 0,
}
}
// Allow reports whether an event keyed by `key` is permitted right
// now. Returns nil when allowed (and records the timestamp in the
// bucket) or ErrRateLimited when the bucket is at maxN.
//
// Empty key is treated as "skip the limiter" — the caller's
// validation should have rejected an empty-key event already; this
// is belt-and-suspenders so a single empty-key bucket doesn't
// become a chokepoint for every empty-key event. SCEP/Intune
// callers compose the key as `subject + "|" + issuer`; EST callers
// compose `cn + "|" + sourceIP` or `sourceIP`-alone for the
// failed-auth limiter.
func (l *SlidingWindowLimiter) Allow(key string, now time.Time) error {
if l.disabled {
return nil
}
if key == "" {
return nil
}
l.mu.Lock()
defer l.mu.Unlock()
// At-cap eviction: when the map is full, drop the oldest entry
// by finding the bucket whose newest timestamp is the smallest.
// O(N_keys) but rarely fires; the prune-on-Allow path keeps
// most buckets short-lived.
if len(l.buckets) >= l.cap {
l.evictOldestLocked()
}
bucket := l.buckets[key]
bucket = pruneOlderThan(bucket, now.Add(-l.window))
if len(bucket) >= l.maxN {
// Don't append; over the limit. Persist the pruned bucket so
// the next call sees the most-recently-pruned state.
l.buckets[key] = bucket
return ErrRateLimited
}
bucket = append(bucket, now)
l.buckets[key] = bucket
return nil
}
// pruneOlderThan returns the slice with all entries strictly before
// `cutoff` removed. Preserves order (timestamps are appended in
// increasing time, so a single linear scan from the front suffices).
func pruneOlderThan(b []time.Time, cutoff time.Time) []time.Time {
i := 0
for i < len(b) && b[i].Before(cutoff) {
i++
}
if i == 0 {
return b
}
// Copy-shrink to release the underlying-array memory eventually
// (otherwise the slice would hold a reference to the older
// entries indefinitely until a re-allocation).
out := make([]time.Time, len(b)-i)
copy(out, b[i:])
return out
}
// evictOldestLocked drops the map entry whose newest timestamp is
// the oldest. Called under l.mu. O(N_keys) per eviction; at-cap is
// rare in practice (caps are sized for steady-state).
func (l *SlidingWindowLimiter) evictOldestLocked() {
var (
oldestKey string
oldestTs time.Time
first = true
)
for k, b := range l.buckets {
if len(b) == 0 {
// Empty bucket — drop it immediately, no candidate scan needed.
delete(l.buckets, k)
return
}
newest := b[len(b)-1]
if first || newest.Before(oldestTs) {
oldestKey = k
oldestTs = newest
first = false
}
}
if oldestKey != "" {
delete(l.buckets, oldestKey)
}
}
// Len returns the approximate number of distinct keys currently
// tracked. For observability + tests; not load-stable under
// concurrent Allow calls.
func (l *SlidingWindowLimiter) Len() int {
l.mu.Lock()
defer l.mu.Unlock()
return len(l.buckets)
}
// Disabled reports whether the limiter is in opt-out mode (maxN <= 0).
// Useful for handler-side gating + admin-endpoint observability.
func (l *SlidingWindowLimiter) Disabled() bool {
return l.disabled
}
+197
View File
@@ -0,0 +1,197 @@
package ratelimit
import (
"errors"
"fmt"
"sync"
"testing"
"time"
)
// EST RFC 7030 hardening master bundle Phase 4.1: this test file holds the
// white-box tests for the SlidingWindowLimiter primitives that used to live
// in internal/scep/intune/rate_limit_test.go (TestPerDeviceRateLimiter_
// DefaultCapsHonored, TestPruneOlderThan, TestPruneOlderThan_NoOpWhen
// NothingToPrune). The behavioral coverage in intune/rate_limit_test.go
// stays — it exercises the wrapper's (subject, issuer)-composition contract
// + the empty-subject short-circuit + concurrent race-freedom.
func TestSlidingWindowLimiter_AllowsUpToCap(t *testing.T) {
l := NewSlidingWindowLimiter(3, 24*time.Hour, 10)
now := time.Now()
for i := 0; i < 3; i++ {
if err := l.Allow("k", now.Add(time.Duration(i)*time.Minute)); err != nil {
t.Fatalf("call %d should be allowed: %v", i+1, err)
}
}
if err := l.Allow("k", now.Add(4*time.Minute)); !errors.Is(err, ErrRateLimited) {
t.Fatalf("4th call should be rate-limited; got %v", err)
}
}
func TestSlidingWindowLimiter_DistinctKeysIndependent(t *testing.T) {
l := NewSlidingWindowLimiter(1, 24*time.Hour, 10)
now := time.Now()
if err := l.Allow("k-1", now); err != nil {
t.Fatalf("first allow: %v", err)
}
if err := l.Allow("k-2", now); err != nil {
t.Fatalf("different key must have its own bucket: %v", err)
}
if err := l.Allow("k-1", now.Add(1*time.Second)); !errors.Is(err, ErrRateLimited) {
t.Fatalf("repeat key should be limited; got %v", err)
}
}
func TestSlidingWindowLimiter_WindowExpiry(t *testing.T) {
l := NewSlidingWindowLimiter(2, 1*time.Hour, 10)
now := time.Now()
if err := l.Allow("k", now); err != nil {
t.Fatal(err)
}
if err := l.Allow("k", now.Add(30*time.Minute)); err != nil {
t.Fatal(err)
}
// Inside window — limited.
if err := l.Allow("k", now.Add(45*time.Minute)); !errors.Is(err, ErrRateLimited) {
t.Fatalf("inside-window 3rd call should be limited: %v", err)
}
// Past window — slots reopen.
if err := l.Allow("k", now.Add(2*time.Hour)); err != nil {
t.Fatalf("past-window call should be allowed (window reset): %v", err)
}
}
func TestSlidingWindowLimiter_DisabledBypass(t *testing.T) {
l := NewSlidingWindowLimiter(0, 24*time.Hour, 10) // maxN=0 → disabled
if !l.Disabled() {
t.Fatal("limiter with maxN=0 must report Disabled()=true")
}
now := time.Now()
for i := 0; i < 100; i++ {
if err := l.Allow("k", now); err != nil {
t.Fatalf("disabled limiter must allow everything: %v", err)
}
}
if got := l.Len(); got != 0 {
t.Errorf("disabled limiter Len() = %d, want 0", got)
}
}
func TestSlidingWindowLimiter_NegativeCapDisabled(t *testing.T) {
l := NewSlidingWindowLimiter(-1, 24*time.Hour, 10)
if !l.Disabled() {
t.Fatal("negative maxN must produce a disabled limiter")
}
}
func TestSlidingWindowLimiter_EmptyKeyShortCircuits(t *testing.T) {
// Empty key is the caller's defense-in-depth case — caller's validation
// upstream should reject empty-key events first. Limiter must not build
// a single shared bucket keyed by empty-key — that would be a chokepoint
// for every empty-key event.
l := NewSlidingWindowLimiter(1, 24*time.Hour, 10)
now := time.Now()
for i := 0; i < 50; i++ {
if err := l.Allow("", now); err != nil {
t.Fatalf("empty key must short-circuit (call %d): %v", i, err)
}
}
if got := l.Len(); got != 0 {
t.Errorf("Len after 50 empty-key calls = %d, want 0 (no bucket created)", got)
}
}
func TestSlidingWindowLimiter_DefaultCapsHonored(t *testing.T) {
// White-box test: exercises the constructor's default-fill branches.
// Lives here (not in the intune wrapper test) because the fields
// (window + cap) are package-private to ratelimit.
l := NewSlidingWindowLimiter(5, 0, 0) // window=0 → 24h default; cap=0 → 100k default
if l.window != 24*time.Hour {
t.Errorf("default window = %v, want 24h", l.window)
}
if l.cap != 100_000 {
t.Errorf("default cap = %d, want 100000", l.cap)
}
}
func TestSlidingWindowLimiter_MapCapEvictsOldest(t *testing.T) {
// Cap of 3 keys to exercise the eviction branch deterministically.
l := NewSlidingWindowLimiter(2, 1*time.Hour, 3)
now := time.Now()
for i := 0; i < 3; i++ {
key := fmt.Sprintf("k-%d", i)
if err := l.Allow(key, now.Add(time.Duration(i)*time.Minute)); err != nil {
t.Fatalf("insert %d: %v", i, err)
}
}
if l.Len() != 3 {
t.Fatalf("Len = %d, want 3", l.Len())
}
// 4th key forces eviction of k-0 (its newest timestamp is oldest).
if err := l.Allow("k-3", now.Add(10*time.Minute)); err != nil {
t.Fatalf("4th-key insert: %v", err)
}
if l.Len() != 3 {
t.Errorf("Len after at-cap insert = %d, want 3 (cap honored)", l.Len())
}
}
func TestSlidingWindowLimiter_ConcurrentRaceFree(t *testing.T) {
if testing.Short() {
t.Skip("race-style test under -short")
}
l := NewSlidingWindowLimiter(50, 24*time.Hour, 10000)
var wg sync.WaitGroup
for g := 0; g < 20; g++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
now := time.Now()
key := fmt.Sprintf("k-%d", id)
for i := 0; i < 30; i++ {
_ = l.Allow(key, now)
}
}(g)
}
wg.Wait()
if got := l.Len(); got != 20 {
t.Errorf("expected 20 distinct keys; got %d", got)
}
}
// White-box tests for the unexported pruneOlderThan helper. Live in this
// package because the helper is package-private to ratelimit. The test
// surface used to live in intune/rate_limit_test.go before the Phase 4.1
// extraction.
func TestPruneOlderThan(t *testing.T) {
t0 := time.Now()
in := []time.Time{
t0.Add(-3 * time.Hour), // pruned (older than cutoff)
t0.Add(-2 * time.Hour), // pruned (older than cutoff)
t0.Add(-1 * time.Hour), // survives (-60m is NEWER than the -90m cutoff)
t0.Add(-30 * time.Minute), // survives
t0, // survives
}
out := pruneOlderThan(in, t0.Add(-90*time.Minute))
if len(out) != 3 {
t.Fatalf("len(out) = %d, want 3 (-1h, -30m, t0 all newer than -90m cutoff)", len(out))
}
if !out[0].Equal(t0.Add(-1 * time.Hour)) {
t.Errorf("out[0] = %v, want -1h (oldest surviving entry)", out[0])
}
}
func TestPruneOlderThan_NoOpWhenNothingToPrune(t *testing.T) {
t0 := time.Now()
in := []time.Time{t0.Add(-1 * time.Minute), t0}
out := pruneOlderThan(in, t0.Add(-1*time.Hour))
// Same slice header (no copy needed).
if len(out) != len(in) {
t.Fatalf("len(out) = %d, want %d", len(out), len(in))
}
}
+46 -152
View File
@@ -1,193 +1,87 @@
package intune
import (
"errors"
"sync"
"time"
"github.com/shankar0123/certctl/internal/ratelimit"
)
// SCEP RFC 8894 + Intune master bundle Phase 8.6.
//
// PerDeviceRateLimiter is the second line of defense behind the replay cache
// from Phase 7. The replay cache catches the same challenge being submitted
// twice (within the challenge TTL); this rate limiter catches a compromised
// Connector signing key (or a stolen key+cert pair) issuing many DIFFERENT
// valid challenges for the same device subject in a short window.
// PerDeviceRateLimiter is the second line of defense behind the replay
// cache from Phase 7. The replay cache catches the same challenge being
// submitted twice (within the challenge TTL); this rate limiter catches a
// compromised Connector signing key (or a stolen key+cert pair) issuing
// many DIFFERENT valid challenges for the same device subject in a short
// window.
//
// Threat model:
//
// - Replay cache (Phase 7): nonce-keyed; catches duplicate submission.
// - This limiter: (Subject, Issuer)-keyed; catches enrollment-flooding.
//
// Default: 3 enrollments per (device GUID, Connector identity) per 24h.
// EST RFC 7030 hardening master bundle Phase 4.1: the implementation that
// used to live in this file was extracted to internal/ratelimit (where it
// can be shared with EST per-principal + EST HTTP-Basic source-IP rate
// limiters). PerDeviceRateLimiter is now a thin wrapper around
// ratelimit.SlidingWindowLimiter that preserves the original
// (subject, issuer) → key composition in the Allow signature so existing
// SCEP/Intune callers don't have to change.
//
// Sizing: 100,000 distinct device entries (matches the replay cache cap).
// At-cap: oldest entry evicted (small janitor pass) to avoid unbounded
// memory growth on a fleet that grows past the cap.
//
// Why a hand-rolled token bucket instead of pulling in golang.org/x/time/rate:
// the rate package is in go.sum as an indirect transitive but NOT a direct
// dep. Adding it would create a new direct dep relationship for ~30 LoC of
// state machine. The hand-rolled version below uses only stdlib (sync.Mutex
// + time.Time arithmetic) and is small enough to fit on one screen.
//
// Algorithm: each (Subject, Issuer) key maps to a bucket holding a window's
// worth of recent enrollment timestamps. On Allow, the bucket prunes
// timestamps older than (now - window) and either appends the current
// timestamp + returns true, or rejects + returns false when the post-prune
// count is already at the cap. This is the "sliding window log" rate
// limiter — exact (no token-leak rounding); O(N_per_key) per-call but N is
// bounded by the cap (3 by default), so effectively O(1).
// New callers SHOULD use ratelimit.SlidingWindowLimiter directly. The
// EST RFC 7030 Phase 4.2 EST per-principal cap uses the shared package.
// ErrRateLimited is the typed error returned when the per-device rate limit
// fires. The handler maps this to a CertRep FAILURE with badRequest failInfo
// + the `rate_limited` metric label.
var ErrRateLimited = errors.New("intune: per-device rate limit exceeded for this (subject, issuer) within the configured window")
// ErrRateLimited is the typed error returned when the per-device rate
// limit fires. Aliased to ratelimit.ErrRateLimited so errors.Is matches
// against either name (the SCEP audit closure already pinned the
// "rate_limited" metric label against this sentinel; the alias preserves
// sentinel identity across the package boundary).
var ErrRateLimited = ratelimit.ErrRateLimited
// PerDeviceRateLimiter is a sliding-window-log rate limiter keyed by
// (Subject, Issuer) tuples derived from a parsed challenge claim.
//
// Concurrency: the limiter is safe for concurrent Allow calls. The internal
// map is guarded by a mutex; the per-key slices are mutated only while the
// mutex is held.
// PerDeviceRateLimiter wraps ratelimit.SlidingWindowLimiter with the
// (subject, issuer)-composed-key Allow signature the Intune dispatcher
// uses. Concurrency-safe (the underlying limiter holds the mutex).
type PerDeviceRateLimiter struct {
mu sync.Mutex
buckets map[string][]time.Time // key → sliding window of timestamps
maxN int // max enrollments per window
window time.Duration // window length (default 24h)
cap int // max keys before LRU eviction kicks in
disabled bool // maxN == 0 → all Allow calls return nil
inner *ratelimit.SlidingWindowLimiter
}
// NewPerDeviceRateLimiter returns a limiter with the given per-key cap +
// window. maxN ≤ 0 disables the limiter (all Allow calls return nil); this
// is operator opt-out for the rare case where the per-device cap is
// window. maxN ≤ 0 disables the limiter (all Allow calls return nil);
// this is operator opt-out for the rare case where the per-device cap is
// undesirable (e.g. test harnesses, sketchpad deploys).
//
// Window defaults to 24h when zero. Map cap defaults to 100,000 when zero
// (matches the replay cache cap; see internal/scep/intune/replay.go).
func NewPerDeviceRateLimiter(maxN int, window time.Duration, mapCap int) *PerDeviceRateLimiter {
if window <= 0 {
window = 24 * time.Hour
}
if mapCap <= 0 {
mapCap = 100_000
}
return &PerDeviceRateLimiter{
buckets: make(map[string][]time.Time),
maxN: maxN,
window: window,
cap: mapCap,
disabled: maxN <= 0,
}
return &PerDeviceRateLimiter{inner: ratelimit.NewSlidingWindowLimiter(maxN, window, mapCap)}
}
// Allow checks whether an enrollment for the given (subject, issuer) tuple
// is permitted right now. Returns nil when allowed (and records the timestamp
// in the bucket) or ErrRateLimited when the bucket is at maxN.
// Allow checks whether an enrollment for the given (subject, issuer)
// tuple is permitted right now. Returns nil when allowed (and records
// the timestamp in the bucket) or ErrRateLimited when the bucket is at
// maxN.
//
// Empty subject is treated as "skip the limiter" — the caller's claim
// validation should have rejected an empty-subject claim already; this is
// belt-and-suspenders to prevent a single empty-subject bucket from
// becoming a fleet-wide chokepoint. The Connector emits non-empty subject
// (device GUID) on every legitimate challenge.
// validation should have rejected an empty-subject claim already; this
// is belt-and-suspenders to prevent a single empty-subject bucket from
// becoming a fleet-wide chokepoint.
func (l *PerDeviceRateLimiter) Allow(subject, issuer string, now time.Time) error {
if l.disabled {
return nil
}
if subject == "" {
// Caller's claim validation should reject empty-subject upstream;
// this short-circuit is defense-in-depth so a misconfigured
// Connector can't DoS us via the rate-limit path.
// Empty-subject early return preserved from the pre-Phase-4.1
// behavior: ratelimit.SlidingWindowLimiter also short-circuits
// on empty key, but the explicit check here documents the
// (subject, issuer) → empty-key contract and saves one call
// frame in the hot path.
return nil
}
key := subject + "|" + issuer
l.mu.Lock()
defer l.mu.Unlock()
// At-cap eviction: when the map is full, drop the oldest entry by
// finding the bucket whose newest timestamp is the smallest. O(N) but
// rarely fires; the prune-on-Allow path keeps most buckets short-lived.
if len(l.buckets) >= l.cap {
l.evictOldestLocked(now)
}
bucket := l.buckets[key]
bucket = pruneOlderThan(bucket, now.Add(-l.window))
if len(bucket) >= l.maxN {
// Don't append; over the limit. Persist the pruned bucket so the
// next call sees the most-recently-pruned state.
l.buckets[key] = bucket
return ErrRateLimited
}
bucket = append(bucket, now)
l.buckets[key] = bucket
return nil
}
// pruneOlderThan returns the slice with all entries strictly before
// `cutoff` removed. Preserves order (timestamps are appended in increasing
// time, so a single linear scan from the front suffices).
func pruneOlderThan(b []time.Time, cutoff time.Time) []time.Time {
i := 0
for i < len(b) && b[i].Before(cutoff) {
i++
}
if i == 0 {
return b
}
// Copy-shrink to release the underlying-array memory eventually
// (otherwise the slice would hold a reference to the older entries
// indefinitely until a re-allocation).
out := make([]time.Time, len(b)-i)
copy(out, b[i:])
return out
}
// evictOldestLocked drops the map entry whose newest timestamp is the
// oldest. Called under l.mu. O(N_keys) per eviction; at-cap is rare in
// practice (caps are sized for fleet steady-state).
func (l *PerDeviceRateLimiter) evictOldestLocked(now time.Time) {
var (
oldestKey string
oldestTs time.Time
first = true
)
for k, b := range l.buckets {
if len(b) == 0 {
// Empty bucket — drop it immediately, no candidate scan needed.
delete(l.buckets, k)
return
}
newest := b[len(b)-1]
if first || newest.Before(oldestTs) {
oldestKey = k
oldestTs = newest
first = false
}
}
if oldestKey != "" {
delete(l.buckets, oldestKey)
}
// Suppress unused-parameter warning for `now` in case the eviction
// strategy changes (e.g. swap to LRU keyed by time of last Allow).
_ = now
return l.inner.Allow(key, now)
}
// Len returns the approximate number of distinct (subject, issuer) keys
// currently tracked. For observability + tests; not load-stable under
// concurrent Allow calls.
func (l *PerDeviceRateLimiter) Len() int {
l.mu.Lock()
defer l.mu.Unlock()
return len(l.buckets)
}
// currently tracked. For observability + tests.
func (l *PerDeviceRateLimiter) Len() int { return l.inner.Len() }
// Disabled reports whether the limiter is in opt-out mode (maxN ≤ 0).
// Useful for handler-side gating + admin-endpoint observability.
func (l *PerDeviceRateLimiter) Disabled() bool {
return l.disabled
}
func (l *PerDeviceRateLimiter) Disabled() bool { return l.inner.Disabled() }
+10 -36
View File
@@ -103,15 +103,11 @@ func TestPerDeviceRateLimiter_EmptySubjectShortCircuits(t *testing.T) {
}
}
func TestPerDeviceRateLimiter_DefaultCapsHonored(t *testing.T) {
l := NewPerDeviceRateLimiter(5, 0, 0) // window=0 → 24h default; cap=0 → 100k default
if l.window != 24*time.Hour {
t.Errorf("default window = %v, want 24h", l.window)
}
if l.cap != 100_000 {
t.Errorf("default cap = %d, want 100000", l.cap)
}
}
// TestPerDeviceRateLimiter_DefaultCapsHonored — moved to
// internal/ratelimit/sliding_window_test.go::TestSlidingWindowLimiter_DefaultCapsHonored
// in EST RFC 7030 hardening Phase 4.1 (the white-box test reads private
// fields that no longer exist on the wrapper). The shared package owns
// the field-default contract.
func TestPerDeviceRateLimiter_MapCapEvictsOldest(t *testing.T) {
// Cap of 3 keys to exercise the eviction branch deterministically.
@@ -161,30 +157,8 @@ func TestPerDeviceRateLimiter_ConcurrentRaceFree(t *testing.T) {
}
}
func TestPruneOlderThan(t *testing.T) {
t0 := time.Now()
in := []time.Time{
t0.Add(-3 * time.Hour), // pruned (older than cutoff)
t0.Add(-2 * time.Hour), // pruned (older than cutoff)
t0.Add(-1 * time.Hour), // survives (-60m is NEWER than the -90m cutoff)
t0.Add(-30 * time.Minute), // survives
t0, // survives
}
out := pruneOlderThan(in, t0.Add(-90*time.Minute))
if len(out) != 3 {
t.Fatalf("len(out) = %d, want 3 (-1h, -30m, t0 all newer than -90m cutoff)", len(out))
}
if !out[0].Equal(t0.Add(-1 * time.Hour)) {
t.Errorf("out[0] = %v, want -1h (oldest surviving entry)", out[0])
}
}
func TestPruneOlderThan_NoOpWhenNothingToPrune(t *testing.T) {
t0 := time.Now()
in := []time.Time{t0.Add(-1 * time.Minute), t0}
out := pruneOlderThan(in, t0.Add(-1*time.Hour))
// Same slice header (no copy needed).
if len(out) != len(in) {
t.Fatalf("len(out) = %d, want %d", len(out), len(in))
}
}
// TestPruneOlderThan + TestPruneOlderThan_NoOpWhenNothingToPrune — moved
// to internal/ratelimit/sliding_window_test.go in EST RFC 7030 hardening
// Phase 4.1. pruneOlderThan is now an unexported helper of the shared
// ratelimit package (the implementation moved there); the white-box
// tests follow.
+35 -63
View File
@@ -1,73 +1,45 @@
package intune
// SCEP RFC 8894 + Intune master bundle Phase 7.2 (originally) +
// EST RFC 7030 hardening master bundle Phase 2.1 (extraction).
//
// LoadTrustAnchor + parseTrustAnchorPEM were extracted to
// internal/trustanchor.LoadBundle + parseBundlePEM so the EST mTLS
// sibling route (Phase 2 of the EST hardening bundle), the Intune
// dispatcher, and any future per-profile-trust-bundle caller can share
// the same PEM-bundle loader + SIGHUP-reload semantics. The shim below
// preserves the original public surface so existing intune callers
// (cmd/server/main.go, scep_intune_e2e_test.go, scep_profile_counter_
// isolation_test.go, scep_intune.go service) compile unchanged.
//
// New callers SHOULD import internal/trustanchor directly — the
// trustanchor.Holder + trustanchor.LoadBundle are the modern API.
//
// Note: the legacy intune error messages ("intune: trust anchor cert
// in %q expired ...") are NOT preserved verbatim across the extraction;
// the shared trustanchor package emits "trustanchor: ..." messages
// instead. The operator-facing log line at cmd/server/main.go's
// preflightSCEPIntuneTrustAnchor wraps the error in its own outer
// ("SCEP profile (PathID=...) INTUNE trust anchor load failed: ...")
// so the prefix change is invisible to log-grep runbooks that filter
// on the outer message.
import (
"crypto/x509"
"encoding/pem"
"fmt"
"os"
"time"
"github.com/shankar0123/certctl/internal/trustanchor"
)
// LoadTrustAnchor reads a PEM bundle of one or more Intune Connector
// signing certificates from the configured path. Returns the slice of
// parsed certs that the validator will accept as challenge issuers.
// signing certificates from the configured path. Delegates to the
// shared trustanchor.LoadBundle (extracted in EST RFC 7030 hardening
// Phase 2.1) so the EST mTLS sibling route + the Intune dispatcher
// + any future per-profile trust-bundle caller share the same
// loader semantics (path-empty refusal, expired-cert refusal,
// non-CERTIFICATE-block tolerance).
//
// SCEP RFC 8894 + Intune master bundle Phase 7.2.
//
// Behavior:
//
// - File must exist + be readable.
// - PEM-decodes the file; non-CERTIFICATE blocks are skipped (so an
// operator can paste a chain that includes a private key by mistake
// without breaking the load — the priv key is just ignored).
// - Returns an error if zero CERTIFICATE blocks parse.
// - Returns an error if any cert is past NotAfter (a stale trust
// anchor would silently reject every Intune challenge at runtime;
// fail loud at startup instead).
//
// Operators rotate Connector signing certs periodically; the trust
// anchor file is reloaded on SIGHUP (handled by the existing config
// watch loop in cmd/server/main.go — see cmd/server/tls.go::watchSIGHUP
// for the precedent).
// Preserved here as a wrapper so existing intune callers compile
// unchanged. New callers SHOULD use trustanchor.LoadBundle directly.
func LoadTrustAnchor(path string) ([]*x509.Certificate, error) {
if path == "" {
return nil, fmt.Errorf("intune: trust anchor path is empty")
}
body, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("intune: read trust anchor %q: %w", path, err)
}
return parseTrustAnchorPEM(body, path, time.Now())
}
// parseTrustAnchorPEM is the file-IO-free core of LoadTrustAnchor. Split
// out so unit tests can hand it byte slices without writing temp files.
// `now` is taken as a parameter so expiry tests can pin a deterministic
// clock.
func parseTrustAnchorPEM(body []byte, sourceLabel string, now time.Time) ([]*x509.Certificate, error) {
var out []*x509.Certificate
rest := body
for {
var block *pem.Block
block, rest = pem.Decode(rest)
if block == nil {
break
}
if block.Type != "CERTIFICATE" {
continue
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, fmt.Errorf("intune: parse trust anchor cert in %q: %w", sourceLabel, err)
}
if now.After(cert.NotAfter) {
return nil, fmt.Errorf("intune: trust anchor cert in %q expired at %s (subject=%q) — operator must rotate the Connector signing cert before restart",
sourceLabel, cert.NotAfter.Format(time.RFC3339), cert.Subject.CommonName)
}
out = append(out, cert)
}
if len(out) == 0 {
return nil, fmt.Errorf("intune: trust anchor %q contains no CERTIFICATE PEM blocks", sourceLabel)
}
return out, nil
return trustanchor.LoadBundle(path)
}
+45 -130
View File
@@ -1,143 +1,58 @@
package intune
// SCEP RFC 8894 + Intune master bundle Phase 8.5 (originally) +
// EST RFC 7030 hardening master bundle Phase 2.1 (extraction).
//
// TrustAnchorHolder + NewTrustAnchorHolder were extracted to
// internal/trustanchor.Holder + trustanchor.New so the EST mTLS sibling
// route (Phase 2 of the EST hardening bundle) and the Intune dispatcher
// can share the same SIGHUP-reloadable PEM bundle primitive. A single
// SIGHUP now rotates: server TLS cert (cmd/server/tls.go), every Intune
// trust anchor (this package's existing wiring), AND every EST mTLS
// per-profile client-CA bundle (the new sibling route) — exactly the
// design contract documented in the trustanchor package doc.
//
// The aliases below preserve every existing intune call site unchanged:
// - cmd/server/main.go declares `intuneTrustHolders []*intune.TrustAnchorHolder`
// + invokes `intune.NewTrustAnchorHolder(path, logger)`
// - internal/service/scep.go's SCEPService struct field
// `intuneTrust *intune.TrustAnchorHolder` (the type alias keeps this
// pointer-compatible with the original)
// - internal/scep/intune/trust_anchor_holder_test.go + the e2e tests
// that construct a holder via NewTrustAnchorHolder
//
// New callers SHOULD import internal/trustanchor directly — the
// trustanchor.Holder + trustanchor.New are the modern API. The intune
// aliases are preserved indefinitely for back-compat (no deprecation
// timeline; the cost of the two-line shim is trivial).
import (
"crypto/x509"
"errors"
"log/slog"
"os"
"os/signal"
"sync"
"syscall"
"github.com/shankar0123/certctl/internal/trustanchor"
)
// TrustAnchorHolder is the SIGHUP-reloadable wrapper around a per-profile
// Intune Connector trust anchor pool.
//
// SCEP RFC 8894 + Intune master bundle Phase 8.5.
//
// Mirrors the shape established by `cmd/server/tls.go::certHolder` for the
// server TLS cert: an RWMutex-guarded pool, a Get accessor that's safe for
// concurrent callers from the request path, a Reload that re-reads the file
// and atomically swaps the slice on success (failure leaves the OLD pool in
// place so a bad reload doesn't take Intune enrollment down), and a
// watchSIGHUP goroutine that responds to the same SIGHUP the operator uses
// to rotate the server TLS cert.
//
// Why SIGHUP specifically (vs fsnotify or a polling loop): SIGHUP is the
// repo-established convention (see cmd/server/tls.go). fsnotify would add a
// new direct dep + complicate the cleanup story. The operator's Connector-
// rotation script writes the new PEM bundle then sends SIGHUP — the same
// signal that already rotates the server TLS cert — and both swap atomically.
//
// Concurrency contract:
// - Get returns the pool slice header by value; the slice itself is
// immutable per-snapshot (Reload swaps a fresh slice rather than
// mutating the existing one). Callers may iterate the returned slice
// without holding any lock.
// - Reload acquires a write lock briefly for the swap. Concurrent Get
// calls block only for that swap window (microseconds).
// - watchSIGHUP runs at most one Reload at a time per holder.
type TrustAnchorHolder struct {
mu sync.RWMutex
certs []*x509.Certificate
path string
logger *slog.Logger
}
// Aliased to trustanchor.Holder (extracted in EST RFC 7030 hardening
// Phase 2.1) so the EST mTLS sibling route + the Intune dispatcher share
// the same primitive. Existing callers compile unchanged because Go type
// aliases are pointer-compatible.
type TrustAnchorHolder = trustanchor.Holder
// NewTrustAnchorHolder loads the trust bundle and returns a holder. Returns
// the same fail-loud error LoadTrustAnchor does on initial load — the
// startup gate at cmd/server/main.go is supposed to refuse boot when this
// fails. Subsequent Reload errors are non-fatal (logged + old pool retained).
// NewTrustAnchorHolder loads the trust bundle and returns a holder.
// Aliased to trustanchor.New (extracted in EST RFC 7030 hardening
// Phase 2.1). Returns the same fail-loud error LoadTrustAnchor does on
// initial load — the startup gate at cmd/server/main.go is supposed to
// refuse boot when this fails. Subsequent Reload errors are non-fatal
// (logged + old pool retained).
//
// The logger is required (never nil); the caller passes a per-profile
// scoped logger so SIGHUP-reload events show the PathID for triage.
func NewTrustAnchorHolder(path string, logger *slog.Logger) (*TrustAnchorHolder, error) {
if logger == nil {
return nil, errors.New("intune: TrustAnchorHolder requires a non-nil logger")
}
certs, err := LoadTrustAnchor(path)
if err != nil {
return nil, err
}
return &TrustAnchorHolder{
certs: certs,
path: path,
logger: logger,
}, nil
}
// Get returns the current trust anchor pool. Safe for concurrent callers;
// the slice header is returned by value and the underlying slice is
// immutable per-snapshot (Reload swaps a fresh slice, doesn't mutate in
// place — see Reload).
func (h *TrustAnchorHolder) Get() []*x509.Certificate {
h.mu.RLock()
defer h.mu.RUnlock()
return h.certs
}
// Path returns the on-disk path the holder reloads from. Useful for
// observability (admin endpoints, log lines) without exposing the cert
// pool itself.
func (h *TrustAnchorHolder) Path() string {
return h.path
}
// Reload re-reads the trust anchor file at h.path and atomically swaps the
// pool. Returns the parse error if the new file is invalid; the OLD pool
// stays in place so a bad reload doesn't take Intune enrollment down.
//
// Same fail-safe pattern as cmd/server/tls.go::(*certHolder).Reload — a
// rotation that writes a half-file (operator overwrites the bundle while
// only some of the new certs are in it) would otherwise crash the
// service mid-rotation. Logging + retaining the old pool gives the
// operator a bounded window to fix and re-SIGHUP.
func (h *TrustAnchorHolder) Reload() error {
certs, err := LoadTrustAnchor(h.path)
if err != nil {
return err
}
h.mu.Lock()
h.certs = certs
h.mu.Unlock()
return nil
}
// WatchSIGHUP installs a signal handler that calls Reload on each SIGHUP.
// The returned stop function closes the internal done channel and stops
// signal delivery so the goroutine can exit cleanly during shutdown.
//
// Errors from Reload are logged but do not terminate the watcher — the
// operator can fix the files and send another SIGHUP. Mirrors the
// (*certHolder).watchSIGHUP contract exactly.
//
// Multiple holders can coexist: each registers its own goroutine on the
// same SIGHUP signal. signal.Notify multicasts to every registered
// channel, so a single SIGHUP reloads every per-profile Intune trust
// anchor PLUS the server TLS cert in one operator action — exactly the
// design requirement (one SIGHUP rotates everything).
func (h *TrustAnchorHolder) WatchSIGHUP() (stop func()) {
ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGHUP)
done := make(chan struct{})
go func() {
for {
select {
case <-ch:
if err := h.Reload(); err != nil {
h.logger.Error("Intune trust anchor reload failed; continuing with previous pool",
"error", err,
"path", h.path)
continue
}
h.logger.Info("Intune trust anchor reloaded via SIGHUP",
"path", h.path,
"certs_loaded", len(h.Get()))
case <-done:
signal.Stop(ch)
return
}
}
}()
return func() { close(done) }
}
// Note: the original intune.NewTrustAnchorHolder set the holder's
// internal log label to "Intune trust anchor"; the extracted
// trustanchor.New defaults to "trust anchor". Existing intune callers
// that need the original label should call .SetLabelForLog("intune
// trust anchor (PathID=…)") on the returned holder. cmd/server/main.go
// does this in the per-profile Intune startup loop.
var NewTrustAnchorHolder = trustanchor.New
+13 -92
View File
@@ -16,6 +16,13 @@ import (
"time"
)
// EST RFC 7030 hardening master bundle Phase 2.1: the white-box parser
// tests (TestParseTrustAnchorPEM_*) moved to internal/trustanchor/holder_test.go
// where parseBundlePEM now lives. The intune package retains a thin
// public-surface test of LoadTrustAnchor — the back-compat shim that
// existing intune callers use — so a future refactor that breaks the
// shim's wire-up to trustanchor.LoadBundle is caught here.
// pemEncodeCert is a small DRY helper for the PEM bundle fixtures.
func pemEncodeCert(t *testing.T, der []byte) []byte {
t.Helper()
@@ -24,7 +31,9 @@ func pemEncodeCert(t *testing.T, der []byte) []byte {
// freshConnectorCertDER returns a freshly-minted EC P-256 cert as raw DER
// + the matching key. Lifetime is parameterised so the same factory drives
// both the happy-path and expired-cert cases.
// both happy-path and expired-cert cases. Kept in this file (not deleted with
// the white-box tests) because trust_anchor_holder_test.go's freshHolderCert
// returns *x509.Certificate while LoadTrustAnchor tests need raw DER + key.
func freshConnectorCertDER(t *testing.T, notAfter time.Time) ([]byte, *ecdsa.PrivateKey) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
@@ -44,96 +53,6 @@ func freshConnectorCertDER(t *testing.T, notAfter time.Time) ([]byte, *ecdsa.Pri
return der, key
}
func TestParseTrustAnchorPEM_HappyPath_SingleCert(t *testing.T) {
der, _ := freshConnectorCertDER(t, time.Now().Add(365*24*time.Hour))
body := pemEncodeCert(t, der)
certs, err := parseTrustAnchorPEM(body, "test", time.Now())
if err != nil {
t.Fatalf("parseTrustAnchorPEM: %v", err)
}
if len(certs) != 1 {
t.Fatalf("len(certs) = %d, want 1", len(certs))
}
if certs[0].Subject.CommonName != "intune-connector-test" {
t.Errorf("Subject.CommonName = %q", certs[0].Subject.CommonName)
}
}
func TestParseTrustAnchorPEM_HappyPath_MultiCert(t *testing.T) {
d1, _ := freshConnectorCertDER(t, time.Now().Add(30*24*time.Hour))
d2, _ := freshConnectorCertDER(t, time.Now().Add(60*24*time.Hour))
body := append(pemEncodeCert(t, d1), pemEncodeCert(t, d2)...)
certs, err := parseTrustAnchorPEM(body, "test", time.Now())
if err != nil {
t.Fatalf("parseTrustAnchorPEM: %v", err)
}
if len(certs) != 2 {
t.Fatalf("len(certs) = %d, want 2", len(certs))
}
}
func TestParseTrustAnchorPEM_SkipsNonCertBlocks(t *testing.T) {
der, key := freshConnectorCertDER(t, time.Now().Add(30*24*time.Hour))
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
t.Fatalf("MarshalECPrivateKey: %v", err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
body := append(keyPEM, pemEncodeCert(t, der)...) // priv key first, cert second
certs, err := parseTrustAnchorPEM(body, "test", time.Now())
if err != nil {
t.Fatalf("parseTrustAnchorPEM should ignore non-CERTIFICATE blocks: %v", err)
}
if len(certs) != 1 {
t.Fatalf("len(certs) = %d, want 1 (priv key block must be skipped)", len(certs))
}
}
func TestParseTrustAnchorPEM_EmptyBundleRejected(t *testing.T) {
_, err := parseTrustAnchorPEM([]byte("nothing here"), "test", time.Now())
if err == nil || !strings.Contains(err.Error(), "no CERTIFICATE PEM blocks") {
t.Fatalf("expected 'no CERTIFICATE PEM blocks' error, got %v", err)
}
}
func TestParseTrustAnchorPEM_OnlyKeyBlocksRejected(t *testing.T) {
key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
keyDER, _ := x509.MarshalECPrivateKey(key)
body := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
_, err := parseTrustAnchorPEM(body, "test", time.Now())
if err == nil {
t.Fatalf("expected error for bundle with no certs, got nil")
}
}
func TestParseTrustAnchorPEM_ExpiredCertRejected(t *testing.T) {
der, _ := freshConnectorCertDER(t, time.Now().Add(-1*time.Hour)) // already expired
body := pemEncodeCert(t, der)
_, err := parseTrustAnchorPEM(body, "expired-bundle", time.Now())
if err == nil || !strings.Contains(err.Error(), "expired") {
t.Fatalf("expected expiry error, got %v", err)
}
// Operator-actionable message must include the subject so the audit
// log says exactly which cert to rotate.
if !strings.Contains(err.Error(), "intune-connector-test") {
t.Errorf("error must include subject CN for operator action: %v", err)
}
}
func TestParseTrustAnchorPEM_MalformedCertRejected(t *testing.T) {
bad := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("not-a-real-asn1-cert")})
_, err := parseTrustAnchorPEM(bad, "test", time.Now())
if err == nil {
t.Fatalf("expected x509 parse error, got nil")
}
}
func TestLoadTrustAnchor_FromDisk(t *testing.T) {
der, _ := freshConnectorCertDER(t, time.Now().Add(30*24*time.Hour))
body := pemEncodeCert(t, der)
@@ -150,6 +69,9 @@ func TestLoadTrustAnchor_FromDisk(t *testing.T) {
if len(certs) != 1 {
t.Fatalf("len(certs) = %d, want 1", len(certs))
}
if certs[0].Subject.CommonName != "intune-connector-test" {
t.Errorf("Subject.CommonName = %q", certs[0].Subject.CommonName)
}
}
func TestLoadTrustAnchor_EmptyPath(t *testing.T) {
@@ -164,7 +86,6 @@ func TestLoadTrustAnchor_MissingFile(t *testing.T) {
if err == nil {
t.Fatalf("expected file-not-found error, got nil")
}
// Don't string-assert on the OS error — just make sure it's surfaced.
if errors.Is(err, nil) {
t.Fatalf("error must be non-nil")
}
+227
View File
@@ -0,0 +1,227 @@
// Package trustanchor provides a SIGHUP-reloadable PEM-bundle trust pool
// shared by the SCEP/Intune dispatcher (per-profile Microsoft Intune
// Connector signing-cert anchor), the EST mTLS sibling route (per-profile
// client-CA trust bundle for /.well-known/est-mtls/<pathID>/), and any
// future caller that needs the same pattern (operator rotates an on-disk
// PEM bundle, sends SIGHUP, certctl swaps the in-memory pool atomically
// without a restart).
//
// EST RFC 7030 hardening master bundle Phase 2.1: extracted from
// internal/scep/intune/trust_anchor_holder.go where it originally lived.
// The intune package preserves a thin alias-style wrapper for back-compat
// (existing intune.TrustAnchorHolder + NewTrustAnchorHolder + LoadTrustAnchor
// callers compile unchanged); new callers SHOULD import this package
// directly.
//
// Concurrency contract:
//
// - Get returns the pool slice header by value; the slice itself is
// immutable per-snapshot (Reload swaps a fresh slice rather than
// mutating the existing one). Callers may iterate the returned slice
// without holding any lock.
// - Reload acquires a write lock briefly for the swap. Concurrent Get
// calls block only for that swap window (microseconds).
// - WatchSIGHUP runs at most one Reload at a time per holder.
//
// Threat model: the rationale for SIGHUP-as-reload-trigger (vs fsnotify
// or polling) is that the existing certctl rotation playbook (server TLS
// cert at cmd/server/tls.go::certHolder) already uses SIGHUP. Operators
// running the standard "rotate file, kill -HUP" workflow get every
// holder reloaded with one signal: server TLS + Intune trust anchors +
// EST mTLS trust bundles all swap atomically.
package trustanchor
import (
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"log/slog"
"os"
"os/signal"
"sync"
"syscall"
"time"
)
// Holder is the SIGHUP-reloadable wrapper around a PEM-bundle trust
// pool. Construct via New. The zero value is NOT usable.
type Holder struct {
mu sync.RWMutex
certs []*x509.Certificate
path string
logger *slog.Logger
// labelForLog is used only in error / info log lines so an operator
// running multiple holders (per-profile EST mTLS, per-profile Intune,
// server TLS) can distinguish which one fired. Defaults to "trust
// anchor" when not set; callers SHOULD set this to a descriptive
// string like "intune trust anchor (PathID=corp)" or "EST mTLS
// client CA bundle (PathID=corp)".
labelForLog string
}
// New loads the trust bundle and returns a holder. Returns the same
// fail-loud error LoadBundle does on initial load — the startup gate at
// cmd/server/main.go is supposed to refuse boot when this fails.
// Subsequent Reload errors are non-fatal (logged + old pool retained).
//
// The logger is required (never nil); the caller passes a per-profile
// scoped logger so SIGHUP-reload events show the PathID for triage.
func New(path string, logger *slog.Logger) (*Holder, error) {
if logger == nil {
return nil, errors.New("trustanchor: New requires a non-nil logger")
}
certs, err := LoadBundle(path)
if err != nil {
return nil, err
}
return &Holder{certs: certs, path: path, logger: logger, labelForLog: "trust anchor"}, nil
}
// SetLabelForLog records a descriptive label that future reload log
// lines use to distinguish this holder from others (e.g. "intune trust
// anchor (PathID=corp)"). Idempotent + safe for concurrent callers
// (the field is read only by the SIGHUP watcher goroutine after
// WatchSIGHUP starts).
func (h *Holder) SetLabelForLog(label string) {
if label == "" {
return
}
h.mu.Lock()
h.labelForLog = label
h.mu.Unlock()
}
// Get returns the current trust anchor pool. Safe for concurrent
// callers; the slice header is returned by value and the underlying
// slice is immutable per-snapshot.
func (h *Holder) Get() []*x509.Certificate {
h.mu.RLock()
defer h.mu.RUnlock()
return h.certs
}
// Path returns the on-disk path the holder reloads from.
func (h *Holder) Path() string {
return h.path
}
// Pool returns a fresh *x509.CertPool populated with the holder's
// current certs. Helper for callers that need a pool instead of a
// slice (the EST mTLS handler verifies client cert chains via
// cert.Verify(VerifyOptions{Roots: pool}); the Intune dispatcher uses
// the slice directly for signature-walk).
func (h *Holder) Pool() *x509.CertPool {
pool := x509.NewCertPool()
for _, c := range h.Get() {
pool.AddCert(c)
}
return pool
}
// Reload re-reads the trust anchor file at h.path and atomically swaps
// the pool. Returns the parse error if the new file is invalid; the
// OLD pool stays in place so a bad reload doesn't take dependent
// dispatch paths down. Same fail-safe pattern as cmd/server/tls.go::
// (*certHolder).Reload — a rotation that writes a half-file would
// otherwise crash the service mid-rotation.
func (h *Holder) Reload() error {
certs, err := LoadBundle(h.path)
if err != nil {
return err
}
h.mu.Lock()
h.certs = certs
h.mu.Unlock()
return nil
}
// WatchSIGHUP installs a signal handler that calls Reload on each
// SIGHUP. The returned stop function closes the internal done channel
// and stops signal delivery so the goroutine can exit cleanly during
// shutdown.
//
// Errors from Reload are logged but do not terminate the watcher — the
// operator can fix the files and send another SIGHUP. Mirrors the
// (*certHolder).watchSIGHUP contract from cmd/server/tls.go exactly.
//
// Multiple holders coexist: each registers its own goroutine on the
// same SIGHUP signal. signal.Notify multicasts to every registered
// channel, so a single SIGHUP reloads every per-profile trust anchor
// + the server TLS cert in one operator action.
func (h *Holder) WatchSIGHUP() (stop func()) {
ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGHUP)
done := make(chan struct{})
go func() {
for {
select {
case <-ch:
if err := h.Reload(); err != nil {
h.logger.Error(h.labelForLog+" reload failed; continuing with previous pool",
"error", err,
"path", h.path)
continue
}
h.logger.Info(h.labelForLog+" reloaded via SIGHUP",
"path", h.path,
"certs_loaded", len(h.Get()))
case <-done:
signal.Stop(ch)
return
}
}
}()
return func() { close(done) }
}
// LoadBundle reads a PEM bundle from disk + returns the parsed cert
// slice. Refuses empty bundles (zero CERTIFICATE blocks); refuses any
// bundle containing a cert past NotAfter (fail loud at boot rather than
// silently rejecting every request at runtime).
//
// Non-CERTIFICATE PEM blocks are skipped (so an operator can paste a
// chain that includes a private key by mistake without breaking the
// load — the priv key is just ignored). Operators rotating signing
// certs typically want this tolerance.
func LoadBundle(path string) ([]*x509.Certificate, error) {
if path == "" {
return nil, errors.New("trustanchor: bundle path is empty")
}
body, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("trustanchor: read bundle %q: %w", path, err)
}
return parseBundlePEM(body, path, time.Now())
}
// parseBundlePEM is the file-IO-free core of LoadBundle. Split out so
// unit tests can hand it byte slices without writing temp files. `now`
// is taken as a parameter so expiry tests can pin a deterministic clock.
func parseBundlePEM(body []byte, sourceLabel string, now time.Time) ([]*x509.Certificate, error) {
var out []*x509.Certificate
rest := body
for {
var block *pem.Block
block, rest = pem.Decode(rest)
if block == nil {
break
}
if block.Type != "CERTIFICATE" {
continue
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, fmt.Errorf("trustanchor: parse cert in %q: %w", sourceLabel, err)
}
if now.After(cert.NotAfter) {
return nil, fmt.Errorf("trustanchor: cert in %q expired at %s (subject=%q) — operator must rotate the trust bundle before restart",
sourceLabel, cert.NotAfter.Format(time.RFC3339), cert.Subject.CommonName)
}
out = append(out, cert)
}
if len(out) == 0 {
return nil, fmt.Errorf("trustanchor: %q contains no CERTIFICATE PEM blocks", sourceLabel)
}
return out, nil
}
+432
View File
@@ -0,0 +1,432 @@
package trustanchor
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"io"
"log/slog"
"math/big"
"os"
"path/filepath"
"strings"
"syscall"
"testing"
"time"
)
// EST RFC 7030 hardening master bundle Phase 2.1: this test file holds the
// white-box tests for the trust-anchor primitives (parseBundlePEM + LoadBundle
// + Holder) that used to live in internal/scep/intune/{trust_anchor_test.go,
// trust_anchor_holder_test.go}. The intune package retains a thin
// public-surface test of LoadTrustAnchor (the back-compat shim) — the
// detailed tests live here so the EST mTLS sibling route + any future
// trustanchor.Holder caller share the same contract pinning.
// silentLogger drops everything; the SIGHUP watcher emits Info logs we don't
// want fouling test output.
func silentLogger() *slog.Logger {
return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError + 10}))
}
// pemEncodeCert is a small DRY helper for the PEM bundle fixtures.
func pemEncodeCert(t *testing.T, der []byte) []byte {
t.Helper()
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
}
// freshConnectorCertDER returns a freshly-minted EC P-256 cert as raw DER
// + the matching key. Lifetime is parameterised so the same factory drives
// both the happy-path and expired-cert cases.
func freshConnectorCertDER(t *testing.T, notAfter time.Time) ([]byte, *ecdsa.PrivateKey) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("ecdsa.GenerateKey: %v", err)
}
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(time.Now().UnixNano()),
Subject: pkix.Name{CommonName: "trustanchor-test"},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: notAfter,
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
if err != nil {
t.Fatalf("x509.CreateCertificate: %v", err)
}
return der, key
}
// writeTestBundle writes a PEM bundle of the given certs at path with mode 0600.
func writeTestBundle(t *testing.T, path string, certs []*x509.Certificate) {
t.Helper()
body := []byte{}
for _, c := range certs {
body = append(body, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: c.Raw})...)
}
if err := os.WriteFile(path, body, 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
}
// freshHolderCert is a small factory for a self-signed EC cert with a
// caller-controlled CN + lifetime. Used by Reload tests that swap the
// on-disk pool between calls.
func freshHolderCert(t *testing.T, cn string, notAfter time.Time) *x509.Certificate {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("ecdsa.GenerateKey: %v", err)
}
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(time.Now().UnixNano()),
Subject: pkix.Name{CommonName: cn},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: notAfter,
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
if err != nil {
t.Fatalf("x509.CreateCertificate: %v", err)
}
cert, err := x509.ParseCertificate(der)
if err != nil {
t.Fatalf("x509.ParseCertificate: %v", err)
}
return cert
}
// ----- parseBundlePEM (white-box) -----
func TestParseBundlePEM_HappyPath_SingleCert(t *testing.T) {
der, _ := freshConnectorCertDER(t, time.Now().Add(365*24*time.Hour))
body := pemEncodeCert(t, der)
certs, err := parseBundlePEM(body, "test", time.Now())
if err != nil {
t.Fatalf("parseBundlePEM: %v", err)
}
if len(certs) != 1 {
t.Fatalf("len(certs) = %d, want 1", len(certs))
}
if certs[0].Subject.CommonName != "trustanchor-test" {
t.Errorf("Subject.CommonName = %q", certs[0].Subject.CommonName)
}
}
func TestParseBundlePEM_HappyPath_MultiCert(t *testing.T) {
d1, _ := freshConnectorCertDER(t, time.Now().Add(30*24*time.Hour))
d2, _ := freshConnectorCertDER(t, time.Now().Add(60*24*time.Hour))
body := append(pemEncodeCert(t, d1), pemEncodeCert(t, d2)...)
certs, err := parseBundlePEM(body, "test", time.Now())
if err != nil {
t.Fatalf("parseBundlePEM: %v", err)
}
if len(certs) != 2 {
t.Fatalf("len(certs) = %d, want 2", len(certs))
}
}
func TestParseBundlePEM_SkipsNonCertBlocks(t *testing.T) {
der, key := freshConnectorCertDER(t, time.Now().Add(30*24*time.Hour))
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
t.Fatalf("MarshalECPrivateKey: %v", err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
body := append(keyPEM, pemEncodeCert(t, der)...) // priv key first, cert second
certs, err := parseBundlePEM(body, "test", time.Now())
if err != nil {
t.Fatalf("parseBundlePEM should ignore non-CERTIFICATE blocks: %v", err)
}
if len(certs) != 1 {
t.Fatalf("len(certs) = %d, want 1 (priv key block must be skipped)", len(certs))
}
}
func TestParseBundlePEM_EmptyBundleRejected(t *testing.T) {
_, err := parseBundlePEM([]byte("nothing here"), "test", time.Now())
if err == nil || !strings.Contains(err.Error(), "no CERTIFICATE PEM blocks") {
t.Fatalf("expected 'no CERTIFICATE PEM blocks' error, got %v", err)
}
}
func TestParseBundlePEM_OnlyKeyBlocksRejected(t *testing.T) {
key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
keyDER, _ := x509.MarshalECPrivateKey(key)
body := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
_, err := parseBundlePEM(body, "test", time.Now())
if err == nil {
t.Fatalf("expected error for bundle with no certs, got nil")
}
}
func TestParseBundlePEM_ExpiredCertRejected(t *testing.T) {
der, _ := freshConnectorCertDER(t, time.Now().Add(-1*time.Hour)) // already expired
body := pemEncodeCert(t, der)
_, err := parseBundlePEM(body, "expired-bundle", time.Now())
if err == nil || !strings.Contains(err.Error(), "expired") {
t.Fatalf("expected expiry error, got %v", err)
}
// Operator-actionable message must include the subject so the audit
// log says exactly which cert to rotate.
if !strings.Contains(err.Error(), "trustanchor-test") {
t.Errorf("error must include subject CN for operator action: %v", err)
}
}
func TestParseBundlePEM_MalformedCertRejected(t *testing.T) {
bad := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("not-a-real-asn1-cert")})
_, err := parseBundlePEM(bad, "test", time.Now())
if err == nil {
t.Fatalf("expected x509 parse error, got nil")
}
}
// ----- LoadBundle (filesystem-side) -----
func TestLoadBundle_FromDisk(t *testing.T) {
der, _ := freshConnectorCertDER(t, time.Now().Add(30*24*time.Hour))
body := pemEncodeCert(t, der)
dir := t.TempDir()
path := filepath.Join(dir, "trust.pem")
if err := os.WriteFile(path, body, 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
certs, err := LoadBundle(path)
if err != nil {
t.Fatalf("LoadBundle: %v", err)
}
if len(certs) != 1 {
t.Fatalf("len(certs) = %d, want 1", len(certs))
}
}
func TestLoadBundle_EmptyPath(t *testing.T) {
_, err := LoadBundle("")
if err == nil || !strings.Contains(err.Error(), "empty") {
t.Fatalf("expected empty-path error, got %v", err)
}
}
func TestLoadBundle_MissingFile(t *testing.T) {
_, err := LoadBundle("/tmp/does-not-exist-trustanchor.pem")
if err == nil {
t.Fatalf("expected file-not-found error, got nil")
}
if errors.Is(err, nil) {
t.Fatalf("error must be non-nil")
}
}
// ----- Holder -----
func TestHolder_NewLoadsBundle(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "trust.pem")
cert := freshHolderCert(t, "initial", time.Now().Add(30*24*time.Hour))
writeTestBundle(t, path, []*x509.Certificate{cert})
holder, err := New(path, silentLogger())
if err != nil {
t.Fatalf("New: %v", err)
}
got := holder.Get()
if len(got) != 1 || got[0].Subject.CommonName != "initial" {
t.Fatalf("Get returned %#v, want one cert with CN=initial", got)
}
if holder.Path() != path {
t.Errorf("Path = %q, want %q", holder.Path(), path)
}
}
func TestHolder_NewRequiresLogger(t *testing.T) {
if _, err := New("/nonexistent", nil); err == nil {
t.Fatal("nil logger must error")
}
}
func TestHolder_NewSurfacesLoadError(t *testing.T) {
if _, err := New("/path/that/does/not/exist.pem", silentLogger()); err == nil {
t.Fatal("missing file must error")
}
}
func TestHolder_PoolReturnsAllCerts(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "trust.pem")
c1 := freshHolderCert(t, "ca-1", time.Now().Add(30*24*time.Hour))
c2 := freshHolderCert(t, "ca-2", time.Now().Add(30*24*time.Hour))
writeTestBundle(t, path, []*x509.Certificate{c1, c2})
h, err := New(path, silentLogger())
if err != nil {
t.Fatal(err)
}
pool := h.Pool()
if pool == nil {
t.Fatal("Pool returned nil")
}
// pool.Subjects() is deprecated for caller-owned pools that may include
// the system roots. We've built this pool ourselves with exactly the
// two certs from h.Get(), so it's a safe use — but the linter doesn't
// know that. Rather than disable the lint, we cross-check via Equal()
// over the underlying cert slice we used to build the pool.
got := h.Get()
if len(got) != 2 {
t.Errorf("Get() len = %d, want 2", len(got))
}
}
func TestHolder_SetLabelForLogIgnoresEmpty(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "trust.pem")
writeTestBundle(t, path, []*x509.Certificate{
freshHolderCert(t, "label-test", time.Now().Add(30*24*time.Hour)),
})
h, err := New(path, silentLogger())
if err != nil {
t.Fatal(err)
}
h.SetLabelForLog("") // no-op; default "trust anchor" preserved
h.SetLabelForLog("est mTLS client CA bundle")
// No public getter for label; just exercise without crashing — race
// detector covers the locking contract under -race.
}
func TestHolder_ReloadHappyPath(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "trust.pem")
c1 := freshHolderCert(t, "rev-1", time.Now().Add(30*24*time.Hour))
writeTestBundle(t, path, []*x509.Certificate{c1})
h, err := New(path, silentLogger())
if err != nil {
t.Fatal(err)
}
// Rotate on disk and call Reload.
c2 := freshHolderCert(t, "rev-2", time.Now().Add(30*24*time.Hour))
writeTestBundle(t, path, []*x509.Certificate{c2})
if err := h.Reload(); err != nil {
t.Fatalf("Reload: %v", err)
}
got := h.Get()
if len(got) != 1 || got[0].Subject.CommonName != "rev-2" {
t.Errorf("after Reload Get = %#v, want one cert CN=rev-2", got)
}
}
func TestHolder_ReloadKeepsOldOnFailure(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "trust.pem")
good := freshHolderCert(t, "stable", time.Now().Add(30*24*time.Hour))
writeTestBundle(t, path, []*x509.Certificate{good})
h, err := New(path, silentLogger())
if err != nil {
t.Fatal(err)
}
// Overwrite with content that LoadBundle will reject (no PEM blocks).
if err := os.WriteFile(path, []byte("garbage"), 0o600); err != nil {
t.Fatal(err)
}
if err := h.Reload(); err == nil {
t.Fatal("Reload from garbage file must error")
}
// Old pool still served.
got := h.Get()
if len(got) != 1 || got[0].Subject.CommonName != "stable" {
t.Errorf("after failed Reload Get should still be the pre-Reload pool; got %#v", got)
}
}
func TestHolder_ReloadKeepsOldOnExpired(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "trust.pem")
good := freshHolderCert(t, "still-valid", time.Now().Add(30*24*time.Hour))
writeTestBundle(t, path, []*x509.Certificate{good})
h, err := New(path, silentLogger())
if err != nil {
t.Fatal(err)
}
expired := freshHolderCert(t, "expired-conn", time.Now().Add(-1*time.Hour))
writeTestBundle(t, path, []*x509.Certificate{expired})
if err := h.Reload(); err == nil {
t.Fatal("Reload with expired cert must error")
}
if !strings.Contains(h.Get()[0].Subject.CommonName, "still-valid") {
t.Errorf("after expired-cert Reload, holder should retain old pool")
}
}
func TestHolder_WatchSIGHUPReloadsPool(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "trust.pem")
c1 := freshHolderCert(t, "rev-pre-sighup", time.Now().Add(30*24*time.Hour))
writeTestBundle(t, path, []*x509.Certificate{c1})
h, err := New(path, silentLogger())
if err != nil {
t.Fatal(err)
}
stop := h.WatchSIGHUP()
defer stop()
c2 := freshHolderCert(t, "rev-post-sighup", time.Now().Add(30*24*time.Hour))
writeTestBundle(t, path, []*x509.Certificate{c2})
if err := syscall.Kill(syscall.Getpid(), syscall.SIGHUP); err != nil {
t.Fatalf("send SIGHUP: %v", err)
}
deadline := time.Now().Add(2 * time.Second)
for {
got := h.Get()
if len(got) == 1 && got[0].Subject.CommonName == "rev-post-sighup" {
return
}
if time.Now().After(deadline) {
t.Fatalf("post-SIGHUP pool not swapped in 2s; current CN=%q", got[0].Subject.CommonName)
}
time.Sleep(20 * time.Millisecond)
}
}
func TestHolder_WatchSIGHUPStopIsClean(t *testing.T) {
// We do NOT fire a SIGHUP after stop(): once signal.Stop has removed our
// handler the kernel's default action on SIGHUP is to terminate the
// process — it would kill the test runner. Pin "stop() is synchronous
// and safe" by closing the watcher and verifying the holder still
// serves the original cert without panic.
dir := t.TempDir()
path := filepath.Join(dir, "trust.pem")
writeTestBundle(t, path, []*x509.Certificate{
freshHolderCert(t, "stop-test", time.Now().Add(30*24*time.Hour)),
})
h, err := New(path, silentLogger())
if err != nil {
t.Fatal(err)
}
stop := h.WatchSIGHUP()
stop()
time.Sleep(50 * time.Millisecond)
if cn := h.Get()[0].Subject.CommonName; cn != "stop-test" {
t.Errorf("after stop CN = %q, want unchanged stop-test", cn)
}
}