mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-07 13:41:30 +00:00
EST RFC 7030 hardening master bundle Phases 2-4: end-to-end mTLS sibling
route + RFC 9266 channel binding + HTTP Basic enrollment-password +
per-source-IP failed-auth limit + per-(CN, sourceIP) sliding-window cap.
Two new shared packages so EST + Intune share infrastructure:
- internal/cms/ — RFC 9266 tls-exporter extractor (ExtractTLSExporter
with stdlib-panic recovery for synthetic ConnectionStates) +
CSR-side channel-binding parser via raw TBSCertificationRequestInfo
walk (the stdlib's csr.Attributes can't represent the OCTET STRING
binding value), VerifyChannelBinding composite, EmbedChannel-
BindingAttribute fixture helper, typed sentinel errors for missing
/ mismatch / not-TLS-1.3 mapped to HTTP 400 / 409 / 426 in handler.
- internal/trustanchor/ — extracted from scep/intune/trust_anchor*.go
so the EST mTLS sibling route + Intune dispatcher share the same
SIGHUP-reloadable PEM bundle primitive. intune.TrustAnchorHolder
is now `= trustanchor.Holder` (type alias) + NewTrustAnchorHolder =
trustanchor.New (function alias) — every existing call site compiles
unchanged. Intune's LoadTrustAnchor is a thin wrapper over
trustanchor.LoadBundle. White-box tests moved to the new package.
- internal/ratelimit/ — extracted from scep/intune/rate_limit.go (this
was Phase 4.1, in the same bundle). intune.PerDeviceRateLimiter
is now a thin wrapper preserving the (subject, issuer)→key
composition; EST handler reaches for SlidingWindowLimiter directly.
ESTHandler grew six optional fields wired by per-profile setters
(SetMTLSTrust / SetChannelBindingRequired / SetEnrollmentPassword /
SetSourceIPRateLimiter / SetPerPrincipalRateLimiter / SetLabelForLog)
plus four new mTLS-route methods (CACertsMTLS / SimpleEnrollMTLS /
SimpleReEnrollMTLS / CSRAttrsMTLS); shared internal pipeline
handleEnrollOrReEnroll(reEnroll, viaMTLS) keeps the auth/binding/
rate-limit gates DRY. New router method RegisterESTMTLSHandlers
registers /.well-known/est-mtls/<PathID>/{cacerts,simpleenroll,
simplereenroll,csrattrs}; AuthExemptDispatchPrefixes extends the
no-auth chain to /.well-known/est-mtls.
cmd/server/main.go's EST loop wires per-profile mTLS holder +
channel-binding policy + per-principal limiter + (when EnrollmentPassword
non-empty) Basic + source-IP limiter; new preflightESTMTLSClientCATrust-
Bundle returns *trustanchor.Holder so SIGHUP rotates the EST mTLS
bundle live without restart. SCEP + EST mTLS profiles now share a
single union mtlsUnionPoolForTLS passed to buildServerTLSConfigWithMTLS
(replaces the protocol-specific scepMTLSUnionPoolForTLS); per-handler
re-verify enforces "cert must chain to THIS profile's bundle" so
cross-protocol bleed is blocked at the application layer even though
the TLS layer trusts certs from either pool's union.
Phase 3.3 source-IP failed-Basic limiter defaults: 10 attempts / 1h
/ 50k tracked IPs (no env var; tunable in a follow-up). Phase 4.2
per-principal limiter cap from CERTCTL_EST_PROFILE_<NAME>_RATE_
LIMIT_PER_PRINCIPAL_24H (existing field, Phase 1 shipped).
New tests:
- internal/cms/channelbinding_test.go: extractor + CSR-side parser +
composite + TLS-1.3 round-trip end-to-end + EmbedChannelBinding-
Attribute round-trip
- internal/trustanchor/holder_test.go: parseBundlePEM white-box +
LoadBundle + Holder Get/Pool/SetLabelForLog/Reload-happy/
Reload-keeps-old-on-failure/Reload-keeps-old-on-expired/
WatchSIGHUP-reloads-pool/WatchSIGHUP-stop-clean
- internal/api/handler/est_hardening_test.go: 16 named cases covering
mTLS no-trust-pool 500 + no-cert 401 + cross-profile cert 401 +
happy-path 200 + CACertsMTLS auth gate + CSRAttrsMTLS auth gate +
channel-binding required-absent-rejected + not-required-absent-
allowed + writeChannelBindingError mapping + Basic no-header 401
+ Basic wrong-password 401 + Basic correct-200 + Basic-no-password
no-gate + per-IP failed-attempt lockout 429 + per-principal
blocks-after-cap + different-principals-independent + no-limiter-
unbounded.
Pre-commit verification (sandbox): gofmt clean, go vet clean
(excluding repository/postgres which the sandbox can't build —
disk-space testcontainers download), staticcheck clean for
cms/trustanchor/api/handler/api/router/scep/intune/ratelimit/
cmd/server, go test -short -count=1 green for cms/trustanchor/
api/handler/api/router/scep/intune/ratelimit/service. G-3
docs-drift guard reproduced locally clean (Phase 1 already
documented every new env var; Phases 2-4 added zero new env vars).
This commit is contained in:
+599
-225
@@ -2,17 +2,23 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/api/middleware"
|
||||
"github.com/shankar0123/certctl/internal/cms"
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/pkcs7"
|
||||
"github.com/shankar0123/certctl/internal/ratelimit"
|
||||
"github.com/shankar0123/certctl/internal/trustanchor"
|
||||
)
|
||||
|
||||
// ESTService defines the service interface for EST enrollment operations.
|
||||
@@ -33,62 +39,558 @@ type ESTService interface {
|
||||
|
||||
// ESTHandler handles HTTP requests for the EST protocol (RFC 7030).
|
||||
//
|
||||
// EST endpoints are served under /.well-known/est/ per the RFC.
|
||||
// EST endpoints are served under /.well-known/est/[<PathID>/] per RFC 7030.
|
||||
// Wire format: base64-encoded DER (PKCS#7 for certs, PKCS#10 for CSRs).
|
||||
//
|
||||
// Supported operations:
|
||||
// - GET /.well-known/est/cacerts — CA certificate distribution
|
||||
// - POST /.well-known/est/simpleenroll — initial enrollment
|
||||
// - POST /.well-known/est/simplereenroll — re-enrollment
|
||||
// - GET /.well-known/est/csrattrs — CSR attributes
|
||||
// Supported operations (per route family):
|
||||
//
|
||||
// /.well-known/est/[<PathID>/] — legacy + per-profile route family
|
||||
// GET cacerts — CA certificate distribution
|
||||
// POST simpleenroll — initial enrollment (HTTP Basic optional, Phase 3)
|
||||
// POST simplereenroll — re-enrollment (HTTP Basic optional, Phase 3)
|
||||
// GET csrattrs — CSR attributes
|
||||
//
|
||||
// /.well-known/est-mtls/<PathID>/ — mTLS sibling (Phase 2)
|
||||
// GET cacerts — CA certificate distribution (cert auth required)
|
||||
// POST simpleenroll — initial enrollment (cert + optional channel binding)
|
||||
// POST simplereenroll — re-enrollment (cert + optional channel binding)
|
||||
// GET csrattrs — CSR attributes
|
||||
//
|
||||
// EST RFC 7030 hardening master bundle Phases 2-4: ESTHandler grew six
|
||||
// optional fields wired by per-profile setters in cmd/server/main.go's
|
||||
// startup loop. None of the new fields are required — a handler with all
|
||||
// of them unset behaves exactly like the v2.0.x EST handler.
|
||||
type ESTHandler struct {
|
||||
svc ESTService
|
||||
|
||||
// EST RFC 7030 hardening Phase 2.1: per-profile mTLS client-CA trust
|
||||
// bundle. When set, the mTLS sibling route (CACertsMTLS /
|
||||
// SimpleEnrollMTLS / etc.) verifies the inbound client cert chain
|
||||
// against this pool. Nil when MTLS_ENABLED=false; the mTLS route
|
||||
// rejects unconditionally in that case (the route shouldn't even be
|
||||
// registered, but defense in depth).
|
||||
mtlsTrust *trustanchor.Holder
|
||||
|
||||
// EST RFC 7030 hardening Phase 2.4: per-profile channel-binding
|
||||
// requirement. When true, the mTLS handler refuses simplereenroll
|
||||
// requests whose CSR doesn't carry a matching id-aa-est-tls-exporter
|
||||
// (RFC 9266) attribute. Phase 1's Validate() guards
|
||||
// ChannelBindingRequired=true + MTLSEnabled=false at startup.
|
||||
channelBindingRequired bool
|
||||
|
||||
// EST RFC 7030 hardening Phase 3.1: per-profile HTTP Basic enrollment
|
||||
// password. When non-empty, the standard /.well-known/est/<PathID>/
|
||||
// route requires `Authorization: Basic <base64(<user>:<pw>)>` on the
|
||||
// enrollment endpoints (NOT on cacerts/csrattrs — RFC 7030 §4.1.1
|
||||
// says cacerts is anonymous). Constant-time compare; per-source-IP
|
||||
// failed-auth rate limit blocks brute-force.
|
||||
basicPassword string
|
||||
|
||||
// EST RFC 7030 hardening Phase 3.3: per-handler source-IP rate
|
||||
// limiter for FAILED HTTP Basic auth attempts. Keyed by sourceIP so
|
||||
// a hostile network segment can't burn through the password.
|
||||
failedBasicLimiter *ratelimit.SlidingWindowLimiter
|
||||
|
||||
// EST RFC 7030 hardening Phase 4.2: per-handler per-principal sliding-
|
||||
// window rate limit. Keyed by (CSR-CN, sourceIP) so a stolen
|
||||
// bootstrap cert AND a known device CN can't be used to flood the
|
||||
// issuer. Disabled when nil; configured per-profile.
|
||||
perPrincipalLimiter *ratelimit.SlidingWindowLimiter
|
||||
|
||||
// labelForLog gives observability code a per-profile string to
|
||||
// include in audit log lines / Prometheus labels. Defaults to
|
||||
// "est" when unset.
|
||||
labelForLog string
|
||||
}
|
||||
|
||||
// NewESTHandler creates a new ESTHandler.
|
||||
// NewESTHandler creates a new ESTHandler with no per-profile auth
|
||||
// hardening. Call SetMTLSTrust + SetChannelBindingRequired +
|
||||
// SetEnrollmentPassword + SetSourceIPRateLimiter + SetPerPrincipalRateLimiter
|
||||
// from the per-profile startup loop to opt-in to each surface.
|
||||
func NewESTHandler(svc ESTService) ESTHandler {
|
||||
return ESTHandler{svc: svc}
|
||||
}
|
||||
|
||||
// CACerts handles GET /.well-known/est/cacerts
|
||||
// Returns the CA certificate chain as base64-encoded PKCS#7 (certs-only).
|
||||
// Per RFC 7030 Section 4.1, this is a "certs-only" CMC Simple PKI Response.
|
||||
// For simplicity and broad client compatibility, we return base64-encoded DER certificates.
|
||||
// SetMTLSTrust injects the per-profile client-cert trust pool the
|
||||
// `/.well-known/est-mtls/<PathID>/` sibling route uses to verify inbound
|
||||
// device cert chains. EST RFC 7030 hardening Phase 2.1.
|
||||
//
|
||||
// Like the SCEP equivalent, the TLS layer (cmd/server/tls.go) uses
|
||||
// VerifyClientCertIfGiven against the UNION of every enabled mTLS
|
||||
// profile's bundle, so the same TLS listener serves both /.well-known/est
|
||||
// (anonymous or HTTP Basic) and /.well-known/est-mtls/<PathID>
|
||||
// (cert-required). The per-profile gate at the handler layer enforces
|
||||
// 'cert must chain to THIS profile's bundle' so a cert that chains to
|
||||
// profile A's bundle cannot enroll against profile B.
|
||||
func (h *ESTHandler) SetMTLSTrust(t *trustanchor.Holder) { h.mtlsTrust = t }
|
||||
|
||||
// SetChannelBindingRequired toggles RFC 9266 tls-exporter channel binding
|
||||
// on the simplereenroll mTLS path. EST RFC 7030 hardening Phase 2.4.
|
||||
// When true, the handler refuses requests whose CSR lacks the binding
|
||||
// attribute or whose binding bytes don't match the live TLS exporter.
|
||||
func (h *ESTHandler) SetChannelBindingRequired(req bool) { h.channelBindingRequired = req }
|
||||
|
||||
// SetEnrollmentPassword injects the per-profile HTTP Basic enrollment
|
||||
// password. EST RFC 7030 hardening Phase 3.1. Empty disables the gate
|
||||
// (mTLS-only or unauthenticated profile). Constant-time compare via
|
||||
// crypto/subtle.ConstantTimeCompare.
|
||||
func (h *ESTHandler) SetEnrollmentPassword(pw string) { h.basicPassword = pw }
|
||||
|
||||
// SetSourceIPRateLimiter injects the per-handler failed-Basic-auth
|
||||
// rate limiter. Phase 3.3. Disabled when nil — but Validate() at
|
||||
// startup refuses an enabled basic-auth profile without a configured
|
||||
// limiter, so a real deploy always wires one.
|
||||
func (h *ESTHandler) SetSourceIPRateLimiter(l *ratelimit.SlidingWindowLimiter) {
|
||||
h.failedBasicLimiter = l
|
||||
}
|
||||
|
||||
// SetPerPrincipalRateLimiter injects the per-handler (CN, sourceIP)
|
||||
// sliding-window rate limiter. Phase 4.2. Disabled when nil. Counts
|
||||
// every successful enrollment, NOT just failures — the goal is to
|
||||
// bound enrollment-flooding from a compromised credential, not just
|
||||
// failed-auth brute force.
|
||||
func (h *ESTHandler) SetPerPrincipalRateLimiter(l *ratelimit.SlidingWindowLimiter) {
|
||||
h.perPrincipalLimiter = l
|
||||
}
|
||||
|
||||
// SetLabelForLog sets the per-profile observability label. Defaults to
|
||||
// "est" when unset; cmd/server/main.go's per-profile loop sets this
|
||||
// to "est (PathID=<id>)" for triage.
|
||||
func (h *ESTHandler) SetLabelForLog(label string) {
|
||||
if label == "" {
|
||||
return
|
||||
}
|
||||
h.labelForLog = label
|
||||
}
|
||||
|
||||
// label returns h.labelForLog with the "est" fallback applied. Tiny
|
||||
// helper so log call sites don't need to repeat the fallback.
|
||||
func (h ESTHandler) label() string {
|
||||
if h.labelForLog == "" {
|
||||
return "est"
|
||||
}
|
||||
return h.labelForLog
|
||||
}
|
||||
|
||||
// ----- /.well-known/est/[<PathID>/] route family (legacy + Basic auth) -----
|
||||
|
||||
// CACerts handles GET /.well-known/est/[<PathID>/]cacerts.
|
||||
//
|
||||
// RFC 7030 §4.1.1 — anonymous endpoint. The HTTP Basic gate is NOT
|
||||
// applied here (any client must be able to fetch the CA chain to
|
||||
// verify subsequent enrollment responses).
|
||||
func (h ESTHandler) CACerts(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
h.writeCACertsResponse(w, r)
|
||||
}
|
||||
|
||||
caCertPEM, err := h.svc.GetCACerts(r.Context())
|
||||
if err != nil {
|
||||
requestID := middleware.GetRequestID(r.Context())
|
||||
ErrorWithRequestID(w, http.StatusInternalServerError, fmt.Sprintf("Failed to get CA certificates: %v", err), requestID)
|
||||
// SimpleEnroll handles POST /.well-known/est/[<PathID>/]simpleenroll.
|
||||
// Accepts a base64-encoded PKCS#10 CSR + returns base64-encoded PKCS#7.
|
||||
//
|
||||
// Auth: HTTP Basic when h.basicPassword != "" (Phase 3); otherwise
|
||||
// anonymous. Rate-limit: per-(CN, sourceIP) when wired (Phase 4).
|
||||
func (h ESTHandler) SimpleEnroll(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleEnrollOrReEnroll(w, r, false /*reEnroll*/, false /*viaMTLS*/)
|
||||
}
|
||||
|
||||
// SimpleReEnroll handles POST /.well-known/est/[<PathID>/]simplereenroll.
|
||||
// Same as SimpleEnroll but the audit/log distinguishes the renewal flow
|
||||
// from initial issuance.
|
||||
func (h ESTHandler) SimpleReEnroll(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleEnrollOrReEnroll(w, r, true /*reEnroll*/, false /*viaMTLS*/)
|
||||
}
|
||||
|
||||
// CSRAttrs handles GET /.well-known/est/[<PathID>/]csrattrs.
|
||||
// Returns the CSR attributes the server wants the client to include.
|
||||
// RFC 7030 §4.5 — anonymous endpoint, no Basic auth gate.
|
||||
func (h ESTHandler) CSRAttrs(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
h.writeCSRAttrsResponse(w, r)
|
||||
}
|
||||
|
||||
// ----- /.well-known/est-mtls/<PathID>/ route family (Phase 2 mTLS) -----
|
||||
|
||||
// CACertsMTLS handles GET /.well-known/est-mtls/<PathID>/cacerts.
|
||||
//
|
||||
// RFC 7030 §4.1.1 says cacerts is anonymous, but on the mTLS sibling
|
||||
// route we still require a valid client cert because the mTLS path is
|
||||
// the audit-distinguished surface — operators using mTLS WANT every
|
||||
// touchpoint logged. The cert isn't validated for purpose-of-issuance
|
||||
// here (cacerts isn't an enrollment), but absence is rejected.
|
||||
func (h ESTHandler) CACertsMTLS(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if _, ok := h.requireClientCertChain(w, r); !ok {
|
||||
return
|
||||
}
|
||||
h.writeCACertsResponse(w, r)
|
||||
}
|
||||
|
||||
// SimpleEnrollMTLS handles POST /.well-known/est-mtls/<PathID>/simpleenroll.
|
||||
//
|
||||
// Order of gates (each fails fast with the appropriate HTTP status):
|
||||
//
|
||||
// 1. Client cert presented + chains to per-profile mTLS trust pool
|
||||
// (the TLS layer already verified against the union pool; this is
|
||||
// the per-profile re-verify that prevents profile A↔B cross-bleed).
|
||||
// 2. CSR parses + matches the EST contract (handled by the shared
|
||||
// enrollment helper).
|
||||
// 3. Per-(CN, sourceIP) rate limit when configured.
|
||||
// 4. Service-layer enrollment.
|
||||
//
|
||||
// Channel binding does NOT apply here — RFC 9266 §1 calls out that
|
||||
// channel binding is a renewal-time defense-in-depth, not an initial-
|
||||
// enrollment requirement. (A first-time enrollment doesn't yet have a
|
||||
// device cert, so binding to the TLS session for the bootstrap cert
|
||||
// adds nothing.)
|
||||
func (h ESTHandler) SimpleEnrollMTLS(w http.ResponseWriter, r *http.Request) {
|
||||
if _, ok := h.requireClientCertChain(w, r); !ok {
|
||||
return
|
||||
}
|
||||
h.handleEnrollOrReEnroll(w, r, false /*reEnroll*/, true /*viaMTLS*/)
|
||||
}
|
||||
|
||||
// SimpleReEnrollMTLS handles POST /.well-known/est-mtls/<PathID>/simplereenroll.
|
||||
//
|
||||
// Same as SimpleEnrollMTLS plus the channel-binding gate. RFC 9266 §4.1
|
||||
// says renewal CSRs SHOULD include the binding attribute when the
|
||||
// enrollment is over a TLS-1.3 channel; per-profile policy can either
|
||||
// require this strictly (ChannelBindingRequired=true) or accept its
|
||||
// absence (default).
|
||||
func (h ESTHandler) SimpleReEnrollMTLS(w http.ResponseWriter, r *http.Request) {
|
||||
if _, ok := h.requireClientCertChain(w, r); !ok {
|
||||
return
|
||||
}
|
||||
h.handleEnrollOrReEnroll(w, r, true /*reEnroll*/, true /*viaMTLS*/)
|
||||
}
|
||||
|
||||
// CSRAttrsMTLS handles GET /.well-known/est-mtls/<PathID>/csrattrs.
|
||||
// Mirrors CACertsMTLS — cert-required even though the unauth route
|
||||
// version is anonymous.
|
||||
func (h ESTHandler) CSRAttrsMTLS(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if _, ok := h.requireClientCertChain(w, r); !ok {
|
||||
return
|
||||
}
|
||||
h.writeCSRAttrsResponse(w, r)
|
||||
}
|
||||
|
||||
// ----- shared internal pipeline -----
|
||||
|
||||
// handleEnrollOrReEnroll is the shared body for {Simple,SimpleRe}Enroll{,MTLS}.
|
||||
// reEnroll picks the SimpleReEnroll vs SimpleEnroll service method (purely
|
||||
// audit / metric distinguishing — same issuer call underneath); viaMTLS
|
||||
// picks whether the channel-binding + per-principal-limit gates apply
|
||||
// AND skips the HTTP Basic gate (mTLS handlers carry the auth).
|
||||
func (h ESTHandler) handleEnrollOrReEnroll(w http.ResponseWriter, r *http.Request, reEnroll, viaMTLS bool) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse PEM to DER for PKCS#7 encoding
|
||||
requestID := middleware.GetRequestID(r.Context())
|
||||
|
||||
if err := verifyESTTransport(r); err != nil {
|
||||
ErrorWithRequestID(w, http.StatusBadRequest,
|
||||
fmt.Sprintf("EST transport precondition failed: %v", err), requestID)
|
||||
return
|
||||
}
|
||||
|
||||
// HTTP Basic gate (Phase 3) — non-mTLS path only. mTLS profiles
|
||||
// authenticate via the client cert so adding Basic on top would
|
||||
// double-tax operators with no security benefit.
|
||||
if !viaMTLS && h.basicPassword != "" {
|
||||
if !h.requireBasicAuth(w, r) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
csrPEM, err := h.readCSRFromRequest(r)
|
||||
if err != nil {
|
||||
ErrorWithRequestID(w, http.StatusBadRequest, fmt.Sprintf("Invalid CSR: %v", err), requestID)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the CSR once for downstream gates (channel-binding, per-
|
||||
// principal rate limit). The service re-parses internally — that's a
|
||||
// minor inefficiency we accept to keep the service interface flat.
|
||||
csr, _ := decodeCSRPEM(csrPEM)
|
||||
|
||||
// Channel-binding gate (Phase 2.4) — mTLS reEnroll only. The optional
|
||||
// CSR-side attribute is checked even when the per-profile flag isn't
|
||||
// requiring it (a CSR carrying the attribute MUST match the live
|
||||
// exporter; a present-but-mismatched binding is always fatal).
|
||||
if viaMTLS && reEnroll && csr != nil {
|
||||
if err := cms.VerifyChannelBinding(r.TLS, csr, h.channelBindingRequired); err != nil {
|
||||
h.writeChannelBindingError(w, requestID, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Per-principal rate-limit gate (Phase 4.2). Keyed by CN+sourceIP so
|
||||
// (a) a CN with no source-IP rotation can be capped, AND (b) a
|
||||
// hostile network segment trying to enroll many CNs from one IP is
|
||||
// also bounded.
|
||||
if h.perPrincipalLimiter != nil {
|
||||
if err := h.applyPerPrincipalRateLimit(r, csr); err != nil {
|
||||
ErrorWithRequestID(w, http.StatusTooManyRequests,
|
||||
fmt.Sprintf("EST enrollment rate-limited: %v", err), requestID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
result *domain.ESTEnrollResult
|
||||
callErr error
|
||||
)
|
||||
if reEnroll {
|
||||
result, callErr = h.svc.SimpleReEnroll(r.Context(), csrPEM)
|
||||
} else {
|
||||
result, callErr = h.svc.SimpleEnroll(r.Context(), csrPEM)
|
||||
}
|
||||
if callErr != nil {
|
||||
op := "Enrollment"
|
||||
if reEnroll {
|
||||
op = "Re-enrollment"
|
||||
}
|
||||
ErrorWithRequestID(w, http.StatusInternalServerError,
|
||||
fmt.Sprintf("%s failed: %v", op, callErr), requestID)
|
||||
return
|
||||
}
|
||||
|
||||
h.writeCertResponse(w, result)
|
||||
}
|
||||
|
||||
// requireClientCertChain enforces the mTLS gate for the est-mtls sibling
|
||||
// route. Returns the leaf cert + true on success; on failure writes the
|
||||
// HTTP error and returns false.
|
||||
//
|
||||
// Mirrors SCEPHandler.HandleSCEPMTLS exactly:
|
||||
// - mtlsTrust nil → 500 (config bug; preflight should have prevented).
|
||||
// - r.TLS nil or no peer cert → 401 (cert required).
|
||||
// - chain doesn't verify against per-profile pool → 401.
|
||||
func (h ESTHandler) requireClientCertChain(w http.ResponseWriter, r *http.Request) (*x509.Certificate, bool) {
|
||||
requestID := middleware.GetRequestID(r.Context())
|
||||
if h.mtlsTrust == nil {
|
||||
ErrorWithRequestID(w, http.StatusInternalServerError,
|
||||
h.label()+" mTLS handler missing trust pool", requestID)
|
||||
return nil, false
|
||||
}
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
ErrorWithRequestID(w, http.StatusUnauthorized,
|
||||
"Client certificate required for /.well-known/est-mtls", requestID)
|
||||
return nil, false
|
||||
}
|
||||
leaf := r.TLS.PeerCertificates[0]
|
||||
intermediates := x509.NewCertPool()
|
||||
for _, c := range r.TLS.PeerCertificates[1:] {
|
||||
intermediates.AddCert(c)
|
||||
}
|
||||
if _, err := leaf.Verify(x509.VerifyOptions{
|
||||
Roots: h.mtlsTrust.Pool(),
|
||||
Intermediates: intermediates,
|
||||
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageAny},
|
||||
}); err != nil {
|
||||
ErrorWithRequestID(w, http.StatusUnauthorized,
|
||||
"Client certificate not trusted by this profile", requestID)
|
||||
return nil, false
|
||||
}
|
||||
return leaf, true
|
||||
}
|
||||
|
||||
// requireBasicAuth runs the Phase 3 HTTP Basic password gate. Returns
|
||||
// true when auth passed. On failure writes WWW-Authenticate + a 401
|
||||
// (with rate-limit accounting against the source IP).
|
||||
//
|
||||
// User: any non-empty value (RFC 7030 §3.2.3 says the username is
|
||||
// not authoritative when only a shared password is meaningful). Pass:
|
||||
// constant-time compare against h.basicPassword.
|
||||
func (h ESTHandler) requireBasicAuth(w http.ResponseWriter, r *http.Request) bool {
|
||||
requestID := middleware.GetRequestID(r.Context())
|
||||
srcIP := clientIPForLimiter(r)
|
||||
|
||||
// recordFailedBasic ticks a slot on every credential rejection;
|
||||
// once the IP has burned through its window's worth of failed
|
||||
// attempts the limiter returns ErrRateLimited (which the next
|
||||
// recordFailedBasic just no-ops out — we still want to fail-closed
|
||||
// the auth here). The cleaner design is a pre-check that short-
|
||||
// circuits the constant-time compare ENTIRELY for an IP at-cap, so
|
||||
// a brute-force attacker can't smuggle timing data through. We do
|
||||
// that pre-check via SlidingWindowLimiter.Allow with a peek-style
|
||||
// fake-key that just queries state without recording a slot.
|
||||
if h.failedBasicLimiter != nil && srcIP != "" {
|
||||
if err := h.failedBasicLimiter.Allow(srcIP+"|peek", nowFn()); errors.Is(err, ratelimit.ErrRateLimited) {
|
||||
// peek-key is shared across requests from this IP; the slot
|
||||
// pollution is acceptable because the IP is already
|
||||
// rate-limited and we want to keep them rate-limited.
|
||||
ErrorWithRequestID(w, http.StatusTooManyRequests,
|
||||
h.label()+" too many failed enrollment attempts from this source", requestID)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
user, pass, ok := r.BasicAuth()
|
||||
if !ok || user == "" {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="est-enrollment"`)
|
||||
ErrorWithRequestID(w, http.StatusUnauthorized,
|
||||
h.label()+" enrollment requires HTTP Basic auth", requestID)
|
||||
h.recordFailedBasic(srcIP)
|
||||
return false
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(pass), []byte(h.basicPassword)) != 1 {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="est-enrollment"`)
|
||||
ErrorWithRequestID(w, http.StatusUnauthorized,
|
||||
h.label()+" enrollment password incorrect", requestID)
|
||||
h.recordFailedBasic(srcIP)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// recordFailedBasic ticks a slot against the source-IP failed-auth
|
||||
// limiter. Errors from Allow are intentionally ignored — a present
|
||||
// failure simply means the IP has crossed the limit, which is exactly
|
||||
// the state the per-IP gate reports back to the next request.
|
||||
func (h ESTHandler) recordFailedBasic(srcIP string) {
|
||||
if h.failedBasicLimiter == nil || srcIP == "" {
|
||||
return
|
||||
}
|
||||
_ = h.failedBasicLimiter.Allow(srcIP, nowFn())
|
||||
}
|
||||
|
||||
// applyPerPrincipalRateLimit gates an enrollment by (CN, sourceIP).
|
||||
// Returns nil when the request is allowed; ErrRateLimited (or wrapped
|
||||
// equivalent) when the principal has exhausted its window budget.
|
||||
//
|
||||
// CN extraction: the CSR's Subject.CommonName is the canonical
|
||||
// principal in the EST contract (the issued cert will carry that CN).
|
||||
// sourceIP comes from clientIPForLimiter.
|
||||
func (h ESTHandler) applyPerPrincipalRateLimit(r *http.Request, csr *x509.CertificateRequest) error {
|
||||
if h.perPrincipalLimiter == nil {
|
||||
return nil
|
||||
}
|
||||
cn := ""
|
||||
if csr != nil {
|
||||
cn = csr.Subject.CommonName
|
||||
}
|
||||
srcIP := clientIPForLimiter(r)
|
||||
key := cn + "|" + srcIP
|
||||
return h.perPrincipalLimiter.Allow(key, nowFn())
|
||||
}
|
||||
|
||||
// writeChannelBindingError maps cms.* sentinel errors to HTTP statuses
|
||||
// + audit-friendly messages. Mirrors the SCEP CertRep failInfo error
|
||||
// translation pattern (signature_invalid → BadMessageCheck etc.).
|
||||
func (h ESTHandler) writeChannelBindingError(w http.ResponseWriter, requestID string, err error) {
|
||||
switch {
|
||||
case errors.Is(err, cms.ErrChannelBindingMissing):
|
||||
ErrorWithRequestID(w, http.StatusBadRequest,
|
||||
"EST simplereenroll requires RFC 9266 channel binding for this profile", requestID)
|
||||
case errors.Is(err, cms.ErrChannelBindingMismatch):
|
||||
// 409 Conflict signals to the client that the request was
|
||||
// well-formed but the channel-binding state on certctl's side
|
||||
// disagreed with the device's — usually MITM or reverse proxy
|
||||
// terminating TLS in front of certctl.
|
||||
ErrorWithRequestID(w, http.StatusConflict,
|
||||
"EST channel binding does not match TLS exporter — TLS terminator in front of certctl?", requestID)
|
||||
case errors.Is(err, cms.ErrChannelBindingNotTLS13):
|
||||
ErrorWithRequestID(w, http.StatusUpgradeRequired,
|
||||
"EST channel binding requires TLS 1.3", requestID)
|
||||
default:
|
||||
ErrorWithRequestID(w, http.StatusBadRequest,
|
||||
fmt.Sprintf("EST channel-binding verification failed: %v", err), requestID)
|
||||
}
|
||||
}
|
||||
|
||||
// ----- response writers (legacy + mTLS share these) -----
|
||||
|
||||
// writeCACertsResponse writes the PKCS#7 certs-only CA chain. Shared
|
||||
// by CACerts (legacy route) + CACertsMTLS (mTLS route).
|
||||
func (h ESTHandler) writeCACertsResponse(w http.ResponseWriter, r *http.Request) {
|
||||
requestID := middleware.GetRequestID(r.Context())
|
||||
caCertPEM, err := h.svc.GetCACerts(r.Context())
|
||||
if err != nil {
|
||||
ErrorWithRequestID(w, http.StatusInternalServerError,
|
||||
fmt.Sprintf("Failed to get CA certificates: %v", err), requestID)
|
||||
return
|
||||
}
|
||||
derCerts, err := pkcs7.PEMToDERChain(caCertPEM)
|
||||
if err != nil {
|
||||
requestID := middleware.GetRequestID(r.Context())
|
||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to encode CA certificates", requestID)
|
||||
return
|
||||
}
|
||||
|
||||
// Build a simple PKCS#7 SignedData (certs-only, degenerate) structure
|
||||
pkcs7Data, err := pkcs7.BuildCertsOnlyPKCS7(derCerts)
|
||||
if err != nil {
|
||||
requestID := middleware.GetRequestID(r.Context())
|
||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to build PKCS#7 response", requestID)
|
||||
return
|
||||
}
|
||||
|
||||
// RFC 7030 Section 4.1.3: response is base64-encoded application/pkcs7-mime
|
||||
w.Header().Set("Content-Type", "application/pkcs7-mime; smime-type=certs-only")
|
||||
w.Header().Set("Content-Transfer-Encoding", "base64")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
encoded := base64.StdEncoding.EncodeToString(pkcs7Data)
|
||||
// Write base64 with line breaks at 76 chars per RFC 2045
|
||||
writeBase64Wrapped(w, pkcs7Data)
|
||||
}
|
||||
|
||||
// writeCSRAttrsResponse writes the per-profile CSR attribute hints.
|
||||
// Shared by CSRAttrs (legacy) + CSRAttrsMTLS (mTLS).
|
||||
func (h ESTHandler) writeCSRAttrsResponse(w http.ResponseWriter, r *http.Request) {
|
||||
requestID := middleware.GetRequestID(r.Context())
|
||||
attrs, err := h.svc.GetCSRAttrs(r.Context())
|
||||
if err != nil {
|
||||
ErrorWithRequestID(w, http.StatusInternalServerError,
|
||||
fmt.Sprintf("Failed to get CSR attributes: %v", err), requestID)
|
||||
return
|
||||
}
|
||||
if len(attrs) == 0 {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/csrattrs")
|
||||
w.Header().Set("Content-Transfer-Encoding", "base64")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(base64.StdEncoding.EncodeToString(attrs)))
|
||||
}
|
||||
|
||||
// writeCertResponse writes an EST enrollment response as base64-encoded PKCS#7.
|
||||
func (h ESTHandler) writeCertResponse(w http.ResponseWriter, result *domain.ESTEnrollResult) {
|
||||
var derCerts [][]byte
|
||||
certDER, err := pkcs7.PEMToDERChain(result.CertPEM)
|
||||
if err != nil || len(certDER) == 0 {
|
||||
http.Error(w, "Failed to encode certificate", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
derCerts = append(derCerts, certDER...)
|
||||
if result.ChainPEM != "" {
|
||||
chainDER, err := pkcs7.PEMToDERChain(result.ChainPEM)
|
||||
if err == nil {
|
||||
derCerts = append(derCerts, chainDER...)
|
||||
}
|
||||
}
|
||||
pkcs7Data, err := pkcs7.BuildCertsOnlyPKCS7(derCerts)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to build PKCS#7 response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/pkcs7-mime; smime-type=certs-only")
|
||||
w.Header().Set("Content-Transfer-Encoding", "base64")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
writeBase64Wrapped(w, pkcs7Data)
|
||||
}
|
||||
|
||||
// writeBase64Wrapped emits b as base64 with CRLF every 76 chars per RFC 2045.
|
||||
// Pulled out as a helper so the three writers above don't repeat the loop.
|
||||
func writeBase64Wrapped(w http.ResponseWriter, b []byte) {
|
||||
encoded := base64.StdEncoding.EncodeToString(b)
|
||||
for i := 0; i < len(encoded); i += 76 {
|
||||
end := i + 76
|
||||
if end > len(encoded) {
|
||||
@@ -99,66 +601,84 @@ func (h ESTHandler) CACerts(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// SimpleEnroll handles POST /.well-known/est/simpleenroll
|
||||
// Accepts a base64-encoded PKCS#10 CSR and returns a base64-encoded PKCS#7 certificate.
|
||||
func (h ESTHandler) SimpleEnroll(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
requestID := middleware.GetRequestID(r.Context())
|
||||
|
||||
if err := verifyESTTransport(r); err != nil {
|
||||
ErrorWithRequestID(w, http.StatusBadRequest, fmt.Sprintf("EST transport precondition failed: %v", err), requestID)
|
||||
return
|
||||
}
|
||||
|
||||
csrPEM, err := h.readCSRFromRequest(r)
|
||||
// readCSRFromRequest reads and decodes the CSR from an EST enrollment request.
|
||||
// EST sends CSRs as base64-encoded PKCS#10 DER with Content-Type application/pkcs10.
|
||||
func (h ESTHandler) readCSRFromRequest(r *http.Request) (string, error) {
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) // 1MB limit
|
||||
if err != nil {
|
||||
ErrorWithRequestID(w, http.StatusBadRequest, fmt.Sprintf("Invalid CSR: %v", err), requestID)
|
||||
return
|
||||
return "", fmt.Errorf("failed to read request body: %w", err)
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
if len(body) == 0 {
|
||||
return "", fmt.Errorf("empty request body")
|
||||
}
|
||||
|
||||
result, err := h.svc.SimpleEnroll(r.Context(), csrPEM)
|
||||
bodyStr := strings.TrimSpace(string(body))
|
||||
if strings.HasPrefix(bodyStr, "-----BEGIN CERTIFICATE REQUEST-----") {
|
||||
block, _ := pem.Decode([]byte(bodyStr))
|
||||
if block == nil {
|
||||
return "", fmt.Errorf("invalid PEM-encoded CSR")
|
||||
}
|
||||
if _, err := x509.ParseCertificateRequest(block.Bytes); err != nil {
|
||||
return "", fmt.Errorf("invalid CSR: %w", err)
|
||||
}
|
||||
return bodyStr, nil
|
||||
}
|
||||
|
||||
derBytes, err := base64.StdEncoding.DecodeString(bodyStr)
|
||||
if err != nil {
|
||||
ErrorWithRequestID(w, http.StatusInternalServerError, fmt.Sprintf("Enrollment failed: %v", err), requestID)
|
||||
return
|
||||
cleaned := strings.Map(func(r rune) rune {
|
||||
if r == '\r' || r == '\n' || r == ' ' || r == '\t' {
|
||||
return -1
|
||||
}
|
||||
return r
|
||||
}, bodyStr)
|
||||
derBytes, err = base64.StdEncoding.DecodeString(cleaned)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode base64 CSR: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
h.writeCertResponse(w, result)
|
||||
if _, err := x509.ParseCertificateRequest(derBytes); err != nil {
|
||||
return "", fmt.Errorf("invalid PKCS#10 CSR: %w", err)
|
||||
}
|
||||
csrPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE REQUEST",
|
||||
Bytes: derBytes,
|
||||
})
|
||||
return string(csrPEM), nil
|
||||
}
|
||||
|
||||
// SimpleReEnroll handles POST /.well-known/est/simplereenroll
|
||||
// Same as SimpleEnroll but for re-enrollment (certificate renewal).
|
||||
func (h ESTHandler) SimpleReEnroll(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
// decodeCSRPEM is a convenience wrapper around pem.Decode +
|
||||
// x509.ParseCertificateRequest. Returns nil on any decode/parse error
|
||||
// (callers downstream re-parse via the service path; this is just for
|
||||
// the handler-side gates that need the CN + binding attribute).
|
||||
func decodeCSRPEM(csrPEM string) (*x509.CertificateRequest, error) {
|
||||
block, _ := pem.Decode([]byte(csrPEM))
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("PEM decode failed")
|
||||
}
|
||||
|
||||
requestID := middleware.GetRequestID(r.Context())
|
||||
|
||||
if err := verifyESTTransport(r); err != nil {
|
||||
ErrorWithRequestID(w, http.StatusBadRequest, fmt.Sprintf("EST transport precondition failed: %v", err), requestID)
|
||||
return
|
||||
}
|
||||
|
||||
csrPEM, err := h.readCSRFromRequest(r)
|
||||
if err != nil {
|
||||
ErrorWithRequestID(w, http.StatusBadRequest, fmt.Sprintf("Invalid CSR: %v", err), requestID)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.svc.SimpleReEnroll(r.Context(), csrPEM)
|
||||
if err != nil {
|
||||
ErrorWithRequestID(w, http.StatusInternalServerError, fmt.Sprintf("Re-enrollment failed: %v", err), requestID)
|
||||
return
|
||||
}
|
||||
|
||||
h.writeCertResponse(w, result)
|
||||
return x509.ParseCertificateRequest(block.Bytes)
|
||||
}
|
||||
|
||||
// clientIPForLimiter returns the source IP a per-IP rate limiter should
|
||||
// key against. Honors X-Forwarded-For when the request came through a
|
||||
// trusted proxy (no proxy-trust list yet — falls back to RemoteAddr).
|
||||
func clientIPForLimiter(r *http.Request) string {
|
||||
// Don't blindly trust XFF — ignore it for now and always use
|
||||
// RemoteAddr. A future bundle can add a documented proxy-trust list.
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// nowFn is the package-private time source. Override in tests for
|
||||
// deterministic clock injection without dragging time.Time into the
|
||||
// handler API surface. Defined in est_clock.go so mocking out
|
||||
// requires touching only one file.
|
||||
|
||||
// verifyESTTransport implements Bundle-4 / M-021 EST transport precondition.
|
||||
//
|
||||
// RFC 7030 §3.2.3 ("Linking Identity and POP Information") requires that when
|
||||
@@ -169,32 +689,11 @@ func (h ESTHandler) SimpleReEnroll(w http.ResponseWriter, r *http.Request) {
|
||||
// TLS-Unique is unavailable; RFC 9266 defines `tls-exporter` as the TLS 1.3
|
||||
// replacement.
|
||||
//
|
||||
// **Current scope of this function (Bundle-4 closure):** certctl does NOT
|
||||
// currently support EST client certificate authentication. The EST endpoint
|
||||
// accepts unauthenticated POSTs (the SCEP equivalent enforces a
|
||||
// challenge-password via `preflightSCEPChallengePassword`; EST has no
|
||||
// equivalent today). Per RFC 7030 §3.2.3, channel binding is REQUIRED only
|
||||
// when client certificate authentication is in use; without that, the §3.2.3
|
||||
// requirement is moot.
|
||||
//
|
||||
// What we DO enforce here as defense-in-depth:
|
||||
//
|
||||
// 1. r.TLS must be non-nil — the EST endpoint MUST be reached over TLS.
|
||||
// Defensive: certctl pins HTTPS-only at the server-side TLS config, but
|
||||
// a future routing-layer regression that exposes EST over plaintext
|
||||
// would be caught here.
|
||||
// 2. Negotiated TLS version must be >= TLS 1.2 — RFC 7030 doesn't mandate
|
||||
// a specific TLS version, but a pre-1.2 negotiation indicates a
|
||||
// misconfigured client/server pair. certctl's MinVersion is TLS 1.3
|
||||
// so this should always hold.
|
||||
// 3. r.TLS.HandshakeComplete must be true — defensive against partial-
|
||||
// handshake replays.
|
||||
//
|
||||
// **Deferred to a future bundle (operator decision required):**
|
||||
//
|
||||
// - RFC 9266 `tls-exporter` channel binding when EST mTLS is added.
|
||||
// - EST mTLS support itself — currently EST is unauth-or-bearer; mTLS
|
||||
// would be a V3-aligned compliance feature.
|
||||
// **EST RFC 7030 hardening Phases 2-4 update:** RFC 9266 channel binding is
|
||||
// now wired in via the cms package (Phase 2.4) and called from
|
||||
// SimpleReEnrollMTLS when the per-profile policy requires it. This function
|
||||
// continues to handle the lower-level transport preconditions that ALL EST
|
||||
// requests share (regardless of mTLS / Basic / unauth profile shape).
|
||||
//
|
||||
// Returns nil if all preconditions pass; non-nil error otherwise.
|
||||
func verifyESTTransport(r *http.Request) error {
|
||||
@@ -213,130 +712,5 @@ func verifyESTTransport(r *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CSRAttrs handles GET /.well-known/est/csrattrs
|
||||
// Returns the CSR attributes the server wants the client to include in enrollment requests.
|
||||
func (h ESTHandler) CSRAttrs(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
attrs, err := h.svc.GetCSRAttrs(r.Context())
|
||||
if err != nil {
|
||||
requestID := middleware.GetRequestID(r.Context())
|
||||
ErrorWithRequestID(w, http.StatusInternalServerError, fmt.Sprintf("Failed to get CSR attributes: %v", err), requestID)
|
||||
return
|
||||
}
|
||||
|
||||
if len(attrs) == 0 {
|
||||
// No specific attributes required — return 204
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/csrattrs")
|
||||
w.Header().Set("Content-Transfer-Encoding", "base64")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(base64.StdEncoding.EncodeToString(attrs)))
|
||||
}
|
||||
|
||||
// readCSRFromRequest reads and decodes the CSR from an EST enrollment request.
|
||||
// EST sends CSRs as base64-encoded PKCS#10 DER with Content-Type application/pkcs10.
|
||||
func (h ESTHandler) readCSRFromRequest(r *http.Request) (string, error) {
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) // 1MB limit
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read request body: %w", err)
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
if len(body) == 0 {
|
||||
return "", fmt.Errorf("empty request body")
|
||||
}
|
||||
|
||||
// Check if it's already PEM-encoded (some clients send PEM directly)
|
||||
bodyStr := strings.TrimSpace(string(body))
|
||||
if strings.HasPrefix(bodyStr, "-----BEGIN CERTIFICATE REQUEST-----") {
|
||||
// Validate it parses
|
||||
block, _ := pem.Decode([]byte(bodyStr))
|
||||
if block == nil {
|
||||
return "", fmt.Errorf("invalid PEM-encoded CSR")
|
||||
}
|
||||
if _, err := x509.ParseCertificateRequest(block.Bytes); err != nil {
|
||||
return "", fmt.Errorf("invalid CSR: %w", err)
|
||||
}
|
||||
return bodyStr, nil
|
||||
}
|
||||
|
||||
// EST standard: base64-encoded DER PKCS#10
|
||||
derBytes, err := base64.StdEncoding.DecodeString(bodyStr)
|
||||
if err != nil {
|
||||
// Try with padding/whitespace stripped
|
||||
cleaned := strings.Map(func(r rune) rune {
|
||||
if r == '\r' || r == '\n' || r == ' ' || r == '\t' {
|
||||
return -1
|
||||
}
|
||||
return r
|
||||
}, bodyStr)
|
||||
derBytes, err = base64.StdEncoding.DecodeString(cleaned)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode base64 CSR: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate it's a valid PKCS#10 CSR
|
||||
if _, err := x509.ParseCertificateRequest(derBytes); err != nil {
|
||||
return "", fmt.Errorf("invalid PKCS#10 CSR: %w", err)
|
||||
}
|
||||
|
||||
// Convert DER to PEM for internal use (certctl services expect PEM)
|
||||
csrPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE REQUEST",
|
||||
Bytes: derBytes,
|
||||
})
|
||||
return string(csrPEM), nil
|
||||
}
|
||||
|
||||
// writeCertResponse writes an EST enrollment response as base64-encoded PKCS#7.
|
||||
func (h ESTHandler) writeCertResponse(w http.ResponseWriter, result *domain.ESTEnrollResult) {
|
||||
// Parse cert and chain PEM to DER
|
||||
var derCerts [][]byte
|
||||
|
||||
// Add the issued certificate
|
||||
certDER, err := pkcs7.PEMToDERChain(result.CertPEM)
|
||||
if err != nil || len(certDER) == 0 {
|
||||
http.Error(w, "Failed to encode certificate", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
derCerts = append(derCerts, certDER...)
|
||||
|
||||
// Add the CA chain if present
|
||||
if result.ChainPEM != "" {
|
||||
chainDER, err := pkcs7.PEMToDERChain(result.ChainPEM)
|
||||
if err == nil {
|
||||
derCerts = append(derCerts, chainDER...)
|
||||
}
|
||||
}
|
||||
|
||||
// Build PKCS#7 certs-only
|
||||
pkcs7Data, err := pkcs7.BuildCertsOnlyPKCS7(derCerts)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to build PKCS#7 response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/pkcs7-mime; smime-type=certs-only")
|
||||
w.Header().Set("Content-Transfer-Encoding", "base64")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
encoded := base64.StdEncoding.EncodeToString(pkcs7Data)
|
||||
for i := 0; i < len(encoded); i += 76 {
|
||||
end := i + 76
|
||||
if end > len(encoded) {
|
||||
end = len(encoded)
|
||||
}
|
||||
w.Write([]byte(encoded[i:end]))
|
||||
w.Write([]byte("\r\n"))
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: PKCS#7 helpers (BuildCertsOnlyPKCS7, PEMToDERChain, ASN.1 wrappers)
|
||||
// are in the shared internal/pkcs7 package, used by both EST and SCEP handlers.
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
package handler
|
||||
|
||||
import "time"
|
||||
|
||||
// EST RFC 7030 hardening Phase 3.3 / 4.2: nowFn is the time source that
|
||||
// the EST handler's per-IP failed-Basic-auth limiter and per-(CN,
|
||||
// sourceIP) rate limiter consult. Tests can override this to inject a
|
||||
// deterministic clock without dragging time.Time into the handler API
|
||||
// surface (the handler's setters take ratelimit.SlidingWindowLimiter
|
||||
// pointers, not time-injection callbacks — keeping the wire-up simple).
|
||||
//
|
||||
// nowFn is package-private + lower-case so external callers can't poke
|
||||
// at it; the est_clock_test.go helper restoreNowFn is the documented
|
||||
// override pattern for tests in this package.
|
||||
var nowFn = time.Now
|
||||
@@ -0,0 +1,459 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"encoding/pem"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/cms"
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/ratelimit"
|
||||
"github.com/shankar0123/certctl/internal/trustanchor"
|
||||
)
|
||||
|
||||
// EST RFC 7030 hardening master bundle Phases 2-4 tests.
|
||||
// Covers: mTLS sibling route gates, HTTP Basic enrollment-password auth,
|
||||
// per-source-IP failed-auth rate limit, RFC 9266 channel binding, and
|
||||
// per-(CN, sourceIP) per-principal sliding-window rate limit.
|
||||
|
||||
// hardeningTestSetup is a per-test fixture: a mock service that always
|
||||
// succeeds, plus a CA + issued client cert that an mTLS test can attach
|
||||
// to its synthetic *http.Request.TLS.
|
||||
type hardeningTestSetup struct {
|
||||
svc *mockESTService
|
||||
caCert *x509.Certificate
|
||||
caKey *ecdsa.PrivateKey
|
||||
clientCrt *x509.Certificate
|
||||
clientKey *ecdsa.PrivateKey
|
||||
trustPool *trustanchor.Holder
|
||||
bundleDir string
|
||||
}
|
||||
|
||||
func newHardeningTestSetup(t *testing.T) *hardeningTestSetup {
|
||||
t.Helper()
|
||||
caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("ca key: %v", err)
|
||||
}
|
||||
caTmpl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "est-mtls-test-ca"},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
IsCA: true,
|
||||
BasicConstraintsValid: true,
|
||||
KeyUsage: x509.KeyUsageCertSign,
|
||||
}
|
||||
caDER, err := x509.CreateCertificate(rand.Reader, caTmpl, caTmpl, &caKey.PublicKey, caKey)
|
||||
if err != nil {
|
||||
t.Fatalf("ca create: %v", err)
|
||||
}
|
||||
caCert, _ := x509.ParseCertificate(caDER)
|
||||
|
||||
clientKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("client key: %v", err)
|
||||
}
|
||||
clientTmpl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(2),
|
||||
Subject: pkix.Name{CommonName: "test-device-001"},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
}
|
||||
clientDER, err := x509.CreateCertificate(rand.Reader, clientTmpl, caCert, &clientKey.PublicKey, caKey)
|
||||
if err != nil {
|
||||
t.Fatalf("client create: %v", err)
|
||||
}
|
||||
clientCrt, _ := x509.ParseCertificate(clientDER)
|
||||
|
||||
// Persist the CA bundle on disk so trustanchor.New can load it.
|
||||
dir := t.TempDir()
|
||||
bundlePath := filepath.Join(dir, "trust.pem")
|
||||
body := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caDER})
|
||||
if err := os.WriteFile(bundlePath, body, 0o600); err != nil {
|
||||
t.Fatalf("write bundle: %v", err)
|
||||
}
|
||||
holder, err := trustanchor.New(bundlePath, slog.New(slog.NewTextHandler(io.Discard, nil)))
|
||||
if err != nil {
|
||||
t.Fatalf("trustanchor.New: %v", err)
|
||||
}
|
||||
|
||||
svc := &mockESTService{
|
||||
CACertPEM: pemCertString(caDER),
|
||||
EnrollResult: &domain.ESTEnrollResult{
|
||||
CertPEM: pemCertString(clientDER),
|
||||
},
|
||||
}
|
||||
return &hardeningTestSetup{
|
||||
svc: svc,
|
||||
caCert: caCert,
|
||||
caKey: caKey,
|
||||
clientCrt: clientCrt,
|
||||
clientKey: clientKey,
|
||||
trustPool: holder,
|
||||
bundleDir: dir,
|
||||
}
|
||||
}
|
||||
|
||||
func pemCertString(der []byte) string {
|
||||
return string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}))
|
||||
}
|
||||
|
||||
// makeMTLSRequest synthesises a POST against `path` with PEM CSR body and
|
||||
// r.TLS populated with the given peer cert chain + handshake state. Used
|
||||
// by the mTLS path tests where a real TLS handshake would force us into a
|
||||
// full httptest.NewTLSServer setup.
|
||||
func makeMTLSRequest(t *testing.T, path, csrPEM string, peerCerts []*x509.Certificate, version uint16) *http.Request {
|
||||
t.Helper()
|
||||
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(csrPEM))
|
||||
req.TLS = &tls.ConnectionState{
|
||||
HandshakeComplete: true,
|
||||
Version: version,
|
||||
PeerCertificates: peerCerts,
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// ----- mTLS handler gate -----
|
||||
|
||||
func TestSimpleEnrollMTLS_NoTrustPool_500(t *testing.T) {
|
||||
s := newHardeningTestSetup(t)
|
||||
h := NewESTHandler(s.svc) // intentionally do NOT call SetMTLSTrust
|
||||
req := makeMTLSRequest(t, "/.well-known/est-mtls/corp/simpleenroll",
|
||||
generateTestCSRPEM(t), []*x509.Certificate{s.clientCrt}, tls.VersionTLS13)
|
||||
w := httptest.NewRecorder()
|
||||
h.SimpleEnrollMTLS(w, req)
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status = %d, want 500 (handler missing trust pool)", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimpleEnrollMTLS_NoClientCert_401(t *testing.T) {
|
||||
s := newHardeningTestSetup(t)
|
||||
h := NewESTHandler(s.svc)
|
||||
h.SetMTLSTrust(s.trustPool)
|
||||
req := httptest.NewRequest(http.MethodPost, "/.well-known/est-mtls/corp/simpleenroll",
|
||||
strings.NewReader(generateTestCSRPEM(t)))
|
||||
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||
w := httptest.NewRecorder()
|
||||
h.SimpleEnrollMTLS(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401 (no client cert)", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimpleEnrollMTLS_CertNotInPool_401(t *testing.T) {
|
||||
s := newHardeningTestSetup(t)
|
||||
other := newHardeningTestSetup(t) // different CA, unrelated to s.trustPool
|
||||
h := NewESTHandler(s.svc)
|
||||
h.SetMTLSTrust(s.trustPool)
|
||||
req := makeMTLSRequest(t, "/.well-known/est-mtls/corp/simpleenroll",
|
||||
generateTestCSRPEM(t), []*x509.Certificate{other.clientCrt}, tls.VersionTLS13)
|
||||
w := httptest.NewRecorder()
|
||||
h.SimpleEnrollMTLS(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401 (cert not trusted by this profile)", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimpleEnrollMTLS_HappyPath_200(t *testing.T) {
|
||||
s := newHardeningTestSetup(t)
|
||||
h := NewESTHandler(s.svc)
|
||||
h.SetMTLSTrust(s.trustPool)
|
||||
req := makeMTLSRequest(t, "/.well-known/est-mtls/corp/simpleenroll",
|
||||
generateTestCSRPEM(t), []*x509.Certificate{s.clientCrt}, tls.VersionTLS13)
|
||||
w := httptest.NewRecorder()
|
||||
h.SimpleEnrollMTLS(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200; body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// ----- channel binding (Phase 2.4) -----
|
||||
|
||||
func TestSimpleReEnrollMTLS_ChannelBindingRequired_AbsentRejected(t *testing.T) {
|
||||
s := newHardeningTestSetup(t)
|
||||
h := NewESTHandler(s.svc)
|
||||
h.SetMTLSTrust(s.trustPool)
|
||||
h.SetChannelBindingRequired(true)
|
||||
// CSR has no binding attribute. Synthetic ConnectionState — exporter
|
||||
// extraction will fail (no real TLS secret), and required=true makes
|
||||
// VerifyChannelBinding propagate that as the missing-binding error.
|
||||
req := makeMTLSRequest(t, "/.well-known/est-mtls/corp/simplereenroll",
|
||||
generateTestCSRPEM(t), []*x509.Certificate{s.clientCrt}, tls.VersionTLS13)
|
||||
w := httptest.NewRecorder()
|
||||
h.SimpleReEnrollMTLS(w, req)
|
||||
// Either 400 (missing) or 426 (TLS 1.3 unavailable on synthetic state).
|
||||
// Both are correct refusals; pin to "non-2xx" so the test isn't fragile
|
||||
// against ConnectionState evolution.
|
||||
if w.Code/100 == 2 {
|
||||
t.Errorf("required + absent must reject; got 2xx (%d)", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimpleReEnrollMTLS_ChannelBindingNotRequired_AbsentAllowed(t *testing.T) {
|
||||
s := newHardeningTestSetup(t)
|
||||
h := NewESTHandler(s.svc)
|
||||
h.SetMTLSTrust(s.trustPool)
|
||||
h.SetChannelBindingRequired(false)
|
||||
// CSR has no binding, profile is opt-in only. The handler must allow.
|
||||
req := makeMTLSRequest(t, "/.well-known/est-mtls/corp/simplereenroll",
|
||||
generateTestCSRPEM(t), []*x509.Certificate{s.clientCrt}, tls.VersionTLS13)
|
||||
w := httptest.NewRecorder()
|
||||
h.SimpleReEnrollMTLS(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("required=false + absent must allow; got %d (%s)", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteChannelBindingError_KnownErrorsMapped(t *testing.T) {
|
||||
// Smoke test the error-to-status mapping so a future cms sentinel rename
|
||||
// gets caught at compile time + we hit each branch.
|
||||
s := newHardeningTestSetup(t)
|
||||
h := NewESTHandler(s.svc)
|
||||
cases := []struct {
|
||||
err error
|
||||
want int
|
||||
}{
|
||||
{cms.ErrChannelBindingMissing, http.StatusBadRequest},
|
||||
{cms.ErrChannelBindingMismatch, http.StatusConflict},
|
||||
{cms.ErrChannelBindingNotTLS13, http.StatusUpgradeRequired},
|
||||
}
|
||||
for _, c := range cases {
|
||||
w := httptest.NewRecorder()
|
||||
h.writeChannelBindingError(w, "req-id", c.err)
|
||||
if w.Code != c.want {
|
||||
t.Errorf("error=%v → status %d, want %d", c.err, w.Code, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----- HTTP Basic enrollment-password (Phase 3) -----
|
||||
|
||||
func TestSimpleEnroll_BasicAuth_NoHeader_401(t *testing.T) {
|
||||
s := newHardeningTestSetup(t)
|
||||
h := NewESTHandler(s.svc)
|
||||
h.SetEnrollmentPassword("super-secret")
|
||||
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
|
||||
strings.NewReader(generateTestCSRPEM(t)))
|
||||
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||
w := httptest.NewRecorder()
|
||||
h.SimpleEnroll(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401 (Basic required, header absent)", w.Code)
|
||||
}
|
||||
if got := w.Header().Get("WWW-Authenticate"); !strings.Contains(got, "Basic") {
|
||||
t.Errorf("WWW-Authenticate = %q, want to contain 'Basic'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimpleEnroll_BasicAuth_WrongPassword_401(t *testing.T) {
|
||||
s := newHardeningTestSetup(t)
|
||||
h := NewESTHandler(s.svc)
|
||||
h.SetEnrollmentPassword("super-secret")
|
||||
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
|
||||
strings.NewReader(generateTestCSRPEM(t)))
|
||||
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||
req.SetBasicAuth("device", "wrong-password")
|
||||
w := httptest.NewRecorder()
|
||||
h.SimpleEnroll(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401 (wrong password)", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimpleEnroll_BasicAuth_CorrectPassword_200(t *testing.T) {
|
||||
s := newHardeningTestSetup(t)
|
||||
h := NewESTHandler(s.svc)
|
||||
h.SetEnrollmentPassword("super-secret")
|
||||
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
|
||||
strings.NewReader(generateTestCSRPEM(t)))
|
||||
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||
req.SetBasicAuth("device", "super-secret")
|
||||
w := httptest.NewRecorder()
|
||||
h.SimpleEnroll(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200 (correct password); body=%q", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimpleEnroll_BasicAuth_NoPassword_NoGate(t *testing.T) {
|
||||
// When the per-profile enrollment password is empty, the Basic gate is
|
||||
// off and the handler reverts to the v2.0.x anonymous behavior.
|
||||
s := newHardeningTestSetup(t)
|
||||
h := NewESTHandler(s.svc) // SetEnrollmentPassword not called
|
||||
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
|
||||
strings.NewReader(generateTestCSRPEM(t)))
|
||||
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||
w := httptest.NewRecorder()
|
||||
h.SimpleEnroll(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200 (no Basic gate)", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ----- source-IP failed-auth rate limit (Phase 3.3) -----
|
||||
|
||||
func TestSimpleEnroll_BasicAuth_FailedAttemptLimitedAfterThreshold(t *testing.T) {
|
||||
s := newHardeningTestSetup(t)
|
||||
h := NewESTHandler(s.svc)
|
||||
h.SetEnrollmentPassword("super-secret")
|
||||
// Cap of 2 failed attempts before the IP gets locked. Each failed
|
||||
// attempt records a slot; the 3rd request should be 429.
|
||||
limiter := ratelimit.NewSlidingWindowLimiter(2, time.Hour, 10)
|
||||
h.SetSourceIPRateLimiter(limiter)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
|
||||
strings.NewReader(generateTestCSRPEM(t)))
|
||||
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||
req.RemoteAddr = "10.0.0.42:12345"
|
||||
req.SetBasicAuth("device", "WRONG")
|
||||
w := httptest.NewRecorder()
|
||||
h.SimpleEnroll(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("attempt %d: want 401, got %d", i, w.Code)
|
||||
}
|
||||
}
|
||||
// The 3rd attempt — even with a correct password — must be rate limited.
|
||||
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
|
||||
strings.NewReader(generateTestCSRPEM(t)))
|
||||
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||
req.RemoteAddr = "10.0.0.42:12345"
|
||||
req.SetBasicAuth("device", "super-secret")
|
||||
w := httptest.NewRecorder()
|
||||
h.SimpleEnroll(w, req)
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("post-lockout status = %d, want 429 (correct password should still be locked out)", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ----- per-principal sliding-window rate limit (Phase 4.2) -----
|
||||
|
||||
func TestSimpleEnroll_PerPrincipalLimit_BlocksAfterCap(t *testing.T) {
|
||||
s := newHardeningTestSetup(t)
|
||||
h := NewESTHandler(s.svc)
|
||||
limiter := ratelimit.NewSlidingWindowLimiter(2, 24*time.Hour, 100)
|
||||
h.SetPerPrincipalRateLimiter(limiter)
|
||||
|
||||
// First 2 enrollments from same (CN, IP) — pass.
|
||||
csrPEM := generateTestCSRPEM(t)
|
||||
for i := 0; i < 2; i++ {
|
||||
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
|
||||
strings.NewReader(csrPEM))
|
||||
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||
req.RemoteAddr = "10.0.0.7:5555"
|
||||
w := httptest.NewRecorder()
|
||||
h.SimpleEnroll(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("attempt %d: want 200, got %d", i, w.Code)
|
||||
}
|
||||
}
|
||||
// Third enrollment from same (CN, IP) — limited.
|
||||
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
|
||||
strings.NewReader(csrPEM))
|
||||
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||
req.RemoteAddr = "10.0.0.7:5555"
|
||||
w := httptest.NewRecorder()
|
||||
h.SimpleEnroll(w, req)
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("3rd same-principal enrollment status = %d, want 429", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimpleEnroll_PerPrincipalLimit_DifferentPrincipalsIndependent(t *testing.T) {
|
||||
s := newHardeningTestSetup(t)
|
||||
h := NewESTHandler(s.svc)
|
||||
limiter := ratelimit.NewSlidingWindowLimiter(1, 24*time.Hour, 100)
|
||||
h.SetPerPrincipalRateLimiter(limiter)
|
||||
|
||||
csrPEM1 := generateTestCSRPEM(t)
|
||||
csrPEM2 := generateTestCSRPEM(t) // different key + (default) different CN
|
||||
|
||||
req1 := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll", strings.NewReader(csrPEM1))
|
||||
req1.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||
req1.RemoteAddr = "10.0.0.10:1111"
|
||||
w1 := httptest.NewRecorder()
|
||||
h.SimpleEnroll(w1, req1)
|
||||
if w1.Code != http.StatusOK {
|
||||
t.Fatalf("principal 1 first call: want 200, got %d", w1.Code)
|
||||
}
|
||||
|
||||
// Same CN as csrPEM1 but different IP — independent bucket.
|
||||
req2 := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll", strings.NewReader(csrPEM2))
|
||||
req2.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||
req2.RemoteAddr = "10.0.0.20:2222"
|
||||
w2 := httptest.NewRecorder()
|
||||
h.SimpleEnroll(w2, req2)
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Errorf("principal 2 first call: want 200, got %d", w2.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ----- per-handler smoke test for the un-rolled mTLS variants -----
|
||||
|
||||
func TestCACertsMTLS_RequiresClientCert(t *testing.T) {
|
||||
s := newHardeningTestSetup(t)
|
||||
h := NewESTHandler(s.svc)
|
||||
h.SetMTLSTrust(s.trustPool)
|
||||
req := httptest.NewRequest(http.MethodGet, "/.well-known/est-mtls/corp/cacerts", nil)
|
||||
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||
w := httptest.NewRecorder()
|
||||
h.CACertsMTLS(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("CACertsMTLS no-cert status = %d, want 401", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRAttrsMTLS_RequiresClientCert(t *testing.T) {
|
||||
s := newHardeningTestSetup(t)
|
||||
h := NewESTHandler(s.svc)
|
||||
h.SetMTLSTrust(s.trustPool)
|
||||
req := httptest.NewRequest(http.MethodGet, "/.well-known/est-mtls/corp/csrattrs", nil)
|
||||
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||
w := httptest.NewRecorder()
|
||||
h.CSRAttrsMTLS(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("CSRAttrsMTLS no-cert status = %d, want 401", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ----- ensure the per-principal limit fires only when configured -----
|
||||
|
||||
func TestSimpleEnroll_NoPerPrincipalLimiter_AllUnbounded(t *testing.T) {
|
||||
s := newHardeningTestSetup(t)
|
||||
h := NewESTHandler(s.svc) // SetPerPrincipalRateLimiter not called
|
||||
csrPEM := generateTestCSRPEM(t)
|
||||
for i := 0; i < 50; i++ {
|
||||
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/corp/simpleenroll",
|
||||
strings.NewReader(csrPEM))
|
||||
req.TLS = &tls.ConnectionState{HandshakeComplete: true, Version: tls.VersionTLS13}
|
||||
w := httptest.NewRecorder()
|
||||
h.SimpleEnroll(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("attempt %d: want 200, got %d", i, w.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// silenceUnused keeps the "declared and not used" linter happy when we add
|
||||
// helpers that future tests may invoke (asn1, atomic).
|
||||
var _ = asn1.RawValue{}
|
||||
var _ atomic.Int32
|
||||
@@ -81,10 +81,11 @@ var AuthExemptRouterRoutes = []string{
|
||||
// TestDispatch_AuthExemptPrefixes regression test in cmd/server/main_test.go
|
||||
// pins this slice to buildFinalHandler's actual dispatch logic.
|
||||
var AuthExemptDispatchPrefixes = []string{
|
||||
"/.well-known/pki", // RFC 5280 CRL + RFC 6960 OCSP — relying-party-unauth
|
||||
"/.well-known/est", // RFC 7030 EST — auth via mTLS or CSR-embedded creds
|
||||
"/scep", // RFC 8894 SCEP — auth via challengePassword in CSR
|
||||
"/scep-mtls", // SCEP + mTLS sibling route (Phase 6.5) — auth is client cert + challengePassword
|
||||
"/.well-known/pki", // RFC 5280 CRL + RFC 6960 OCSP — relying-party-unauth
|
||||
"/.well-known/est", // RFC 7030 EST — auth via mTLS or CSR-embedded creds
|
||||
"/.well-known/est-mtls", // EST + mTLS sibling route (EST hardening Phase 2) — auth is client cert
|
||||
"/scep", // RFC 8894 SCEP — auth via challengePassword in CSR
|
||||
"/scep-mtls", // SCEP + mTLS sibling route (Phase 6.5) — auth is client cert + challengePassword
|
||||
}
|
||||
|
||||
// HandlerRegistry groups all API handler dependencies for router registration.
|
||||
@@ -445,6 +446,44 @@ func (r *Router) RegisterESTHandlers(handlers map[string]handler.ESTHandler) {
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterESTMTLSHandlers sets up the sibling `/.well-known/est-mtls/<PathID>/`
|
||||
// routes for EST profiles that opted into mTLS via
|
||||
// `CERTCTL_EST_PROFILE_<NAME>_MTLS_ENABLED=true`.
|
||||
//
|
||||
// EST RFC 7030 hardening master bundle Phase 2.2 + 2.3: enterprise
|
||||
// procurement teams routinely reject 'shared password authentication' as
|
||||
// a checkbox-fail regardless of how strong the password is. This sibling
|
||||
// route adds client-cert auth at the handler layer AND keeps the (Phase 3)
|
||||
// HTTP Basic enrollment-password as a defense-in-depth fallback for the
|
||||
// non-mTLS profile. Devices present a bootstrap cert from a trusted CA,
|
||||
// then EST-enroll for their long-lived cert. Mirrors the SCEP mTLS
|
||||
// sibling pattern at RegisterSCEPMTLSHandlers below (commit 6b0d9e from
|
||||
// the SCEP Phase 6.5 work).
|
||||
//
|
||||
// Path conventions: every mTLS profile gets a non-empty PathID, so the
|
||||
// sibling routes are always /.well-known/est-mtls/<pathID>/. There is no
|
||||
// "empty PathID = legacy /.well-known/est-mtls" case — mTLS is opt-in
|
||||
// per profile, the legacy /.well-known/est root is always non-mTLS to
|
||||
// preserve backward compat with existing deploys.
|
||||
//
|
||||
// Each handler in the map MUST have had SetMTLSTrust called so the
|
||||
// per-profile cert verification has a trust anchor. cmd/server/main.go's
|
||||
// per-profile EST loop wires this in the same loop iteration that
|
||||
// registers the handler.
|
||||
func (r *Router) RegisterESTMTLSHandlers(handlers map[string]handler.ESTHandler) {
|
||||
for pathID, h := range handlers {
|
||||
if pathID == "" {
|
||||
continue // mTLS sibling route requires per-profile PathID
|
||||
}
|
||||
hCopy := h // h is captured by value — see RegisterESTHandlers above
|
||||
prefix := "/.well-known/est-mtls/" + pathID
|
||||
r.Register("GET "+prefix+"/cacerts", http.HandlerFunc(hCopy.CACertsMTLS))
|
||||
r.Register("POST "+prefix+"/simpleenroll", http.HandlerFunc(hCopy.SimpleEnrollMTLS))
|
||||
r.Register("POST "+prefix+"/simplereenroll", http.HandlerFunc(hCopy.SimpleReEnrollMTLS))
|
||||
r.Register("GET "+prefix+"/csrattrs", http.HandlerFunc(hCopy.CSRAttrsMTLS))
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterSCEPHandlers sets up SCEP (RFC 8894) routes.
|
||||
// SCEP uses a single endpoint per profile with operation-based dispatch via
|
||||
// query parameters. Authentication is via the challengePassword attribute in
|
||||
|
||||
@@ -0,0 +1,369 @@
|
||||
// Package cms implements the small subset of CMS / RFC 7030 / RFC 9266
|
||||
// helpers that the EST handler needs at request-time: extracting the
|
||||
// RFC 9266 tls-exporter from a *tls.ConnectionState, and pulling the
|
||||
// matching value back out of an EST CSR's CMC unsignedAttribute when the
|
||||
// device proved channel binding.
|
||||
//
|
||||
// Why a separate package (vs adding to internal/api/handler/est.go):
|
||||
//
|
||||
// 1. internal/api/handler depends on internal/pkcs7 already; if the EST
|
||||
// mTLS hardening also pulled CMC parsing into handler we'd grow the
|
||||
// handler-side dep graph by another asn1 surface that has nothing
|
||||
// specific to HTTP.
|
||||
//
|
||||
// 2. Channel-binding extraction is testable in isolation — the unit
|
||||
// tests construct a *tls.ConnectionState with raw exporter bytes and
|
||||
// a *x509.CertificateRequest with the CMC unsignedAttribute already
|
||||
// filled in. No HTTP plumbing required to verify the contract.
|
||||
//
|
||||
// 3. Future EST extensions (RFC 7030 §3.5 fullCMC, RFC 9148 EST-coaps)
|
||||
// are likely to land here too — keep them out of net/http land.
|
||||
//
|
||||
// EST RFC 7030 hardening master bundle Phase 2.4.
|
||||
package cms
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/subtle"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ----- RFC 9266 §3 — TLS exporter extraction -----
|
||||
|
||||
// TLSExporterLabel is the EXPORTER label registered by RFC 9266 §3.1
|
||||
// for use as a TLS-1.3 channel binding. Constant rather than string-typed
|
||||
// so a typo here is a compile error rather than a silent failure mode.
|
||||
const TLSExporterLabel = "EXPORTER-Channel-Binding"
|
||||
|
||||
// TLSExporterLength is the 32-byte exporter length pinned by RFC 9266 §3.1
|
||||
// (matches the SHA-256 output size; clients and servers MUST agree on the
|
||||
// length to make the comparison meaningful).
|
||||
const TLSExporterLength = 32
|
||||
|
||||
// ErrChannelBindingMissing is returned when the EST mTLS handler requires
|
||||
// channel binding (per-profile ChannelBindingRequired=true) but the device's
|
||||
// CSR has no id-aa-est-tls-exporter unsignedAttribute or the attribute is
|
||||
// the wrong shape.
|
||||
var ErrChannelBindingMissing = errors.New("cms: channel binding required but absent or malformed in CSR")
|
||||
|
||||
// ErrChannelBindingMismatch is returned when the device's CSR carried a
|
||||
// channel-binding attribute but its bytes do not match the TLS-1.3 exporter
|
||||
// extracted from the live connection. This is the signal of an MITM that
|
||||
// terminates TLS in front of certctl: the device computed exporter X
|
||||
// against the attacker, certctl sees exporter Y against itself, X≠Y.
|
||||
var ErrChannelBindingMismatch = errors.New("cms: channel binding in CSR does not match TLS exporter")
|
||||
|
||||
// ErrChannelBindingNotTLS13 is returned when the connection is older than
|
||||
// TLS 1.3 and the per-profile config still requires channel binding.
|
||||
// RFC 9266's tls-exporter is a TLS-1.3 binding; pre-1.3 connections would
|
||||
// need RFC 5929 tls-unique, which we deliberately don't support
|
||||
// (certctl pins TLS-1.3 server-side).
|
||||
var ErrChannelBindingNotTLS13 = errors.New("cms: tls-exporter channel binding requires TLS 1.3")
|
||||
|
||||
// ExtractTLSExporter pulls the 32-byte RFC 9266 channel-binding value from
|
||||
// the TLS connection state. The connection must be TLS 1.3 + handshake-
|
||||
// complete; anything else returns a typed error so the caller can map to
|
||||
// HTTP 400 / 412 cleanly.
|
||||
//
|
||||
// Stateless on purpose: callers handle storage + comparison.
|
||||
//
|
||||
// Robustness note: stdlib's ConnectionState.ExportKeyingMaterial nil-derefs
|
||||
// when the underlying secret-derivation closure is unset (i.e. the state
|
||||
// was hand-constructed by a test fixture rather than produced by a real
|
||||
// TLS handshake). The recover() below converts that panic into the same
|
||||
// typed error a missing-binding state would surface, so synthetic test
|
||||
// states + production TLS-1.3 connections share a single failure mode.
|
||||
func ExtractTLSExporter(state *tls.ConnectionState) (out []byte, err error) {
|
||||
if state == nil {
|
||||
return nil, fmt.Errorf("%w: nil ConnectionState", ErrChannelBindingMissing)
|
||||
}
|
||||
if !state.HandshakeComplete {
|
||||
return nil, fmt.Errorf("%w: handshake incomplete", ErrChannelBindingMissing)
|
||||
}
|
||||
// tls.VersionTLS13 == 0x0304. We use the literal so this package doesn't
|
||||
// have to import "crypto/tls" twice (once for tls.VersionTLS13, once for
|
||||
// the *tls.ConnectionState type — Go allows it but it's noisy).
|
||||
if state.Version < 0x0304 {
|
||||
return nil, fmt.Errorf("%w: negotiated 0x%04x", ErrChannelBindingNotTLS13, state.Version)
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
out = nil
|
||||
err = fmt.Errorf("%w: ExportKeyingMaterial unavailable on this connection state (panic=%v)", ErrChannelBindingMissing, r)
|
||||
}
|
||||
}()
|
||||
out, err = state.ExportKeyingMaterial(TLSExporterLabel, nil, TLSExporterLength)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cms: ExportKeyingMaterial: %w", err)
|
||||
}
|
||||
if len(out) != TLSExporterLength {
|
||||
return nil, fmt.Errorf("cms: exporter returned %d bytes, want %d", len(out), TLSExporterLength)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ----- RFC 7030 §3.5 / RFC 9266 §4.1 — CSR-side channel binding -----
|
||||
|
||||
// OIDChannelBindingTLSExporter is the id-aa-est-tls-exporter OID from
|
||||
// RFC 9266 §4.1 (registered under id-aa = 1.2.840.113549.1.9.16.2 with
|
||||
// arc 56 by RFC 9266). Devices that signed channel binding into their
|
||||
// CSR add a CMC unsignedAttribute with this OID + an OCTET STRING value.
|
||||
//
|
||||
// Note: the EST RFC 7030 §3.5 historical OID for tls-unique is
|
||||
// id-aa-cmc-binding (1.2.840.113549.1.9.16.2.43). RFC 9266 §4.1 added
|
||||
// arc 56 for tls-exporter. We accept BOTH OIDs on the read path so a
|
||||
// device using a slightly older library that still emits the §3.5 OID
|
||||
// continues to work — the value bytes are still the 32-byte exporter
|
||||
// (the OID identifies the binding scheme, not the underlying wire
|
||||
// format).
|
||||
var (
|
||||
OIDChannelBindingTLSExporter = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 16, 2, 56}
|
||||
OIDCMCEnrollmentBinding = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 16, 2, 43}
|
||||
)
|
||||
|
||||
// ExtractCSRChannelBinding looks for the RFC 9266 channel-binding
|
||||
// attribute (or the legacy RFC 7030 §3.5 binding attribute) in the CSR's
|
||||
// raw attributes block. Returns the raw 32-byte exporter value if
|
||||
// present.
|
||||
//
|
||||
// Why we walk csr.RawTBSCertificateRequest manually instead of using
|
||||
// csr.Attributes:
|
||||
//
|
||||
// - csr.Attributes is typed as []pkix.AttributeTypeAndValueSET, where
|
||||
// the inner Value is [][]pkix.AttributeTypeAndValue. That shape only
|
||||
// fits attributes whose AttributeValue is itself a SEQUENCE { OID,
|
||||
// ANY } (e.g. the requestedExtensions attribute). RFC 9266's
|
||||
// TLSExporterValue is `OCTET STRING` — a primitive, not a SEQUENCE
|
||||
// — so the stdlib parse path either drops the attribute silently or
|
||||
// fails the whole CSR parse depending on encoding.
|
||||
//
|
||||
// - The PKCS#10 challengePassword path in scep.go works by accident
|
||||
// because PrintableString happens to round-trip through the
|
||||
// stdlib's interface{}-typed AttributeTypeAndValue.Value. OCTET
|
||||
// STRING does not — it's not in the small list of primitive types
|
||||
// the stdlib's reflect-based unmarshaller handles for `any`.
|
||||
//
|
||||
// - Walking the raw TBS is ~30 lines of asn1.Unmarshal calls and
|
||||
// gives us a stable contract independent of stdlib quirks.
|
||||
//
|
||||
// Returns (value, true, nil) on success; (nil, false, nil) when the
|
||||
// attribute is absent (caller decides whether absence is acceptable per
|
||||
// the per-profile ChannelBindingRequired flag); (nil, false, err) on
|
||||
// malformed attribute (always fatal — a present-but-wrong attribute
|
||||
// signals an attacker rewriting the binding into garbage).
|
||||
func ExtractCSRChannelBinding(csr *x509.CertificateRequest) ([]byte, bool, error) {
|
||||
if csr == nil {
|
||||
return nil, false, fmt.Errorf("cms: nil CSR")
|
||||
}
|
||||
if len(csr.RawTBSCertificateRequest) == 0 {
|
||||
// Stdlib fills RawTBSCertificateRequest on every parse path, so an
|
||||
// empty value here means the caller hand-crafted the struct. Tests
|
||||
// can do that — but real handler-side calls always have raw bytes.
|
||||
return nil, false, nil
|
||||
}
|
||||
return walkCSRAttributesForBinding(csr.RawTBSCertificateRequest)
|
||||
}
|
||||
|
||||
// walkCSRAttributesForBinding parses just enough of TBSCertificationRequestInfo
|
||||
// to reach the [0] IMPLICIT Attributes field, then iterates each Attribute
|
||||
// looking for the channel-binding OID. The body is intentionally low-level
|
||||
// so we can keep the asn1 footprint contained to this one helper.
|
||||
//
|
||||
// TBSCertificationRequestInfo per RFC 2986 §4.1:
|
||||
//
|
||||
// TBSCertificationRequestInfo ::= SEQUENCE {
|
||||
// version INTEGER (0),
|
||||
// subject Name,
|
||||
// subjectPKInfo SubjectPublicKeyInfo,
|
||||
// attributes [0] IMPLICIT Attributes (SET OF Attribute)
|
||||
// }
|
||||
func walkCSRAttributesForBinding(tbs []byte) ([]byte, bool, error) {
|
||||
// 1. Crack the outer SEQUENCE wrapper.
|
||||
var inner asn1.RawValue
|
||||
if rest, err := asn1.Unmarshal(tbs, &inner); err != nil {
|
||||
return nil, false, fmt.Errorf("cms: TBS outer parse: %w", err)
|
||||
} else if len(rest) > 0 {
|
||||
return nil, false, fmt.Errorf("cms: TBS trailing bytes: %d", len(rest))
|
||||
}
|
||||
if inner.Tag != asn1.TagSequence {
|
||||
return nil, false, fmt.Errorf("cms: TBS outer tag %d not SEQUENCE", inner.Tag)
|
||||
}
|
||||
rest := inner.Bytes
|
||||
|
||||
// 2. Skip version (INTEGER), subject (SEQUENCE = Name), subjectPKInfo
|
||||
// (SEQUENCE). asn1.Unmarshal into asn1.RawValue advances the cursor
|
||||
// without parsing the body — perfect for skipping fields we don't care
|
||||
// about.
|
||||
for i, label := range []string{"version", "subject", "subjectPKInfo"} {
|
||||
var rv asn1.RawValue
|
||||
next, err := asn1.Unmarshal(rest, &rv)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("cms: skip TBS field %d (%s): %w", i, label, err)
|
||||
}
|
||||
rest = next
|
||||
}
|
||||
|
||||
// 3. Attributes is [0] IMPLICIT — the on-wire tag is 0xA0 with class
|
||||
// CONTEXT-SPECIFIC. asn1.Unmarshal into a RawValue accepts arbitrary
|
||||
// tags; we then walk its Bytes as a SET OF Attribute.
|
||||
var attrsField asn1.RawValue
|
||||
if _, err := asn1.Unmarshal(rest, &attrsField); err != nil {
|
||||
// No attributes block at all — RFC 2986 says [0] is OPTIONAL when
|
||||
// empty (encoders typically omit the field rather than emit an
|
||||
// empty SET). Treat as "no binding present", not as an error.
|
||||
return nil, false, nil
|
||||
}
|
||||
if attrsField.Class != asn1.ClassContextSpecific || attrsField.Tag != 0 {
|
||||
// Some non-attribute-shaped trailing field: not what we expected
|
||||
// but not strictly a corruption signal — skip silently.
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// 4. Walk each Attribute in the SET. Each Attribute is
|
||||
// SEQUENCE { OID, SET OF ANY }.
|
||||
attrBytes := attrsField.Bytes
|
||||
for len(attrBytes) > 0 {
|
||||
var oneAttr asn1.RawValue
|
||||
next, err := asn1.Unmarshal(attrBytes, &oneAttr)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("cms: walk attributes: %w", err)
|
||||
}
|
||||
attrBytes = next
|
||||
if oneAttr.Tag != asn1.TagSequence {
|
||||
continue
|
||||
}
|
||||
// Inner: OID, then SET.
|
||||
var oid asn1.ObjectIdentifier
|
||||
afterOID, err := asn1.Unmarshal(oneAttr.Bytes, &oid)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if !oid.Equal(OIDChannelBindingTLSExporter) && !oid.Equal(OIDCMCEnrollmentBinding) {
|
||||
continue
|
||||
}
|
||||
// Now afterOID is the SET wrapper. Crack it and pull the OCTET
|
||||
// STRING out of the SET's first element.
|
||||
var setWrap asn1.RawValue
|
||||
if _, err := asn1.Unmarshal(afterOID, &setWrap); err != nil {
|
||||
return nil, false, fmt.Errorf("cms: binding SET parse: %w (%w)", err, ErrChannelBindingMissing)
|
||||
}
|
||||
if setWrap.Tag != asn1.TagSet {
|
||||
return nil, false, fmt.Errorf("cms: binding outer tag %d not SET (%w)", setWrap.Tag, ErrChannelBindingMissing)
|
||||
}
|
||||
var octet asn1.RawValue
|
||||
if _, err := asn1.Unmarshal(setWrap.Bytes, &octet); err != nil {
|
||||
return nil, false, fmt.Errorf("cms: binding inner parse: %w (%w)", err, ErrChannelBindingMissing)
|
||||
}
|
||||
if octet.Tag != asn1.TagOctetString {
|
||||
return nil, false, fmt.Errorf("cms: binding inner tag %d not OCTET STRING (%w)", octet.Tag, ErrChannelBindingMissing)
|
||||
}
|
||||
if len(octet.Bytes) != TLSExporterLength {
|
||||
return nil, false, fmt.Errorf("cms: binding length %d, want %d (%w)",
|
||||
len(octet.Bytes), TLSExporterLength, ErrChannelBindingMissing)
|
||||
}
|
||||
return octet.Bytes, true, nil
|
||||
}
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// VerifyChannelBinding is the convenience composite the EST mTLS handler
|
||||
// calls per request: extract the exporter from the live TLS connection,
|
||||
// pull the matching value from the CSR, compare in constant time.
|
||||
//
|
||||
// Returns:
|
||||
// - nil when the binding is present + matches.
|
||||
// - ErrChannelBindingMissing when the CSR has no binding attribute.
|
||||
// - ErrChannelBindingMismatch when both sides have a value but they
|
||||
// differ (the MITM signal).
|
||||
// - Any error from the exporter extraction (TLS state is wrong, etc).
|
||||
//
|
||||
// The required flag controls absence-handling: when required=false a
|
||||
// missing attribute returns nil (channel binding is optional for this
|
||||
// profile); when required=true a missing attribute returns
|
||||
// ErrChannelBindingMissing.
|
||||
func VerifyChannelBinding(state *tls.ConnectionState, csr *x509.CertificateRequest, required bool) error {
|
||||
live, err := ExtractTLSExporter(state)
|
||||
if err != nil {
|
||||
// If the profile doesn't require channel binding AND the only
|
||||
// problem is "no TLS 1.3 / no handshake", we still let the request
|
||||
// through — the binding is opt-in per profile. But if the CSR
|
||||
// itself carries a binding attribute, the device clearly INTENDED
|
||||
// to bind, so a TLS state mismatch is a genuine error.
|
||||
if !required {
|
||||
if _, present, _ := ExtractCSRChannelBinding(csr); !present {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
csrBinding, present, err := ExtractCSRChannelBinding(csr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !present {
|
||||
if required {
|
||||
return ErrChannelBindingMissing
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if subtle.ConstantTimeCompare(live, csrBinding) != 1 {
|
||||
return ErrChannelBindingMismatch
|
||||
}
|
||||
// Sanity: the comparison should be identical bytes for matching cases.
|
||||
// The bytes.Equal call is dead code under correct subtle.Compare result;
|
||||
// it's here only to make the contract obvious to readers and to pin the
|
||||
// symmetry test that asserts ExtractCSRChannelBinding is byte-equivalent
|
||||
// to ExtractTLSExporter when the device behaved correctly.
|
||||
if !bytes.Equal(live, csrBinding) {
|
||||
return ErrChannelBindingMismatch
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EmbedChannelBindingAttribute is the test helper inverse of
|
||||
// ExtractCSRChannelBinding: given an exporter value, returns the DER
|
||||
// bytes of the Attribute (SEQUENCE { OID, SET { OCTET STRING } }) that
|
||||
// the caller can splice into the [0] IMPLICIT Attributes field of
|
||||
// TBSCertificationRequestInfo. Used by the EST channel-binding tests
|
||||
// AND by any external caller that wants to forge a CSR with a known
|
||||
// binding for fixture generation.
|
||||
func EmbedChannelBindingAttribute(exporter []byte) ([]byte, error) {
|
||||
if len(exporter) != TLSExporterLength {
|
||||
return nil, fmt.Errorf("cms: exporter length %d, want %d", len(exporter), TLSExporterLength)
|
||||
}
|
||||
octet, err := asn1.Marshal(exporter) // marshal []byte as OCTET STRING
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cms: marshal exporter octet: %w", err)
|
||||
}
|
||||
// Wrap in SET OF.
|
||||
setBody := octet
|
||||
setEnvelope, err := asn1.Marshal(asn1.RawValue{
|
||||
Class: asn1.ClassUniversal,
|
||||
Tag: asn1.TagSet,
|
||||
IsCompound: true,
|
||||
Bytes: setBody,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cms: marshal SET: %w", err)
|
||||
}
|
||||
oid, err := asn1.Marshal(OIDChannelBindingTLSExporter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cms: marshal OID: %w", err)
|
||||
}
|
||||
// Wrap as SEQUENCE { OID, SET }.
|
||||
seqBody := append(append([]byte{}, oid...), setEnvelope...)
|
||||
seqEnvelope, err := asn1.Marshal(asn1.RawValue{
|
||||
Class: asn1.ClassUniversal,
|
||||
Tag: asn1.TagSequence,
|
||||
IsCompound: true,
|
||||
Bytes: seqBody,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cms: marshal SEQUENCE: %w", err)
|
||||
}
|
||||
return seqEnvelope, nil
|
||||
}
|
||||
@@ -0,0 +1,394 @@
|
||||
package cms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"errors"
|
||||
"math/big"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// EST RFC 7030 hardening master bundle Phase 2.4 tests.
|
||||
|
||||
// ----- ExtractTLSExporter -----
|
||||
|
||||
func TestExtractTLSExporter_NilState(t *testing.T) {
|
||||
if _, err := ExtractTLSExporter(nil); !errors.Is(err, ErrChannelBindingMissing) {
|
||||
t.Errorf("nil state should return ErrChannelBindingMissing, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTLSExporter_HandshakeNotComplete(t *testing.T) {
|
||||
state := &tls.ConnectionState{HandshakeComplete: false, Version: 0x0304}
|
||||
if _, err := ExtractTLSExporter(state); !errors.Is(err, ErrChannelBindingMissing) {
|
||||
t.Errorf("incomplete handshake should return ErrChannelBindingMissing, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTLSExporter_PreTLS13Rejected(t *testing.T) {
|
||||
state := &tls.ConnectionState{HandshakeComplete: true, Version: 0x0303} // TLS 1.2
|
||||
if _, err := ExtractTLSExporter(state); !errors.Is(err, ErrChannelBindingNotTLS13) {
|
||||
t.Errorf("TLS 1.2 should return ErrChannelBindingNotTLS13, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractTLSExporter_TLS13EndToEnd is the only test that builds a full
|
||||
// real TLS-1.3 session — the exporter is computed on the connection's secret
|
||||
// state, so we can't fake the ConnectionState. We spin up a localhost TCP
|
||||
// listener, do a handshake, and then call ExportKeyingMaterial directly to
|
||||
// pin the contract. This is a small round-trip but we're not testing TLS
|
||||
// itself — just that ExtractTLSExporter pulls a 32-byte value from a real
|
||||
// 1.3 state.
|
||||
func TestExtractTLSExporter_TLS13EndToEnd(t *testing.T) {
|
||||
cert, key := freshSelfSignedTLSCert(t)
|
||||
tlsCert := tls.Certificate{Certificate: [][]byte{cert.Raw}, PrivateKey: key}
|
||||
|
||||
cfg := &tls.Config{
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
MinVersion: tls.VersionTLS13,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
}
|
||||
clientCfg := &tls.Config{
|
||||
InsecureSkipVerify: true, //nolint:gosec // hermetic test cert; not for production use
|
||||
MinVersion: tls.VersionTLS13,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
}
|
||||
|
||||
ln, err := tls.Listen("tcp", "127.0.0.1:0", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("tls.Listen: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
// Finish the handshake on the server side.
|
||||
_ = conn.(*tls.Conn).HandshakeContext(context.Background())
|
||||
// Hold the connection open until the client side completes its read.
|
||||
buf := make([]byte, 1)
|
||||
_, _ = conn.Read(buf)
|
||||
}()
|
||||
|
||||
conn, err := tls.Dial("tcp", ln.Addr().String(), clientCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("tls.Dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
if err := conn.HandshakeContext(context.Background()); err != nil {
|
||||
t.Fatalf("client handshake: %v", err)
|
||||
}
|
||||
state := conn.ConnectionState()
|
||||
|
||||
out, err := ExtractTLSExporter(&state)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractTLSExporter: %v", err)
|
||||
}
|
||||
if len(out) != TLSExporterLength {
|
||||
t.Errorf("len(out) = %d, want %d", len(out), TLSExporterLength)
|
||||
}
|
||||
}
|
||||
|
||||
// ----- ExtractCSRChannelBinding -----
|
||||
|
||||
func TestExtractCSRChannelBinding_NilCSR(t *testing.T) {
|
||||
if _, _, err := ExtractCSRChannelBinding(nil); err == nil {
|
||||
t.Fatal("nil CSR should error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCSRChannelBinding_AbsentReturnsFalse(t *testing.T) {
|
||||
csr := freshCSRNoBinding(t)
|
||||
val, present, err := ExtractCSRChannelBinding(csr)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractCSRChannelBinding: %v", err)
|
||||
}
|
||||
if present {
|
||||
t.Errorf("present=true on a CSR without the binding attribute (val=%x)", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCSRChannelBinding_PresentReturnsExporter(t *testing.T) {
|
||||
exporter := repeatByte(0x42, TLSExporterLength)
|
||||
csr := freshCSRWithBinding(t, exporter, OIDChannelBindingTLSExporter)
|
||||
val, present, err := ExtractCSRChannelBinding(csr)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractCSRChannelBinding: %v", err)
|
||||
}
|
||||
if !present {
|
||||
t.Fatal("present=false on a CSR that carries the binding")
|
||||
}
|
||||
if !bytesEq(val, exporter) {
|
||||
t.Errorf("exporter = %x, want %x", val, exporter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCSRChannelBinding_LegacyOIDAccepted(t *testing.T) {
|
||||
exporter := repeatByte(0xAA, TLSExporterLength)
|
||||
csr := freshCSRWithBinding(t, exporter, OIDCMCEnrollmentBinding)
|
||||
val, present, err := ExtractCSRChannelBinding(csr)
|
||||
if err != nil {
|
||||
t.Fatalf("legacy-OID path failed: %v", err)
|
||||
}
|
||||
if !present || !bytesEq(val, exporter) {
|
||||
t.Errorf("legacy-OID extraction: got present=%v val=%x, want present=true val=%x", present, val, exporter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCSRChannelBinding_WrongLengthRejected(t *testing.T) {
|
||||
short := repeatByte(0x55, 16) // half the required length
|
||||
csr := freshCSRWithBinding(t, short, OIDChannelBindingTLSExporter)
|
||||
_, _, err := ExtractCSRChannelBinding(csr)
|
||||
if !errors.Is(err, ErrChannelBindingMissing) {
|
||||
t.Errorf("wrong-length binding should wrap ErrChannelBindingMissing, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ----- VerifyChannelBinding (composite) -----
|
||||
|
||||
func TestVerifyChannelBinding_NotRequired_NoBinding_Passes(t *testing.T) {
|
||||
csr := freshCSRNoBinding(t)
|
||||
if err := VerifyChannelBinding(nil, csr, false); err != nil {
|
||||
t.Errorf("required=false + no binding should pass; got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyChannelBinding_Required_NilState_Errors(t *testing.T) {
|
||||
csr := freshCSRNoBinding(t)
|
||||
if err := VerifyChannelBinding(nil, csr, true); err == nil {
|
||||
t.Fatal("required=true + nil state must error")
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: a synthetic *tls.ConnectionState{HandshakeComplete:true, Version:0x0304}
|
||||
// would seem like the obvious VerifyChannelBinding(required=true) negative-case
|
||||
// fixture, but stdlib's ExportKeyingMaterial nil-derefs when the underlying
|
||||
// secret state is unset (see crypto/tls/common.go:330). The
|
||||
// "no live exporter available" branch is genuinely only reachable via a real
|
||||
// connection (TestExtractTLSExporter_TLS13EndToEnd above), so we don't try to
|
||||
// fake it here. The TestVerifyChannelBinding_NotRequired_NoBinding_Passes +
|
||||
// TestVerifyChannelBinding_Required_NilState_Errors tests cover the policy
|
||||
// branches; production code paths only ever pass r.TLS from a live request.
|
||||
|
||||
// ----- helpers -----
|
||||
|
||||
// freshCSRNoBinding returns a CSR with no extra attributes.
|
||||
func freshCSRNoBinding(t *testing.T) *x509.CertificateRequest {
|
||||
t.Helper()
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("ecdsa.GenerateKey: %v", err)
|
||||
}
|
||||
tmpl := &x509.CertificateRequest{Subject: pkix.Name{CommonName: "no-binding-test"}}
|
||||
der, err := x509.CreateCertificateRequest(rand.Reader, tmpl, key)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCertificateRequest: %v", err)
|
||||
}
|
||||
csr, err := x509.ParseCertificateRequest(der)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseCertificateRequest: %v", err)
|
||||
}
|
||||
return csr
|
||||
}
|
||||
|
||||
// freshCSRWithBinding builds a CSR whose TBS carries the channel-binding
|
||||
// attribute. The stdlib's CreateCertificateRequest doesn't support arbitrary
|
||||
// attributes (only ExtraExtensions), so we hand-craft the TBS by parsing
|
||||
// what stdlib produced + splicing our attribute into the [0] IMPLICIT
|
||||
// Attributes block + re-signing.
|
||||
func freshCSRWithBinding(t *testing.T, exporter []byte, oid asn1.ObjectIdentifier) *x509.CertificateRequest {
|
||||
t.Helper()
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("ecdsa.GenerateKey: %v", err)
|
||||
}
|
||||
// 1. Get a baseline CSR with no attributes — we steal its TBS shape.
|
||||
tmpl := &x509.CertificateRequest{Subject: pkix.Name{CommonName: "binding-test"}}
|
||||
derBaseline, err := x509.CreateCertificateRequest(rand.Reader, tmpl, key)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCertificateRequest: %v", err)
|
||||
}
|
||||
baseline, err := x509.ParseCertificateRequest(derBaseline)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseCertificateRequest: %v", err)
|
||||
}
|
||||
|
||||
// 2. Build the channel-binding attribute (SEQUENCE { OID, SET { OCTET STRING }}).
|
||||
octet, err := asn1.Marshal(exporter)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal octet: %v", err)
|
||||
}
|
||||
setEnv, err := asn1.Marshal(asn1.RawValue{Class: asn1.ClassUniversal, Tag: asn1.TagSet, IsCompound: true, Bytes: octet})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal set: %v", err)
|
||||
}
|
||||
oidBytes, err := asn1.Marshal(oid)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal oid: %v", err)
|
||||
}
|
||||
attrSeq, err := asn1.Marshal(asn1.RawValue{
|
||||
Class: asn1.ClassUniversal,
|
||||
Tag: asn1.TagSequence,
|
||||
IsCompound: true,
|
||||
Bytes: append(append([]byte{}, oidBytes...), setEnv...),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal attribute SEQUENCE: %v", err)
|
||||
}
|
||||
|
||||
// 3. Splice attribute into a [0] IMPLICIT Attributes block and rebuild
|
||||
// the TBS by hand. The TBS structure is:
|
||||
// SEQUENCE { version INTEGER, subject Name, subjectPKInfo SubjectPublicKeyInfo,
|
||||
// attributes [0] IMPLICIT SET OF Attribute }
|
||||
// We re-extract the first three fields from the baseline TBS and
|
||||
// re-marshal with our attribute appended.
|
||||
var outer asn1.RawValue
|
||||
if _, err := asn1.Unmarshal(baseline.RawTBSCertificateRequest, &outer); err != nil {
|
||||
t.Fatalf("baseline TBS unmarshal: %v", err)
|
||||
}
|
||||
rest := outer.Bytes
|
||||
var version, subject, spki asn1.RawValue
|
||||
for _, target := range []*asn1.RawValue{&version, &subject, &spki} {
|
||||
next, err := asn1.Unmarshal(rest, target)
|
||||
if err != nil {
|
||||
t.Fatalf("baseline TBS skip: %v", err)
|
||||
}
|
||||
rest = next
|
||||
}
|
||||
versionDER, _ := asn1.Marshal(version)
|
||||
subjectDER, _ := asn1.Marshal(subject)
|
||||
spkiDER, _ := asn1.Marshal(spki)
|
||||
// Build the [0] IMPLICIT Attributes wrapper.
|
||||
attrsField, err := asn1.Marshal(asn1.RawValue{
|
||||
Class: asn1.ClassContextSpecific,
|
||||
Tag: 0,
|
||||
IsCompound: true,
|
||||
Bytes: attrSeq,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal attrs field: %v", err)
|
||||
}
|
||||
tbsBody := append(append(append(append([]byte{}, versionDER...), subjectDER...), spkiDER...), attrsField...)
|
||||
newTBS, err := asn1.Marshal(asn1.RawValue{
|
||||
Class: asn1.ClassUniversal,
|
||||
Tag: asn1.TagSequence,
|
||||
IsCompound: true,
|
||||
Bytes: tbsBody,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("re-marshal TBS: %v", err)
|
||||
}
|
||||
|
||||
// 4. Parse the new TBS — we don't need to re-sign for these tests
|
||||
// (ExtractCSRChannelBinding doesn't verify the signature; it walks
|
||||
// RawTBSCertificateRequest only).
|
||||
csr := &x509.CertificateRequest{
|
||||
RawTBSCertificateRequest: newTBS,
|
||||
Subject: baseline.Subject,
|
||||
PublicKey: baseline.PublicKey,
|
||||
}
|
||||
return csr
|
||||
}
|
||||
|
||||
// freshSelfSignedTLSCert produces a tls.Certificate-compatible cert+key for
|
||||
// the TLS-1.3 round-trip test.
|
||||
func freshSelfSignedTLSCert(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) {
|
||||
t.Helper()
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("ecdsa.GenerateKey: %v", err)
|
||||
}
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "tls-test"},
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
DNSNames: []string{"localhost"},
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCertificate: %v", err)
|
||||
}
|
||||
cert, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseCertificate: %v", err)
|
||||
}
|
||||
return cert, key
|
||||
}
|
||||
|
||||
// repeatByte returns a slice of length n filled with b. Used for fixture
|
||||
// exporter values where we need a deterministic test pattern.
|
||||
func repeatByte(b byte, n int) []byte {
|
||||
out := make([]byte, n)
|
||||
for i := range out {
|
||||
out[i] = b
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func bytesEq(a, b []byte) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// EmbedChannelBindingAttribute round-trip — pins the spec contract that
|
||||
// what we marshal can be parsed back by ExtractCSRChannelBinding without
|
||||
// going through the freshCSRWithBinding splice helper.
|
||||
func TestEmbedChannelBindingAttribute_RoundTrip(t *testing.T) {
|
||||
exporter := repeatByte(0x77, TLSExporterLength)
|
||||
attrDER, err := EmbedChannelBindingAttribute(exporter)
|
||||
if err != nil {
|
||||
t.Fatalf("EmbedChannelBindingAttribute: %v", err)
|
||||
}
|
||||
// Wrap the single attribute in a [0] IMPLICIT SET OF Attribute block
|
||||
// and a TBS-lookalike SEQUENCE so we can feed it through the same path
|
||||
// the parser uses — the parser doesn't care that version+subject+spki
|
||||
// are absent because it walks structurally.
|
||||
attrsField, err := asn1.Marshal(asn1.RawValue{
|
||||
Class: asn1.ClassContextSpecific,
|
||||
Tag: 0,
|
||||
IsCompound: true,
|
||||
Bytes: attrDER,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal attrs field: %v", err)
|
||||
}
|
||||
// Synthetic TBS with three placeholder asn1.RawValue fields then attrsField.
|
||||
placeholder, _ := asn1.Marshal(asn1.RawValue{Class: asn1.ClassUniversal, Tag: asn1.TagInteger, Bytes: []byte{0x00}})
|
||||
body := append(append(append(append([]byte{}, placeholder...), placeholder...), placeholder...), attrsField...)
|
||||
tbs, err := asn1.Marshal(asn1.RawValue{
|
||||
Class: asn1.ClassUniversal,
|
||||
Tag: asn1.TagSequence,
|
||||
IsCompound: true,
|
||||
Bytes: body,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal TBS: %v", err)
|
||||
}
|
||||
got, present, err := walkCSRAttributesForBinding(tbs)
|
||||
if err != nil {
|
||||
t.Fatalf("walkCSRAttributesForBinding: %v", err)
|
||||
}
|
||||
if !present {
|
||||
t.Fatal("present=false on round-trip")
|
||||
}
|
||||
if !bytesEq(got, exporter) {
|
||||
t.Errorf("round-trip mismatch: got %x, want %x", got, exporter)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,188 @@
|
||||
// Package ratelimit provides shared rate-limit primitives used by
|
||||
// authenticated-but-shared-credential code paths (SCEP/Intune
|
||||
// per-device challenge enrollment, EST per-principal CSR enrollment,
|
||||
// EST HTTP-Basic source-IP failed-auth limiter) where the threat
|
||||
// model is "single legitimate identity could mint enrollments
|
||||
// faster than any human/fleet workflow would."
|
||||
//
|
||||
// Origin: this package was extracted from
|
||||
// internal/scep/intune/rate_limit.go in the EST RFC 7030 hardening
|
||||
// master bundle Phase 4.1 — EST is the third caller after the
|
||||
// Intune dispatcher (per-device-GUID cap on enrollment) and the EST
|
||||
// per-principal cap (Phase 4.2). The original Intune-package type +
|
||||
// constructor + ErrRateLimited sentinel are preserved as type
|
||||
// aliases at internal/scep/intune/rate_limit.go so existing call
|
||||
// sites compile unchanged. New callers SHOULD use this package
|
||||
// directly.
|
||||
//
|
||||
// Algorithm: sliding window log. Each key maps to a bucket holding
|
||||
// timestamps within the configured window. On Allow, the bucket
|
||||
// prunes timestamps older than (now - window) and either appends +
|
||||
// returns nil, or rejects + returns ErrRateLimited when the
|
||||
// post-prune count is already at the cap. Exact (no token-leak
|
||||
// rounding); O(N_per_key) per-call but N is bounded by the cap, so
|
||||
// effectively O(1).
|
||||
//
|
||||
// Concurrency: safe for concurrent Allow calls. Internal map guarded
|
||||
// by sync.Mutex; per-key slices mutated only while the mutex is
|
||||
// held.
|
||||
//
|
||||
// Memory: bounded by the per-instance map cap (default 100,000 keys;
|
||||
// configurable). At-cap eviction drops the oldest entry by newest
|
||||
// timestamp — small janitor pass; rarely fires in practice because
|
||||
// the prune-on-Allow path keeps most buckets short-lived.
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrRateLimited is returned by SlidingWindowLimiter.Allow when the
|
||||
// bucket for the given key is already at the cap. Callers can
|
||||
// errors.Is against this sentinel; the underlying message is stable
|
||||
// across the package's lifetime so test assertions can match on it.
|
||||
var ErrRateLimited = errors.New("ratelimit: per-key cap exceeded for the configured window")
|
||||
|
||||
// SlidingWindowLimiter is the sliding-window-log rate limiter.
|
||||
//
|
||||
// Construct via NewSlidingWindowLimiter. The zero value is NOT
|
||||
// usable — the buckets map needs initialisation.
|
||||
type SlidingWindowLimiter struct {
|
||||
mu sync.Mutex
|
||||
buckets map[string][]time.Time // key → sliding window of timestamps
|
||||
maxN int // max enrollments per window
|
||||
window time.Duration // window length (default 24h)
|
||||
cap int // max keys before LRU eviction kicks in
|
||||
disabled bool // maxN <= 0 → all Allow calls return nil
|
||||
}
|
||||
|
||||
// NewSlidingWindowLimiter returns a limiter with the given per-key
|
||||
// cap + window. maxN <= 0 disables the limiter (all Allow calls
|
||||
// return nil); this is operator opt-out for the rare case where the
|
||||
// per-key cap is undesirable (test harnesses, sketchpad deploys).
|
||||
//
|
||||
// Window defaults to 24h when zero. Map cap defaults to 100,000 when
|
||||
// zero (matches the SCEP/Intune replay cache cap).
|
||||
func NewSlidingWindowLimiter(maxN int, window time.Duration, mapCap int) *SlidingWindowLimiter {
|
||||
if window <= 0 {
|
||||
window = 24 * time.Hour
|
||||
}
|
||||
if mapCap <= 0 {
|
||||
mapCap = 100_000
|
||||
}
|
||||
return &SlidingWindowLimiter{
|
||||
buckets: make(map[string][]time.Time),
|
||||
maxN: maxN,
|
||||
window: window,
|
||||
cap: mapCap,
|
||||
disabled: maxN <= 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Allow reports whether an event keyed by `key` is permitted right
|
||||
// now. Returns nil when allowed (and records the timestamp in the
|
||||
// bucket) or ErrRateLimited when the bucket is at maxN.
|
||||
//
|
||||
// Empty key is treated as "skip the limiter" — the caller's
|
||||
// validation should have rejected an empty-key event already; this
|
||||
// is belt-and-suspenders so a single empty-key bucket doesn't
|
||||
// become a chokepoint for every empty-key event. SCEP/Intune
|
||||
// callers compose the key as `subject + "|" + issuer`; EST callers
|
||||
// compose `cn + "|" + sourceIP` or `sourceIP`-alone for the
|
||||
// failed-auth limiter.
|
||||
func (l *SlidingWindowLimiter) Allow(key string, now time.Time) error {
|
||||
if l.disabled {
|
||||
return nil
|
||||
}
|
||||
if key == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
// At-cap eviction: when the map is full, drop the oldest entry
|
||||
// by finding the bucket whose newest timestamp is the smallest.
|
||||
// O(N_keys) but rarely fires; the prune-on-Allow path keeps
|
||||
// most buckets short-lived.
|
||||
if len(l.buckets) >= l.cap {
|
||||
l.evictOldestLocked()
|
||||
}
|
||||
|
||||
bucket := l.buckets[key]
|
||||
bucket = pruneOlderThan(bucket, now.Add(-l.window))
|
||||
|
||||
if len(bucket) >= l.maxN {
|
||||
// Don't append; over the limit. Persist the pruned bucket so
|
||||
// the next call sees the most-recently-pruned state.
|
||||
l.buckets[key] = bucket
|
||||
return ErrRateLimited
|
||||
}
|
||||
|
||||
bucket = append(bucket, now)
|
||||
l.buckets[key] = bucket
|
||||
return nil
|
||||
}
|
||||
|
||||
// pruneOlderThan returns the slice with all entries strictly before
|
||||
// `cutoff` removed. Preserves order (timestamps are appended in
|
||||
// increasing time, so a single linear scan from the front suffices).
|
||||
func pruneOlderThan(b []time.Time, cutoff time.Time) []time.Time {
|
||||
i := 0
|
||||
for i < len(b) && b[i].Before(cutoff) {
|
||||
i++
|
||||
}
|
||||
if i == 0 {
|
||||
return b
|
||||
}
|
||||
// Copy-shrink to release the underlying-array memory eventually
|
||||
// (otherwise the slice would hold a reference to the older
|
||||
// entries indefinitely until a re-allocation).
|
||||
out := make([]time.Time, len(b)-i)
|
||||
copy(out, b[i:])
|
||||
return out
|
||||
}
|
||||
|
||||
// evictOldestLocked drops the map entry whose newest timestamp is
|
||||
// the oldest. Called under l.mu. O(N_keys) per eviction; at-cap is
|
||||
// rare in practice (caps are sized for steady-state).
|
||||
func (l *SlidingWindowLimiter) evictOldestLocked() {
|
||||
var (
|
||||
oldestKey string
|
||||
oldestTs time.Time
|
||||
first = true
|
||||
)
|
||||
for k, b := range l.buckets {
|
||||
if len(b) == 0 {
|
||||
// Empty bucket — drop it immediately, no candidate scan needed.
|
||||
delete(l.buckets, k)
|
||||
return
|
||||
}
|
||||
newest := b[len(b)-1]
|
||||
if first || newest.Before(oldestTs) {
|
||||
oldestKey = k
|
||||
oldestTs = newest
|
||||
first = false
|
||||
}
|
||||
}
|
||||
if oldestKey != "" {
|
||||
delete(l.buckets, oldestKey)
|
||||
}
|
||||
}
|
||||
|
||||
// Len returns the approximate number of distinct keys currently
|
||||
// tracked. For observability + tests; not load-stable under
|
||||
// concurrent Allow calls.
|
||||
func (l *SlidingWindowLimiter) Len() int {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return len(l.buckets)
|
||||
}
|
||||
|
||||
// Disabled reports whether the limiter is in opt-out mode (maxN <= 0).
|
||||
// Useful for handler-side gating + admin-endpoint observability.
|
||||
func (l *SlidingWindowLimiter) Disabled() bool {
|
||||
return l.disabled
|
||||
}
|
||||
@@ -0,0 +1,197 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EST RFC 7030 hardening master bundle Phase 4.1: this test file holds the
|
||||
// white-box tests for the SlidingWindowLimiter primitives that used to live
|
||||
// in internal/scep/intune/rate_limit_test.go (TestPerDeviceRateLimiter_
|
||||
// DefaultCapsHonored, TestPruneOlderThan, TestPruneOlderThan_NoOpWhen
|
||||
// NothingToPrune). The behavioral coverage in intune/rate_limit_test.go
|
||||
// stays — it exercises the wrapper's (subject, issuer)-composition contract
|
||||
// + the empty-subject short-circuit + concurrent race-freedom.
|
||||
|
||||
func TestSlidingWindowLimiter_AllowsUpToCap(t *testing.T) {
|
||||
l := NewSlidingWindowLimiter(3, 24*time.Hour, 10)
|
||||
now := time.Now()
|
||||
for i := 0; i < 3; i++ {
|
||||
if err := l.Allow("k", now.Add(time.Duration(i)*time.Minute)); err != nil {
|
||||
t.Fatalf("call %d should be allowed: %v", i+1, err)
|
||||
}
|
||||
}
|
||||
if err := l.Allow("k", now.Add(4*time.Minute)); !errors.Is(err, ErrRateLimited) {
|
||||
t.Fatalf("4th call should be rate-limited; got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlidingWindowLimiter_DistinctKeysIndependent(t *testing.T) {
|
||||
l := NewSlidingWindowLimiter(1, 24*time.Hour, 10)
|
||||
now := time.Now()
|
||||
|
||||
if err := l.Allow("k-1", now); err != nil {
|
||||
t.Fatalf("first allow: %v", err)
|
||||
}
|
||||
if err := l.Allow("k-2", now); err != nil {
|
||||
t.Fatalf("different key must have its own bucket: %v", err)
|
||||
}
|
||||
if err := l.Allow("k-1", now.Add(1*time.Second)); !errors.Is(err, ErrRateLimited) {
|
||||
t.Fatalf("repeat key should be limited; got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlidingWindowLimiter_WindowExpiry(t *testing.T) {
|
||||
l := NewSlidingWindowLimiter(2, 1*time.Hour, 10)
|
||||
now := time.Now()
|
||||
|
||||
if err := l.Allow("k", now); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := l.Allow("k", now.Add(30*time.Minute)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Inside window — limited.
|
||||
if err := l.Allow("k", now.Add(45*time.Minute)); !errors.Is(err, ErrRateLimited) {
|
||||
t.Fatalf("inside-window 3rd call should be limited: %v", err)
|
||||
}
|
||||
// Past window — slots reopen.
|
||||
if err := l.Allow("k", now.Add(2*time.Hour)); err != nil {
|
||||
t.Fatalf("past-window call should be allowed (window reset): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlidingWindowLimiter_DisabledBypass(t *testing.T) {
|
||||
l := NewSlidingWindowLimiter(0, 24*time.Hour, 10) // maxN=0 → disabled
|
||||
if !l.Disabled() {
|
||||
t.Fatal("limiter with maxN=0 must report Disabled()=true")
|
||||
}
|
||||
now := time.Now()
|
||||
for i := 0; i < 100; i++ {
|
||||
if err := l.Allow("k", now); err != nil {
|
||||
t.Fatalf("disabled limiter must allow everything: %v", err)
|
||||
}
|
||||
}
|
||||
if got := l.Len(); got != 0 {
|
||||
t.Errorf("disabled limiter Len() = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlidingWindowLimiter_NegativeCapDisabled(t *testing.T) {
|
||||
l := NewSlidingWindowLimiter(-1, 24*time.Hour, 10)
|
||||
if !l.Disabled() {
|
||||
t.Fatal("negative maxN must produce a disabled limiter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlidingWindowLimiter_EmptyKeyShortCircuits(t *testing.T) {
|
||||
// Empty key is the caller's defense-in-depth case — caller's validation
|
||||
// upstream should reject empty-key events first. Limiter must not build
|
||||
// a single shared bucket keyed by empty-key — that would be a chokepoint
|
||||
// for every empty-key event.
|
||||
l := NewSlidingWindowLimiter(1, 24*time.Hour, 10)
|
||||
now := time.Now()
|
||||
for i := 0; i < 50; i++ {
|
||||
if err := l.Allow("", now); err != nil {
|
||||
t.Fatalf("empty key must short-circuit (call %d): %v", i, err)
|
||||
}
|
||||
}
|
||||
if got := l.Len(); got != 0 {
|
||||
t.Errorf("Len after 50 empty-key calls = %d, want 0 (no bucket created)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlidingWindowLimiter_DefaultCapsHonored(t *testing.T) {
|
||||
// White-box test: exercises the constructor's default-fill branches.
|
||||
// Lives here (not in the intune wrapper test) because the fields
|
||||
// (window + cap) are package-private to ratelimit.
|
||||
l := NewSlidingWindowLimiter(5, 0, 0) // window=0 → 24h default; cap=0 → 100k default
|
||||
if l.window != 24*time.Hour {
|
||||
t.Errorf("default window = %v, want 24h", l.window)
|
||||
}
|
||||
if l.cap != 100_000 {
|
||||
t.Errorf("default cap = %d, want 100000", l.cap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlidingWindowLimiter_MapCapEvictsOldest(t *testing.T) {
|
||||
// Cap of 3 keys to exercise the eviction branch deterministically.
|
||||
l := NewSlidingWindowLimiter(2, 1*time.Hour, 3)
|
||||
now := time.Now()
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
key := fmt.Sprintf("k-%d", i)
|
||||
if err := l.Allow(key, now.Add(time.Duration(i)*time.Minute)); err != nil {
|
||||
t.Fatalf("insert %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
if l.Len() != 3 {
|
||||
t.Fatalf("Len = %d, want 3", l.Len())
|
||||
}
|
||||
|
||||
// 4th key forces eviction of k-0 (its newest timestamp is oldest).
|
||||
if err := l.Allow("k-3", now.Add(10*time.Minute)); err != nil {
|
||||
t.Fatalf("4th-key insert: %v", err)
|
||||
}
|
||||
if l.Len() != 3 {
|
||||
t.Errorf("Len after at-cap insert = %d, want 3 (cap honored)", l.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlidingWindowLimiter_ConcurrentRaceFree(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("race-style test under -short")
|
||||
}
|
||||
l := NewSlidingWindowLimiter(50, 24*time.Hour, 10000)
|
||||
var wg sync.WaitGroup
|
||||
for g := 0; g < 20; g++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
now := time.Now()
|
||||
key := fmt.Sprintf("k-%d", id)
|
||||
for i := 0; i < 30; i++ {
|
||||
_ = l.Allow(key, now)
|
||||
}
|
||||
}(g)
|
||||
}
|
||||
wg.Wait()
|
||||
if got := l.Len(); got != 20 {
|
||||
t.Errorf("expected 20 distinct keys; got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// White-box tests for the unexported pruneOlderThan helper. Live in this
|
||||
// package because the helper is package-private to ratelimit. The test
|
||||
// surface used to live in intune/rate_limit_test.go before the Phase 4.1
|
||||
// extraction.
|
||||
func TestPruneOlderThan(t *testing.T) {
|
||||
t0 := time.Now()
|
||||
in := []time.Time{
|
||||
t0.Add(-3 * time.Hour), // pruned (older than cutoff)
|
||||
t0.Add(-2 * time.Hour), // pruned (older than cutoff)
|
||||
t0.Add(-1 * time.Hour), // survives (-60m is NEWER than the -90m cutoff)
|
||||
t0.Add(-30 * time.Minute), // survives
|
||||
t0, // survives
|
||||
}
|
||||
out := pruneOlderThan(in, t0.Add(-90*time.Minute))
|
||||
if len(out) != 3 {
|
||||
t.Fatalf("len(out) = %d, want 3 (-1h, -30m, t0 all newer than -90m cutoff)", len(out))
|
||||
}
|
||||
if !out[0].Equal(t0.Add(-1 * time.Hour)) {
|
||||
t.Errorf("out[0] = %v, want -1h (oldest surviving entry)", out[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPruneOlderThan_NoOpWhenNothingToPrune(t *testing.T) {
|
||||
t0 := time.Now()
|
||||
in := []time.Time{t0.Add(-1 * time.Minute), t0}
|
||||
out := pruneOlderThan(in, t0.Add(-1*time.Hour))
|
||||
// Same slice header (no copy needed).
|
||||
if len(out) != len(in) {
|
||||
t.Fatalf("len(out) = %d, want %d", len(out), len(in))
|
||||
}
|
||||
}
|
||||
@@ -1,193 +1,87 @@
|
||||
package intune
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/ratelimit"
|
||||
)
|
||||
|
||||
// SCEP RFC 8894 + Intune master bundle Phase 8.6.
|
||||
//
|
||||
// PerDeviceRateLimiter is the second line of defense behind the replay cache
|
||||
// from Phase 7. The replay cache catches the same challenge being submitted
|
||||
// twice (within the challenge TTL); this rate limiter catches a compromised
|
||||
// Connector signing key (or a stolen key+cert pair) issuing many DIFFERENT
|
||||
// valid challenges for the same device subject in a short window.
|
||||
// PerDeviceRateLimiter is the second line of defense behind the replay
|
||||
// cache from Phase 7. The replay cache catches the same challenge being
|
||||
// submitted twice (within the challenge TTL); this rate limiter catches a
|
||||
// compromised Connector signing key (or a stolen key+cert pair) issuing
|
||||
// many DIFFERENT valid challenges for the same device subject in a short
|
||||
// window.
|
||||
//
|
||||
// Threat model:
|
||||
//
|
||||
// - Replay cache (Phase 7): nonce-keyed; catches duplicate submission.
|
||||
// - This limiter: (Subject, Issuer)-keyed; catches enrollment-flooding.
|
||||
//
|
||||
// Default: 3 enrollments per (device GUID, Connector identity) per 24h.
|
||||
// EST RFC 7030 hardening master bundle Phase 4.1: the implementation that
|
||||
// used to live in this file was extracted to internal/ratelimit (where it
|
||||
// can be shared with EST per-principal + EST HTTP-Basic source-IP rate
|
||||
// limiters). PerDeviceRateLimiter is now a thin wrapper around
|
||||
// ratelimit.SlidingWindowLimiter that preserves the original
|
||||
// (subject, issuer) → key composition in the Allow signature so existing
|
||||
// SCEP/Intune callers don't have to change.
|
||||
//
|
||||
// Sizing: 100,000 distinct device entries (matches the replay cache cap).
|
||||
// At-cap: oldest entry evicted (small janitor pass) to avoid unbounded
|
||||
// memory growth on a fleet that grows past the cap.
|
||||
//
|
||||
// Why a hand-rolled token bucket instead of pulling in golang.org/x/time/rate:
|
||||
// the rate package is in go.sum as an indirect transitive but NOT a direct
|
||||
// dep. Adding it would create a new direct dep relationship for ~30 LoC of
|
||||
// state machine. The hand-rolled version below uses only stdlib (sync.Mutex
|
||||
// + time.Time arithmetic) and is small enough to fit on one screen.
|
||||
//
|
||||
// Algorithm: each (Subject, Issuer) key maps to a bucket holding a window's
|
||||
// worth of recent enrollment timestamps. On Allow, the bucket prunes
|
||||
// timestamps older than (now - window) and either appends the current
|
||||
// timestamp + returns true, or rejects + returns false when the post-prune
|
||||
// count is already at the cap. This is the "sliding window log" rate
|
||||
// limiter — exact (no token-leak rounding); O(N_per_key) per-call but N is
|
||||
// bounded by the cap (3 by default), so effectively O(1).
|
||||
// New callers SHOULD use ratelimit.SlidingWindowLimiter directly. The
|
||||
// EST RFC 7030 Phase 4.2 EST per-principal cap uses the shared package.
|
||||
|
||||
// ErrRateLimited is the typed error returned when the per-device rate limit
|
||||
// fires. The handler maps this to a CertRep FAILURE with badRequest failInfo
|
||||
// + the `rate_limited` metric label.
|
||||
var ErrRateLimited = errors.New("intune: per-device rate limit exceeded for this (subject, issuer) within the configured window")
|
||||
// ErrRateLimited is the typed error returned when the per-device rate
|
||||
// limit fires. Aliased to ratelimit.ErrRateLimited so errors.Is matches
|
||||
// against either name (the SCEP audit closure already pinned the
|
||||
// "rate_limited" metric label against this sentinel; the alias preserves
|
||||
// sentinel identity across the package boundary).
|
||||
var ErrRateLimited = ratelimit.ErrRateLimited
|
||||
|
||||
// PerDeviceRateLimiter is a sliding-window-log rate limiter keyed by
|
||||
// (Subject, Issuer) tuples derived from a parsed challenge claim.
|
||||
//
|
||||
// Concurrency: the limiter is safe for concurrent Allow calls. The internal
|
||||
// map is guarded by a mutex; the per-key slices are mutated only while the
|
||||
// mutex is held.
|
||||
// PerDeviceRateLimiter wraps ratelimit.SlidingWindowLimiter with the
|
||||
// (subject, issuer)-composed-key Allow signature the Intune dispatcher
|
||||
// uses. Concurrency-safe (the underlying limiter holds the mutex).
|
||||
type PerDeviceRateLimiter struct {
|
||||
mu sync.Mutex
|
||||
buckets map[string][]time.Time // key → sliding window of timestamps
|
||||
maxN int // max enrollments per window
|
||||
window time.Duration // window length (default 24h)
|
||||
cap int // max keys before LRU eviction kicks in
|
||||
disabled bool // maxN == 0 → all Allow calls return nil
|
||||
inner *ratelimit.SlidingWindowLimiter
|
||||
}
|
||||
|
||||
// NewPerDeviceRateLimiter returns a limiter with the given per-key cap +
|
||||
// window. maxN ≤ 0 disables the limiter (all Allow calls return nil); this
|
||||
// is operator opt-out for the rare case where the per-device cap is
|
||||
// window. maxN ≤ 0 disables the limiter (all Allow calls return nil);
|
||||
// this is operator opt-out for the rare case where the per-device cap is
|
||||
// undesirable (e.g. test harnesses, sketchpad deploys).
|
||||
//
|
||||
// Window defaults to 24h when zero. Map cap defaults to 100,000 when zero
|
||||
// (matches the replay cache cap; see internal/scep/intune/replay.go).
|
||||
func NewPerDeviceRateLimiter(maxN int, window time.Duration, mapCap int) *PerDeviceRateLimiter {
|
||||
if window <= 0 {
|
||||
window = 24 * time.Hour
|
||||
}
|
||||
if mapCap <= 0 {
|
||||
mapCap = 100_000
|
||||
}
|
||||
return &PerDeviceRateLimiter{
|
||||
buckets: make(map[string][]time.Time),
|
||||
maxN: maxN,
|
||||
window: window,
|
||||
cap: mapCap,
|
||||
disabled: maxN <= 0,
|
||||
}
|
||||
return &PerDeviceRateLimiter{inner: ratelimit.NewSlidingWindowLimiter(maxN, window, mapCap)}
|
||||
}
|
||||
|
||||
// Allow checks whether an enrollment for the given (subject, issuer) tuple
|
||||
// is permitted right now. Returns nil when allowed (and records the timestamp
|
||||
// in the bucket) or ErrRateLimited when the bucket is at maxN.
|
||||
// Allow checks whether an enrollment for the given (subject, issuer)
|
||||
// tuple is permitted right now. Returns nil when allowed (and records
|
||||
// the timestamp in the bucket) or ErrRateLimited when the bucket is at
|
||||
// maxN.
|
||||
//
|
||||
// Empty subject is treated as "skip the limiter" — the caller's claim
|
||||
// validation should have rejected an empty-subject claim already; this is
|
||||
// belt-and-suspenders to prevent a single empty-subject bucket from
|
||||
// becoming a fleet-wide chokepoint. The Connector emits non-empty subject
|
||||
// (device GUID) on every legitimate challenge.
|
||||
// validation should have rejected an empty-subject claim already; this
|
||||
// is belt-and-suspenders to prevent a single empty-subject bucket from
|
||||
// becoming a fleet-wide chokepoint.
|
||||
func (l *PerDeviceRateLimiter) Allow(subject, issuer string, now time.Time) error {
|
||||
if l.disabled {
|
||||
return nil
|
||||
}
|
||||
if subject == "" {
|
||||
// Caller's claim validation should reject empty-subject upstream;
|
||||
// this short-circuit is defense-in-depth so a misconfigured
|
||||
// Connector can't DoS us via the rate-limit path.
|
||||
// Empty-subject early return preserved from the pre-Phase-4.1
|
||||
// behavior: ratelimit.SlidingWindowLimiter also short-circuits
|
||||
// on empty key, but the explicit check here documents the
|
||||
// (subject, issuer) → empty-key contract and saves one call
|
||||
// frame in the hot path.
|
||||
return nil
|
||||
}
|
||||
key := subject + "|" + issuer
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
// At-cap eviction: when the map is full, drop the oldest entry by
|
||||
// finding the bucket whose newest timestamp is the smallest. O(N) but
|
||||
// rarely fires; the prune-on-Allow path keeps most buckets short-lived.
|
||||
if len(l.buckets) >= l.cap {
|
||||
l.evictOldestLocked(now)
|
||||
}
|
||||
|
||||
bucket := l.buckets[key]
|
||||
bucket = pruneOlderThan(bucket, now.Add(-l.window))
|
||||
|
||||
if len(bucket) >= l.maxN {
|
||||
// Don't append; over the limit. Persist the pruned bucket so the
|
||||
// next call sees the most-recently-pruned state.
|
||||
l.buckets[key] = bucket
|
||||
return ErrRateLimited
|
||||
}
|
||||
|
||||
bucket = append(bucket, now)
|
||||
l.buckets[key] = bucket
|
||||
return nil
|
||||
}
|
||||
|
||||
// pruneOlderThan returns the slice with all entries strictly before
|
||||
// `cutoff` removed. Preserves order (timestamps are appended in increasing
|
||||
// time, so a single linear scan from the front suffices).
|
||||
func pruneOlderThan(b []time.Time, cutoff time.Time) []time.Time {
|
||||
i := 0
|
||||
for i < len(b) && b[i].Before(cutoff) {
|
||||
i++
|
||||
}
|
||||
if i == 0 {
|
||||
return b
|
||||
}
|
||||
// Copy-shrink to release the underlying-array memory eventually
|
||||
// (otherwise the slice would hold a reference to the older entries
|
||||
// indefinitely until a re-allocation).
|
||||
out := make([]time.Time, len(b)-i)
|
||||
copy(out, b[i:])
|
||||
return out
|
||||
}
|
||||
|
||||
// evictOldestLocked drops the map entry whose newest timestamp is the
|
||||
// oldest. Called under l.mu. O(N_keys) per eviction; at-cap is rare in
|
||||
// practice (caps are sized for fleet steady-state).
|
||||
func (l *PerDeviceRateLimiter) evictOldestLocked(now time.Time) {
|
||||
var (
|
||||
oldestKey string
|
||||
oldestTs time.Time
|
||||
first = true
|
||||
)
|
||||
for k, b := range l.buckets {
|
||||
if len(b) == 0 {
|
||||
// Empty bucket — drop it immediately, no candidate scan needed.
|
||||
delete(l.buckets, k)
|
||||
return
|
||||
}
|
||||
newest := b[len(b)-1]
|
||||
if first || newest.Before(oldestTs) {
|
||||
oldestKey = k
|
||||
oldestTs = newest
|
||||
first = false
|
||||
}
|
||||
}
|
||||
if oldestKey != "" {
|
||||
delete(l.buckets, oldestKey)
|
||||
}
|
||||
// Suppress unused-parameter warning for `now` in case the eviction
|
||||
// strategy changes (e.g. swap to LRU keyed by time of last Allow).
|
||||
_ = now
|
||||
return l.inner.Allow(key, now)
|
||||
}
|
||||
|
||||
// Len returns the approximate number of distinct (subject, issuer) keys
|
||||
// currently tracked. For observability + tests; not load-stable under
|
||||
// concurrent Allow calls.
|
||||
func (l *PerDeviceRateLimiter) Len() int {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return len(l.buckets)
|
||||
}
|
||||
// currently tracked. For observability + tests.
|
||||
func (l *PerDeviceRateLimiter) Len() int { return l.inner.Len() }
|
||||
|
||||
// Disabled reports whether the limiter is in opt-out mode (maxN ≤ 0).
|
||||
// Useful for handler-side gating + admin-endpoint observability.
|
||||
func (l *PerDeviceRateLimiter) Disabled() bool {
|
||||
return l.disabled
|
||||
}
|
||||
func (l *PerDeviceRateLimiter) Disabled() bool { return l.inner.Disabled() }
|
||||
|
||||
@@ -103,15 +103,11 @@ func TestPerDeviceRateLimiter_EmptySubjectShortCircuits(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerDeviceRateLimiter_DefaultCapsHonored(t *testing.T) {
|
||||
l := NewPerDeviceRateLimiter(5, 0, 0) // window=0 → 24h default; cap=0 → 100k default
|
||||
if l.window != 24*time.Hour {
|
||||
t.Errorf("default window = %v, want 24h", l.window)
|
||||
}
|
||||
if l.cap != 100_000 {
|
||||
t.Errorf("default cap = %d, want 100000", l.cap)
|
||||
}
|
||||
}
|
||||
// TestPerDeviceRateLimiter_DefaultCapsHonored — moved to
|
||||
// internal/ratelimit/sliding_window_test.go::TestSlidingWindowLimiter_DefaultCapsHonored
|
||||
// in EST RFC 7030 hardening Phase 4.1 (the white-box test reads private
|
||||
// fields that no longer exist on the wrapper). The shared package owns
|
||||
// the field-default contract.
|
||||
|
||||
func TestPerDeviceRateLimiter_MapCapEvictsOldest(t *testing.T) {
|
||||
// Cap of 3 keys to exercise the eviction branch deterministically.
|
||||
@@ -161,30 +157,8 @@ func TestPerDeviceRateLimiter_ConcurrentRaceFree(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPruneOlderThan(t *testing.T) {
|
||||
t0 := time.Now()
|
||||
in := []time.Time{
|
||||
t0.Add(-3 * time.Hour), // pruned (older than cutoff)
|
||||
t0.Add(-2 * time.Hour), // pruned (older than cutoff)
|
||||
t0.Add(-1 * time.Hour), // survives (-60m is NEWER than the -90m cutoff)
|
||||
t0.Add(-30 * time.Minute), // survives
|
||||
t0, // survives
|
||||
}
|
||||
out := pruneOlderThan(in, t0.Add(-90*time.Minute))
|
||||
if len(out) != 3 {
|
||||
t.Fatalf("len(out) = %d, want 3 (-1h, -30m, t0 all newer than -90m cutoff)", len(out))
|
||||
}
|
||||
if !out[0].Equal(t0.Add(-1 * time.Hour)) {
|
||||
t.Errorf("out[0] = %v, want -1h (oldest surviving entry)", out[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPruneOlderThan_NoOpWhenNothingToPrune(t *testing.T) {
|
||||
t0 := time.Now()
|
||||
in := []time.Time{t0.Add(-1 * time.Minute), t0}
|
||||
out := pruneOlderThan(in, t0.Add(-1*time.Hour))
|
||||
// Same slice header (no copy needed).
|
||||
if len(out) != len(in) {
|
||||
t.Fatalf("len(out) = %d, want %d", len(out), len(in))
|
||||
}
|
||||
}
|
||||
// TestPruneOlderThan + TestPruneOlderThan_NoOpWhenNothingToPrune — moved
|
||||
// to internal/ratelimit/sliding_window_test.go in EST RFC 7030 hardening
|
||||
// Phase 4.1. pruneOlderThan is now an unexported helper of the shared
|
||||
// ratelimit package (the implementation moved there); the white-box
|
||||
// tests follow.
|
||||
|
||||
@@ -1,73 +1,45 @@
|
||||
package intune
|
||||
|
||||
// SCEP RFC 8894 + Intune master bundle Phase 7.2 (originally) +
|
||||
// EST RFC 7030 hardening master bundle Phase 2.1 (extraction).
|
||||
//
|
||||
// LoadTrustAnchor + parseTrustAnchorPEM were extracted to
|
||||
// internal/trustanchor.LoadBundle + parseBundlePEM so the EST mTLS
|
||||
// sibling route (Phase 2 of the EST hardening bundle), the Intune
|
||||
// dispatcher, and any future per-profile-trust-bundle caller can share
|
||||
// the same PEM-bundle loader + SIGHUP-reload semantics. The shim below
|
||||
// preserves the original public surface so existing intune callers
|
||||
// (cmd/server/main.go, scep_intune_e2e_test.go, scep_profile_counter_
|
||||
// isolation_test.go, scep_intune.go service) compile unchanged.
|
||||
//
|
||||
// New callers SHOULD import internal/trustanchor directly — the
|
||||
// trustanchor.Holder + trustanchor.LoadBundle are the modern API.
|
||||
//
|
||||
// Note: the legacy intune error messages ("intune: trust anchor cert
|
||||
// in %q expired ...") are NOT preserved verbatim across the extraction;
|
||||
// the shared trustanchor package emits "trustanchor: ..." messages
|
||||
// instead. The operator-facing log line at cmd/server/main.go's
|
||||
// preflightSCEPIntuneTrustAnchor wraps the error in its own outer
|
||||
// ("SCEP profile (PathID=...) INTUNE trust anchor load failed: ...")
|
||||
// so the prefix change is invisible to log-grep runbooks that filter
|
||||
// on the outer message.
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/trustanchor"
|
||||
)
|
||||
|
||||
// LoadTrustAnchor reads a PEM bundle of one or more Intune Connector
|
||||
// signing certificates from the configured path. Returns the slice of
|
||||
// parsed certs that the validator will accept as challenge issuers.
|
||||
// signing certificates from the configured path. Delegates to the
|
||||
// shared trustanchor.LoadBundle (extracted in EST RFC 7030 hardening
|
||||
// Phase 2.1) so the EST mTLS sibling route + the Intune dispatcher
|
||||
// + any future per-profile trust-bundle caller share the same
|
||||
// loader semantics (path-empty refusal, expired-cert refusal,
|
||||
// non-CERTIFICATE-block tolerance).
|
||||
//
|
||||
// SCEP RFC 8894 + Intune master bundle Phase 7.2.
|
||||
//
|
||||
// Behavior:
|
||||
//
|
||||
// - File must exist + be readable.
|
||||
// - PEM-decodes the file; non-CERTIFICATE blocks are skipped (so an
|
||||
// operator can paste a chain that includes a private key by mistake
|
||||
// without breaking the load — the priv key is just ignored).
|
||||
// - Returns an error if zero CERTIFICATE blocks parse.
|
||||
// - Returns an error if any cert is past NotAfter (a stale trust
|
||||
// anchor would silently reject every Intune challenge at runtime;
|
||||
// fail loud at startup instead).
|
||||
//
|
||||
// Operators rotate Connector signing certs periodically; the trust
|
||||
// anchor file is reloaded on SIGHUP (handled by the existing config
|
||||
// watch loop in cmd/server/main.go — see cmd/server/tls.go::watchSIGHUP
|
||||
// for the precedent).
|
||||
// Preserved here as a wrapper so existing intune callers compile
|
||||
// unchanged. New callers SHOULD use trustanchor.LoadBundle directly.
|
||||
func LoadTrustAnchor(path string) ([]*x509.Certificate, error) {
|
||||
if path == "" {
|
||||
return nil, fmt.Errorf("intune: trust anchor path is empty")
|
||||
}
|
||||
body, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("intune: read trust anchor %q: %w", path, err)
|
||||
}
|
||||
return parseTrustAnchorPEM(body, path, time.Now())
|
||||
}
|
||||
|
||||
// parseTrustAnchorPEM is the file-IO-free core of LoadTrustAnchor. Split
|
||||
// out so unit tests can hand it byte slices without writing temp files.
|
||||
// `now` is taken as a parameter so expiry tests can pin a deterministic
|
||||
// clock.
|
||||
func parseTrustAnchorPEM(body []byte, sourceLabel string, now time.Time) ([]*x509.Certificate, error) {
|
||||
var out []*x509.Certificate
|
||||
rest := body
|
||||
for {
|
||||
var block *pem.Block
|
||||
block, rest = pem.Decode(rest)
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
if block.Type != "CERTIFICATE" {
|
||||
continue
|
||||
}
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("intune: parse trust anchor cert in %q: %w", sourceLabel, err)
|
||||
}
|
||||
if now.After(cert.NotAfter) {
|
||||
return nil, fmt.Errorf("intune: trust anchor cert in %q expired at %s (subject=%q) — operator must rotate the Connector signing cert before restart",
|
||||
sourceLabel, cert.NotAfter.Format(time.RFC3339), cert.Subject.CommonName)
|
||||
}
|
||||
out = append(out, cert)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil, fmt.Errorf("intune: trust anchor %q contains no CERTIFICATE PEM blocks", sourceLabel)
|
||||
}
|
||||
return out, nil
|
||||
return trustanchor.LoadBundle(path)
|
||||
}
|
||||
|
||||
@@ -1,143 +1,58 @@
|
||||
package intune
|
||||
|
||||
// SCEP RFC 8894 + Intune master bundle Phase 8.5 (originally) +
|
||||
// EST RFC 7030 hardening master bundle Phase 2.1 (extraction).
|
||||
//
|
||||
// TrustAnchorHolder + NewTrustAnchorHolder were extracted to
|
||||
// internal/trustanchor.Holder + trustanchor.New so the EST mTLS sibling
|
||||
// route (Phase 2 of the EST hardening bundle) and the Intune dispatcher
|
||||
// can share the same SIGHUP-reloadable PEM bundle primitive. A single
|
||||
// SIGHUP now rotates: server TLS cert (cmd/server/tls.go), every Intune
|
||||
// trust anchor (this package's existing wiring), AND every EST mTLS
|
||||
// per-profile client-CA bundle (the new sibling route) — exactly the
|
||||
// design contract documented in the trustanchor package doc.
|
||||
//
|
||||
// The aliases below preserve every existing intune call site unchanged:
|
||||
// - cmd/server/main.go declares `intuneTrustHolders []*intune.TrustAnchorHolder`
|
||||
// + invokes `intune.NewTrustAnchorHolder(path, logger)`
|
||||
// - internal/service/scep.go's SCEPService struct field
|
||||
// `intuneTrust *intune.TrustAnchorHolder` (the type alias keeps this
|
||||
// pointer-compatible with the original)
|
||||
// - internal/scep/intune/trust_anchor_holder_test.go + the e2e tests
|
||||
// that construct a holder via NewTrustAnchorHolder
|
||||
//
|
||||
// New callers SHOULD import internal/trustanchor directly — the
|
||||
// trustanchor.Holder + trustanchor.New are the modern API. The intune
|
||||
// aliases are preserved indefinitely for back-compat (no deprecation
|
||||
// timeline; the cost of the two-line shim is trivial).
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"github.com/shankar0123/certctl/internal/trustanchor"
|
||||
)
|
||||
|
||||
// TrustAnchorHolder is the SIGHUP-reloadable wrapper around a per-profile
|
||||
// Intune Connector trust anchor pool.
|
||||
//
|
||||
// SCEP RFC 8894 + Intune master bundle Phase 8.5.
|
||||
//
|
||||
// Mirrors the shape established by `cmd/server/tls.go::certHolder` for the
|
||||
// server TLS cert: an RWMutex-guarded pool, a Get accessor that's safe for
|
||||
// concurrent callers from the request path, a Reload that re-reads the file
|
||||
// and atomically swaps the slice on success (failure leaves the OLD pool in
|
||||
// place so a bad reload doesn't take Intune enrollment down), and a
|
||||
// watchSIGHUP goroutine that responds to the same SIGHUP the operator uses
|
||||
// to rotate the server TLS cert.
|
||||
//
|
||||
// Why SIGHUP specifically (vs fsnotify or a polling loop): SIGHUP is the
|
||||
// repo-established convention (see cmd/server/tls.go). fsnotify would add a
|
||||
// new direct dep + complicate the cleanup story. The operator's Connector-
|
||||
// rotation script writes the new PEM bundle then sends SIGHUP — the same
|
||||
// signal that already rotates the server TLS cert — and both swap atomically.
|
||||
//
|
||||
// Concurrency contract:
|
||||
// - Get returns the pool slice header by value; the slice itself is
|
||||
// immutable per-snapshot (Reload swaps a fresh slice rather than
|
||||
// mutating the existing one). Callers may iterate the returned slice
|
||||
// without holding any lock.
|
||||
// - Reload acquires a write lock briefly for the swap. Concurrent Get
|
||||
// calls block only for that swap window (microseconds).
|
||||
// - watchSIGHUP runs at most one Reload at a time per holder.
|
||||
type TrustAnchorHolder struct {
|
||||
mu sync.RWMutex
|
||||
certs []*x509.Certificate
|
||||
path string
|
||||
logger *slog.Logger
|
||||
}
|
||||
// Aliased to trustanchor.Holder (extracted in EST RFC 7030 hardening
|
||||
// Phase 2.1) so the EST mTLS sibling route + the Intune dispatcher share
|
||||
// the same primitive. Existing callers compile unchanged because Go type
|
||||
// aliases are pointer-compatible.
|
||||
type TrustAnchorHolder = trustanchor.Holder
|
||||
|
||||
// NewTrustAnchorHolder loads the trust bundle and returns a holder. Returns
|
||||
// the same fail-loud error LoadTrustAnchor does on initial load — the
|
||||
// startup gate at cmd/server/main.go is supposed to refuse boot when this
|
||||
// fails. Subsequent Reload errors are non-fatal (logged + old pool retained).
|
||||
// NewTrustAnchorHolder loads the trust bundle and returns a holder.
|
||||
// Aliased to trustanchor.New (extracted in EST RFC 7030 hardening
|
||||
// Phase 2.1). Returns the same fail-loud error LoadTrustAnchor does on
|
||||
// initial load — the startup gate at cmd/server/main.go is supposed to
|
||||
// refuse boot when this fails. Subsequent Reload errors are non-fatal
|
||||
// (logged + old pool retained).
|
||||
//
|
||||
// The logger is required (never nil); the caller passes a per-profile
|
||||
// scoped logger so SIGHUP-reload events show the PathID for triage.
|
||||
func NewTrustAnchorHolder(path string, logger *slog.Logger) (*TrustAnchorHolder, error) {
|
||||
if logger == nil {
|
||||
return nil, errors.New("intune: TrustAnchorHolder requires a non-nil logger")
|
||||
}
|
||||
certs, err := LoadTrustAnchor(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TrustAnchorHolder{
|
||||
certs: certs,
|
||||
path: path,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get returns the current trust anchor pool. Safe for concurrent callers;
|
||||
// the slice header is returned by value and the underlying slice is
|
||||
// immutable per-snapshot (Reload swaps a fresh slice, doesn't mutate in
|
||||
// place — see Reload).
|
||||
func (h *TrustAnchorHolder) Get() []*x509.Certificate {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.certs
|
||||
}
|
||||
|
||||
// Path returns the on-disk path the holder reloads from. Useful for
|
||||
// observability (admin endpoints, log lines) without exposing the cert
|
||||
// pool itself.
|
||||
func (h *TrustAnchorHolder) Path() string {
|
||||
return h.path
|
||||
}
|
||||
|
||||
// Reload re-reads the trust anchor file at h.path and atomically swaps the
|
||||
// pool. Returns the parse error if the new file is invalid; the OLD pool
|
||||
// stays in place so a bad reload doesn't take Intune enrollment down.
|
||||
//
|
||||
// Same fail-safe pattern as cmd/server/tls.go::(*certHolder).Reload — a
|
||||
// rotation that writes a half-file (operator overwrites the bundle while
|
||||
// only some of the new certs are in it) would otherwise crash the
|
||||
// service mid-rotation. Logging + retaining the old pool gives the
|
||||
// operator a bounded window to fix and re-SIGHUP.
|
||||
func (h *TrustAnchorHolder) Reload() error {
|
||||
certs, err := LoadTrustAnchor(h.path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.mu.Lock()
|
||||
h.certs = certs
|
||||
h.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// WatchSIGHUP installs a signal handler that calls Reload on each SIGHUP.
|
||||
// The returned stop function closes the internal done channel and stops
|
||||
// signal delivery so the goroutine can exit cleanly during shutdown.
|
||||
//
|
||||
// Errors from Reload are logged but do not terminate the watcher — the
|
||||
// operator can fix the files and send another SIGHUP. Mirrors the
|
||||
// (*certHolder).watchSIGHUP contract exactly.
|
||||
//
|
||||
// Multiple holders can coexist: each registers its own goroutine on the
|
||||
// same SIGHUP signal. signal.Notify multicasts to every registered
|
||||
// channel, so a single SIGHUP reloads every per-profile Intune trust
|
||||
// anchor PLUS the server TLS cert in one operator action — exactly the
|
||||
// design requirement (one SIGHUP rotates everything).
|
||||
func (h *TrustAnchorHolder) WatchSIGHUP() (stop func()) {
|
||||
ch := make(chan os.Signal, 1)
|
||||
signal.Notify(ch, syscall.SIGHUP)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
if err := h.Reload(); err != nil {
|
||||
h.logger.Error("Intune trust anchor reload failed; continuing with previous pool",
|
||||
"error", err,
|
||||
"path", h.path)
|
||||
continue
|
||||
}
|
||||
h.logger.Info("Intune trust anchor reloaded via SIGHUP",
|
||||
"path", h.path,
|
||||
"certs_loaded", len(h.Get()))
|
||||
case <-done:
|
||||
signal.Stop(ch)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return func() { close(done) }
|
||||
}
|
||||
// Note: the original intune.NewTrustAnchorHolder set the holder's
|
||||
// internal log label to "Intune trust anchor"; the extracted
|
||||
// trustanchor.New defaults to "trust anchor". Existing intune callers
|
||||
// that need the original label should call .SetLabelForLog("intune
|
||||
// trust anchor (PathID=…)") on the returned holder. cmd/server/main.go
|
||||
// does this in the per-profile Intune startup loop.
|
||||
var NewTrustAnchorHolder = trustanchor.New
|
||||
|
||||
@@ -16,6 +16,13 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// EST RFC 7030 hardening master bundle Phase 2.1: the white-box parser
|
||||
// tests (TestParseTrustAnchorPEM_*) moved to internal/trustanchor/holder_test.go
|
||||
// where parseBundlePEM now lives. The intune package retains a thin
|
||||
// public-surface test of LoadTrustAnchor — the back-compat shim that
|
||||
// existing intune callers use — so a future refactor that breaks the
|
||||
// shim's wire-up to trustanchor.LoadBundle is caught here.
|
||||
|
||||
// pemEncodeCert is a small DRY helper for the PEM bundle fixtures.
|
||||
func pemEncodeCert(t *testing.T, der []byte) []byte {
|
||||
t.Helper()
|
||||
@@ -24,7 +31,9 @@ func pemEncodeCert(t *testing.T, der []byte) []byte {
|
||||
|
||||
// freshConnectorCertDER returns a freshly-minted EC P-256 cert as raw DER
|
||||
// + the matching key. Lifetime is parameterised so the same factory drives
|
||||
// both the happy-path and expired-cert cases.
|
||||
// both happy-path and expired-cert cases. Kept in this file (not deleted with
|
||||
// the white-box tests) because trust_anchor_holder_test.go's freshHolderCert
|
||||
// returns *x509.Certificate while LoadTrustAnchor tests need raw DER + key.
|
||||
func freshConnectorCertDER(t *testing.T, notAfter time.Time) ([]byte, *ecdsa.PrivateKey) {
|
||||
t.Helper()
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
@@ -44,96 +53,6 @@ func freshConnectorCertDER(t *testing.T, notAfter time.Time) ([]byte, *ecdsa.Pri
|
||||
return der, key
|
||||
}
|
||||
|
||||
func TestParseTrustAnchorPEM_HappyPath_SingleCert(t *testing.T) {
|
||||
der, _ := freshConnectorCertDER(t, time.Now().Add(365*24*time.Hour))
|
||||
body := pemEncodeCert(t, der)
|
||||
|
||||
certs, err := parseTrustAnchorPEM(body, "test", time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("parseTrustAnchorPEM: %v", err)
|
||||
}
|
||||
if len(certs) != 1 {
|
||||
t.Fatalf("len(certs) = %d, want 1", len(certs))
|
||||
}
|
||||
if certs[0].Subject.CommonName != "intune-connector-test" {
|
||||
t.Errorf("Subject.CommonName = %q", certs[0].Subject.CommonName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTrustAnchorPEM_HappyPath_MultiCert(t *testing.T) {
|
||||
d1, _ := freshConnectorCertDER(t, time.Now().Add(30*24*time.Hour))
|
||||
d2, _ := freshConnectorCertDER(t, time.Now().Add(60*24*time.Hour))
|
||||
body := append(pemEncodeCert(t, d1), pemEncodeCert(t, d2)...)
|
||||
|
||||
certs, err := parseTrustAnchorPEM(body, "test", time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("parseTrustAnchorPEM: %v", err)
|
||||
}
|
||||
if len(certs) != 2 {
|
||||
t.Fatalf("len(certs) = %d, want 2", len(certs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTrustAnchorPEM_SkipsNonCertBlocks(t *testing.T) {
|
||||
der, key := freshConnectorCertDER(t, time.Now().Add(30*24*time.Hour))
|
||||
keyDER, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalECPrivateKey: %v", err)
|
||||
}
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||
body := append(keyPEM, pemEncodeCert(t, der)...) // priv key first, cert second
|
||||
|
||||
certs, err := parseTrustAnchorPEM(body, "test", time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("parseTrustAnchorPEM should ignore non-CERTIFICATE blocks: %v", err)
|
||||
}
|
||||
if len(certs) != 1 {
|
||||
t.Fatalf("len(certs) = %d, want 1 (priv key block must be skipped)", len(certs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTrustAnchorPEM_EmptyBundleRejected(t *testing.T) {
|
||||
_, err := parseTrustAnchorPEM([]byte("nothing here"), "test", time.Now())
|
||||
if err == nil || !strings.Contains(err.Error(), "no CERTIFICATE PEM blocks") {
|
||||
t.Fatalf("expected 'no CERTIFICATE PEM blocks' error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTrustAnchorPEM_OnlyKeyBlocksRejected(t *testing.T) {
|
||||
key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
keyDER, _ := x509.MarshalECPrivateKey(key)
|
||||
body := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||
|
||||
_, err := parseTrustAnchorPEM(body, "test", time.Now())
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for bundle with no certs, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTrustAnchorPEM_ExpiredCertRejected(t *testing.T) {
|
||||
der, _ := freshConnectorCertDER(t, time.Now().Add(-1*time.Hour)) // already expired
|
||||
body := pemEncodeCert(t, der)
|
||||
|
||||
_, err := parseTrustAnchorPEM(body, "expired-bundle", time.Now())
|
||||
if err == nil || !strings.Contains(err.Error(), "expired") {
|
||||
t.Fatalf("expected expiry error, got %v", err)
|
||||
}
|
||||
// Operator-actionable message must include the subject so the audit
|
||||
// log says exactly which cert to rotate.
|
||||
if !strings.Contains(err.Error(), "intune-connector-test") {
|
||||
t.Errorf("error must include subject CN for operator action: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTrustAnchorPEM_MalformedCertRejected(t *testing.T) {
|
||||
bad := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("not-a-real-asn1-cert")})
|
||||
|
||||
_, err := parseTrustAnchorPEM(bad, "test", time.Now())
|
||||
if err == nil {
|
||||
t.Fatalf("expected x509 parse error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadTrustAnchor_FromDisk(t *testing.T) {
|
||||
der, _ := freshConnectorCertDER(t, time.Now().Add(30*24*time.Hour))
|
||||
body := pemEncodeCert(t, der)
|
||||
@@ -150,6 +69,9 @@ func TestLoadTrustAnchor_FromDisk(t *testing.T) {
|
||||
if len(certs) != 1 {
|
||||
t.Fatalf("len(certs) = %d, want 1", len(certs))
|
||||
}
|
||||
if certs[0].Subject.CommonName != "intune-connector-test" {
|
||||
t.Errorf("Subject.CommonName = %q", certs[0].Subject.CommonName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadTrustAnchor_EmptyPath(t *testing.T) {
|
||||
@@ -164,7 +86,6 @@ func TestLoadTrustAnchor_MissingFile(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Fatalf("expected file-not-found error, got nil")
|
||||
}
|
||||
// Don't string-assert on the OS error — just make sure it's surfaced.
|
||||
if errors.Is(err, nil) {
|
||||
t.Fatalf("error must be non-nil")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,227 @@
|
||||
// Package trustanchor provides a SIGHUP-reloadable PEM-bundle trust pool
|
||||
// shared by the SCEP/Intune dispatcher (per-profile Microsoft Intune
|
||||
// Connector signing-cert anchor), the EST mTLS sibling route (per-profile
|
||||
// client-CA trust bundle for /.well-known/est-mtls/<pathID>/), and any
|
||||
// future caller that needs the same pattern (operator rotates an on-disk
|
||||
// PEM bundle, sends SIGHUP, certctl swaps the in-memory pool atomically
|
||||
// without a restart).
|
||||
//
|
||||
// EST RFC 7030 hardening master bundle Phase 2.1: extracted from
|
||||
// internal/scep/intune/trust_anchor_holder.go where it originally lived.
|
||||
// The intune package preserves a thin alias-style wrapper for back-compat
|
||||
// (existing intune.TrustAnchorHolder + NewTrustAnchorHolder + LoadTrustAnchor
|
||||
// callers compile unchanged); new callers SHOULD import this package
|
||||
// directly.
|
||||
//
|
||||
// Concurrency contract:
|
||||
//
|
||||
// - Get returns the pool slice header by value; the slice itself is
|
||||
// immutable per-snapshot (Reload swaps a fresh slice rather than
|
||||
// mutating the existing one). Callers may iterate the returned slice
|
||||
// without holding any lock.
|
||||
// - Reload acquires a write lock briefly for the swap. Concurrent Get
|
||||
// calls block only for that swap window (microseconds).
|
||||
// - WatchSIGHUP runs at most one Reload at a time per holder.
|
||||
//
|
||||
// Threat model: the rationale for SIGHUP-as-reload-trigger (vs fsnotify
|
||||
// or polling) is that the existing certctl rotation playbook (server TLS
|
||||
// cert at cmd/server/tls.go::certHolder) already uses SIGHUP. Operators
|
||||
// running the standard "rotate file, kill -HUP" workflow get every
|
||||
// holder reloaded with one signal: server TLS + Intune trust anchors +
|
||||
// EST mTLS trust bundles all swap atomically.
|
||||
package trustanchor
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Holder is the SIGHUP-reloadable wrapper around a PEM-bundle trust
|
||||
// pool. Construct via New. The zero value is NOT usable.
|
||||
type Holder struct {
|
||||
mu sync.RWMutex
|
||||
certs []*x509.Certificate
|
||||
path string
|
||||
logger *slog.Logger
|
||||
// labelForLog is used only in error / info log lines so an operator
|
||||
// running multiple holders (per-profile EST mTLS, per-profile Intune,
|
||||
// server TLS) can distinguish which one fired. Defaults to "trust
|
||||
// anchor" when not set; callers SHOULD set this to a descriptive
|
||||
// string like "intune trust anchor (PathID=corp)" or "EST mTLS
|
||||
// client CA bundle (PathID=corp)".
|
||||
labelForLog string
|
||||
}
|
||||
|
||||
// New loads the trust bundle and returns a holder. Returns the same
|
||||
// fail-loud error LoadBundle does on initial load — the startup gate at
|
||||
// cmd/server/main.go is supposed to refuse boot when this fails.
|
||||
// Subsequent Reload errors are non-fatal (logged + old pool retained).
|
||||
//
|
||||
// The logger is required (never nil); the caller passes a per-profile
|
||||
// scoped logger so SIGHUP-reload events show the PathID for triage.
|
||||
func New(path string, logger *slog.Logger) (*Holder, error) {
|
||||
if logger == nil {
|
||||
return nil, errors.New("trustanchor: New requires a non-nil logger")
|
||||
}
|
||||
certs, err := LoadBundle(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Holder{certs: certs, path: path, logger: logger, labelForLog: "trust anchor"}, nil
|
||||
}
|
||||
|
||||
// SetLabelForLog records a descriptive label that future reload log
|
||||
// lines use to distinguish this holder from others (e.g. "intune trust
|
||||
// anchor (PathID=corp)"). Idempotent + safe for concurrent callers
|
||||
// (the field is read only by the SIGHUP watcher goroutine after
|
||||
// WatchSIGHUP starts).
|
||||
func (h *Holder) SetLabelForLog(label string) {
|
||||
if label == "" {
|
||||
return
|
||||
}
|
||||
h.mu.Lock()
|
||||
h.labelForLog = label
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
// Get returns the current trust anchor pool. Safe for concurrent
|
||||
// callers; the slice header is returned by value and the underlying
|
||||
// slice is immutable per-snapshot.
|
||||
func (h *Holder) Get() []*x509.Certificate {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.certs
|
||||
}
|
||||
|
||||
// Path returns the on-disk path the holder reloads from.
|
||||
func (h *Holder) Path() string {
|
||||
return h.path
|
||||
}
|
||||
|
||||
// Pool returns a fresh *x509.CertPool populated with the holder's
|
||||
// current certs. Helper for callers that need a pool instead of a
|
||||
// slice (the EST mTLS handler verifies client cert chains via
|
||||
// cert.Verify(VerifyOptions{Roots: pool}); the Intune dispatcher uses
|
||||
// the slice directly for signature-walk).
|
||||
func (h *Holder) Pool() *x509.CertPool {
|
||||
pool := x509.NewCertPool()
|
||||
for _, c := range h.Get() {
|
||||
pool.AddCert(c)
|
||||
}
|
||||
return pool
|
||||
}
|
||||
|
||||
// Reload re-reads the trust anchor file at h.path and atomically swaps
|
||||
// the pool. Returns the parse error if the new file is invalid; the
|
||||
// OLD pool stays in place so a bad reload doesn't take dependent
|
||||
// dispatch paths down. Same fail-safe pattern as cmd/server/tls.go::
|
||||
// (*certHolder).Reload — a rotation that writes a half-file would
|
||||
// otherwise crash the service mid-rotation.
|
||||
func (h *Holder) Reload() error {
|
||||
certs, err := LoadBundle(h.path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.mu.Lock()
|
||||
h.certs = certs
|
||||
h.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// WatchSIGHUP installs a signal handler that calls Reload on each
|
||||
// SIGHUP. The returned stop function closes the internal done channel
|
||||
// and stops signal delivery so the goroutine can exit cleanly during
|
||||
// shutdown.
|
||||
//
|
||||
// Errors from Reload are logged but do not terminate the watcher — the
|
||||
// operator can fix the files and send another SIGHUP. Mirrors the
|
||||
// (*certHolder).watchSIGHUP contract from cmd/server/tls.go exactly.
|
||||
//
|
||||
// Multiple holders coexist: each registers its own goroutine on the
|
||||
// same SIGHUP signal. signal.Notify multicasts to every registered
|
||||
// channel, so a single SIGHUP reloads every per-profile trust anchor
|
||||
// + the server TLS cert in one operator action.
|
||||
func (h *Holder) WatchSIGHUP() (stop func()) {
|
||||
ch := make(chan os.Signal, 1)
|
||||
signal.Notify(ch, syscall.SIGHUP)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
if err := h.Reload(); err != nil {
|
||||
h.logger.Error(h.labelForLog+" reload failed; continuing with previous pool",
|
||||
"error", err,
|
||||
"path", h.path)
|
||||
continue
|
||||
}
|
||||
h.logger.Info(h.labelForLog+" reloaded via SIGHUP",
|
||||
"path", h.path,
|
||||
"certs_loaded", len(h.Get()))
|
||||
case <-done:
|
||||
signal.Stop(ch)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return func() { close(done) }
|
||||
}
|
||||
|
||||
// LoadBundle reads a PEM bundle from disk + returns the parsed cert
|
||||
// slice. Refuses empty bundles (zero CERTIFICATE blocks); refuses any
|
||||
// bundle containing a cert past NotAfter (fail loud at boot rather than
|
||||
// silently rejecting every request at runtime).
|
||||
//
|
||||
// Non-CERTIFICATE PEM blocks are skipped (so an operator can paste a
|
||||
// chain that includes a private key by mistake without breaking the
|
||||
// load — the priv key is just ignored). Operators rotating signing
|
||||
// certs typically want this tolerance.
|
||||
func LoadBundle(path string) ([]*x509.Certificate, error) {
|
||||
if path == "" {
|
||||
return nil, errors.New("trustanchor: bundle path is empty")
|
||||
}
|
||||
body, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("trustanchor: read bundle %q: %w", path, err)
|
||||
}
|
||||
return parseBundlePEM(body, path, time.Now())
|
||||
}
|
||||
|
||||
// parseBundlePEM is the file-IO-free core of LoadBundle. Split out so
|
||||
// unit tests can hand it byte slices without writing temp files. `now`
|
||||
// is taken as a parameter so expiry tests can pin a deterministic clock.
|
||||
func parseBundlePEM(body []byte, sourceLabel string, now time.Time) ([]*x509.Certificate, error) {
|
||||
var out []*x509.Certificate
|
||||
rest := body
|
||||
for {
|
||||
var block *pem.Block
|
||||
block, rest = pem.Decode(rest)
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
if block.Type != "CERTIFICATE" {
|
||||
continue
|
||||
}
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("trustanchor: parse cert in %q: %w", sourceLabel, err)
|
||||
}
|
||||
if now.After(cert.NotAfter) {
|
||||
return nil, fmt.Errorf("trustanchor: cert in %q expired at %s (subject=%q) — operator must rotate the trust bundle before restart",
|
||||
sourceLabel, cert.NotAfter.Format(time.RFC3339), cert.Subject.CommonName)
|
||||
}
|
||||
out = append(out, cert)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil, fmt.Errorf("trustanchor: %q contains no CERTIFICATE PEM blocks", sourceLabel)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -0,0 +1,432 @@
|
||||
package trustanchor
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EST RFC 7030 hardening master bundle Phase 2.1: this test file holds the
|
||||
// white-box tests for the trust-anchor primitives (parseBundlePEM + LoadBundle
|
||||
// + Holder) that used to live in internal/scep/intune/{trust_anchor_test.go,
|
||||
// trust_anchor_holder_test.go}. The intune package retains a thin
|
||||
// public-surface test of LoadTrustAnchor (the back-compat shim) — the
|
||||
// detailed tests live here so the EST mTLS sibling route + any future
|
||||
// trustanchor.Holder caller share the same contract pinning.
|
||||
|
||||
// silentLogger drops everything; the SIGHUP watcher emits Info logs we don't
|
||||
// want fouling test output.
|
||||
func silentLogger() *slog.Logger {
|
||||
return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError + 10}))
|
||||
}
|
||||
|
||||
// pemEncodeCert is a small DRY helper for the PEM bundle fixtures.
|
||||
func pemEncodeCert(t *testing.T, der []byte) []byte {
|
||||
t.Helper()
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||
}
|
||||
|
||||
// freshConnectorCertDER returns a freshly-minted EC P-256 cert as raw DER
|
||||
// + the matching key. Lifetime is parameterised so the same factory drives
|
||||
// both the happy-path and expired-cert cases.
|
||||
func freshConnectorCertDER(t *testing.T, notAfter time.Time) ([]byte, *ecdsa.PrivateKey) {
|
||||
t.Helper()
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("ecdsa.GenerateKey: %v", err)
|
||||
}
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(time.Now().UnixNano()),
|
||||
Subject: pkix.Name{CommonName: "trustanchor-test"},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: notAfter,
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
t.Fatalf("x509.CreateCertificate: %v", err)
|
||||
}
|
||||
return der, key
|
||||
}
|
||||
|
||||
// writeTestBundle writes a PEM bundle of the given certs at path with mode 0600.
|
||||
func writeTestBundle(t *testing.T, path string, certs []*x509.Certificate) {
|
||||
t.Helper()
|
||||
body := []byte{}
|
||||
for _, c := range certs {
|
||||
body = append(body, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: c.Raw})...)
|
||||
}
|
||||
if err := os.WriteFile(path, body, 0o600); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// freshHolderCert is a small factory for a self-signed EC cert with a
|
||||
// caller-controlled CN + lifetime. Used by Reload tests that swap the
|
||||
// on-disk pool between calls.
|
||||
func freshHolderCert(t *testing.T, cn string, notAfter time.Time) *x509.Certificate {
|
||||
t.Helper()
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("ecdsa.GenerateKey: %v", err)
|
||||
}
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(time.Now().UnixNano()),
|
||||
Subject: pkix.Name{CommonName: cn},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: notAfter,
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
t.Fatalf("x509.CreateCertificate: %v", err)
|
||||
}
|
||||
cert, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
t.Fatalf("x509.ParseCertificate: %v", err)
|
||||
}
|
||||
return cert
|
||||
}
|
||||
|
||||
// ----- parseBundlePEM (white-box) -----
|
||||
|
||||
func TestParseBundlePEM_HappyPath_SingleCert(t *testing.T) {
|
||||
der, _ := freshConnectorCertDER(t, time.Now().Add(365*24*time.Hour))
|
||||
body := pemEncodeCert(t, der)
|
||||
|
||||
certs, err := parseBundlePEM(body, "test", time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("parseBundlePEM: %v", err)
|
||||
}
|
||||
if len(certs) != 1 {
|
||||
t.Fatalf("len(certs) = %d, want 1", len(certs))
|
||||
}
|
||||
if certs[0].Subject.CommonName != "trustanchor-test" {
|
||||
t.Errorf("Subject.CommonName = %q", certs[0].Subject.CommonName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseBundlePEM_HappyPath_MultiCert(t *testing.T) {
|
||||
d1, _ := freshConnectorCertDER(t, time.Now().Add(30*24*time.Hour))
|
||||
d2, _ := freshConnectorCertDER(t, time.Now().Add(60*24*time.Hour))
|
||||
body := append(pemEncodeCert(t, d1), pemEncodeCert(t, d2)...)
|
||||
|
||||
certs, err := parseBundlePEM(body, "test", time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("parseBundlePEM: %v", err)
|
||||
}
|
||||
if len(certs) != 2 {
|
||||
t.Fatalf("len(certs) = %d, want 2", len(certs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseBundlePEM_SkipsNonCertBlocks(t *testing.T) {
|
||||
der, key := freshConnectorCertDER(t, time.Now().Add(30*24*time.Hour))
|
||||
keyDER, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalECPrivateKey: %v", err)
|
||||
}
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||
body := append(keyPEM, pemEncodeCert(t, der)...) // priv key first, cert second
|
||||
|
||||
certs, err := parseBundlePEM(body, "test", time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("parseBundlePEM should ignore non-CERTIFICATE blocks: %v", err)
|
||||
}
|
||||
if len(certs) != 1 {
|
||||
t.Fatalf("len(certs) = %d, want 1 (priv key block must be skipped)", len(certs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseBundlePEM_EmptyBundleRejected(t *testing.T) {
|
||||
_, err := parseBundlePEM([]byte("nothing here"), "test", time.Now())
|
||||
if err == nil || !strings.Contains(err.Error(), "no CERTIFICATE PEM blocks") {
|
||||
t.Fatalf("expected 'no CERTIFICATE PEM blocks' error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseBundlePEM_OnlyKeyBlocksRejected(t *testing.T) {
|
||||
key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
keyDER, _ := x509.MarshalECPrivateKey(key)
|
||||
body := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||
|
||||
_, err := parseBundlePEM(body, "test", time.Now())
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for bundle with no certs, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseBundlePEM_ExpiredCertRejected(t *testing.T) {
|
||||
der, _ := freshConnectorCertDER(t, time.Now().Add(-1*time.Hour)) // already expired
|
||||
body := pemEncodeCert(t, der)
|
||||
|
||||
_, err := parseBundlePEM(body, "expired-bundle", time.Now())
|
||||
if err == nil || !strings.Contains(err.Error(), "expired") {
|
||||
t.Fatalf("expected expiry error, got %v", err)
|
||||
}
|
||||
// Operator-actionable message must include the subject so the audit
|
||||
// log says exactly which cert to rotate.
|
||||
if !strings.Contains(err.Error(), "trustanchor-test") {
|
||||
t.Errorf("error must include subject CN for operator action: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseBundlePEM_MalformedCertRejected(t *testing.T) {
|
||||
bad := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("not-a-real-asn1-cert")})
|
||||
|
||||
_, err := parseBundlePEM(bad, "test", time.Now())
|
||||
if err == nil {
|
||||
t.Fatalf("expected x509 parse error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// ----- LoadBundle (filesystem-side) -----
|
||||
|
||||
func TestLoadBundle_FromDisk(t *testing.T) {
|
||||
der, _ := freshConnectorCertDER(t, time.Now().Add(30*24*time.Hour))
|
||||
body := pemEncodeCert(t, der)
|
||||
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "trust.pem")
|
||||
if err := os.WriteFile(path, body, 0o600); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
certs, err := LoadBundle(path)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadBundle: %v", err)
|
||||
}
|
||||
if len(certs) != 1 {
|
||||
t.Fatalf("len(certs) = %d, want 1", len(certs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadBundle_EmptyPath(t *testing.T) {
|
||||
_, err := LoadBundle("")
|
||||
if err == nil || !strings.Contains(err.Error(), "empty") {
|
||||
t.Fatalf("expected empty-path error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadBundle_MissingFile(t *testing.T) {
|
||||
_, err := LoadBundle("/tmp/does-not-exist-trustanchor.pem")
|
||||
if err == nil {
|
||||
t.Fatalf("expected file-not-found error, got nil")
|
||||
}
|
||||
if errors.Is(err, nil) {
|
||||
t.Fatalf("error must be non-nil")
|
||||
}
|
||||
}
|
||||
|
||||
// ----- Holder -----
|
||||
|
||||
func TestHolder_NewLoadsBundle(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "trust.pem")
|
||||
cert := freshHolderCert(t, "initial", time.Now().Add(30*24*time.Hour))
|
||||
writeTestBundle(t, path, []*x509.Certificate{cert})
|
||||
|
||||
holder, err := New(path, silentLogger())
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
got := holder.Get()
|
||||
if len(got) != 1 || got[0].Subject.CommonName != "initial" {
|
||||
t.Fatalf("Get returned %#v, want one cert with CN=initial", got)
|
||||
}
|
||||
if holder.Path() != path {
|
||||
t.Errorf("Path = %q, want %q", holder.Path(), path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHolder_NewRequiresLogger(t *testing.T) {
|
||||
if _, err := New("/nonexistent", nil); err == nil {
|
||||
t.Fatal("nil logger must error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHolder_NewSurfacesLoadError(t *testing.T) {
|
||||
if _, err := New("/path/that/does/not/exist.pem", silentLogger()); err == nil {
|
||||
t.Fatal("missing file must error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHolder_PoolReturnsAllCerts(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "trust.pem")
|
||||
c1 := freshHolderCert(t, "ca-1", time.Now().Add(30*24*time.Hour))
|
||||
c2 := freshHolderCert(t, "ca-2", time.Now().Add(30*24*time.Hour))
|
||||
writeTestBundle(t, path, []*x509.Certificate{c1, c2})
|
||||
|
||||
h, err := New(path, silentLogger())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pool := h.Pool()
|
||||
if pool == nil {
|
||||
t.Fatal("Pool returned nil")
|
||||
}
|
||||
// pool.Subjects() is deprecated for caller-owned pools that may include
|
||||
// the system roots. We've built this pool ourselves with exactly the
|
||||
// two certs from h.Get(), so it's a safe use — but the linter doesn't
|
||||
// know that. Rather than disable the lint, we cross-check via Equal()
|
||||
// over the underlying cert slice we used to build the pool.
|
||||
got := h.Get()
|
||||
if len(got) != 2 {
|
||||
t.Errorf("Get() len = %d, want 2", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHolder_SetLabelForLogIgnoresEmpty(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "trust.pem")
|
||||
writeTestBundle(t, path, []*x509.Certificate{
|
||||
freshHolderCert(t, "label-test", time.Now().Add(30*24*time.Hour)),
|
||||
})
|
||||
h, err := New(path, silentLogger())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
h.SetLabelForLog("") // no-op; default "trust anchor" preserved
|
||||
h.SetLabelForLog("est mTLS client CA bundle")
|
||||
// No public getter for label; just exercise without crashing — race
|
||||
// detector covers the locking contract under -race.
|
||||
}
|
||||
|
||||
func TestHolder_ReloadHappyPath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "trust.pem")
|
||||
c1 := freshHolderCert(t, "rev-1", time.Now().Add(30*24*time.Hour))
|
||||
writeTestBundle(t, path, []*x509.Certificate{c1})
|
||||
|
||||
h, err := New(path, silentLogger())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Rotate on disk and call Reload.
|
||||
c2 := freshHolderCert(t, "rev-2", time.Now().Add(30*24*time.Hour))
|
||||
writeTestBundle(t, path, []*x509.Certificate{c2})
|
||||
if err := h.Reload(); err != nil {
|
||||
t.Fatalf("Reload: %v", err)
|
||||
}
|
||||
got := h.Get()
|
||||
if len(got) != 1 || got[0].Subject.CommonName != "rev-2" {
|
||||
t.Errorf("after Reload Get = %#v, want one cert CN=rev-2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHolder_ReloadKeepsOldOnFailure(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "trust.pem")
|
||||
good := freshHolderCert(t, "stable", time.Now().Add(30*24*time.Hour))
|
||||
writeTestBundle(t, path, []*x509.Certificate{good})
|
||||
|
||||
h, err := New(path, silentLogger())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Overwrite with content that LoadBundle will reject (no PEM blocks).
|
||||
if err := os.WriteFile(path, []byte("garbage"), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := h.Reload(); err == nil {
|
||||
t.Fatal("Reload from garbage file must error")
|
||||
}
|
||||
|
||||
// Old pool still served.
|
||||
got := h.Get()
|
||||
if len(got) != 1 || got[0].Subject.CommonName != "stable" {
|
||||
t.Errorf("after failed Reload Get should still be the pre-Reload pool; got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHolder_ReloadKeepsOldOnExpired(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "trust.pem")
|
||||
good := freshHolderCert(t, "still-valid", time.Now().Add(30*24*time.Hour))
|
||||
writeTestBundle(t, path, []*x509.Certificate{good})
|
||||
|
||||
h, err := New(path, silentLogger())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expired := freshHolderCert(t, "expired-conn", time.Now().Add(-1*time.Hour))
|
||||
writeTestBundle(t, path, []*x509.Certificate{expired})
|
||||
|
||||
if err := h.Reload(); err == nil {
|
||||
t.Fatal("Reload with expired cert must error")
|
||||
}
|
||||
if !strings.Contains(h.Get()[0].Subject.CommonName, "still-valid") {
|
||||
t.Errorf("after expired-cert Reload, holder should retain old pool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHolder_WatchSIGHUPReloadsPool(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "trust.pem")
|
||||
c1 := freshHolderCert(t, "rev-pre-sighup", time.Now().Add(30*24*time.Hour))
|
||||
writeTestBundle(t, path, []*x509.Certificate{c1})
|
||||
|
||||
h, err := New(path, silentLogger())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
stop := h.WatchSIGHUP()
|
||||
defer stop()
|
||||
|
||||
c2 := freshHolderCert(t, "rev-post-sighup", time.Now().Add(30*24*time.Hour))
|
||||
writeTestBundle(t, path, []*x509.Certificate{c2})
|
||||
if err := syscall.Kill(syscall.Getpid(), syscall.SIGHUP); err != nil {
|
||||
t.Fatalf("send SIGHUP: %v", err)
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
got := h.Get()
|
||||
if len(got) == 1 && got[0].Subject.CommonName == "rev-post-sighup" {
|
||||
return
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("post-SIGHUP pool not swapped in 2s; current CN=%q", got[0].Subject.CommonName)
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHolder_WatchSIGHUPStopIsClean(t *testing.T) {
|
||||
// We do NOT fire a SIGHUP after stop(): once signal.Stop has removed our
|
||||
// handler the kernel's default action on SIGHUP is to terminate the
|
||||
// process — it would kill the test runner. Pin "stop() is synchronous
|
||||
// and safe" by closing the watcher and verifying the holder still
|
||||
// serves the original cert without panic.
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "trust.pem")
|
||||
writeTestBundle(t, path, []*x509.Certificate{
|
||||
freshHolderCert(t, "stop-test", time.Now().Add(30*24*time.Hour)),
|
||||
})
|
||||
|
||||
h, err := New(path, silentLogger())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
stop := h.WatchSIGHUP()
|
||||
stop()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if cn := h.Get()[0].Subject.CommonName; cn != "stop-test" {
|
||||
t.Errorf("after stop CN = %q, want unchanged stop-test", cn)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user