feat(scep-intune): per-profile dispatcher + SIGHUP reload + per-device rate limit + compliance hook seam

Phase 8 of the SCEP RFC 8894 + Intune master bundle. Wires the
internal/scep/intune validator from Phase 7 into the SCEPService
dispatch path, with a SIGHUP-reloadable trust anchor holder, a
per-(Subject, Issuer) sliding-window rate limiter, and a nil-default
ComplianceCheck seam for V3-Pro.

Operator-visible surface (per-profile, all default to off):

  CERTCTL_SCEP_PROFILE_<NAME>_INTUNE_ENABLED=true
  CERTCTL_SCEP_PROFILE_<NAME>_INTUNE_CONNECTOR_CERT_PATH=/etc/certctl/intune.pem
  CERTCTL_SCEP_PROFILE_<NAME>_INTUNE_AUDIENCE=https://certctl.example.com/scep/corp
  CERTCTL_SCEP_PROFILE_<NAME>_INTUNE_CHALLENGE_VALIDITY=60m
  CERTCTL_SCEP_PROFILE_<NAME>_INTUNE_PER_DEVICE_RATE_LIMIT_24H=3

Per-profile dispatch (Phase 8.8): an operator running corp-laptops
through Intune AND IoT devices through static challenge configures
INTUNE_ENABLED=true on the corp profile only — the IoT profile's
PKCSReq path skips the dispatcher entirely. Mirrors the per-profile
shape established by Phase 1.5.

