mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-10 13:58:52 +00:00
feat(scep-intune): per-profile dispatcher + SIGHUP reload + per-device rate limit + compliance hook seam
Phase 8 of the SCEP RFC 8894 + Intune master bundle. Wires the internal/scep/intune validator from Phase 7 into the SCEPService dispatch path, with a SIGHUP-reloadable trust anchor holder, a per-(Subject, Issuer) sliding-window rate limiter, and a nil-default ComplianceCheck seam for V3-Pro. Operator-visible surface (per-profile, all default to off): CERTCTL_SCEP_PROFILE_<NAME>_INTUNE_ENABLED=true CERTCTL_SCEP_PROFILE_<NAME>_INTUNE_CONNECTOR_CERT_PATH=/etc/certctl/intune.pem CERTCTL_SCEP_PROFILE_<NAME>_INTUNE_AUDIENCE=https://certctl.example.com/scep/corp CERTCTL_SCEP_PROFILE_<NAME>_INTUNE_CHALLENGE_VALIDITY=60m CERTCTL_SCEP_PROFILE_<NAME>_INTUNE_PER_DEVICE_RATE_LIMIT_24H=3 Per-profile dispatch (Phase 8.8): an operator running corp-laptops through Intune AND IoT devices through static challenge configures INTUNE_ENABLED=true on the corp profile only — the IoT profile's PKCSReq path skips the dispatcher entirely. Mirrors the per-profile shape established by Phase 1.5. Wire-in surfaces: * config.go (Phase 8.1): SCEPProfileConfig.Intune sub-config of type SCEPIntuneProfileConfig (Enabled/ConnectorCertPath/Audience/ ChallengeValidity/PerDeviceRateLimit24h). Loaded from the indexed CERTCTL_SCEP_PROFILE_<NAME>_INTUNE_* env-var family. Per-profile Validate gate refuses INTUNE_ENABLED=true with empty ConnectorCertPath OR negative PerDeviceRateLimit24h. * cmd/server/main.go (Phase 8.2 + wire-in): preflightSCEPIntuneTrustAnchor helper mirrors preflightSCEPRACertKey/preflightSCEPMTLSTrustBundle shape — fail-loud at boot when the trust anchor file is missing / unreadable / empty / contains an expired cert. The per-profile loop builds the holder + replay cache + rate limiter, calls SetIntuneIntegration on the SCEPService, and starts the SIGHUP watcher. A deferred sweep stops every watcher at shutdown. * internal/scep/intune/trust_anchor_holder.go (Phase 8.5): TrustAnchorHolder mirrors cmd/server/tls.go::certHolder. RWMutex- guarded pool + Reload that swaps a fresh slice on success + WatchSIGHUP goroutine that responds to the same SIGHUP the existing TLS-cert watcher uses. A bad reload (parse error, expired cert) keeps the OLD pool in place so a half-rotation doesn't take Intune enrollment down — same fail-safe pattern. Operators rotate via the on-disk file then 'kill -HUP <certctl-pid>'. * internal/scep/intune/rate_limit.go (Phase 8.6): hand-rolled sliding-window-log limiter keyed by (Subject, Issuer). 100k-entry map cap (matches replay cache); at-cap drops the bucket whose newest timestamp is the oldest. Default 3 enrollments per 24h covers legitimate first-cert + recovery + post-wipe re-enrollment but blocks bulk enumeration from a compromised Connector signing key. maxN <= 0 disables the limiter for tests + the rare operator who wants no per-device cap. Empty subject short-circuits to allow (defense-in-depth: caller's claim validation rejects empty-subject upstream; no shared bucket on ''). Why hand-rolled instead of golang.org/x/time/rate: the rate package is in go.sum as an indirect transitive but not a direct dep. ~30 LoC of stdlib avoids creating a new direct dep. * internal/service/scep.go (Phase 8.3 + 8.4 + 8.7): - SCEPService gains intuneEnabled / intuneTrust / intuneAudience / intuneValidity / intuneReplayCache / intuneRateLimiter / complianceCheck fields. - SetIntuneIntegration() constructor-time injection wires the per-profile state. Profiles with INTUNE_ENABLED=false never call this method, so they pay zero overhead. - SetComplianceCheck() installs the V3-Pro plug-in (see Phase 8.7). - looksIntuneShaped(): JWT-shape pre-check (length > 200 + exactly two dots). Allowed to false-positive (validator catches malformed → ErrChallengeMalformed); MUST NOT false-negative on real Intune challenges. - dispatchIntuneChallenge(): the load-bearing core. Runs ValidateChallenge → CSR-binding via DeviceMatchesCSR → replay cache CheckAndInsert → per-device Allow → optional ComplianceCheck. Each failure leg increments a typed metric label and emits an audit-friendly Warn log line. - PKCSReq + PKCSReqWithEnvelope + RenewalReqWithEnvelope all call dispatchIntuneChallenge first; on outcome.decided=true they either short-circuit (with a typed-error → SCEPFailInfo mapping) or call processEnrollment with action='scep_pkcsreq_intune' (so audit greps can count Intune-vs-static enrollments). - mapIntuneErrorToFailInfo(): typed-error → SCEPFailInfo per RFC 8894 §3.2.1.4.5 (signature/replay/expired → BadMessageCheck; claim-mismatch → BadRequest; default → BadRequest). - intuneFailReason(): typed-error → metric label ('signature_invalid' / 'expired' / 'rate_limited' / etc.). Default 'malformed' so a previously-unseen error category still surfaces in the metric for follow-up. - ComplianceCheck (Phase 8.7): nil-default no-op gate. V3-Pro plugs in via SetComplianceCheck to call Microsoft Graph's compliance API. Returns (compliant, reason, err). nil-err + compliant=false → CertRep FAILURE + 'compliance' reason in audit. err != nil → fail-safe deny (V3-Pro module is responsible for any 'permit on API failure' policy). * internal/service/scep.go also gains parseCSRForIntune() — small private wrapper around encoding/pem + x509 used by the dispatcher for the claim ↔ CSR binding check (separated from the broader processEnrollment because we want to bind BEFORE consuming the replay-cache slot). Tests (gates: ≥85% coverage on intune package, ≥70% on service): * scep_intune_test.go (in internal/service): 14 dispatcher tests covering happy-path Intune enrollment + static-challenge fallback + tampered-challenge reject + claim-mismatch reject + replay detected + rate-limited + compliance-hook nil-default + compliance- hook denies non-compliant + compliance-hook error fails closed + IntuneEnabled accessor + 'no IntuneEnabled = static path unchanged' regression pin + intuneFailReason mapping for every typed error + looksIntuneShaped boundary cases. * trust_anchor_holder_test.go (in internal/scep/intune): NewLoadsBundle, NewRequiresLogger, NewSurfacesLoadError, ReloadHappyPath, ReloadKeepsOldOnFailure, ReloadKeepsOldOnExpired (the fail-safe semantics that make the SIGHUP path operator-friendly), WatchSIGHUPReloadsPool (real SIGHUP to self with poll-for-swap pattern mirroring cmd/server/tls_test.go), WatchSIGHUPStopIsClean (does NOT fire SIGHUP after stop — same caveat as the TLS test: the Go runtime would otherwise terminate the test runner on the next SIGHUP since signal.Stop has removed the handler). * rate_limit_test.go (in internal/scep/intune): AllowsUpToCap, DistinctKeysIndependent, WindowExpiry, DisabledBypass (maxN=0), NegativeCapDisabled, EmptySubjectShortCircuits (defense-in-depth against an empty-subject DoS chokepoint), DefaultCapsHonored, MapCapEvictsOldest (at-cap eviction branch), ConcurrentRaceFree (50 goroutines × 200 inserts), pruneOlderThan + the no-op case. Verification: * gofmt -l on all touched files: clean * go vet ./... : clean * staticcheck on intune/service/config/cmd-server: clean * go test -count=1 -cover ./internal/scep/intune/...: 94.8% (target ≥85%) * go test -short across intune+service+config+handler+cmd-server: all green * G-3 docs-drift CI guard reproduced locally: docs-only filtered= empty, config-only=empty. The new env vars match the existing CERTCTL_SCEP_ allowlist prefix. Refs: cowork/scep-rfc8894-intune-master-prompt.md::Phase 8 cowork/scep-rfc8894-intune/progress.md Constitutional rule: 'Always take the complete path, not the easy path' (cowork/CLAUDE.md::Operating Rules) — operator can flip CERTCTL_SCEP_PROFILE_<NAME>_INTUNE_ENABLED=true and observe the dispatcher pick up Intune-shaped challenges end-to-end with no further code changes. Foundation + plumbing ship together.
This commit is contained in:
@@ -0,0 +1,193 @@
|
||||
package intune
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SCEP RFC 8894 + Intune master bundle Phase 8.6.
|
||||
//
|
||||
// PerDeviceRateLimiter is the second line of defense behind the replay cache
|
||||
// from Phase 7. The replay cache catches the same challenge being submitted
|
||||
// twice (within the challenge TTL); this rate limiter catches a compromised
|
||||
// Connector signing key (or a stolen key+cert pair) issuing many DIFFERENT
|
||||
// valid challenges for the same device subject in a short window.
|
||||
//
|
||||
// Threat model:
|
||||
//
|
||||
// - Replay cache (Phase 7): nonce-keyed; catches duplicate submission.
|
||||
// - This limiter: (Subject, Issuer)-keyed; catches enrollment-flooding.
|
||||
//
|
||||
// Default: 3 enrollments per (device GUID, Connector identity) per 24h.
|
||||
//
|
||||
// Sizing: 100,000 distinct device entries (matches the replay cache cap).
|
||||
// At-cap: oldest entry evicted (small janitor pass) to avoid unbounded
|
||||
// memory growth on a fleet that grows past the cap.
|
||||
//
|
||||
// Why a hand-rolled token bucket instead of pulling in golang.org/x/time/rate:
|
||||
// the rate package is in go.sum as an indirect transitive but NOT a direct
|
||||
// dep. Adding it would create a new direct dep relationship for ~30 LoC of
|
||||
// state machine. The hand-rolled version below uses only stdlib (sync.Mutex
|
||||
// + time.Time arithmetic) and is small enough to fit on one screen.
|
||||
//
|
||||
// Algorithm: each (Subject, Issuer) key maps to a bucket holding a window's
|
||||
// worth of recent enrollment timestamps. On Allow, the bucket prunes
|
||||
// timestamps older than (now - window) and either appends the current
|
||||
// timestamp + returns true, or rejects + returns false when the post-prune
|
||||
// count is already at the cap. This is the "sliding window log" rate
|
||||
// limiter — exact (no token-leak rounding); O(N_per_key) per-call but N is
|
||||
// bounded by the cap (3 by default), so effectively O(1).
|
||||
|
||||
// ErrRateLimited is the typed error returned when the per-device rate limit
|
||||
// fires. The handler maps this to a CertRep FAILURE with badRequest failInfo
|
||||
// + the `rate_limited` metric label.
|
||||
var ErrRateLimited = errors.New("intune: per-device rate limit exceeded for this (subject, issuer) within the configured window")
|
||||
|
||||
// PerDeviceRateLimiter is a sliding-window-log rate limiter keyed by
|
||||
// (Subject, Issuer) tuples derived from a parsed challenge claim.
|
||||
//
|
||||
// Concurrency: the limiter is safe for concurrent Allow calls. The internal
|
||||
// map is guarded by a mutex; the per-key slices are mutated only while the
|
||||
// mutex is held.
|
||||
type PerDeviceRateLimiter struct {
|
||||
mu sync.Mutex
|
||||
buckets map[string][]time.Time // key → sliding window of timestamps
|
||||
maxN int // max enrollments per window
|
||||
window time.Duration // window length (default 24h)
|
||||
cap int // max keys before LRU eviction kicks in
|
||||
disabled bool // maxN == 0 → all Allow calls return nil
|
||||
}
|
||||
|
||||
// NewPerDeviceRateLimiter returns a limiter with the given per-key cap +
|
||||
// window. maxN ≤ 0 disables the limiter (all Allow calls return nil); this
|
||||
// is operator opt-out for the rare case where the per-device cap is
|
||||
// undesirable (e.g. test harnesses, sketchpad deploys).
|
||||
//
|
||||
// Window defaults to 24h when zero. Map cap defaults to 100,000 when zero
|
||||
// (matches the replay cache cap; see internal/scep/intune/replay.go).
|
||||
func NewPerDeviceRateLimiter(maxN int, window time.Duration, mapCap int) *PerDeviceRateLimiter {
|
||||
if window <= 0 {
|
||||
window = 24 * time.Hour
|
||||
}
|
||||
if mapCap <= 0 {
|
||||
mapCap = 100_000
|
||||
}
|
||||
return &PerDeviceRateLimiter{
|
||||
buckets: make(map[string][]time.Time),
|
||||
maxN: maxN,
|
||||
window: window,
|
||||
cap: mapCap,
|
||||
disabled: maxN <= 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks whether an enrollment for the given (subject, issuer) tuple
|
||||
// is permitted right now. Returns nil when allowed (and records the timestamp
|
||||
// in the bucket) or ErrRateLimited when the bucket is at maxN.
|
||||
//
|
||||
// Empty subject is treated as "skip the limiter" — the caller's claim
|
||||
// validation should have rejected an empty-subject claim already; this is
|
||||
// belt-and-suspenders to prevent a single empty-subject bucket from
|
||||
// becoming a fleet-wide chokepoint. The Connector emits non-empty subject
|
||||
// (device GUID) on every legitimate challenge.
|
||||
func (l *PerDeviceRateLimiter) Allow(subject, issuer string, now time.Time) error {
|
||||
if l.disabled {
|
||||
return nil
|
||||
}
|
||||
if subject == "" {
|
||||
// Caller's claim validation should reject empty-subject upstream;
|
||||
// this short-circuit is defense-in-depth so a misconfigured
|
||||
// Connector can't DoS us via the rate-limit path.
|
||||
return nil
|
||||
}
|
||||
key := subject + "|" + issuer
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
// At-cap eviction: when the map is full, drop the oldest entry by
|
||||
// finding the bucket whose newest timestamp is the smallest. O(N) but
|
||||
// rarely fires; the prune-on-Allow path keeps most buckets short-lived.
|
||||
if len(l.buckets) >= l.cap {
|
||||
l.evictOldestLocked(now)
|
||||
}
|
||||
|
||||
bucket := l.buckets[key]
|
||||
bucket = pruneOlderThan(bucket, now.Add(-l.window))
|
||||
|
||||
if len(bucket) >= l.maxN {
|
||||
// Don't append; over the limit. Persist the pruned bucket so the
|
||||
// next call sees the most-recently-pruned state.
|
||||
l.buckets[key] = bucket
|
||||
return ErrRateLimited
|
||||
}
|
||||
|
||||
bucket = append(bucket, now)
|
||||
l.buckets[key] = bucket
|
||||
return nil
|
||||
}
|
||||
|
||||
// pruneOlderThan returns the slice with all entries strictly before
|
||||
// `cutoff` removed. Preserves order (timestamps are appended in increasing
|
||||
// time, so a single linear scan from the front suffices).
|
||||
func pruneOlderThan(b []time.Time, cutoff time.Time) []time.Time {
|
||||
i := 0
|
||||
for i < len(b) && b[i].Before(cutoff) {
|
||||
i++
|
||||
}
|
||||
if i == 0 {
|
||||
return b
|
||||
}
|
||||
// Copy-shrink to release the underlying-array memory eventually
|
||||
// (otherwise the slice would hold a reference to the older entries
|
||||
// indefinitely until a re-allocation).
|
||||
out := make([]time.Time, len(b)-i)
|
||||
copy(out, b[i:])
|
||||
return out
|
||||
}
|
||||
|
||||
// evictOldestLocked drops the map entry whose newest timestamp is the
|
||||
// oldest. Called under l.mu. O(N_keys) per eviction; at-cap is rare in
|
||||
// practice (caps are sized for fleet steady-state).
|
||||
func (l *PerDeviceRateLimiter) evictOldestLocked(now time.Time) {
|
||||
var (
|
||||
oldestKey string
|
||||
oldestTs time.Time
|
||||
first = true
|
||||
)
|
||||
for k, b := range l.buckets {
|
||||
if len(b) == 0 {
|
||||
// Empty bucket — drop it immediately, no candidate scan needed.
|
||||
delete(l.buckets, k)
|
||||
return
|
||||
}
|
||||
newest := b[len(b)-1]
|
||||
if first || newest.Before(oldestTs) {
|
||||
oldestKey = k
|
||||
oldestTs = newest
|
||||
first = false
|
||||
}
|
||||
}
|
||||
if oldestKey != "" {
|
||||
delete(l.buckets, oldestKey)
|
||||
}
|
||||
// Suppress unused-parameter warning for `now` in case the eviction
|
||||
// strategy changes (e.g. swap to LRU keyed by time of last Allow).
|
||||
_ = now
|
||||
}
|
||||
|
||||
// Len returns the approximate number of distinct (subject, issuer) keys
|
||||
// currently tracked. For observability + tests; not load-stable under
|
||||
// concurrent Allow calls.
|
||||
func (l *PerDeviceRateLimiter) Len() int {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return len(l.buckets)
|
||||
}
|
||||
|
||||
// Disabled reports whether the limiter is in opt-out mode (maxN ≤ 0).
|
||||
// Useful for handler-side gating + admin-endpoint observability.
|
||||
func (l *PerDeviceRateLimiter) Disabled() bool {
|
||||
return l.disabled
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
package intune
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPerDeviceRateLimiter_AllowsUpToCap(t *testing.T) {
|
||||
l := NewPerDeviceRateLimiter(3, 24*time.Hour, 10)
|
||||
now := time.Now()
|
||||
for i := 0; i < 3; i++ {
|
||||
if err := l.Allow("device-1", "issuer-A", now.Add(time.Duration(i)*time.Minute)); err != nil {
|
||||
t.Fatalf("call %d should be allowed: %v", i+1, err)
|
||||
}
|
||||
}
|
||||
if err := l.Allow("device-1", "issuer-A", now.Add(4*time.Minute)); !errors.Is(err, ErrRateLimited) {
|
||||
t.Fatalf("4th call should be rate-limited; got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerDeviceRateLimiter_DistinctKeysIndependent(t *testing.T) {
|
||||
l := NewPerDeviceRateLimiter(1, 24*time.Hour, 10)
|
||||
now := time.Now()
|
||||
|
||||
if err := l.Allow("device-1", "issuer-A", now); err != nil {
|
||||
t.Fatalf("first allow: %v", err)
|
||||
}
|
||||
// Different subject — independent bucket.
|
||||
if err := l.Allow("device-2", "issuer-A", now); err != nil {
|
||||
t.Fatalf("different subject must have its own bucket: %v", err)
|
||||
}
|
||||
// Different issuer — also independent.
|
||||
if err := l.Allow("device-1", "issuer-B", now); err != nil {
|
||||
t.Fatalf("different issuer must have its own bucket: %v", err)
|
||||
}
|
||||
// Same key as call 1 — must be limited.
|
||||
if err := l.Allow("device-1", "issuer-A", now.Add(1*time.Second)); !errors.Is(err, ErrRateLimited) {
|
||||
t.Fatalf("repeat key should be limited; got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerDeviceRateLimiter_WindowExpiry(t *testing.T) {
|
||||
l := NewPerDeviceRateLimiter(2, 1*time.Hour, 10)
|
||||
now := time.Now()
|
||||
|
||||
if err := l.Allow("dev", "iss", now); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := l.Allow("dev", "iss", now.Add(30*time.Minute)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Inside window — limited.
|
||||
if err := l.Allow("dev", "iss", now.Add(45*time.Minute)); !errors.Is(err, ErrRateLimited) {
|
||||
t.Fatalf("inside-window 3rd call should be limited: %v", err)
|
||||
}
|
||||
// Past window — slots reopen.
|
||||
if err := l.Allow("dev", "iss", now.Add(2*time.Hour)); err != nil {
|
||||
t.Fatalf("past-window call should be allowed (window reset): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerDeviceRateLimiter_DisabledBypass(t *testing.T) {
|
||||
l := NewPerDeviceRateLimiter(0, 24*time.Hour, 10) // maxN=0 → disabled
|
||||
if !l.Disabled() {
|
||||
t.Fatal("limiter with maxN=0 must report Disabled()=true")
|
||||
}
|
||||
now := time.Now()
|
||||
for i := 0; i < 100; i++ {
|
||||
if err := l.Allow("dev", "iss", now); err != nil {
|
||||
t.Fatalf("disabled limiter must allow everything: %v", err)
|
||||
}
|
||||
}
|
||||
// Disabled limiter doesn't track buckets.
|
||||
if got := l.Len(); got != 0 {
|
||||
t.Errorf("disabled limiter Len() = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerDeviceRateLimiter_NegativeCapDisabled(t *testing.T) {
|
||||
l := NewPerDeviceRateLimiter(-1, 24*time.Hour, 10)
|
||||
if !l.Disabled() {
|
||||
t.Fatal("negative maxN must produce a disabled limiter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerDeviceRateLimiter_EmptySubjectShortCircuits(t *testing.T) {
|
||||
// Empty subject is the caller's defense-in-depth case (claim validation
|
||||
// upstream should reject empty-subject claims first). Limiter must not
|
||||
// build a single shared bucket keyed by empty-subject — that would
|
||||
// be a fleet-wide chokepoint.
|
||||
l := NewPerDeviceRateLimiter(1, 24*time.Hour, 10)
|
||||
now := time.Now()
|
||||
for i := 0; i < 50; i++ {
|
||||
if err := l.Allow("", "iss", now); err != nil {
|
||||
t.Fatalf("empty subject must short-circuit (call %d): %v", i, err)
|
||||
}
|
||||
}
|
||||
if got := l.Len(); got != 0 {
|
||||
t.Errorf("Len after 50 empty-subject calls = %d, want 0 (no bucket created)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerDeviceRateLimiter_DefaultCapsHonored(t *testing.T) {
|
||||
l := NewPerDeviceRateLimiter(5, 0, 0) // window=0 → 24h default; cap=0 → 100k default
|
||||
if l.window != 24*time.Hour {
|
||||
t.Errorf("default window = %v, want 24h", l.window)
|
||||
}
|
||||
if l.cap != 100_000 {
|
||||
t.Errorf("default cap = %d, want 100000", l.cap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerDeviceRateLimiter_MapCapEvictsOldest(t *testing.T) {
|
||||
// Cap of 3 keys to exercise the eviction branch deterministically.
|
||||
l := NewPerDeviceRateLimiter(2, 1*time.Hour, 3)
|
||||
now := time.Now()
|
||||
|
||||
// Insert 3 distinct keys with increasing timestamps.
|
||||
for i := 0; i < 3; i++ {
|
||||
key := fmt.Sprintf("dev-%d", i)
|
||||
if err := l.Allow(key, "iss", now.Add(time.Duration(i)*time.Minute)); err != nil {
|
||||
t.Fatalf("insert %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
if l.Len() != 3 {
|
||||
t.Fatalf("Len = %d, want 3", l.Len())
|
||||
}
|
||||
|
||||
// 4th key forces eviction of dev-0 (its newest timestamp is oldest).
|
||||
if err := l.Allow("dev-3", "iss", now.Add(10*time.Minute)); err != nil {
|
||||
t.Fatalf("4th-key insert: %v", err)
|
||||
}
|
||||
if l.Len() != 3 {
|
||||
t.Errorf("Len after at-cap insert = %d, want 3 (cap honored)", l.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerDeviceRateLimiter_ConcurrentRaceFree(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("race-style test under -short")
|
||||
}
|
||||
l := NewPerDeviceRateLimiter(50, 24*time.Hour, 10000)
|
||||
var wg sync.WaitGroup
|
||||
for g := 0; g < 20; g++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
now := time.Now()
|
||||
key := fmt.Sprintf("dev-%d", id)
|
||||
for i := 0; i < 30; i++ {
|
||||
_ = l.Allow(key, "iss", now)
|
||||
}
|
||||
}(g)
|
||||
}
|
||||
wg.Wait()
|
||||
if got := l.Len(); got != 20 {
|
||||
t.Errorf("expected 20 distinct keys; got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPruneOlderThan(t *testing.T) {
|
||||
t0 := time.Now()
|
||||
in := []time.Time{
|
||||
t0.Add(-3 * time.Hour), // pruned (older than cutoff)
|
||||
t0.Add(-2 * time.Hour), // pruned (older than cutoff)
|
||||
t0.Add(-1 * time.Hour), // survives (-60m is NEWER than the -90m cutoff)
|
||||
t0.Add(-30 * time.Minute), // survives
|
||||
t0, // survives
|
||||
}
|
||||
out := pruneOlderThan(in, t0.Add(-90*time.Minute))
|
||||
if len(out) != 3 {
|
||||
t.Fatalf("len(out) = %d, want 3 (-1h, -30m, t0 all newer than -90m cutoff)", len(out))
|
||||
}
|
||||
if !out[0].Equal(t0.Add(-1 * time.Hour)) {
|
||||
t.Errorf("out[0] = %v, want -1h (oldest surviving entry)", out[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPruneOlderThan_NoOpWhenNothingToPrune(t *testing.T) {
|
||||
t0 := time.Now()
|
||||
in := []time.Time{t0.Add(-1 * time.Minute), t0}
|
||||
out := pruneOlderThan(in, t0.Add(-1*time.Hour))
|
||||
// Same slice header (no copy needed).
|
||||
if len(out) != len(in) {
|
||||
t.Fatalf("len(out) = %d, want %d", len(out), len(in))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,143 @@
|
||||
package intune
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// TrustAnchorHolder is the SIGHUP-reloadable wrapper around a per-profile
|
||||
// Intune Connector trust anchor pool.
|
||||
//
|
||||
// SCEP RFC 8894 + Intune master bundle Phase 8.5.
|
||||
//
|
||||
// Mirrors the shape established by `cmd/server/tls.go::certHolder` for the
|
||||
// server TLS cert: an RWMutex-guarded pool, a Get accessor that's safe for
|
||||
// concurrent callers from the request path, a Reload that re-reads the file
|
||||
// and atomically swaps the slice on success (failure leaves the OLD pool in
|
||||
// place so a bad reload doesn't take Intune enrollment down), and a
|
||||
// watchSIGHUP goroutine that responds to the same SIGHUP the operator uses
|
||||
// to rotate the server TLS cert.
|
||||
//
|
||||
// Why SIGHUP specifically (vs fsnotify or a polling loop): SIGHUP is the
|
||||
// repo-established convention (see cmd/server/tls.go). fsnotify would add a
|
||||
// new direct dep + complicate the cleanup story. The operator's Connector-
|
||||
// rotation script writes the new PEM bundle then sends SIGHUP — the same
|
||||
// signal that already rotates the server TLS cert — and both swap atomically.
|
||||
//
|
||||
// Concurrency contract:
|
||||
// - Get returns the pool slice header by value; the slice itself is
|
||||
// immutable per-snapshot (Reload swaps a fresh slice rather than
|
||||
// mutating the existing one). Callers may iterate the returned slice
|
||||
// without holding any lock.
|
||||
// - Reload acquires a write lock briefly for the swap. Concurrent Get
|
||||
// calls block only for that swap window (microseconds).
|
||||
// - watchSIGHUP runs at most one Reload at a time per holder.
|
||||
type TrustAnchorHolder struct {
|
||||
mu sync.RWMutex
|
||||
certs []*x509.Certificate
|
||||
path string
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewTrustAnchorHolder loads the trust bundle and returns a holder. Returns
|
||||
// the same fail-loud error LoadTrustAnchor does on initial load — the
|
||||
// startup gate at cmd/server/main.go is supposed to refuse boot when this
|
||||
// fails. Subsequent Reload errors are non-fatal (logged + old pool retained).
|
||||
//
|
||||
// The logger is required (never nil); the caller passes a per-profile
|
||||
// scoped logger so SIGHUP-reload events show the PathID for triage.
|
||||
func NewTrustAnchorHolder(path string, logger *slog.Logger) (*TrustAnchorHolder, error) {
|
||||
if logger == nil {
|
||||
return nil, errors.New("intune: TrustAnchorHolder requires a non-nil logger")
|
||||
}
|
||||
certs, err := LoadTrustAnchor(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TrustAnchorHolder{
|
||||
certs: certs,
|
||||
path: path,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get returns the current trust anchor pool. Safe for concurrent callers;
|
||||
// the slice header is returned by value and the underlying slice is
|
||||
// immutable per-snapshot (Reload swaps a fresh slice, doesn't mutate in
|
||||
// place — see Reload).
|
||||
func (h *TrustAnchorHolder) Get() []*x509.Certificate {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.certs
|
||||
}
|
||||
|
||||
// Path returns the on-disk path the holder reloads from. Useful for
|
||||
// observability (admin endpoints, log lines) without exposing the cert
|
||||
// pool itself.
|
||||
func (h *TrustAnchorHolder) Path() string {
|
||||
return h.path
|
||||
}
|
||||
|
||||
// Reload re-reads the trust anchor file at h.path and atomically swaps the
|
||||
// pool. Returns the parse error if the new file is invalid; the OLD pool
|
||||
// stays in place so a bad reload doesn't take Intune enrollment down.
|
||||
//
|
||||
// Same fail-safe pattern as cmd/server/tls.go::(*certHolder).Reload — a
|
||||
// rotation that writes a half-file (operator overwrites the bundle while
|
||||
// only some of the new certs are in it) would otherwise crash the
|
||||
// service mid-rotation. Logging + retaining the old pool gives the
|
||||
// operator a bounded window to fix and re-SIGHUP.
|
||||
func (h *TrustAnchorHolder) Reload() error {
|
||||
certs, err := LoadTrustAnchor(h.path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.mu.Lock()
|
||||
h.certs = certs
|
||||
h.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// WatchSIGHUP installs a signal handler that calls Reload on each SIGHUP.
|
||||
// The returned stop function closes the internal done channel and stops
|
||||
// signal delivery so the goroutine can exit cleanly during shutdown.
|
||||
//
|
||||
// Errors from Reload are logged but do not terminate the watcher — the
|
||||
// operator can fix the files and send another SIGHUP. Mirrors the
|
||||
// (*certHolder).watchSIGHUP contract exactly.
|
||||
//
|
||||
// Multiple holders can coexist: each registers its own goroutine on the
|
||||
// same SIGHUP signal. signal.Notify multicasts to every registered
|
||||
// channel, so a single SIGHUP reloads every per-profile Intune trust
|
||||
// anchor PLUS the server TLS cert in one operator action — exactly the
|
||||
// design requirement (one SIGHUP rotates everything).
|
||||
func (h *TrustAnchorHolder) WatchSIGHUP() (stop func()) {
|
||||
ch := make(chan os.Signal, 1)
|
||||
signal.Notify(ch, syscall.SIGHUP)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
if err := h.Reload(); err != nil {
|
||||
h.logger.Error("Intune trust anchor reload failed; continuing with previous pool",
|
||||
"error", err,
|
||||
"path", h.path)
|
||||
continue
|
||||
}
|
||||
h.logger.Info("Intune trust anchor reloaded via SIGHUP",
|
||||
"path", h.path,
|
||||
"certs_loaded", len(h.Get()))
|
||||
case <-done:
|
||||
signal.Stop(ch)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return func() { close(done) }
|
||||
}
|
||||
@@ -0,0 +1,234 @@
|
||||
package intune
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// silentLogger returns a logger that drops everything; the SIGHUP watcher
|
||||
// path emits Info logs we don't want fouling test output.
|
||||
func silentTestLogger() *slog.Logger {
|
||||
return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError + 10}))
|
||||
}
|
||||
|
||||
// writeTestBundle writes a PEM bundle of the given certs at path with mode 0600.
|
||||
func writeTestBundle(t *testing.T, path string, certs []*x509.Certificate) {
|
||||
t.Helper()
|
||||
body := []byte{}
|
||||
for _, c := range certs {
|
||||
body = append(body, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: c.Raw})...)
|
||||
}
|
||||
if err := os.WriteFile(path, body, 0o600); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// freshHolderCert is a small factory for a self-signed EC cert with a
|
||||
// caller-controlled CN + lifetime. Used by Reload tests that swap the
|
||||
// on-disk pool between calls.
|
||||
func freshHolderCert(t *testing.T, cn string, notAfter time.Time) *x509.Certificate {
|
||||
t.Helper()
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("ecdsa.GenerateKey: %v", err)
|
||||
}
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(time.Now().UnixNano()),
|
||||
Subject: pkix.Name{CommonName: cn},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: notAfter,
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
t.Fatalf("x509.CreateCertificate: %v", err)
|
||||
}
|
||||
cert, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
t.Fatalf("x509.ParseCertificate: %v", err)
|
||||
}
|
||||
return cert
|
||||
}
|
||||
|
||||
func TestTrustAnchorHolder_NewLoadsBundle(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "intune-trust.pem")
|
||||
cert := freshHolderCert(t, "initial-conn", time.Now().Add(30*24*time.Hour))
|
||||
writeTestBundle(t, path, []*x509.Certificate{cert})
|
||||
|
||||
holder, err := NewTrustAnchorHolder(path, silentTestLogger())
|
||||
if err != nil {
|
||||
t.Fatalf("NewTrustAnchorHolder: %v", err)
|
||||
}
|
||||
got := holder.Get()
|
||||
if len(got) != 1 || got[0].Subject.CommonName != "initial-conn" {
|
||||
t.Fatalf("Get returned %#v, want one cert with CN=initial-conn", got)
|
||||
}
|
||||
if holder.Path() != path {
|
||||
t.Errorf("Path = %q, want %q", holder.Path(), path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrustAnchorHolder_NewRequiresLogger(t *testing.T) {
|
||||
if _, err := NewTrustAnchorHolder("/nonexistent", nil); err == nil {
|
||||
t.Fatal("nil logger must error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrustAnchorHolder_NewSurfacesLoadError(t *testing.T) {
|
||||
if _, err := NewTrustAnchorHolder("/path/that/does/not/exist.pem", silentTestLogger()); err == nil {
|
||||
t.Fatal("missing file must error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrustAnchorHolder_ReloadHappyPath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "trust.pem")
|
||||
c1 := freshHolderCert(t, "rev-1", time.Now().Add(30*24*time.Hour))
|
||||
writeTestBundle(t, path, []*x509.Certificate{c1})
|
||||
|
||||
h, err := NewTrustAnchorHolder(path, silentTestLogger())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Rotate on disk and call Reload.
|
||||
c2 := freshHolderCert(t, "rev-2", time.Now().Add(30*24*time.Hour))
|
||||
writeTestBundle(t, path, []*x509.Certificate{c2})
|
||||
if err := h.Reload(); err != nil {
|
||||
t.Fatalf("Reload: %v", err)
|
||||
}
|
||||
got := h.Get()
|
||||
if len(got) != 1 || got[0].Subject.CommonName != "rev-2" {
|
||||
t.Errorf("after Reload Get = %#v, want one cert CN=rev-2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrustAnchorHolder_ReloadKeepsOldOnFailure(t *testing.T) {
|
||||
// Mid-rotation half-file: operator overwrites the bundle with garbage
|
||||
// → Reload errors → holder must still serve the OLD pool. Without this
|
||||
// fail-safe a single typo would take Intune enrollment down for the
|
||||
// whole window until a re-rotate.
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "trust.pem")
|
||||
good := freshHolderCert(t, "stable", time.Now().Add(30*24*time.Hour))
|
||||
writeTestBundle(t, path, []*x509.Certificate{good})
|
||||
|
||||
h, err := NewTrustAnchorHolder(path, silentTestLogger())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Overwrite with content that LoadTrustAnchor will reject (no PEM blocks).
|
||||
if err := os.WriteFile(path, []byte("garbage"), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := h.Reload(); err == nil {
|
||||
t.Fatal("Reload from garbage file must error")
|
||||
}
|
||||
|
||||
// Old pool still served.
|
||||
got := h.Get()
|
||||
if len(got) != 1 || got[0].Subject.CommonName != "stable" {
|
||||
t.Errorf("after failed Reload Get should still be the pre-Reload pool; got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrustAnchorHolder_ReloadKeepsOldOnExpired(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "trust.pem")
|
||||
good := freshHolderCert(t, "still-valid", time.Now().Add(30*24*time.Hour))
|
||||
writeTestBundle(t, path, []*x509.Certificate{good})
|
||||
|
||||
h, err := NewTrustAnchorHolder(path, silentTestLogger())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Operator rotates to a cert that's already expired (their script
|
||||
// pulled an old bundle by mistake). Reload should error AND the holder
|
||||
// should retain the previous good pool — exactly the fail-safe semantics
|
||||
// LoadTrustAnchor enforces at startup.
|
||||
expired := freshHolderCert(t, "expired-conn", time.Now().Add(-1*time.Hour))
|
||||
writeTestBundle(t, path, []*x509.Certificate{expired})
|
||||
|
||||
if err := h.Reload(); err == nil {
|
||||
t.Fatal("Reload with expired cert must error")
|
||||
}
|
||||
if !strings.Contains(h.Get()[0].Subject.CommonName, "still-valid") {
|
||||
t.Errorf("after expired-cert Reload, holder should retain old pool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrustAnchorHolder_WatchSIGHUPReloadsPool(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "trust.pem")
|
||||
c1 := freshHolderCert(t, "rev-pre-sighup", time.Now().Add(30*24*time.Hour))
|
||||
writeTestBundle(t, path, []*x509.Certificate{c1})
|
||||
|
||||
h, err := NewTrustAnchorHolder(path, silentTestLogger())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
stop := h.WatchSIGHUP()
|
||||
defer stop()
|
||||
|
||||
// Rotate on disk, then send SIGHUP to our own process and poll for the swap.
|
||||
c2 := freshHolderCert(t, "rev-post-sighup", time.Now().Add(30*24*time.Hour))
|
||||
writeTestBundle(t, path, []*x509.Certificate{c2})
|
||||
if err := syscall.Kill(syscall.Getpid(), syscall.SIGHUP); err != nil {
|
||||
t.Fatalf("send SIGHUP: %v", err)
|
||||
}
|
||||
|
||||
// Poll for up to 2 seconds.
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
got := h.Get()
|
||||
if len(got) == 1 && got[0].Subject.CommonName == "rev-post-sighup" {
|
||||
return
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("post-SIGHUP pool not swapped in 2s; current CN=%q", got[0].Subject.CommonName)
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrustAnchorHolder_WatchSIGHUPStopIsClean(t *testing.T) {
|
||||
// Mirrors cmd/server/tls_test.go::TestCertHolder_WatchSIGHUP_StopExits:
|
||||
// we do NOT fire a SIGHUP after stop(), because once signal.Stop has
|
||||
// removed our handler the kernel's default action on SIGHUP is to
|
||||
// terminate the process — it would kill the test runner. The contract
|
||||
// we need to pin is "stop() is synchronous and safe", which we
|
||||
// demonstrate by closing the watcher and verifying the holder still
|
||||
// serves the original cert without panic.
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "trust.pem")
|
||||
writeTestBundle(t, path, []*x509.Certificate{
|
||||
freshHolderCert(t, "stop-test", time.Now().Add(30*24*time.Hour)),
|
||||
})
|
||||
|
||||
h, err := NewTrustAnchorHolder(path, silentTestLogger())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
stop := h.WatchSIGHUP()
|
||||
stop()
|
||||
time.Sleep(50 * time.Millisecond) // let the goroutine fully exit
|
||||
|
||||
if cn := h.Get()[0].Subject.CommonName; cn != "stop-test" {
|
||||
t.Errorf("after stop CN = %q, want unchanged stop-test", cn)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user