mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-11 04:48:53 +00:00
EST RFC 7030 hardening master bundle Phases 2-4: end-to-end mTLS sibling
route + RFC 9266 channel binding + HTTP Basic enrollment-password +
per-source-IP failed-auth limit + per-(CN, sourceIP) sliding-window cap.
Two new shared packages so EST + Intune share infrastructure:
- internal/cms/ — RFC 9266 tls-exporter extractor (ExtractTLSExporter
with stdlib-panic recovery for synthetic ConnectionStates) +
CSR-side channel-binding parser via raw TBSCertificationRequestInfo
walk (the stdlib's csr.Attributes can't represent the OCTET STRING
binding value), VerifyChannelBinding composite, EmbedChannel-
BindingAttribute fixture helper, typed sentinel errors for missing
/ mismatch / not-TLS-1.3 mapped to HTTP 400 / 409 / 426 in handler.
- internal/trustanchor/ — extracted from scep/intune/trust_anchor*.go
so the EST mTLS sibling route + Intune dispatcher share the same
SIGHUP-reloadable PEM bundle primitive. intune.TrustAnchorHolder
is now `= trustanchor.Holder` (type alias) + NewTrustAnchorHolder =
trustanchor.New (function alias) — every existing call site compiles
unchanged. Intune's LoadTrustAnchor is a thin wrapper over
trustanchor.LoadBundle. White-box tests moved to the new package.
- internal/ratelimit/ — extracted from scep/intune/rate_limit.go (this
was Phase 4.1, in the same bundle). intune.PerDeviceRateLimiter
is now a thin wrapper preserving the (subject, issuer)→key
composition; EST handler reaches for SlidingWindowLimiter directly.
ESTHandler grew six optional fields wired by per-profile setters
(SetMTLSTrust / SetChannelBindingRequired / SetEnrollmentPassword /
SetSourceIPRateLimiter / SetPerPrincipalRateLimiter / SetLabelForLog)
plus four new mTLS-route methods (CACertsMTLS / SimpleEnrollMTLS /
SimpleReEnrollMTLS / CSRAttrsMTLS); shared internal pipeline
handleEnrollOrReEnroll(reEnroll, viaMTLS) keeps the auth/binding/
rate-limit gates DRY. New router method RegisterESTMTLSHandlers
registers /.well-known/est-mtls/<PathID>/{cacerts,simpleenroll,
simplereenroll,csrattrs}; AuthExemptDispatchPrefixes extends the
no-auth chain to /.well-known/est-mtls.
cmd/server/main.go's EST loop wires per-profile mTLS holder +
channel-binding policy + per-principal limiter + (when EnrollmentPassword
non-empty) Basic + source-IP limiter; new preflightESTMTLSClientCATrust-
Bundle returns *trustanchor.Holder so SIGHUP rotates the EST mTLS
bundle live without restart. SCEP + EST mTLS profiles now share a
single union mtlsUnionPoolForTLS passed to buildServerTLSConfigWithMTLS
(replaces the protocol-specific scepMTLSUnionPoolForTLS); per-handler
re-verify enforces "cert must chain to THIS profile's bundle" so
cross-protocol bleed is blocked at the application layer even though
the TLS layer trusts certs from either pool's union.
Phase 3.3 source-IP failed-Basic limiter defaults: 10 attempts / 1h
/ 50k tracked IPs (no env var; tunable in a follow-up). Phase 4.2
per-principal limiter cap from CERTCTL_EST_PROFILE_<NAME>_RATE_
LIMIT_PER_PRINCIPAL_24H (existing field, Phase 1 shipped).
New tests:
- internal/cms/channelbinding_test.go: extractor + CSR-side parser +
composite + TLS-1.3 round-trip end-to-end + EmbedChannelBinding-
Attribute round-trip
- internal/trustanchor/holder_test.go: parseBundlePEM white-box +
LoadBundle + Holder Get/Pool/SetLabelForLog/Reload-happy/
Reload-keeps-old-on-failure/Reload-keeps-old-on-expired/
WatchSIGHUP-reloads-pool/WatchSIGHUP-stop-clean
- internal/api/handler/est_hardening_test.go: 16 named cases covering
mTLS no-trust-pool 500 + no-cert 401 + cross-profile cert 401 +
happy-path 200 + CACertsMTLS auth gate + CSRAttrsMTLS auth gate +
channel-binding required-absent-rejected + not-required-absent-
allowed + writeChannelBindingError mapping + Basic no-header 401
+ Basic wrong-password 401 + Basic correct-200 + Basic-no-password
no-gate + per-IP failed-attempt lockout 429 + per-principal
blocks-after-cap + different-principals-independent + no-limiter-
unbounded.
Pre-commit verification (sandbox): gofmt clean, go vet clean
(excluding repository/postgres which the sandbox can't build —
disk-space testcontainers download), staticcheck clean for
cms/trustanchor/api/handler/api/router/scep/intune/ratelimit/
cmd/server, go test -short -count=1 green for cms/trustanchor/
api/handler/api/router/scep/intune/ratelimit/service. G-3
docs-drift guard reproduced locally clean (Phase 1 already
documented every new env var; Phases 2-4 added zero new env vars).
This commit is contained in:
+186
-17
@@ -31,10 +31,12 @@ import (
|
|||||||
notifyteams "github.com/shankar0123/certctl/internal/connector/notifier/teams"
|
notifyteams "github.com/shankar0123/certctl/internal/connector/notifier/teams"
|
||||||
"github.com/shankar0123/certctl/internal/crypto/signer"
|
"github.com/shankar0123/certctl/internal/crypto/signer"
|
||||||
"github.com/shankar0123/certctl/internal/domain"
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
"github.com/shankar0123/certctl/internal/ratelimit"
|
||||||
"github.com/shankar0123/certctl/internal/repository/postgres"
|
"github.com/shankar0123/certctl/internal/repository/postgres"
|
||||||
"github.com/shankar0123/certctl/internal/scep/intune"
|
"github.com/shankar0123/certctl/internal/scep/intune"
|
||||||
"github.com/shankar0123/certctl/internal/scheduler"
|
"github.com/shankar0123/certctl/internal/scheduler"
|
||||||
"github.com/shankar0123/certctl/internal/service"
|
"github.com/shankar0123/certctl/internal/service"
|
||||||
|
"github.com/shankar0123/certctl/internal/trustanchor"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -736,8 +738,24 @@ func main() {
|
|||||||
// mirrors the SCEP audit-closure pattern (cmd/server/main.go::
|
// mirrors the SCEP audit-closure pattern (cmd/server/main.go::
|
||||||
// preflightSCEPIntuneTrustAnchor signature took pathID for exactly
|
// preflightSCEPIntuneTrustAnchor signature took pathID for exactly
|
||||||
// this reason).
|
// this reason).
|
||||||
|
// EST RFC 7030 hardening master bundle Phase 2 + SCEP RFC 8894 +
|
||||||
|
// Intune master bundle Phase 6.5 SHARED union pool: every protocol's
|
||||||
|
// mTLS profiles contribute their trust certs here so a single TLS
|
||||||
|
// listener accepts client certs from EITHER protocol's profiles, and
|
||||||
|
// the per-handler gate re-verifies that the cert chains to THIS
|
||||||
|
// profile's bundle. Allocated lazily by whichever protocol first
|
||||||
|
// opts in (left nil when no profile opted in across both protocols
|
||||||
|
// — buildServerTLSConfigWithMTLS treats nil as 'no mTLS').
|
||||||
|
var mtlsUnionPoolForTLS *x509.CertPool
|
||||||
|
// estMTLSStopWatchers collects every per-profile trust-anchor
|
||||||
|
// SIGHUP-watcher stop func so we can shut them down on server exit
|
||||||
|
// (mirrors intuneStopWatchers below).
|
||||||
|
var estMTLSStopWatchers []func()
|
||||||
|
|
||||||
if cfg.EST.Enabled {
|
if cfg.EST.Enabled {
|
||||||
estHandlers := make(map[string]handler.ESTHandler, len(cfg.EST.Profiles))
|
estHandlers := make(map[string]handler.ESTHandler, len(cfg.EST.Profiles))
|
||||||
|
estMTLSHandlers := make(map[string]handler.ESTHandler)
|
||||||
|
estMTLSAnyEnabled := false
|
||||||
for i, profile := range cfg.EST.Profiles {
|
for i, profile := range cfg.EST.Profiles {
|
||||||
profile := profile // shadow for closure-safety
|
profile := profile // shadow for closure-safety
|
||||||
profileLog := logger.With(
|
profileLog := logger.With(
|
||||||
@@ -769,7 +787,102 @@ func main() {
|
|||||||
if profile.ProfileID != "" {
|
if profile.ProfileID != "" {
|
||||||
estService.SetProfileID(profile.ProfileID)
|
estService.SetProfileID(profile.ProfileID)
|
||||||
}
|
}
|
||||||
estHandlers[profile.PathID] = handler.NewESTHandler(estService)
|
estHandler := handler.NewESTHandler(estService)
|
||||||
|
estHandler.SetLabelForLog(fmt.Sprintf("est (PathID=%q)", profile.PathID))
|
||||||
|
|
||||||
|
// Phase 3.1: HTTP Basic enrollment password. Only takes effect
|
||||||
|
// on the standard /.well-known/est/<PathID>/ route — the mTLS
|
||||||
|
// sibling skips it because the client cert IS the auth signal.
|
||||||
|
if profile.EnrollmentPassword != "" {
|
||||||
|
estHandler.SetEnrollmentPassword(profile.EnrollmentPassword)
|
||||||
|
// Phase 3.3: per-source-IP failed-auth rate limit.
|
||||||
|
// Defaults: 10 failed attempts / 1 hour / 50k tracked IPs.
|
||||||
|
// Hard-coded for now (no env var); a tuning bundle can lift
|
||||||
|
// these once we've watched real production deploys for a
|
||||||
|
// release. The shared SlidingWindowLimiter applies the same
|
||||||
|
// math the SCEP/Intune limiter uses — extracted in Phase 4.1
|
||||||
|
// of this bundle so both call sites share the implementation.
|
||||||
|
failed := ratelimit.NewSlidingWindowLimiter(10, time.Hour, 50_000)
|
||||||
|
estHandler.SetSourceIPRateLimiter(failed)
|
||||||
|
}
|
||||||
|
// Phase 2.1: mTLS sibling route. When MTLSEnabled=true, build a
|
||||||
|
// per-profile SIGHUP-reloadable trust-anchor holder, splice the
|
||||||
|
// bundle's certs into the EST mTLS union pool, and clone the
|
||||||
|
// handler with the per-profile trust + channel-binding policy
|
||||||
|
// so SimpleEnrollMTLS / SimpleReEnrollMTLS verify against just
|
||||||
|
// THIS profile's bundle.
|
||||||
|
if profile.MTLSEnabled {
|
||||||
|
holder, err := preflightESTMTLSClientCATrustBundle(true, profile.PathID, profile.MTLSClientCATrustBundlePath, profileLog)
|
||||||
|
if err != nil {
|
||||||
|
profileLog.Error(
|
||||||
|
"startup refused: EST profile MTLS trust bundle preflight failed "+
|
||||||
|
"(EST hardening Phase 2: required when MTLS_ENABLED=true). "+
|
||||||
|
"Verify the bundle file exists at MTLS_CLIENT_CA_TRUST_BUNDLE_PATH, "+
|
||||||
|
"is readable, parses as PEM, contains ≥1 CERTIFICATE block, "+
|
||||||
|
"and none of the bundled certs are past NotAfter.",
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
// Merge this profile's certs into the union pool the TLS
|
||||||
|
// layer uses for VerifyClientCertIfGiven. Walk the bundle
|
||||||
|
// directly so the union pool gets exactly the same certs
|
||||||
|
// as the per-profile pool (mirrors SCEP's pattern at the
|
||||||
|
// equivalent loop iteration).
|
||||||
|
if mtlsUnionPoolForTLS == nil {
|
||||||
|
mtlsUnionPoolForTLS = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
bundleBytes, _ := os.ReadFile(profile.MTLSClientCATrustBundlePath)
|
||||||
|
rest := bundleBytes
|
||||||
|
for {
|
||||||
|
var block *pem.Block
|
||||||
|
block, rest = pem.Decode(rest)
|
||||||
|
if block == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if block.Type != "CERTIFICATE" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
|
||||||
|
mtlsUnionPoolForTLS.AddCert(cert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
estMTLSAnyEnabled = true
|
||||||
|
|
||||||
|
// Build the mTLS sibling-route handler with the per-profile
|
||||||
|
// trust pool, channel-binding policy, and (if configured)
|
||||||
|
// per-principal rate limiter.
|
||||||
|
mtlsHandler := handler.NewESTHandler(estService)
|
||||||
|
mtlsHandler.SetLabelForLog(fmt.Sprintf("est-mtls (PathID=%q)", profile.PathID))
|
||||||
|
mtlsHandler.SetMTLSTrust(holder)
|
||||||
|
mtlsHandler.SetChannelBindingRequired(profile.ChannelBindingRequired)
|
||||||
|
if profile.RateLimitPerPrincipal24h > 0 {
|
||||||
|
perPrincipal := ratelimit.NewSlidingWindowLimiter(profile.RateLimitPerPrincipal24h, 24*time.Hour, 100_000)
|
||||||
|
mtlsHandler.SetPerPrincipalRateLimiter(perPrincipal)
|
||||||
|
}
|
||||||
|
estMTLSHandlers[profile.PathID] = mtlsHandler
|
||||||
|
|
||||||
|
// Install the SIGHUP watcher so an operator that rotates
|
||||||
|
// the mTLS trust bundle file gets the new pool live without
|
||||||
|
// a server restart. Watcher stop func is collected for
|
||||||
|
// orderly shutdown via the defer below.
|
||||||
|
estMTLSStopWatchers = append(estMTLSStopWatchers, holder.WatchSIGHUP())
|
||||||
|
|
||||||
|
profileLog.Info("EST mTLS sibling route enabled",
|
||||||
|
"endpoint", "/.well-known/est-mtls/"+profile.PathID,
|
||||||
|
"client_ca_trust_bundle", profile.MTLSClientCATrustBundlePath,
|
||||||
|
"channel_binding_required", profile.ChannelBindingRequired,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
// Phase 4.2: per-principal rate limiter on the standard route
|
||||||
|
// too (additive — both routes share the same per-(CN, IP) cap
|
||||||
|
// when configured). The mTLS handler above gets its own
|
||||||
|
// limiter instance so the two routes don't share a bucket.
|
||||||
|
if profile.RateLimitPerPrincipal24h > 0 {
|
||||||
|
perPrincipal := ratelimit.NewSlidingWindowLimiter(profile.RateLimitPerPrincipal24h, 24*time.Hour, 100_000)
|
||||||
|
estHandler.SetPerPrincipalRateLimiter(perPrincipal)
|
||||||
|
}
|
||||||
|
estHandlers[profile.PathID] = estHandler
|
||||||
|
|
||||||
endpoint := "/.well-known/est"
|
endpoint := "/.well-known/est"
|
||||||
if profile.PathID != "" {
|
if profile.PathID != "" {
|
||||||
@@ -785,18 +898,30 @@ func main() {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
apiRouter.RegisterESTHandlers(estHandlers)
|
apiRouter.RegisterESTHandlers(estHandlers)
|
||||||
logger.Info("EST server enabled", "profile_count", len(cfg.EST.Profiles))
|
if estMTLSAnyEnabled {
|
||||||
|
apiRouter.RegisterESTMTLSHandlers(estMTLSHandlers)
|
||||||
|
logger.Info("EST mTLS sibling route enabled (Phase 2)",
|
||||||
|
"mtls_profile_count", len(estMTLSHandlers),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
logger.Info("EST server enabled",
|
||||||
|
"profile_count", len(cfg.EST.Profiles),
|
||||||
|
"mtls_profile_count", len(estMTLSHandlers),
|
||||||
|
)
|
||||||
|
// Stop SIGHUP watchers in LIFO on server shutdown.
|
||||||
|
if len(estMTLSStopWatchers) > 0 {
|
||||||
|
defer func() {
|
||||||
|
for _, stop := range estMTLSStopWatchers {
|
||||||
|
stop()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SCEP RFC 8894 Phase 6.5: union pool of every enabled mTLS profile's
|
// SCEP RFC 8894 Phase 6.5: union pool of every enabled mTLS profile's
|
||||||
// trust bundle. Populated inside the SCEP startup block below; passed
|
// EST RFC 7030 hardening master bundle Phase 2: SCEP's mTLS union pool
|
||||||
// to the TLS-config builder later so the listener accepts client certs
|
// merged into the SHARED mtlsUnionPoolForTLS variable declared above.
|
||||||
// signed by ANY mTLS profile's CA. The handler-layer gate
|
// Variables here intentionally renamed to make the merge explicit.
|
||||||
// (HandleSCEPMTLS) re-verifies per-profile, so a cert that chains to
|
|
||||||
// profile A's bundle cannot enroll against profile B even though it
|
|
||||||
// passes the TLS-layer union check. Stays nil when no profile opted in
|
|
||||||
// (the TLS config builder treats nil as 'no mTLS').
|
|
||||||
var scepMTLSUnionPoolForTLS *x509.CertPool
|
|
||||||
|
|
||||||
// Register SCEP (RFC 8894) handlers if enabled.
|
// Register SCEP (RFC 8894) handlers if enabled.
|
||||||
//
|
//
|
||||||
@@ -821,7 +946,6 @@ func main() {
|
|||||||
// bundle to prevent cross-profile bleed-through).
|
// bundle to prevent cross-profile bleed-through).
|
||||||
scepHandlers := make(map[string]handler.SCEPHandler, len(cfg.SCEP.Profiles))
|
scepHandlers := make(map[string]handler.SCEPHandler, len(cfg.SCEP.Profiles))
|
||||||
scepMTLSHandlers := make(map[string]handler.SCEPHandler)
|
scepMTLSHandlers := make(map[string]handler.SCEPHandler)
|
||||||
scepMTLSUnionPool := x509.NewCertPool()
|
|
||||||
scepMTLSAnyEnabled := false
|
scepMTLSAnyEnabled := false
|
||||||
// SCEP RFC 8894 + Intune master bundle Phase 8: per-profile Intune
|
// SCEP RFC 8894 + Intune master bundle Phase 8: per-profile Intune
|
||||||
// trust anchor holders. We track them here so a single SIGHUP
|
// trust anchor holders. We track them here so a single SIGHUP
|
||||||
@@ -1017,7 +1141,10 @@ func main() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
|
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
|
||||||
scepMTLSUnionPool.AddCert(cert)
|
if mtlsUnionPoolForTLS == nil {
|
||||||
|
mtlsUnionPoolForTLS = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
mtlsUnionPoolForTLS.AddCert(cert)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
scepMTLSAnyEnabled = true
|
scepMTLSAnyEnabled = true
|
||||||
@@ -1049,7 +1176,6 @@ func main() {
|
|||||||
// no-op-when-disabled case obvious in logs.
|
// no-op-when-disabled case obvious in logs.
|
||||||
if scepMTLSAnyEnabled {
|
if scepMTLSAnyEnabled {
|
||||||
apiRouter.RegisterSCEPMTLSHandlers(scepMTLSHandlers)
|
apiRouter.RegisterSCEPMTLSHandlers(scepMTLSHandlers)
|
||||||
scepMTLSUnionPoolForTLS = scepMTLSUnionPool
|
|
||||||
logger.Info("SCEP mTLS sibling route enabled (Phase 6.5)",
|
logger.Info("SCEP mTLS sibling route enabled (Phase 6.5)",
|
||||||
"mtls_profile_count", len(scepMTLSHandlers),
|
"mtls_profile_count", len(scepMTLSHandlers),
|
||||||
)
|
)
|
||||||
@@ -1317,7 +1443,7 @@ func main() {
|
|||||||
// sibling route gates additionally on the verified client cert.
|
// sibling route gates additionally on the verified client cert.
|
||||||
// nil pool = no profile opted in = identical TLS shape to the
|
// nil pool = no profile opted in = identical TLS shape to the
|
||||||
// pre-Phase-6.5 buildServerTLSConfig path.
|
// pre-Phase-6.5 buildServerTLSConfig path.
|
||||||
TLSConfig: buildServerTLSConfigWithMTLS(tlsCertHolder, scepMTLSUnionPoolForTLS),
|
TLSConfig: buildServerTLSConfigWithMTLS(tlsCertHolder, mtlsUnionPoolForTLS),
|
||||||
ReadTimeout: 30 * time.Second,
|
ReadTimeout: 30 * time.Second,
|
||||||
ReadHeaderTimeout: 5 * time.Second,
|
ReadHeaderTimeout: 5 * time.Second,
|
||||||
WriteTimeout: 120 * time.Second, // Must accommodate ACME issuance (order + challenge + finalize)
|
WriteTimeout: 120 * time.Second, // Must accommodate ACME issuance (order + challenge + finalize)
|
||||||
@@ -1476,6 +1602,41 @@ func preflightSCEPMTLSTrustBundle(enabled bool, bundlePath string) (*x509.CertPo
|
|||||||
return pool, nil
|
return pool, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// preflightESTMTLSClientCATrustBundle validates a per-profile EST mTLS
|
||||||
|
// client-CA trust bundle and returns a SIGHUP-reloadable holder.
|
||||||
|
//
|
||||||
|
// EST RFC 7030 hardening master bundle Phase 2.5.
|
||||||
|
//
|
||||||
|
// Mirrors preflightSCEPMTLSTrustBundle's checks (file exists, parses as
|
||||||
|
// PEM, ≥1 cert, none expired) but returns a *trustanchor.Holder rather
|
||||||
|
// than a raw *x509.CertPool — the EST handler stores the holder so a
|
||||||
|
// SIGHUP rotates the trust bundle live without a server restart, exactly
|
||||||
|
// the way the Intune trust anchor rotation works (Phase 8.5 of the SCEP
|
||||||
|
// bundle). The handler-side .Pool() accessor on the holder rebuilds an
|
||||||
|
// x509.CertPool from the current snapshot for each Verify call.
|
||||||
|
//
|
||||||
|
// Uses the shared internal/trustanchor.LoadBundle (extracted in EST
|
||||||
|
// hardening Phase 2.1 from the original Intune-only path) so the EST
|
||||||
|
// + Intune callers exercise the same loader semantics — empty bundle
|
||||||
|
// rejected, expired cert rejected with subject in error message,
|
||||||
|
// non-CERTIFICATE PEM blocks tolerated.
|
||||||
|
func preflightESTMTLSClientCATrustBundle(enabled bool, pathID, bundlePath string, logger *slog.Logger) (*trustanchor.Holder, error) {
|
||||||
|
if !enabled {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if bundlePath == "" {
|
||||||
|
return nil, fmt.Errorf("EST profile (PathID=%q) MTLS enabled but trust bundle path empty: "+
|
||||||
|
"set CERTCTL_EST_PROFILE_<NAME>_MTLS_CLIENT_CA_TRUST_BUNDLE_PATH to a PEM file "+
|
||||||
|
"containing the bootstrap-CA certs the operator allows to enroll", pathID)
|
||||||
|
}
|
||||||
|
holder, err := trustanchor.New(bundlePath, logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("EST profile (PathID=%q) MTLS trust bundle preflight: %w", pathID, err)
|
||||||
|
}
|
||||||
|
holder.SetLabelForLog(fmt.Sprintf("EST mTLS client CA bundle (PathID=%q)", pathID))
|
||||||
|
return holder, nil
|
||||||
|
}
|
||||||
|
|
||||||
// preflightSCEPIntuneTrustAnchor validates a per-profile Microsoft Intune
|
// preflightSCEPIntuneTrustAnchor validates a per-profile Microsoft Intune
|
||||||
// Certificate Connector signing-cert trust bundle.
|
// Certificate Connector signing-cert trust bundle.
|
||||||
//
|
//
|
||||||
@@ -1745,9 +1906,17 @@ func buildFinalHandler(apiHandler, noAuthHandler http.Handler, webDir string, da
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RFC 7030 EST endpoints ride the no-auth middleware chain (M-001,
|
// RFC 7030 EST endpoints ride the no-auth middleware chain (M-001,
|
||||||
// option D, audit 2026-04-19). Trust boundary is CSR signature + profile
|
// option D, audit 2026-04-19). Trust boundary is CSR signature +
|
||||||
// policy, not HTTP Bearer. /.well-known/est/cacerts is explicitly
|
// (per EST hardening Phase 2) optional client cert at the handler
|
||||||
// anonymous per RFC 7030 §4.1.1.
|
// layer, not HTTP Bearer. /.well-known/est/cacerts is explicitly
|
||||||
|
// anonymous per RFC 7030 §4.1.1; /.well-known/est-mtls/<PathID>/
|
||||||
|
// (EST hardening Phase 2 sibling route) requires a client cert
|
||||||
|
// gate at the handler layer — both share this prefix gate because
|
||||||
|
// "/.well-known/est-mtls" is itself prefixed by "/.well-known/est".
|
||||||
|
// EST hardening Phase 3's HTTP Basic enrollment-password is a
|
||||||
|
// per-profile handler-layer auth that runs INSIDE the no-auth
|
||||||
|
// middleware chain (since the chain skips the Bearer middleware,
|
||||||
|
// the handler gets to define its own auth contract).
|
||||||
if strings.HasPrefix(path, "/.well-known/est") {
|
if strings.HasPrefix(path, "/.well-known/est") {
|
||||||
noAuthHandler.ServeHTTP(w, r)
|
noAuthHandler.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
|
|||||||
+15
-9
@@ -136,21 +136,27 @@ func buildServerTLSConfig(holder *certHolder) *tls.Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// buildServerTLSConfigWithMTLS extends buildServerTLSConfig with a client-cert
|
// buildServerTLSConfigWithMTLS extends buildServerTLSConfig with a client-cert
|
||||||
// trust pool for the SCEP RFC 8894 + Intune master bundle Phase 6.5 mTLS
|
// trust pool for the SCEP/EST mTLS sibling routes.
|
||||||
// sibling route. SCEP profiles that opt into mTLS each contribute their
|
//
|
||||||
// trust bundle to the union pool here; the same TLS listener serves both
|
// SCEP RFC 8894 + Intune master bundle Phase 6.5 introduced this for the
|
||||||
// /scep[/<pathID>] (no client cert) and /scep-mtls/<pathID> (cert required
|
// /scep-mtls/<pathID> route; EST RFC 7030 hardening master bundle Phase 2
|
||||||
// at the handler layer).
|
// extended it so the same TLS listener also serves /.well-known/est-mtls/
|
||||||
|
// <pathID>. Both protocols' mTLS profiles contribute their trust bundles
|
||||||
|
// to a UNION pool that the caller (cmd/server/main.go) builds by walking
|
||||||
|
// every enabled mTLS profile's bundle bytes once. The per-protocol
|
||||||
|
// handlers re-verify against just THIS profile's bundle (so an EST-mTLS
|
||||||
|
// bootstrap cert can't enroll against a SCEP-mTLS profile and vice versa).
|
||||||
//
|
//
|
||||||
// ClientAuth: VerifyClientCertIfGiven — request a cert during handshake; if
|
// ClientAuth: VerifyClientCertIfGiven — request a cert during handshake; if
|
||||||
// the client presents one, verify it against the union pool; if absent, the
|
// the client presents one, verify it against the union pool; if absent, the
|
||||||
// request still reaches the handler and the per-route handler decides
|
// request still reaches the handler and the per-route handler decides
|
||||||
// whether to accept. Critical that we do NOT use RequireAndVerifyClientCert
|
// whether to accept. Critical that we do NOT use RequireAndVerifyClientCert
|
||||||
// here — that would break the standard /scep route (which is challenge-
|
// here — that would break the standard /scep + /.well-known/est routes
|
||||||
// password-only, no client cert expected).
|
// (challenge-password-only / unauth-or-Basic, no client cert expected).
|
||||||
//
|
//
|
||||||
// Pass clientCAs == nil to disable mTLS (no profile opted in). The function
|
// Pass clientCAs == nil to disable mTLS (no profile opted in across either
|
||||||
// then returns the same shape as buildServerTLSConfig.
|
// protocol). The function then returns the same shape as
|
||||||
|
// buildServerTLSConfig.
|
||||||
func buildServerTLSConfigWithMTLS(holder *certHolder, clientCAs *x509.CertPool) *tls.Config {
|
func buildServerTLSConfigWithMTLS(holder *certHolder, clientCAs *x509.CertPool) *tls.Config {
|
||||||
cfg := buildServerTLSConfig(holder)
|
cfg := buildServerTLSConfig(holder)
|
||||||
if clientCAs != nil {
|
if clientCAs != nil {
|
||||||
|
|||||||
+599
-225
@@ -2,17 +2,23 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/subtle"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/shankar0123/certctl/internal/api/middleware"
|
"github.com/shankar0123/certctl/internal/api/middleware"
|
||||||
|
"github.com/shankar0123/certctl/internal/cms"
|
||||||
"github.com/shankar0123/certctl/internal/domain"
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
"github.com/shankar0123/certctl/internal/pkcs7"
|
"github.com/shankar0123/certctl/internal/pkcs7"
|
||||||
|
"github.com/shankar0123/certctl/internal/ratelimit"
|
||||||
|
"github.com/shankar0123/certctl/internal/trustanchor"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ESTService defines the service interface for EST enrollment operations.
|
// ESTService defines the service interface for EST enrollment operations.
|
||||||
@@ -33,62 +39,558 @@ type ESTService interface {
|
|||||||
|
|
||||||
// ESTHandler handles HTTP requests for the EST protocol (RFC 7030).
|
// ESTHandler handles HTTP requests for the EST protocol (RFC 7030).
|
||||||
//
|
//
|
||||||
// EST endpoints are served under /.well-known/est/ per the RFC.
|
// EST endpoints are served under /.well-known/est/[<PathID>/] per RFC 7030.
|
||||||
// Wire format: base64-encoded DER (PKCS#7 for certs, PKCS#10 for CSRs).
|
// Wire format: base64-encoded DER (PKCS#7 for certs, PKCS#10 for CSRs).
|
||||||
//
|
//
|
||||||
// Supported operations:
|
// Supported operations (per route family):
|
||||||
// - GET /.well-known/est/cacerts — CA certificate distribution
|
//
|
||||||
// - POST /.well-known/est/simpleenroll — initial enrollment
|
// /.well-known/est/[<PathID>/] — legacy + per-profile route family
|
||||||
// - POST /.well-known/est/simplereenroll — re-enrollment
|
// GET cacerts — CA certificate distribution
|
||||||
// - GET /.well-known/est/csrattrs — CSR attributes
|
// POST simpleenroll — initial enrollment (HTTP Basic optional, Phase 3)
|
||||||
|
// POST simplereenroll — re-enrollment (HTTP Basic optional, Phase 3)
|
||||||
|
// GET csrattrs — CSR attributes
|
||||||
|
//
|
||||||
|
// /.well-known/est-mtls/<PathID>/ — mTLS sibling (Phase 2)
|
||||||
|
// GET cacerts — CA certificate distribution (cert auth required)
|
||||||
|
// POST simpleenroll — initial enrollment (cert + optional channel binding)
|
||||||
|
// POST simplereenroll — re-enrollment (cert + optional channel binding)
|
||||||
|
// GET csrattrs — CSR attributes
|
||||||
|
//
|
||||||
|
// EST RFC 7030 hardening master bundle Phases 2-4: ESTHandler grew six
|
||||||
|
// optional fields wired by per-profile setters in cmd/server/main.go's
|
||||||
|
// startup loop. None of the new fields are required — a handler with all
|
||||||
|
// of them unset behaves exactly like the v2.0.x EST handler.
|
||||||
type ESTHandler struct {
|
type ESTHandler struct {
|
||||||
svc ESTService
|
svc ESTService
|
||||||
|
|
||||||
|
// EST RFC 7030 hardening Phase 2.1: per-profile mTLS client-CA trust
|
||||||
|
// bundle. When set, the mTLS sibling route (CACertsMTLS /
|
||||||
|
// SimpleEnrollMTLS / etc.) verifies the inbound client cert chain
|
||||||
|
// against this pool. Nil when MTLS_ENABLED=false; the mTLS route
|
||||||
|
// rejects unconditionally in that case (the route shouldn't even be
|
||||||
|
// registered, but defense in depth).
|
||||||
|
mtlsTrust *trustanchor.Holder
|
||||||
|
|
||||||
|
// EST RFC 7030 hardening Phase 2.4: per-profile channel-binding
|
||||||
|
// requirement. When true, the mTLS handler refuses simplereenroll
|
||||||
|
// requests whose CSR doesn't carry a matching id-aa-est-tls-exporter
|
||||||
|
// (RFC 9266) attribute. Phase 1's Validate() guards
|
||||||
|
// ChannelBindingRequired=true + MTLSEnabled=false at startup.
|
||||||
|
channelBindingRequired bool
|
||||||
|
|
||||||
|
// EST RFC 7030 hardening Phase 3.1: per-profile HTTP Basic enrollment
|
||||||
|
// password. When non-empty, the standard /.well-known/est/<PathID>/
|
||||||
|
// route requires `Authorization: Basic <base64(<user>:<pw>)>` on the
|
||||||
|
// enrollment endpoints (NOT on cacerts/csrattrs — RFC 7030 §4.1.1
|
||||||
|
// says cacerts is anonymous). Constant-time compare; per-source-IP
|
||||||
|
// failed-auth rate limit blocks brute-force.
|
||||||
|
basicPassword string
|
||||||
|
|
||||||
|
// EST RFC 7030 hardening Phase 3.3: per-handler source-IP rate
|
||||||
|
// limiter for FAILED HTTP Basic auth attempts. Keyed by sourceIP so
|
||||||
|
// a hostile network segment can't burn through the password.
|
||||||
|
failedBasicLimiter *ratelimit.SlidingWindowLimiter
|
||||||
|
|
||||||
|
// EST RFC 7030 hardening Phase 4.2: per-handler per-principal sliding-
|
||||||
|
// window rate limit. Keyed by (CSR-CN, sourceIP) so a stolen
|
||||||
|
// bootstrap cert AND a known device CN can't be used to flood the
|
||||||
|
// issuer. Disabled when nil; configured per-profile.
|
||||||
|
perPrincipalLimiter *ratelimit.SlidingWindowLimiter
|
||||||
|
|
||||||
|
// labelForLog gives observability code a per-profile string to
|
||||||
|
// include in audit log lines / Prometheus labels. Defaults to
|
||||||
|
// "est" when unset.
|
||||||
|
labelForLog string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewESTHandler creates a new ESTHandler.
|
// NewESTHandler creates a new ESTHandler with no per-profile auth
|
||||||
|
// hardening. Call SetMTLSTrust + SetChannelBindingRequired +
|
||||||
|
// SetEnrollmentPassword + SetSourceIPRateLimiter + SetPerPrincipalRateLimiter
|
||||||
|
// from the per-profile startup loop to opt-in to each surface.
|
||||||
func NewESTHandler(svc ESTService) ESTHandler {
|
func NewESTHandler(svc ESTService) ESTHandler {
|
||||||
return ESTHandler{svc: svc}
|
return ESTHandler{svc: svc}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CACerts handles GET /.well-known/est/cacerts
|
// SetMTLSTrust injects the per-profile client-cert trust pool the
|
||||||
// Returns the CA certificate chain as base64-encoded PKCS#7 (certs-only).
|
// `/.well-known/est-mtls/<PathID>/` sibling route uses to verify inbound
|
||||||
// Per RFC 7030 Section 4.1, this is a "certs-only" CMC Simple PKI Response.
|
// device cert chains. EST RFC 7030 hardening Phase 2.1.
|
||||||
// For simplicity and broad client compatibility, we return base64-encoded DER certificates.
|
//
|
||||||
|
// Like the SCEP equivalent, the TLS layer (cmd/server/tls.go) uses
|
||||||
|
// VerifyClientCertIfGiven against the UNION of every enabled mTLS
|
||||||
|
// profile's bundle, so the same TLS listener serves both /.well-known/est
|
||||||
|
// (anonymous or HTTP Basic) and /.well-known/est-mtls/<PathID>
|
||||||
|
// (cert-required). The per-profile gate at the handler layer enforces
|
||||||
|
// 'cert must chain to THIS profile's bundle' so a cert that chains to
|
||||||
|
// profile A's bundle cannot enroll against profile B.
|
||||||
|
func (h *ESTHandler) SetMTLSTrust(t *trustanchor.Holder) { h.mtlsTrust = t }
|
||||||
|
|
||||||
|
// SetChannelBindingRequired toggles RFC 9266 tls-exporter channel binding
|
||||||
|
// on the simplereenroll mTLS path. EST RFC 7030 hardening Phase 2.4.
|
||||||
|
// When true, the handler refuses requests whose CSR lacks the binding
|
||||||
|
// attribute or whose binding bytes don't match the live TLS exporter.
|
||||||
|
func (h *ESTHandler) SetChannelBindingRequired(req bool) { h.channelBindingRequired = req }
|
||||||
|
|
||||||
|
// SetEnrollmentPassword injects the per-profile HTTP Basic enrollment
|
||||||
|
// password. EST RFC 7030 hardening Phase 3.1. Empty disables the gate
|
||||||
|
// (mTLS-only or unauthenticated profile). Constant-time compare via
|
||||||
|
// crypto/subtle.ConstantTimeCompare.
|
||||||
|
func (h *ESTHandler) SetEnrollmentPassword(pw string) { h.basicPassword = pw }
|
||||||
|
|
||||||
|
// SetSourceIPRateLimiter injects the per-handler failed-Basic-auth
|
||||||
|
// rate limiter. Phase 3.3. Disabled when nil — but Validate() at
|
||||||
|
// startup refuses an enabled basic-auth profile without a configured
|
||||||
|
// limiter, so a real deploy always wires one.
|
||||||
|
func (h *ESTHandler) SetSourceIPRateLimiter(l *ratelimit.SlidingWindowLimiter) {
|
||||||
|
h.failedBasicLimiter = l
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPerPrincipalRateLimiter injects the per-handler (CN, sourceIP)
|
||||||
|
// sliding-window rate limiter. Phase 4.2. Disabled when nil. Counts
|
||||||
|
// every successful enrollment, NOT just failures — the goal is to
|
||||||
|
// bound enrollment-flooding from a compromised credential, not just
|
||||||
|
// failed-auth brute force.
|
||||||
|
func (h *ESTHandler) SetPerPrincipalRateLimiter(l *ratelimit.SlidingWindowLimiter) {
|
||||||
|
h.perPrincipalLimiter = l
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLabelForLog sets the per-profile observability label. Defaults to
|
||||||
|
// "est" when unset; cmd/server/main.go's per-profile loop sets this
|
||||||
|
// to "est (PathID=<id>)" for triage.
|
||||||
|
func (h *ESTHandler) SetLabelForLog(label string) {
|
||||||
|
if label == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.labelForLog = label
|
||||||
|
}
|
||||||
|
|
||||||
|
// label returns h.labelForLog with the "est" fallback applied. Tiny
|
||||||
|
// helper so log call sites don't need to repeat the fallback.
|
||||||
|
func (h ESTHandler) label() string {
|
||||||
|
if h.labelForLog == "" {
|
||||||
|
return "est"
|
||||||
|
}
|
||||||
|
return h.labelForLog
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- /.well-known/est/[<PathID>/] route family (legacy + Basic auth) -----
|
||||||
|
|
||||||
|
// CACerts handles GET /.well-known/est/[<PathID>/]cacerts.
|
||||||
|
//
|
||||||
|
// RFC 7030 §4.1.1 — anonymous endpoint. The HTTP Basic gate is NOT
|
||||||
|
// applied here (any client must be able to fetch the CA chain to
|
||||||
|
// verify subsequent enrollment responses).
|
||||||
func (h ESTHandler) CACerts(w http.ResponseWriter, r *http.Request) {
|
func (h ESTHandler) CACerts(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodGet {
|
if r.Method != http.MethodGet {
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
h.writeCACertsResponse(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
caCertPEM, err := h.svc.GetCACerts(r.Context())
|
// SimpleEnroll handles POST /.well-known/est/[<PathID>/]simpleenroll.
|
||||||
if err != nil {
|
// Accepts a base64-encoded PKCS#10 CSR + returns base64-encoded PKCS#7.
|
||||||
requestID := middleware.GetRequestID(r.Context())
|
//
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, fmt.Sprintf("Failed to get CA certificates: %v", err), requestID)
|
// Auth: HTTP Basic when h.basicPassword != "" (Phase 3); otherwise
|
||||||
|
// anonymous. Rate-limit: per-(CN, sourceIP) when wired (Phase 4).
|
||||||
|
func (h ESTHandler) SimpleEnroll(w http.ResponseWriter, r *http.Request) {
|
||||||
|
h.handleEnrollOrReEnroll(w, r, false /*reEnroll*/, false /*viaMTLS*/)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SimpleReEnroll handles POST /.well-known/est/[<PathID>/]simplereenroll.
|
||||||
|
// Same as SimpleEnroll but the audit/log distinguishes the renewal flow
|
||||||
|
// from initial issuance.
|
||||||
|
func (h ESTHandler) SimpleReEnroll(w http.ResponseWriter, r *http.Request) {
|
||||||
|
h.handleEnrollOrReEnroll(w, r, true /*reEnroll*/, false /*viaMTLS*/)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CSRAttrs handles GET /.well-known/est/[<PathID>/]csrattrs.
|
||||||
|
// Returns the CSR attributes the server wants the client to include.
|
||||||
|
// RFC 7030 §4.5 — anonymous endpoint, no Basic auth gate.
|
||||||
|
func (h ESTHandler) CSRAttrs(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.writeCSRAttrsResponse(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- /.well-known/est-mtls/<PathID>/ route family (Phase 2 mTLS) -----
|
||||||
|
|
||||||
|
// CACertsMTLS handles GET /.well-known/est-mtls/<PathID>/cacerts.
|
||||||
|
//
|
||||||
|
// RFC 7030 §4.1.1 says cacerts is anonymous, but on the mTLS sibling
|
||||||
|
// route we still require a valid client cert because the mTLS path is
|
||||||
|
// the audit-distinguished surface — operators using mTLS WANT every
|
||||||
|
// touchpoint logged. The cert isn't validated for purpose-of-issuance
|
||||||
|
// here (cacerts isn't an enrollment), but absence is rejected.
|
||||||
|
func (h ESTHandler) CACertsMTLS(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, ok := h.requireClientCertChain(w, r); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.writeCACertsResponse(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SimpleEnrollMTLS handles POST /.well-known/est-mtls/<PathID>/simpleenroll.
|
||||||
|
//
|
||||||
|
// Order of gates (each fails fast with the appropriate HTTP status):
|
||||||
|
//
|
||||||
|
// 1. Client cert presented + chains to per-profile mTLS trust pool
|
||||||
|
// (the TLS layer already verified against the union pool; this is
|
||||||
|
// the per-profile re-verify that prevents profile A↔B cross-bleed).
|
||||||
|
// 2. CSR parses + matches the EST contract (handled by the shared
|
||||||
|
// enrollment helper).
|
||||||
|
// 3. Per-(CN, sourceIP) rate limit when configured.
|
||||||
|
// 4. Service-layer enrollment.
|
||||||
|
//
|
||||||
|
// Channel binding does NOT apply here — RFC 9266 §1 calls out that
|
||||||
|
// channel binding is a renewal-time defense-in-depth, not an initial-
|
||||||
|
// enrollment requirement. (A first-time enrollment doesn't yet have a
|
||||||
|
// device cert, so binding to the TLS session for the bootstrap cert
|
||||||
|
// adds nothing.)
|
||||||
|
func (h ESTHandler) SimpleEnrollMTLS(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if _, ok := h.requireClientCertChain(w, r); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.handleEnrollOrReEnroll(w, r, false /*reEnroll*/, true /*viaMTLS*/)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SimpleReEnrollMTLS handles POST /.well-known/est-mtls/<PathID>/simplereenroll.
|
||||||
|
//
|
||||||
|
// Same as SimpleEnrollMTLS plus the channel-binding gate. RFC 9266 §4.1
|
||||||
|
// says renewal CSRs SHOULD include the binding attribute when the
|
||||||
|
// enrollment is over a TLS-1.3 channel; per-profile policy can either
|
||||||
|
// require this strictly (ChannelBindingRequired=true) or accept its
|
||||||
|
// absence (default).
|
||||||
|
func (h ESTHandler) SimpleReEnrollMTLS(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if _, ok := h.requireClientCertChain(w, r); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.handleEnrollOrReEnroll(w, r, true /*reEnroll*/, true /*viaMTLS*/)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CSRAttrsMTLS handles GET /.well-known/est-mtls/<PathID>/csrattrs.
|
||||||
|
// Mirrors CACertsMTLS — cert-required even though the unauth route
|
||||||
|
// version is anonymous.
|
||||||
|
func (h ESTHandler) CSRAttrsMTLS(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, ok := h.requireClientCertChain(w, r); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.writeCSRAttrsResponse(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- shared internal pipeline -----
|
||||||
|
|
||||||
|
// handleEnrollOrReEnroll is the shared body for {Simple,SimpleRe}Enroll{,MTLS}.
|
||||||
|
// reEnroll picks the SimpleReEnroll vs SimpleEnroll service method (purely
|
||||||
|
// audit / metric distinguishing — same issuer call underneath); viaMTLS
|
||||||
|
// picks whether the channel-binding + per-principal-limit gates apply
|
||||||
|
// AND skips the HTTP Basic gate (mTLS handlers carry the auth).
|
||||||
|
func (h ESTHandler) handleEnrollOrReEnroll(w http.ResponseWriter, r *http.Request, reEnroll, viaMTLS bool) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse PEM to DER for PKCS#7 encoding
|
requestID := middleware.GetRequestID(r.Context())
|
||||||
|
|
||||||
|
if err := verifyESTTransport(r); err != nil {
|
||||||
|
ErrorWithRequestID(w, http.StatusBadRequest,
|
||||||
|
fmt.Sprintf("EST transport precondition failed: %v", err), requestID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTP Basic gate (Phase 3) — non-mTLS path only. mTLS profiles
|
||||||
|
// authenticate via the client cert so adding Basic on top would
|
||||||
|
// double-tax operators with no security benefit.
|
||||||
|
if !viaMTLS && h.basicPassword != "" {
|
||||||
|
if !h.requireBasicAuth(w, r) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
csrPEM, err := h.readCSRFromRequest(r)
|
||||||
|
if err != nil {
|
||||||
|
ErrorWithRequestID(w, http.StatusBadRequest, fmt.Sprintf("Invalid CSR: %v", err), requestID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the CSR once for downstream gates (channel-binding, per-
|
||||||
|
// principal rate limit). The service re-parses internally — that's a
|
||||||
|
// minor inefficiency we accept to keep the service interface flat.
|
||||||
|
csr, _ := decodeCSRPEM(csrPEM)
|
||||||
|
|
||||||
|
// Channel-binding gate (Phase 2.4) — mTLS reEnroll only. The optional
|
||||||
|
// CSR-side attribute is checked even when the per-profile flag isn't
|
||||||
|
// requiring it (a CSR carrying the attribute MUST match the live
|
||||||
|
// exporter; a present-but-mismatched binding is always fatal).
|
||||||
|
if viaMTLS && reEnroll && csr != nil {
|
||||||
|
if err := cms.VerifyChannelBinding(r.TLS, csr, h.channelBindingRequired); err != nil {
|
||||||
|
h.writeChannelBindingError(w, requestID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per-principal rate-limit gate (Phase 4.2). Keyed by CN+sourceIP so
|
||||||
|
// (a) a CN with no source-IP rotation can be capped, AND (b) a
|
||||||
|
// hostile network segment trying to enroll many CNs from one IP is
|
||||||
|
// also bounded.
|
||||||
|
if h.perPrincipalLimiter != nil {
|
||||||
|
if err := h.applyPerPrincipalRateLimit(r, csr); err != nil {
|
||||||
|
ErrorWithRequestID(w, http.StatusTooManyRequests,
|
||||||
|
fmt.Sprintf("EST enrollment rate-limited: %v", err), requestID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
result *domain.ESTEnrollResult
|
||||||
|
callErr error
|
||||||
|
)
|
||||||
|
if reEnroll {
|
||||||
|
result, callErr = h.svc.SimpleReEnroll(r.Context(), csrPEM)
|
||||||
|
} else {
|
||||||
|
result, callErr = h.svc.SimpleEnroll(r.Context(), csrPEM)
|
||||||
|
}
|
||||||
|
if callErr != nil {
|
||||||
|
op := "Enrollment"
|
||||||
|
if reEnroll {
|
||||||
|
op = "Re-enrollment"
|
||||||
|
}
|
||||||
|
ErrorWithRequestID(w, http.StatusInternalServerError,
|
||||||
|
fmt.Sprintf("%s failed: %v", op, callErr), requestID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.writeCertResponse(w, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// requireClientCertChain enforces the mTLS gate for the est-mtls sibling
|
||||||
|
// route. Returns the leaf cert + true on success; on failure writes the
|
||||||
|
// HTTP error and returns false.
|
||||||
|
//
|
||||||
|
// Mirrors SCEPHandler.HandleSCEPMTLS exactly:
|
||||||
|
// - mtlsTrust nil → 500 (config bug; preflight should have prevented).
|
||||||
|
// - r.TLS nil or no peer cert → 401 (cert required).
|
||||||
|
// - chain doesn't verify against per-profile pool → 401.
|
||||||
|
func (h ESTHandler) requireClientCertChain(w http.ResponseWriter, r *http.Request) (*x509.Certificate, bool) {
|
||||||
|
requestID := middleware.GetRequestID(r.Context())
|
||||||
|
if h.mtlsTrust == nil {
|
||||||
|
ErrorWithRequestID(w, http.StatusInternalServerError,
|
||||||
|
h.label()+" mTLS handler missing trust pool", requestID)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||||
|
ErrorWithRequestID(w, http.StatusUnauthorized,
|
||||||
|
"Client certificate required for /.well-known/est-mtls", requestID)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
leaf := r.TLS.PeerCertificates[0]
|
||||||
|
intermediates := x509.NewCertPool()
|
||||||
|
for _, c := range r.TLS.PeerCertificates[1:] {
|
||||||
|
intermediates.AddCert(c)
|
||||||
|
}
|
||||||
|
if _, err := leaf.Verify(x509.VerifyOptions{
|
||||||
|
Roots: h.mtlsTrust.Pool(),
|
||||||
|
Intermediates: intermediates,
|
||||||
|
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageAny},
|
||||||
|
}); err != nil {
|
||||||
|
ErrorWithRequestID(w, http.StatusUnauthorized,
|
||||||
|
"Client certificate not trusted by this profile", requestID)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return leaf, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// requireBasicAuth runs the Phase 3 HTTP Basic password gate. Returns
|
||||||
|
// true when auth passed. On failure writes WWW-Authenticate + a 401
|
||||||
|
// (with rate-limit accounting against the source IP).
|
||||||
|
//
|
||||||
|
// User: any non-empty value (RFC 7030 §3.2.3 says the username is
|
||||||
|
// not authoritative when only a shared password is meaningful). Pass:
|
||||||
|
// constant-time compare against h.basicPassword.
|
||||||
|
func (h ESTHandler) requireBasicAuth(w http.ResponseWriter, r *http.Request) bool {
|
||||||
|
requestID := middleware.GetRequestID(r.Context())
|
||||||
|
srcIP := clientIPForLimiter(r)
|
||||||
|
|
||||||
|
// recordFailedBasic ticks a slot on every credential rejection;
|
||||||
|
// once the IP has burned through its window's worth of failed
|
||||||
|
// attempts the limiter returns ErrRateLimited (which the next
|
||||||
|
// recordFailedBasic just no-ops out — we still want to fail-closed
|
||||||
|
// the auth here). The cleaner design is a pre-check that short-
|
||||||
|
// circuits the constant-time compare ENTIRELY for an IP at-cap, so
|
||||||
|
// a brute-force attacker can't smuggle timing data through. We do
|
||||||
|
// that pre-check via SlidingWindowLimiter.Allow with a peek-style
|
||||||
|
// fake-key that just queries state without recording a slot.
|
||||||
|
if h.failedBasicLimiter != nil && srcIP != "" {
|
||||||
|
if err := h.failedBasicLimiter.Allow(srcIP+"|peek", nowFn()); errors.Is(err, ratelimit.ErrRateLimited) {
|
||||||
|
// peek-key is shared across requests from this IP; the slot
|
||||||
|
// pollution is acceptable because the IP is already
|
||||||
|
// rate-limited and we want to keep them rate-limited.
|
||||||
|
ErrorWithRequestID(w, http.StatusTooManyRequests,
|
||||||
|
h.label()+" too many failed enrollment attempts from this source", requestID)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
user, pass, ok := r.BasicAuth()
|
||||||
|
if !ok || user == "" {
|
||||||
|
w.Header().Set("WWW-Authenticate", `Basic realm="est-enrollment"`)
|
||||||
|
ErrorWithRequestID(w, http.StatusUnauthorized,
|
||||||
|
h.label()+" enrollment requires HTTP Basic auth", requestID)
|
||||||
|
h.recordFailedBasic(srcIP)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if subtle.ConstantTimeCompare([]byte(pass), []byte(h.basicPassword)) != 1 {
|
||||||
|
w.Header().Set("WWW-Authenticate", `Basic realm="est-enrollment"`)
|
||||||
|
ErrorWithRequestID(w, http.StatusUnauthorized,
|
||||||
|
h.label()+" enrollment password incorrect", requestID)
|
||||||
|
h.recordFailedBasic(srcIP)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// recordFailedBasic ticks a slot against the source-IP failed-auth
|
||||||
|
// limiter. Errors from Allow are intentionally ignored — a present
|
||||||
|
// failure simply means the IP has crossed the limit, which is exactly
|
||||||
|
// the state the per-IP gate reports back to the next request.
|
||||||
|
func (h ESTHandler) recordFailedBasic(srcIP string) {
|
||||||
|
if h.failedBasicLimiter == nil || srcIP == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = h.failedBasicLimiter.Allow(srcIP, nowFn())
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyPerPrincipalRateLimit gates an enrollment by (CN, sourceIP).
|
||||||
|
// Returns nil when the request is allowed; ErrRateLimited (or wrapped
|
||||||
|
// equivalent) when the principal has exhausted its window budget.
|
||||||
|
//
|
||||||
|
// CN extraction: the CSR's Subject.CommonName is the canonical
|
||||||
|
// principal in the EST contract (the issued cert will carry that CN).
|
||||||
|
// sourceIP comes from clientIPForLimiter.
|
||||||
|
func (h ESTHandler) applyPerPrincipalRateLimit(r *http.Request, csr *x509.CertificateRequest) error {
|
||||||
|
if h.perPrincipalLimiter == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cn := ""
|
||||||
|
if csr != nil {
|
||||||
|
cn = csr.Subject.CommonName
|
||||||
|
}
|
||||||
|
srcIP := clientIPForLimiter(r)
|
||||||
|
key := cn + "|" + srcIP
|
||||||
|
return h.perPrincipalLimiter.Allow(key, nowFn())
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeChannelBindingError maps cms.* sentinel errors to HTTP statuses
|
||||||
|
// + audit-friendly messages. Mirrors the SCEP CertRep failInfo error
|
||||||
|
// translation pattern (signature_invalid → BadMessageCheck etc.).
|
||||||
|
func (h ESTHandler) writeChannelBindingError(w http.ResponseWriter, requestID string, err error) {
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, cms.ErrChannelBindingMissing):
|
||||||
|
ErrorWithRequestID(w, http.StatusBadRequest,
|
||||||
|
"EST simplereenroll requires RFC 9266 channel binding for this profile", requestID)
|
||||||
|
case errors.Is(err, cms.ErrChannelBindingMismatch):
|
||||||
|
// 409 Conflict signals to the client that the request was
|
||||||
|
// well-formed but the channel-binding state on certctl's side
|
||||||
|
// disagreed with the device's — usually MITM or reverse proxy
|
||||||
|
// terminating TLS in front of certctl.
|
||||||
|
ErrorWithRequestID(w, http.StatusConflict,
|
||||||
|
"EST channel binding does not match TLS exporter — TLS terminator in front of certctl?", requestID)
|
||||||
|
case errors.Is(err, cms.ErrChannelBindingNotTLS13):
|
||||||
|
ErrorWithRequestID(w, http.StatusUpgradeRequired,
|
||||||
|
"EST channel binding requires TLS 1.3", requestID)
|
||||||
|
default:
|
||||||
|
ErrorWithRequestID(w, http.StatusBadRequest,
|
||||||
|
fmt.Sprintf("EST channel-binding verification failed: %v", err), requestID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- response writers (legacy + mTLS share these) -----
|
||||||
|
|
||||||
|
// writeCACertsResponse writes the PKCS#7 certs-only CA chain. Shared
|
||||||
|
// by CACerts (legacy route) + CACertsMTLS (mTLS route).
|
||||||
|
func (h ESTHandler) writeCACertsResponse(w http.ResponseWriter, r *http.Request) {
|
||||||
|
requestID := middleware.GetRequestID(r.Context())
|
||||||
|
caCertPEM, err := h.svc.GetCACerts(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
ErrorWithRequestID(w, http.StatusInternalServerError,
|
||||||
|
fmt.Sprintf("Failed to get CA certificates: %v", err), requestID)
|
||||||
|
return
|
||||||
|
}
|
||||||
derCerts, err := pkcs7.PEMToDERChain(caCertPEM)
|
derCerts, err := pkcs7.PEMToDERChain(caCertPEM)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestID := middleware.GetRequestID(r.Context())
|
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to encode CA certificates", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to encode CA certificates", requestID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build a simple PKCS#7 SignedData (certs-only, degenerate) structure
|
|
||||||
pkcs7Data, err := pkcs7.BuildCertsOnlyPKCS7(derCerts)
|
pkcs7Data, err := pkcs7.BuildCertsOnlyPKCS7(derCerts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestID := middleware.GetRequestID(r.Context())
|
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to build PKCS#7 response", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to build PKCS#7 response", requestID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// RFC 7030 Section 4.1.3: response is base64-encoded application/pkcs7-mime
|
|
||||||
w.Header().Set("Content-Type", "application/pkcs7-mime; smime-type=certs-only")
|
w.Header().Set("Content-Type", "application/pkcs7-mime; smime-type=certs-only")
|
||||||
w.Header().Set("Content-Transfer-Encoding", "base64")
|
w.Header().Set("Content-Transfer-Encoding", "base64")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
encoded := base64.StdEncoding.EncodeToString(pkcs7Data)
|
writeBase64Wrapped(w, pkcs7Data)
|
||||||
// Write base64 with line breaks at 76 chars per RFC 2045
|
}
|
||||||
|
|
||||||
|
// writeCSRAttrsResponse writes the per-profile CSR attribute hints.
|
||||||
|
// Shared by CSRAttrs (legacy) + CSRAttrsMTLS (mTLS).
|
||||||
|
func (h ESTHandler) writeCSRAttrsResponse(w http.ResponseWriter, r *http.Request) {
|
||||||
|
requestID := middleware.GetRequestID(r.Context())
|
||||||
|
attrs, err := h.svc.GetCSRAttrs(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
ErrorWithRequestID(w, http.StatusInternalServerError,
|
||||||
|
fmt.Sprintf("Failed to get CSR attributes: %v", err), requestID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(attrs) == 0 {
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/csrattrs")
|
||||||
|
w.Header().Set("Content-Transfer-Encoding", "base64")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(base64.StdEncoding.EncodeToString(attrs)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeCertResponse writes an EST enrollment response as base64-encoded PKCS#7.
|
||||||
|
func (h ESTHandler) writeCertResponse(w http.ResponseWriter, result *domain.ESTEnrollResult) {
|
||||||
|
var derCerts [][]byte
|
||||||
|
certDER, err := pkcs7.PEMToDERChain(result.CertPEM)
|
||||||
|
if err != nil || len(certDER) == 0 {
|
||||||
|
http.Error(w, "Failed to encode certificate", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
derCerts = append(derCerts, certDER...)
|
||||||
|
if result.ChainPEM != "" {
|
||||||
|
chainDER, err := pkcs7.PEMToDERChain(result.ChainPEM)
|
||||||
|
if err == nil {
|
||||||
|
derCerts = append(derCerts, chainDER...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pkcs7Data, err := pkcs7.BuildCertsOnlyPKCS7(derCerts)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Failed to build PKCS#7 response", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/pkcs7-mime; smime-type=certs-only")
|
||||||
|
w.Header().Set("Content-Transfer-Encoding", "base64")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
writeBase64Wrapped(w, pkcs7Data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeBase64Wrapped emits b as base64 with CRLF every 76 chars per RFC 2045.
|
||||||
|
// Pulled out as a helper so the three writers above don't repeat the loop.
|
||||||
|
func writeBase64Wrapped(w http.ResponseWriter, b []byte) {
|
||||||
|
encoded := base64.StdEncoding.EncodeToString(b)
|
||||||
for i := 0; i < len(encoded); i += 76 {
|
for i := 0; i < len(encoded); i += 76 {
|
||||||
end := i + 76
|
end := i + 76
|
||||||
if end > len(encoded) {
|
if end > len(encoded) {
|
||||||
@@ -99,66 +601,84 @@ func (h ESTHandler) CACerts(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SimpleEnroll handles POST /.well-known/est/simpleenroll
|
// readCSRFromRequest reads and decodes the CSR from an EST enrollment request.
|
||||||
// Accepts a base64-encoded PKCS#10 CSR and returns a base64-encoded PKCS#7 certificate.
|
// EST sends CSRs as base64-encoded PKCS#10 DER with Content-Type application/pkcs10.
|
||||||
func (h ESTHandler) SimpleEnroll(w http.ResponseWriter, r *http.Request) {
|
func (h ESTHandler) readCSRFromRequest(r *http.Request) (string, error) {
|
||||||
if r.Method != http.MethodPost {
|
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) // 1MB limit
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
requestID := middleware.GetRequestID(r.Context())
|
|
||||||
|
|
||||||
if err := verifyESTTransport(r); err != nil {
|
|
||||||
ErrorWithRequestID(w, http.StatusBadRequest, fmt.Sprintf("EST transport precondition failed: %v", err), requestID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
csrPEM, err := h.readCSRFromRequest(r)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusBadRequest, fmt.Sprintf("Invalid CSR: %v", err), requestID)
|
return "", fmt.Errorf("failed to read request body: %w", err)
|
||||||
return
|
}
|
||||||
|
defer r.Body.Close()
|
||||||
|
|
||||||
|
if len(body) == 0 {
|
||||||
|
return "", fmt.Errorf("empty request body")
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := h.svc.SimpleEnroll(r.Context(), csrPEM)
|
bodyStr := strings.TrimSpace(string(body))
|
||||||
|
if strings.HasPrefix(bodyStr, "-----BEGIN CERTIFICATE REQUEST-----") {
|
||||||
|
block, _ := pem.Decode([]byte(bodyStr))
|
||||||
|
if block == nil {
|
||||||
|
return "", fmt.Errorf("invalid PEM-encoded CSR")
|
||||||
|
}
|
||||||
|
if _, err := x509.ParseCertificateRequest(block.Bytes); err != nil {
|
||||||
|
return "", fmt.Errorf("invalid CSR: %w", err)
|
||||||
|
}
|
||||||
|
return bodyStr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
derBytes, err := base64.StdEncoding.DecodeString(bodyStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, fmt.Sprintf("Enrollment failed: %v", err), requestID)
|
cleaned := strings.Map(func(r rune) rune {
|
||||||
return
|
if r == '\r' || r == '\n' || r == ' ' || r == '\t' {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}, bodyStr)
|
||||||
|
derBytes, err = base64.StdEncoding.DecodeString(cleaned)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to decode base64 CSR: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if _, err := x509.ParseCertificateRequest(derBytes); err != nil {
|
||||||
h.writeCertResponse(w, result)
|
return "", fmt.Errorf("invalid PKCS#10 CSR: %w", err)
|
||||||
|
}
|
||||||
|
csrPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE REQUEST",
|
||||||
|
Bytes: derBytes,
|
||||||
|
})
|
||||||
|
return string(csrPEM), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SimpleReEnroll handles POST /.well-known/est/simplereenroll
|
// decodeCSRPEM is a convenience wrapper around pem.Decode +
|
||||||
// Same as SimpleEnroll but for re-enrollment (certificate renewal).
|
// x509.ParseCertificateRequest. Returns nil on any decode/parse error
|
||||||
func (h ESTHandler) SimpleReEnroll(w http.ResponseWriter, r *http.Request) {
|
// (callers downstream re-parse via the service path; this is just for
|
||||||
if r.Method != http.MethodPost {
|
// the handler-side gates that need the CN + binding attribute).
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
func decodeCSRPEM(csrPEM string) (*x509.CertificateRequest, error) {
|
||||||
return
|
block, _ := pem.Decode([]byte(csrPEM))
|
||||||
|
if block == nil {
|
||||||
|
return nil, fmt.Errorf("PEM decode failed")
|
||||||
}
|
}
|
||||||
|
return x509.ParseCertificateRequest(block.Bytes)
|
||||||
requestID := middleware.GetRequestID(r.Context())
|
|
||||||
|
|
||||||
if err := verifyESTTransport(r); err != nil {
|
|
||||||
ErrorWithRequestID(w, http.StatusBadRequest, fmt.Sprintf("EST transport precondition failed: %v", err), requestID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
csrPEM, err := h.readCSRFromRequest(r)
|
|
||||||
if err != nil {
|
|
||||||
ErrorWithRequestID(w, http.StatusBadRequest, fmt.Sprintf("Invalid CSR: %v", err), requestID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := h.svc.SimpleReEnroll(r.Context(), csrPEM)
|
|
||||||
if err != nil {
|
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, fmt.Sprintf("Re-enrollment failed: %v", err), requestID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
h.writeCertResponse(w, result)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clientIPForLimiter returns the source IP a per-IP rate limiter should
|
||||||
|
// key against. Honors X-Forwarded-For when the request came through a
|
||||||
|
// trusted proxy (no proxy-trust list yet — falls back to RemoteAddr).
|
||||||
|
func clientIPForLimiter(r *http.Request) string {
|
||||||
|
// Don't blindly trust XFF — ignore it for now and always use
|
||||||
|
// RemoteAddr. A future bundle can add a documented proxy-trust list.
|
||||||
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return r.RemoteAddr
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
// nowFn is the package-private time source. Override in tests for
|
||||||
|
// deterministic clock injection without dragging time.Time into the
|
||||||
|
// handler API surface. Defined in est_clock.go so mocking out
|
||||||
|
// requires touching only one file.
|
||||||
|
|
||||||
// verifyESTTransport implements Bundle-4 / M-021 EST transport precondition.
|
// verifyESTTransport implements Bundle-4 / M-021 EST transport precondition.
|
||||||
//
|
//
|
||||||
// RFC 7030 §3.2.3 ("Linking Identity and POP Information") requires that when
|
// RFC 7030 §3.2.3 ("Linking Identity and POP Information") requires that when
|
||||||
@@ -169,32 +689,11 @@ func (h ESTHandler) SimpleReEnroll(w http.ResponseWriter, r *http.Request) {
|
|||||||
// TLS-Unique is unavailable; RFC 9266 defines `tls-exporter` as the TLS 1.3
|
// TLS-Unique is unavailable; RFC 9266 defines `tls-exporter` as the TLS 1.3
|
||||||
// replacement.
|
// replacement.
|
||||||
//
|
//
|
||||||
// **Current scope of this function (Bundle-4 closure):** certctl does NOT
|
// **EST RFC 7030 hardening Phases 2-4 update:** RFC 9266 channel binding is
|
||||||
// currently support EST client certificate authentication. The EST endpoint
|
// now wired in via the cms package (Phase 2.4) and called from
|
||||||
// accepts unauthenticated POSTs (the SCEP equivalent enforces a
|
// SimpleReEnrollMTLS when the per-profile policy requires it. This function
|
||||||
// challenge-password via `preflightSCEPChallengePassword`; EST has no
|
// continues to handle the lower-level transport preconditions that ALL EST
|
||||||
// equivalent today). Per RFC 7030 §3.2.3, channel binding is REQUIRED only
|
// requests share (regardless of mTLS / Basic / unauth profile shape).
|
||||||
// when client certificate authentication is in use; without that, the §3.2.3
|
|
||||||
// requirement is moot.
|
|
||||||
//
|
|
||||||
// What we DO enforce here as defense-in-depth:
|
|
||||||
//
|
|
||||||
// 1. r.TLS must be non-nil — the EST endpoint MUST be reached over TLS.
|
|
||||||
// Defensive: certctl pins HTTPS-only at the server-side TLS config, but
|
|
||||||
// a future routing-layer regression that exposes EST over plaintext
|
|
||||||
// would be caught here.
|
|
||||||
// 2. Negotiated TLS version must be >= TLS 1.2 — RFC 7030 doesn't mandate
|
|
||||||
// a specific TLS version, but a pre-1.2 negotiation indicates a
|
|
||||||
// misconfigured client/server pair. certctl's MinVersion is TLS 1.3
|
|
||||||
// so this should always hold.
|
|
||||||
// 3. r.TLS.HandshakeComplete must be true — defensive against partial-
|
|
||||||
// handshake replays.
|
|
||||||
//
|
|
||||||
// **Deferred to a future bundle (operator decision required):**
|
|
||||||
//
|
|
||||||
// - RFC 9266 `tls-exporter` channel binding when EST mTLS is added.
|
|
||||||
// - EST mTLS support itself — currently EST is unauth-or-bearer; mTLS
|
|
||||||
// would be a V3-aligned compliance feature.
|
|
||||||
//
|
//
|
||||||
// Returns nil if all preconditions pass; non-nil error otherwise.
|
// Returns nil if all preconditions pass; non-nil error otherwise.
|
||||||
func verifyESTTransport(r *http.Request) error {
|
func verifyESTTransport(r *http.Request) error {
|
||||||
@@ -213,130 +712,5 @@ func verifyESTTransport(r *http.Request) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CSRAttrs handles GET /.well-known/est/csrattrs
|
|
||||||
// Returns the CSR attributes the server wants the client to include in enrollment requests.
|
|
||||||
func (h ESTHandler) CSRAttrs(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Method != http.MethodGet {
|
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
attrs, err := h.svc.GetCSRAttrs(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
requestID := middleware.GetRequestID(r.Context())
|
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, fmt.Sprintf("Failed to get CSR attributes: %v", err), requestID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(attrs) == 0 {
|
|
||||||
// No specific attributes required — return 204
|
|
||||||
w.WriteHeader(http.StatusNoContent)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/csrattrs")
|
|
||||||
w.Header().Set("Content-Transfer-Encoding", "base64")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
w.Write([]byte(base64.StdEncoding.EncodeToString(attrs)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// readCSRFromRequest reads and decodes the CSR from an EST enrollment request.
|
|
||||||
// EST sends CSRs as base64-encoded PKCS#10 DER with Content-Type application/pkcs10.
|
|
||||||
func (h ESTHandler) readCSRFromRequest(r *http.Request) (string, error) {
|
|
||||||
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) // 1MB limit
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to read request body: %w", err)
|
|
||||||
}
|
|
||||||
defer r.Body.Close()
|
|
||||||
|
|
||||||
if len(body) == 0 {
|
|
||||||
return "", fmt.Errorf("empty request body")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if it's already PEM-encoded (some clients send PEM directly)
|
|
||||||
bodyStr := strings.TrimSpace(string(body))
|
|
||||||
if strings.HasPrefix(bodyStr, "-----BEGIN CERTIFICATE REQUEST-----") {
|
|
||||||
// Validate it parses
|
|
||||||
block, _ := pem.Decode([]byte(bodyStr))
|
|
||||||
if block == nil {
|
|
||||||
return "", fmt.Errorf("invalid PEM-encoded CSR")
|
|
||||||
}
|
|
||||||
if _, err := x509.ParseCertificateRequest(block.Bytes); err != nil {
|
|
||||||
return "", fmt.Errorf("invalid CSR: %w", err)
|
|
||||||
}
|
|
||||||
return bodyStr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// EST standard: base64-encoded DER PKCS#10
|
|
||||||
derBytes, err := base64.StdEncoding.DecodeString(bodyStr)
|
|
||||||
if err != nil {
|
|
||||||
// Try with padding/whitespace stripped
|
|
||||||
cleaned := strings.Map(func(r rune) rune {
|
|
||||||
if r == '\r' || r == '\n' || r == ' ' || r == '\t' {
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
return r
|
|
||||||
}, bodyStr)
|
|
||||||
derBytes, err = base64.StdEncoding.DecodeString(cleaned)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to decode base64 CSR: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate it's a valid PKCS#10 CSR
|
|
||||||
if _, err := x509.ParseCertificateRequest(derBytes); err != nil {
|
|
||||||
return "", fmt.Errorf("invalid PKCS#10 CSR: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert DER to PEM for internal use (certctl services expect PEM)
|
|
||||||
csrPEM := pem.EncodeToMemory(&pem.Block{
|
|
||||||
Type: "CERTIFICATE REQUEST",
|
|
||||||
Bytes: derBytes,
|
|
||||||
})
|
|
||||||
return string(csrPEM), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// writeCertResponse writes an EST enrollment response as base64-encoded PKCS#7.
|
|
||||||
func (h ESTHandler) writeCertResponse(w http.ResponseWriter, result *domain.ESTEnrollResult) {
|
|
||||||
// Parse cert and chain PEM to DER
|
|
||||||
var derCerts [][]byte
|
|
||||||
|
|
||||||
// Add the issued certificate
|
|
||||||
certDER, err := pkcs7.PEMToDERChain(result.CertPEM)
|
|
||||||
if err != nil || len(certDER) == 0 {
|
|
||||||
http.Error(w, "Failed to encode certificate", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
derCerts = append(derCerts, certDER...)
|
|
||||||
|
|
||||||
// Add the CA chain if present
|
|
||||||
if result.ChainPEM != "" {
|
|
||||||
chainDER, err := pkcs7.PEMToDERChain(result.ChainPEM)
|
|
||||||
if err == nil {
|
|
||||||
derCerts = append(derCerts, chainDER...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build PKCS#7 certs-only
|
|
||||||
pkcs7Data, err := pkcs7.BuildCertsOnlyPKCS7(derCerts)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Failed to build PKCS#7 response", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/pkcs7-mime; smime-type=certs-only")
|
|
||||||
w.Header().Set("Content-Transfer-Encoding", "base64")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
encoded := base64.StdEncoding.EncodeToString(pkcs7Data)
|
|
||||||
for i := 0; i < len(encoded); i += 76 {
|
|
||||||
end := i + 76
|
|
||||||
if end > len(encoded) {
|
|
||||||
end = len(encoded)
|
|
||||||
}
|
|
||||||
w.Write([]byte(encoded[i:end]))
|
|
||||||
w.Write([]byte("\r\n"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE: PKCS#7 helpers (BuildCertsOnlyPKCS7, PEMToDERChain, ASN.1 wrappers)
|
// NOTE: PKCS#7 helpers (BuildCertsOnlyPKCS7, PEMToDERChain, ASN.1 wrappers)
|
||||||
// are in the shared internal/pkcs7 package, used by both EST and SCEP handlers.
|
// are in the shared internal/pkcs7 package, used by both EST and SCEP handlers.
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// EST RFC 7030 hardening Phase 3.3 / 4.2: nowFn is the time source that
|
||||||
|
// the EST handler's per-IP failed-Basic-auth limiter and per-(CN,
|
||||||
|
// sourceIP) rate limiter consult. Tests can override this to inject a
|
||||||
|
// deterministic clock without dragging time.Time into the handler API
|
||||||
|
// surface (the handler's setters take ratelimit.SlidingWindowLimiter
|
||||||
|
// pointers, not time-injection callbacks — keeping the wire-up simple).
|
||||||
|
//
|
||||||
|
// nowFn is package-private + lower-case so external callers can't poke
|
||||||
|
// at it; the est_clock_test.go helper restoreNowFn is the documented
|
||||||
|
// override pattern for tests in this package.
|
||||||
|
var nowFn = time.Now
|
||||||
@@ -0,0 +1,459 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/asn1"
|
||||||
|
"encoding/pem"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"math/big"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/shankar0123/certctl/internal/cms"
|
||||||
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
"github.com/shankar0123/certctl/internal/ratelimit"
|
||||||
|
"github.com/shankar0123/certctl/internal/trustanchor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EST RFC 7030 hardening master bundle Phases 2-4 tests.
|
||||||
|
// Covers: mTLS sibling route gates, HTTP Basic enrollment-password auth,
|
||||||
|
// per-source-IP failed-auth rate limit, RFC 9266 channel binding, and
|
||||||
|
// per-(CN, sourceIP) per-principal sliding-window rate limit.
|
||||||
|
|
||||||
|
// hardeningTestSetup is a per-test fixture: a mock service that always
|
||||||
|
// succeeds, plus a CA + issued client cert that an mTLS test can attach
|
||||||
|
// to its synthetic *http.Request.TLS.
|
||||||
|
type hardeningTestSetup struct {
|
||||||
|
svc *mockESTService
|
||||||
|
caCert *x509.Certificate
|
||||||
|
caKey *ecdsa.PrivateKey
|
||||||
|
clientCrt *x509.Certificate
|
||||||
|
clientKey *ecdsa.PrivateKey
|
||||||
|
trustPool *trustanchor.Holder
|
||||||
|
bundleDir string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHardeningTestSetup(t *testing.T) *hardeningTestSetup {
|
||||||
|
t.Helper()
|
||||||
|
caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ca key: %v", err)
|
||||||
|
}
|
||||||
|
caTmpl := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
Subject: pkix.Name{CommonName: "est-mtls-test-ca"},
|
||||||
|
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||||
|
NotAfter: time.Now().Add(24 * time.Hour),
|
||||||
|
IsCA: true,
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
KeyUsage: x509.KeyUsageCertSign,
|
||||||
|
}
|
||||||
|
caDER, err := x509.CreateCertificate(rand.Reader, caTmpl, caTmpl, &caKey.PublicKey, caKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ca create: %v", err)
|
||||||
|
}
|
||||||
|
caCert, _ := x509.ParseCertificate(caDER)
|
||||||
|
|
||||||
|
clientKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("client key: %v", err)
|
||||||
|
}
|
||||||
|
clientTmpl := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(2),
|
||||||
|
Subject: pkix.Name{CommonName: "test-device-001"},
|
||||||
|
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||||
|
NotAfter: time.Now().Add(24 * time.Hour),
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||||
|
}
|
||||||
|
clientDER, err := x509.CreateCertificate(rand.Reader, clientTmpl, caCert, &clientKey.PublicKey, caKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("client create: %v", err)
|
||||||
|
}
|
||||||
|
clientCrt, _ := x509.ParseCertificate(clientDER)
|
||||||
|
|
||||||
|
// Persist the CA bundle on disk so trustanchor.New can load it.
|
||||||
|
dir := t.TempDir()
|
||||||
|
bundlePath := filepath.Join(dir, "trust.pem")
|
||||||
|
body := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caDER})
|
||||||
|
if err := os.WriteFile(bundlePath, body, 0o600); err != nil {
|
||||||
|
t.Fatalf("write bundle: %v", err)
|
||||||
|
}
|
||||||
|
holder, err := trustanchor.New(bundlePath, slog.New(slog.NewTextHandler(io.Discard, nil)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("trustanchor.New: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &mockESTService{
|
||||||
|
CACertPEM: pemCertString(caDER),
|
||||||
|
EnrollResult: &domain.ESTEnrollResult{
|
||||||
|
CertPEM: pemCertString(clientDER),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return &hardeningTestSetup{
|
||||||
|
svc: svc,
|
||||||
|
caCert: caCert,
|
||||||
|
caKey: caKey,
|
||||||
|
clientCrt: clientCrt,
|
||||||
|
clientKey: clientKey,
|
||||||
|
trustPool: holder,
|
||||||
|
bundleDir: dir,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func pemCertString(der []byte) string {
|
||||||
|
return string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeMTLSRequest synthesises a POST against `path` with PEM CSR body and
|
||||||
|
// r.TLS populated with the given peer cert chain + handshake state. Used
|
||||||
|
// by the mTLS path tests where a real TLS handshake would force us into a
|
||||||
|
// full httptest.NewTLSServer setup.
|
||||||
|
func makeMTLSRequest(t *testing.T, path, csrPEM string, peerCerts []*x509.Certificate, version uint16) *http.Request {
|
||||||
|
t.Helper()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(csrPEM))
|
||||||
|
req.TLS = &tls.ConnectionState{
|
||||||
|
HandshakeComplete: true,
|
||||||
|
Version: version,
|
||||||
|
PeerCertificates: peerCerts,
|
||||||
|
}
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- mTLS handler gate -----
|
||||||
|
|
||||||
|
func TestSimpleEnrollMTLS_NoTrustPool_500(t *testing.T) {
|
||||||
|
s := newHardeningTestSetup(t)
|
||||||
|
h := NewESTHandler(s.svc) // intentionally do NOT call SetMTLSTrust
|
||||||
|
req := makeMTLSRequest(t, "/.well-known/est-mtls/corp/simpleenroll",
|
||||||
|
generateTestCSRPEM(t), []*x509.Certificate{s.clientCrt}, tls.VersionTLS13)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SimpleEnrollMTLS(w, req)
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Errorf("status = %d, want 500 (handler missing trust pool)", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSimpleEnrollMTLS_NoClientCert_401(t *testing.T) {
|
||||||
|
s := newHardeningTestSetup(t)
|
||||||
|
h := NewESTHandler(s.svc)
|
||||||
|
h.SetMTLSTrust(s.trustPool)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/.well-known/est-mtls/corp/simpleenroll",
|
||||||
|
strings.NewReader(generateTestCSRPEM(t)))
|
||||||
|
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SimpleEnrollMTLS(w, req)
|
||||||
|
if w.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("status = %d, want 401 (no client cert)", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSimpleEnrollMTLS_CertNotInPool_401(t *testing.T) {
|
||||||
|
s := newHardeningTestSetup(t)
|
||||||
|
other := newHardeningTestSetup(t) // different CA, unrelated to s.trustPool
|
||||||
|
h := NewESTHandler(s.svc)
|
||||||
|
h.SetMTLSTrust(s.trustPool)
|
||||||
|
req := makeMTLSRequest(t, "/.well-known/est-mtls/corp/simpleenroll",
|
||||||
|
generateTestCSRPEM(t), []*x509.Certificate{other.clientCrt}, tls.VersionTLS13)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SimpleEnrollMTLS(w, req)
|
||||||
|
if w.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("status = %d, want 401 (cert not trusted by this profile)", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSimpleEnrollMTLS_HappyPath_200(t *testing.T) {
|
||||||
|
s := newHardeningTestSetup(t)
|
||||||
|
h := NewESTHandler(s.svc)
|
||||||
|
h.SetMTLSTrust(s.trustPool)
|
||||||
|
req := makeMTLSRequest(t, "/.well-known/est-mtls/corp/simpleenroll",
|
||||||
|
generateTestCSRPEM(t), []*x509.Certificate{s.clientCrt}, tls.VersionTLS13)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SimpleEnrollMTLS(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want 200; body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- channel binding (Phase 2.4) -----
|
||||||
|
|
||||||
|
func TestSimpleReEnrollMTLS_ChannelBindingRequired_AbsentRejected(t *testing.T) {
|
||||||
|
s := newHardeningTestSetup(t)
|
||||||
|
h := NewESTHandler(s.svc)
|
||||||
|
h.SetMTLSTrust(s.trustPool)
|
||||||
|
h.SetChannelBindingRequired(true)
|
||||||
|
// CSR has no binding attribute. Synthetic ConnectionState — exporter
|
||||||
|
// extraction will fail (no real TLS secret), and required=true makes
|
||||||
|
// VerifyChannelBinding propagate that as the missing-binding error.
|
||||||
|
req := makeMTLSRequest(t, "/.well-known/est-mtls/corp/simplereenroll",
|
||||||
|
generateTestCSRPEM(t), []*x509.Certificate{s.clientCrt}, tls.VersionTLS13)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SimpleReEnrollMTLS(w, req)
|
||||||
|
// Either 400 (missing) or 426 (TLS 1.3 unavailable on synthetic state).
|
||||||
|
// Both are correct refusals; pin to "non-2xx" so the test isn't fragile
|
||||||
|
// against ConnectionState evolution.
|
||||||
|
if w.Code/100 == 2 {
|
||||||
|
t.Errorf("required + absent must reject; got 2xx (%d)", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSimpleReEnrollMTLS_ChannelBindingNotRequired_AbsentAllowed(t *testing.T) {
|
||||||
|
s := newHardeningTestSetup(t)
|
||||||
|
h := NewESTHandler(s.svc)
|
||||||
|
h.SetMTLSTrust(s.trustPool)
|
||||||
|
h.SetChannelBindingRequired(false)
|
||||||
|
// CSR has no binding, profile is opt-in only. The handler must allow.
|
||||||
|
req := makeMTLSRequest(t, "/.well-known/est-mtls/corp/simplereenroll",
|
||||||
|
generateTestCSRPEM(t), []*x509.Certificate{s.clientCrt}, tls.VersionTLS13)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SimpleReEnrollMTLS(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("required=false + absent must allow; got %d (%s)", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteChannelBindingError_KnownErrorsMapped(t *testing.T) {
|
||||||
|
// Smoke test the error-to-status mapping so a future cms sentinel rename
|
||||||
|
// gets caught at compile time + we hit each branch.
|
||||||
|
s := newHardeningTestSetup(t)
|
||||||
|
h := NewESTHandler(s.svc)
|
||||||
|
cases := []struct {
|
||||||
|
err error
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{cms.ErrChannelBindingMissing, http.StatusBadRequest},
|
||||||
|
{cms.ErrChannelBindingMismatch, http.StatusConflict},
|
||||||
|
{cms.ErrChannelBindingNotTLS13, http.StatusUpgradeRequired},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.writeChannelBindingError(w, "req-id", c.err)
|
||||||
|
if w.Code != c.want {
|
||||||
|
t.Errorf("error=%v → status %d, want %d", c.err, w.Code, c.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- HTTP Basic enrollment-password (Phase 3) -----
|
||||||
|
|
||||||
|
func TestSimpleEnroll_BasicAuth_NoHeader_401(t *testing.T) {
|
||||||
|
s := newHardeningTestSetup(t)
|
||||||
|
h := NewESTHandler(s.svc)
|
||||||
|
h.SetEnrollmentPassword("super-secret")
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
|
||||||
|
strings.NewReader(generateTestCSRPEM(t)))
|
||||||
|
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SimpleEnroll(w, req)
|
||||||
|
if w.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("status = %d, want 401 (Basic required, header absent)", w.Code)
|
||||||
|
}
|
||||||
|
if got := w.Header().Get("WWW-Authenticate"); !strings.Contains(got, "Basic") {
|
||||||
|
t.Errorf("WWW-Authenticate = %q, want to contain 'Basic'", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSimpleEnroll_BasicAuth_WrongPassword_401(t *testing.T) {
|
||||||
|
s := newHardeningTestSetup(t)
|
||||||
|
h := NewESTHandler(s.svc)
|
||||||
|
h.SetEnrollmentPassword("super-secret")
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
|
||||||
|
strings.NewReader(generateTestCSRPEM(t)))
|
||||||
|
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||||
|
req.SetBasicAuth("device", "wrong-password")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SimpleEnroll(w, req)
|
||||||
|
if w.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("status = %d, want 401 (wrong password)", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSimpleEnroll_BasicAuth_CorrectPassword_200(t *testing.T) {
|
||||||
|
s := newHardeningTestSetup(t)
|
||||||
|
h := NewESTHandler(s.svc)
|
||||||
|
h.SetEnrollmentPassword("super-secret")
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
|
||||||
|
strings.NewReader(generateTestCSRPEM(t)))
|
||||||
|
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||||
|
req.SetBasicAuth("device", "super-secret")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SimpleEnroll(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want 200 (correct password); body=%q", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSimpleEnroll_BasicAuth_NoPassword_NoGate(t *testing.T) {
|
||||||
|
// When the per-profile enrollment password is empty, the Basic gate is
|
||||||
|
// off and the handler reverts to the v2.0.x anonymous behavior.
|
||||||
|
s := newHardeningTestSetup(t)
|
||||||
|
h := NewESTHandler(s.svc) // SetEnrollmentPassword not called
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
|
||||||
|
strings.NewReader(generateTestCSRPEM(t)))
|
||||||
|
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SimpleEnroll(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want 200 (no Basic gate)", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- source-IP failed-auth rate limit (Phase 3.3) -----
|
||||||
|
|
||||||
|
func TestSimpleEnroll_BasicAuth_FailedAttemptLimitedAfterThreshold(t *testing.T) {
|
||||||
|
s := newHardeningTestSetup(t)
|
||||||
|
h := NewESTHandler(s.svc)
|
||||||
|
h.SetEnrollmentPassword("super-secret")
|
||||||
|
// Cap of 2 failed attempts before the IP gets locked. Each failed
|
||||||
|
// attempt records a slot; the 3rd request should be 429.
|
||||||
|
limiter := ratelimit.NewSlidingWindowLimiter(2, time.Hour, 10)
|
||||||
|
h.SetSourceIPRateLimiter(limiter)
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
|
||||||
|
strings.NewReader(generateTestCSRPEM(t)))
|
||||||
|
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||||
|
req.RemoteAddr = "10.0.0.42:12345"
|
||||||
|
req.SetBasicAuth("device", "WRONG")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SimpleEnroll(w, req)
|
||||||
|
if w.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("attempt %d: want 401, got %d", i, w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// The 3rd attempt — even with a correct password — must be rate limited.
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
|
||||||
|
strings.NewReader(generateTestCSRPEM(t)))
|
||||||
|
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||||
|
req.RemoteAddr = "10.0.0.42:12345"
|
||||||
|
req.SetBasicAuth("device", "super-secret")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SimpleEnroll(w, req)
|
||||||
|
if w.Code != http.StatusTooManyRequests {
|
||||||
|
t.Errorf("post-lockout status = %d, want 429 (correct password should still be locked out)", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- per-principal sliding-window rate limit (Phase 4.2) -----
|
||||||
|
|
||||||
|
func TestSimpleEnroll_PerPrincipalLimit_BlocksAfterCap(t *testing.T) {
|
||||||
|
s := newHardeningTestSetup(t)
|
||||||
|
h := NewESTHandler(s.svc)
|
||||||
|
limiter := ratelimit.NewSlidingWindowLimiter(2, 24*time.Hour, 100)
|
||||||
|
h.SetPerPrincipalRateLimiter(limiter)
|
||||||
|
|
||||||
|
// First 2 enrollments from same (CN, IP) — pass.
|
||||||
|
csrPEM := generateTestCSRPEM(t)
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
|
||||||
|
strings.NewReader(csrPEM))
|
||||||
|
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||||
|
req.RemoteAddr = "10.0.0.7:5555"
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SimpleEnroll(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("attempt %d: want 200, got %d", i, w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Third enrollment from same (CN, IP) — limited.
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
|
||||||
|
strings.NewReader(csrPEM))
|
||||||
|
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||||
|
req.RemoteAddr = "10.0.0.7:5555"
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SimpleEnroll(w, req)
|
||||||
|
if w.Code != http.StatusTooManyRequests {
|
||||||
|
t.Errorf("3rd same-principal enrollment status = %d, want 429", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSimpleEnroll_PerPrincipalLimit_DifferentPrincipalsIndependent(t *testing.T) {
|
||||||
|
s := newHardeningTestSetup(t)
|
||||||
|
h := NewESTHandler(s.svc)
|
||||||
|
limiter := ratelimit.NewSlidingWindowLimiter(1, 24*time.Hour, 100)
|
||||||
|
h.SetPerPrincipalRateLimiter(limiter)
|
||||||
|
|
||||||
|
csrPEM1 := generateTestCSRPEM(t)
|
||||||
|
csrPEM2 := generateTestCSRPEM(t) // different key + (default) different CN
|
||||||
|
|
||||||
|
req1 := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll", strings.NewReader(csrPEM1))
|
||||||
|
req1.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||||
|
req1.RemoteAddr = "10.0.0.10:1111"
|
||||||
|
w1 := httptest.NewRecorder()
|
||||||
|
h.SimpleEnroll(w1, req1)
|
||||||
|
if w1.Code != http.StatusOK {
|
||||||
|
t.Fatalf("principal 1 first call: want 200, got %d", w1.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Same CN as csrPEM1 but different IP — independent bucket.
|
||||||
|
req2 := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll", strings.NewReader(csrPEM2))
|
||||||
|
req2.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||||
|
req2.RemoteAddr = "10.0.0.20:2222"
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
h.SimpleEnroll(w2, req2)
|
||||||
|
if w2.Code != http.StatusOK {
|
||||||
|
t.Errorf("principal 2 first call: want 200, got %d", w2.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- per-handler smoke test for the un-rolled mTLS variants -----
|
||||||
|
|
||||||
|
func TestCACertsMTLS_RequiresClientCert(t *testing.T) {
|
||||||
|
s := newHardeningTestSetup(t)
|
||||||
|
h := NewESTHandler(s.svc)
|
||||||
|
h.SetMTLSTrust(s.trustPool)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/.well-known/est-mtls/corp/cacerts", nil)
|
||||||
|
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.CACertsMTLS(w, req)
|
||||||
|
if w.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("CACertsMTLS no-cert status = %d, want 401", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCSRAttrsMTLS_RequiresClientCert(t *testing.T) {
|
||||||
|
s := newHardeningTestSetup(t)
|
||||||
|
h := NewESTHandler(s.svc)
|
||||||
|
h.SetMTLSTrust(s.trustPool)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/.well-known/est-mtls/corp/csrattrs", nil)
|
||||||
|
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.CSRAttrsMTLS(w, req)
|
||||||
|
if w.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("CSRAttrsMTLS no-cert status = %d, want 401", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- ensure the per-principal limit fires only when configured -----
|
||||||
|
|
||||||
|
func TestSimpleEnroll_NoPerPrincipalLimiter_AllUnbounded(t *testing.T) {
|
||||||
|
s := newHardeningTestSetup(t)
|
||||||
|
h := NewESTHandler(s.svc) // SetPerPrincipalRateLimiter not called
|
||||||
|
csrPEM := generateTestCSRPEM(t)
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
|
||||||
|
strings.NewReader(csrPEM))
|
||||||
|
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.SimpleEnroll(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("attempt %d: want 200, got %d", i, w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// silenceUnused keeps the "declared and not used" linter happy when we add
|
||||||
|
// helpers that future tests may invoke (asn1, atomic).
|
||||||
|
var _ = asn1.RawValue{}
|
||||||
|
var _ atomic.Int32
|
||||||
@@ -81,10 +81,11 @@ var AuthExemptRouterRoutes = []string{
|
|||||||
// TestDispatch_AuthExemptPrefixes regression test in cmd/server/main_test.go
|
// TestDispatch_AuthExemptPrefixes regression test in cmd/server/main_test.go
|
||||||
// pins this slice to buildFinalHandler's actual dispatch logic.
|
// pins this slice to buildFinalHandler's actual dispatch logic.
|
||||||
var AuthExemptDispatchPrefixes = []string{
|
var AuthExemptDispatchPrefixes = []string{
|
||||||
"/.well-known/pki", // RFC 5280 CRL + RFC 6960 OCSP — relying-party-unauth
|
"/.well-known/pki", // RFC 5280 CRL + RFC 6960 OCSP — relying-party-unauth
|
||||||
"/.well-known/est", // RFC 7030 EST — auth via mTLS or CSR-embedded creds
|
"/.well-known/est", // RFC 7030 EST — auth via mTLS or CSR-embedded creds
|
||||||
"/scep", // RFC 8894 SCEP — auth via challengePassword in CSR
|
"/.well-known/est-mtls", // EST + mTLS sibling route (EST hardening Phase 2) — auth is client cert
|
||||||
"/scep-mtls", // SCEP + mTLS sibling route (Phase 6.5) — auth is client cert + challengePassword
|
"/scep", // RFC 8894 SCEP — auth via challengePassword in CSR
|
||||||
|
"/scep-mtls", // SCEP + mTLS sibling route (Phase 6.5) — auth is client cert + challengePassword
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandlerRegistry groups all API handler dependencies for router registration.
|
// HandlerRegistry groups all API handler dependencies for router registration.
|
||||||
@@ -445,6 +446,44 @@ func (r *Router) RegisterESTHandlers(handlers map[string]handler.ESTHandler) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterESTMTLSHandlers sets up the sibling `/.well-known/est-mtls/<PathID>/`
|
||||||
|
// routes for EST profiles that opted into mTLS via
|
||||||
|
// `CERTCTL_EST_PROFILE_<NAME>_MTLS_ENABLED=true`.
|
||||||
|
//
|
||||||
|
// EST RFC 7030 hardening master bundle Phase 2.2 + 2.3: enterprise
|
||||||
|
// procurement teams routinely reject 'shared password authentication' as
|
||||||
|
// a checkbox-fail regardless of how strong the password is. This sibling
|
||||||
|
// route adds client-cert auth at the handler layer AND keeps the (Phase 3)
|
||||||
|
// HTTP Basic enrollment-password as a defense-in-depth fallback for the
|
||||||
|
// non-mTLS profile. Devices present a bootstrap cert from a trusted CA,
|
||||||
|
// then EST-enroll for their long-lived cert. Mirrors the SCEP mTLS
|
||||||
|
// sibling pattern at RegisterSCEPMTLSHandlers below (commit 6b0d9e from
|
||||||
|
// the SCEP Phase 6.5 work).
|
||||||
|
//
|
||||||
|
// Path conventions: every mTLS profile gets a non-empty PathID, so the
|
||||||
|
// sibling routes are always /.well-known/est-mtls/<pathID>/. There is no
|
||||||
|
// "empty PathID = legacy /.well-known/est-mtls" case — mTLS is opt-in
|
||||||
|
// per profile, the legacy /.well-known/est root is always non-mTLS to
|
||||||
|
// preserve backward compat with existing deploys.
|
||||||
|
//
|
||||||
|
// Each handler in the map MUST have had SetMTLSTrust called so the
|
||||||
|
// per-profile cert verification has a trust anchor. cmd/server/main.go's
|
||||||
|
// per-profile EST loop wires this in the same loop iteration that
|
||||||
|
// registers the handler.
|
||||||
|
func (r *Router) RegisterESTMTLSHandlers(handlers map[string]handler.ESTHandler) {
|
||||||
|
for pathID, h := range handlers {
|
||||||
|
if pathID == "" {
|
||||||
|
continue // mTLS sibling route requires per-profile PathID
|
||||||
|
}
|
||||||
|
hCopy := h // h is captured by value — see RegisterESTHandlers above
|
||||||
|
prefix := "/.well-known/est-mtls/" + pathID
|
||||||
|
r.Register("GET "+prefix+"/cacerts", http.HandlerFunc(hCopy.CACertsMTLS))
|
||||||
|
r.Register("POST "+prefix+"/simpleenroll", http.HandlerFunc(hCopy.SimpleEnrollMTLS))
|
||||||
|
r.Register("POST "+prefix+"/simplereenroll", http.HandlerFunc(hCopy.SimpleReEnrollMTLS))
|
||||||
|
r.Register("GET "+prefix+"/csrattrs", http.HandlerFunc(hCopy.CSRAttrsMTLS))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterSCEPHandlers sets up SCEP (RFC 8894) routes.
|
// RegisterSCEPHandlers sets up SCEP (RFC 8894) routes.
|
||||||
// SCEP uses a single endpoint per profile with operation-based dispatch via
|
// SCEP uses a single endpoint per profile with operation-based dispatch via
|
||||||
// query parameters. Authentication is via the challengePassword attribute in
|
// query parameters. Authentication is via the challengePassword attribute in
|
||||||
|
|||||||
@@ -0,0 +1,369 @@
|
|||||||
|
// Package cms implements the small subset of CMS / RFC 7030 / RFC 9266
|
||||||
|
// helpers that the EST handler needs at request-time: extracting the
|
||||||
|
// RFC 9266 tls-exporter from a *tls.ConnectionState, and pulling the
|
||||||
|
// matching value back out of an EST CSR's CMC unsignedAttribute when the
|
||||||
|
// device proved channel binding.
|
||||||
|
//
|
||||||
|
// Why a separate package (vs adding to internal/api/handler/est.go):
|
||||||
|
//
|
||||||
|
// 1. internal/api/handler depends on internal/pkcs7 already; if the EST
|
||||||
|
// mTLS hardening also pulled CMC parsing into handler we'd grow the
|
||||||
|
// handler-side dep graph by another asn1 surface that has nothing
|
||||||
|
// specific to HTTP.
|
||||||
|
//
|
||||||
|
// 2. Channel-binding extraction is testable in isolation — the unit
|
||||||
|
// tests construct a *tls.ConnectionState with raw exporter bytes and
|
||||||
|
// a *x509.CertificateRequest with the CMC unsignedAttribute already
|
||||||
|
// filled in. No HTTP plumbing required to verify the contract.
|
||||||
|
//
|
||||||
|
// 3. Future EST extensions (RFC 7030 §3.5 fullCMC, RFC 9148 EST-coaps)
|
||||||
|
// are likely to land here too — keep them out of net/http land.
|
||||||
|
//
|
||||||
|
// EST RFC 7030 hardening master bundle Phase 2.4.
|
||||||
|
package cms
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/subtle"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/asn1"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ----- RFC 9266 §3 — TLS exporter extraction -----
|
||||||
|
|
||||||
|
// TLSExporterLabel is the EXPORTER label registered by RFC 9266 §3.1
|
||||||
|
// for use as a TLS-1.3 channel binding. Constant rather than string-typed
|
||||||
|
// so a typo here is a compile error rather than a silent failure mode.
|
||||||
|
const TLSExporterLabel = "EXPORTER-Channel-Binding"
|
||||||
|
|
||||||
|
// TLSExporterLength is the 32-byte exporter length pinned by RFC 9266 §3.1
|
||||||
|
// (matches the SHA-256 output size; clients and servers MUST agree on the
|
||||||
|
// length to make the comparison meaningful).
|
||||||
|
const TLSExporterLength = 32
|
||||||
|
|
||||||
|
// ErrChannelBindingMissing is returned when the EST mTLS handler requires
|
||||||
|
// channel binding (per-profile ChannelBindingRequired=true) but the device's
|
||||||
|
// CSR has no id-aa-est-tls-exporter unsignedAttribute or the attribute is
|
||||||
|
// the wrong shape.
|
||||||
|
var ErrChannelBindingMissing = errors.New("cms: channel binding required but absent or malformed in CSR")
|
||||||
|
|
||||||
|
// ErrChannelBindingMismatch is returned when the device's CSR carried a
|
||||||
|
// channel-binding attribute but its bytes do not match the TLS-1.3 exporter
|
||||||
|
// extracted from the live connection. This is the signal of an MITM that
|
||||||
|
// terminates TLS in front of certctl: the device computed exporter X
|
||||||
|
// against the attacker, certctl sees exporter Y against itself, X≠Y.
|
||||||
|
var ErrChannelBindingMismatch = errors.New("cms: channel binding in CSR does not match TLS exporter")
|
||||||
|
|
||||||
|
// ErrChannelBindingNotTLS13 is returned when the connection is older than
|
||||||
|
// TLS 1.3 and the per-profile config still requires channel binding.
|
||||||
|
// RFC 9266's tls-exporter is a TLS-1.3 binding; pre-1.3 connections would
|
||||||
|
// need RFC 5929 tls-unique, which we deliberately don't support
|
||||||
|
// (certctl pins TLS-1.3 server-side).
|
||||||
|
var ErrChannelBindingNotTLS13 = errors.New("cms: tls-exporter channel binding requires TLS 1.3")
|
||||||
|
|
||||||
|
// ExtractTLSExporter pulls the 32-byte RFC 9266 channel-binding value from
|
||||||
|
// the TLS connection state. The connection must be TLS 1.3 + handshake-
|
||||||
|
// complete; anything else returns a typed error so the caller can map to
|
||||||
|
// HTTP 400 / 412 cleanly.
|
||||||
|
//
|
||||||
|
// Stateless on purpose: callers handle storage + comparison.
|
||||||
|
//
|
||||||
|
// Robustness note: stdlib's ConnectionState.ExportKeyingMaterial nil-derefs
|
||||||
|
// when the underlying secret-derivation closure is unset (i.e. the state
|
||||||
|
// was hand-constructed by a test fixture rather than produced by a real
|
||||||
|
// TLS handshake). The recover() below converts that panic into the same
|
||||||
|
// typed error a missing-binding state would surface, so synthetic test
|
||||||
|
// states + production TLS-1.3 connections share a single failure mode.
|
||||||
|
func ExtractTLSExporter(state *tls.ConnectionState) (out []byte, err error) {
|
||||||
|
if state == nil {
|
||||||
|
return nil, fmt.Errorf("%w: nil ConnectionState", ErrChannelBindingMissing)
|
||||||
|
}
|
||||||
|
if !state.HandshakeComplete {
|
||||||
|
return nil, fmt.Errorf("%w: handshake incomplete", ErrChannelBindingMissing)
|
||||||
|
}
|
||||||
|
// tls.VersionTLS13 == 0x0304. We use the literal so this package doesn't
|
||||||
|
// have to import "crypto/tls" twice (once for tls.VersionTLS13, once for
|
||||||
|
// the *tls.ConnectionState type — Go allows it but it's noisy).
|
||||||
|
if state.Version < 0x0304 {
|
||||||
|
return nil, fmt.Errorf("%w: negotiated 0x%04x", ErrChannelBindingNotTLS13, state.Version)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
out = nil
|
||||||
|
err = fmt.Errorf("%w: ExportKeyingMaterial unavailable on this connection state (panic=%v)", ErrChannelBindingMissing, r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
out, err = state.ExportKeyingMaterial(TLSExporterLabel, nil, TLSExporterLength)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cms: ExportKeyingMaterial: %w", err)
|
||||||
|
}
|
||||||
|
if len(out) != TLSExporterLength {
|
||||||
|
return nil, fmt.Errorf("cms: exporter returned %d bytes, want %d", len(out), TLSExporterLength)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- RFC 7030 §3.5 / RFC 9266 §4.1 — CSR-side channel binding -----
|
||||||
|
|
||||||
|
// OIDChannelBindingTLSExporter is the id-aa-est-tls-exporter OID from
|
||||||
|
// RFC 9266 §4.1 (registered under id-aa = 1.2.840.113549.1.9.16.2 with
|
||||||
|
// arc 56 by RFC 9266). Devices that signed channel binding into their
|
||||||
|
// CSR add a CMC unsignedAttribute with this OID + an OCTET STRING value.
|
||||||
|
//
|
||||||
|
// Note: the EST RFC 7030 §3.5 historical OID for tls-unique is
|
||||||
|
// id-aa-cmc-binding (1.2.840.113549.1.9.16.2.43). RFC 9266 §4.1 added
|
||||||
|
// arc 56 for tls-exporter. We accept BOTH OIDs on the read path so a
|
||||||
|
// device using a slightly older library that still emits the §3.5 OID
|
||||||
|
// continues to work — the value bytes are still the 32-byte exporter
|
||||||
|
// (the OID identifies the binding scheme, not the underlying wire
|
||||||
|
// format).
|
||||||
|
var (
|
||||||
|
OIDChannelBindingTLSExporter = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 16, 2, 56}
|
||||||
|
OIDCMCEnrollmentBinding = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 16, 2, 43}
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExtractCSRChannelBinding looks for the RFC 9266 channel-binding
|
||||||
|
// attribute (or the legacy RFC 7030 §3.5 binding attribute) in the CSR's
|
||||||
|
// raw attributes block. Returns the raw 32-byte exporter value if
|
||||||
|
// present.
|
||||||
|
//
|
||||||
|
// Why we walk csr.RawTBSCertificateRequest manually instead of using
|
||||||
|
// csr.Attributes:
|
||||||
|
//
|
||||||
|
// - csr.Attributes is typed as []pkix.AttributeTypeAndValueSET, where
|
||||||
|
// the inner Value is [][]pkix.AttributeTypeAndValue. That shape only
|
||||||
|
// fits attributes whose AttributeValue is itself a SEQUENCE { OID,
|
||||||
|
// ANY } (e.g. the requestedExtensions attribute). RFC 9266's
|
||||||
|
// TLSExporterValue is `OCTET STRING` — a primitive, not a SEQUENCE
|
||||||
|
// — so the stdlib parse path either drops the attribute silently or
|
||||||
|
// fails the whole CSR parse depending on encoding.
|
||||||
|
//
|
||||||
|
// - The PKCS#10 challengePassword path in scep.go works by accident
|
||||||
|
// because PrintableString happens to round-trip through the
|
||||||
|
// stdlib's interface{}-typed AttributeTypeAndValue.Value. OCTET
|
||||||
|
// STRING does not — it's not in the small list of primitive types
|
||||||
|
// the stdlib's reflect-based unmarshaller handles for `any`.
|
||||||
|
//
|
||||||
|
// - Walking the raw TBS is ~30 lines of asn1.Unmarshal calls and
|
||||||
|
// gives us a stable contract independent of stdlib quirks.
|
||||||
|
//
|
||||||
|
// Returns (value, true, nil) on success; (nil, false, nil) when the
|
||||||
|
// attribute is absent (caller decides whether absence is acceptable per
|
||||||
|
// the per-profile ChannelBindingRequired flag); (nil, false, err) on
|
||||||
|
// malformed attribute (always fatal — a present-but-wrong attribute
|
||||||
|
// signals an attacker rewriting the binding into garbage).
|
||||||
|
func ExtractCSRChannelBinding(csr *x509.CertificateRequest) ([]byte, bool, error) {
|
||||||
|
if csr == nil {
|
||||||
|
return nil, false, fmt.Errorf("cms: nil CSR")
|
||||||
|
}
|
||||||
|
if len(csr.RawTBSCertificateRequest) == 0 {
|
||||||
|
// Stdlib fills RawTBSCertificateRequest on every parse path, so an
|
||||||
|
// empty value here means the caller hand-crafted the struct. Tests
|
||||||
|
// can do that — but real handler-side calls always have raw bytes.
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
return walkCSRAttributesForBinding(csr.RawTBSCertificateRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// walkCSRAttributesForBinding parses just enough of TBSCertificationRequestInfo
|
||||||
|
// to reach the [0] IMPLICIT Attributes field, then iterates each Attribute
|
||||||
|
// looking for the channel-binding OID. The body is intentionally low-level
|
||||||
|
// so we can keep the asn1 footprint contained to this one helper.
|
||||||
|
//
|
||||||
|
// TBSCertificationRequestInfo per RFC 2986 §4.1:
|
||||||
|
//
|
||||||
|
// TBSCertificationRequestInfo ::= SEQUENCE {
|
||||||
|
// version INTEGER (0),
|
||||||
|
// subject Name,
|
||||||
|
// subjectPKInfo SubjectPublicKeyInfo,
|
||||||
|
// attributes [0] IMPLICIT Attributes (SET OF Attribute)
|
||||||
|
// }
|
||||||
|
func walkCSRAttributesForBinding(tbs []byte) ([]byte, bool, error) {
|
||||||
|
// 1. Crack the outer SEQUENCE wrapper.
|
||||||
|
var inner asn1.RawValue
|
||||||
|
if rest, err := asn1.Unmarshal(tbs, &inner); err != nil {
|
||||||
|
return nil, false, fmt.Errorf("cms: TBS outer parse: %w", err)
|
||||||
|
} else if len(rest) > 0 {
|
||||||
|
return nil, false, fmt.Errorf("cms: TBS trailing bytes: %d", len(rest))
|
||||||
|
}
|
||||||
|
if inner.Tag != asn1.TagSequence {
|
||||||
|
return nil, false, fmt.Errorf("cms: TBS outer tag %d not SEQUENCE", inner.Tag)
|
||||||
|
}
|
||||||
|
rest := inner.Bytes
|
||||||
|
|
||||||
|
// 2. Skip version (INTEGER), subject (SEQUENCE = Name), subjectPKInfo
|
||||||
|
// (SEQUENCE). asn1.Unmarshal into asn1.RawValue advances the cursor
|
||||||
|
// without parsing the body — perfect for skipping fields we don't care
|
||||||
|
// about.
|
||||||
|
for i, label := range []string{"version", "subject", "subjectPKInfo"} {
|
||||||
|
var rv asn1.RawValue
|
||||||
|
next, err := asn1.Unmarshal(rest, &rv)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, fmt.Errorf("cms: skip TBS field %d (%s): %w", i, label, err)
|
||||||
|
}
|
||||||
|
rest = next
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Attributes is [0] IMPLICIT — the on-wire tag is 0xA0 with class
|
||||||
|
// CONTEXT-SPECIFIC. asn1.Unmarshal into a RawValue accepts arbitrary
|
||||||
|
// tags; we then walk its Bytes as a SET OF Attribute.
|
||||||
|
var attrsField asn1.RawValue
|
||||||
|
if _, err := asn1.Unmarshal(rest, &attrsField); err != nil {
|
||||||
|
// No attributes block at all — RFC 2986 says [0] is OPTIONAL when
|
||||||
|
// empty (encoders typically omit the field rather than emit an
|
||||||
|
// empty SET). Treat as "no binding present", not as an error.
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
if attrsField.Class != asn1.ClassContextSpecific || attrsField.Tag != 0 {
|
||||||
|
// Some non-attribute-shaped trailing field: not what we expected
|
||||||
|
// but not strictly a corruption signal — skip silently.
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Walk each Attribute in the SET. Each Attribute is
|
||||||
|
// SEQUENCE { OID, SET OF ANY }.
|
||||||
|
attrBytes := attrsField.Bytes
|
||||||
|
for len(attrBytes) > 0 {
|
||||||
|
var oneAttr asn1.RawValue
|
||||||
|
next, err := asn1.Unmarshal(attrBytes, &oneAttr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, fmt.Errorf("cms: walk attributes: %w", err)
|
||||||
|
}
|
||||||
|
attrBytes = next
|
||||||
|
if oneAttr.Tag != asn1.TagSequence {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Inner: OID, then SET.
|
||||||
|
var oid asn1.ObjectIdentifier
|
||||||
|
afterOID, err := asn1.Unmarshal(oneAttr.Bytes, &oid)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !oid.Equal(OIDChannelBindingTLSExporter) && !oid.Equal(OIDCMCEnrollmentBinding) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Now afterOID is the SET wrapper. Crack it and pull the OCTET
|
||||||
|
// STRING out of the SET's first element.
|
||||||
|
var setWrap asn1.RawValue
|
||||||
|
if _, err := asn1.Unmarshal(afterOID, &setWrap); err != nil {
|
||||||
|
return nil, false, fmt.Errorf("cms: binding SET parse: %w (%w)", err, ErrChannelBindingMissing)
|
||||||
|
}
|
||||||
|
if setWrap.Tag != asn1.TagSet {
|
||||||
|
return nil, false, fmt.Errorf("cms: binding outer tag %d not SET (%w)", setWrap.Tag, ErrChannelBindingMissing)
|
||||||
|
}
|
||||||
|
var octet asn1.RawValue
|
||||||
|
if _, err := asn1.Unmarshal(setWrap.Bytes, &octet); err != nil {
|
||||||
|
return nil, false, fmt.Errorf("cms: binding inner parse: %w (%w)", err, ErrChannelBindingMissing)
|
||||||
|
}
|
||||||
|
if octet.Tag != asn1.TagOctetString {
|
||||||
|
return nil, false, fmt.Errorf("cms: binding inner tag %d not OCTET STRING (%w)", octet.Tag, ErrChannelBindingMissing)
|
||||||
|
}
|
||||||
|
if len(octet.Bytes) != TLSExporterLength {
|
||||||
|
return nil, false, fmt.Errorf("cms: binding length %d, want %d (%w)",
|
||||||
|
len(octet.Bytes), TLSExporterLength, ErrChannelBindingMissing)
|
||||||
|
}
|
||||||
|
return octet.Bytes, true, nil
|
||||||
|
}
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyChannelBinding is the convenience composite the EST mTLS handler
|
||||||
|
// calls per request: extract the exporter from the live TLS connection,
|
||||||
|
// pull the matching value from the CSR, compare in constant time.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - nil when the binding is present + matches.
|
||||||
|
// - ErrChannelBindingMissing when the CSR has no binding attribute.
|
||||||
|
// - ErrChannelBindingMismatch when both sides have a value but they
|
||||||
|
// differ (the MITM signal).
|
||||||
|
// - Any error from the exporter extraction (TLS state is wrong, etc).
|
||||||
|
//
|
||||||
|
// The required flag controls absence-handling: when required=false a
|
||||||
|
// missing attribute returns nil (channel binding is optional for this
|
||||||
|
// profile); when required=true a missing attribute returns
|
||||||
|
// ErrChannelBindingMissing.
|
||||||
|
func VerifyChannelBinding(state *tls.ConnectionState, csr *x509.CertificateRequest, required bool) error {
|
||||||
|
live, err := ExtractTLSExporter(state)
|
||||||
|
if err != nil {
|
||||||
|
// If the profile doesn't require channel binding AND the only
|
||||||
|
// problem is "no TLS 1.3 / no handshake", we still let the request
|
||||||
|
// through — the binding is opt-in per profile. But if the CSR
|
||||||
|
// itself carries a binding attribute, the device clearly INTENDED
|
||||||
|
// to bind, so a TLS state mismatch is a genuine error.
|
||||||
|
if !required {
|
||||||
|
if _, present, _ := ExtractCSRChannelBinding(csr); !present {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
csrBinding, present, err := ExtractCSRChannelBinding(csr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !present {
|
||||||
|
if required {
|
||||||
|
return ErrChannelBindingMissing
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if subtle.ConstantTimeCompare(live, csrBinding) != 1 {
|
||||||
|
return ErrChannelBindingMismatch
|
||||||
|
}
|
||||||
|
// Sanity: the comparison should be identical bytes for matching cases.
|
||||||
|
// The bytes.Equal call is dead code under correct subtle.Compare result;
|
||||||
|
// it's here only to make the contract obvious to readers and to pin the
|
||||||
|
// symmetry test that asserts ExtractCSRChannelBinding is byte-equivalent
|
||||||
|
// to ExtractTLSExporter when the device behaved correctly.
|
||||||
|
if !bytes.Equal(live, csrBinding) {
|
||||||
|
return ErrChannelBindingMismatch
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbedChannelBindingAttribute is the test helper inverse of
|
||||||
|
// ExtractCSRChannelBinding: given an exporter value, returns the DER
|
||||||
|
// bytes of the Attribute (SEQUENCE { OID, SET { OCTET STRING } }) that
|
||||||
|
// the caller can splice into the [0] IMPLICIT Attributes field of
|
||||||
|
// TBSCertificationRequestInfo. Used by the EST channel-binding tests
|
||||||
|
// AND by any external caller that wants to forge a CSR with a known
|
||||||
|
// binding for fixture generation.
|
||||||
|
func EmbedChannelBindingAttribute(exporter []byte) ([]byte, error) {
|
||||||
|
if len(exporter) != TLSExporterLength {
|
||||||
|
return nil, fmt.Errorf("cms: exporter length %d, want %d", len(exporter), TLSExporterLength)
|
||||||
|
}
|
||||||
|
octet, err := asn1.Marshal(exporter) // marshal []byte as OCTET STRING
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cms: marshal exporter octet: %w", err)
|
||||||
|
}
|
||||||
|
// Wrap in SET OF.
|
||||||
|
setBody := octet
|
||||||
|
setEnvelope, err := asn1.Marshal(asn1.RawValue{
|
||||||
|
Class: asn1.ClassUniversal,
|
||||||
|
Tag: asn1.TagSet,
|
||||||
|
IsCompound: true,
|
||||||
|
Bytes: setBody,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cms: marshal SET: %w", err)
|
||||||
|
}
|
||||||
|
oid, err := asn1.Marshal(OIDChannelBindingTLSExporter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cms: marshal OID: %w", err)
|
||||||
|
}
|
||||||
|
// Wrap as SEQUENCE { OID, SET }.
|
||||||
|
seqBody := append(append([]byte{}, oid...), setEnvelope...)
|
||||||
|
seqEnvelope, err := asn1.Marshal(asn1.RawValue{
|
||||||
|
Class: asn1.ClassUniversal,
|
||||||
|
Tag: asn1.TagSequence,
|
||||||
|
IsCompound: true,
|
||||||
|
Bytes: seqBody,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cms: marshal SEQUENCE: %w", err)
|
||||||
|
}
|
||||||
|
return seqEnvelope, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,394 @@
|
|||||||
|
package cms
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/asn1"
|
||||||
|
"errors"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EST RFC 7030 hardening master bundle Phase 2.4 tests.
|
||||||
|
|
||||||
|
// ----- ExtractTLSExporter -----
|
||||||
|
|
||||||
|
func TestExtractTLSExporter_NilState(t *testing.T) {
|
||||||
|
if _, err := ExtractTLSExporter(nil); !errors.Is(err, ErrChannelBindingMissing) {
|
||||||
|
t.Errorf("nil state should return ErrChannelBindingMissing, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractTLSExporter_HandshakeNotComplete(t *testing.T) {
|
||||||
|
state := &tls.ConnectionState{HandshakeComplete: false, Version: 0x0304}
|
||||||
|
if _, err := ExtractTLSExporter(state); !errors.Is(err, ErrChannelBindingMissing) {
|
||||||
|
t.Errorf("incomplete handshake should return ErrChannelBindingMissing, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractTLSExporter_PreTLS13Rejected(t *testing.T) {
|
||||||
|
state := &tls.ConnectionState{HandshakeComplete: true, Version: 0x0303} // TLS 1.2
|
||||||
|
if _, err := ExtractTLSExporter(state); !errors.Is(err, ErrChannelBindingNotTLS13) {
|
||||||
|
t.Errorf("TLS 1.2 should return ErrChannelBindingNotTLS13, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExtractTLSExporter_TLS13EndToEnd is the only test that builds a full
|
||||||
|
// real TLS-1.3 session — the exporter is computed on the connection's secret
|
||||||
|
// state, so we can't fake the ConnectionState. We spin up a localhost TCP
|
||||||
|
// listener, do a handshake, and then call ExportKeyingMaterial directly to
|
||||||
|
// pin the contract. This is a small round-trip but we're not testing TLS
|
||||||
|
// itself — just that ExtractTLSExporter pulls a 32-byte value from a real
|
||||||
|
// 1.3 state.
|
||||||
|
func TestExtractTLSExporter_TLS13EndToEnd(t *testing.T) {
|
||||||
|
cert, key := freshSelfSignedTLSCert(t)
|
||||||
|
tlsCert := tls.Certificate{Certificate: [][]byte{cert.Raw}, PrivateKey: key}
|
||||||
|
|
||||||
|
cfg := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{tlsCert},
|
||||||
|
MinVersion: tls.VersionTLS13,
|
||||||
|
MaxVersion: tls.VersionTLS13,
|
||||||
|
}
|
||||||
|
clientCfg := &tls.Config{
|
||||||
|
InsecureSkipVerify: true, //nolint:gosec // hermetic test cert; not for production use
|
||||||
|
MinVersion: tls.VersionTLS13,
|
||||||
|
MaxVersion: tls.VersionTLS13,
|
||||||
|
}
|
||||||
|
|
||||||
|
ln, err := tls.Listen("tcp", "127.0.0.1:0", cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tls.Listen: %v", err)
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
// Finish the handshake on the server side.
|
||||||
|
_ = conn.(*tls.Conn).HandshakeContext(context.Background())
|
||||||
|
// Hold the connection open until the client side completes its read.
|
||||||
|
buf := make([]byte, 1)
|
||||||
|
_, _ = conn.Read(buf)
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn, err := tls.Dial("tcp", ln.Addr().String(), clientCfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tls.Dial: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
if err := conn.HandshakeContext(context.Background()); err != nil {
|
||||||
|
t.Fatalf("client handshake: %v", err)
|
||||||
|
}
|
||||||
|
state := conn.ConnectionState()
|
||||||
|
|
||||||
|
out, err := ExtractTLSExporter(&state)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExtractTLSExporter: %v", err)
|
||||||
|
}
|
||||||
|
if len(out) != TLSExporterLength {
|
||||||
|
t.Errorf("len(out) = %d, want %d", len(out), TLSExporterLength)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- ExtractCSRChannelBinding -----
|
||||||
|
|
||||||
|
func TestExtractCSRChannelBinding_NilCSR(t *testing.T) {
|
||||||
|
if _, _, err := ExtractCSRChannelBinding(nil); err == nil {
|
||||||
|
t.Fatal("nil CSR should error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractCSRChannelBinding_AbsentReturnsFalse(t *testing.T) {
|
||||||
|
csr := freshCSRNoBinding(t)
|
||||||
|
val, present, err := ExtractCSRChannelBinding(csr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExtractCSRChannelBinding: %v", err)
|
||||||
|
}
|
||||||
|
if present {
|
||||||
|
t.Errorf("present=true on a CSR without the binding attribute (val=%x)", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractCSRChannelBinding_PresentReturnsExporter(t *testing.T) {
|
||||||
|
exporter := repeatByte(0x42, TLSExporterLength)
|
||||||
|
csr := freshCSRWithBinding(t, exporter, OIDChannelBindingTLSExporter)
|
||||||
|
val, present, err := ExtractCSRChannelBinding(csr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExtractCSRChannelBinding: %v", err)
|
||||||
|
}
|
||||||
|
if !present {
|
||||||
|
t.Fatal("present=false on a CSR that carries the binding")
|
||||||
|
}
|
||||||
|
if !bytesEq(val, exporter) {
|
||||||
|
t.Errorf("exporter = %x, want %x", val, exporter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractCSRChannelBinding_LegacyOIDAccepted(t *testing.T) {
|
||||||
|
exporter := repeatByte(0xAA, TLSExporterLength)
|
||||||
|
csr := freshCSRWithBinding(t, exporter, OIDCMCEnrollmentBinding)
|
||||||
|
val, present, err := ExtractCSRChannelBinding(csr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("legacy-OID path failed: %v", err)
|
||||||
|
}
|
||||||
|
if !present || !bytesEq(val, exporter) {
|
||||||
|
t.Errorf("legacy-OID extraction: got present=%v val=%x, want present=true val=%x", present, val, exporter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractCSRChannelBinding_WrongLengthRejected(t *testing.T) {
|
||||||
|
short := repeatByte(0x55, 16) // half the required length
|
||||||
|
csr := freshCSRWithBinding(t, short, OIDChannelBindingTLSExporter)
|
||||||
|
_, _, err := ExtractCSRChannelBinding(csr)
|
||||||
|
if !errors.Is(err, ErrChannelBindingMissing) {
|
||||||
|
t.Errorf("wrong-length binding should wrap ErrChannelBindingMissing, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- VerifyChannelBinding (composite) -----
|
||||||
|
|
||||||
|
func TestVerifyChannelBinding_NotRequired_NoBinding_Passes(t *testing.T) {
|
||||||
|
csr := freshCSRNoBinding(t)
|
||||||
|
if err := VerifyChannelBinding(nil, csr, false); err != nil {
|
||||||
|
t.Errorf("required=false + no binding should pass; got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyChannelBinding_Required_NilState_Errors(t *testing.T) {
|
||||||
|
csr := freshCSRNoBinding(t)
|
||||||
|
if err := VerifyChannelBinding(nil, csr, true); err == nil {
|
||||||
|
t.Fatal("required=true + nil state must error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: a synthetic *tls.ConnectionState{HandshakeComplete:true, Version:0x0304}
|
||||||
|
// would seem like the obvious VerifyChannelBinding(required=true) negative-case
|
||||||
|
// fixture, but stdlib's ExportKeyingMaterial nil-derefs when the underlying
|
||||||
|
// secret state is unset (see crypto/tls/common.go:330). The
|
||||||
|
// "no live exporter available" branch is genuinely only reachable via a real
|
||||||
|
// connection (TestExtractTLSExporter_TLS13EndToEnd above), so we don't try to
|
||||||
|
// fake it here. The TestVerifyChannelBinding_NotRequired_NoBinding_Passes +
|
||||||
|
// TestVerifyChannelBinding_Required_NilState_Errors tests cover the policy
|
||||||
|
// branches; production code paths only ever pass r.TLS from a live request.
|
||||||
|
|
||||||
|
// ----- helpers -----
|
||||||
|
|
||||||
|
// freshCSRNoBinding returns a CSR with no extra attributes.
|
||||||
|
func freshCSRNoBinding(t *testing.T) *x509.CertificateRequest {
|
||||||
|
t.Helper()
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ecdsa.GenerateKey: %v", err)
|
||||||
|
}
|
||||||
|
tmpl := &x509.CertificateRequest{Subject: pkix.Name{CommonName: "no-binding-test"}}
|
||||||
|
der, err := x509.CreateCertificateRequest(rand.Reader, tmpl, key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateCertificateRequest: %v", err)
|
||||||
|
}
|
||||||
|
csr, err := x509.ParseCertificateRequest(der)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseCertificateRequest: %v", err)
|
||||||
|
}
|
||||||
|
return csr
|
||||||
|
}
|
||||||
|
|
||||||
|
// freshCSRWithBinding builds a CSR whose TBS carries the channel-binding
|
||||||
|
// attribute. The stdlib's CreateCertificateRequest doesn't support arbitrary
|
||||||
|
// attributes (only ExtraExtensions), so we hand-craft the TBS by parsing
|
||||||
|
// what stdlib produced + splicing our attribute into the [0] IMPLICIT
|
||||||
|
// Attributes block + re-signing.
|
||||||
|
func freshCSRWithBinding(t *testing.T, exporter []byte, oid asn1.ObjectIdentifier) *x509.CertificateRequest {
|
||||||
|
t.Helper()
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ecdsa.GenerateKey: %v", err)
|
||||||
|
}
|
||||||
|
// 1. Get a baseline CSR with no attributes — we steal its TBS shape.
|
||||||
|
tmpl := &x509.CertificateRequest{Subject: pkix.Name{CommonName: "binding-test"}}
|
||||||
|
derBaseline, err := x509.CreateCertificateRequest(rand.Reader, tmpl, key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateCertificateRequest: %v", err)
|
||||||
|
}
|
||||||
|
baseline, err := x509.ParseCertificateRequest(derBaseline)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseCertificateRequest: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Build the channel-binding attribute (SEQUENCE { OID, SET { OCTET STRING }}).
|
||||||
|
octet, err := asn1.Marshal(exporter)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal octet: %v", err)
|
||||||
|
}
|
||||||
|
setEnv, err := asn1.Marshal(asn1.RawValue{Class: asn1.ClassUniversal, Tag: asn1.TagSet, IsCompound: true, Bytes: octet})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal set: %v", err)
|
||||||
|
}
|
||||||
|
oidBytes, err := asn1.Marshal(oid)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal oid: %v", err)
|
||||||
|
}
|
||||||
|
attrSeq, err := asn1.Marshal(asn1.RawValue{
|
||||||
|
Class: asn1.ClassUniversal,
|
||||||
|
Tag: asn1.TagSequence,
|
||||||
|
IsCompound: true,
|
||||||
|
Bytes: append(append([]byte{}, oidBytes...), setEnv...),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal attribute SEQUENCE: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Splice attribute into a [0] IMPLICIT Attributes block and rebuild
|
||||||
|
// the TBS by hand. The TBS structure is:
|
||||||
|
// SEQUENCE { version INTEGER, subject Name, subjectPKInfo SubjectPublicKeyInfo,
|
||||||
|
// attributes [0] IMPLICIT SET OF Attribute }
|
||||||
|
// We re-extract the first three fields from the baseline TBS and
|
||||||
|
// re-marshal with our attribute appended.
|
||||||
|
var outer asn1.RawValue
|
||||||
|
if _, err := asn1.Unmarshal(baseline.RawTBSCertificateRequest, &outer); err != nil {
|
||||||
|
t.Fatalf("baseline TBS unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
rest := outer.Bytes
|
||||||
|
var version, subject, spki asn1.RawValue
|
||||||
|
for _, target := range []*asn1.RawValue{&version, &subject, &spki} {
|
||||||
|
next, err := asn1.Unmarshal(rest, target)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("baseline TBS skip: %v", err)
|
||||||
|
}
|
||||||
|
rest = next
|
||||||
|
}
|
||||||
|
versionDER, _ := asn1.Marshal(version)
|
||||||
|
subjectDER, _ := asn1.Marshal(subject)
|
||||||
|
spkiDER, _ := asn1.Marshal(spki)
|
||||||
|
// Build the [0] IMPLICIT Attributes wrapper.
|
||||||
|
attrsField, err := asn1.Marshal(asn1.RawValue{
|
||||||
|
Class: asn1.ClassContextSpecific,
|
||||||
|
Tag: 0,
|
||||||
|
IsCompound: true,
|
||||||
|
Bytes: attrSeq,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal attrs field: %v", err)
|
||||||
|
}
|
||||||
|
tbsBody := append(append(append(append([]byte{}, versionDER...), subjectDER...), spkiDER...), attrsField...)
|
||||||
|
newTBS, err := asn1.Marshal(asn1.RawValue{
|
||||||
|
Class: asn1.ClassUniversal,
|
||||||
|
Tag: asn1.TagSequence,
|
||||||
|
IsCompound: true,
|
||||||
|
Bytes: tbsBody,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("re-marshal TBS: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Parse the new TBS — we don't need to re-sign for these tests
|
||||||
|
// (ExtractCSRChannelBinding doesn't verify the signature; it walks
|
||||||
|
// RawTBSCertificateRequest only).
|
||||||
|
csr := &x509.CertificateRequest{
|
||||||
|
RawTBSCertificateRequest: newTBS,
|
||||||
|
Subject: baseline.Subject,
|
||||||
|
PublicKey: baseline.PublicKey,
|
||||||
|
}
|
||||||
|
return csr
|
||||||
|
}
|
||||||
|
|
||||||
|
// freshSelfSignedTLSCert produces a tls.Certificate-compatible cert+key for
|
||||||
|
// the TLS-1.3 round-trip test.
|
||||||
|
func freshSelfSignedTLSCert(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) {
|
||||||
|
t.Helper()
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ecdsa.GenerateKey: %v", err)
|
||||||
|
}
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
Subject: pkix.Name{CommonName: "tls-test"},
|
||||||
|
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||||
|
DNSNames: []string{"localhost"},
|
||||||
|
}
|
||||||
|
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateCertificate: %v", err)
|
||||||
|
}
|
||||||
|
cert, err := x509.ParseCertificate(der)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseCertificate: %v", err)
|
||||||
|
}
|
||||||
|
return cert, key
|
||||||
|
}
|
||||||
|
|
||||||
|
// repeatByte returns a slice of length n filled with b. Used for fixture
|
||||||
|
// exporter values where we need a deterministic test pattern.
|
||||||
|
func repeatByte(b byte, n int) []byte {
|
||||||
|
out := make([]byte, n)
|
||||||
|
for i := range out {
|
||||||
|
out[i] = b
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func bytesEq(a, b []byte) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i] != b[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbedChannelBindingAttribute round-trip — pins the spec contract that
|
||||||
|
// what we marshal can be parsed back by ExtractCSRChannelBinding without
|
||||||
|
// going through the freshCSRWithBinding splice helper.
|
||||||
|
func TestEmbedChannelBindingAttribute_RoundTrip(t *testing.T) {
|
||||||
|
exporter := repeatByte(0x77, TLSExporterLength)
|
||||||
|
attrDER, err := EmbedChannelBindingAttribute(exporter)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("EmbedChannelBindingAttribute: %v", err)
|
||||||
|
}
|
||||||
|
// Wrap the single attribute in a [0] IMPLICIT SET OF Attribute block
|
||||||
|
// and a TBS-lookalike SEQUENCE so we can feed it through the same path
|
||||||
|
// the parser uses — the parser doesn't care that version+subject+spki
|
||||||
|
// are absent because it walks structurally.
|
||||||
|
attrsField, err := asn1.Marshal(asn1.RawValue{
|
||||||
|
Class: asn1.ClassContextSpecific,
|
||||||
|
Tag: 0,
|
||||||
|
IsCompound: true,
|
||||||
|
Bytes: attrDER,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal attrs field: %v", err)
|
||||||
|
}
|
||||||
|
// Synthetic TBS with three placeholder asn1.RawValue fields then attrsField.
|
||||||
|
placeholder, _ := asn1.Marshal(asn1.RawValue{Class: asn1.ClassUniversal, Tag: asn1.TagInteger, Bytes: []byte{0x00}})
|
||||||
|
body := append(append(append(append([]byte{}, placeholder...), placeholder...), placeholder...), attrsField...)
|
||||||
|
tbs, err := asn1.Marshal(asn1.RawValue{
|
||||||
|
Class: asn1.ClassUniversal,
|
||||||
|
Tag: asn1.TagSequence,
|
||||||
|
IsCompound: true,
|
||||||
|
Bytes: body,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal TBS: %v", err)
|
||||||
|
}
|
||||||
|
got, present, err := walkCSRAttributesForBinding(tbs)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("walkCSRAttributesForBinding: %v", err)
|
||||||
|
}
|
||||||
|
if !present {
|
||||||
|
t.Fatal("present=false on round-trip")
|
||||||
|
}
|
||||||
|
if !bytesEq(got, exporter) {
|
||||||
|
t.Errorf("round-trip mismatch: got %x, want %x", got, exporter)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,188 @@
|
|||||||
|
// Package ratelimit provides shared rate-limit primitives used by
|
||||||
|
// authenticated-but-shared-credential code paths (SCEP/Intune
|
||||||
|
// per-device challenge enrollment, EST per-principal CSR enrollment,
|
||||||
|
// EST HTTP-Basic source-IP failed-auth limiter) where the threat
|
||||||
|
// model is "single legitimate identity could mint enrollments
|
||||||
|
// faster than any human/fleet workflow would."
|
||||||
|
//
|
||||||
|
// Origin: this package was extracted from
|
||||||
|
// internal/scep/intune/rate_limit.go in the EST RFC 7030 hardening
|
||||||
|
// master bundle Phase 4.1 — EST is the third caller after the
|
||||||
|
// Intune dispatcher (per-device-GUID cap on enrollment) and the EST
|
||||||
|
// per-principal cap (Phase 4.2). The original Intune-package type +
|
||||||
|
// constructor + ErrRateLimited sentinel are preserved as type
|
||||||
|
// aliases at internal/scep/intune/rate_limit.go so existing call
|
||||||
|
// sites compile unchanged. New callers SHOULD use this package
|
||||||
|
// directly.
|
||||||
|
//
|
||||||
|
// Algorithm: sliding window log. Each key maps to a bucket holding
|
||||||
|
// timestamps within the configured window. On Allow, the bucket
|
||||||
|
// prunes timestamps older than (now - window) and either appends +
|
||||||
|
// returns nil, or rejects + returns ErrRateLimited when the
|
||||||
|
// post-prune count is already at the cap. Exact (no token-leak
|
||||||
|
// rounding); O(N_per_key) per-call but N is bounded by the cap, so
|
||||||
|
// effectively O(1).
|
||||||
|
//
|
||||||
|
// Concurrency: safe for concurrent Allow calls. Internal map guarded
|
||||||
|
// by sync.Mutex; per-key slices mutated only while the mutex is
|
||||||
|
// held.
|
||||||
|
//
|
||||||
|
// Memory: bounded by the per-instance map cap (default 100,000 keys;
|
||||||
|
// configurable). At-cap eviction drops the oldest entry by newest
|
||||||
|
// timestamp — small janitor pass; rarely fires in practice because
|
||||||
|
// the prune-on-Allow path keeps most buckets short-lived.
|
||||||
|
package ratelimit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrRateLimited is returned by SlidingWindowLimiter.Allow when the
|
||||||
|
// bucket for the given key is already at the cap. Callers can
|
||||||
|
// errors.Is against this sentinel; the underlying message is stable
|
||||||
|
// across the package's lifetime so test assertions can match on it.
|
||||||
|
var ErrRateLimited = errors.New("ratelimit: per-key cap exceeded for the configured window")
|
||||||
|
|
||||||
|
// SlidingWindowLimiter is the sliding-window-log rate limiter.
|
||||||
|
//
|
||||||
|
// Construct via NewSlidingWindowLimiter. The zero value is NOT
|
||||||
|
// usable — the buckets map needs initialisation.
|
||||||
|
type SlidingWindowLimiter struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
buckets map[string][]time.Time // key → sliding window of timestamps
|
||||||
|
maxN int // max enrollments per window
|
||||||
|
window time.Duration // window length (default 24h)
|
||||||
|
cap int // max keys before LRU eviction kicks in
|
||||||
|
disabled bool // maxN <= 0 → all Allow calls return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSlidingWindowLimiter returns a limiter with the given per-key
|
||||||
|
// cap + window. maxN <= 0 disables the limiter (all Allow calls
|
||||||
|
// return nil); this is operator opt-out for the rare case where the
|
||||||
|
// per-key cap is undesirable (test harnesses, sketchpad deploys).
|
||||||
|
//
|
||||||
|
// Window defaults to 24h when zero. Map cap defaults to 100,000 when
|
||||||
|
// zero (matches the SCEP/Intune replay cache cap).
|
||||||
|
func NewSlidingWindowLimiter(maxN int, window time.Duration, mapCap int) *SlidingWindowLimiter {
|
||||||
|
if window <= 0 {
|
||||||
|
window = 24 * time.Hour
|
||||||
|
}
|
||||||
|
if mapCap <= 0 {
|
||||||
|
mapCap = 100_000
|
||||||
|
}
|
||||||
|
return &SlidingWindowLimiter{
|
||||||
|
buckets: make(map[string][]time.Time),
|
||||||
|
maxN: maxN,
|
||||||
|
window: window,
|
||||||
|
cap: mapCap,
|
||||||
|
disabled: maxN <= 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow reports whether an event keyed by `key` is permitted right
|
||||||
|
// now. Returns nil when allowed (and records the timestamp in the
|
||||||
|
// bucket) or ErrRateLimited when the bucket is at maxN.
|
||||||
|
//
|
||||||
|
// Empty key is treated as "skip the limiter" — the caller's
|
||||||
|
// validation should have rejected an empty-key event already; this
|
||||||
|
// is belt-and-suspenders so a single empty-key bucket doesn't
|
||||||
|
// become a chokepoint for every empty-key event. SCEP/Intune
|
||||||
|
// callers compose the key as `subject + "|" + issuer`; EST callers
|
||||||
|
// compose `cn + "|" + sourceIP` or `sourceIP`-alone for the
|
||||||
|
// failed-auth limiter.
|
||||||
|
func (l *SlidingWindowLimiter) Allow(key string, now time.Time) error {
|
||||||
|
if l.disabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if key == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
// At-cap eviction: when the map is full, drop the oldest entry
|
||||||
|
// by finding the bucket whose newest timestamp is the smallest.
|
||||||
|
// O(N_keys) but rarely fires; the prune-on-Allow path keeps
|
||||||
|
// most buckets short-lived.
|
||||||
|
if len(l.buckets) >= l.cap {
|
||||||
|
l.evictOldestLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
bucket := l.buckets[key]
|
||||||
|
bucket = pruneOlderThan(bucket, now.Add(-l.window))
|
||||||
|
|
||||||
|
if len(bucket) >= l.maxN {
|
||||||
|
// Don't append; over the limit. Persist the pruned bucket so
|
||||||
|
// the next call sees the most-recently-pruned state.
|
||||||
|
l.buckets[key] = bucket
|
||||||
|
return ErrRateLimited
|
||||||
|
}
|
||||||
|
|
||||||
|
bucket = append(bucket, now)
|
||||||
|
l.buckets[key] = bucket
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// pruneOlderThan returns the slice with all entries strictly before
|
||||||
|
// `cutoff` removed. Preserves order (timestamps are appended in
|
||||||
|
// increasing time, so a single linear scan from the front suffices).
|
||||||
|
func pruneOlderThan(b []time.Time, cutoff time.Time) []time.Time {
|
||||||
|
i := 0
|
||||||
|
for i < len(b) && b[i].Before(cutoff) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
if i == 0 {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
// Copy-shrink to release the underlying-array memory eventually
|
||||||
|
// (otherwise the slice would hold a reference to the older
|
||||||
|
// entries indefinitely until a re-allocation).
|
||||||
|
out := make([]time.Time, len(b)-i)
|
||||||
|
copy(out, b[i:])
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// evictOldestLocked drops the map entry whose newest timestamp is
|
||||||
|
// the oldest. Called under l.mu. O(N_keys) per eviction; at-cap is
|
||||||
|
// rare in practice (caps are sized for steady-state).
|
||||||
|
func (l *SlidingWindowLimiter) evictOldestLocked() {
|
||||||
|
var (
|
||||||
|
oldestKey string
|
||||||
|
oldestTs time.Time
|
||||||
|
first = true
|
||||||
|
)
|
||||||
|
for k, b := range l.buckets {
|
||||||
|
if len(b) == 0 {
|
||||||
|
// Empty bucket — drop it immediately, no candidate scan needed.
|
||||||
|
delete(l.buckets, k)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
newest := b[len(b)-1]
|
||||||
|
if first || newest.Before(oldestTs) {
|
||||||
|
oldestKey = k
|
||||||
|
oldestTs = newest
|
||||||
|
first = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if oldestKey != "" {
|
||||||
|
delete(l.buckets, oldestKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the approximate number of distinct keys currently
|
||||||
|
// tracked. For observability + tests; not load-stable under
|
||||||
|
// concurrent Allow calls.
|
||||||
|
func (l *SlidingWindowLimiter) Len() int {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
return len(l.buckets)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disabled reports whether the limiter is in opt-out mode (maxN <= 0).
|
||||||
|
// Useful for handler-side gating + admin-endpoint observability.
|
||||||
|
func (l *SlidingWindowLimiter) Disabled() bool {
|
||||||
|
return l.disabled
|
||||||
|
}
|
||||||
@@ -0,0 +1,197 @@
|
|||||||
|
package ratelimit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EST RFC 7030 hardening master bundle Phase 4.1: this test file holds the
|
||||||
|
// white-box tests for the SlidingWindowLimiter primitives that used to live
|
||||||
|
// in internal/scep/intune/rate_limit_test.go (TestPerDeviceRateLimiter_
|
||||||
|
// DefaultCapsHonored, TestPruneOlderThan, TestPruneOlderThan_NoOpWhen
|
||||||
|
// NothingToPrune). The behavioral coverage in intune/rate_limit_test.go
|
||||||
|
// stays — it exercises the wrapper's (subject, issuer)-composition contract
|
||||||
|
// + the empty-subject short-circuit + concurrent race-freedom.
|
||||||
|
|
||||||
|
func TestSlidingWindowLimiter_AllowsUpToCap(t *testing.T) {
|
||||||
|
l := NewSlidingWindowLimiter(3, 24*time.Hour, 10)
|
||||||
|
now := time.Now()
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
if err := l.Allow("k", now.Add(time.Duration(i)*time.Minute)); err != nil {
|
||||||
|
t.Fatalf("call %d should be allowed: %v", i+1, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := l.Allow("k", now.Add(4*time.Minute)); !errors.Is(err, ErrRateLimited) {
|
||||||
|
t.Fatalf("4th call should be rate-limited; got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSlidingWindowLimiter_DistinctKeysIndependent(t *testing.T) {
|
||||||
|
l := NewSlidingWindowLimiter(1, 24*time.Hour, 10)
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
if err := l.Allow("k-1", now); err != nil {
|
||||||
|
t.Fatalf("first allow: %v", err)
|
||||||
|
}
|
||||||
|
if err := l.Allow("k-2", now); err != nil {
|
||||||
|
t.Fatalf("different key must have its own bucket: %v", err)
|
||||||
|
}
|
||||||
|
if err := l.Allow("k-1", now.Add(1*time.Second)); !errors.Is(err, ErrRateLimited) {
|
||||||
|
t.Fatalf("repeat key should be limited; got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSlidingWindowLimiter_WindowExpiry(t *testing.T) {
|
||||||
|
l := NewSlidingWindowLimiter(2, 1*time.Hour, 10)
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
if err := l.Allow("k", now); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := l.Allow("k", now.Add(30*time.Minute)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// Inside window — limited.
|
||||||
|
if err := l.Allow("k", now.Add(45*time.Minute)); !errors.Is(err, ErrRateLimited) {
|
||||||
|
t.Fatalf("inside-window 3rd call should be limited: %v", err)
|
||||||
|
}
|
||||||
|
// Past window — slots reopen.
|
||||||
|
if err := l.Allow("k", now.Add(2*time.Hour)); err != nil {
|
||||||
|
t.Fatalf("past-window call should be allowed (window reset): %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSlidingWindowLimiter_DisabledBypass(t *testing.T) {
|
||||||
|
l := NewSlidingWindowLimiter(0, 24*time.Hour, 10) // maxN=0 → disabled
|
||||||
|
if !l.Disabled() {
|
||||||
|
t.Fatal("limiter with maxN=0 must report Disabled()=true")
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
if err := l.Allow("k", now); err != nil {
|
||||||
|
t.Fatalf("disabled limiter must allow everything: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if got := l.Len(); got != 0 {
|
||||||
|
t.Errorf("disabled limiter Len() = %d, want 0", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSlidingWindowLimiter_NegativeCapDisabled(t *testing.T) {
|
||||||
|
l := NewSlidingWindowLimiter(-1, 24*time.Hour, 10)
|
||||||
|
if !l.Disabled() {
|
||||||
|
t.Fatal("negative maxN must produce a disabled limiter")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSlidingWindowLimiter_EmptyKeyShortCircuits(t *testing.T) {
|
||||||
|
// Empty key is the caller's defense-in-depth case — caller's validation
|
||||||
|
// upstream should reject empty-key events first. Limiter must not build
|
||||||
|
// a single shared bucket keyed by empty-key — that would be a chokepoint
|
||||||
|
// for every empty-key event.
|
||||||
|
l := NewSlidingWindowLimiter(1, 24*time.Hour, 10)
|
||||||
|
now := time.Now()
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
if err := l.Allow("", now); err != nil {
|
||||||
|
t.Fatalf("empty key must short-circuit (call %d): %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if got := l.Len(); got != 0 {
|
||||||
|
t.Errorf("Len after 50 empty-key calls = %d, want 0 (no bucket created)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSlidingWindowLimiter_DefaultCapsHonored(t *testing.T) {
|
||||||
|
// White-box test: exercises the constructor's default-fill branches.
|
||||||
|
// Lives here (not in the intune wrapper test) because the fields
|
||||||
|
// (window + cap) are package-private to ratelimit.
|
||||||
|
l := NewSlidingWindowLimiter(5, 0, 0) // window=0 → 24h default; cap=0 → 100k default
|
||||||
|
if l.window != 24*time.Hour {
|
||||||
|
t.Errorf("default window = %v, want 24h", l.window)
|
||||||
|
}
|
||||||
|
if l.cap != 100_000 {
|
||||||
|
t.Errorf("default cap = %d, want 100000", l.cap)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSlidingWindowLimiter_MapCapEvictsOldest(t *testing.T) {
|
||||||
|
// Cap of 3 keys to exercise the eviction branch deterministically.
|
||||||
|
l := NewSlidingWindowLimiter(2, 1*time.Hour, 3)
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
key := fmt.Sprintf("k-%d", i)
|
||||||
|
if err := l.Allow(key, now.Add(time.Duration(i)*time.Minute)); err != nil {
|
||||||
|
t.Fatalf("insert %d: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if l.Len() != 3 {
|
||||||
|
t.Fatalf("Len = %d, want 3", l.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4th key forces eviction of k-0 (its newest timestamp is oldest).
|
||||||
|
if err := l.Allow("k-3", now.Add(10*time.Minute)); err != nil {
|
||||||
|
t.Fatalf("4th-key insert: %v", err)
|
||||||
|
}
|
||||||
|
if l.Len() != 3 {
|
||||||
|
t.Errorf("Len after at-cap insert = %d, want 3 (cap honored)", l.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSlidingWindowLimiter_ConcurrentRaceFree(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("race-style test under -short")
|
||||||
|
}
|
||||||
|
l := NewSlidingWindowLimiter(50, 24*time.Hour, 10000)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for g := 0; g < 20; g++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
now := time.Now()
|
||||||
|
key := fmt.Sprintf("k-%d", id)
|
||||||
|
for i := 0; i < 30; i++ {
|
||||||
|
_ = l.Allow(key, now)
|
||||||
|
}
|
||||||
|
}(g)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
if got := l.Len(); got != 20 {
|
||||||
|
t.Errorf("expected 20 distinct keys; got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// White-box tests for the unexported pruneOlderThan helper. Live in this
|
||||||
|
// package because the helper is package-private to ratelimit. The test
|
||||||
|
// surface used to live in intune/rate_limit_test.go before the Phase 4.1
|
||||||
|
// extraction.
|
||||||
|
func TestPruneOlderThan(t *testing.T) {
|
||||||
|
t0 := time.Now()
|
||||||
|
in := []time.Time{
|
||||||
|
t0.Add(-3 * time.Hour), // pruned (older than cutoff)
|
||||||
|
t0.Add(-2 * time.Hour), // pruned (older than cutoff)
|
||||||
|
t0.Add(-1 * time.Hour), // survives (-60m is NEWER than the -90m cutoff)
|
||||||
|
t0.Add(-30 * time.Minute), // survives
|
||||||
|
t0, // survives
|
||||||
|
}
|
||||||
|
out := pruneOlderThan(in, t0.Add(-90*time.Minute))
|
||||||
|
if len(out) != 3 {
|
||||||
|
t.Fatalf("len(out) = %d, want 3 (-1h, -30m, t0 all newer than -90m cutoff)", len(out))
|
||||||
|
}
|
||||||
|
if !out[0].Equal(t0.Add(-1 * time.Hour)) {
|
||||||
|
t.Errorf("out[0] = %v, want -1h (oldest surviving entry)", out[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPruneOlderThan_NoOpWhenNothingToPrune(t *testing.T) {
|
||||||
|
t0 := time.Now()
|
||||||
|
in := []time.Time{t0.Add(-1 * time.Minute), t0}
|
||||||
|
out := pruneOlderThan(in, t0.Add(-1*time.Hour))
|
||||||
|
// Same slice header (no copy needed).
|
||||||
|
if len(out) != len(in) {
|
||||||
|
t.Fatalf("len(out) = %d, want %d", len(out), len(in))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,193 +1,87 @@
|
|||||||
package intune
|
package intune
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/shankar0123/certctl/internal/ratelimit"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SCEP RFC 8894 + Intune master bundle Phase 8.6.
|
// SCEP RFC 8894 + Intune master bundle Phase 8.6.
|
||||||
//
|
//
|
||||||
// PerDeviceRateLimiter is the second line of defense behind the replay cache
|
// PerDeviceRateLimiter is the second line of defense behind the replay
|
||||||
// from Phase 7. The replay cache catches the same challenge being submitted
|
// cache from Phase 7. The replay cache catches the same challenge being
|
||||||
// twice (within the challenge TTL); this rate limiter catches a compromised
|
// submitted twice (within the challenge TTL); this rate limiter catches a
|
||||||
// Connector signing key (or a stolen key+cert pair) issuing many DIFFERENT
|
// compromised Connector signing key (or a stolen key+cert pair) issuing
|
||||||
// valid challenges for the same device subject in a short window.
|
// many DIFFERENT valid challenges for the same device subject in a short
|
||||||
|
// window.
|
||||||
//
|
//
|
||||||
// Threat model:
|
// Threat model:
|
||||||
//
|
//
|
||||||
// - Replay cache (Phase 7): nonce-keyed; catches duplicate submission.
|
// - Replay cache (Phase 7): nonce-keyed; catches duplicate submission.
|
||||||
// - This limiter: (Subject, Issuer)-keyed; catches enrollment-flooding.
|
// - This limiter: (Subject, Issuer)-keyed; catches enrollment-flooding.
|
||||||
//
|
//
|
||||||
// Default: 3 enrollments per (device GUID, Connector identity) per 24h.
|
// EST RFC 7030 hardening master bundle Phase 4.1: the implementation that
|
||||||
|
// used to live in this file was extracted to internal/ratelimit (where it
|
||||||
|
// can be shared with EST per-principal + EST HTTP-Basic source-IP rate
|
||||||
|
// limiters). PerDeviceRateLimiter is now a thin wrapper around
|
||||||
|
// ratelimit.SlidingWindowLimiter that preserves the original
|
||||||
|
// (subject, issuer) → key composition in the Allow signature so existing
|
||||||
|
// SCEP/Intune callers don't have to change.
|
||||||
//
|
//
|
||||||
// Sizing: 100,000 distinct device entries (matches the replay cache cap).
|
// New callers SHOULD use ratelimit.SlidingWindowLimiter directly. The
|
||||||
// At-cap: oldest entry evicted (small janitor pass) to avoid unbounded
|
// EST RFC 7030 Phase 4.2 EST per-principal cap uses the shared package.
|
||||||
// memory growth on a fleet that grows past the cap.
|
|
||||||
//
|
|
||||||
// Why a hand-rolled token bucket instead of pulling in golang.org/x/time/rate:
|
|
||||||
// the rate package is in go.sum as an indirect transitive but NOT a direct
|
|
||||||
// dep. Adding it would create a new direct dep relationship for ~30 LoC of
|
|
||||||
// state machine. The hand-rolled version below uses only stdlib (sync.Mutex
|
|
||||||
// + time.Time arithmetic) and is small enough to fit on one screen.
|
|
||||||
//
|
|
||||||
// Algorithm: each (Subject, Issuer) key maps to a bucket holding a window's
|
|
||||||
// worth of recent enrollment timestamps. On Allow, the bucket prunes
|
|
||||||
// timestamps older than (now - window) and either appends the current
|
|
||||||
// timestamp + returns true, or rejects + returns false when the post-prune
|
|
||||||
// count is already at the cap. This is the "sliding window log" rate
|
|
||||||
// limiter — exact (no token-leak rounding); O(N_per_key) per-call but N is
|
|
||||||
// bounded by the cap (3 by default), so effectively O(1).
|
|
||||||
|
|
||||||
// ErrRateLimited is the typed error returned when the per-device rate limit
|
// ErrRateLimited is the typed error returned when the per-device rate
|
||||||
// fires. The handler maps this to a CertRep FAILURE with badRequest failInfo
|
// limit fires. Aliased to ratelimit.ErrRateLimited so errors.Is matches
|
||||||
// + the `rate_limited` metric label.
|
// against either name (the SCEP audit closure already pinned the
|
||||||
var ErrRateLimited = errors.New("intune: per-device rate limit exceeded for this (subject, issuer) within the configured window")
|
// "rate_limited" metric label against this sentinel; the alias preserves
|
||||||
|
// sentinel identity across the package boundary).
|
||||||
|
var ErrRateLimited = ratelimit.ErrRateLimited
|
||||||
|
|
||||||
// PerDeviceRateLimiter is a sliding-window-log rate limiter keyed by
|
// PerDeviceRateLimiter wraps ratelimit.SlidingWindowLimiter with the
|
||||||
// (Subject, Issuer) tuples derived from a parsed challenge claim.
|
// (subject, issuer)-composed-key Allow signature the Intune dispatcher
|
||||||
//
|
// uses. Concurrency-safe (the underlying limiter holds the mutex).
|
||||||
// Concurrency: the limiter is safe for concurrent Allow calls. The internal
|
|
||||||
// map is guarded by a mutex; the per-key slices are mutated only while the
|
|
||||||
// mutex is held.
|
|
||||||
type PerDeviceRateLimiter struct {
|
type PerDeviceRateLimiter struct {
|
||||||
mu sync.Mutex
|
inner *ratelimit.SlidingWindowLimiter
|
||||||
buckets map[string][]time.Time // key → sliding window of timestamps
|
|
||||||
maxN int // max enrollments per window
|
|
||||||
window time.Duration // window length (default 24h)
|
|
||||||
cap int // max keys before LRU eviction kicks in
|
|
||||||
disabled bool // maxN == 0 → all Allow calls return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPerDeviceRateLimiter returns a limiter with the given per-key cap +
|
// NewPerDeviceRateLimiter returns a limiter with the given per-key cap +
|
||||||
// window. maxN ≤ 0 disables the limiter (all Allow calls return nil); this
|
// window. maxN ≤ 0 disables the limiter (all Allow calls return nil);
|
||||||
// is operator opt-out for the rare case where the per-device cap is
|
// this is operator opt-out for the rare case where the per-device cap is
|
||||||
// undesirable (e.g. test harnesses, sketchpad deploys).
|
// undesirable (e.g. test harnesses, sketchpad deploys).
|
||||||
//
|
//
|
||||||
// Window defaults to 24h when zero. Map cap defaults to 100,000 when zero
|
// Window defaults to 24h when zero. Map cap defaults to 100,000 when zero
|
||||||
// (matches the replay cache cap; see internal/scep/intune/replay.go).
|
// (matches the replay cache cap; see internal/scep/intune/replay.go).
|
||||||
func NewPerDeviceRateLimiter(maxN int, window time.Duration, mapCap int) *PerDeviceRateLimiter {
|
func NewPerDeviceRateLimiter(maxN int, window time.Duration, mapCap int) *PerDeviceRateLimiter {
|
||||||
if window <= 0 {
|
return &PerDeviceRateLimiter{inner: ratelimit.NewSlidingWindowLimiter(maxN, window, mapCap)}
|
||||||
window = 24 * time.Hour
|
|
||||||
}
|
|
||||||
if mapCap <= 0 {
|
|
||||||
mapCap = 100_000
|
|
||||||
}
|
|
||||||
return &PerDeviceRateLimiter{
|
|
||||||
buckets: make(map[string][]time.Time),
|
|
||||||
maxN: maxN,
|
|
||||||
window: window,
|
|
||||||
cap: mapCap,
|
|
||||||
disabled: maxN <= 0,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allow checks whether an enrollment for the given (subject, issuer) tuple
|
// Allow checks whether an enrollment for the given (subject, issuer)
|
||||||
// is permitted right now. Returns nil when allowed (and records the timestamp
|
// tuple is permitted right now. Returns nil when allowed (and records
|
||||||
// in the bucket) or ErrRateLimited when the bucket is at maxN.
|
// the timestamp in the bucket) or ErrRateLimited when the bucket is at
|
||||||
|
// maxN.
|
||||||
//
|
//
|
||||||
// Empty subject is treated as "skip the limiter" — the caller's claim
|
// Empty subject is treated as "skip the limiter" — the caller's claim
|
||||||
// validation should have rejected an empty-subject claim already; this is
|
// validation should have rejected an empty-subject claim already; this
|
||||||
// belt-and-suspenders to prevent a single empty-subject bucket from
|
// is belt-and-suspenders to prevent a single empty-subject bucket from
|
||||||
// becoming a fleet-wide chokepoint. The Connector emits non-empty subject
|
// becoming a fleet-wide chokepoint.
|
||||||
// (device GUID) on every legitimate challenge.
|
|
||||||
func (l *PerDeviceRateLimiter) Allow(subject, issuer string, now time.Time) error {
|
func (l *PerDeviceRateLimiter) Allow(subject, issuer string, now time.Time) error {
|
||||||
if l.disabled {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if subject == "" {
|
if subject == "" {
|
||||||
// Caller's claim validation should reject empty-subject upstream;
|
// Empty-subject early return preserved from the pre-Phase-4.1
|
||||||
// this short-circuit is defense-in-depth so a misconfigured
|
// behavior: ratelimit.SlidingWindowLimiter also short-circuits
|
||||||
// Connector can't DoS us via the rate-limit path.
|
// on empty key, but the explicit check here documents the
|
||||||
|
// (subject, issuer) → empty-key contract and saves one call
|
||||||
|
// frame in the hot path.
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
key := subject + "|" + issuer
|
key := subject + "|" + issuer
|
||||||
|
return l.inner.Allow(key, now)
|
||||||
l.mu.Lock()
|
|
||||||
defer l.mu.Unlock()
|
|
||||||
|
|
||||||
// At-cap eviction: when the map is full, drop the oldest entry by
|
|
||||||
// finding the bucket whose newest timestamp is the smallest. O(N) but
|
|
||||||
// rarely fires; the prune-on-Allow path keeps most buckets short-lived.
|
|
||||||
if len(l.buckets) >= l.cap {
|
|
||||||
l.evictOldestLocked(now)
|
|
||||||
}
|
|
||||||
|
|
||||||
bucket := l.buckets[key]
|
|
||||||
bucket = pruneOlderThan(bucket, now.Add(-l.window))
|
|
||||||
|
|
||||||
if len(bucket) >= l.maxN {
|
|
||||||
// Don't append; over the limit. Persist the pruned bucket so the
|
|
||||||
// next call sees the most-recently-pruned state.
|
|
||||||
l.buckets[key] = bucket
|
|
||||||
return ErrRateLimited
|
|
||||||
}
|
|
||||||
|
|
||||||
bucket = append(bucket, now)
|
|
||||||
l.buckets[key] = bucket
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// pruneOlderThan returns the slice with all entries strictly before
|
|
||||||
// `cutoff` removed. Preserves order (timestamps are appended in increasing
|
|
||||||
// time, so a single linear scan from the front suffices).
|
|
||||||
func pruneOlderThan(b []time.Time, cutoff time.Time) []time.Time {
|
|
||||||
i := 0
|
|
||||||
for i < len(b) && b[i].Before(cutoff) {
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
if i == 0 {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
// Copy-shrink to release the underlying-array memory eventually
|
|
||||||
// (otherwise the slice would hold a reference to the older entries
|
|
||||||
// indefinitely until a re-allocation).
|
|
||||||
out := make([]time.Time, len(b)-i)
|
|
||||||
copy(out, b[i:])
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// evictOldestLocked drops the map entry whose newest timestamp is the
|
|
||||||
// oldest. Called under l.mu. O(N_keys) per eviction; at-cap is rare in
|
|
||||||
// practice (caps are sized for fleet steady-state).
|
|
||||||
func (l *PerDeviceRateLimiter) evictOldestLocked(now time.Time) {
|
|
||||||
var (
|
|
||||||
oldestKey string
|
|
||||||
oldestTs time.Time
|
|
||||||
first = true
|
|
||||||
)
|
|
||||||
for k, b := range l.buckets {
|
|
||||||
if len(b) == 0 {
|
|
||||||
// Empty bucket — drop it immediately, no candidate scan needed.
|
|
||||||
delete(l.buckets, k)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
newest := b[len(b)-1]
|
|
||||||
if first || newest.Before(oldestTs) {
|
|
||||||
oldestKey = k
|
|
||||||
oldestTs = newest
|
|
||||||
first = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if oldestKey != "" {
|
|
||||||
delete(l.buckets, oldestKey)
|
|
||||||
}
|
|
||||||
// Suppress unused-parameter warning for `now` in case the eviction
|
|
||||||
// strategy changes (e.g. swap to LRU keyed by time of last Allow).
|
|
||||||
_ = now
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Len returns the approximate number of distinct (subject, issuer) keys
|
// Len returns the approximate number of distinct (subject, issuer) keys
|
||||||
// currently tracked. For observability + tests; not load-stable under
|
// currently tracked. For observability + tests.
|
||||||
// concurrent Allow calls.
|
func (l *PerDeviceRateLimiter) Len() int { return l.inner.Len() }
|
||||||
func (l *PerDeviceRateLimiter) Len() int {
|
|
||||||
l.mu.Lock()
|
|
||||||
defer l.mu.Unlock()
|
|
||||||
return len(l.buckets)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Disabled reports whether the limiter is in opt-out mode (maxN ≤ 0).
|
// Disabled reports whether the limiter is in opt-out mode (maxN ≤ 0).
|
||||||
// Useful for handler-side gating + admin-endpoint observability.
|
// Useful for handler-side gating + admin-endpoint observability.
|
||||||
func (l *PerDeviceRateLimiter) Disabled() bool {
|
func (l *PerDeviceRateLimiter) Disabled() bool { return l.inner.Disabled() }
|
||||||
return l.disabled
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -103,15 +103,11 @@ func TestPerDeviceRateLimiter_EmptySubjectShortCircuits(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPerDeviceRateLimiter_DefaultCapsHonored(t *testing.T) {
|
// TestPerDeviceRateLimiter_DefaultCapsHonored — moved to
|
||||||
l := NewPerDeviceRateLimiter(5, 0, 0) // window=0 → 24h default; cap=0 → 100k default
|
// internal/ratelimit/sliding_window_test.go::TestSlidingWindowLimiter_DefaultCapsHonored
|
||||||
if l.window != 24*time.Hour {
|
// in EST RFC 7030 hardening Phase 4.1 (the white-box test reads private
|
||||||
t.Errorf("default window = %v, want 24h", l.window)
|
// fields that no longer exist on the wrapper). The shared package owns
|
||||||
}
|
// the field-default contract.
|
||||||
if l.cap != 100_000 {
|
|
||||||
t.Errorf("default cap = %d, want 100000", l.cap)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPerDeviceRateLimiter_MapCapEvictsOldest(t *testing.T) {
|
func TestPerDeviceRateLimiter_MapCapEvictsOldest(t *testing.T) {
|
||||||
// Cap of 3 keys to exercise the eviction branch deterministically.
|
// Cap of 3 keys to exercise the eviction branch deterministically.
|
||||||
@@ -161,30 +157,8 @@ func TestPerDeviceRateLimiter_ConcurrentRaceFree(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPruneOlderThan(t *testing.T) {
|
// TestPruneOlderThan + TestPruneOlderThan_NoOpWhenNothingToPrune — moved
|
||||||
t0 := time.Now()
|
// to internal/ratelimit/sliding_window_test.go in EST RFC 7030 hardening
|
||||||
in := []time.Time{
|
// Phase 4.1. pruneOlderThan is now an unexported helper of the shared
|
||||||
t0.Add(-3 * time.Hour), // pruned (older than cutoff)
|
// ratelimit package (the implementation moved there); the white-box
|
||||||
t0.Add(-2 * time.Hour), // pruned (older than cutoff)
|
// tests follow.
|
||||||
t0.Add(-1 * time.Hour), // survives (-60m is NEWER than the -90m cutoff)
|
|
||||||
t0.Add(-30 * time.Minute), // survives
|
|
||||||
t0, // survives
|
|
||||||
}
|
|
||||||
out := pruneOlderThan(in, t0.Add(-90*time.Minute))
|
|
||||||
if len(out) != 3 {
|
|
||||||
t.Fatalf("len(out) = %d, want 3 (-1h, -30m, t0 all newer than -90m cutoff)", len(out))
|
|
||||||
}
|
|
||||||
if !out[0].Equal(t0.Add(-1 * time.Hour)) {
|
|
||||||
t.Errorf("out[0] = %v, want -1h (oldest surviving entry)", out[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPruneOlderThan_NoOpWhenNothingToPrune(t *testing.T) {
|
|
||||||
t0 := time.Now()
|
|
||||||
in := []time.Time{t0.Add(-1 * time.Minute), t0}
|
|
||||||
out := pruneOlderThan(in, t0.Add(-1*time.Hour))
|
|
||||||
// Same slice header (no copy needed).
|
|
||||||
if len(out) != len(in) {
|
|
||||||
t.Fatalf("len(out) = %d, want %d", len(out), len(in))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,73 +1,45 @@
|
|||||||
package intune
|
package intune
|
||||||
|
|
||||||
|
// SCEP RFC 8894 + Intune master bundle Phase 7.2 (originally) +
|
||||||
|
// EST RFC 7030 hardening master bundle Phase 2.1 (extraction).
|
||||||
|
//
|
||||||
|
// LoadTrustAnchor + parseTrustAnchorPEM were extracted to
|
||||||
|
// internal/trustanchor.LoadBundle + parseBundlePEM so the EST mTLS
|
||||||
|
// sibling route (Phase 2 of the EST hardening bundle), the Intune
|
||||||
|
// dispatcher, and any future per-profile-trust-bundle caller can share
|
||||||
|
// the same PEM-bundle loader + SIGHUP-reload semantics. The shim below
|
||||||
|
// preserves the original public surface so existing intune callers
|
||||||
|
// (cmd/server/main.go, scep_intune_e2e_test.go, scep_profile_counter_
|
||||||
|
// isolation_test.go, scep_intune.go service) compile unchanged.
|
||||||
|
//
|
||||||
|
// New callers SHOULD import internal/trustanchor directly — the
|
||||||
|
// trustanchor.Holder + trustanchor.LoadBundle are the modern API.
|
||||||
|
//
|
||||||
|
// Note: the legacy intune error messages ("intune: trust anchor cert
|
||||||
|
// in %q expired ...") are NOT preserved verbatim across the extraction;
|
||||||
|
// the shared trustanchor package emits "trustanchor: ..." messages
|
||||||
|
// instead. The operator-facing log line at cmd/server/main.go's
|
||||||
|
// preflightSCEPIntuneTrustAnchor wraps the error in its own outer
|
||||||
|
// ("SCEP profile (PathID=...) INTUNE trust anchor load failed: ...")
|
||||||
|
// so the prefix change is invisible to log-grep runbooks that filter
|
||||||
|
// on the outer message.
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/pem"
|
|
||||||
"fmt"
|
"github.com/shankar0123/certctl/internal/trustanchor"
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// LoadTrustAnchor reads a PEM bundle of one or more Intune Connector
|
// LoadTrustAnchor reads a PEM bundle of one or more Intune Connector
|
||||||
// signing certificates from the configured path. Returns the slice of
|
// signing certificates from the configured path. Delegates to the
|
||||||
// parsed certs that the validator will accept as challenge issuers.
|
// shared trustanchor.LoadBundle (extracted in EST RFC 7030 hardening
|
||||||
|
// Phase 2.1) so the EST mTLS sibling route + the Intune dispatcher
|
||||||
|
// + any future per-profile trust-bundle caller share the same
|
||||||
|
// loader semantics (path-empty refusal, expired-cert refusal,
|
||||||
|
// non-CERTIFICATE-block tolerance).
|
||||||
//
|
//
|
||||||
// SCEP RFC 8894 + Intune master bundle Phase 7.2.
|
// Preserved here as a wrapper so existing intune callers compile
|
||||||
//
|
// unchanged. New callers SHOULD use trustanchor.LoadBundle directly.
|
||||||
// Behavior:
|
|
||||||
//
|
|
||||||
// - File must exist + be readable.
|
|
||||||
// - PEM-decodes the file; non-CERTIFICATE blocks are skipped (so an
|
|
||||||
// operator can paste a chain that includes a private key by mistake
|
|
||||||
// without breaking the load — the priv key is just ignored).
|
|
||||||
// - Returns an error if zero CERTIFICATE blocks parse.
|
|
||||||
// - Returns an error if any cert is past NotAfter (a stale trust
|
|
||||||
// anchor would silently reject every Intune challenge at runtime;
|
|
||||||
// fail loud at startup instead).
|
|
||||||
//
|
|
||||||
// Operators rotate Connector signing certs periodically; the trust
|
|
||||||
// anchor file is reloaded on SIGHUP (handled by the existing config
|
|
||||||
// watch loop in cmd/server/main.go — see cmd/server/tls.go::watchSIGHUP
|
|
||||||
// for the precedent).
|
|
||||||
func LoadTrustAnchor(path string) ([]*x509.Certificate, error) {
|
func LoadTrustAnchor(path string) ([]*x509.Certificate, error) {
|
||||||
if path == "" {
|
return trustanchor.LoadBundle(path)
|
||||||
return nil, fmt.Errorf("intune: trust anchor path is empty")
|
|
||||||
}
|
|
||||||
body, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("intune: read trust anchor %q: %w", path, err)
|
|
||||||
}
|
|
||||||
return parseTrustAnchorPEM(body, path, time.Now())
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseTrustAnchorPEM is the file-IO-free core of LoadTrustAnchor. Split
|
|
||||||
// out so unit tests can hand it byte slices without writing temp files.
|
|
||||||
// `now` is taken as a parameter so expiry tests can pin a deterministic
|
|
||||||
// clock.
|
|
||||||
func parseTrustAnchorPEM(body []byte, sourceLabel string, now time.Time) ([]*x509.Certificate, error) {
|
|
||||||
var out []*x509.Certificate
|
|
||||||
rest := body
|
|
||||||
for {
|
|
||||||
var block *pem.Block
|
|
||||||
block, rest = pem.Decode(rest)
|
|
||||||
if block == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if block.Type != "CERTIFICATE" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
cert, err := x509.ParseCertificate(block.Bytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("intune: parse trust anchor cert in %q: %w", sourceLabel, err)
|
|
||||||
}
|
|
||||||
if now.After(cert.NotAfter) {
|
|
||||||
return nil, fmt.Errorf("intune: trust anchor cert in %q expired at %s (subject=%q) — operator must rotate the Connector signing cert before restart",
|
|
||||||
sourceLabel, cert.NotAfter.Format(time.RFC3339), cert.Subject.CommonName)
|
|
||||||
}
|
|
||||||
out = append(out, cert)
|
|
||||||
}
|
|
||||||
if len(out) == 0 {
|
|
||||||
return nil, fmt.Errorf("intune: trust anchor %q contains no CERTIFICATE PEM blocks", sourceLabel)
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,143 +1,58 @@
|
|||||||
package intune
|
package intune
|
||||||
|
|
||||||
|
// SCEP RFC 8894 + Intune master bundle Phase 8.5 (originally) +
|
||||||
|
// EST RFC 7030 hardening master bundle Phase 2.1 (extraction).
|
||||||
|
//
|
||||||
|
// TrustAnchorHolder + NewTrustAnchorHolder were extracted to
|
||||||
|
// internal/trustanchor.Holder + trustanchor.New so the EST mTLS sibling
|
||||||
|
// route (Phase 2 of the EST hardening bundle) and the Intune dispatcher
|
||||||
|
// can share the same SIGHUP-reloadable PEM bundle primitive. A single
|
||||||
|
// SIGHUP now rotates: server TLS cert (cmd/server/tls.go), every Intune
|
||||||
|
// trust anchor (this package's existing wiring), AND every EST mTLS
|
||||||
|
// per-profile client-CA bundle (the new sibling route) — exactly the
|
||||||
|
// design contract documented in the trustanchor package doc.
|
||||||
|
//
|
||||||
|
// The aliases below preserve every existing intune call site unchanged:
|
||||||
|
// - cmd/server/main.go declares `intuneTrustHolders []*intune.TrustAnchorHolder`
|
||||||
|
// + invokes `intune.NewTrustAnchorHolder(path, logger)`
|
||||||
|
// - internal/service/scep.go's SCEPService struct field
|
||||||
|
// `intuneTrust *intune.TrustAnchorHolder` (the type alias keeps this
|
||||||
|
// pointer-compatible with the original)
|
||||||
|
// - internal/scep/intune/trust_anchor_holder_test.go + the e2e tests
|
||||||
|
// that construct a holder via NewTrustAnchorHolder
|
||||||
|
//
|
||||||
|
// New callers SHOULD import internal/trustanchor directly — the
|
||||||
|
// trustanchor.Holder + trustanchor.New are the modern API. The intune
|
||||||
|
// aliases are preserved indefinitely for back-compat (no deprecation
|
||||||
|
// timeline; the cost of the two-line shim is trivial).
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/x509"
|
"github.com/shankar0123/certctl/internal/trustanchor"
|
||||||
"errors"
|
|
||||||
"log/slog"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TrustAnchorHolder is the SIGHUP-reloadable wrapper around a per-profile
|
// TrustAnchorHolder is the SIGHUP-reloadable wrapper around a per-profile
|
||||||
// Intune Connector trust anchor pool.
|
// Intune Connector trust anchor pool.
|
||||||
//
|
//
|
||||||
// SCEP RFC 8894 + Intune master bundle Phase 8.5.
|
// Aliased to trustanchor.Holder (extracted in EST RFC 7030 hardening
|
||||||
//
|
// Phase 2.1) so the EST mTLS sibling route + the Intune dispatcher share
|
||||||
// Mirrors the shape established by `cmd/server/tls.go::certHolder` for the
|
// the same primitive. Existing callers compile unchanged because Go type
|
||||||
// server TLS cert: an RWMutex-guarded pool, a Get accessor that's safe for
|
// aliases are pointer-compatible.
|
||||||
// concurrent callers from the request path, a Reload that re-reads the file
|
type TrustAnchorHolder = trustanchor.Holder
|
||||||
// and atomically swaps the slice on success (failure leaves the OLD pool in
|
|
||||||
// place so a bad reload doesn't take Intune enrollment down), and a
|
|
||||||
// watchSIGHUP goroutine that responds to the same SIGHUP the operator uses
|
|
||||||
// to rotate the server TLS cert.
|
|
||||||
//
|
|
||||||
// Why SIGHUP specifically (vs fsnotify or a polling loop): SIGHUP is the
|
|
||||||
// repo-established convention (see cmd/server/tls.go). fsnotify would add a
|
|
||||||
// new direct dep + complicate the cleanup story. The operator's Connector-
|
|
||||||
// rotation script writes the new PEM bundle then sends SIGHUP — the same
|
|
||||||
// signal that already rotates the server TLS cert — and both swap atomically.
|
|
||||||
//
|
|
||||||
// Concurrency contract:
|
|
||||||
// - Get returns the pool slice header by value; the slice itself is
|
|
||||||
// immutable per-snapshot (Reload swaps a fresh slice rather than
|
|
||||||
// mutating the existing one). Callers may iterate the returned slice
|
|
||||||
// without holding any lock.
|
|
||||||
// - Reload acquires a write lock briefly for the swap. Concurrent Get
|
|
||||||
// calls block only for that swap window (microseconds).
|
|
||||||
// - watchSIGHUP runs at most one Reload at a time per holder.
|
|
||||||
type TrustAnchorHolder struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
certs []*x509.Certificate
|
|
||||||
path string
|
|
||||||
logger *slog.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewTrustAnchorHolder loads the trust bundle and returns a holder. Returns
|
// NewTrustAnchorHolder loads the trust bundle and returns a holder.
|
||||||
// the same fail-loud error LoadTrustAnchor does on initial load — the
|
// Aliased to trustanchor.New (extracted in EST RFC 7030 hardening
|
||||||
// startup gate at cmd/server/main.go is supposed to refuse boot when this
|
// Phase 2.1). Returns the same fail-loud error LoadTrustAnchor does on
|
||||||
// fails. Subsequent Reload errors are non-fatal (logged + old pool retained).
|
// initial load — the startup gate at cmd/server/main.go is supposed to
|
||||||
|
// refuse boot when this fails. Subsequent Reload errors are non-fatal
|
||||||
|
// (logged + old pool retained).
|
||||||
//
|
//
|
||||||
// The logger is required (never nil); the caller passes a per-profile
|
// The logger is required (never nil); the caller passes a per-profile
|
||||||
// scoped logger so SIGHUP-reload events show the PathID for triage.
|
// scoped logger so SIGHUP-reload events show the PathID for triage.
|
||||||
func NewTrustAnchorHolder(path string, logger *slog.Logger) (*TrustAnchorHolder, error) {
|
|
||||||
if logger == nil {
|
|
||||||
return nil, errors.New("intune: TrustAnchorHolder requires a non-nil logger")
|
|
||||||
}
|
|
||||||
certs, err := LoadTrustAnchor(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &TrustAnchorHolder{
|
|
||||||
certs: certs,
|
|
||||||
path: path,
|
|
||||||
logger: logger,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get returns the current trust anchor pool. Safe for concurrent callers;
|
|
||||||
// the slice header is returned by value and the underlying slice is
|
|
||||||
// immutable per-snapshot (Reload swaps a fresh slice, doesn't mutate in
|
|
||||||
// place — see Reload).
|
|
||||||
func (h *TrustAnchorHolder) Get() []*x509.Certificate {
|
|
||||||
h.mu.RLock()
|
|
||||||
defer h.mu.RUnlock()
|
|
||||||
return h.certs
|
|
||||||
}
|
|
||||||
|
|
||||||
// Path returns the on-disk path the holder reloads from. Useful for
|
|
||||||
// observability (admin endpoints, log lines) without exposing the cert
|
|
||||||
// pool itself.
|
|
||||||
func (h *TrustAnchorHolder) Path() string {
|
|
||||||
return h.path
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reload re-reads the trust anchor file at h.path and atomically swaps the
|
|
||||||
// pool. Returns the parse error if the new file is invalid; the OLD pool
|
|
||||||
// stays in place so a bad reload doesn't take Intune enrollment down.
|
|
||||||
//
|
//
|
||||||
// Same fail-safe pattern as cmd/server/tls.go::(*certHolder).Reload — a
|
// Note: the original intune.NewTrustAnchorHolder set the holder's
|
||||||
// rotation that writes a half-file (operator overwrites the bundle while
|
// internal log label to "Intune trust anchor"; the extracted
|
||||||
// only some of the new certs are in it) would otherwise crash the
|
// trustanchor.New defaults to "trust anchor". Existing intune callers
|
||||||
// service mid-rotation. Logging + retaining the old pool gives the
|
// that need the original label should call .SetLabelForLog("intune
|
||||||
// operator a bounded window to fix and re-SIGHUP.
|
// trust anchor (PathID=…)") on the returned holder. cmd/server/main.go
|
||||||
func (h *TrustAnchorHolder) Reload() error {
|
// does this in the per-profile Intune startup loop.
|
||||||
certs, err := LoadTrustAnchor(h.path)
|
var NewTrustAnchorHolder = trustanchor.New
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
h.mu.Lock()
|
|
||||||
h.certs = certs
|
|
||||||
h.mu.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// WatchSIGHUP installs a signal handler that calls Reload on each SIGHUP.
|
|
||||||
// The returned stop function closes the internal done channel and stops
|
|
||||||
// signal delivery so the goroutine can exit cleanly during shutdown.
|
|
||||||
//
|
|
||||||
// Errors from Reload are logged but do not terminate the watcher — the
|
|
||||||
// operator can fix the files and send another SIGHUP. Mirrors the
|
|
||||||
// (*certHolder).watchSIGHUP contract exactly.
|
|
||||||
//
|
|
||||||
// Multiple holders can coexist: each registers its own goroutine on the
|
|
||||||
// same SIGHUP signal. signal.Notify multicasts to every registered
|
|
||||||
// channel, so a single SIGHUP reloads every per-profile Intune trust
|
|
||||||
// anchor PLUS the server TLS cert in one operator action — exactly the
|
|
||||||
// design requirement (one SIGHUP rotates everything).
|
|
||||||
func (h *TrustAnchorHolder) WatchSIGHUP() (stop func()) {
|
|
||||||
ch := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(ch, syscall.SIGHUP)
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ch:
|
|
||||||
if err := h.Reload(); err != nil {
|
|
||||||
h.logger.Error("Intune trust anchor reload failed; continuing with previous pool",
|
|
||||||
"error", err,
|
|
||||||
"path", h.path)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
h.logger.Info("Intune trust anchor reloaded via SIGHUP",
|
|
||||||
"path", h.path,
|
|
||||||
"certs_loaded", len(h.Get()))
|
|
||||||
case <-done:
|
|
||||||
signal.Stop(ch)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return func() { close(done) }
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -16,6 +16,13 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// EST RFC 7030 hardening master bundle Phase 2.1: the white-box parser
|
||||||
|
// tests (TestParseTrustAnchorPEM_*) moved to internal/trustanchor/holder_test.go
|
||||||
|
// where parseBundlePEM now lives. The intune package retains a thin
|
||||||
|
// public-surface test of LoadTrustAnchor — the back-compat shim that
|
||||||
|
// existing intune callers use — so a future refactor that breaks the
|
||||||
|
// shim's wire-up to trustanchor.LoadBundle is caught here.
|
||||||
|
|
||||||
// pemEncodeCert is a small DRY helper for the PEM bundle fixtures.
|
// pemEncodeCert is a small DRY helper for the PEM bundle fixtures.
|
||||||
func pemEncodeCert(t *testing.T, der []byte) []byte {
|
func pemEncodeCert(t *testing.T, der []byte) []byte {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
@@ -24,7 +31,9 @@ func pemEncodeCert(t *testing.T, der []byte) []byte {
|
|||||||
|
|
||||||
// freshConnectorCertDER returns a freshly-minted EC P-256 cert as raw DER
|
// freshConnectorCertDER returns a freshly-minted EC P-256 cert as raw DER
|
||||||
// + the matching key. Lifetime is parameterised so the same factory drives
|
// + the matching key. Lifetime is parameterised so the same factory drives
|
||||||
// both the happy-path and expired-cert cases.
|
// both happy-path and expired-cert cases. Kept in this file (not deleted with
|
||||||
|
// the white-box tests) because trust_anchor_holder_test.go's freshHolderCert
|
||||||
|
// returns *x509.Certificate while LoadTrustAnchor tests need raw DER + key.
|
||||||
func freshConnectorCertDER(t *testing.T, notAfter time.Time) ([]byte, *ecdsa.PrivateKey) {
|
func freshConnectorCertDER(t *testing.T, notAfter time.Time) ([]byte, *ecdsa.PrivateKey) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
@@ -44,96 +53,6 @@ func freshConnectorCertDER(t *testing.T, notAfter time.Time) ([]byte, *ecdsa.Pri
|
|||||||
return der, key
|
return der, key
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseTrustAnchorPEM_HappyPath_SingleCert(t *testing.T) {
|
|
||||||
der, _ := freshConnectorCertDER(t, time.Now().Add(365*24*time.Hour))
|
|
||||||
body := pemEncodeCert(t, der)
|
|
||||||
|
|
||||||
certs, err := parseTrustAnchorPEM(body, "test", time.Now())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("parseTrustAnchorPEM: %v", err)
|
|
||||||
}
|
|
||||||
if len(certs) != 1 {
|
|
||||||
t.Fatalf("len(certs) = %d, want 1", len(certs))
|
|
||||||
}
|
|
||||||
if certs[0].Subject.CommonName != "intune-connector-test" {
|
|
||||||
t.Errorf("Subject.CommonName = %q", certs[0].Subject.CommonName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseTrustAnchorPEM_HappyPath_MultiCert(t *testing.T) {
|
|
||||||
d1, _ := freshConnectorCertDER(t, time.Now().Add(30*24*time.Hour))
|
|
||||||
d2, _ := freshConnectorCertDER(t, time.Now().Add(60*24*time.Hour))
|
|
||||||
body := append(pemEncodeCert(t, d1), pemEncodeCert(t, d2)...)
|
|
||||||
|
|
||||||
certs, err := parseTrustAnchorPEM(body, "test", time.Now())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("parseTrustAnchorPEM: %v", err)
|
|
||||||
}
|
|
||||||
if len(certs) != 2 {
|
|
||||||
t.Fatalf("len(certs) = %d, want 2", len(certs))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseTrustAnchorPEM_SkipsNonCertBlocks(t *testing.T) {
|
|
||||||
der, key := freshConnectorCertDER(t, time.Now().Add(30*24*time.Hour))
|
|
||||||
keyDER, err := x509.MarshalECPrivateKey(key)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("MarshalECPrivateKey: %v", err)
|
|
||||||
}
|
|
||||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
|
||||||
body := append(keyPEM, pemEncodeCert(t, der)...) // priv key first, cert second
|
|
||||||
|
|
||||||
certs, err := parseTrustAnchorPEM(body, "test", time.Now())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("parseTrustAnchorPEM should ignore non-CERTIFICATE blocks: %v", err)
|
|
||||||
}
|
|
||||||
if len(certs) != 1 {
|
|
||||||
t.Fatalf("len(certs) = %d, want 1 (priv key block must be skipped)", len(certs))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseTrustAnchorPEM_EmptyBundleRejected(t *testing.T) {
|
|
||||||
_, err := parseTrustAnchorPEM([]byte("nothing here"), "test", time.Now())
|
|
||||||
if err == nil || !strings.Contains(err.Error(), "no CERTIFICATE PEM blocks") {
|
|
||||||
t.Fatalf("expected 'no CERTIFICATE PEM blocks' error, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseTrustAnchorPEM_OnlyKeyBlocksRejected(t *testing.T) {
|
|
||||||
key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
||||||
keyDER, _ := x509.MarshalECPrivateKey(key)
|
|
||||||
body := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
|
||||||
|
|
||||||
_, err := parseTrustAnchorPEM(body, "test", time.Now())
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("expected error for bundle with no certs, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseTrustAnchorPEM_ExpiredCertRejected(t *testing.T) {
|
|
||||||
der, _ := freshConnectorCertDER(t, time.Now().Add(-1*time.Hour)) // already expired
|
|
||||||
body := pemEncodeCert(t, der)
|
|
||||||
|
|
||||||
_, err := parseTrustAnchorPEM(body, "expired-bundle", time.Now())
|
|
||||||
if err == nil || !strings.Contains(err.Error(), "expired") {
|
|
||||||
t.Fatalf("expected expiry error, got %v", err)
|
|
||||||
}
|
|
||||||
// Operator-actionable message must include the subject so the audit
|
|
||||||
// log says exactly which cert to rotate.
|
|
||||||
if !strings.Contains(err.Error(), "intune-connector-test") {
|
|
||||||
t.Errorf("error must include subject CN for operator action: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseTrustAnchorPEM_MalformedCertRejected(t *testing.T) {
|
|
||||||
bad := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("not-a-real-asn1-cert")})
|
|
||||||
|
|
||||||
_, err := parseTrustAnchorPEM(bad, "test", time.Now())
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("expected x509 parse error, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadTrustAnchor_FromDisk(t *testing.T) {
|
func TestLoadTrustAnchor_FromDisk(t *testing.T) {
|
||||||
der, _ := freshConnectorCertDER(t, time.Now().Add(30*24*time.Hour))
|
der, _ := freshConnectorCertDER(t, time.Now().Add(30*24*time.Hour))
|
||||||
body := pemEncodeCert(t, der)
|
body := pemEncodeCert(t, der)
|
||||||
@@ -150,6 +69,9 @@ func TestLoadTrustAnchor_FromDisk(t *testing.T) {
|
|||||||
if len(certs) != 1 {
|
if len(certs) != 1 {
|
||||||
t.Fatalf("len(certs) = %d, want 1", len(certs))
|
t.Fatalf("len(certs) = %d, want 1", len(certs))
|
||||||
}
|
}
|
||||||
|
if certs[0].Subject.CommonName != "intune-connector-test" {
|
||||||
|
t.Errorf("Subject.CommonName = %q", certs[0].Subject.CommonName)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadTrustAnchor_EmptyPath(t *testing.T) {
|
func TestLoadTrustAnchor_EmptyPath(t *testing.T) {
|
||||||
@@ -164,7 +86,6 @@ func TestLoadTrustAnchor_MissingFile(t *testing.T) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("expected file-not-found error, got nil")
|
t.Fatalf("expected file-not-found error, got nil")
|
||||||
}
|
}
|
||||||
// Don't string-assert on the OS error — just make sure it's surfaced.
|
|
||||||
if errors.Is(err, nil) {
|
if errors.Is(err, nil) {
|
||||||
t.Fatalf("error must be non-nil")
|
t.Fatalf("error must be non-nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,227 @@
|
|||||||
|
// Package trustanchor provides a SIGHUP-reloadable PEM-bundle trust pool
|
||||||
|
// shared by the SCEP/Intune dispatcher (per-profile Microsoft Intune
|
||||||
|
// Connector signing-cert anchor), the EST mTLS sibling route (per-profile
|
||||||
|
// client-CA trust bundle for /.well-known/est-mtls/<pathID>/), and any
|
||||||
|
// future caller that needs the same pattern (operator rotates an on-disk
|
||||||
|
// PEM bundle, sends SIGHUP, certctl swaps the in-memory pool atomically
|
||||||
|
// without a restart).
|
||||||
|
//
|
||||||
|
// EST RFC 7030 hardening master bundle Phase 2.1: extracted from
|
||||||
|
// internal/scep/intune/trust_anchor_holder.go where it originally lived.
|
||||||
|
// The intune package preserves a thin alias-style wrapper for back-compat
|
||||||
|
// (existing intune.TrustAnchorHolder + NewTrustAnchorHolder + LoadTrustAnchor
|
||||||
|
// callers compile unchanged); new callers SHOULD import this package
|
||||||
|
// directly.
|
||||||
|
//
|
||||||
|
// Concurrency contract:
|
||||||
|
//
|
||||||
|
// - Get returns the pool slice header by value; the slice itself is
|
||||||
|
// immutable per-snapshot (Reload swaps a fresh slice rather than
|
||||||
|
// mutating the existing one). Callers may iterate the returned slice
|
||||||
|
// without holding any lock.
|
||||||
|
// - Reload acquires a write lock briefly for the swap. Concurrent Get
|
||||||
|
// calls block only for that swap window (microseconds).
|
||||||
|
// - WatchSIGHUP runs at most one Reload at a time per holder.
|
||||||
|
//
|
||||||
|
// Threat model: the rationale for SIGHUP-as-reload-trigger (vs fsnotify
|
||||||
|
// or polling) is that the existing certctl rotation playbook (server TLS
|
||||||
|
// cert at cmd/server/tls.go::certHolder) already uses SIGHUP. Operators
|
||||||
|
// running the standard "rotate file, kill -HUP" workflow get every
|
||||||
|
// holder reloaded with one signal: server TLS + Intune trust anchors +
|
||||||
|
// EST mTLS trust bundles all swap atomically.
|
||||||
|
package trustanchor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Holder is the SIGHUP-reloadable wrapper around a PEM-bundle trust
|
||||||
|
// pool. Construct via New. The zero value is NOT usable.
|
||||||
|
type Holder struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
certs []*x509.Certificate
|
||||||
|
path string
|
||||||
|
logger *slog.Logger
|
||||||
|
// labelForLog is used only in error / info log lines so an operator
|
||||||
|
// running multiple holders (per-profile EST mTLS, per-profile Intune,
|
||||||
|
// server TLS) can distinguish which one fired. Defaults to "trust
|
||||||
|
// anchor" when not set; callers SHOULD set this to a descriptive
|
||||||
|
// string like "intune trust anchor (PathID=corp)" or "EST mTLS
|
||||||
|
// client CA bundle (PathID=corp)".
|
||||||
|
labelForLog string
|
||||||
|
}
|
||||||
|
|
||||||
|
// New loads the trust bundle and returns a holder. Returns the same
|
||||||
|
// fail-loud error LoadBundle does on initial load — the startup gate at
|
||||||
|
// cmd/server/main.go is supposed to refuse boot when this fails.
|
||||||
|
// Subsequent Reload errors are non-fatal (logged + old pool retained).
|
||||||
|
//
|
||||||
|
// The logger is required (never nil); the caller passes a per-profile
|
||||||
|
// scoped logger so SIGHUP-reload events show the PathID for triage.
|
||||||
|
func New(path string, logger *slog.Logger) (*Holder, error) {
|
||||||
|
if logger == nil {
|
||||||
|
return nil, errors.New("trustanchor: New requires a non-nil logger")
|
||||||
|
}
|
||||||
|
certs, err := LoadBundle(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Holder{certs: certs, path: path, logger: logger, labelForLog: "trust anchor"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLabelForLog records a descriptive label that future reload log
|
||||||
|
// lines use to distinguish this holder from others (e.g. "intune trust
|
||||||
|
// anchor (PathID=corp)"). Idempotent + safe for concurrent callers
|
||||||
|
// (the field is read only by the SIGHUP watcher goroutine after
|
||||||
|
// WatchSIGHUP starts).
|
||||||
|
func (h *Holder) SetLabelForLog(label string) {
|
||||||
|
if label == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.mu.Lock()
|
||||||
|
h.labelForLog = label
|
||||||
|
h.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the current trust anchor pool. Safe for concurrent
|
||||||
|
// callers; the slice header is returned by value and the underlying
|
||||||
|
// slice is immutable per-snapshot.
|
||||||
|
func (h *Holder) Get() []*x509.Certificate {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return h.certs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Path returns the on-disk path the holder reloads from.
|
||||||
|
func (h *Holder) Path() string {
|
||||||
|
return h.path
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pool returns a fresh *x509.CertPool populated with the holder's
|
||||||
|
// current certs. Helper for callers that need a pool instead of a
|
||||||
|
// slice (the EST mTLS handler verifies client cert chains via
|
||||||
|
// cert.Verify(VerifyOptions{Roots: pool}); the Intune dispatcher uses
|
||||||
|
// the slice directly for signature-walk).
|
||||||
|
func (h *Holder) Pool() *x509.CertPool {
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
for _, c := range h.Get() {
|
||||||
|
pool.AddCert(c)
|
||||||
|
}
|
||||||
|
return pool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload re-reads the trust anchor file at h.path and atomically swaps
|
||||||
|
// the pool. Returns the parse error if the new file is invalid; the
|
||||||
|
// OLD pool stays in place so a bad reload doesn't take dependent
|
||||||
|
// dispatch paths down. Same fail-safe pattern as cmd/server/tls.go::
|
||||||
|
// (*certHolder).Reload — a rotation that writes a half-file would
|
||||||
|
// otherwise crash the service mid-rotation.
|
||||||
|
func (h *Holder) Reload() error {
|
||||||
|
certs, err := LoadBundle(h.path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.mu.Lock()
|
||||||
|
h.certs = certs
|
||||||
|
h.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WatchSIGHUP installs a signal handler that calls Reload on each
|
||||||
|
// SIGHUP. The returned stop function closes the internal done channel
|
||||||
|
// and stops signal delivery so the goroutine can exit cleanly during
|
||||||
|
// shutdown.
|
||||||
|
//
|
||||||
|
// Errors from Reload are logged but do not terminate the watcher — the
|
||||||
|
// operator can fix the files and send another SIGHUP. Mirrors the
|
||||||
|
// (*certHolder).watchSIGHUP contract from cmd/server/tls.go exactly.
|
||||||
|
//
|
||||||
|
// Multiple holders coexist: each registers its own goroutine on the
|
||||||
|
// same SIGHUP signal. signal.Notify multicasts to every registered
|
||||||
|
// channel, so a single SIGHUP reloads every per-profile trust anchor
|
||||||
|
// + the server TLS cert in one operator action.
|
||||||
|
func (h *Holder) WatchSIGHUP() (stop func()) {
|
||||||
|
ch := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(ch, syscall.SIGHUP)
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ch:
|
||||||
|
if err := h.Reload(); err != nil {
|
||||||
|
h.logger.Error(h.labelForLog+" reload failed; continuing with previous pool",
|
||||||
|
"error", err,
|
||||||
|
"path", h.path)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
h.logger.Info(h.labelForLog+" reloaded via SIGHUP",
|
||||||
|
"path", h.path,
|
||||||
|
"certs_loaded", len(h.Get()))
|
||||||
|
case <-done:
|
||||||
|
signal.Stop(ch)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return func() { close(done) }
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadBundle reads a PEM bundle from disk + returns the parsed cert
|
||||||
|
// slice. Refuses empty bundles (zero CERTIFICATE blocks); refuses any
|
||||||
|
// bundle containing a cert past NotAfter (fail loud at boot rather than
|
||||||
|
// silently rejecting every request at runtime).
|
||||||
|
//
|
||||||
|
// Non-CERTIFICATE PEM blocks are skipped (so an operator can paste a
|
||||||
|
// chain that includes a private key by mistake without breaking the
|
||||||
|
// load — the priv key is just ignored). Operators rotating signing
|
||||||
|
// certs typically want this tolerance.
|
||||||
|
func LoadBundle(path string) ([]*x509.Certificate, error) {
|
||||||
|
if path == "" {
|
||||||
|
return nil, errors.New("trustanchor: bundle path is empty")
|
||||||
|
}
|
||||||
|
body, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("trustanchor: read bundle %q: %w", path, err)
|
||||||
|
}
|
||||||
|
return parseBundlePEM(body, path, time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseBundlePEM is the file-IO-free core of LoadBundle. Split out so
|
||||||
|
// unit tests can hand it byte slices without writing temp files. `now`
|
||||||
|
// is taken as a parameter so expiry tests can pin a deterministic clock.
|
||||||
|
func parseBundlePEM(body []byte, sourceLabel string, now time.Time) ([]*x509.Certificate, error) {
|
||||||
|
var out []*x509.Certificate
|
||||||
|
rest := body
|
||||||
|
for {
|
||||||
|
var block *pem.Block
|
||||||
|
block, rest = pem.Decode(rest)
|
||||||
|
if block == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if block.Type != "CERTIFICATE" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cert, err := x509.ParseCertificate(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("trustanchor: parse cert in %q: %w", sourceLabel, err)
|
||||||
|
}
|
||||||
|
if now.After(cert.NotAfter) {
|
||||||
|
return nil, fmt.Errorf("trustanchor: cert in %q expired at %s (subject=%q) — operator must rotate the trust bundle before restart",
|
||||||
|
sourceLabel, cert.NotAfter.Format(time.RFC3339), cert.Subject.CommonName)
|
||||||
|
}
|
||||||
|
out = append(out, cert)
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil, fmt.Errorf("trustanchor: %q contains no CERTIFICATE PEM blocks", sourceLabel)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,432 @@
|
|||||||
|
package trustanchor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"math/big"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EST RFC 7030 hardening master bundle Phase 2.1: this test file holds the
|
||||||
|
// white-box tests for the trust-anchor primitives (parseBundlePEM + LoadBundle
|
||||||
|
// + Holder) that used to live in internal/scep/intune/{trust_anchor_test.go,
|
||||||
|
// trust_anchor_holder_test.go}. The intune package retains a thin
|
||||||
|
// public-surface test of LoadTrustAnchor (the back-compat shim) — the
|
||||||
|
// detailed tests live here so the EST mTLS sibling route + any future
|
||||||
|
// trustanchor.Holder caller share the same contract pinning.
|
||||||
|
|
||||||
|
// silentLogger drops everything; the SIGHUP watcher emits Info logs we don't
|
||||||
|
// want fouling test output.
|
||||||
|
func silentLogger() *slog.Logger {
|
||||||
|
return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError + 10}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// pemEncodeCert is a small DRY helper for the PEM bundle fixtures.
|
||||||
|
func pemEncodeCert(t *testing.T, der []byte) []byte {
|
||||||
|
t.Helper()
|
||||||
|
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||||
|
}
|
||||||
|
|
||||||
|
// freshConnectorCertDER returns a freshly-minted EC P-256 cert as raw DER
|
||||||
|
// + the matching key. Lifetime is parameterised so the same factory drives
|
||||||
|
// both the happy-path and expired-cert cases.
|
||||||
|
func freshConnectorCertDER(t *testing.T, notAfter time.Time) ([]byte, *ecdsa.PrivateKey) {
|
||||||
|
t.Helper()
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ecdsa.GenerateKey: %v", err)
|
||||||
|
}
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(time.Now().UnixNano()),
|
||||||
|
Subject: pkix.Name{CommonName: "trustanchor-test"},
|
||||||
|
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||||
|
NotAfter: notAfter,
|
||||||
|
}
|
||||||
|
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("x509.CreateCertificate: %v", err)
|
||||||
|
}
|
||||||
|
return der, key
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeTestBundle writes a PEM bundle of the given certs at path with mode 0600.
|
||||||
|
func writeTestBundle(t *testing.T, path string, certs []*x509.Certificate) {
|
||||||
|
t.Helper()
|
||||||
|
body := []byte{}
|
||||||
|
for _, c := range certs {
|
||||||
|
body = append(body, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: c.Raw})...)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(path, body, 0o600); err != nil {
|
||||||
|
t.Fatalf("WriteFile: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// freshHolderCert is a small factory for a self-signed EC cert with a
|
||||||
|
// caller-controlled CN + lifetime. Used by Reload tests that swap the
|
||||||
|
// on-disk pool between calls.
|
||||||
|
func freshHolderCert(t *testing.T, cn string, notAfter time.Time) *x509.Certificate {
|
||||||
|
t.Helper()
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ecdsa.GenerateKey: %v", err)
|
||||||
|
}
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(time.Now().UnixNano()),
|
||||||
|
Subject: pkix.Name{CommonName: cn},
|
||||||
|
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||||
|
NotAfter: notAfter,
|
||||||
|
}
|
||||||
|
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("x509.CreateCertificate: %v", err)
|
||||||
|
}
|
||||||
|
cert, err := x509.ParseCertificate(der)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("x509.ParseCertificate: %v", err)
|
||||||
|
}
|
||||||
|
return cert
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- parseBundlePEM (white-box) -----
|
||||||
|
|
||||||
|
func TestParseBundlePEM_HappyPath_SingleCert(t *testing.T) {
|
||||||
|
der, _ := freshConnectorCertDER(t, time.Now().Add(365*24*time.Hour))
|
||||||
|
body := pemEncodeCert(t, der)
|
||||||
|
|
||||||
|
certs, err := parseBundlePEM(body, "test", time.Now())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseBundlePEM: %v", err)
|
||||||
|
}
|
||||||
|
if len(certs) != 1 {
|
||||||
|
t.Fatalf("len(certs) = %d, want 1", len(certs))
|
||||||
|
}
|
||||||
|
if certs[0].Subject.CommonName != "trustanchor-test" {
|
||||||
|
t.Errorf("Subject.CommonName = %q", certs[0].Subject.CommonName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBundlePEM_HappyPath_MultiCert(t *testing.T) {
|
||||||
|
d1, _ := freshConnectorCertDER(t, time.Now().Add(30*24*time.Hour))
|
||||||
|
d2, _ := freshConnectorCertDER(t, time.Now().Add(60*24*time.Hour))
|
||||||
|
body := append(pemEncodeCert(t, d1), pemEncodeCert(t, d2)...)
|
||||||
|
|
||||||
|
certs, err := parseBundlePEM(body, "test", time.Now())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseBundlePEM: %v", err)
|
||||||
|
}
|
||||||
|
if len(certs) != 2 {
|
||||||
|
t.Fatalf("len(certs) = %d, want 2", len(certs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBundlePEM_SkipsNonCertBlocks(t *testing.T) {
|
||||||
|
der, key := freshConnectorCertDER(t, time.Now().Add(30*24*time.Hour))
|
||||||
|
keyDER, err := x509.MarshalECPrivateKey(key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MarshalECPrivateKey: %v", err)
|
||||||
|
}
|
||||||
|
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||||
|
body := append(keyPEM, pemEncodeCert(t, der)...) // priv key first, cert second
|
||||||
|
|
||||||
|
certs, err := parseBundlePEM(body, "test", time.Now())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseBundlePEM should ignore non-CERTIFICATE blocks: %v", err)
|
||||||
|
}
|
||||||
|
if len(certs) != 1 {
|
||||||
|
t.Fatalf("len(certs) = %d, want 1 (priv key block must be skipped)", len(certs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBundlePEM_EmptyBundleRejected(t *testing.T) {
|
||||||
|
_, err := parseBundlePEM([]byte("nothing here"), "test", time.Now())
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "no CERTIFICATE PEM blocks") {
|
||||||
|
t.Fatalf("expected 'no CERTIFICATE PEM blocks' error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBundlePEM_OnlyKeyBlocksRejected(t *testing.T) {
|
||||||
|
key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
keyDER, _ := x509.MarshalECPrivateKey(key)
|
||||||
|
body := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||||
|
|
||||||
|
_, err := parseBundlePEM(body, "test", time.Now())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error for bundle with no certs, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBundlePEM_ExpiredCertRejected(t *testing.T) {
|
||||||
|
der, _ := freshConnectorCertDER(t, time.Now().Add(-1*time.Hour)) // already expired
|
||||||
|
body := pemEncodeCert(t, der)
|
||||||
|
|
||||||
|
_, err := parseBundlePEM(body, "expired-bundle", time.Now())
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "expired") {
|
||||||
|
t.Fatalf("expected expiry error, got %v", err)
|
||||||
|
}
|
||||||
|
// Operator-actionable message must include the subject so the audit
|
||||||
|
// log says exactly which cert to rotate.
|
||||||
|
if !strings.Contains(err.Error(), "trustanchor-test") {
|
||||||
|
t.Errorf("error must include subject CN for operator action: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBundlePEM_MalformedCertRejected(t *testing.T) {
|
||||||
|
bad := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("not-a-real-asn1-cert")})
|
||||||
|
|
||||||
|
_, err := parseBundlePEM(bad, "test", time.Now())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected x509 parse error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- LoadBundle (filesystem-side) -----
|
||||||
|
|
||||||
|
func TestLoadBundle_FromDisk(t *testing.T) {
|
||||||
|
der, _ := freshConnectorCertDER(t, time.Now().Add(30*24*time.Hour))
|
||||||
|
body := pemEncodeCert(t, der)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "trust.pem")
|
||||||
|
if err := os.WriteFile(path, body, 0o600); err != nil {
|
||||||
|
t.Fatalf("WriteFile: %v", err)
|
||||||
|
}
|
||||||
|
certs, err := LoadBundle(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadBundle: %v", err)
|
||||||
|
}
|
||||||
|
if len(certs) != 1 {
|
||||||
|
t.Fatalf("len(certs) = %d, want 1", len(certs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadBundle_EmptyPath(t *testing.T) {
|
||||||
|
_, err := LoadBundle("")
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "empty") {
|
||||||
|
t.Fatalf("expected empty-path error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadBundle_MissingFile(t *testing.T) {
|
||||||
|
_, err := LoadBundle("/tmp/does-not-exist-trustanchor.pem")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected file-not-found error, got nil")
|
||||||
|
}
|
||||||
|
if errors.Is(err, nil) {
|
||||||
|
t.Fatalf("error must be non-nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- Holder -----
|
||||||
|
|
||||||
|
func TestHolder_NewLoadsBundle(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "trust.pem")
|
||||||
|
cert := freshHolderCert(t, "initial", time.Now().Add(30*24*time.Hour))
|
||||||
|
writeTestBundle(t, path, []*x509.Certificate{cert})
|
||||||
|
|
||||||
|
holder, err := New(path, silentLogger())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New: %v", err)
|
||||||
|
}
|
||||||
|
got := holder.Get()
|
||||||
|
if len(got) != 1 || got[0].Subject.CommonName != "initial" {
|
||||||
|
t.Fatalf("Get returned %#v, want one cert with CN=initial", got)
|
||||||
|
}
|
||||||
|
if holder.Path() != path {
|
||||||
|
t.Errorf("Path = %q, want %q", holder.Path(), path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHolder_NewRequiresLogger(t *testing.T) {
|
||||||
|
if _, err := New("/nonexistent", nil); err == nil {
|
||||||
|
t.Fatal("nil logger must error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHolder_NewSurfacesLoadError(t *testing.T) {
|
||||||
|
if _, err := New("/path/that/does/not/exist.pem", silentLogger()); err == nil {
|
||||||
|
t.Fatal("missing file must error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHolder_PoolReturnsAllCerts(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "trust.pem")
|
||||||
|
c1 := freshHolderCert(t, "ca-1", time.Now().Add(30*24*time.Hour))
|
||||||
|
c2 := freshHolderCert(t, "ca-2", time.Now().Add(30*24*time.Hour))
|
||||||
|
writeTestBundle(t, path, []*x509.Certificate{c1, c2})
|
||||||
|
|
||||||
|
h, err := New(path, silentLogger())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
pool := h.Pool()
|
||||||
|
if pool == nil {
|
||||||
|
t.Fatal("Pool returned nil")
|
||||||
|
}
|
||||||
|
// pool.Subjects() is deprecated for caller-owned pools that may include
|
||||||
|
// the system roots. We've built this pool ourselves with exactly the
|
||||||
|
// two certs from h.Get(), so it's a safe use — but the linter doesn't
|
||||||
|
// know that. Rather than disable the lint, we cross-check via Equal()
|
||||||
|
// over the underlying cert slice we used to build the pool.
|
||||||
|
got := h.Get()
|
||||||
|
if len(got) != 2 {
|
||||||
|
t.Errorf("Get() len = %d, want 2", len(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHolder_SetLabelForLogIgnoresEmpty(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "trust.pem")
|
||||||
|
writeTestBundle(t, path, []*x509.Certificate{
|
||||||
|
freshHolderCert(t, "label-test", time.Now().Add(30*24*time.Hour)),
|
||||||
|
})
|
||||||
|
h, err := New(path, silentLogger())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
h.SetLabelForLog("") // no-op; default "trust anchor" preserved
|
||||||
|
h.SetLabelForLog("est mTLS client CA bundle")
|
||||||
|
// No public getter for label; just exercise without crashing — race
|
||||||
|
// detector covers the locking contract under -race.
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHolder_ReloadHappyPath(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "trust.pem")
|
||||||
|
c1 := freshHolderCert(t, "rev-1", time.Now().Add(30*24*time.Hour))
|
||||||
|
writeTestBundle(t, path, []*x509.Certificate{c1})
|
||||||
|
|
||||||
|
h, err := New(path, silentLogger())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rotate on disk and call Reload.
|
||||||
|
c2 := freshHolderCert(t, "rev-2", time.Now().Add(30*24*time.Hour))
|
||||||
|
writeTestBundle(t, path, []*x509.Certificate{c2})
|
||||||
|
if err := h.Reload(); err != nil {
|
||||||
|
t.Fatalf("Reload: %v", err)
|
||||||
|
}
|
||||||
|
got := h.Get()
|
||||||
|
if len(got) != 1 || got[0].Subject.CommonName != "rev-2" {
|
||||||
|
t.Errorf("after Reload Get = %#v, want one cert CN=rev-2", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHolder_ReloadKeepsOldOnFailure(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "trust.pem")
|
||||||
|
good := freshHolderCert(t, "stable", time.Now().Add(30*24*time.Hour))
|
||||||
|
writeTestBundle(t, path, []*x509.Certificate{good})
|
||||||
|
|
||||||
|
h, err := New(path, silentLogger())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Overwrite with content that LoadBundle will reject (no PEM blocks).
|
||||||
|
if err := os.WriteFile(path, []byte("garbage"), 0o600); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := h.Reload(); err == nil {
|
||||||
|
t.Fatal("Reload from garbage file must error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Old pool still served.
|
||||||
|
got := h.Get()
|
||||||
|
if len(got) != 1 || got[0].Subject.CommonName != "stable" {
|
||||||
|
t.Errorf("after failed Reload Get should still be the pre-Reload pool; got %#v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHolder_ReloadKeepsOldOnExpired(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "trust.pem")
|
||||||
|
good := freshHolderCert(t, "still-valid", time.Now().Add(30*24*time.Hour))
|
||||||
|
writeTestBundle(t, path, []*x509.Certificate{good})
|
||||||
|
|
||||||
|
h, err := New(path, silentLogger())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expired := freshHolderCert(t, "expired-conn", time.Now().Add(-1*time.Hour))
|
||||||
|
writeTestBundle(t, path, []*x509.Certificate{expired})
|
||||||
|
|
||||||
|
if err := h.Reload(); err == nil {
|
||||||
|
t.Fatal("Reload with expired cert must error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(h.Get()[0].Subject.CommonName, "still-valid") {
|
||||||
|
t.Errorf("after expired-cert Reload, holder should retain old pool")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHolder_WatchSIGHUPReloadsPool(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "trust.pem")
|
||||||
|
c1 := freshHolderCert(t, "rev-pre-sighup", time.Now().Add(30*24*time.Hour))
|
||||||
|
writeTestBundle(t, path, []*x509.Certificate{c1})
|
||||||
|
|
||||||
|
h, err := New(path, silentLogger())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
stop := h.WatchSIGHUP()
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
c2 := freshHolderCert(t, "rev-post-sighup", time.Now().Add(30*24*time.Hour))
|
||||||
|
writeTestBundle(t, path, []*x509.Certificate{c2})
|
||||||
|
if err := syscall.Kill(syscall.Getpid(), syscall.SIGHUP); err != nil {
|
||||||
|
t.Fatalf("send SIGHUP: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline := time.Now().Add(2 * time.Second)
|
||||||
|
for {
|
||||||
|
got := h.Get()
|
||||||
|
if len(got) == 1 && got[0].Subject.CommonName == "rev-post-sighup" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if time.Now().After(deadline) {
|
||||||
|
t.Fatalf("post-SIGHUP pool not swapped in 2s; current CN=%q", got[0].Subject.CommonName)
|
||||||
|
}
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHolder_WatchSIGHUPStopIsClean(t *testing.T) {
|
||||||
|
// We do NOT fire a SIGHUP after stop(): once signal.Stop has removed our
|
||||||
|
// handler the kernel's default action on SIGHUP is to terminate the
|
||||||
|
// process — it would kill the test runner. Pin "stop() is synchronous
|
||||||
|
// and safe" by closing the watcher and verifying the holder still
|
||||||
|
// serves the original cert without panic.
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "trust.pem")
|
||||||
|
writeTestBundle(t, path, []*x509.Certificate{
|
||||||
|
freshHolderCert(t, "stop-test", time.Now().Add(30*24*time.Hour)),
|
||||||
|
})
|
||||||
|
|
||||||
|
h, err := New(path, silentLogger())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
stop := h.WatchSIGHUP()
|
||||||
|
stop()
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
if cn := h.Get()[0].Subject.CommonName; cn != "stop-test" {
|
||||||
|
t.Errorf("after stop CN = %q, want unchanged stop-test", cn)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user