Wire-in surfaces:

  * config.go (Phase 8.1): SCEPProfileConfig.Intune sub-config of
    type SCEPIntuneProfileConfig (Enabled/ConnectorCertPath/Audience/
    ChallengeValidity/PerDeviceRateLimit24h). Loaded from the indexed
    CERTCTL_SCEP_PROFILE_<NAME>_INTUNE_* env-var family. Per-profile
    Validate gate refuses INTUNE_ENABLED=true with empty ConnectorCertPath
    OR negative PerDeviceRateLimit24h.

  * cmd/server/main.go (Phase 8.2 + wire-in): preflightSCEPIntuneTrustAnchor
    helper mirrors preflightSCEPRACertKey/preflightSCEPMTLSTrustBundle
    shape — fail-loud at boot when the trust anchor file is missing /
    unreadable / empty / contains an expired cert. The per-profile loop
    builds the holder + replay cache + rate limiter, calls
    SetIntuneIntegration on the SCEPService, and starts the SIGHUP
    watcher. A deferred sweep stops every watcher at shutdown.

  * internal/scep/intune/trust_anchor_holder.go (Phase 8.5):
    TrustAnchorHolder mirrors cmd/server/tls.go::certHolder. RWMutex-
    guarded pool + Reload that swaps a fresh slice on success +
    WatchSIGHUP goroutine that responds to the same SIGHUP the existing
    TLS-cert watcher uses. A bad reload (parse error, expired cert)
    keeps the OLD pool in place so a half-rotation doesn't take Intune
    enrollment down — same fail-safe pattern. Operators rotate via the
    on-disk file then 'kill -HUP <certctl-pid>'.

  * internal/scep/intune/rate_limit.go (Phase 8.6): hand-rolled
    sliding-window-log limiter keyed by (Subject, Issuer). 100k-entry
    map cap (matches replay cache); at-cap drops the bucket whose
    newest timestamp is the oldest. Default 3 enrollments per 24h
    covers legitimate first-cert + recovery + post-wipe re-enrollment
    but blocks bulk enumeration from a compromised Connector signing
    key. maxN <= 0 disables the limiter for tests + the rare operator
    who wants no per-device cap. Empty subject short-circuits to allow
    (defense-in-depth: caller's claim validation rejects empty-subject
    upstream; no shared bucket on '').

    Why hand-rolled instead of golang.org/x/time/rate: the rate
    package is in go.sum as an indirect transitive but not a direct
    dep. ~30 LoC of stdlib avoids creating a new direct dep.

  * internal/service/scep.go (Phase 8.3 + 8.4 + 8.7):
    - SCEPService gains intuneEnabled / intuneTrust / intuneAudience /
      intuneValidity / intuneReplayCache / intuneRateLimiter /
      complianceCheck fields.
    - SetIntuneIntegration() constructor-time injection wires the
      per-profile state. Profiles with INTUNE_ENABLED=false never
      call this method, so they pay zero overhead.
    - SetComplianceCheck() installs the V3-Pro plug-in (see Phase 8.7).
    - looksIntuneShaped(): JWT-shape pre-check (length > 200 + exactly
      two dots). Allowed to false-positive (validator catches malformed
      → ErrChallengeMalformed); MUST NOT false-negative on real Intune
      challenges.
    - dispatchIntuneChallenge(): the load-bearing core. Runs
      ValidateChallenge → CSR-binding via DeviceMatchesCSR → replay
      cache CheckAndInsert → per-device Allow → optional ComplianceCheck.
      Each failure leg increments a typed metric label and emits an
      audit-friendly Warn log line.
    - PKCSReq + PKCSReqWithEnvelope + RenewalReqWithEnvelope all call
      dispatchIntuneChallenge first; on outcome.decided=true they
      either short-circuit (with a typed-error → SCEPFailInfo mapping)
      or call processEnrollment with action='scep_pkcsreq_intune'
      (so audit greps can count Intune-vs-static enrollments).
    - mapIntuneErrorToFailInfo(): typed-error → SCEPFailInfo per
      RFC 8894 §3.2.1.4.5 (signature/replay/expired → BadMessageCheck;
      claim-mismatch → BadRequest; default → BadRequest).
    - intuneFailReason(): typed-error → metric label
      ('signature_invalid' / 'expired' / 'rate_limited' / etc.). Default
      'malformed' so a previously-unseen error category still surfaces
      in the metric for follow-up.
    - ComplianceCheck (Phase 8.7): nil-default no-op gate. V3-Pro plugs
      in via SetComplianceCheck to call Microsoft Graph's compliance
      API. Returns (compliant, reason, err). nil-err + compliant=false
      → CertRep FAILURE + 'compliance' reason in audit. err != nil →
      fail-safe deny (V3-Pro module is responsible for any 'permit on
      API failure' policy).

  * internal/service/scep.go also gains parseCSRForIntune() — small
    private wrapper around encoding/pem + x509 used by the dispatcher
    for the claim ↔ CSR binding check (separated from the broader
    processEnrollment because we want to bind BEFORE consuming the
    replay-cache slot).

Tests (gates: ≥85% coverage on intune package, ≥70% on service):

  * scep_intune_test.go (in internal/service): 14 dispatcher tests
    covering happy-path Intune enrollment + static-challenge fallback
    + tampered-challenge reject + claim-mismatch reject + replay
    detected + rate-limited + compliance-hook nil-default + compliance-
    hook denies non-compliant + compliance-hook error fails closed +
    IntuneEnabled accessor + 'no IntuneEnabled = static path
    unchanged' regression pin + intuneFailReason mapping for every
    typed error + looksIntuneShaped boundary cases.

  * trust_anchor_holder_test.go (in internal/scep/intune): NewLoadsBundle,
    NewRequiresLogger, NewSurfacesLoadError, ReloadHappyPath,
    ReloadKeepsOldOnFailure, ReloadKeepsOldOnExpired (the fail-safe
    semantics that make the SIGHUP path operator-friendly),
    WatchSIGHUPReloadsPool (real SIGHUP to self with poll-for-swap
    pattern mirroring cmd/server/tls_test.go), WatchSIGHUPStopIsClean
    (does NOT fire SIGHUP after stop — same caveat as the TLS test:
    the Go runtime would otherwise terminate the test runner on the
    next SIGHUP since signal.Stop has removed the handler).

  * rate_limit_test.go (in internal/scep/intune): AllowsUpToCap,
    DistinctKeysIndependent, WindowExpiry, DisabledBypass (maxN=0),
    NegativeCapDisabled, EmptySubjectShortCircuits (defense-in-depth
    against an empty-subject DoS chokepoint), DefaultCapsHonored,
    MapCapEvictsOldest (at-cap eviction branch), ConcurrentRaceFree
    (50 goroutines × 200 inserts), pruneOlderThan + the no-op case.

Verification:

  * gofmt -l on all touched files: clean
  * go vet ./... : clean
  * staticcheck on intune/service/config/cmd-server: clean
  * go test -count=1 -cover ./internal/scep/intune/...: 94.8%
    (target ≥85%)
  * go test -short across intune+service+config+handler+cmd-server:
    all green
  * G-3 docs-drift CI guard reproduced locally: docs-only filtered=
    empty, config-only=empty. The new env vars match the existing
    CERTCTL_SCEP_ allowlist prefix.

Refs: cowork/scep-rfc8894-intune-master-prompt.md::Phase 8
      cowork/scep-rfc8894-intune/progress.md
      Constitutional rule: 'Always take the complete path, not the
      easy path' (cowork/CLAUDE.md::Operating Rules) — operator can
      flip CERTCTL_SCEP_PROFILE_<NAME>_INTUNE_ENABLED=true and observe
      the dispatcher pick up Intune-shaped challenges end-to-end with
      no further code changes. Foundation + plumbing ship together.
This commit is contained in:
Shankar
2026-04-29 15:34:19 +00:00
parent 0861aa9482
commit 2263e2886b
10 changed files with 1918 additions and 4 deletions
+193
View File
@@ -0,0 +1,193 @@
package intune
import (
"errors"
"sync"
"time"
)
// SCEP RFC 8894 + Intune master bundle Phase 8.6.
//
// PerDeviceRateLimiter is the second line of defense behind the replay cache
// from Phase 7. The replay cache catches the same challenge being submitted
// twice (within the challenge TTL); this rate limiter catches a compromised
// Connector signing key (or a stolen key+cert pair) issuing many DIFFERENT
// valid challenges for the same device subject in a short window.
//
// Threat model:
//
// - Replay cache (Phase 7): nonce-keyed; catches duplicate submission.
// - This limiter: (Subject, Issuer)-keyed; catches enrollment-flooding.
//
// Default: 3 enrollments per (device GUID, Connector identity) per 24h.
//
// Sizing: 100,000 distinct device entries (matches the replay cache cap).
// At-cap: oldest entry evicted (small janitor pass) to avoid unbounded
// memory growth on a fleet that grows past the cap.
//
// Why a hand-rolled token bucket instead of pulling in golang.org/x/time/rate:
// the rate package is in go.sum as an indirect transitive but NOT a direct
// dep. Adding it would create a new direct dep relationship for ~30 LoC of
// state machine. The hand-rolled version below uses only stdlib (sync.Mutex
// + time.Time arithmetic) and is small enough to fit on one screen.
//
// Algorithm: each (Subject, Issuer) key maps to a bucket holding a window's
// worth of recent enrollment timestamps. On Allow, the bucket prunes
// timestamps older than (now - window) and either appends the current
// timestamp + returns true, or rejects + returns false when the post-prune
// count is already at the cap. This is the "sliding window log" rate
// limiter — exact (no token-leak rounding); O(N_per_key) per-call but N is
// bounded by the cap (3 by default), so effectively O(1).
// ErrRateLimited is the typed error returned when the per-device rate limit
// fires. The handler maps this to a CertRep FAILURE with badRequest failInfo
// + the `rate_limited` metric label.
var ErrRateLimited = errors.New("intune: per-device rate limit exceeded for this (subject, issuer) within the configured window")
// PerDeviceRateLimiter is a sliding-window-log rate limiter keyed by
// (Subject, Issuer) tuples derived from a parsed challenge claim.
//
// Concurrency: the limiter is safe for concurrent Allow calls. The internal
// map is guarded by a mutex; the per-key slices are mutated only while the
// mutex is held.
type PerDeviceRateLimiter struct {
mu sync.Mutex
buckets map[string][]time.Time // key → sliding window of timestamps
maxN int // max enrollments per window
window time.Duration // window length (default 24h)
cap int // max keys before LRU eviction kicks in
disabled bool // maxN == 0 → all Allow calls return nil
}
// NewPerDeviceRateLimiter returns a limiter with the given per-key cap +
// window. maxN ≤ 0 disables the limiter (all Allow calls return nil); this
// is operator opt-out for the rare case where the per-device cap is
// undesirable (e.g. test harnesses, sketchpad deploys).
//
// Window defaults to 24h when zero. Map cap defaults to 100,000 when zero
// (matches the replay cache cap; see internal/scep/intune/replay.go).
func NewPerDeviceRateLimiter(maxN int, window time.Duration, mapCap int) *PerDeviceRateLimiter {
if window <= 0 {
window = 24 * time.Hour
}
if mapCap <= 0 {
mapCap = 100_000
}
return &PerDeviceRateLimiter{
buckets: make(map[string][]time.Time),
maxN: maxN,
window: window,
cap: mapCap,
disabled: maxN <= 0,
}
}
// Allow checks whether an enrollment for the given (subject, issuer) tuple
// is permitted right now. Returns nil when allowed (and records the timestamp
// in the bucket) or ErrRateLimited when the bucket is at maxN.
//
// Empty subject is treated as "skip the limiter" — the caller's claim
// validation should have rejected an empty-subject claim already; this is
// belt-and-suspenders to prevent a single empty-subject bucket from
// becoming a fleet-wide chokepoint. The Connector emits non-empty subject
// (device GUID) on every legitimate challenge.
func (l *PerDeviceRateLimiter) Allow(subject, issuer string, now time.Time) error {
if l.disabled {
return nil
}
if subject == "" {
// Caller's claim validation should reject empty-subject upstream;
// this short-circuit is defense-in-depth so a misconfigured
// Connector can't DoS us via the rate-limit path.
return nil
}
key := subject + "|" + issuer
l.mu.Lock()
defer l.mu.Unlock()
// At-cap eviction: when the map is full, drop the oldest entry by
// finding the bucket whose newest timestamp is the smallest. O(N) but
// rarely fires; the prune-on-Allow path keeps most buckets short-lived.
if len(l.buckets) >= l.cap {
l.evictOldestLocked(now)
}
bucket := l.buckets[key]
bucket = pruneOlderThan(bucket, now.Add(-l.window))
if len(bucket) >= l.maxN {
// Don't append; over the limit. Persist the pruned bucket so the
// next call sees the most-recently-pruned state.
l.buckets[key] = bucket
return ErrRateLimited
}
bucket = append(bucket, now)
l.buckets[key] = bucket
return nil
}
// pruneOlderThan returns the slice with all entries strictly before
// `cutoff` removed. Preserves order (timestamps are appended in increasing
// time, so a single linear scan from the front suffices).
func pruneOlderThan(b []time.Time, cutoff time.Time) []time.Time {
i := 0
for i < len(b) && b[i].Before(cutoff) {
i++
}
if i == 0 {
return b
}
// Copy-shrink to release the underlying-array memory eventually
// (otherwise the slice would hold a reference to the older entries
// indefinitely until a re-allocation).
out := make([]time.Time, len(b)-i)
copy(out, b[i:])
return out
}
// evictOldestLocked drops the map entry whose newest timestamp is the
// oldest. Called under l.mu. O(N_keys) per eviction; at-cap is rare in
// practice (caps are sized for fleet steady-state).
func (l *PerDeviceRateLimiter) evictOldestLocked(now time.Time) {
var (
oldestKey string
oldestTs time.Time
first = true
)
for k, b := range l.buckets {
if len(b) == 0 {
// Empty bucket — drop it immediately, no candidate scan needed.
delete(l.buckets, k)
return
}
newest := b[len(b)-1]
if first || newest.Before(oldestTs) {
oldestKey = k
oldestTs = newest
first = false
}
}
if oldestKey != "" {
delete(l.buckets, oldestKey)
}
// Suppress unused-parameter warning for `now` in case the eviction
// strategy changes (e.g. swap to LRU keyed by time of last Allow).
_ = now
}
// Len returns the approximate number of distinct (subject, issuer) keys
// currently tracked. For observability + tests; not load-stable under
// concurrent Allow calls.
func (l *PerDeviceRateLimiter) Len() int {
l.mu.Lock()
defer l.mu.Unlock()
return len(l.buckets)
}
// Disabled reports whether the limiter is in opt-out mode (maxN ≤ 0).
// Useful for handler-side gating + admin-endpoint observability.
func (l *PerDeviceRateLimiter) Disabled() bool {
return l.disabled
}
+190
View File
@@ -0,0 +1,190 @@
package intune
import (
"errors"
"fmt"
"sync"
"testing"
"time"
)
func TestPerDeviceRateLimiter_AllowsUpToCap(t *testing.T) {
l := NewPerDeviceRateLimiter(3, 24*time.Hour, 10)
now := time.Now()
for i := 0; i < 3; i++ {
if err := l.Allow("device-1", "issuer-A", now.Add(time.Duration(i)*time.Minute)); err != nil {
t.Fatalf("call %d should be allowed: %v", i+1, err)
}
}
if err := l.Allow("device-1", "issuer-A", now.Add(4*time.Minute)); !errors.Is(err, ErrRateLimited) {
t.Fatalf("4th call should be rate-limited; got %v", err)
}
}
func TestPerDeviceRateLimiter_DistinctKeysIndependent(t *testing.T) {
l := NewPerDeviceRateLimiter(1, 24*time.Hour, 10)
now := time.Now()
if err := l.Allow("device-1", "issuer-A", now); err != nil {
t.Fatalf("first allow: %v", err)
}
// Different subject — independent bucket.
if err := l.Allow("device-2", "issuer-A", now); err != nil {
t.Fatalf("different subject must have its own bucket: %v", err)
}
// Different issuer — also independent.
if err := l.Allow("device-1", "issuer-B", now); err != nil {
t.Fatalf("different issuer must have its own bucket: %v", err)
}
// Same key as call 1 — must be limited.
if err := l.Allow("device-1", "issuer-A", now.Add(1*time.Second)); !errors.Is(err, ErrRateLimited) {
t.Fatalf("repeat key should be limited; got %v", err)
}
}
func TestPerDeviceRateLimiter_WindowExpiry(t *testing.T) {
l := NewPerDeviceRateLimiter(2, 1*time.Hour, 10)
now := time.Now()
if err := l.Allow("dev", "iss", now); err != nil {
t.Fatal(err)
}
if err := l.Allow("dev", "iss", now.Add(30*time.Minute)); err != nil {
t.Fatal(err)
}
// Inside window — limited.
if err := l.Allow("dev", "iss", now.Add(45*time.Minute)); !errors.Is(err, ErrRateLimited) {
t.Fatalf("inside-window 3rd call should be limited: %v", err)
}
// Past window — slots reopen.
if err := l.Allow("dev", "iss", now.Add(2*time.Hour)); err != nil {
t.Fatalf("past-window call should be allowed (window reset): %v", err)
}
}
func TestPerDeviceRateLimiter_DisabledBypass(t *testing.T) {
l := NewPerDeviceRateLimiter(0, 24*time.Hour, 10) // maxN=0 → disabled
if !l.Disabled() {
t.Fatal("limiter with maxN=0 must report Disabled()=true")
}
now := time.Now()
for i := 0; i < 100; i++ {
if err := l.Allow("dev", "iss", now); err != nil {
t.Fatalf("disabled limiter must allow everything: %v", err)
}
}
// Disabled limiter doesn't track buckets.
if got := l.Len(); got != 0 {
t.Errorf("disabled limiter Len() = %d, want 0", got)
}
}
func TestPerDeviceRateLimiter_NegativeCapDisabled(t *testing.T) {
l := NewPerDeviceRateLimiter(-1, 24*time.Hour, 10)
if !l.Disabled() {
t.Fatal("negative maxN must produce a disabled limiter")
}
}
func TestPerDeviceRateLimiter_EmptySubjectShortCircuits(t *testing.T) {
// Empty subject is the caller's defense-in-depth case (claim validation
// upstream should reject empty-subject claims first). Limiter must not
// build a single shared bucket keyed by empty-subject — that would
// be a fleet-wide chokepoint.
l := NewPerDeviceRateLimiter(1, 24*time.Hour, 10)
now := time.Now()
for i := 0; i < 50; i++ {
if err := l.Allow("", "iss", now); err != nil {
t.Fatalf("empty subject must short-circuit (call %d): %v", i, err)
}
}
if got := l.Len(); got != 0 {
t.Errorf("Len after 50 empty-subject calls = %d, want 0 (no bucket created)", got)
}
}
func TestPerDeviceRateLimiter_DefaultCapsHonored(t *testing.T) {
l := NewPerDeviceRateLimiter(5, 0, 0) // window=0 → 24h default; cap=0 → 100k default
if l.window != 24*time.Hour {
t.Errorf("default window = %v, want 24h", l.window)
}
if l.cap != 100_000 {
t.Errorf("default cap = %d, want 100000", l.cap)
}
}
func TestPerDeviceRateLimiter_MapCapEvictsOldest(t *testing.T) {
// Cap of 3 keys to exercise the eviction branch deterministically.
l := NewPerDeviceRateLimiter(2, 1*time.Hour, 3)
now := time.Now()
// Insert 3 distinct keys with increasing timestamps.
for i := 0; i < 3; i++ {
key := fmt.Sprintf("dev-%d", i)
if err := l.Allow(key, "iss", now.Add(time.Duration(i)*time.Minute)); err != nil {
t.Fatalf("insert %d: %v", i, err)
}
}
if l.Len() != 3 {
t.Fatalf("Len = %d, want 3", l.Len())
}
// 4th key forces eviction of dev-0 (its newest timestamp is oldest).
if err := l.Allow("dev-3", "iss", now.Add(10*time.Minute)); err != nil {
t.Fatalf("4th-key insert: %v", err)
}
if l.Len() != 3 {
t.Errorf("Len after at-cap insert = %d, want 3 (cap honored)", l.Len())
}
}
func TestPerDeviceRateLimiter_ConcurrentRaceFree(t *testing.T) {
if testing.Short() {
t.Skip("race-style test under -short")
}
l := NewPerDeviceRateLimiter(50, 24*time.Hour, 10000)
var wg sync.WaitGroup
for g := 0; g < 20; g++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
now := time.Now()
key := fmt.Sprintf("dev-%d", id)
for i := 0; i < 30; i++ {
_ = l.Allow(key, "iss", now)
}
}(g)
}
wg.Wait()
if got := l.Len(); got != 20 {
t.Errorf("expected 20 distinct keys; got %d", got)
}
}
func TestPruneOlderThan(t *testing.T) {
t0 := time.Now()
in := []time.Time{
t0.Add(-3 * time.Hour), // pruned (older than cutoff)
t0.Add(-2 * time.Hour), // pruned (older than cutoff)
t0.Add(-1 * time.Hour), // survives (-60m is NEWER than the -90m cutoff)
t0.Add(-30 * time.Minute), // survives
t0, // survives
}
out := pruneOlderThan(in, t0.Add(-90*time.Minute))
if len(out) != 3 {
t.Fatalf("len(out) = %d, want 3 (-1h, -30m, t0 all newer than -90m cutoff)", len(out))
}
if !out[0].Equal(t0.Add(-1 * time.Hour)) {
t.Errorf("out[0] = %v, want -1h (oldest surviving entry)", out[0])
}
}
func TestPruneOlderThan_NoOpWhenNothingToPrune(t *testing.T) {
t0 := time.Now()
in := []time.Time{t0.Add(-1 * time.Minute), t0}
out := pruneOlderThan(in, t0.Add(-1*time.Hour))
// Same slice header (no copy needed).
if len(out) != len(in) {
t.Fatalf("len(out) = %d, want %d", len(out), len(in))
}
}
+143
View File
@@ -0,0 +1,143 @@
package intune
import (
"crypto/x509"
"errors"
"log/slog"
"os"
"os/signal"
"sync"
"syscall"
)
// TrustAnchorHolder is the SIGHUP-reloadable wrapper around a per-profile
// Intune Connector trust anchor pool.
//
// SCEP RFC 8894 + Intune master bundle Phase 8.5.
//
// Mirrors the shape established by `cmd/server/tls.go::certHolder` for the
// server TLS cert: an RWMutex-guarded pool, a Get accessor that's safe for
// concurrent callers from the request path, a Reload that re-reads the file
// and atomically swaps the slice on success (failure leaves the OLD pool in
// place so a bad reload doesn't take Intune enrollment down), and a
// watchSIGHUP goroutine that responds to the same SIGHUP the operator uses
// to rotate the server TLS cert.
//
// Why SIGHUP specifically (vs fsnotify or a polling loop): SIGHUP is the
// repo-established convention (see cmd/server/tls.go). fsnotify would add a
// new direct dep + complicate the cleanup story. The operator's Connector-
// rotation script writes the new PEM bundle then sends SIGHUP — the same
// signal that already rotates the server TLS cert — and both swap atomically.
//
// Concurrency contract:
// - Get returns the pool slice header by value; the slice itself is
// immutable per-snapshot (Reload swaps a fresh slice rather than
// mutating the existing one). Callers may iterate the returned slice
// without holding any lock.
// - Reload acquires a write lock briefly for the swap. Concurrent Get
// calls block only for that swap window (microseconds).
// - watchSIGHUP runs at most one Reload at a time per holder.
type TrustAnchorHolder struct {
mu sync.RWMutex
certs []*x509.Certificate
path string
logger *slog.Logger
}
// NewTrustAnchorHolder loads the trust bundle and returns a holder. Returns
// the same fail-loud error LoadTrustAnchor does on initial load — the
// startup gate at cmd/server/main.go is supposed to refuse boot when this
// fails. Subsequent Reload errors are non-fatal (logged + old pool retained).
//
// The logger is required (never nil); the caller passes a per-profile
// scoped logger so SIGHUP-reload events show the PathID for triage.
func NewTrustAnchorHolder(path string, logger *slog.Logger) (*TrustAnchorHolder, error) {
if logger == nil {
return nil, errors.New("intune: TrustAnchorHolder requires a non-nil logger")
}
certs, err := LoadTrustAnchor(path)
if err != nil {
return nil, err
}
return &TrustAnchorHolder{
certs: certs,
path: path,
logger: logger,
}, nil
}
// Get returns the current trust anchor pool. Safe for concurrent callers;
// the slice header is returned by value and the underlying slice is
// immutable per-snapshot (Reload swaps a fresh slice, doesn't mutate in
// place — see Reload).
func (h *TrustAnchorHolder) Get() []*x509.Certificate {
h.mu.RLock()
defer h.mu.RUnlock()
return h.certs
}
// Path returns the on-disk path the holder reloads from. Useful for
// observability (admin endpoints, log lines) without exposing the cert
// pool itself.
func (h *TrustAnchorHolder) Path() string {
return h.path
}
// Reload re-reads the trust anchor file at h.path and atomically swaps the
// pool. Returns the parse error if the new file is invalid; the OLD pool
// stays in place so a bad reload doesn't take Intune enrollment down.
//
// Same fail-safe pattern as cmd/server/tls.go::(*certHolder).Reload — a
// rotation that writes a half-file (operator overwrites the bundle while
// only some of the new certs are in it) would otherwise crash the
// service mid-rotation. Logging + retaining the old pool gives the
// operator a bounded window to fix and re-SIGHUP.
func (h *TrustAnchorHolder) Reload() error {
certs, err := LoadTrustAnchor(h.path)
if err != nil {
return err
}
h.mu.Lock()
h.certs = certs
h.mu.Unlock()
return nil
}
// WatchSIGHUP installs a signal handler that calls Reload on each SIGHUP.
// The returned stop function closes the internal done channel and stops
// signal delivery so the goroutine can exit cleanly during shutdown.
//
// Errors from Reload are logged but do not terminate the watcher — the
// operator can fix the files and send another SIGHUP. Mirrors the
// (*certHolder).watchSIGHUP contract exactly.
//
// Multiple holders can coexist: each registers its own goroutine on the
// same SIGHUP signal. signal.Notify multicasts to every registered
// channel, so a single SIGHUP reloads every per-profile Intune trust
// anchor PLUS the server TLS cert in one operator action — exactly the
// design requirement (one SIGHUP rotates everything).
func (h *TrustAnchorHolder) WatchSIGHUP() (stop func()) {
ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGHUP)
done := make(chan struct{})
go func() {
for {
select {
case <-ch:
if err := h.Reload(); err != nil {
h.logger.Error("Intune trust anchor reload failed; continuing with previous pool",
"error", err,
"path", h.path)
continue
}
h.logger.Info("Intune trust anchor reloaded via SIGHUP",
"path", h.path,
"certs_loaded", len(h.Get()))
case <-done:
signal.Stop(ch)
return
}
}
}()
return func() { close(done) }
}
@@ -0,0 +1,234 @@
package intune
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"io"
"log/slog"
"math/big"
"os"
"path/filepath"
"strings"
"syscall"
"testing"
"time"
)
// silentLogger returns a logger that drops everything; the SIGHUP watcher
// path emits Info logs we don't want fouling test output.
func silentTestLogger() *slog.Logger {
return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError + 10}))
}
// writeTestBundle writes a PEM bundle of the given certs at path with mode 0600.
func writeTestBundle(t *testing.T, path string, certs []*x509.Certificate) {
t.Helper()
body := []byte{}
for _, c := range certs {
body = append(body, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: c.Raw})...)
}
if err := os.WriteFile(path, body, 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
}
// freshHolderCert is a small factory for a self-signed EC cert with a
// caller-controlled CN + lifetime. Used by Reload tests that swap the
// on-disk pool between calls.
func freshHolderCert(t *testing.T, cn string, notAfter time.Time) *x509.Certificate {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("ecdsa.GenerateKey: %v", err)
}
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(time.Now().UnixNano()),
Subject: pkix.Name{CommonName: cn},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: notAfter,
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
if err != nil {
t.Fatalf("x509.CreateCertificate: %v", err)
}
cert, err := x509.ParseCertificate(der)
if err != nil {
t.Fatalf("x509.ParseCertificate: %v", err)
}
return cert
}
func TestTrustAnchorHolder_NewLoadsBundle(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "intune-trust.pem")
cert := freshHolderCert(t, "initial-conn", time.Now().Add(30*24*time.Hour))
writeTestBundle(t, path, []*x509.Certificate{cert})
holder, err := NewTrustAnchorHolder(path, silentTestLogger())
if err != nil {
t.Fatalf("NewTrustAnchorHolder: %v", err)
}
got := holder.Get()
if len(got) != 1 || got[0].Subject.CommonName != "initial-conn" {
t.Fatalf("Get returned %#v, want one cert with CN=initial-conn", got)
}
if holder.Path() != path {
t.Errorf("Path = %q, want %q", holder.Path(), path)
}
}
func TestTrustAnchorHolder_NewRequiresLogger(t *testing.T) {
if _, err := NewTrustAnchorHolder("/nonexistent", nil); err == nil {
t.Fatal("nil logger must error")
}
}
func TestTrustAnchorHolder_NewSurfacesLoadError(t *testing.T) {
if _, err := NewTrustAnchorHolder("/path/that/does/not/exist.pem", silentTestLogger()); err == nil {
t.Fatal("missing file must error")
}
}
func TestTrustAnchorHolder_ReloadHappyPath(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "trust.pem")
c1 := freshHolderCert(t, "rev-1", time.Now().Add(30*24*time.Hour))
writeTestBundle(t, path, []*x509.Certificate{c1})
h, err := NewTrustAnchorHolder(path, silentTestLogger())
if err != nil {
t.Fatal(err)
}
// Rotate on disk and call Reload.
c2 := freshHolderCert(t, "rev-2", time.Now().Add(30*24*time.Hour))
writeTestBundle(t, path, []*x509.Certificate{c2})
if err := h.Reload(); err != nil {
t.Fatalf("Reload: %v", err)
}
got := h.Get()
if len(got) != 1 || got[0].Subject.CommonName != "rev-2" {
t.Errorf("after Reload Get = %#v, want one cert CN=rev-2", got)
}
}
func TestTrustAnchorHolder_ReloadKeepsOldOnFailure(t *testing.T) {
// Mid-rotation half-file: operator overwrites the bundle with garbage
// → Reload errors → holder must still serve the OLD pool. Without this
// fail-safe a single typo would take Intune enrollment down for the
// whole window until a re-rotate.
dir := t.TempDir()
path := filepath.Join(dir, "trust.pem")
good := freshHolderCert(t, "stable", time.Now().Add(30*24*time.Hour))
writeTestBundle(t, path, []*x509.Certificate{good})
h, err := NewTrustAnchorHolder(path, silentTestLogger())
if err != nil {
t.Fatal(err)
}
// Overwrite with content that LoadTrustAnchor will reject (no PEM blocks).
if err := os.WriteFile(path, []byte("garbage"), 0o600); err != nil {
t.Fatal(err)
}
if err := h.Reload(); err == nil {
t.Fatal("Reload from garbage file must error")
}
// Old pool still served.
got := h.Get()
if len(got) != 1 || got[0].Subject.CommonName != "stable" {
t.Errorf("after failed Reload Get should still be the pre-Reload pool; got %#v", got)
}
}
func TestTrustAnchorHolder_ReloadKeepsOldOnExpired(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "trust.pem")
good := freshHolderCert(t, "still-valid", time.Now().Add(30*24*time.Hour))
writeTestBundle(t, path, []*x509.Certificate{good})
h, err := NewTrustAnchorHolder(path, silentTestLogger())
if err != nil {
t.Fatal(err)
}
// Operator rotates to a cert that's already expired (their script
// pulled an old bundle by mistake). Reload should error AND the holder
// should retain the previous good pool — exactly the fail-safe semantics
// LoadTrustAnchor enforces at startup.
expired := freshHolderCert(t, "expired-conn", time.Now().Add(-1*time.Hour))
writeTestBundle(t, path, []*x509.Certificate{expired})
if err := h.Reload(); err == nil {
t.Fatal("Reload with expired cert must error")
}
if !strings.Contains(h.Get()[0].Subject.CommonName, "still-valid") {
t.Errorf("after expired-cert Reload, holder should retain old pool")
}
}
func TestTrustAnchorHolder_WatchSIGHUPReloadsPool(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "trust.pem")
c1 := freshHolderCert(t, "rev-pre-sighup", time.Now().Add(30*24*time.Hour))
writeTestBundle(t, path, []*x509.Certificate{c1})
h, err := NewTrustAnchorHolder(path, silentTestLogger())
if err != nil {
t.Fatal(err)
}
stop := h.WatchSIGHUP()
defer stop()
// Rotate on disk, then send SIGHUP to our own process and poll for the swap.
c2 := freshHolderCert(t, "rev-post-sighup", time.Now().Add(30*24*time.Hour))
writeTestBundle(t, path, []*x509.Certificate{c2})
if err := syscall.Kill(syscall.Getpid(), syscall.SIGHUP); err != nil {
t.Fatalf("send SIGHUP: %v", err)
}
// Poll for up to 2 seconds.
deadline := time.Now().Add(2 * time.Second)
for {
got := h.Get()
if len(got) == 1 && got[0].Subject.CommonName == "rev-post-sighup" {
return
}
if time.Now().After(deadline) {
t.Fatalf("post-SIGHUP pool not swapped in 2s; current CN=%q", got[0].Subject.CommonName)
}
time.Sleep(20 * time.Millisecond)
}
}
func TestTrustAnchorHolder_WatchSIGHUPStopIsClean(t *testing.T) {
// Mirrors cmd/server/tls_test.go::TestCertHolder_WatchSIGHUP_StopExits:
// we do NOT fire a SIGHUP after stop(), because once signal.Stop has
// removed our handler the kernel's default action on SIGHUP is to
// terminate the process — it would kill the test runner. The contract
// we need to pin is "stop() is synchronous and safe", which we
// demonstrate by closing the watcher and verifying the holder still
// serves the original cert without panic.
dir := t.TempDir()
path := filepath.Join(dir, "trust.pem")
writeTestBundle(t, path, []*x509.Certificate{
freshHolderCert(t, "stop-test", time.Now().Add(30*24*time.Hour)),
})
h, err := NewTrustAnchorHolder(path, silentTestLogger())
if err != nil {
t.Fatal(err)
}
stop := h.WatchSIGHUP()
stop()
time.Sleep(50 * time.Millisecond) // let the goroutine fully exit
if cn := h.Get()[0].Subject.CommonName; cn != "stop-test" {
t.Errorf("after stop CN = %q, want unchanged stop-test", cn)
}
}