From 757e2ec30cee2d9b30fe3c086ea5cc2bb179c298 Mon Sep 17 00:00:00 2001 From: shankar0123 Date: Sun, 10 May 2026 04:56:03 +0000 Subject: [PATCH] auth-bundle-2 Phase 3: OIDC service (HandleAuthRequest, HandleCallback, RefreshKeys), hand-rolled group-claim resolver, 21+ negative-test matrix, token-leak hygiene, IdP downgrade-attack defense MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 3 of the bundle ships the business logic that turns the Phase 2 storage primitives into a working OpenID Connect 1.0 + RFC 7636 PKCE authorization-code flow against any enterprise IdP (Okta / Azure AD / Google Workspace / Keycloak / Authentik / Auth0). Service surface: - Service.HandleAuthRequest(providerID) -> authURL, cookie, preLoginID Builds the IdP redirect with PKCE-S256 (mandatory; RFC 9700 §2.1.1), server-generated 32-byte state + nonce, persisted to the pre-login row keyed by the cookie value. - Service.HandleCallback(cookie, code, state, ip, ua) -> *CallbackResult 11-step validation: pre-login lookup-and-consume (single-use), constant-time state compare, code-for-token exchange with PKCE verifier, ID-token verify (alg pin via go-oidc/v3), service-layer re-checks of iss / aud / azp (multi-aud requires it; mismatch rejected) / at_hash (REQUIRED when access_token returned — Phase 3 lifts the OIDC core "MAY" to a service-level "MUST") / exp / iat-window / nonce, group-claim resolution with userinfo fallback, group->role mapping (fail-closed on no match), user upsert, session mint via SessionMinter port. - Service.RefreshKeys(providerID) — explicit cache eviction + re-load. Re-runs the IdP downgrade-attack defense so a provider that later rotates to advertising HS* / none is caught BEFORE the next user login attempt. Security posture (every fail-closed branch is a sentinel error + test): - Algorithm pinning: allow-list {RS256, RS512, ES256, ES384, EdDSA}; deny-list {HS256, HS384, HS512, none}. Belt-and-braces re-check via isDisallowedAlg after go-oidc.Verify. - PKCE-S256 mandatory (oauth2.GenerateVerifier + S256ChallengeOption); `plain` rejection sentinel exists for defense-in-depth. - State + nonce: 32-byte crypto/rand, base64url-no-pad, constant-time compare, single-use. - IdP downgrade-attack defense: at provider creation / RefreshKeys, reject any IdP whose discovery doc advertises HS* / none in id_token_signing_alg_values_supported. - JWKS fail-closed: in-flight login fails 503; existing sessions untouched. isJWKSFetchError detects the gooidc verify-error shape; ErrJWKSUnreachable is the wire mapping. - Token-leak hygiene: ID tokens, access tokens, refresh tokens, authorization codes, PKCE verifiers, state, nonce, signing key bytes — NEVER logged at any level. logging_test.go pins the invariant via a slog buffer + grep-assert across HandleAuthRequest, HandleCallback, alg rejection, and provider-load paths. Group-claim resolver (internal/auth/oidc/groupclaim/): - Hand-rolled per Decision 10 (no JSON-path lib; ~150 LOC). - URL-shape paths (https:// / http://) treated as a single literal key — Auth0 namespaced claims like https://your-namespace/groups work without splitting on the dots in the URL. - Dot-separated paths walked through nested map[string]interface{}. - []interface{} / []string / single-string normalized to []string; bool / number / object / nil → fail closed. - 18 unit tests + sentinels (ErrPathEmpty, ErrSegmentMissing, ErrSegmentNotObject, ErrInvalidValueType). Test surface: - service_test.go: 57 test functions including all 21 prompt-mandated negative cases (wrong aud / wrong iss / expired / unknown alg / alg=none / HMAC alg / azp missing on multi-aud / azp mismatched / at_hash missing / at_hash mismatched / iat in future / iat too old / nonce mismatched / state mismatched / state replayed / PKCE plain sentinel / pre-login replay / forged cookie / IdP downgrade / group-claim missing / group-claim unmapped) plus the userinfo fallback matrix (happy path + endpoint-missing + endpoint-failing + userinfo-also-empty), HandleAuthRequest entry point + RNG-failure paths, upsertUser update + create + display-name fallback + Validate-error paths, decryptClientSecret real-encrypt round-trip + bad-passphrase, alg-parser malformed-header matrix. - logging_test.go: 4 hygiene tests pinning no token / code / verifier / state / cookie / client_secret / alg name appears in any captured log line. - groupclaim/resolver_test.go: 18 cases covering Okta string-array, Keycloak realm_access.roles, Auth0 namespaced URL claim, single-string normalization, deeply-nested 3-segment walks, and every fail-closed branch. Coverage: internal/auth/oidc 92.2% (floor: 90) internal/auth/oidc/groupclaim 100.0% (floor: 95) internal/auth/oidc/domain 96.2% (floor: 90) Coverage gates added at .github/coverage-thresholds.yml so a future regression in any fail-closed branch fails CI before the commit lands. Phase 3 of cowork/auth-bundle-2-prompt.md is closed. Next up: Phase 4 (Session service: cookies, revocation, sliding-vs-absolute expiry). --- .github/coverage-thresholds.yml | 43 + internal/auth/oidc/doc.go | 23 +- internal/auth/oidc/groupclaim/resolver.go | 142 ++ .../auth/oidc/groupclaim/resolver_test.go | 248 +++ internal/auth/oidc/logging_test.go | 183 ++ internal/auth/oidc/service.go | 847 +++++++++ internal/auth/oidc/service_test.go | 1593 +++++++++++++++++ 7 files changed, 3057 insertions(+), 22 deletions(-) create mode 100644 internal/auth/oidc/groupclaim/resolver.go create mode 100644 internal/auth/oidc/groupclaim/resolver_test.go create mode 100644 internal/auth/oidc/logging_test.go create mode 100644 internal/auth/oidc/service.go create mode 100644 internal/auth/oidc/service_test.go diff --git a/.github/coverage-thresholds.yml b/.github/coverage-thresholds.yml index cbe30df..ed8db5c 100644 --- a/.github/coverage-thresholds.yml +++ b/.github/coverage-thresholds.yml @@ -105,3 +105,46 @@ internal/service/auth: (ErrUnauthenticated / ErrForbidden / ErrSelfRoleAssignment / ErrAuthReservedActor / ErrAuthUnknownPermission / ErrAuthRoleInUse). + +internal/auth/oidc: + floor: 90 + why: | + Bundle 2 Phase 3 — OIDC service coverage gate. Phase 3 spec + pins the floor at 90 explicitly because every fail-closed + branch is load-bearing for the security posture: alg pinning + (deny-list HS*/none + allow-list RS*/ES*/EdDSA), audience + re-check, azp enforcement on multi-aud tokens, at_hash + REQUIRED-when-access-token-present (Phase 3 lifts the OIDC + core "MAY" to a service-level "MUST"), iat-window window, + nonce constant-time-compare, single-use state replay defense, + PKCE-S256 mandatory, IdP downgrade-attack defense at + provider-load + RefreshKeys time, JWKS-fail-closed semantics, + group-claim resolution + userinfo-fallback fail-closed + semantics, token-leak hygiene. A regression in any one of + these branches is a security incident; the floor catches it + before the commit lands. The mock-IdP fixture in + service_test.go is the load-bearing harness. + +internal/auth/oidc/groupclaim: + floor: 95 + why: | + Bundle 2 Phase 3 — group-claim resolver. Hand-rolled (no + JSON-path dep per Decision 10); ~150 LOC, every branch + exercised by 19 unit tests covering the documented IdP shapes + (Okta string array, Keycloak realm_access.roles, Auth0 + namespaced URL claim, single-string normalization, + deeply-nested 3-segment walks) plus every fail-closed branch + (empty path, missing key, missing nested key, non-object + intermediate, bool/number/object/nil values, array with + non-string element, URL-shape with dots-in-path treated as + literal). Resolver should be at 100%; floor at 95 leaves a + 1-statement margin for future error-message refactors. + +internal/auth/oidc/domain: + floor: 90 + why: | + Bundle 2 Phase 1 — OIDCProvider + GroupRoleMapping domain. + Validation-heavy package; constructors + Validate methods + cover all canonical IdP shapes (Okta / Azure AD / Google + Workspace / Keycloak / Authentik / Auth0). Floor at 90 to + catch any future field that ships without a validator. diff --git a/internal/auth/oidc/doc.go b/internal/auth/oidc/doc.go index 6a0565f..7d23397 100644 --- a/internal/auth/oidc/doc.go +++ b/internal/auth/oidc/doc.go @@ -6,21 +6,10 @@ // // Package layout (post-Bundle-2): // -// - internal/auth/oidc/ - this package (Phase 3 ships service.go). +// - internal/auth/oidc/ - this package; service.go ships in Phase 3. // - internal/auth/oidc/domain/ - Phase 1 ships OIDCProvider + GroupRoleMapping. // - internal/auth/oidc/groupclaim/ - Phase 3 ships the hand-rolled group-claim resolver // (no JSON-path library; ~40 LOC walking dot-paths through map[string]interface{}). -// - internal/auth/oidc/testfixtures/ - Phase 10 ships the `//go:build integration` -// Keycloak harness backing the multi-IdP test surface. -// -// Phase 0 (this commit) reserves the package directory and pins -// coreos/go-oidc/v3 + golang.org/x/oauth2 as direct go.mod requires -// via the blank imports below. Without these blanks, `go mod tidy` -// would demote both back to // indirect because no Go file under this -// tree imports them yet (the actual imports land in Phase 3's -// service.go). The blank imports are deliberate Phase-0 transitional -// scaffolding; Phase 3 replaces them with real symbol use and these -// blanks are removed. // // Audit context (do not lose): // - Apache-2.0 license, OSV.dev shows zero advisories ever on @@ -35,13 +24,3 @@ // PaesslerAG/jsonpath, ohler55/ojg, tidwall/gjson, or any sibling // transitive bloat for what is a 40-line problem. package oidc - -import ( - // Phase 0: lift coreos/go-oidc/v3 + golang.org/x/oauth2 to direct - // go.mod requires so a future `go mod tidy` keeps them out of the - // // indirect block. Phase 3 replaces these blank imports with real - // symbol use (oidc.Provider, oauth2.Config, etc.) at which point - // these lines are removed. - _ "github.com/coreos/go-oidc/v3/oidc" - _ "golang.org/x/oauth2" -) diff --git a/internal/auth/oidc/groupclaim/resolver.go b/internal/auth/oidc/groupclaim/resolver.go new file mode 100644 index 0000000..4819366 --- /dev/null +++ b/internal/auth/oidc/groupclaim/resolver.go @@ -0,0 +1,142 @@ +// Package groupclaim resolves the operator-configured `groups_claim_path` +// against an ID token's parsed claims, returning the user's group +// membership as a `[]string`. +// +// Auth Bundle 2 Phase 3 ships this without a JSON-path library +// dependency per the pre-bundle dep audit. The contract is narrow +// enough that ~40 LOC of straight Go covers every documented use case +// (Keycloak, Auth0, Okta, Azure AD, Google Workspace) without the +// transitive footprint or maintenance liability of pulling in +// PaesslerAG/jsonpath, ohler55/ojg, or tidwall/gjson. +// +// Resolution rules: +// +// 1. URL-shape paths (prefix `https://` or `http://`) are treated as a +// single literal key. This handles Auth0's namespaced claims like +// `https://your-namespace/groups`. +// 2. Dot-separated paths (e.g. Keycloak's `realm_access.roles`) are +// split on `.` and walked through nested `map[string]interface{}` +// chains. A non-object segment or missing key fails closed with a +// clear error. +// 3. The resolved value is coerced to `[]string`: +// - `[]string` → as-is. +// - `[]interface{}` of strings → coerced. +// - single `string` → wrapped in a one-element slice. +// - any other type (bool, number, object, nil) → fails closed. +// +// Phase 3 callers MUST treat the empty-result case as fail-closed: no +// session is minted, an audit row records `auth.oidc_login_unmapped_groups` +// (the user's IdP returned a claim but it didn't match any of the +// operator's mappings). +package groupclaim + +import ( + "errors" + "fmt" + "strings" +) + +// Sentinel errors. Service-layer callers branch on these via errors.Is. +var ( + // ErrPathEmpty is returned when the configured path is the empty + // string. The operator API layer + domain Validate() catch this + // upstream; this sentinel exists so the resolver is safe to call + // even with malformed config. + ErrPathEmpty = errors.New("groupclaim: path is empty") + + // ErrSegmentMissing is returned when a path segment doesn't exist + // on the current claims object (e.g. path `realm_access.roles` + // applied to a token without `realm_access`). Phase 3's + // HandleCallback maps to "no groups; fail closed". + ErrSegmentMissing = errors.New("groupclaim: path segment missing") + + // ErrSegmentNotObject is returned when an intermediate path + // segment resolves to a non-object (e.g. trying to walk into a + // string). Indicates the IdP token shape doesn't match the + // operator's configured path. + ErrSegmentNotObject = errors.New("groupclaim: intermediate segment is not an object") + + // ErrInvalidValueType is returned when the resolved value cannot + // be coerced to a string array. Bool, number, object, nil all + // fail closed. + ErrInvalidValueType = errors.New("groupclaim: resolved value is not coercible to []string") +) + +// Resolve walks `path` through `claims` and returns the resolved +// group list. See the package doc for the full contract. +// +// Per Phase 3's "complete path, not easy path" discipline: this +// function does NOT modify `claims` and does NOT log any of its +// inputs. Token-leak hygiene tests assert that paths through this +// function never emit any of `claims`, `path`, or the resolved +// value to the slog buffer. +func Resolve(claims map[string]interface{}, path string) ([]string, error) { + if path == "" { + return nil, ErrPathEmpty + } + + // Rule 1: URL-shape paths are single literal keys. + var segments []string + if isURLShapePath(path) { + segments = []string{path} + } else { + segments = strings.Split(path, ".") + } + + // Walk the segments through the nested map. + var cur interface{} = claims + for i, seg := range segments { + obj, ok := cur.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("%w: segment %q (index %d) applied to non-object", ErrSegmentNotObject, seg, i) + } + next, ok := obj[seg] + if !ok { + return nil, fmt.Errorf("%w: %q at index %d", ErrSegmentMissing, seg, i) + } + cur = next + } + + // Coerce the resolved value to []string. + return coerceStringArray(cur) +} + +// isURLShapePath reports whether path is a URL-shape (Auth0-style +// namespaced claim). Such paths are NOT split on `.`; they're treated +// as a single literal key against the top-level claims map. +func isURLShapePath(path string) bool { + return strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") +} + +// coerceStringArray converts the resolved claim value to []string per +// the rules in the package doc. Fails closed on any other type. +func coerceStringArray(v interface{}) ([]string, error) { + switch x := v.(type) { + case []string: + // Already the right type. Return a copy so the caller can't + // mutate the underlying claims map by surprise. + out := make([]string, len(x)) + copy(out, x) + return out, nil + case []interface{}: + // JSON unmarshal into map[string]interface{} produces + // []interface{} for arrays. Coerce each element to string; + // any non-string element fails the whole resolution. + out := make([]string, 0, len(x)) + for i, e := range x { + s, ok := e.(string) + if !ok { + return nil, fmt.Errorf("%w: element %d is %T not string", ErrInvalidValueType, i, e) + } + out = append(out, s) + } + return out, nil + case string: + // Single string: wrap in a one-element slice. Some IdPs + // return a single role as a bare string rather than a + // one-element array; the resolver normalizes both shapes. + return []string{x}, nil + default: + return nil, fmt.Errorf("%w: got %T", ErrInvalidValueType, v) + } +} diff --git a/internal/auth/oidc/groupclaim/resolver_test.go b/internal/auth/oidc/groupclaim/resolver_test.go new file mode 100644 index 0000000..ec16ed4 --- /dev/null +++ b/internal/auth/oidc/groupclaim/resolver_test.go @@ -0,0 +1,248 @@ +package groupclaim + +import ( + "errors" + "reflect" + "testing" +) + +// ============================================================================= +// Happy-path tests covering the documented IdP shapes. +// ============================================================================= + +// TestResolve_OktaStyleStringArray pins the most common shape: +// {"groups": ["engineers", "platform-admins"]}. +func TestResolve_OktaStyleStringArray(t *testing.T) { + claims := map[string]interface{}{ + "groups": []interface{}{"engineers", "platform-admins"}, + } + got, err := Resolve(claims, "groups") + if err != nil { + t.Fatalf("Resolve: %v", err) + } + want := []string{"engineers", "platform-admins"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +// TestResolve_KeycloakNestedRoles pins the dot-path walk: +// {"realm_access": {"roles": ["admin", "user"]}}. +func TestResolve_KeycloakNestedRoles(t *testing.T) { + claims := map[string]interface{}{ + "realm_access": map[string]interface{}{ + "roles": []interface{}{"admin", "user"}, + }, + } + got, err := Resolve(claims, "realm_access.roles") + if err != nil { + t.Fatalf("Resolve: %v", err) + } + want := []string{"admin", "user"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +// TestResolve_Auth0NamespacedClaim pins the URL-shape literal-key path: +// {"https://your-namespace/groups": ["engineers"]}. +func TestResolve_Auth0NamespacedClaim(t *testing.T) { + claims := map[string]interface{}{ + "https://your-namespace/groups": []interface{}{"engineers"}, + } + got, err := Resolve(claims, "https://your-namespace/groups") + if err != nil { + t.Fatalf("Resolve: %v", err) + } + want := []string{"engineers"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +// TestResolve_HTTPSchemeAlsoTreatedAsLiteral pins that http:// (not just +// https://) triggers the URL-shape path treatment. Some on-prem IdPs +// use http for namespaced claims in dev environments. +func TestResolve_HTTPSchemeAlsoTreatedAsLiteral(t *testing.T) { + claims := map[string]interface{}{ + "http://internal.example.com/groups": []interface{}{"role-a"}, + } + got, err := Resolve(claims, "http://internal.example.com/groups") + if err != nil { + t.Fatalf("Resolve: %v", err) + } + if len(got) != 1 || got[0] != "role-a" { + t.Errorf("got %v, want [role-a]", got) + } +} + +// TestResolve_SingleStringWrapped pins the normalization: some IdPs +// return a single role as a bare string rather than a one-element +// array. The resolver wraps it. +func TestResolve_SingleStringWrapped(t *testing.T) { + claims := map[string]interface{}{ + "role": "admin", + } + got, err := Resolve(claims, "role") + if err != nil { + t.Fatalf("Resolve: %v", err) + } + want := []string{"admin"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +// TestResolve_AlreadyStringSlice covers the rare case where a caller +// pre-coerced []interface{} to []string. The resolver returns a copy. +func TestResolve_AlreadyStringSlice(t *testing.T) { + claims := map[string]interface{}{ + "groups": []string{"a", "b"}, + } + got, err := Resolve(claims, "groups") + if err != nil { + t.Fatalf("Resolve: %v", err) + } + if !reflect.DeepEqual(got, []string{"a", "b"}) { + t.Errorf("got %v, want [a b]", got) + } + // Mutating the result must NOT mutate the input claim. + got[0] = "MUTATED" + if claims["groups"].([]string)[0] == "MUTATED" { + t.Errorf("Resolve returned a slice aliased to the input; mutation leaked back") + } +} + +// TestResolve_EmptyArrayReturnsEmpty pins the documented edge: an IdP +// that returns an empty groups claim is NOT a resolver error; the +// caller (Phase 3 service) decides fail-closed semantics. +func TestResolve_EmptyArrayReturnsEmpty(t *testing.T) { + claims := map[string]interface{}{ + "groups": []interface{}{}, + } + got, err := Resolve(claims, "groups") + if err != nil { + t.Fatalf("Resolve: %v", err) + } + if len(got) != 0 { + t.Errorf("got %v, want []", got) + } +} + +// TestResolve_DeeplyNestedPath pins a 3-segment walk works. +func TestResolve_DeeplyNestedPath(t *testing.T) { + claims := map[string]interface{}{ + "a": map[string]interface{}{ + "b": map[string]interface{}{ + "c": []interface{}{"deep"}, + }, + }, + } + got, err := Resolve(claims, "a.b.c") + if err != nil { + t.Fatalf("Resolve: %v", err) + } + if len(got) != 1 || got[0] != "deep" { + t.Errorf("got %v, want [deep]", got) + } +} + +// ============================================================================= +// Negative paths — every fail-closed branch. +// ============================================================================= + +func TestResolve_EmptyPathRejected(t *testing.T) { + _, err := Resolve(map[string]interface{}{"groups": []interface{}{"x"}}, "") + if !errors.Is(err, ErrPathEmpty) { + t.Errorf("err = %v; want ErrPathEmpty", err) + } +} + +func TestResolve_MissingKeyRejected(t *testing.T) { + claims := map[string]interface{}{"other": "thing"} + _, err := Resolve(claims, "groups") + if !errors.Is(err, ErrSegmentMissing) { + t.Errorf("err = %v; want ErrSegmentMissing", err) + } +} + +func TestResolve_MissingNestedKeyRejected(t *testing.T) { + claims := map[string]interface{}{ + "realm_access": map[string]interface{}{"other": "thing"}, + } + _, err := Resolve(claims, "realm_access.roles") + if !errors.Is(err, ErrSegmentMissing) { + t.Errorf("err = %v; want ErrSegmentMissing", err) + } +} + +func TestResolve_NonObjectIntermediateRejected(t *testing.T) { + // "realm_access" resolves to a string, not an object; can't walk + // further into it. + claims := map[string]interface{}{ + "realm_access": "not-an-object", + } + _, err := Resolve(claims, "realm_access.roles") + if !errors.Is(err, ErrSegmentNotObject) { + t.Errorf("err = %v; want ErrSegmentNotObject", err) + } +} + +func TestResolve_RejectsBoolValue(t *testing.T) { + claims := map[string]interface{}{"groups": true} + _, err := Resolve(claims, "groups") + if !errors.Is(err, ErrInvalidValueType) { + t.Errorf("err = %v; want ErrInvalidValueType", err) + } +} + +func TestResolve_RejectsNumberValue(t *testing.T) { + claims := map[string]interface{}{"groups": 42} + _, err := Resolve(claims, "groups") + if !errors.Is(err, ErrInvalidValueType) { + t.Errorf("err = %v; want ErrInvalidValueType", err) + } +} + +func TestResolve_RejectsObjectValue(t *testing.T) { + claims := map[string]interface{}{"groups": map[string]interface{}{"x": "y"}} + _, err := Resolve(claims, "groups") + if !errors.Is(err, ErrInvalidValueType) { + t.Errorf("err = %v; want ErrInvalidValueType", err) + } +} + +func TestResolve_RejectsNilValue(t *testing.T) { + claims := map[string]interface{}{"groups": nil} + _, err := Resolve(claims, "groups") + if !errors.Is(err, ErrInvalidValueType) { + t.Errorf("err = %v; want ErrInvalidValueType", err) + } +} + +func TestResolve_RejectsArrayWithNonStringElement(t *testing.T) { + claims := map[string]interface{}{ + "groups": []interface{}{"a", 42, "c"}, // 42 is not a string + } + _, err := Resolve(claims, "groups") + if !errors.Is(err, ErrInvalidValueType) { + t.Errorf("err = %v; want ErrInvalidValueType", err) + } +} + +// TestResolve_URLShapeWithDotsInPathTreatedAsLiteral pins the +// disambiguation: a URL-shape path like +// `https://example.com/team.id` must NOT be split on the dot in +// "team.id"; it's a single literal key. +func TestResolve_URLShapeWithDotsInPathTreatedAsLiteral(t *testing.T) { + claims := map[string]interface{}{ + "https://example.com/team.id": []interface{}{"sales"}, + } + got, err := Resolve(claims, "https://example.com/team.id") + if err != nil { + t.Fatalf("Resolve: %v", err) + } + if len(got) != 1 || got[0] != "sales" { + t.Errorf("got %v, want [sales]", got) + } +} diff --git a/internal/auth/oidc/logging_test.go b/internal/auth/oidc/logging_test.go new file mode 100644 index 0000000..ee82eb5 --- /dev/null +++ b/internal/auth/oidc/logging_test.go @@ -0,0 +1,183 @@ +package oidc + +import ( + "bytes" + "context" + "io" + "log/slog" + "strings" + "testing" +) + +// ============================================================================= +// Token-leak hygiene: no secret value (ID token, access token, refresh +// token, authorization code, PKCE verifier, state, nonce, signing key +// material) appears in any log line at any level. +// +// Methodology mirrors Bundle 1's +// internal/auth/bootstrap/service_test.go::TestService_TokenLeakHygiene: +// redirect slog.Default to a buffer, run the OIDC service paths, +// grep-assert the secret string never appears in any captured line. +// +// This is the load-bearing invariant for Phase 3's "tokens never +// logged" contract. Every secret-bearing path that enters the +// service.go code MUST flow through write-once-to-response patterns; +// adding a `slog.Info("got token", "value", token)` somewhere would +// fail this test immediately. +// ============================================================================= + +// captureLogger swaps the slog.Default with one that writes to the +// returned buffer. The returned restore func re-installs the original +// logger; callers must defer it. +func captureLogger(t *testing.T) (*bytes.Buffer, func()) { + t.Helper() + buf := &bytes.Buffer{} + original := slog.Default() + slog.SetDefault(slog.New(slog.NewTextHandler(io.Writer(buf), &slog.HandlerOptions{ + Level: slog.LevelDebug, + }))) + return buf, func() { slog.SetDefault(original) } +} + +// TestLoggingHygiene_HandleAuthRequest_LeaksNothing exercises the full +// HandleAuthRequest path against a mock IdP and asserts that the +// generated state, nonce, PKCE verifier, and pre-login cookie never +// appear in any captured log line. +func TestLoggingHygiene_HandleAuthRequest_LeaksNothing(t *testing.T) { + idp := newMockIdP(t) + svc, _ := newServiceWithProviderAndPL(t, idp.URL(), "op-leak-1") + + buf, restore := captureLogger(t) + defer restore() + + authURL, cookieValue, _, err := svc.HandleAuthRequest(context.Background(), "op-leak-1") + if err != nil { + t.Fatalf("HandleAuthRequest: %v", err) + } + + // Extract state from the authURL query so we can grep-assert. + parts := strings.Split(authURL, "state=") + if len(parts) < 2 { + t.Fatalf("authURL missing state param: %q", authURL) + } + stateValue := strings.SplitN(parts[1], "&", 2)[0] + + captured := buf.String() + for _, secret := range []string{stateValue, cookieValue} { + if secret == "" { + continue + } + if strings.Contains(captured, secret) { + t.Errorf("secret value %q appeared in log output:\n%s", secret, captured) + } + } +} + +// TestLoggingHygiene_HandleCallback_LeaksNothing runs the full callback +// flow (against the mock IdP) and grep-asserts the captured log buffer +// has no occurrence of the access token, the ID token, the +// authorization code, or the PKCE verifier. +func TestLoggingHygiene_HandleCallback_LeaksNothing(t *testing.T) { + idp := newMockIdP(t) + svc, pl := newServiceWithProviderAndPL(t, idp.URL(), "op-leak-2") + + // Pre-login row with a known verifier we can grep for after. + verifier := "test-verifier-do-not-leak-aaaaaaaaaaaaa" + cookie, _, err := pl.CreatePreLogin(context.Background(), "op-leak-2", "the-state", "test-nonce-fixed", verifier) + if err != nil { + t.Fatalf("CreatePreLogin: %v", err) + } + + buf, restore := captureLogger(t) + defer restore() + + authCode := "secret-auth-code-do-not-leak" + res, err := svc.HandleCallback(context.Background(), cookie, authCode, "the-state", "10.0.0.1", "Mozilla") + if err != nil { + t.Fatalf("HandleCallback: %v", err) + } + + captured := buf.String() + + // Direct secrets that flow through HandleCallback's parameter list. + for _, secret := range []string{ + authCode, + verifier, + "test-access-token", + idp.receivedCode, + idp.receivedVerifier, + } { + if secret == "" { + continue + } + if strings.Contains(captured, secret) { + t.Errorf("secret value %q appeared in log output:\n%s", secret, captured) + } + } + + // The session cookie + CSRF token are returned by the mint stub; + // in production they're set on the response, not logged. Pin that + // we never logged them. + for _, secret := range []string{res.CookieValue, res.CSRFToken} { + if secret == "" { + continue + } + if strings.Contains(captured, secret) { + t.Errorf("session secret %q appeared in log output:\n%s", secret, captured) + } + } +} + +// TestLoggingHygiene_AlgPinningDoesNotLogAlg is a defense-in-depth pin: +// when isDisallowedAlg rejects a token, the alg name might land in an +// error returned to the handler — but the service.go MUST NOT log the +// alg value itself (an attacker could probe to discover allow-list +// composition). The handler maps to a uniform 400; alg detail lives +// only in audit rows the operator owns. +func TestLoggingHygiene_AlgRejectionDoesNotLogAlg(t *testing.T) { + buf, restore := captureLogger(t) + defer restore() + + // Direct call to the helper; this exercises the deny-list match. + _, _ = isDisallowedAlg("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.body.sig") + + captured := buf.String() + if strings.Contains(captured, "HS256") { + t.Errorf("alg value HS256 appeared in log output (defense-in-depth violation):\n%s", captured) + } +} + +// TestLoggingHygiene_ProviderLoadDoesNotLogClientSecret pins that +// even on getOrLoad failures, the decrypted client_secret bytes never +// land in a log line. Decryption happens before verifier construction; +// any error path that flows through must not surface the plaintext. +func TestLoggingHygiene_ProviderLoadDoesNotLogClientSecret(t *testing.T) { + idp := newMockIdP(t) + + // Use a provider with a recognizable plaintext "secret" (no encryption + // key set, so decryptClientSecret returns the bytes as-is). + prov := makeProvider(idp.URL(), "op-leak-secret") + prov.ClientSecretEncrypted = []byte("client-secret-plaintext-do-not-leak-xxxxx") + + pl := newStubPreLogin() + svc := NewService( + &stubProviderLookup{provider: prov}, + &stubMappings{roleIDs: []string{"r-operator"}}, + newStubUsers(), + &stubSessions{}, + pl, + "", + ) + + buf, restore := captureLogger(t) + defer restore() + + if _, err := svc.getOrLoad(context.Background(), "op-leak-secret"); err != nil { + t.Fatalf("getOrLoad: %v", err) + } + + captured := buf.String() + if strings.Contains(captured, "client-secret-plaintext-do-not-leak") { + t.Errorf("client secret plaintext appeared in log output:\n%s", captured) + } +} diff --git a/internal/auth/oidc/service.go b/internal/auth/oidc/service.go new file mode 100644 index 0000000..dce6bfc --- /dev/null +++ b/internal/auth/oidc/service.go @@ -0,0 +1,847 @@ +package oidc + +import ( + "context" + cryptorand "crypto/rand" + "crypto/sha256" + "crypto/sha512" + "crypto/subtle" + "encoding/base64" + "errors" + "fmt" + "hash" + "strings" + "sync" + "time" + + gooidc "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" + + oidcdomain "github.com/certctl-io/certctl/internal/auth/oidc/domain" + "github.com/certctl-io/certctl/internal/auth/oidc/groupclaim" + userdomain "github.com/certctl-io/certctl/internal/auth/user/domain" + "github.com/certctl-io/certctl/internal/crypto" + "github.com/certctl-io/certctl/internal/repository" +) + +// ============================================================================= +// Auth Bundle 2 / Phase 3 / OIDC Service +// +// The Service implements the certctl side of the OpenID Connect 1.0 +// authorization-code flow with PKCE-S256 (RFC 7636), against any IdP +// that satisfies the OIDC discovery doc + JWKS contract. Token +// validation enforces every fail-closed check from OIDC core +// §3.1.3.7 plus the operator-policy gates (alg allow-list, audience, +// `azp` for multi-aud tokens, `at_hash` when access tokens are +// returned, `iat` window, `nonce`, single-use state). +// +// Security posture: +// +// 1. JWKS endpoints MUST be HTTPS (validated at provider creation +// by the domain layer; transport never weakened). +// 2. PKCE S256 is REQUIRED on every login per RFC 9700 §2.1.1; +// the `plain` challenge method is rejected. +// 3. State is server-generated random 32 bytes (256 bits of +// entropy), single-use, stored in the pre-login session row. +// 4. Nonce is server-generated random 32 bytes, single-use, +// stored in the pre-login session row, validated against the +// ID token nonce claim via constant-time compare. +// 5. Algorithms are pinned to an allow-list (default: RS256, RS512, +// ES256, ES384, EdDSA). HS256/HS384/HS512 are NEVER allowed +// (HMAC + JWKS is alg confusion); `none` is NEVER allowed. +// 6. IdP downgrade-attack defense: at provider creation / +// RefreshKeys, the discovery doc's +// `id_token_signing_alg_values_supported` is intersected with +// the allow-list. If the IdP advertises HS* / none AT ALL, the +// provider is rejected with an actionable error so a future +// compromised IdP can't downgrade. +// 7. JWKS handling delegated to coreos/go-oidc/v3; on JWKS fetch +// failure during a key rotation the service returns +// ErrJWKSUnreachable (HTTP 503), existing sessions untouched, +// no exponential backoff. +// 8. Token-leak hygiene: ID tokens, access tokens, refresh tokens, +// authorization codes, PKCE verifiers, state, nonce, and any +// signing key bytes MUST NEVER be logged. The service contains +// ZERO log statements that include these values; tests in +// logging_test.go pin the invariant. +// ============================================================================= + +// Service implements the OIDC integration. +type Service struct { + providers OIDCProviderLookup + mappings repository.GroupRoleMappingRepository + users repository.UserRepository + sessions SessionMinter + preLogin PreLoginStore + + encryptionKey string // CERTCTL_CONFIG_ENCRYPTION_KEY for client_secret decrypt + + mu sync.RWMutex + cache map[string]*providerEntry // keyed by provider ID + clockNow func() time.Time // injectable for tests +} + +// providerEntry caches the go-oidc Provider + the OAuth2 config + the +// IdP-advertised algs (used for the downgrade-attack defense check on +// every RefreshKeys). The Provider's internal JWKS cache handles +// rotation transparently. +type providerEntry struct { + cfgRow *oidcdomain.OIDCProvider + provider *gooidc.Provider + verifier *gooidc.IDTokenVerifier + oauthConfig *oauth2.Config + allowedAlgs []string // intersected: domain config ∩ allow-list ∩ IdP-advertised + plaintext []byte // decrypted client secret; held for token exchange +} + +// OIDCProviderLookup is a narrow read-side projection of +// repository.OIDCProviderRepository — service.go only ever reads +// providers; mutations go through the repo from the handler / GUI side. +// Defined here so test mocks can satisfy the smaller surface. +type OIDCProviderLookup interface { + Get(ctx context.Context, id string) (*oidcdomain.OIDCProvider, error) + List(ctx context.Context, tenantID string) ([]*oidcdomain.OIDCProvider, error) +} + +// PreLoginStore wraps the pre-login session row that holds state + +// nonce + PKCE verifier across the IdP redirect. Phase 4's +// SessionService satisfies this interface; Phase 3 defines it so the +// Service can be unit-tested without the full session machinery. +type PreLoginStore interface { + // CreatePreLogin persists a row with the given identifiers. + // providerID is the configured op-... id; state, nonce, verifier + // are server-generated random strings the callback will validate. + // Returns the opaque cookie value the handler sets, plus the + // session ID (used as the audit trail anchor). + CreatePreLogin(ctx context.Context, providerID, state, nonce, verifier string) (cookieValue, sessionID string, err error) + + // LookupAndConsume reads the pre-login row by cookie value AND + // deletes it atomically. Single-use: a second call with the same + // cookie value returns ErrPreLoginNotFound. Returns the stored + // state/nonce/verifier/providerID for the caller to validate + // against the callback parameters. + LookupAndConsume(ctx context.Context, cookieValue string) (providerID, state, nonce, verifier string, err error) +} + +// SessionMinter wraps the post-login session creation. Phase 4's +// SessionService satisfies this. Defined here so the OIDC service +// can be unit-tested independently of session signing. +type SessionMinter interface { + // MintForUser creates a post-login session for the named user. + // Returns the cookie value the handler sets and a CSRF token + // the GUI echoes into the X-CSRF-Token header on POSTs. + MintForUser(ctx context.Context, user *userdomain.User, roleIDs []string, ip, userAgent string) (cookieValue, csrfToken string, err error) +} + +// IDGenerator returns a new opaque session id. Defaults to 32 random +// bytes base64url-no-pad-encoded. Injectable for tests. +type IDGenerator func() (string, error) + +// Service-layer sentinels. Handler-layer translates to HTTP status. +var ( + // ErrPreLoginNotFound: the pre-login cookie doesn't match a row. + // Either the row was already consumed (replay) or never existed + // (forged cookie). HTTP 400. + ErrPreLoginNotFound = errors.New("oidc: pre-login session not found or already consumed") + + // ErrStateMismatch: callback `state` differs from the stored + // pre-login state. HTTP 400. + ErrStateMismatch = errors.New("oidc: state parameter mismatch (replay or forgery)") + + // ErrNonceMismatch: ID token `nonce` differs from the stored + // pre-login nonce. HTTP 400. + ErrNonceMismatch = errors.New("oidc: nonce mismatch") + + // ErrIssuerMismatch: ID token `iss` doesn't match the configured + // provider issuer_url. HTTP 400. + ErrIssuerMismatch = errors.New("oidc: issuer mismatch") + + // ErrAudienceMismatch: ID token `aud` doesn't include the + // configured client_id. HTTP 400. + ErrAudienceMismatch = errors.New("oidc: audience mismatch") + + // ErrAZPRequired: ID token has multi-valued aud but no `azp` + // claim. Per OIDC core §3.1.3.7 step 5, `azp` MUST be present + // when there are multiple audiences. HTTP 400. + ErrAZPRequired = errors.New("oidc: multi-aud ID token missing required azp claim") + + // ErrAZPMismatch: ID token `azp` doesn't equal client_id. HTTP 400. + ErrAZPMismatch = errors.New("oidc: azp claim does not match client_id") + + // ErrATHashMismatch: ID token `at_hash` doesn't match the + // re-computed hash of the access token. HTTP 400. + ErrATHashMismatch = errors.New("oidc: at_hash claim does not match access token") + + // ErrATHashRequired: an access token was returned alongside the ID + // token but the ID token carries no `at_hash` claim. Per the Phase 3 + // spec (OIDC core §3.1.3.6 + §3.2.2.9), at_hash is REQUIRED in this + // case so a substituted access token can be detected. Fail closed. + // HTTP 400. + ErrATHashRequired = errors.New("oidc: access_token present but ID token has no at_hash claim") + + // ErrTokenExpired: ID token `exp` is in the past (with 60s + // clock-skew tolerance). HTTP 400. + ErrTokenExpired = errors.New("oidc: ID token expired") + + // ErrIATInFuture: ID token `iat` is in the future beyond the 60s + // skew tolerance. HTTP 400. + ErrIATInFuture = errors.New("oidc: ID token iat is in the future") + + // ErrIATTooOld: ID token `iat` is older than the configured + // IATWindow. HTTP 400. + ErrIATTooOld = errors.New("oidc: ID token iat older than configured window") + + // ErrAlgRejected: ID token signed with an alg outside the + // allow-list. HTTP 400. + ErrAlgRejected = errors.New("oidc: ID token signed with disallowed algorithm") + + // ErrIdPDowngradeAdvertised: provider's discovery doc advertises + // HS* or `none` algorithms. Provider creation / refresh rejects. + // HTTP 400. + ErrIdPDowngradeAdvertised = errors.New("oidc: IdP advertises weak signing algorithms (HS*/none); refusing to use as defense against downgrade attacks") + + // ErrJWKSUnreachable: JWKS endpoint fetch failed during a + // rotation. The in-flight login fails 503; existing sessions + // untouched. + ErrJWKSUnreachable = errors.New("oidc: JWKS endpoint unreachable; in-flight login fails, existing sessions untouched") + + // ErrGroupsMissing: the configured groups_claim_path resolves + // to nothing or is malformed. Phase 3 fails closed. + ErrGroupsMissing = errors.New("oidc: configured groups claim missing or malformed") + + // ErrGroupsUnmapped: the user's groups don't match any of the + // operator's group_role_mappings for this provider. No session + // minted; audit row records auth.oidc_login_unmapped_groups. + ErrGroupsUnmapped = errors.New("oidc: groups did not match any configured mapping") + + // ErrPKCEPlainRejected: somehow `plain` PKCE method got into + // the flow. Defense-in-depth; the service NEVER generates a plain + // verifier, but this sentinel exists in case a future code path + // regresses. + ErrPKCEPlainRejected = errors.New("oidc: PKCE method 'plain' is rejected; S256 is mandatory") +) + +// DefaultAllowedAlgs is the operator-default ID-token signing algorithm +// allow-list. Configurable per-provider but the union must be a subset +// of this set. HMAC algorithms (HS256/HS384/HS512) and `none` are +// NEVER in the default set; the IdP downgrade defense rejects any +// provider that advertises them in discovery. +var DefaultAllowedAlgs = []string{ + gooidc.RS256, gooidc.RS512, + gooidc.ES256, gooidc.ES384, + gooidc.EdDSA, +} + +// disallowedAlgs is the explicit deny-list. Anything in this set +// fails the IdP downgrade check at provider creation / RefreshKeys +// AND fails the per-token alg check at HandleCallback time, even if +// the operator somehow added it to AllowedAlgs by hand. +var disallowedAlgs = map[string]struct{}{ + "HS256": {}, + "HS384": {}, + "HS512": {}, + "none": {}, +} + +// NewService constructs an OIDC Service. +func NewService( + providers OIDCProviderLookup, + mappings repository.GroupRoleMappingRepository, + users repository.UserRepository, + sessions SessionMinter, + preLogin PreLoginStore, + encryptionKey string, +) *Service { + return &Service{ + providers: providers, + mappings: mappings, + users: users, + sessions: sessions, + preLogin: preLogin, + encryptionKey: encryptionKey, + cache: make(map[string]*providerEntry), + clockNow: time.Now, + } +} + +// SetClockForTest replaces the clock used for `iat`/`exp` checks. ONLY +// for tests; production paths read time.Now via the default. +func (s *Service) SetClockForTest(now func() time.Time) { + s.clockNow = now +} + +// ============================================================================= +// HandleAuthRequest: kicks off the OIDC handshake. +// +// Returns the IdP authorization URL (302 target), the cookie value to +// set for the pre-login session, and the pre-login session ID for the +// audit trail. The caller (HTTP handler) sets the cookie + redirects. +// +// PKCE-S256 is mandatory: a 43-128 character base64url-no-pad random +// verifier is generated, the challenge is the SHA-256 of the verifier +// base64url-encoded, the method is hard-coded `S256`. No code path in +// this service ever sets `code_challenge_method=plain`. +// ============================================================================= + +// HandleAuthRequest builds the IdP redirect URL + persists the +// pre-login session row holding state + nonce + PKCE verifier. +func (s *Service) HandleAuthRequest(ctx context.Context, providerID string) (authURL, cookieValue, preLoginID string, err error) { + entry, err := s.getOrLoad(ctx, providerID) + if err != nil { + return "", "", "", err + } + + state, err := randomB64URL(32) + if err != nil { + return "", "", "", fmt.Errorf("oidc: state generate: %w", err) + } + nonce, err := randomB64URL(32) + if err != nil { + return "", "", "", fmt.Errorf("oidc: nonce generate: %w", err) + } + // PKCE S256 verifier: 32 random bytes -> 43-char base64url-no-pad + // (well within the RFC 7636 43-128 character bound). + verifier := oauth2.GenerateVerifier() + + cookieValue, preLoginID, err = s.preLogin.CreatePreLogin(ctx, providerID, state, nonce, verifier) + if err != nil { + return "", "", "", fmt.Errorf("oidc: pre-login store: %w", err) + } + + // Build the IdP redirect URL. PKCE S256 is hard-coded via + // oauth2.S256ChallengeOption; nonce is added via OIDC's + // AuthCodeOption. + authURL = entry.oauthConfig.AuthCodeURL( + state, + oauth2.AccessTypeOnline, + oauth2.S256ChallengeOption(verifier), + oauth2.SetAuthURLParam("nonce", nonce), + ) + + return authURL, cookieValue, preLoginID, nil +} + +// ============================================================================= +// HandleCallback: completes the OIDC handshake and creates a session. +// +// Validates state, exchanges code for tokens (with PKCE verifier), +// validates ID token (alg pin, iss, aud, azp, at_hash, exp, iat, +// nonce), parses group claims, maps groups to roles, creates / updates +// the user record, mints a session. +// +// Every fail-closed branch returns one of the package-scoped sentinel +// errors so the handler can map to the right HTTP status without +// leaking which check failed (uniform 400 to the wire; specific +// reason in the audit row). +// ============================================================================= + +// CallbackResult is what HandleCallback returns to the handler. The +// handler sets cookieValue + csrfToken on the response and 302's to +// the GUI dashboard. +type CallbackResult struct { + User *userdomain.User + RoleIDs []string + CookieValue string // post-login session cookie + CSRFToken string // CSRF token for the GUI to echo into X-CSRF-Token +} + +// HandleCallback completes the OIDC flow. +func (s *Service) HandleCallback( + ctx context.Context, + preLoginCookie, code, callbackState, ip, userAgent string, +) (*CallbackResult, error) { + // Step 1: consume the pre-login row (single-use). + providerID, storedState, storedNonce, verifier, err := s.preLogin.LookupAndConsume(ctx, preLoginCookie) + if err != nil { + return nil, ErrPreLoginNotFound + } + + // Step 2: state constant-time compare. + if subtle.ConstantTimeCompare([]byte(callbackState), []byte(storedState)) != 1 { + return nil, ErrStateMismatch + } + + entry, err := s.getOrLoad(ctx, providerID) + if err != nil { + return nil, err + } + + // Step 3: exchange the auth code for tokens (with PKCE verifier). + token, err := entry.oauthConfig.Exchange(ctx, code, oauth2.VerifierOption(verifier)) + if err != nil { + return nil, fmt.Errorf("oidc: code exchange failed: %w", err) + } + + // Step 4: extract + validate the ID token. NEVER log token here. + rawIDToken, ok := token.Extra("id_token").(string) + if !ok || rawIDToken == "" { + return nil, fmt.Errorf("oidc: token response missing id_token") + } + + idToken, err := entry.verifier.Verify(ctx, rawIDToken) + if err != nil { + // Map go-oidc's verify errors to ErrJWKSUnreachable when the + // underlying cause is a JWKS fetch failure; otherwise return + // the wrapped error for the handler to map to 400. + if isJWKSFetchError(err) { + return nil, ErrJWKSUnreachable + } + return nil, fmt.Errorf("oidc: id_token verify failed: %w", err) + } + + // Step 5: alg pinning. go-oidc's verifier already enforces the + // allow-list we set in the config, but we re-check the header alg + // against our deny-list for belt-and-braces (defense vs an + // upstream library regression). + if rejected, alg := isDisallowedAlg(rawIDToken); rejected { + _ = alg // do not log + return nil, ErrAlgRejected + } + + // Step 6: per-OIDC-core §3.1.3.7 claims checks beyond what + // gooidc.Verify covers. + now := s.clockNow().UTC() + + // iss is verified by gooidc.Verify against entry.cfgRow.IssuerURL; + // re-check exactly to defend against a library regression. + if idToken.Issuer != entry.cfgRow.IssuerURL { + return nil, ErrIssuerMismatch + } + + // aud must contain client_id. + audOK := false + for _, a := range idToken.Audience { + if a == entry.cfgRow.ClientID { + audOK = true + break + } + } + if !audOK { + return nil, ErrAudienceMismatch + } + + // azp required when aud is multi-valued; if present, must equal client_id. + var extra struct { + AZP string `json:"azp"` + ATHash string `json:"at_hash"` + Nonce string `json:"nonce"` + } + if err := idToken.Claims(&extra); err != nil { + return nil, fmt.Errorf("oidc: id_token claims unmarshal: %w", err) + } + if len(idToken.Audience) > 1 { + if extra.AZP == "" { + return nil, ErrAZPRequired + } + } + if extra.AZP != "" && extra.AZP != entry.cfgRow.ClientID { + return nil, ErrAZPMismatch + } + + // at_hash validation. When an access token is returned alongside the + // ID token, OIDC core §3.1.3.6 + §3.2.2.9 require the ID token to + // carry an at_hash claim that hashes the access token (alg-matching + // hash family, left-half, base64url-no-pad). The Phase 3 spec lifts + // this from the RFC's "MAY" to a "MUST" so a substituted access + // token cannot ride a clean ID token through the verifier. + if token.AccessToken != "" { + if extra.ATHash == "" { + return nil, ErrATHashRequired + } + if !atHashMatches(rawIDToken, token.AccessToken, extra.ATHash) { + return nil, ErrATHashMismatch + } + } + + // exp + iat (60s clock skew tolerance). + const skew = 60 * time.Second + if idToken.Expiry.Add(skew).Before(now) { + return nil, ErrTokenExpired + } + if idToken.IssuedAt.After(now.Add(skew)) { + return nil, ErrIATInFuture + } + iatWindow := time.Duration(entry.cfgRow.IATWindowSeconds) * time.Second + if idToken.IssuedAt.Add(iatWindow).Before(now) { + return nil, ErrIATTooOld + } + + // nonce constant-time compare. + if subtle.ConstantTimeCompare([]byte(extra.Nonce), []byte(storedNonce)) != 1 { + return nil, ErrNonceMismatch + } + + // Step 7: extract claims for group resolution + user record. + var profile struct { + Email string `json:"email"` + Name string `json:"name"` + PreferredUsername string `json:"preferred_username"` + Raw map[string]interface{} `json:"-"` + } + if err := idToken.Claims(&profile); err != nil { + return nil, fmt.Errorf("oidc: profile claims unmarshal: %w", err) + } + var raw map[string]interface{} + if err := idToken.Claims(&raw); err != nil { + return nil, fmt.Errorf("oidc: raw claims unmarshal: %w", err) + } + profile.Raw = raw + + // Step 8: group claim resolution. + groups, err := groupclaim.Resolve(profile.Raw, entry.cfgRow.GroupsClaimPath) + if err != nil || len(groups) == 0 { + // Try the userinfo endpoint fallback if the operator opted in. + if entry.cfgRow.FetchUserinfo { + groups2, uerr := s.fetchUserinfoGroups(ctx, entry, token, entry.cfgRow.GroupsClaimPath) + if uerr == nil && len(groups2) > 0 { + groups = groups2 + } else { + return nil, ErrGroupsMissing + } + } else { + return nil, ErrGroupsMissing + } + } + + // Step 9: map groups to role IDs. Empty result => fail closed. + roleIDs, err := s.mappings.Map(ctx, providerID, groups) + if err != nil { + return nil, fmt.Errorf("oidc: group-role mapping lookup: %w", err) + } + if len(roleIDs) == 0 { + return nil, ErrGroupsUnmapped + } + + // Step 10: upsert the user record. Per Phase 1 contract, identity + // is per-(provider, oidc_subject); a person logging in via a new + // provider gets a new users row. + user, err := s.upsertUser(ctx, entry.cfgRow, idToken.Subject, profile.Email, profile.Name, profile.PreferredUsername) + if err != nil { + return nil, fmt.Errorf("oidc: upsert user: %w", err) + } + + // Step 11: mint a post-login session via Phase 4's SessionService. + cookieValue, csrfToken, err := s.sessions.MintForUser(ctx, user, roleIDs, ip, userAgent) + if err != nil { + return nil, fmt.Errorf("oidc: session mint: %w", err) + } + + return &CallbackResult{ + User: user, + RoleIDs: roleIDs, + CookieValue: cookieValue, + CSRFToken: csrfToken, + }, nil +} + +// upsertUser looks up by (provider, subject) and either updates the +// existing user or creates a new one. last_login_at is bumped on every +// login. +func (s *Service) upsertUser( + ctx context.Context, + provider *oidcdomain.OIDCProvider, + subject, email, displayName, fallbackName string, +) (*userdomain.User, error) { + if displayName == "" { + displayName = fallbackName + } + if displayName == "" { + displayName = email + } + + existing, err := s.users.GetByOIDCSubject(ctx, provider.ID, subject) + if err == nil { + // Update last_login_at, email, display_name (per the Phase 1 + // mutable-field contract). + existing.Email = email + existing.DisplayName = displayName + existing.LastLoginAt = s.clockNow().UTC() + if uerr := s.users.Update(ctx, existing); uerr != nil { + return nil, uerr + } + return existing, nil + } + if !errors.Is(err, repository.ErrUserNotFound) { + return nil, err + } + + // First login: create a new user record. + id, err := randomB64URL(16) + if err != nil { + return nil, fmt.Errorf("oidc: user id generate: %w", err) + } + u := &userdomain.User{ + ID: "u-" + id, + TenantID: provider.TenantID, + Email: email, + DisplayName: displayName, + OIDCSubject: subject, + OIDCProviderID: provider.ID, + LastLoginAt: s.clockNow().UTC(), + WebAuthnCredentials: []byte("[]"), + } + if verr := u.Validate(); verr != nil { + return nil, fmt.Errorf("oidc: new user validate: %w", verr) + } + if cerr := s.users.Create(ctx, u); cerr != nil { + return nil, cerr + } + return u, nil +} + +// fetchUserinfoGroups falls back to the IdP userinfo endpoint when +// the operator opts in via fetch_userinfo=true AND the ID token +// didn't surface the groups claim. Returns the group list resolved +// against groups_claim_path. +func (s *Service) fetchUserinfoGroups( + ctx context.Context, + entry *providerEntry, + token *oauth2.Token, + path string, +) ([]string, error) { + if entry.provider.UserInfoEndpoint() == "" { + return nil, fmt.Errorf("oidc: userinfo fallback configured but provider has no userinfo endpoint") + } + ts := entry.oauthConfig.TokenSource(ctx, token) + uinfo, err := entry.provider.UserInfo(ctx, ts) + if err != nil { + return nil, fmt.Errorf("oidc: userinfo fetch: %w", err) + } + var raw map[string]interface{} + if err := uinfo.Claims(&raw); err != nil { + return nil, fmt.Errorf("oidc: userinfo claims: %w", err) + } + return groupclaim.Resolve(raw, path) +} + +// ============================================================================= +// RefreshKeys: explicitly invalidate + refetch the cached provider. +// +// Used by the GUI's "Refresh discovery cache" button (Phase 8) when an +// operator knows the IdP rotated its keys mid-day and the JWKS cache +// is stale. Re-runs the IdP downgrade-attack defense too: if the IdP +// rotated in HS* / `none` advertisement, we catch it here. +// ============================================================================= + +// RefreshKeys evicts the cached provider entry and re-loads it from +// scratch. Invokes the discovery doc fetch + the downgrade defense. +func (s *Service) RefreshKeys(ctx context.Context, providerID string) error { + s.mu.Lock() + delete(s.cache, providerID) + s.mu.Unlock() + + _, err := s.getOrLoad(ctx, providerID) + return err +} + +// ============================================================================= +// Provider load + cache + IdP downgrade defense. +// ============================================================================= + +// getOrLoad returns a cached provider entry, loading from the repo + +// fetching the IdP discovery doc on miss. Cache uses a write-then-read +// pattern under sync.RWMutex; concurrent first-loads of the same +// provider may duplicate the discovery fetch but never produce +// divergent cache entries (the second-arriving entry overwrites and +// both entries are equivalent). +func (s *Service) getOrLoad(ctx context.Context, providerID string) (*providerEntry, error) { + s.mu.RLock() + entry, ok := s.cache[providerID] + s.mu.RUnlock() + if ok { + return entry, nil + } + + // Read the configured row. + cfgRow, err := s.providers.Get(ctx, providerID) + if err != nil { + return nil, err + } + + // Fetch + cache the discovery doc + JWKS via go-oidc. + provider, err := gooidc.NewProvider(ctx, cfgRow.IssuerURL) + if err != nil { + return nil, fmt.Errorf("oidc: discovery fetch failed for %s: %w", providerID, err) + } + + // IdP downgrade-attack defense. The discovery doc's + // id_token_signing_alg_values_supported MUST NOT include any + // disallowed alg. + var advertised struct { + IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"` + } + if cerr := provider.Claims(&advertised); cerr != nil { + return nil, fmt.Errorf("oidc: discovery claims: %w", cerr) + } + for _, a := range advertised.IDTokenSigningAlgValuesSupported { + if _, deny := disallowedAlgs[a]; deny { + return nil, fmt.Errorf("%w: %s", ErrIdPDowngradeAdvertised, a) + } + } + + // Compute the effective allow-list: intersection of the default + // allow-list AND any operator-configured restriction (currently + // the domain layer doesn't expose per-provider alg config beyond + // the default; placeholder for a future Phase-3-extended config). + allowed := DefaultAllowedAlgs + + // Decrypt the client secret. The plaintext is held in memory only; + // never persisted, never logged. + plaintext, err := decryptClientSecret(cfgRow.ClientSecretEncrypted, s.encryptionKey) + if err != nil { + return nil, fmt.Errorf("oidc: client_secret decrypt: %w", err) + } + + verifier := provider.Verifier(&gooidc.Config{ + ClientID: cfgRow.ClientID, + SupportedSigningAlgs: allowed, + }) + + oauthConfig := &oauth2.Config{ + ClientID: cfgRow.ClientID, + ClientSecret: string(plaintext), + Endpoint: provider.Endpoint(), + RedirectURL: cfgRow.RedirectURI, + Scopes: cfgRow.Scopes, + } + + entry = &providerEntry{ + cfgRow: cfgRow, + provider: provider, + verifier: verifier, + oauthConfig: oauthConfig, + allowedAlgs: allowed, + plaintext: plaintext, + } + + s.mu.Lock() + s.cache[providerID] = entry + s.mu.Unlock() + + return entry, nil +} + +// ============================================================================= +// Helpers (alg parsing, at_hash, random, JWKS-error detection, +// client_secret decrypt). Kept private; tests in service_test.go. +// ============================================================================= + +// randomB64URL returns nbytes of cryptographic randomness encoded as +// base64url-no-pad. Used for state, nonce, session IDs. +func randomB64URL(nbytes int) (string, error) { + b := make([]byte, nbytes) + if _, err := readRand(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// readRand is a package-level seam so tests can deterministically +// substitute crypto/rand. Production reads from crypto/rand.Reader. +var readRand = func(b []byte) (int, error) { + return cryptorand.Read(b) +} + +// isDisallowedAlg parses the JWS header alg and reports whether it's +// in the deny-list. NEVER returns or logs the alg; the caller maps +// the bool to ErrAlgRejected without surfacing details. +func isDisallowedAlg(rawJWT string) (bool, string) { + // JWS Compact:
... Decode header, + // extract `alg`. Defensive: catches bad input shapes too. + parts := strings.Split(rawJWT, ".") + if len(parts) != 3 { + return true, "" + } + headerJSON, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + return true, "" + } + // Find the alg value. Extreme minimal parser: avoid pulling in + // encoding/json so the path is allocation-tight on every login. + // Format: {"alg":"RS256",...}; some libraries emit + // {"alg" : "RS256" ,...} so the parser tolerates whitespace + // around both the colon and the value. + hdr := string(headerJSON) + idx := strings.Index(hdr, `"alg"`) + if idx < 0 { + return true, "" + } + rest := hdr[idx+5:] // skip "alg" + rest = strings.TrimLeft(rest, " \t\r\n") + if !strings.HasPrefix(rest, ":") { + return true, "" + } + rest = rest[1:] + rest = strings.TrimLeft(rest, " \t\r\n") + if !strings.HasPrefix(rest, `"`) { + return true, "" + } + rest = rest[1:] + end := strings.Index(rest, `"`) + if end < 0 { + return true, "" + } + alg := rest[:end] + if _, deny := disallowedAlgs[alg]; deny { + return true, alg + } + return false, alg +} + +// atHashMatches recomputes at_hash per OIDC core §3.1.3.6 + §3.2.2.9 +// and constant-time-compares against the claim. Algorithm matches the +// hash family of the ID token's signing alg (RS256 -> SHA-256, RS512 +// -> SHA-512, ES256 -> SHA-256, ES384 -> SHA-384, EdDSA -> SHA-512). +// Returns true iff the recomputed half-hash equals the claim. +func atHashMatches(rawIDToken, accessToken, claimAtHash string) bool { + _, alg := isDisallowedAlg(rawIDToken) // re-extracts alg + var h hash.Hash + switch alg { + case "RS256", "ES256": + h = sha256.New() + case "ES384": + h = sha512.New384() + case "RS512", "EdDSA": + h = sha512.New() + default: + // Unknown alg should already have been caught by the + // alg-pin check; refuse to recompute here. + return false + } + h.Write([]byte(accessToken)) + sum := h.Sum(nil) + half := sum[:len(sum)/2] + expected := base64.RawURLEncoding.EncodeToString(half) + return subtle.ConstantTimeCompare([]byte(expected), []byte(claimAtHash)) == 1 +} + +// isJWKSFetchError detects whether the underlying error from +// gooidc.IDTokenVerifier.Verify is a JWKS-fetch failure (network +// error talking to the IdP's jwks_uri during a key rotation event). +// Maps to ErrJWKSUnreachable so the handler returns 503 to the +// in-flight login attempt without auto-revoking existing sessions. +func isJWKSFetchError(err error) bool { + if err == nil { + return false + } + msg := err.Error() + return strings.Contains(msg, "fetching keys") || + strings.Contains(msg, "jwks_uri") || + strings.Contains(msg, "key set") +} + +// decryptClientSecret runs the client_secret_encrypted blob through +// internal/crypto/encryption.go's v2 Decrypt path. The plaintext +// MUST NOT be logged or written anywhere except oauthConfig.ClientSecret. +func decryptClientSecret(blob []byte, key string) ([]byte, error) { + if key == "" { + // Test path / local dev: blob is already the plaintext (the + // caller didn't run it through Encrypt). Return as-is. + return blob, nil + } + plain, err := crypto.DecryptIfKeySet(blob, key) + if err != nil { + return nil, err + } + return plain, nil +} diff --git a/internal/auth/oidc/service_test.go b/internal/auth/oidc/service_test.go new file mode 100644 index 0000000..29a1111 --- /dev/null +++ b/internal/auth/oidc/service_test.go @@ -0,0 +1,1593 @@ +package oidc + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/sha512" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "hash" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" + + oidcdomain "github.com/certctl-io/certctl/internal/auth/oidc/domain" + userdomain "github.com/certctl-io/certctl/internal/auth/user/domain" + cryptopkg "github.com/certctl-io/certctl/internal/crypto" + "github.com/certctl-io/certctl/internal/repository" +) + +// sha384New returns a SHA-384 hash via crypto/sha512 (Go stdlib). +func sha384New() hash.Hash { return sha512.New384() } + +// sha512New returns a SHA-512 hash. Helper named to mirror sha384New. +func sha512New() hash.Hash { return sha512.New() } + +// ============================================================================= +// Mock IdP test fixture +// +// Spins up an httptest.Server that serves the OIDC discovery doc + JWKS +// + a token endpoint that returns server-signed ID tokens. Lets us +// drive the full OIDC service.HandleCallback path without a live IdP. +// Used by the audience / issuer / nonce / azp / at_hash / iat negative +// tests below. +// ============================================================================= + +type mockIdP struct { + server *httptest.Server + key *rsa.PrivateKey + signer jose.Signer + keyID string + + // Per-request token customization. Tests set these before calling + // HandleCallback to inject the specific malformity. + overrideAudience []string + overrideIssuer string + overrideNonce string + overrideAZP string + overrideExp time.Time + overrideIAT time.Time + overrideSubject string + overrideEmail string + overrideGroups []string + overrideATHash string // when set, injected as the id_token at_hash claim + overrideName string // when set to a sentinel "", emits empty name + + // advertisedAlgs controls what id_token_signing_alg_values_supported + // reports in the discovery doc. Tests set ["HS256"] to trigger the + // downgrade-attack defense. + advertisedAlgs []string + + // omitUserinfoEndpoint suppresses listing the userinfo endpoint in + // the discovery doc. Used to test the "userinfo fallback configured + // but provider has no userinfo endpoint" branch in fetchUserinfoGroups. + omitUserinfoEndpoint bool + + // userinfoGroups is what the /userinfo endpoint returns under the + // `groups` claim. Empty (default) means the endpoint returns a + // response without a `groups` claim at all. + userinfoGroups []string + + // userinfoFails causes /userinfo to return HTTP 500. Used to + // exercise fetchUserinfoGroups's UserInfo-fetch error wrap. + userinfoFails bool + + // suppressIDToken causes /token to return a response WITHOUT an + // id_token field. Used to test the "token response missing + // id_token" branch in HandleCallback. + suppressIDToken bool + + // Captured to assert the PKCE verifier round-trip + return a stub + // access_token + id_token to the service. + receivedCode string + receivedVerifier string +} + +func newMockIdP(t *testing.T) *mockIdP { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("rsa.GenerateKey: %v", err) + } + keyID := "test-key-1" + signer, err := jose.NewSigner( + jose.SigningKey{Algorithm: jose.RS256, Key: key}, + (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", keyID), + ) + if err != nil { + t.Fatalf("jose.NewSigner: %v", err) + } + + idp := &mockIdP{ + key: key, + signer: signer, + keyID: keyID, + advertisedAlgs: []string{"RS256"}, + } + + mux := http.NewServeMux() + + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + base := "http://" + r.Host + doc := map[string]interface{}{ + "issuer": base, + "authorization_endpoint": base + "/authorize", + "token_endpoint": base + "/token", + "jwks_uri": base + "/jwks", + "id_token_signing_alg_values_supported": idp.advertisedAlgs, + "response_types_supported": []string{"code"}, + "subject_types_supported": []string{"public"}, + } + if !idp.omitUserinfoEndpoint { + doc["userinfo_endpoint"] = base + "/userinfo" + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(doc) + }) + + mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { + if idp.userinfoFails { + http.Error(w, "userinfo simulated failure", http.StatusInternalServerError) + return + } + // The OAuth2 client sends the access token as Bearer; we don't + // validate the value (the test stub always returns + // "test-access-token" from /token). Return a JSON body with the + // claims the production fetchUserinfoGroups path consumes. + body := map[string]interface{}{ + "sub": "test-subject", + "email": "user@example.com", + } + if idp.userinfoGroups != nil { + body["groups"] = idp.userinfoGroups + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(body) + }) + + mux.HandleFunc("/jwks", func(w http.ResponseWriter, r *http.Request) { + jwks := jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + {Key: key.Public(), KeyID: keyID, Algorithm: "RS256", Use: "sig"}, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(jwks) + }) + + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseForm() + idp.receivedCode = r.PostFormValue("code") + idp.receivedVerifier = r.PostFormValue("code_verifier") + + base := "http://" + r.Host + now := time.Now().UTC() + + audience := []string{"certctl"} + if idp.overrideAudience != nil { + audience = idp.overrideAudience + } + issuer := base + if idp.overrideIssuer != "" { + issuer = idp.overrideIssuer + } + exp := now.Add(time.Hour) + if !idp.overrideExp.IsZero() { + exp = idp.overrideExp + } + iat := now + if !idp.overrideIAT.IsZero() { + iat = idp.overrideIAT + } + subject := "test-subject" + if idp.overrideSubject != "" { + subject = idp.overrideSubject + } + email := "user@example.com" + if idp.overrideEmail == "" { + email = "" + } else if idp.overrideEmail != "" { + email = idp.overrideEmail + } + groups := []string{"engineers"} + if idp.overrideGroups != nil { + groups = idp.overrideGroups + } + + // "name" is included by default; "" sentinel suppresses it + // (used to test the upsertUser display-name fallback chain). + name := "Test User" + if idp.overrideName == "" { + name = "" + } else if idp.overrideName != "" { + name = idp.overrideName + } + claims := map[string]interface{}{ + "iss": issuer, + "aud": audience, + "sub": subject, + "exp": exp.Unix(), + "iat": iat.Unix(), + "email": email, + "name": name, + "groups": groups, + } + if idp.overrideNonce != "" { + claims["nonce"] = idp.overrideNonce + } else { + // Echo back whatever nonce the test supplied via the + // pre-login row. The test stub PreLoginStore generates a + // fixed nonce; we mirror it here. + claims["nonce"] = "test-nonce-fixed" + } + if idp.overrideAZP != "" { + claims["azp"] = idp.overrideAZP + } + // Default: emit a correct at_hash computed from the canned + // access_token under SHA-256 (matches the RS256 signing alg the + // mockIdP uses). Tests that need to exercise the + // at_hash-mismatch / at_hash-missing paths set overrideATHash + // to "" or "" respectively. + switch idp.overrideATHash { + case "": + h := sha256.Sum256([]byte("test-access-token")) + claims["at_hash"] = base64.RawURLEncoding.EncodeToString(h[:len(h)/2]) + case "": + // Suppress at_hash entirely. + default: + claims["at_hash"] = idp.overrideATHash + } + + raw, err := jwt.Signed(signer).Claims(claims).Serialize() + if err != nil { + http.Error(w, err.Error(), 500) + return + } + + resp := map[string]interface{}{ + "access_token": "test-access-token", + "token_type": "Bearer", + "expires_in": 3600, + } + if !idp.suppressIDToken { + resp["id_token"] = raw + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + }) + + mux.HandleFunc("/authorize", func(w http.ResponseWriter, r *http.Request) { + // Tests call HandleCallback directly; this endpoint exists for + // completeness but the test never round-trips through it. + http.Error(w, "test fixture: not implemented", 501) + }) + + idp.server = httptest.NewServer(mux) + t.Cleanup(idp.server.Close) + return idp +} + +func (m *mockIdP) URL() string { return m.server.URL } + +// ============================================================================= +// Stubs for the Service's collaborators +// ============================================================================= + +type stubProviderLookup struct { + provider *oidcdomain.OIDCProvider +} + +func (s *stubProviderLookup) Get(_ context.Context, id string) (*oidcdomain.OIDCProvider, error) { + if s.provider == nil || s.provider.ID != id { + return nil, repository.ErrOIDCProviderNotFound + } + return s.provider, nil +} +func (s *stubProviderLookup) List(_ context.Context, _ string) ([]*oidcdomain.OIDCProvider, error) { + if s.provider == nil { + return nil, nil + } + return []*oidcdomain.OIDCProvider{s.provider}, nil +} + +type stubMappings struct { + roleIDs []string + mapErr error // when set, Map returns this error +} + +func (s *stubMappings) ListByProvider(_ context.Context, _ string) ([]*oidcdomain.GroupRoleMapping, error) { + return nil, nil +} +func (s *stubMappings) Get(_ context.Context, _ string) (*oidcdomain.GroupRoleMapping, error) { + return nil, repository.ErrGroupRoleMappingNotFound +} +func (s *stubMappings) Add(_ context.Context, _ *oidcdomain.GroupRoleMapping) error { return nil } +func (s *stubMappings) Remove(_ context.Context, _ string) error { return nil } +func (s *stubMappings) Map(_ context.Context, _ string, _ []string) ([]string, error) { + if s.mapErr != nil { + return nil, s.mapErr + } + return s.roleIDs, nil +} + +type stubUsers struct { + byID map[string]*userdomain.User + bySubject map[string]*userdomain.User + createErr error // when set, Create returns this error + getErr error // when set, GetByOIDCSubject returns this error (other than NotFound) +} + +func newStubUsers() *stubUsers { + return &stubUsers{ + byID: make(map[string]*userdomain.User), + bySubject: make(map[string]*userdomain.User), + } +} +func (s *stubUsers) Get(_ context.Context, id string) (*userdomain.User, error) { + u, ok := s.byID[id] + if !ok { + return nil, repository.ErrUserNotFound + } + return u, nil +} +func (s *stubUsers) GetByOIDCSubject(_ context.Context, providerID, subject string) (*userdomain.User, error) { + if s.getErr != nil { + return nil, s.getErr + } + u, ok := s.bySubject[providerID+":"+subject] + if !ok { + return nil, repository.ErrUserNotFound + } + return u, nil +} +func (s *stubUsers) Create(_ context.Context, u *userdomain.User) error { + if s.createErr != nil { + return s.createErr + } + s.byID[u.ID] = u + s.bySubject[u.OIDCProviderID+":"+u.OIDCSubject] = u + return nil +} +func (s *stubUsers) Update(_ context.Context, u *userdomain.User) error { + s.byID[u.ID] = u + s.bySubject[u.OIDCProviderID+":"+u.OIDCSubject] = u + return nil +} +func (s *stubUsers) ListAll(_ context.Context, _ string) ([]*userdomain.User, error) { + out := make([]*userdomain.User, 0, len(s.byID)) + for _, u := range s.byID { + out = append(out, u) + } + return out, nil +} + +type stubSessions struct { + cookieValue string + csrfToken string + mintErr error // when set, MintForUser returns this error +} + +func (s *stubSessions) MintForUser(_ context.Context, _ *userdomain.User, _ []string, _, _ string) (string, string, error) { + if s.mintErr != nil { + return "", "", s.mintErr + } + if s.cookieValue == "" { + s.cookieValue = "test-cookie" + } + if s.csrfToken == "" { + s.csrfToken = "test-csrf" + } + return s.cookieValue, s.csrfToken, nil +} + +// stubPreLogin is in-memory PreLoginStore. Single-use enforced via +// delete-on-LookupAndConsume. +type stubPreLogin struct { + rows map[string]preLoginRow + createErr error // when set, CreatePreLogin returns this error +} + +type preLoginRow struct { + providerID, state, nonce, verifier string +} + +func newStubPreLogin() *stubPreLogin { + return &stubPreLogin{rows: make(map[string]preLoginRow)} +} +func (s *stubPreLogin) CreatePreLogin(_ context.Context, providerID, state, nonce, verifier string) (string, string, error) { + if s.createErr != nil { + return "", "", s.createErr + } + cookieVal := fmt.Sprintf("pl-%d", len(s.rows)+1) + s.rows[cookieVal] = preLoginRow{providerID, state, nonce, verifier} + return cookieVal, "ses-" + cookieVal, nil +} +func (s *stubPreLogin) LookupAndConsume(_ context.Context, cookie string) (string, string, string, string, error) { + r, ok := s.rows[cookie] + if !ok { + return "", "", "", "", ErrPreLoginNotFound + } + delete(s.rows, cookie) + return r.providerID, r.state, r.nonce, r.verifier, nil +} + +// ============================================================================= +// Standalone unit tests (no live IdP needed) +// ============================================================================= + +// Test 1: PKCE 'plain' is rejected. The Service NEVER generates a plain +// verifier (oauth2.GenerateVerifier + S256ChallengeOption are +// hard-coded), but we pin the deny-list constant exists so a future +// regression is caught. +func TestService_PKCEPlainRejectedSentinel(t *testing.T) { + // The sentinel exists; that's the contract a future code path must + // reference if it ever surfaces a plain-method path. Pin it. + if ErrPKCEPlainRejected == nil { + t.Fatalf("ErrPKCEPlainRejected sentinel must exist") + } + if !strings.Contains(ErrPKCEPlainRejected.Error(), "plain") { + t.Errorf("sentinel message should reference 'plain'; got %q", ErrPKCEPlainRejected.Error()) + } +} + +// Test 2: state replay (consume-once). After LookupAndConsume succeeds, +// a second call with the same cookie returns ErrPreLoginNotFound. +func TestService_StateReplayDeniedByConsumeOnce(t *testing.T) { + pl := newStubPreLogin() + cookie, _, err := pl.CreatePreLogin(context.Background(), "op-x", "the-state", "the-nonce", "verifier-xxx") + if err != nil { + t.Fatalf("CreatePreLogin: %v", err) + } + if _, _, _, _, err := pl.LookupAndConsume(context.Background(), cookie); err != nil { + t.Fatalf("first LookupAndConsume: %v", err) + } + _, _, _, _, err = pl.LookupAndConsume(context.Background(), cookie) + if !errors.Is(err, ErrPreLoginNotFound) { + t.Errorf("second LookupAndConsume err = %v; want ErrPreLoginNotFound (single-use violated)", err) + } +} + +// Test 3: forged pre-login cookie returns ErrPreLoginNotFound. +func TestService_HandleCallback_RejectsForgedPreLoginCookie(t *testing.T) { + svc := newServiceForUnitTest(t) + _, err := svc.HandleCallback(context.Background(), "bogus-cookie", "any-code", "any-state", "ip", "ua") + if !errors.Is(err, ErrPreLoginNotFound) { + t.Errorf("err = %v; want ErrPreLoginNotFound", err) + } +} + +// Test 4: state mismatch (cookie matches but the callback state doesn't). +func TestService_HandleCallback_RejectsStateMismatch(t *testing.T) { + svc, pl := newServiceForUnitTestWithPL(t) + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-test", "real-state", "real-nonce", "verifier-xxx") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "wrong-state", "ip", "ua") + if !errors.Is(err, ErrStateMismatch) { + t.Errorf("err = %v; want ErrStateMismatch", err) + } +} + +// Test 5: alg pinning — direct unit test of isDisallowedAlg helper. +// Hand-builds a JWT header for each algorithm, asserts the deny-list +// catches HS* and `none`. +func TestService_AlgPinning_RejectsHSAlgsAndNone(t *testing.T) { + for _, alg := range []string{"HS256", "HS384", "HS512", "none"} { + header := fmt.Sprintf(`{"alg":%q,"typ":"JWT"}`, alg) + token := base64.RawURLEncoding.EncodeToString([]byte(header)) + ".body.sig" + rejected, gotAlg := isDisallowedAlg(token) + if !rejected { + t.Errorf("alg=%q: not rejected; want rejected", alg) + } + if gotAlg != alg { + t.Errorf("alg=%q: extracted %q; want %q", alg, gotAlg, alg) + } + } +} + +// Test 6: alg pinning — allowed algs pass. +func TestService_AlgPinning_AllowsRSAndECAndEdDSA(t *testing.T) { + for _, alg := range []string{"RS256", "RS512", "ES256", "ES384", "EdDSA"} { + header := fmt.Sprintf(`{"alg":%q,"typ":"JWT"}`, alg) + token := base64.RawURLEncoding.EncodeToString([]byte(header)) + ".body.sig" + rejected, gotAlg := isDisallowedAlg(token) + if rejected { + t.Errorf("alg=%q: rejected; want allowed", alg) + } + if gotAlg != alg { + t.Errorf("alg=%q: extracted %q; want %q", alg, gotAlg, alg) + } + } +} + +// Test 7: malformed JWT (wrong segment count) → rejected as if alg-bad. +func TestService_AlgPinning_RejectsMalformedJWT(t *testing.T) { + for _, bad := range []string{"", "single-segment", "two.segments", "more.than.three.segments"} { + rejected, _ := isDisallowedAlg(bad) + if !rejected { + t.Errorf("malformed JWT %q: not rejected", bad) + } + } +} + +// Test 8: at_hash recomputation — happy path matches. +func TestService_ATHash_MatchesForRS256(t *testing.T) { + accessToken := "test-access-token-value" + h := sha256.Sum256([]byte(accessToken)) + half := h[:len(h)/2] + expected := base64.RawURLEncoding.EncodeToString(half) + + header := `{"alg":"RS256","typ":"JWT"}` + rawIDToken := base64.RawURLEncoding.EncodeToString([]byte(header)) + ".body.sig" + if !atHashMatches(rawIDToken, accessToken, expected) { + t.Errorf("atHashMatches should accept correctly-computed at_hash") + } +} + +// Test 9: at_hash mismatch → rejected. +func TestService_ATHash_RejectsMismatch(t *testing.T) { + header := `{"alg":"RS256","typ":"JWT"}` + rawIDToken := base64.RawURLEncoding.EncodeToString([]byte(header)) + ".body.sig" + if atHashMatches(rawIDToken, "the-token", "wrong-hash-claim") { + t.Errorf("atHashMatches accepted bad at_hash; should reject") + } +} + +// Test 10: at_hash for unknown alg returns false (defense vs an alg +// that escaped the alg-pin check). +func TestService_ATHash_UnknownAlgReturnsFalse(t *testing.T) { + header := `{"alg":"unknown","typ":"JWT"}` + rawIDToken := base64.RawURLEncoding.EncodeToString([]byte(header)) + ".body.sig" + if atHashMatches(rawIDToken, "any-access-token", "any-hash") { + t.Errorf("atHashMatches with unknown alg should return false") + } +} + +// Test 11: IdP downgrade-attack defense. A provider whose discovery doc +// advertises HS256 in id_token_signing_alg_values_supported is REJECTED +// by the cache load with ErrIdPDowngradeAdvertised. +func TestService_IdPDowngradeDefense_RejectsHSAdvertised(t *testing.T) { + idp := newMockIdP(t) + idp.advertisedAlgs = []string{"RS256", "HS256"} // HS256 is the downgrade vector + + svc, _ := newServiceWithProvider(t, idp.URL(), "op-bad-idp") + + _, err := svc.getOrLoad(context.Background(), "op-bad-idp") + if !errors.Is(err, ErrIdPDowngradeAdvertised) { + t.Errorf("err = %v; want ErrIdPDowngradeAdvertised", err) + } +} + +// Test 12: IdP downgrade-attack defense — `none` advertisement also +// triggers rejection. +func TestService_IdPDowngradeDefense_RejectsNoneAdvertised(t *testing.T) { + idp := newMockIdP(t) + idp.advertisedAlgs = []string{"RS256", "none"} + + svc, _ := newServiceWithProvider(t, idp.URL(), "op-none-idp") + + _, err := svc.getOrLoad(context.Background(), "op-none-idp") + if !errors.Is(err, ErrIdPDowngradeAdvertised) { + t.Errorf("err = %v; want ErrIdPDowngradeAdvertised", err) + } +} + +// Test 13: clean RS256 IdP loads successfully. +func TestService_GetOrLoad_AcceptsCleanIdP(t *testing.T) { + idp := newMockIdP(t) // default advertisedAlgs=["RS256"] + svc, _ := newServiceWithProvider(t, idp.URL(), "op-good-idp") + + entry, err := svc.getOrLoad(context.Background(), "op-good-idp") + if err != nil { + t.Fatalf("getOrLoad: %v", err) + } + if entry.provider == nil { + t.Errorf("entry.provider is nil") + } + if entry.verifier == nil { + t.Errorf("entry.verifier is nil") + } +} + +// Test 14: RefreshKeys evicts the cache + re-fetches discovery, which +// re-runs the downgrade defense. If the IdP rotated to advertising +// HS256 between loads, RefreshKeys catches it. +func TestService_RefreshKeys_CatchesPostLoadDowngrade(t *testing.T) { + idp := newMockIdP(t) + svc, _ := newServiceWithProvider(t, idp.URL(), "op-rotate") + + if _, err := svc.getOrLoad(context.Background(), "op-rotate"); err != nil { + t.Fatalf("initial load: %v", err) + } + + // IdP rotates to advertising HS256. + idp.advertisedAlgs = []string{"RS256", "HS256"} + err := svc.RefreshKeys(context.Background(), "op-rotate") + if !errors.Is(err, ErrIdPDowngradeAdvertised) { + t.Errorf("RefreshKeys err = %v; want ErrIdPDowngradeAdvertised", err) + } +} + +// Test 15: HandleCallback happy path against the mock IdP. +func TestService_HandleCallback_HappyPath(t *testing.T) { + idp := newMockIdP(t) + svc, pl := newServiceWithProviderAndPL(t, idp.URL(), "op-happy") + + cookie, _, err := pl.CreatePreLogin(context.Background(), "op-happy", "happy-state", "test-nonce-fixed", "verifier-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + if err != nil { + t.Fatalf("CreatePreLogin: %v", err) + } + + res, err := svc.HandleCallback(context.Background(), cookie, "test-code", "happy-state", "10.0.0.1", "Mozilla/5.0") + if err != nil { + t.Fatalf("HandleCallback: %v", err) + } + if res.User == nil { + t.Errorf("CallbackResult.User nil") + } + if len(res.RoleIDs) == 0 { + t.Errorf("CallbackResult.RoleIDs empty") + } + if res.CookieValue == "" { + t.Errorf("CallbackResult.CookieValue empty") + } +} + +// Test 16: HandleCallback rejects ID token with wrong audience. +func TestService_HandleCallback_RejectsWrongAudience(t *testing.T) { + idp := newMockIdP(t) + idp.overrideAudience = []string{"some-other-client"} + svc, pl := newServiceWithProviderAndPL(t, idp.URL(), "op-aud") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-aud", "s", "test-nonce-fixed", "v-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + // gooidc.Verify catches this first; its wrap reaches us as a wrapped error. + // Either ErrAudienceMismatch (our re-check) OR a wrapped verify error is acceptable. + if err == nil { + t.Errorf("expected non-nil err for wrong-aud token") + } +} + +// Test 17: HandleCallback rejects an ID token whose nonce doesn't match +// the pre-login row. +func TestService_HandleCallback_RejectsNonceMismatch(t *testing.T) { + idp := newMockIdP(t) + idp.overrideNonce = "wrong-nonce-from-idp" + svc, pl := newServiceWithProviderAndPL(t, idp.URL(), "op-nonce") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-nonce", "s", "expected-nonce", "v-bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if !errors.Is(err, ErrNonceMismatch) { + t.Errorf("err = %v; want ErrNonceMismatch", err) + } +} + +// Test 18: HandleCallback rejects expired ID token. +func TestService_HandleCallback_RejectsExpiredToken(t *testing.T) { + idp := newMockIdP(t) + idp.overrideExp = time.Now().Add(-2 * time.Hour) // 2 hours past + svc, pl := newServiceWithProviderAndPL(t, idp.URL(), "op-exp") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-exp", "s", "test-nonce-fixed", "v-cccccccccccccccccccccccccccccccccccccccccc") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + // Either ErrTokenExpired (our re-check) or a wrapped verify error is fine. + if err == nil { + t.Errorf("expected non-nil err for expired token") + } +} + +// Test 19: HandleCallback rejects ID token whose iat is too old per the +// configured IATWindow. +func TestService_HandleCallback_RejectsIATTooOld(t *testing.T) { + idp := newMockIdP(t) + // Token was issued 20 minutes ago; default IATWindow is 5 minutes. + idp.overrideIAT = time.Now().Add(-20 * time.Minute) + idp.overrideExp = time.Now().Add(2 * time.Hour) // exp is fine + svc, pl := newServiceWithProviderAndPL(t, idp.URL(), "op-iat") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-iat", "s", "test-nonce-fixed", "v-dddddddddddddddddddddddddddddddddddddddddd") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if !errors.Is(err, ErrIATTooOld) { + t.Errorf("err = %v; want ErrIATTooOld", err) + } +} + +// Test 20: HandleCallback rejects when group claim is missing. +func TestService_HandleCallback_RejectsGroupsMissing(t *testing.T) { + idp := newMockIdP(t) + idp.overrideGroups = []string{} // empty groups claim + svc, pl := newServiceWithProviderAndPL(t, idp.URL(), "op-grp") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-grp", "s", "test-nonce-fixed", "v-eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if !errors.Is(err, ErrGroupsMissing) { + t.Errorf("err = %v; want ErrGroupsMissing", err) + } +} + +// Test 21: HandleCallback rejects when groups don't match any +// configured mapping → ErrGroupsUnmapped. +func TestService_HandleCallback_RejectsGroupsUnmapped(t *testing.T) { + idp := newMockIdP(t) + svc, pl := newServiceWithProviderAndPLNoMappings(t, idp.URL(), "op-unmap") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-unmap", "s", "test-nonce-fixed", "v-ffffffffffffffffffffffffffffffffffffffffff") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if !errors.Is(err, ErrGroupsUnmapped) { + t.Errorf("err = %v; want ErrGroupsUnmapped", err) + } +} + +// ============================================================================= +// Test helpers +// ============================================================================= + +func makeProvider(idpURL, providerID string) *oidcdomain.OIDCProvider { + return &oidcdomain.OIDCProvider{ + ID: providerID, + TenantID: "t-default", + Name: "Test " + providerID, + IssuerURL: idpURL, + ClientID: "certctl", + ClientSecretEncrypted: []byte("test-secret"), + RedirectURI: "https://certctl.example.com/auth/oidc/callback", + GroupsClaimPath: "groups", + GroupsClaimFormat: "string-array", + Scopes: []string{"openid", "profile", "email"}, + IATWindowSeconds: 300, + JWKSCacheTTLSeconds: 3600, + } +} + +// newServiceWithProvider returns a Service wired against the given IdP +// URL + a provider already in the stub provider lookup. +func newServiceWithProvider(t *testing.T, idpURL, providerID string) (*Service, *stubPreLogin) { + return newServiceWithProviderAndPL(t, idpURL, providerID) +} + +func newServiceWithProviderAndPL(t *testing.T, idpURL, providerID string) (*Service, *stubPreLogin) { + t.Helper() + prov := makeProvider(idpURL, providerID) + pl := newStubPreLogin() + mappings := &stubMappings{roleIDs: []string{"r-operator"}} + users := newStubUsers() + sessions := &stubSessions{} + svc := NewService( + &stubProviderLookup{provider: prov}, + mappings, + users, + sessions, + pl, + "", // no encryption key; client_secret already plaintext for test + ) + return svc, pl +} + +func newServiceWithProviderAndPLNoMappings(t *testing.T, idpURL, providerID string) (*Service, *stubPreLogin) { + t.Helper() + prov := makeProvider(idpURL, providerID) + pl := newStubPreLogin() + mappings := &stubMappings{roleIDs: nil} // empty mappings + users := newStubUsers() + sessions := &stubSessions{} + svc := NewService( + &stubProviderLookup{provider: prov}, + mappings, + users, + sessions, + pl, + "", + ) + return svc, pl +} + +func newServiceForUnitTest(t *testing.T) *Service { + t.Helper() + pl := newStubPreLogin() + return NewService( + &stubProviderLookup{}, + &stubMappings{}, + newStubUsers(), + &stubSessions{}, + pl, + "", + ) +} + +func newServiceForUnitTestWithPL(t *testing.T) (*Service, *stubPreLogin) { + t.Helper() + pl := newStubPreLogin() + return NewService( + &stubProviderLookup{}, + &stubMappings{}, + newStubUsers(), + &stubSessions{}, + pl, + "", + ), pl +} + +// ============================================================================= +// Additional coverage tests: HandleAuthRequest entry point, upsert +// update path, atHashMatches alg coverage, helpers. +// ============================================================================= + +// TestService_HandleAuthRequest_BuildsValidIdPRedirect covers the +// authz-request path end-to-end. Asserts the URL contains state + +// nonce + code_challenge_method=S256 + the operator-configured +// client_id. +func TestService_HandleAuthRequest_BuildsValidIdPRedirect(t *testing.T) { + idp := newMockIdP(t) + svc, pl := newServiceWithProviderAndPL(t, idp.URL(), "op-har") + + authURL, cookieValue, preLoginID, err := svc.HandleAuthRequest(context.Background(), "op-har") + if err != nil { + t.Fatalf("HandleAuthRequest: %v", err) + } + if cookieValue == "" || preLoginID == "" { + t.Errorf("empty cookieValue or preLoginID") + } + for _, want := range []string{ + "client_id=certctl", + "code_challenge_method=S256", + "code_challenge=", + "state=", + "nonce=", + "redirect_uri=", + "scope=", + } { + if !strings.Contains(authURL, want) { + t.Errorf("authURL missing %q in %q", want, authURL) + } + } + // Pin the pre-login row got persisted with a matching state value. + if len(pl.rows) != 1 { + t.Errorf("pl rows = %d; want 1", len(pl.rows)) + } +} + +// TestService_HandleAuthRequest_UnknownProviderRejected pins the +// repo-not-found path through HandleAuthRequest. +func TestService_HandleAuthRequest_UnknownProviderRejected(t *testing.T) { + svc := newServiceForUnitTest(t) + _, _, _, err := svc.HandleAuthRequest(context.Background(), "op-nonexistent") + if !errors.Is(err, repository.ErrOIDCProviderNotFound) { + t.Errorf("err = %v; want ErrOIDCProviderNotFound", err) + } +} + +// TestService_UpsertUser_UpdateExistingPath: a second login by the +// same user updates last_login_at + email + display_name without +// creating a duplicate row. +func TestService_UpsertUser_UpdateExistingPath(t *testing.T) { + idp := newMockIdP(t) + users := newStubUsers() + + prov := makeProvider(idp.URL(), "op-upd") + pl := newStubPreLogin() + mappings := &stubMappings{roleIDs: []string{"r-operator"}} + sessions := &stubSessions{} + svc := NewService(&stubProviderLookup{provider: prov}, mappings, users, sessions, pl, "") + + // First login creates the user. + cookie1, _, _ := pl.CreatePreLogin(context.Background(), "op-upd", "s1", "test-nonce-fixed", "v-1aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") + res1, err := svc.HandleCallback(context.Background(), cookie1, "code", "s1", "ip", "ua") + if err != nil { + t.Fatalf("first HandleCallback: %v", err) + } + if len(users.byID) != 1 { + t.Errorf("first login: user count = %d; want 1", len(users.byID)) + } + originalLogin := res1.User.LastLoginAt + + time.Sleep(10 * time.Millisecond) // ensure timestamps advance + + // Second login by same subject: update path, no new user row. + cookie2, _, _ := pl.CreatePreLogin(context.Background(), "op-upd", "s2", "test-nonce-fixed", "v-2aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") + idp.overrideEmail = "user-renamed@example.com" + res2, err := svc.HandleCallback(context.Background(), cookie2, "code2", "s2", "ip", "ua") + if err != nil { + t.Fatalf("second HandleCallback: %v", err) + } + if len(users.byID) != 1 { + t.Errorf("second login: user count = %d; want 1 (Update path)", len(users.byID)) + } + if !res2.User.LastLoginAt.After(originalLogin) { + t.Errorf("LastLoginAt did not advance on second login: %v -> %v", originalLogin, res2.User.LastLoginAt) + } + if res2.User.Email != "user-renamed@example.com" { + t.Errorf("Email did not update: %q", res2.User.Email) + } +} + +// TestService_ATHash_CoversAllAllowedAlgs pins the at_hash alg dispatch +// for every algorithm in DefaultAllowedAlgs. +func TestService_ATHash_CoversAllAllowedAlgs(t *testing.T) { + cases := []struct { + alg string + hashName string + }{ + {"RS256", "sha256"}, + {"RS512", "sha512"}, + {"ES256", "sha256"}, + {"ES384", "sha384"}, + {"EdDSA", "sha512"}, + } + for _, tc := range cases { + t.Run(tc.alg, func(t *testing.T) { + accessToken := "access-token-for-" + tc.alg + // Compute the expected hash using the same logic as atHashMatches. + var sum []byte + switch tc.alg { + case "RS256", "ES256": + h := sha256.Sum256([]byte(accessToken)) + sum = h[:] + case "ES384": + // SHA-384 via crypto/sha512 (sha512.Sum384 returns [48]byte). + // Avoid importing sha512 here; use the prod helper indirectly. + ok := atHashMatches(makeJWTHeader(tc.alg), accessToken, computeATHashViaProd(t, tc.alg, accessToken)) + if !ok { + t.Errorf("alg=%q: atHashMatches returned false on round-trip", tc.alg) + } + return + case "RS512", "EdDSA": + ok := atHashMatches(makeJWTHeader(tc.alg), accessToken, computeATHashViaProd(t, tc.alg, accessToken)) + if !ok { + t.Errorf("alg=%q: atHashMatches returned false on round-trip", tc.alg) + } + return + } + half := sum[:len(sum)/2] + expected := base64.RawURLEncoding.EncodeToString(half) + if !atHashMatches(makeJWTHeader(tc.alg), accessToken, expected) { + t.Errorf("alg=%q: at_hash mismatch", tc.alg) + } + }) + } +} + +// computeATHashViaProd shims around atHashMatches by binary-searching +// for the at_hash value: we just call the production helper with each +// alg, and the test passes if the same value reproduces. Avoids +// duplicating the alg → hash dispatch in test code. +func computeATHashViaProd(_ *testing.T, alg, accessToken string) string { + // Build a JWT with that alg, then use atHashMatches twice with + // different claim values to find the matching one. Since we + // can't easily do that without infinite test loops, the easier + // path is to call the production code at the at_hash reflect + // surface. But our service has no public at_hash compute helper — + // only matches helper. So: use a trial-and-error with the empty + // hash and check against the real recomputed hash via a helper + // that doesn't exist. Instead, this function reaches into the + // implementation by replicating it minimally. + h := newHasherForAlg(alg) + if h == nil { + return "" + } + h.Write([]byte(accessToken)) + sum := h.Sum(nil) + half := sum[:len(sum)/2] + return base64.RawURLEncoding.EncodeToString(half) +} + +// newHasherForAlg duplicates the dispatch in atHashMatches for the +// test helper. Kept in test code so the production path stays +// dependency-light. +func newHasherForAlg(alg string) interface { + Write([]byte) (int, error) + Sum([]byte) []byte +} { + switch alg { + case "RS256", "ES256": + return sha256.New() + case "ES384": + return sha384New() + case "RS512", "EdDSA": + return sha512New() + default: + return nil + } +} + +// makeJWTHeader returns a minimal JWT-shape string with the given alg +// in the header. body + sig are dummy. +func makeJWTHeader(alg string) string { + header := fmt.Sprintf(`{"alg":%q,"typ":"JWT"}`, alg) + return base64.RawURLEncoding.EncodeToString([]byte(header)) + ".body.sig" +} + +// TestService_AlgPinning_HandlesWhitespaceInHeader pins the parser +// against headers with whitespace around the alg value (some libraries +// emit " :" instead of ":"). +func TestService_AlgPinning_HandlesWhitespaceInHeader(t *testing.T) { + header := `{"alg" : "RS256" ,"typ":"JWT"}` + token := base64.RawURLEncoding.EncodeToString([]byte(header)) + ".body.sig" + rejected, alg := isDisallowedAlg(token) + if rejected { + t.Errorf("RS256 with whitespace: rejected = true; want allowed") + } + if alg != "RS256" { + t.Errorf("alg extraction failed: got %q", alg) + } +} + +// TestService_AlgPinning_HeaderWithBadBase64 returns rejected=true +// when the header isn't decodable. +func TestService_AlgPinning_HeaderWithBadBase64(t *testing.T) { + rejected, _ := isDisallowedAlg("!!!not-base64.body.sig") + if !rejected { + t.Errorf("bad base64 header: rejected = false; want true") + } +} + +// TestService_AlgPinning_HeaderMissingAlgField returns rejected=true. +func TestService_AlgPinning_HeaderMissingAlgField(t *testing.T) { + header := `{"typ":"JWT"}` + token := base64.RawURLEncoding.EncodeToString([]byte(header)) + ".body.sig" + rejected, _ := isDisallowedAlg(token) + if !rejected { + t.Errorf("header missing alg: rejected = false; want true") + } +} + +// TestService_IsJWKSFetchError pins the error-string heuristic. +func TestService_IsJWKSFetchError(t *testing.T) { + cases := []struct { + msg string + want bool + }{ + {"oidc: fetching keys oidc: get keys failed: timeout", true}, + {"failed to fetch jwks_uri", true}, + {"unable to load key set", true}, + {"some other unrelated error", false}, + {"", false}, + } + for _, tc := range cases { + got := isJWKSFetchError(errors.New(tc.msg)) + if got != tc.want { + t.Errorf("isJWKSFetchError(%q) = %v; want %v", tc.msg, got, tc.want) + } + } + if isJWKSFetchError(nil) { + t.Errorf("isJWKSFetchError(nil) = true; want false") + } +} + +// TestService_DecryptClientSecret_NoKeyReturnsBytesAsIs covers the +// empty-key short-circuit (used by tests with plaintext blobs). +func TestService_DecryptClientSecret_NoKeyReturnsBytesAsIs(t *testing.T) { + plain := []byte("test-plaintext-secret") + got, err := decryptClientSecret(plain, "") + if err != nil { + t.Fatalf("decryptClientSecret(no key): %v", err) + } + if string(got) != string(plain) { + t.Errorf("decryptClientSecret returned %q; want %q", string(got), string(plain)) + } +} + +// TestService_RandomB64URL_ProducesNonEmptyAndUnique pins the random +// generator's contract. +func TestService_RandomB64URL_ProducesNonEmptyAndUnique(t *testing.T) { + a, err := randomB64URL(32) + if err != nil { + t.Fatalf("a: %v", err) + } + b, err := randomB64URL(32) + if err != nil { + t.Fatalf("b: %v", err) + } + if a == "" || b == "" { + t.Errorf("got empty random value") + } + if a == b { + t.Errorf("two random values were equal (RNG broken)") + } +} + +// TestService_SetClockForTest_OverridesNow pins the test seam works. +func TestService_SetClockForTest_OverridesNow(t *testing.T) { + svc := newServiceForUnitTest(t) + frozen := time.Date(2026, 5, 10, 12, 0, 0, 0, time.UTC) + svc.SetClockForTest(func() time.Time { return frozen }) + if got := svc.clockNow(); !got.Equal(frozen) { + t.Errorf("clock = %v; want %v", got, frozen) + } +} + +// ============================================================================= +// Coverage-lift batch: HandleCallback branch tests + fetchUserinfoGroups + +// upsertUser fallback chain + decryptClientSecret real-encrypt round trip + +// randomB64URL error path + HandleAuthRequest preLogin failure. +// +// These tests exist to lift the package above the 90% per-statement floor +// pinned by Phase 13 of the bundle prompt. Each one targets a specific +// uncovered branch in service.go; the test name announces which. +// ============================================================================= + +// TestService_HandleCallback_AZPRequired_OnMultiAud pins the OIDC core +// §3.1.3.7 step 5 enforcement: a multi-audience ID token MUST carry an +// `azp` claim equal to the relying-party client_id, otherwise the token +// is rejected. +func TestService_HandleCallback_AZPRequired_OnMultiAud(t *testing.T) { + idp := newMockIdP(t) + // Multi-aud, NO azp — Phase 3 requires azp in this case. + idp.overrideAudience = []string{"certctl", "another-relying-party"} + svc, pl := newServiceWithProviderAndPL(t, idp.URL(), "op-azp-req") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-azp-req", "s", "test-nonce-fixed", "v-azpreqxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if !errors.Is(err, ErrAZPRequired) { + t.Errorf("err = %v; want ErrAZPRequired", err) + } +} + +// TestService_HandleCallback_AZPMismatch pins the equal-to-client_id +// requirement when azp is present. +func TestService_HandleCallback_AZPMismatch(t *testing.T) { + idp := newMockIdP(t) + idp.overrideAZP = "some-other-client" // != "certctl" + svc, pl := newServiceWithProviderAndPL(t, idp.URL(), "op-azp-mis") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-azp-mis", "s", "test-nonce-fixed", "v-azpmisxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if !errors.Is(err, ErrAZPMismatch) { + t.Errorf("err = %v; want ErrAZPMismatch", err) + } +} + +// TestService_HandleCallback_ATHashMismatch pins the at_hash recompute +// check: if the IdP returns at_hash that doesn't match SHA-256 of the +// access token's first half, reject. +func TestService_HandleCallback_ATHashMismatch(t *testing.T) { + idp := newMockIdP(t) + // Inject a wrong at_hash. The mockIdP returns access_token = + // "test-access-token"; the real at_hash for that token under RS256 + // is sha256[:16] base64url. We overshoot with a known-wrong value. + idp.overrideATHash = "not-the-real-at-hash" + svc, pl := newServiceWithProviderAndPL(t, idp.URL(), "op-ath-mis") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-ath-mis", "s", "test-nonce-fixed", "v-athmisxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if !errors.Is(err, ErrATHashMismatch) { + t.Errorf("err = %v; want ErrATHashMismatch", err) + } +} + +// TestService_HandleCallback_ATHashRequired_WhenAccessTokenPresent pins +// the Phase 3 tightening of the OIDC core "MAY" to a service-level +// "MUST": when an access token is returned, the ID token MUST carry an +// at_hash claim. A substituted access token would otherwise ride a +// clean ID token through the verifier — fail closed at the service. +func TestService_HandleCallback_ATHashRequired_WhenAccessTokenPresent(t *testing.T) { + idp := newMockIdP(t) + idp.overrideATHash = "" // suppress at_hash even though access_token is returned + svc, pl := newServiceWithProviderAndPL(t, idp.URL(), "op-ath-req") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-ath-req", "s", "test-nonce-fixed", "v-athreqxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if !errors.Is(err, ErrATHashRequired) { + t.Errorf("err = %v; want ErrATHashRequired", err) + } +} + +// TestService_HandleCallback_IATInFuture pins the iat-in-future rejection +// (60s clock-skew tolerance is the only allowance). +func TestService_HandleCallback_IATInFuture(t *testing.T) { + idp := newMockIdP(t) + // iat is 10 minutes in the future, well beyond 60s skew. + idp.overrideIAT = time.Now().Add(10 * time.Minute) + idp.overrideExp = time.Now().Add(2 * time.Hour) + svc, pl := newServiceWithProviderAndPL(t, idp.URL(), "op-iat-fut") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-iat-fut", "s", "test-nonce-fixed", "v-iatfutxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if !errors.Is(err, ErrIATInFuture) { + t.Errorf("err = %v; want ErrIATInFuture", err) + } +} + +// TestService_HandleCallback_MappingsMapError pins the wrap on the +// mappings.Map repo-layer error. +func TestService_HandleCallback_MappingsMapError(t *testing.T) { + idp := newMockIdP(t) + prov := makeProvider(idp.URL(), "op-map-err") + pl := newStubPreLogin() + mappings := &stubMappings{mapErr: fmt.Errorf("simulated repo failure")} + users := newStubUsers() + sessions := &stubSessions{} + svc := NewService(&stubProviderLookup{provider: prov}, mappings, users, sessions, pl, "") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-map-err", "s", "test-nonce-fixed", "v-mapxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if err == nil || !strings.Contains(err.Error(), "group-role mapping") { + t.Errorf("err = %v; want group-role mapping wrap", err) + } +} + +// TestService_HandleCallback_SessionMintError pins the wrap on the +// SessionService.MintForUser error. +func TestService_HandleCallback_SessionMintError(t *testing.T) { + idp := newMockIdP(t) + prov := makeProvider(idp.URL(), "op-mint-err") + pl := newStubPreLogin() + mappings := &stubMappings{roleIDs: []string{"r-operator"}} + users := newStubUsers() + sessions := &stubSessions{mintErr: fmt.Errorf("simulated session minter failure")} + svc := NewService(&stubProviderLookup{provider: prov}, mappings, users, sessions, pl, "") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-mint-err", "s", "test-nonce-fixed", "v-mintxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if err == nil || !strings.Contains(err.Error(), "session mint") { + t.Errorf("err = %v; want session mint wrap", err) + } +} + +// TestService_HandleCallback_UserCreateError pins the wrap on the +// users.Create repo-layer error. +func TestService_HandleCallback_UserCreateError(t *testing.T) { + idp := newMockIdP(t) + prov := makeProvider(idp.URL(), "op-uc-err") + pl := newStubPreLogin() + mappings := &stubMappings{roleIDs: []string{"r-operator"}} + users := newStubUsers() + users.createErr = fmt.Errorf("simulated insert failure") + sessions := &stubSessions{} + svc := NewService(&stubProviderLookup{provider: prov}, mappings, users, sessions, pl, "") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-uc-err", "s", "test-nonce-fixed", "v-ucxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if err == nil || !strings.Contains(err.Error(), "upsert user") { + t.Errorf("err = %v; want upsert user wrap", err) + } +} + +// TestService_HandleCallback_GetByOIDCSubjectNonNotFoundError pins the +// upsertUser early-return when the GetByOIDCSubject repo call fails for +// a reason OTHER than not-found (DB connection drop, query error, etc.). +func TestService_HandleCallback_GetByOIDCSubjectNonNotFoundError(t *testing.T) { + idp := newMockIdP(t) + prov := makeProvider(idp.URL(), "op-get-err") + pl := newStubPreLogin() + mappings := &stubMappings{roleIDs: []string{"r-operator"}} + users := newStubUsers() + users.getErr = fmt.Errorf("simulated query failure") + sessions := &stubSessions{} + svc := NewService(&stubProviderLookup{provider: prov}, mappings, users, sessions, pl, "") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-get-err", "s", "test-nonce-fixed", "v-getxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if err == nil || !strings.Contains(err.Error(), "simulated query failure") { + t.Errorf("err = %v; want simulated query failure unwrap", err) + } +} + +// TestService_UpsertUser_DisplayNameFallsBackToEmail covers the +// last-resort fallback: when both name and preferred_username are empty, +// the user record's display_name is set to the email. +func TestService_UpsertUser_DisplayNameFallsBackToEmail(t *testing.T) { + idp := newMockIdP(t) + idp.overrideName = "" // suppress name claim entirely + // preferred_username isn't emitted by the mockIdP at all, so it's "". + prov := makeProvider(idp.URL(), "op-name-fb") + pl := newStubPreLogin() + mappings := &stubMappings{roleIDs: []string{"r-operator"}} + users := newStubUsers() + sessions := &stubSessions{} + svc := NewService(&stubProviderLookup{provider: prov}, mappings, users, sessions, pl, "") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-name-fb", "s", "test-nonce-fixed", "v-namxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + res, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if err != nil { + t.Fatalf("HandleCallback: %v", err) + } + if res.User.DisplayName != "user@example.com" { + t.Errorf("DisplayName = %q; want fallback to email %q", res.User.DisplayName, "user@example.com") + } +} + +// TestService_FetchUserinfoGroups_HappyPath_OnEmptyIDTokenGroups pins +// the userinfo fallback: if the ID token's groups claim is empty AND +// the operator opted in via FetchUserinfo, the userinfo endpoint is +// consulted and its groups feed the role-mapping step. +func TestService_FetchUserinfoGroups_HappyPath_OnEmptyIDTokenGroups(t *testing.T) { + idp := newMockIdP(t) + idp.overrideGroups = []string{} // ID token returns no groups + idp.userinfoGroups = []string{"engineers", "platform"} // userinfo returns groups + prov := makeProvider(idp.URL(), "op-ui-ok") + prov.FetchUserinfo = true + pl := newStubPreLogin() + mappings := &stubMappings{roleIDs: []string{"r-operator"}} + users := newStubUsers() + sessions := &stubSessions{} + svc := NewService(&stubProviderLookup{provider: prov}, mappings, users, sessions, pl, "") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-ui-ok", "s", "test-nonce-fixed", "v-uioxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + res, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if err != nil { + t.Fatalf("HandleCallback: %v", err) + } + if len(res.RoleIDs) == 0 { + t.Errorf("expected RoleIDs from userinfo-fallback path; got empty") + } +} + +// TestService_FetchUserinfoGroups_ReturnsErrGroupsMissing_WhenUserinfoAlsoEmpty +// pins the fail-closed semantics: even with FetchUserinfo=true, if the +// userinfo response also has no groups, the login fails closed. +func TestService_FetchUserinfoGroups_ReturnsErrGroupsMissing_WhenUserinfoAlsoEmpty(t *testing.T) { + idp := newMockIdP(t) + idp.overrideGroups = []string{} // ID token returns no groups + idp.userinfoGroups = nil // userinfo also returns no groups + prov := makeProvider(idp.URL(), "op-ui-empty") + prov.FetchUserinfo = true + pl := newStubPreLogin() + mappings := &stubMappings{roleIDs: []string{"r-operator"}} + users := newStubUsers() + sessions := &stubSessions{} + svc := NewService(&stubProviderLookup{provider: prov}, mappings, users, sessions, pl, "") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-ui-empty", "s", "test-nonce-fixed", "v-uixxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if !errors.Is(err, ErrGroupsMissing) { + t.Errorf("err = %v; want ErrGroupsMissing", err) + } +} + +// TestService_FetchUserinfoGroups_ReturnsErrGroupsMissing_WhenEndpointMissing +// pins the "operator opted in but provider doesn't list a userinfo +// endpoint" branch in fetchUserinfoGroups. +func TestService_FetchUserinfoGroups_ReturnsErrGroupsMissing_WhenEndpointMissing(t *testing.T) { + idp := newMockIdP(t) + idp.overrideGroups = []string{} + idp.omitUserinfoEndpoint = true // discovery doc lacks userinfo_endpoint + prov := makeProvider(idp.URL(), "op-ui-noendpoint") + prov.FetchUserinfo = true + pl := newStubPreLogin() + mappings := &stubMappings{roleIDs: []string{"r-operator"}} + users := newStubUsers() + sessions := &stubSessions{} + svc := NewService(&stubProviderLookup{provider: prov}, mappings, users, sessions, pl, "") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-ui-noendpoint", "s", "test-nonce-fixed", "v-uixxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if !errors.Is(err, ErrGroupsMissing) { + t.Errorf("err = %v; want ErrGroupsMissing", err) + } +} + +// TestService_HandleAuthRequest_PreLoginStoreError pins the wrap on a +// PreLoginStore.CreatePreLogin failure (e.g. database unavailable +// during the GET /auth/oidc/start handler). +func TestService_HandleAuthRequest_PreLoginStoreError(t *testing.T) { + idp := newMockIdP(t) + prov := makeProvider(idp.URL(), "op-pl-err") + pl := newStubPreLogin() + pl.createErr = fmt.Errorf("simulated pre-login insert failure") + svc := NewService( + &stubProviderLookup{provider: prov}, + &stubMappings{roleIDs: []string{"r-operator"}}, + newStubUsers(), + &stubSessions{}, + pl, + "", + ) + + _, _, _, err := svc.HandleAuthRequest(context.Background(), "op-pl-err") + if err == nil || !strings.Contains(err.Error(), "pre-login store") { + t.Errorf("err = %v; want pre-login store wrap", err) + } +} + +// TestService_DecryptClientSecret_RealEncryptedRoundTrip pins that the +// production decrypt path works against a real +// internal/crypto.EncryptIfKeySet output. Catches future regressions +// where the v3 blob format changes without updating this consumer. +func TestService_DecryptClientSecret_RealEncryptedRoundTrip(t *testing.T) { + plaintext := []byte("super-secret-client-secret-do-not-leak") + passphrase := "test-passphrase-please-keep-secret" + + blob, _, err := cryptopkg.EncryptIfKeySet(plaintext, passphrase) + if err != nil { + t.Fatalf("EncryptIfKeySet: %v", err) + } + if len(blob) == 0 { + t.Fatalf("EncryptIfKeySet returned empty blob") + } + + got, err := decryptClientSecret(blob, passphrase) + if err != nil { + t.Fatalf("decryptClientSecret: %v", err) + } + if string(got) != string(plaintext) { + t.Errorf("decrypt round-trip: got %q; want %q", string(got), string(plaintext)) + } +} + +// TestService_DecryptClientSecret_BadPassphraseFails pins that a wrong +// passphrase against a real encrypted blob returns an error (NOT the +// plaintext, NOT a panic). +func TestService_DecryptClientSecret_BadPassphraseFails(t *testing.T) { + plaintext := []byte("super-secret-client-secret-do-not-leak") + passphrase := "test-passphrase-correct" + + blob, _, err := cryptopkg.EncryptIfKeySet(plaintext, passphrase) + if err != nil { + t.Fatalf("EncryptIfKeySet: %v", err) + } + + got, err := decryptClientSecret(blob, "wrong-passphrase-different") + if err == nil { + t.Errorf("decryptClientSecret with wrong passphrase: err = nil, got = %q; want non-nil err", string(got)) + } +} + +// TestService_RandomB64URL_PropagatesReadError exercises the readRand +// seam by overriding it to return an error. Asserts the production code +// surfaces the error rather than silently returning an empty string. +func TestService_RandomB64URL_PropagatesReadError(t *testing.T) { + original := readRand + readRand = func(_ []byte) (int, error) { + return 0, fmt.Errorf("simulated entropy starvation") + } + defer func() { readRand = original }() + + got, err := randomB64URL(32) + if err == nil { + t.Errorf("randomB64URL: err = nil; want non-nil") + } + if got != "" { + t.Errorf("randomB64URL: returned %q on error path; want empty string", got) + } +} + +// TestService_HandleAuthRequest_RandomFailureSurfaces pins that a +// state-generation failure from the readRand seam surfaces through the +// HandleAuthRequest path as a wrapped "state generate" error. +func TestService_HandleAuthRequest_RandomFailureSurfaces(t *testing.T) { + idp := newMockIdP(t) + svc, _ := newServiceWithProviderAndPL(t, idp.URL(), "op-rand-fail") + + original := readRand + readRand = func(_ []byte) (int, error) { + return 0, fmt.Errorf("simulated rng exhaustion") + } + defer func() { readRand = original }() + + _, _, _, err := svc.HandleAuthRequest(context.Background(), "op-rand-fail") + if err == nil || !strings.Contains(err.Error(), "state generate") { + t.Errorf("err = %v; want state generate wrap", err) + } +} + +// TestService_HandleAuthRequest_NonceRandomFailureSurfaces lets the +// state-generation succeed on call 1 and fails the nonce-generation on +// call 2. Pins the second readRand call's error wrap. +func TestService_HandleAuthRequest_NonceRandomFailureSurfaces(t *testing.T) { + idp := newMockIdP(t) + svc, _ := newServiceWithProviderAndPL(t, idp.URL(), "op-nonce-rand-fail") + + original := readRand + calls := 0 + readRand = func(b []byte) (int, error) { + calls++ + if calls == 1 { + return original(b) // state succeeds + } + return 0, fmt.Errorf("simulated rng exhaustion on nonce") // nonce fails + } + defer func() { readRand = original }() + + _, _, _, err := svc.HandleAuthRequest(context.Background(), "op-nonce-rand-fail") + if err == nil || !strings.Contains(err.Error(), "nonce generate") { + t.Errorf("err = %v; want nonce generate wrap", err) + } +} + +// TestService_HandleCallback_RejectsTokenResponseMissingIDToken pins +// the "token response missing id_token" branch — the IdP returned a +// 200 from /token but the response payload lacked the id_token field +// (a misconfigured IdP, or a OAuth2-only flow we shouldn't be hitting). +func TestService_HandleCallback_RejectsTokenResponseMissingIDToken(t *testing.T) { + idp := newMockIdP(t) + idp.suppressIDToken = true + svc, pl := newServiceWithProviderAndPL(t, idp.URL(), "op-no-idtok") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-no-idtok", "s", "test-nonce-fixed", "v-noidxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if err == nil || !strings.Contains(err.Error(), "missing id_token") { + t.Errorf("err = %v; want missing id_token error", err) + } +} + +// TestService_FetchUserinfoGroups_ReturnsErrGroupsMissing_WhenUserinfoFails +// pins the UserInfo-fetch HTTP error wrap. With FetchUserinfo=true and +// /userinfo returning HTTP 500, the service surfaces ErrGroupsMissing +// to the caller (the inner error stays in the audit row, not the wire). +func TestService_FetchUserinfoGroups_ReturnsErrGroupsMissing_WhenUserinfoFails(t *testing.T) { + idp := newMockIdP(t) + idp.overrideGroups = []string{} + idp.userinfoFails = true + prov := makeProvider(idp.URL(), "op-ui-500") + prov.FetchUserinfo = true + pl := newStubPreLogin() + mappings := &stubMappings{roleIDs: []string{"r-operator"}} + users := newStubUsers() + sessions := &stubSessions{} + svc := NewService(&stubProviderLookup{provider: prov}, mappings, users, sessions, pl, "") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-ui-500", "s", "test-nonce-fixed", "v-uifxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if !errors.Is(err, ErrGroupsMissing) { + t.Errorf("err = %v; want ErrGroupsMissing", err) + } +} + +// TestService_AlgPinning_HeaderMissingColonAfterAlg covers the parser +// branch where the alg key appears but isn't followed by a colon (a +// malformed header that's still valid base64 + valid JSON outer shape). +func TestService_AlgPinning_HeaderMissingColonAfterAlg(t *testing.T) { + // `"alg" "RS256"` — alg key but no colon between key and value. + // Note: this is intentionally not valid JSON; the minimal parser + // only checks for the colon and rejects this shape conservatively. + header := `{"alg" "RS256"}` + token := base64.RawURLEncoding.EncodeToString([]byte(header)) + ".body.sig" + rejected, _ := isDisallowedAlg(token) + if !rejected { + t.Errorf("header missing colon after alg: rejected = false; want true") + } +} + +// TestService_AlgPinning_HeaderAlgValueNotQuoted covers the parser +// branch where the value after the colon isn't a JSON string literal +// (e.g., a number or unquoted token). +func TestService_AlgPinning_HeaderAlgValueNotQuoted(t *testing.T) { + header := `{"alg":42}` + token := base64.RawURLEncoding.EncodeToString([]byte(header)) + ".body.sig" + rejected, _ := isDisallowedAlg(token) + if !rejected { + t.Errorf("header with non-string alg: rejected = false; want true") + } +} + +// TestService_AlgPinning_HeaderAlgValueUnterminatedString covers the +// parser branch where the value starts a JSON string but never closes +// it (truncated header). +func TestService_AlgPinning_HeaderAlgValueUnterminatedString(t *testing.T) { + // Valid base64 of `{"alg":"RS256` (missing closing quote + brace). + header := `{"alg":"RS256` + token := base64.RawURLEncoding.EncodeToString([]byte(header)) + ".body.sig" + rejected, _ := isDisallowedAlg(token) + if !rejected { + t.Errorf("header with unterminated alg string: rejected = false; want true") + } +} + +// TestService_UpsertUser_ValidateErrorOnEmptyEmail pins the +// User.Validate failure path. The IdP returns an empty email (missing +// claim); the upsertUser display-name fallback resolves to "" too; +// User.Validate then trips ErrUserEmptyEmail. +func TestService_UpsertUser_ValidateErrorOnEmptyEmail(t *testing.T) { + idp := newMockIdP(t) + idp.overrideEmail = "" // sentinel — see /token handler patch below + idp.overrideName = "" // suppress name to force email fallback + prov := makeProvider(idp.URL(), "op-validate-err") + pl := newStubPreLogin() + mappings := &stubMappings{roleIDs: []string{"r-operator"}} + users := newStubUsers() + sessions := &stubSessions{} + svc := NewService(&stubProviderLookup{provider: prov}, mappings, users, sessions, pl, "") + + cookie, _, _ := pl.CreatePreLogin(context.Background(), "op-validate-err", "s", "test-nonce-fixed", "v-valxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + _, err := svc.HandleCallback(context.Background(), cookie, "code", "s", "ip", "ua") + if err == nil || !strings.Contains(err.Error(), "validate") { + t.Errorf("err = %v; want validate wrap", err) + } +}