Compare commits

...

31 Commits

Author SHA1 Message Date
shankar0123 a53a4b845b fix(gui,api): close C-001 + C-002 — ownership + agent FK contract
C-001 — CreateCertificate was server-accepted with null owner_id,
team_id, renewal_policy_id because the GUI neither collected the fields
nor enforced them, even though the backend's ManagedCertificate schema
and handler contract treat them as required. Fix the contract at all
four layers:

  - web/src/pages/CertificatesPage.tsx: replace owner_id/team_id free-
    text inputs with <select> elements fed by getOwners/getTeams/
    getPolicies queries; mark all three required; gate the Create
    button on owner_id + team_id + renewal_policy_id being set.
  - internal/api/handler/certificates.go: ValidateRequired for
    owner_id, team_id, renewal_policy_id on CreateCertificate so the
    handler returns HTTP 400 with the offending field name before the
    service layer is reached.
  - internal/mcp/types.go: drop ',omitempty' from
    CreateCertificateInput.RenewalPolicyID so the MCP schema reflects
    the required contract; Update inputs keep partial-update semantics.
  - api/openapi.yaml: 'required: [name, common_name, renewal_policy_id,
    issuer_id, owner_id, team_id]' was already present on the Create
    schema; clarified DeploymentTarget.agent_id description to note the
    FK contract.

C-002 — CreateTargetWizard accepted an empty or bogus agent_id and the
service inserted directly, producing a Postgres 23503 FK-violation that
bubbled out as a generic HTTP 500. The FK itself (migration 000001 line
104: agent_id TEXT NOT NULL REFERENCES agents(id)) is correct; we keep
the schema strict and add validation at three layers:

  - internal/service/target.go: introduce
    ErrAgentNotFound sentinel and pre-validate agent_id in
    TargetService.CreateTarget — empty string returns
    'agent_id is required'; a nonexistent id returns the full
    'referenced agent does not exist: <id>' error. Both wrap
    ErrAgentNotFound via fmt.Errorf %w so callers can use errors.Is.
  - internal/api/handler/targets.go: ValidateRequired on agent_id; map
    errors.Is(err, service.ErrAgentNotFound) to HTTP 400 instead of
    letting it fall through to the generic 500 branch.
  - internal/mcp/types.go: drop ',omitempty' from
    CreateTargetInput.AgentID to match the required contract.
  - web/src/pages/TargetsPage.tsx: replace the free-text Agent ID input
    with a <select> populated from getAgents(); include agent in the
    canProceedToReview gate so Next is disabled until an agent is
    chosen.

Regression coverage (21 new subtests total):

  - TestCreateCertificate_MissingRequiredField_Returns400 — 6 subtests,
    one per required field, each proves the handler guard fires before
    the mock service is called.
  - TestCreateTarget_MissingAgentID_Returns400 — handler guard.
  - TestCreateTarget_NonexistentAgent_Returns400 — pins the
    ErrAgentNotFound -> 400 translation.
  - TestTargetService_CreateTarget_MissingAgentID — errors.Is sentinel.
  - TestTargetService_CreateTarget_NonexistentAgentID — errors.Is.
  - The existing TestTargetService_CreateTarget_Success, along with
    TestCreateTarget_{MissingName,MissingType,NameTooLong}_* handler
    tests, were updated to seed a real agent or include agent_id in
    the request body so the happy paths still run cleanly.

Gates (Phase 4):
  - go build/vet/test/race: green
  - go test -cover: internal/service 68.7% (gate 55%),
    internal/api/handler 78.9% (gate 60%)
  - golangci-lint on service+handler+mcp: 0 issues
  - govulncheck: no reachable vulns
  - tsc --noEmit: clean
  - vitest: 223/223 passing

See cowork/certctl-coverage-gap-audit.md entries C-001 and C-002.
2026-04-18 16:01:40 +00:00
shankar0123 9143da5fa8 Merge branch 'fix/d-008-policy-engine-drift' 2026-04-18 14:56:06 +00:00
shankar0123 b3cc7cbdb2 fix(policies): close the D-006 loop — TitleCase seed canonicals + severity-aware, config-consuming rule engine (D-008)
D-008 was a three-part drift in the policy engine that made the
D-005/D-006 remediation cosmetic below the DB layer:

  (a) migrations/seed.sql INSERTed rules with pre-D-005 lowercase
      types ('ownership', 'environment', 'lifetime', 'renewal_window')
      that the handler validator rejects on Create/Update but that
      raw SQL INSERTs bypassed entirely. At runtime evaluateRule's
      switch fell through to the default "unknown policy rule type"
      error branch on every demo rule × every cert × every cycle,
      flooding logs while emitting zero violations.

  (b) migrations/seed_demo.sql persisted lowercase severity values
      ('critical', 'error', 'warning') on policy_violations rows.
      INSERT succeeded because that column had no CHECK, but any
      frontend comparing against the canonical PolicySeverity enum
      mis-categorized every seeded violation.

  (c) evaluateRule hardcoded Severity: PolicySeverityWarning on
      every emitted violation and ignored rule.Config entirely —
      so the D-006 per-rule severity column (000013) and every
      per-arm Config JSON ({allowed_issuer_ids, allowed_domains,
      required_keys, allowed, lead_time_days, max_days}) was dead
      data below the evaluation layer.

This commit lands (a)+(b)+(c) atomically. Shipping any subset
leaves the feature half-working.

## Changes

Domain (internal/domain/policy.go):
  * Add PolicyTypeCertificateLifetime as the 6th TitleCase canonical.
    Pre-D-008 the seeded "max-certificate-lifetime" rule had no engine
    arm — routing it through RenewalLeadTime would conflate "how
    close to expiry before we renew" with "how long can the cert
    possibly be", two distinct semantics. The new type accepts
    config {"max_days": int} and flags certs whose
    NotAfter - NotBefore exceeds the cap.

Handler validator (internal/api/handler/validation.go):
  * ValidatePolicyType allowlist grown to 6 canonicals
    (AllowedIssuers, AllowedDomains, RequiredMetadata,
    AllowedEnvironments, RenewalLeadTime, CertificateLifetime).

OpenAPI (api/openapi.yaml):
  * PolicyType enum grown to match domain.

Frontend (web/src/api/types.ts, types.test.ts):
  * POLICY_TYPES tuple gains CertificateLifetime; pin test asserts
    all 6 canonicals and rejects casing drift.

Migration 000014 (policy_violations severity CHECK):
  * Named CHECK constraint (policy_violations_severity_check)
    mirroring 000013's allowlist, defense-in-depth at the DB layer
    against future drift from bypassed writes (migrations, psql
    sessions, future callers). Symmetric down migration drops by
    name.

Seed data:
  * migrations/seed.sql rewritten to emit TitleCase canonicals with
    per-arm config JSON that actually exercises the config-consuming
    paths (not the missing-field backstops):
      - pr-require-owner         → RequiredMetadata     {"required_keys":["owner"]}                        Warning
      - pr-allowed-environments  → AllowedEnvironments  {"allowed":["production","staging","development"]} Error
      - pr-max-certificate-lifetime → CertificateLifetime {"max_days":90}                                   Critical
      - pr-min-renewal-window    → RenewalLeadTime      {"lead_time_days":14}                              Warning
    Severities are now differentiated per rule (D-006 intent).
  * migrations/seed_demo.sql violation rows flipped to TitleCase
    severity ('Critical', 'Error', 'Warning') so migration 000014
    applies cleanly on upgrade paths.

Engine rewrite (internal/service/policy.go):
  * evaluateRule rewritten. All six arms now:
      1. Parse rule.Config into the per-arm typed struct.
      2. Bad JSON → log at ValidateCertificate boundary and skip
         this rule (no co-located poisoning of other rules in the
         same batch).
      3. Empty/null Config → emit the pre-D-008 missing-field
         violation (backwards compat invariant — operators who
         haven't reconfigured still see the same output).
      4. Violations emitted carry rule.Severity (no more hardcoded
         Warning); D-006 column is now load-bearing.
  * CertificateLifetime arm reads NotBefore/NotAfter from the
    certificate's latest version via CertRepo. Injected via
    PolicyService.SetCertRepo() setter — avoids churning ~36
    NewPolicyService call sites while keeping the lifetime arm
    optional (degrades to a log+skip if the setter is not wired).

Server wiring (cmd/server/main.go):
  * policyService.SetCertRepo(certRepo) wired after construction.

Tests (internal/service/policy_test.go):
  * 25 new subtests across 5 groups:
      - TestEvaluateRule_SeverityPassThrough (6): every rule type
        emits violations carrying rule.Severity, not hardcoded.
      - TestEvaluateRule_ConfigConsumed (12): every per-arm Config
        path exercised positive + negative.
      - TestEvaluateRule_EmptyConfig_BackCompat (3): empty/null
        Config still emits pre-D-008 missing-field violations.
      - TestEvaluateRule_BadConfig_SkipsRule: malformed JSON logs
        and skips cleanly without poisoning neighbors.
      - TestEvaluateRule_CertificateLifetime_RepoScenarios (3):
        ok when repo wired, log+skip when not, handles missing
        NotBefore/NotAfter edges.

Provenance: D-008 surfaced during D-005/D-006 remediation review
in eef1db0. That commit added persistence and CI pins for the
severity field but did not re-verify the evaluation layer
consumed it; this finding and fix close the audit-process gap.
2026-04-18 14:55:56 +00:00
shankar0123 eef1db0f0a fix(policies): stop 400ing the "+ New Policy" button + add per-rule severity (D-005, D-006)
Coverage Gap Audit findings D-005 (P0) + D-006 (P1) fixed together in a
single commit because they share the same root cause — policy CRUD sending
values the backend silently rejects — and splitting them would leave a
half-working UI between commits.

## D-005 (P0): PoliciesPage dropdown 400s every Create Policy

Root cause
----------
`web/src/pages/PoliciesPage.tsx` populated the Type `<select>` from a
hardcoded `['key_algorithm', 'ownership', 'allowed_issuers', ...]` array.
The backend's `internal/api/handler/validators.go::ValidatePolicyType`
enforces the TitleCase allowlist `AllowedIssuers`, `AllowedDomains`,
`RequiredMetadata`, `AllowedEnvironments`, `RenewalLeadTime` — defined in
`internal/domain/policy.go`. Every Create Policy request was rejected with
`400 invalid policy type`. The error surfaced only as a transient toast;
the modal closed anyway. Silent user-visible failure.

Fix
---
- `web/src/api/types.ts`: added `POLICY_TYPES` and `POLICY_SEVERITIES`
  tuples with `as const` and narrowed `PolicyRule.type`, `.severity`, and
  `PolicyViolation.severity` to the literal-union types. Dropdown is now
  sourced from the tuple; casing drift becomes a compile error.
- `web/src/pages/PoliciesPage.tsx`: rekeyed `severityStyles` /
  `severityDots` to the TitleCase values, added `humanize()` for display
  (AllowedIssuers → "Allowed Issuers"), removed the `badge-neutral`
  fallback that was papering over the mismatch.
- `web/src/api/types.test.ts` (new): pins both tuples exactly. If anyone
  edits one side of the frontend/backend contract without the other, CI
  fails with a clear assertion. Pure-TS vitest, no RTL dependency.

## D-006 (P1): `severity` field silently dropped on create/update

Root cause
----------
`PolicyRule` had no `Severity` field in `internal/domain/policy.go`. The
frontend has always sent `severity` on create/update, but Go's
`json.Decoder` (default settings, no `DisallowUnknownFields`) silently
dropped it. The value never reached PostgreSQL. Every rule rendered with
the same severity because there was no severity — just a display
computation downstream.

Fix: option (b), full-stack schema add (not delete-the-field)
-------------------------------------------------------------
- Migration `000013_policy_rule_severity` (up + down): adds
  `severity VARCHAR(50) NOT NULL DEFAULT 'Warning'` to `policy_rules` with
  CHECK constraint `severity IN ('Warning', 'Error', 'Critical')`. No
  index — three-value column on a low-thousands-rows table, planner will
  seq-scan regardless. PG 11+ metadata-only ADD COLUMN, safe on live data.
- `internal/domain/policy.go`: added `Severity PolicySeverity` field.
- `internal/repository/postgres/policy.go`: plumbed `severity` through
  ListRules SELECT + Scan, GetRule SELECT + Scan, CreateRule INSERT,
  UpdateRule UPDATE (4 queries).
- `internal/service/policy.go::UpdatePolicy`: if the client omits
  severity on a PUT (zero-value empty string), fetch the existing rule
  and preserve its severity. Without this, partial updates would trip the
  NOT NULL CHECK and 500. Preserves pre-existing behavior for Name/Type
  (out of scope).
- `internal/api/handler/policies.go::CreatePolicy`: default empty severity
  to `'Warning'`, then validate via `ValidatePolicySeverity`. 400 with
  clear message instead of 500 on CHECK violation. `UpdatePolicy`:
  validates severity only when provided.
- `internal/mcp/types.go` + `internal/mcp/tools.go`: added optional
  `severity` on the MCP `create_policy` / `update_policy` tool inputs so
  LLM callers stay in sync with the wire contract.
- `api/openapi.yaml`: added `severity` to the `PolicyRule` schema with
  the enum and default.

Acceptance criterion (user-defined)
-----------------------------------
"Create a rule with severity=Critical, reload the page, and still see
Critical — no silent drops." Verified end-to-end: frontend sends
`severity: "Critical"`, handler validates, service persists, DB stores,
GET returns, React renders the correct badge.

Seed data
---------
`migrations/seed.sql`: four demo rules now have differentiated severities
— `pr-require-owner` → Warning, `pr-allowed-environments` → Error,
`pr-max-certificate-lifetime` → Critical, `pr-min-renewal-window` →
Warning. The user called out that seeding all four at the same severity
makes the feature look decorative; differentiation demonstrates the
column carries real signal.

## Integration test fix (side effect of D-006)

`internal/integration/e2e_test.go::TestCrossResourceWorkflow/CreatePolicy`
was sending `"severity": "High"` — a value from the pre-audit severity
vocabulary that the new `ValidatePolicySeverity` correctly rejects with
400. Changed to `"Error"` (closest semantic match in the new TitleCase
allowlist). Only severity reference in the integration/ directory;
verified via grep.

## Out of scope, logged for follow-up (d/D-008)

Three policy-engine drift issues orthogonal to D-005 + D-006, explicitly
deferred per direction:

1. `migrations/seed.sql` policy_rules INSERTs use lowercase TYPE values
   (`'ownership'`, `'environment'`, `'lifetime'`, `'renewal_window'`).
   These are load-bearing on `internal/service/policy.go::evaluateRule`'s
   `switch rule.Type` (which also uses the lowercase strings). Migrating
   requires coordinated changes across seed + evaluation engine.
2. `migrations/seed_demo.sql:482-483` contains lowercase `'critical'`
   severity — will now fail the new CHECK constraint. Separate fix.
3. `evaluateRule` hardcodes `Severity: domain.PolicySeverityWarning` on
   emitted violations and ignores the configured `rule.Config`. The new
   severity column is read correctly on the CRUD path but not yet
   consulted during evaluation.

## Verification

Backend:
- `go build ./...` — clean
- `go vet ./...` — clean
- `go test -short ./...` — all packages green, including
  `internal/service` (policy service), `internal/api/handler` (policy +
  MCP handler tests), `internal/integration` (e2e_test.go after fix),
  `internal/domain`, `internal/repository/postgres`.

Frontend:
- `tsc --noEmit` — clean
- `vitest run` — 223/223 passing (4 new assertions in types.test.ts)
- `vite build` — clean (only the pre-existing chunk-size warning)
2026-04-18 13:02:04 +00:00
shankar0123 72f5246ce3 Merge branch 'fix/m11-cosign-v3-sign-blob-bundle': M-11 cosign v3 sign-blob migration 2026-04-18 09:29:25 +00:00
shankar0123 cb308bb4c7 ci(release): migrate cosign sign-blob to --bundle (cosign v3.0)
Cosign v3.0 (shipped by default with sigstore/cosign-installer@cad07c2e,
release v3.0.5) removed --output-signature and --output-certificate from
the sign-blob subcommand. The replacement is a single --bundle flag that
emits a unified Sigstore bundle (.sigstore.json) containing the
signature, certificate chain, and Rekor inclusion proof in one file.

This change migrates both sign-blob invocations in .github/workflows/
release.yml (per-binary matrix signing and aggregate checksums.txt
signing), updates the artefact upload paths, the artefact aggregation
case filter, the GitHub Release asset list, and the release-notes body
verify-blob example. The README cosign verification snippet and sidecar
description are also updated to the --bundle / .sigstore.json shape.

No cosign version pinning. No legacy fallback. OCI image signing
(cosign sign on image digest) is unchanged — only sign-blob flags
changed in v3.0. See M-11 in certctl-audit-report.md.

Verification gates:
- YAML parse: OK
- go vet ./...: exit 0
- go build ./...: exit 0
- grep 'cosign sign-blob' release.yml: 2 (expected: 2)
- grep '.sigstore.json' release.yml: 9 (expected: >=5)
- grep '.sig/.pem' release.yml non-comment: 0 (expected: 0)
- README legacy cosign refs: 0 (expected: 0)
- docs/ legacy cosign refs: 0 (expected: 0)

Coverage: unchanged (CI workflow edit + README — zero Go code touched).
2026-04-18 09:29:20 +00:00
shankar0123 ad93e99158 Merge branch 'fix/m10-openapi-spec-drift': M-10 OpenAPI spec drift reconciliation 2026-04-18 03:21:45 +00:00
shankar0123 9d0c3dfa15 docs(openapi): reconcile api/openapi.yaml with router routes (M-10)
Add 9 missing operations to api/openapi.yaml that exist in router.go but
were absent from the spec. Spec-only change with no runtime Go code
changes; all 106 pre-existing operationIds preserved byte-identical.

New operationIds:
  - testTargetConnection (POST /api/v1/targets/{id}/test)
  - verifyDeployment    (POST /api/v1/jobs/{id}/verify)
  - getJobVerification  (GET  /api/v1/jobs/{id}/verification)
  - estCACerts          (GET  /.well-known/est/cacerts)
  - estSimpleEnroll     (POST /.well-known/est/simpleenroll)
  - estSimpleReEnroll   (POST /.well-known/est/simplereenroll)
  - estCSRAttrs         (GET  /.well-known/est/csrattrs)
  - scepGet             (GET  /scep)
  - scepPost            (POST /scep)

Spec operations: 106 → 115 (matches 115 router routes exactly).

Verification:
  - openapi-spec-validator: OK
  - go build ./...: clean
  - go vet ./...:   clean
  - go test -race -count=1 -short ./...: 54 packages ok, 0 FAIL
  - golangci-lint run ./...: 0 issues
  - govulncheck ./...: 0 vulnerabilities in our code
  - tsc --noEmit: 0 errors
  - vitest run: 3 files, 218 tests passed

sha256 before: 7c14f77107a86f8de82fe91b7f5e16cca11206d1e1fab7b7bd77ff396620fdf3
sha256 after:  87bd92d0407d63643bec612d27261bf489563beb90d0791ea71cde26346f83d3
2026-04-18 03:21:40 +00:00
shankar0123 2c9602db71 Merge branch 'fix/m9-sentinel-discovery-log-levels': M-9 sentinel discovery log-level fix 2026-04-18 02:53:50 +00:00
shankar0123 ef670fa6da fix(m-9): aggregate per-endpoint scan errors in NetworkScanService
Before this fix, RunScan declared `scanErrors []string` but never
appended to it. As a result:

  - the summary Info log ("network target scan completed") always
    reported `"errors": 0`, regardless of how many endpoints failed
  - the DiscoveryReport's `Errors` field — stored on the scan record
    and surfaced in the GUI scan history — was always nil

Operators who needed to understand scan failures had to enable Debug
logging and grep through the noise of expected sweep-scan connection
refusals. The per-endpoint log level (Debug) is deliberate and correct
— scanning a /24 typically produces 200+ connection-refused results,
and logging each at Warn would create massive log spam at default
verbosity. The bug was the silent loss of the aggregate count.

This commit:

  - extracts the partitioning logic into `collectScanResults`, a pure
    method that splits per-endpoint results into discovered certificate
    entries and a list of endpoint error strings
  - populates the errors list with "<address>: <error>" so the scan
    record correlates failures back to specific endpoints
  - preserves the existing Debug-level per-endpoint log (sweep noise
    discipline) — no change to default-verbosity log output

The summary Info log's "errors" field and the DiscoveryReport's Errors
field now reflect the true failure count. Debug detail remains
available for operators diagnosing specific endpoints.

Audit scope note: the M-9 finding narrative implied broad Debug-level
hiding of real errors across AWS SM, Azure KV, GCP SM, and network
scan sentinel agents. On investigation, the three cloud-discovery
connectors (awssm, azurekv, gcpsm) already use appropriate Warn/Error
discipline for per-item and root-level failures. Only the network
scanner had a silent observability gap, and it was a missed append
rather than a misapplied log level. See audit resolution log for
full details.

CWE: CWE-778 (Insufficient Logging) — aggregate failure count lost.

Tests: 4 new unit tests on collectScanResults covering the
aggregation path (success + failure mix), all-success, all-failed,
and empty-input degenerate cases. All tests pass with -race.

Verification:
  - go build ./cmd/server/... ./cmd/agent/... ./cmd/mcp-server/... ./cmd/cli/...  exit 0
  - go vet ./...                                                                    exit 0
  - go test -race -count=1 -timeout 300s [full CI race path]                        exit 0
  - golangci-lint run ./... --timeout 5m (v2.11.4)                                  0 issues
  - govulncheck ./... (@latest)                                                     0 in-code vulnerabilities
  - go test -count=1 -cover ./internal/service/...                                  68.0% (> 55% threshold)

Invariants preserved:
  - collectScanResults signature: method on *NetworkScanService,
    input []domain.NetworkScanResult, return ([]DiscoveredCertEntry, []string)
  - Debug log key names unchanged ("address", "error")
  - DiscoveryReport schema unchanged (Errors field already existed)
  - Sentinel agent ID "server-scanner" unchanged
  - No migration, no API, no wire-format change

Refs: M-9 Medium finding; audit resolution log appended in follow-up
commit on workspace-level audit report.
2026-04-18 02:34:14 +00:00
shankar0123 5a6ec39cfd Merge branch 'fix/m2-pr-f-scheduler-contextcheck-audit-closeout' 2026-04-18 01:43:56 +00:00
shankar0123 e3196e7b50 M-2 PR-F: Middleware/ACME ctx-propagation + contextcheck linter + audit closeout
Final PR in the six-commit M-2 sequence (PR-A: CertificateService cluster
cdc9d03, PR-B: IssuerService+TargetService eb14236, PR-C: Policy/Profile/
Owner/Team 2497be4, PR-D: Job/Notification/Audit ccd89c3, PR-E: AgentService
283ec27, PR-F: this commit). PR-A through PR-E collapsed the service-layer
shim methods and deleted every in-production context.Background() /
context.TODO() call from internal/service/; this PR completes the sweep
across the non-service tiers (HTTP middleware + ACME connector) and wires
the contextcheck linter so regressions fail CI.

Three narrow edits land the D-3 pattern (context.WithoutCancel for
subsidiary async writes and deferred shutdown contexts):

  - internal/api/middleware/audit.go  -- async audit goroutine now runs
    on auditCtx := context.WithoutCancel(r.Context()) instead of
    context.Background(). Preserves request-scoped values (trace ID, auth)
    while detaching from the request's cancellation so the audit write
    does not get killed when the response completes. Goroutine is still
    tracked via a.wg (M-1 shutdown drain) so Flush(ctx) behaviour is
    unchanged. CWE-770 Missing Release (goroutine leak potential) +
    CWE-400 Resource Exhaustion (missed cancellation propagation).

  - internal/api/middleware/middleware.go -- Recovery panic path now
    logs via slog.ErrorContext(ctx, ...) instead of log.Printf. Request-
    scoped trace/auth metadata now carries through the panic log, matching
    every other request log. D-3 non-bypass: the context is r.Context()
    captured before the defer, so even a panic mid-handler propagates
    the ctx's trace ID into the ERROR log line.

  - internal/connector/issuer/acme/acme.go (HTTP-01 challenge server
    shutdown) -- defer shutdown context derived from
    context.WithTimeout(context.WithoutCancel(ctx), 5s) instead of
    context.Background(). Preserves parent ctx values, detaches from
    parent cancellation so Shutdown always gets its full 5-second
    budget even when the parent was cancelled. Matches the same pattern
    applied in ACME's solveAuthorizationsDNS01 and solveAuthorizationsDNSPersist01.

Linter wiring: .golangci.yml adds `contextcheck` to the enabled set.
golangci-lint v2.11.4 now fails CI on any function that takes a
context.Context parameter but calls into context.Background() or
context.TODO() instead of propagating -- regression guard for all five
prior PRs.

Verification (CI parity, GOCACHE=/tmp/gocache GOMODCACHE=/tmp/gomodcache
GOLANGCI_LINT_CACHE=/tmp/lintcache):

  - go build ./... -> 0
  - go vet ./... -> 0
  - golangci-lint run (contextcheck enabled) -> 0 issues
  - go test -race -short ./internal/api/middleware/... -> PASS
  - go test -race -short ./internal/scheduler/... -> PASS
  - go test -race -short ./internal/connector/issuer/acme/... -> PASS
  - go test -race -short ./internal/service/... -> PASS
  - rg "context\.(Background|TODO)\(\)" internal/service/ internal/scheduler/
    internal/connector/ internal/api/middleware/ -> 0 non-test hits
    (one pedagogical godoc reference in audit.go documenting why
    context.Background() would be wrong remains intentional)

Wire-format invariants preserved: 0 API routes, 0 SQL migrations, 0
frontend bytes, 0 OpenAPI bytes, 0 connector interface signature changes,
0 new env vars, 0 new external dependencies (pure context stdlib). The
AuditRecorder interface signature, the body-hash algorithm (SHA-256 16
hex chars), the excluded-path short-circuit, the actor-extraction path,
the responseWriter status-capture wrapper, the AuditServiceAdapter, and
all 116 API routes under /api/v1/, /.well-known/est/, /scep, /health,
/auth are byte-identical.

M-2 aggregate across PR-A through PR-F: 57 files, +635 / -613 (PR-A 12f
+227/-237, PR-B 9f +150/-146, PR-C 17f +156/-148, PR-D 11f +67/-63,
PR-E 4f +9/-15, PR-F 4f +26/-4). With M-2 closed, 8 of 10 Medium
findings resolved; M-9, M-10, L-1..L-4, I-1..I-8 remain post-v2.1.0
hardening batch.

Audit complete. Commit: 1f6cf0eafa. Sections: 12. Findings: 2/7/10/4/6.
2026-04-18 01:43:47 +00:00
shankar0123 bea69efd12 Merge branch 'fix/m2-pr-e-agent-service'
PR-E of 6: AgentService ctx-first collapse.

Collapses the HeartbeatWithContext wrapper into a single Heartbeat
method. Handler-facing method name is preserved (D-4); the handler
service interface and mock already expected ctx-first, so this PR
touches only the service layer and its tests (4 files, 9+/15-).

Verification on the feature branch: build, vet, test (-short),
test -race, full-module test -short, and golangci-lint all clean.

Audit complete. Commit: 1f6cf0eafa. Sections: 12. Findings: 2/7/10/4/6.
2026-04-18 01:25:30 +00:00
shankar0123 283ec27ca4 fix(m2-pr-e): collapse AgentService.HeartbeatWithContext into Heartbeat
PR-E of 6 in the M-2 end-to-end remediation sequence. Collapses the
HeartbeatWithContext wrapper into a single ctx-first Heartbeat method,
matching D-1 (ctx-only signatures, no dual forms). The handler-facing
method name is preserved (D-4) — internal/api/handler/agents.go already
declares `Heartbeat(ctx, ...)` on its local service interface, and the
handler mock at internal/api/handler/agent_handler_test.go already
takes `_ context.Context` as its first param, so no handler churn.

Changes
-------
internal/service/agent.go
  - Delete the zero-body Heartbeat wrapper that forwarded to
    HeartbeatWithContext with context.Background().
  - Rename HeartbeatWithContext → Heartbeat (ctx-bearing body
    folded directly into the canonical method).

internal/service/agent_test.go
  - TestHeartbeat (L95) and TestHeartbeat_NotFound (L128):
    agentService.HeartbeatWithContext(ctx, ...) → .Heartbeat(ctx, ...).

internal/service/concurrent_test.go
  - L162: agentSvc.HeartbeatWithContext(ctx, agentID, metadata)
    → .Heartbeat(ctx, agentID, metadata).

internal/service/context_test.go
  - L179 + L232: agentSvc.HeartbeatWithContext(ctx, ...) → .Heartbeat(...)
  - L185 + L238 t.Logf strings: "HeartbeatWithContext with ..." →
    "Heartbeat with ..." to match the collapsed method name.

Verification (Go 1.25.9 linux/arm64, CI-parity caches)
------------------------------------------------------
  go build ./...                 clean
  go vet ./...                   clean
  go test -short ./internal/service/... ./internal/api/handler/... \
    ./internal/integration/...   all ok
  go test -race -short same set  all ok
  go test -short ./...           all packages ok
  golangci-lint run ./...        0 issues

Locked decisions from the M-2 plan:
  D-1 ctx-only signatures (no dual forms)
  D-4 preserve handler method names facing the router
  D-5 domain types stay ctx-free

Audit complete. Commit: 1f6cf0eafa. Sections: 12. Findings: 2/7/10/4/6.
2026-04-18 01:25:20 +00:00
shankar0123 a67a6b6c30 Merge branch 'fix/m2-pr-d-job-notification-audit'
PR-D: Thread ctx through Job + Notification + Audit service cluster.
Collapse CancelJobWithContext into CancelJob; eliminate 10
context.Background() hits.

Audit complete. Commit: 1f6cf0eafa. Sections: 12. Findings: 2/7/10/4/6.
2026-04-18 01:20:58 +00:00
shankar0123 ccd89c348f fix(m2-pr-d): thread ctx through Job/Notification/Audit services
Collapse CancelJobWithContext into CancelJob; eliminate 10 context.Background()
hits across the Job+Notification+Audit service cluster by threading ctx
through their handler-facing service interfaces.

Services (ctx-first):
- service/job.go: ListJobs, GetJob, CancelJob, ApproveJob, RejectJob now
  accept ctx; the CancelJobWithContext wrapper is removed (handler callers
  continue to invoke CancelJob, now ctx-aware).
- service/notification.go: ListNotifications, GetNotification, MarkAsRead
  accept ctx.
- service/audit.go: ListAuditEvents, GetAuditEvent accept ctx.

Handlers (interface + callsites):
- handler/jobs.go, handler/notifications.go, handler/audit.go: local
  service interfaces updated, r.Context() threaded at every callsite.

Tests:
- Mock services updated to match the new interfaces (ctx accepted and
  ignored via '_ context.Context' first parameter; Fn closure fields
  unchanged).
- job_test.go / notification_test.go callsites thread context.Background()
  to match production shape.

Verification:
  go build ./...                 ok
  go vet ./...                   ok
  go test -short ./...           ok
  go test -race -short ./...     ok
  golangci-lint run ./...        0 issues

Locked decisions from the M-2 plan:
  D-1 ctx-only signatures (no dual forms)
  D-4 preserve handler method names facing the router
  D-5 domain types stay ctx-free

Audit complete. Commit: 1f6cf0eafa. Sections: 12. Findings: 2/7/10/4/6.
2026-04-18 01:20:46 +00:00
shankar0123 478a141498 Merge branch 'fix/m2-pr-c-crud-cluster' 2026-04-18 01:10:10 +00:00
shankar0123 2497be496d M-2 PR-C: Collapse Policy/Profile/Owner/Team services to ctx-first signatures
- Add ctx first param to 21 service-layer handler-interface methods
  across policy.go (6), profile.go (5), owner.go (5), team.go (5)
- Replace 24 context.Background() call sites with received ctx; use
  context.WithoutCancel(ctx) for subsidiary audit-recording ops to
  preserve fire-and-forget audit semantics without inheriting caller
  cancellation
- Add ctx first param to 21 handler-interface method signatures across
  policies.go (6), profiles.go (5), owners.go (5), teams.go (5)
- Thread r.Context() through 21 HTTP handler sites (ListPolicies,
  GetPolicy, CreatePolicy, UpdatePolicy, DeletePolicy, ListViolations,
  ListProfiles, GetProfile, CreateProfile, UpdateProfile, DeleteProfile,
  ListOwners, GetOwner, CreateOwner, UpdateOwner, DeleteOwner,
  ListTeams, GetTeam, CreateTeam, UpdateTeam, DeleteTeam)
- Update MockPolicyService/MockProfileService/MockOwnerService/
  MockTeamService mock method impls with _ context.Context first param
  (Fn fields unchanged — closures do not need ctx); update mock impls
  in integration/lifecycle_test.go for all four services
- Update 12 service-layer test callsites (policy_test.go ×2,
  owner_test.go ×5, team_test.go ×5, profile_test.go ×13) to pass
  context.Background() at the call site

Audit complete. Commit: 1f6cf0eafa. Sections: 12. Findings: 2/7/10/4/6.
2026-04-18 01:10:06 +00:00
shankar0123 25dd6c07f3 Merge branch 'fix/m2-pr-b-issuer-target' 2026-04-18 00:47:02 +00:00
shankar0123 eb14236166 M-2 PR-B: Collapse IssuerService + TargetService to ctx-first signatures
- Delete bare TestConnection wrapper in IssuerService; rename
  TestConnectionWithContext → TestConnection
- Delete TestTargetConnection delegate shim in TargetService (canonical
  TestConnection already ctx-first)
- Add ctx first param to 10 handler-interface methods
  (ListIssuers/GetIssuer/CreateIssuer/UpdateIssuer/DeleteIssuer and
  ListTargets/GetTarget/CreateTarget/UpdateTarget/DeleteTarget)
- Replace 16 context.Background() call sites with received ctx
- Thread r.Context() through 12 HTTP handler sites in issuers.go and
  targets.go (outer TargetHandler.TestTargetConnection HTTP method name
  preserved for router compatibility)
- Update MockIssuerService, MockTargetService, and mockTargetService
  (integration) for ctx-first forwarding; update test callsite literals

Audit complete. Commit: 1f6cf0eafa. Sections: 12. Findings: 2/7/10/4/6.
2026-04-18 00:46:58 +00:00
shankar0123 bbb628243f Merge branch 'fix/m2-pr-a-certificate-cluster' 2026-04-18 00:29:40 +00:00
shankar0123 cdc9d03d5b fix(m-2): thread context through CertificateService cluster
Collapses CertificateService, RevocationSvc, and CAOperationsSvc to
ctx-accepting method signatures. Removes context.Background() synthesis
at 24 internal call sites across certificate.go, revocation_svc.go, and
ca_operations.go.

- Primary repo calls inherit request cancellation via the passed ctx.
- Audit and notification dispatches use context.WithoutCancel(ctx) so
  they survive client disconnect.
- Collapses TriggerRenewal/TriggerRenewalWithActor,
  TriggerDeployment/TriggerDeploymentWithActor, and
  RevokeCertificate/RevokeCertificateWithActor sibling pairs into single
  canonical ctx-accepting methods (decisions D-1, D-2).

Handlers pass r.Context(). Mocks and tests updated to match new
signatures. No HTTP surface change, no OpenAPI change.

PR 1 of 6 in the M-2 remediation chain. Master green at this commit.

Refs: certctl-audit-report.md M-2 (L143, L224)
2026-04-18 00:29:37 +00:00
shankar0123 e951d319d0 Merge branch 'fix/m1-audit-shutdown-drain'
Resolves M-1 (Medium): Audit recorder shutdown drain.

The API audit middleware's detached recording goroutines now drain
during graceful shutdown via AuditMiddleware.Flush (sync.WaitGroup +
timeout-aware select), called between http.Server.Shutdown and
db.Close. Prevents silent audit-event loss on SIGTERM
(CWE-662 / CWE-400).
2026-04-17 17:29:54 +00:00
shankar0123 d14a45401b fix(audit): drain in-flight recording goroutines on shutdown (M-1)
Audit events spawned from the HTTP middleware ran in detached goroutines
using context.Background(). On SIGTERM the DB pool was closed before
those goroutines finished writing, silently dropping audit events
(CWE-662 Improper Synchronization / CWE-400 Uncontrolled Resource
Consumption).

NewAuditLog now returns an *AuditMiddleware struct that tracks every
spawned goroutine with sync.WaitGroup. Callers wire the middleware via
its Middleware method value (preserves the existing
func(http.Handler) http.Handler shape) and drain the WaitGroup with
Flush(ctx), which blocks until in-flight recordings complete or the
provided context is cancelled — mirroring scheduler.WaitForCompletion.

Flush is invoked in cmd/server/main.go between http.Server.Shutdown
(no new requests accepted) and db.Close (pool torn down), with a
timeout returning ErrAuditFlushTimeout wrapping ctx.Err().

Request-derived inputs (method, path, status) are snapshotted before
the goroutine spawn so the worker does not race with http.Server
reusing r after the handler returns.

Tests:
  TestAuditLog_FlushDrainsInFlightGoroutines
  TestAuditLog_FlushTimeoutReturnsErrAuditFlushTimeout

Verification:
  go build ./...                            : 0
  go vet ./...                              : 0
  go test -race -short ./...                : 0 (all packages)
  go test -cover ./internal/api/middleware  : 81.4%
  golangci-lint run                         : 0 issues
  govulncheck ./...                         : 0 vulns in called code
2026-04-17 17:29:48 +00:00
shankar0123 655e2879e6 feat(frontend): add Owner field to OnboardingWizard Certificate step
The first-run onboarding wizard's Certificate step now surfaces an
Owner dropdown (required) alongside Issuer and Profile, matching the
ownership model introduced in M11b. Prevents newly-created certs from
being unowned and bypassing notification routing.

- web/src/pages/OnboardingWizard.tsx: getOwners query, ownerId state,
  Owner <select>, required-field guard (nextDisabled), empty-state link
  to /owners page when no owners exist yet.

Frontend-only change; no backend wiring or schema impact. Separated
from the M-6 sentinel-agent idempotency commit per scope-guard.
2026-04-17 16:55:44 +00:00
shankar0123 e757ef1471 Merge branch 'fix/m6-sentinel-idempotent-create'
Resolves M-6 (Medium): swallowed sentinel agent INSERT errors.
CWE-662 / CWE-209-adjacent.

Shape A: CreateIfNotExists helper + 4 sentinel call sites.
2026-04-17 16:32:12 +00:00
shankar0123 27afa4463d fix(repository): idempotent sentinel agent creation via ON CONFLICT (M-6)
Sentinel agents (server-scanner, cloud-aws-sm, cloud-azure-kv,
cloud-gcp-sm) were created on startup with a plain INSERT whose
duplicate-key error was swallowed unconditionally. That silenced every
other DB failure too (connectivity drop, permissions change, unrelated
constraint violation) — a restart after the first boot quietly
de-fanged cloud discovery and the network scanner (CWE-662, CWE-209-
adjacent).

Shape A: add AgentRepository.CreateIfNotExists using ON CONFLICT (id)
DO NOTHING RETURNING id + sql.ErrNoRows discrimination. This keeps the
strict Create semantics (duplicate-key is an error) intact for real
agent registration and gives sentinels their own idempotent path.

- repo: CreateIfNotExists returns (created bool, err error); false,nil
  on pre-existing row; false,wrapped err on anything else.
- interface: CreateIfNotExists added to AgentRepository.
- main.go: 4 sentinel sites log Error/Info/Debug distinctly.
- mocks: service + integration mocks implement the new method.
- tests: 4 new testcontainers integration tests cover first-insert,
  idempotent second-call, concurrent 16-goroutine race (exactly one
  creator, no duplicate-key panic), and pre-cancelled context
  surfacing.

Coverage gates (go test -cover): service 67.6%/55, handler 78.6%/60,
domain 92.7%/40, middleware 80.0%/30, crypto 86.7%/85. Race/vet/
golangci-lint v2.11.4 (0 issues)/govulncheck v1.2.0 clean across all
touched packages.
2026-04-17 16:32:07 +00:00
shankar0123 80450c7180 fix(repository): populate TargetIDs in certificate scan helper (M-7)
scanCertificate never queried the certificate_target_mappings junction
table, so Certificate.TargetIDs was always nil on reads. This silently
broke deployment lookups, bulk revocation filters, cert detail pages,
and any code path that iterated TargetIDs to dispatch target work.

Fix:
- Convert scanCertificate to a receiver method (r *CertificateRepository)
  so it has access to the DB for the secondary junction query.
- Get(): scan the row, then call r.getTargetIDs(ctx, certID) to populate
  TargetIDs with a single targeted query.
- List() and GetExpiringCertificates(): inline the scan loop so we can
  collect all certIDs first, then call getTargetIDsForCertificates once
  with pq.Array(certIDs) to avoid N+1 round-trips. Build a map and
  attach TargetIDs to each certificate in the result set.
- Default TargetIDs to []string{} (not nil) when a cert has no mappings
  so JSON marshals as [] rather than null.

Tests:
- New integration test file certificate_targetids_test.go with 5
  subtests exercising Get / List / GetExpiringCertificates single
  and multi-target cases plus the empty-slice vs nil contract.
- Uses the shared testcontainers-go setupTestDB infrastructure and
  skips under 'go test -short' so CI (which excludes ./internal/repository/...
  from coverage paths anyway) stays green.

Addresses M-7 from certctl-audit-report.md.
2026-04-17 15:41:08 +00:00
shankar0123 c655e0f8c5 fix(crypto/local-ca): reject expired or not-yet-valid sub-CA certificates on disk load (M-5)
loadCAFromDisk now validates the upstream sub-CA certificate's NotBefore
and NotAfter fields before accepting it, returning a fail-closed error
at server startup instead of silently loading an out-of-window CA.

Before this fix, loadCAFromDisk checked BasicConstraints.IsCA and
KeyUsage=CertSign but not the validity window. An expired enterprise
sub-CA (e.g. an ADCS subordinate whose rollover slipped) would load
without warning and the scheduler would mint child certs that every
RFC 5280 path validator rejects — outages show up at relying parties,
not at certctl, and only after thresholds trip.

CWE-672 (Operation on a Resource after Expiration or Release); secondary
CWE-295 (Improper Certificate Validation). Error strings include the CA
subject CommonName and both RFC3339 timestamps so the log line is
actionable in a 3am incident.

Tests: TestSubCAMode gains three subtests exercising the new gate —
SubCA_ExpiredCert_IsRejected (CA expired 1h ago → error mentions
'expired' and the CN), SubCA_NotYetValid_IsRejected (CA valid +1h →
error mentions 'not yet valid' and the CN), and SubCA_BarelyValid_IsAccepted
(CA valid [now-1m, now+1h] → issuance succeeds, proving no
over-rejection). Adds generateTestSubCAWithValidity helper; the
original generateTestSubCA wrapper preserves the [now, now+5y] default
for existing tests.

Package coverage: 67.7% -> 68.3%.

Verification: go build, go vet, go test -race, go test -cover all
green locally; golangci-lint v2.11.4 clean; govulncheck clean. All CI
coverage floors met with margin (service 67.6/55, handler 78.6/60,
domain 92.7/40, middleware 80.0/30, crypto 86.7/85).

Parent: 5abeeb8 (M-8 per-ciphertext salt).
Closes: audit finding M-5 in certctl-audit-report.md.
2026-04-17 14:10:23 +00:00
shankar0123 5abeeb882b fix(crypto): per-ciphertext PBKDF2 salt + v2 versioned format with v1 fallback (M-8) 2026-04-17 05:36:29 +00:00
shankar0123 b1df6dab27 ci(release): add CLI/MCP binaries, checksums, SBOM, Cosign, SLSA provenance (M-3) 2026-04-17 04:04:55 +00:00
97 changed files with 4699 additions and 1084 deletions
+13 -2
View File
@@ -45,11 +45,11 @@ jobs:
run: govulncheck ./... run: govulncheck ./...
- name: Race Detection - name: Race Detection
run: go test -race ./internal/service/... ./internal/api/handler/... ./internal/api/middleware/... ./internal/scheduler/... ./internal/connector/... ./internal/domain/... ./internal/validation/... ./internal/tlsprobe/... -count=1 -timeout 300s run: go test -race ./internal/service/... ./internal/api/handler/... ./internal/api/middleware/... ./internal/scheduler/... ./internal/connector/... ./internal/crypto/... ./internal/domain/... ./internal/validation/... ./internal/tlsprobe/... -count=1 -timeout 300s
- name: Go Test with Coverage - name: Go Test with Coverage
run: | run: |
go test ./internal/service/... ./internal/api/handler/... ./internal/api/middleware/... ./internal/integration/... ./internal/connector/issuer/... ./internal/connector/target/... ./internal/connector/notifier/... ./internal/connector/discovery/... ./internal/mcp/... ./internal/cli/... ./internal/domain/... ./internal/validation/... ./internal/tlsprobe/... -count=1 -cover -coverprofile=coverage.out go test ./internal/service/... ./internal/api/handler/... ./internal/api/middleware/... ./internal/integration/... ./internal/connector/issuer/... ./internal/connector/target/... ./internal/connector/notifier/... ./internal/connector/discovery/... ./internal/crypto/... ./internal/mcp/... ./internal/cli/... ./internal/domain/... ./internal/validation/... ./internal/tlsprobe/... -count=1 -cover -coverprofile=coverage.out
- name: Check Coverage Thresholds - name: Check Coverage Thresholds
run: | run: |
@@ -73,6 +73,13 @@ jobs:
MIDDLEWARE_COV=$(go tool cover -func=coverage.out | grep 'internal/api/middleware' | awk '{print $NF}' | sed 's/%//' | awk '{sum+=$1; n++} END {if(n>0) printf "%.1f", sum/n; else print "0"}') MIDDLEWARE_COV=$(go tool cover -func=coverage.out | grep 'internal/api/middleware' | awk '{print $NF}' | sed 's/%//' | awk '{sum+=$1; n++} END {if(n>0) printf "%.1f", sum/n; else print "0"}')
echo "Middleware layer coverage: ${MIDDLEWARE_COV}%" echo "Middleware layer coverage: ${MIDDLEWARE_COV}%"
# Check crypto package coverage (target: 85%+)
# M-8 rationale: encryption primitives are a security-critical gate.
# v2 format, key-derivation, fallback, and fail-closed sentinel paths
# all need exhaustive coverage to avoid silent regressions (CWE-916 / CWE-329).
CRYPTO_COV=$(go tool cover -func=coverage.out | grep 'internal/crypto' | awk '{print $NF}' | sed 's/%//' | awk '{sum+=$1; n++} END {if(n>0) printf "%.1f", sum/n; else print "0"}')
echo "Crypto package coverage: ${CRYPTO_COV}%"
# Fail if thresholds not met # Fail if thresholds not met
if [ "$(echo "$SERVICE_COV < 55" | bc -l)" -eq 1 ]; then if [ "$(echo "$SERVICE_COV < 55" | bc -l)" -eq 1 ]; then
echo "::error::Service layer coverage ${SERVICE_COV}% is below 55% threshold" echo "::error::Service layer coverage ${SERVICE_COV}% is below 55% threshold"
@@ -90,6 +97,10 @@ jobs:
echo "::error::Middleware layer coverage ${MIDDLEWARE_COV}% is below 30% threshold" echo "::error::Middleware layer coverage ${MIDDLEWARE_COV}% is below 30% threshold"
exit 1 exit 1
fi fi
if [ "$(echo "$CRYPTO_COV < 85" | bc -l)" -eq 1 ]; then
echo "::error::Crypto package coverage ${CRYPTO_COV}% is below 85% threshold"
exit 1
fi
echo "Coverage thresholds passed!" echo "Coverage thresholds passed!"
- name: Upload Coverage Report - name: Upload Coverage Report
+275 -43
View File
@@ -7,40 +7,30 @@ on:
env: env:
REGISTRY: ghcr.io REGISTRY: ghcr.io
GO_VERSION: '1.22' # Keep in lock-step with .github/workflows/ci.yml (M-3).
GO_VERSION: '1.25.9'
IMAGE_NAMESPACE: shankar0123
jobs: jobs:
# Cross-compile agent and server binaries for multiple platforms # ----------------------------------------------------------------------
# build-binaries (M-3): matrix build every (binary × OS × arch) tuple.
# For each tuple we produce: the binary, a SPDX-JSON SBOM, a keyless
# Cosign signature + certificate bundle, and a single-line sha256sum
# file. All artefacts are uploaded to a workflow-scoped artifact; the
# aggregate-checksums job fans them back in for release upload.
# ----------------------------------------------------------------------
build-binaries: build-binaries:
name: Build Cross-Platform Binaries name: Build ${{ matrix.binary }} (${{ matrix.os }}/${{ matrix.arch }})
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions: permissions:
contents: write contents: read
id-token: write # Cosign keyless OIDC identity token
strategy: strategy:
fail-fast: false
matrix: matrix:
include: binary: [agent, server, cli, mcp-server]
# Agent binaries (4 platforms) os: [linux, darwin]
- os: linux arch: [amd64, arm64]
arch: amd64
binary: agent
- os: linux
arch: arm64
binary: agent
- os: darwin
arch: amd64
binary: agent
- os: darwin
arch: arm64
binary: agent
# Server binaries (2 platforms)
- os: linux
arch: amd64
binary: server
- os: linux
arch: arm64
binary: server
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
@@ -51,35 +41,174 @@ jobs:
- name: Extract version from tag - name: Extract version from tag
id: version id: version
run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> "$GITHUB_OUTPUT"
- name: Build ${{ matrix.binary }} binary (${{ matrix.os }}-${{ matrix.arch }}) - name: Build binary
id: build
env: env:
GOOS: ${{ matrix.os }} GOOS: ${{ matrix.os }}
GOARCH: ${{ matrix.arch }} GOARCH: ${{ matrix.arch }}
CGO_ENABLED: 0 CGO_ENABLED: '0'
VERSION: ${{ steps.version.outputs.VERSION }}
run: | run: |
set -euo pipefail
OUTPUT_NAME="certctl-${{ matrix.binary }}-${{ matrix.os }}-${{ matrix.arch }}" OUTPUT_NAME="certctl-${{ matrix.binary }}-${{ matrix.os }}-${{ matrix.arch }}"
go build -ldflags="-w -s -X main.Version=${{ steps.version.outputs.VERSION }}" \ mkdir -p dist
go build \
-trimpath \
-ldflags="-w -s -X main.Version=${VERSION}" \
-o "dist/${OUTPUT_NAME}" \ -o "dist/${OUTPUT_NAME}" \
"./cmd/${{ matrix.binary }}" "./cmd/${{ matrix.binary }}"
ls -lh "dist/${OUTPUT_NAME}" ls -lh "dist/${OUTPUT_NAME}"
echo "output_name=${OUTPUT_NAME}" >> "$GITHUB_OUTPUT"
- name: Upload binaries to release - name: Generate SBOM (SPDX-JSON)
uses: anchore/sbom-action@e22c389904149dbc22b58101806040fa8d37a610 # v0.24.0
with:
file: dist/${{ steps.build.outputs.output_name }}
format: spdx-json
output-file: dist/${{ steps.build.outputs.output_name }}.sbom.spdx.json
upload-artifact: false
upload-release-assets: false
- name: Install Cosign
uses: sigstore/cosign-installer@cad07c2e89fa2edd6e2d7bab4c1aa38e53f76003 # v4.1.1
- name: Keyless-sign binary with Cosign
env:
OUTPUT_NAME: ${{ steps.build.outputs.output_name }}
run: |
set -euo pipefail
# Cosign v3.0 (shipped by cosign-installer@v4.1.1 default
# cosign-release=v3.0.5) removed --output-signature/--output-certificate
# on sign-blob. The replacement is --bundle, which emits a unified
# Sigstore bundle (signature + cert chain + Rekor inclusion proof) as
# a single .sigstore.json artefact. M-11.
cosign sign-blob \
--yes \
--bundle "dist/${OUTPUT_NAME}.sigstore.json" \
"dist/${OUTPUT_NAME}"
- name: Compute SHA-256 sidecar
env:
OUTPUT_NAME: ${{ steps.build.outputs.output_name }}
run: |
set -euo pipefail
cd dist
sha256sum "${OUTPUT_NAME}" > "${OUTPUT_NAME}.sha256"
cat "${OUTPUT_NAME}.sha256"
- name: Upload build artefacts
uses: actions/upload-artifact@v4
with:
name: binary-${{ steps.build.outputs.output_name }}
path: |
dist/${{ steps.build.outputs.output_name }}
dist/${{ steps.build.outputs.output_name }}.sigstore.json
dist/${{ steps.build.outputs.output_name }}.sbom.spdx.json
dist/${{ steps.build.outputs.output_name }}.sha256
if-no-files-found: error
retention-days: 7
# ----------------------------------------------------------------------
# aggregate-checksums (M-3): fan in every matrix artefact, produce a
# single checksums.txt (sha256sum format, compatible with `sha256sum
# -c`), sign it with Cosign, upload everything to the GitHub Release,
# and emit a base64-encoded hash manifest for the SLSA generator.
# ----------------------------------------------------------------------
aggregate-checksums:
name: Aggregate checksums & sign
runs-on: ubuntu-latest
needs: [build-binaries]
permissions:
contents: write
id-token: write # Cosign keyless OIDC identity token
outputs:
hashes: ${{ steps.hashes.outputs.hashes }}
steps:
- name: Download binary artefacts
uses: actions/download-artifact@v4
with:
pattern: binary-*
path: artifacts
merge-multiple: true
- name: Aggregate SHA-256 sums
id: hashes
run: |
set -euo pipefail
cd artifacts
: > checksums.txt
for f in certctl-*; do
case "$f" in
*.sigstore.json|*.sbom.spdx.json|*.sha256|checksums.txt)
continue ;;
esac
sha256sum "$f" >> checksums.txt
done
echo "=== checksums.txt ==="
cat checksums.txt
# base64 hashes (single line, no wrapping) for SLSA generator.
HASHES=$(base64 -w0 < checksums.txt)
echo "hashes=${HASHES}" >> "$GITHUB_OUTPUT"
- name: Install Cosign
uses: sigstore/cosign-installer@cad07c2e89fa2edd6e2d7bab4c1aa38e53f76003 # v4.1.1
- name: Keyless-sign checksums.txt
run: |
set -euo pipefail
cd artifacts
# Cosign v3.0 --bundle replaces the removed v2 flag pair
# --output-signature / --output-certificate. See M-11.
cosign sign-blob \
--yes \
--bundle checksums.txt.sigstore.json \
checksums.txt
- name: Upload artefacts to GitHub Release
uses: softprops/action-gh-release@v2 uses: softprops/action-gh-release@v2
if: startsWith(github.ref, 'refs/tags/') if: startsWith(github.ref, 'refs/tags/')
with: with:
files: | files: |
dist/certctl-agent-* artifacts/certctl-*
dist/certctl-server-* artifacts/checksums.txt
artifacts/checksums.txt.sigstore.json
# Build and push Docker images # ----------------------------------------------------------------------
# provenance-binaries (M-3): SLSA Level 3 provenance for every binary.
# The SLSA generic generator reusable workflow runs in a hermetic
# workflow run, producing multiple.intoto.jsonl from the base64 hash
# manifest and uploading it as a release asset.
# ----------------------------------------------------------------------
provenance-binaries:
name: SLSA provenance (binaries)
needs: [aggregate-checksums]
permissions:
actions: read
id-token: write
contents: write
uses: slsa-framework/slsa-github-generator/.github/workflows/generator_generic_slsa3.yml@v2.1.0
with:
base64-subjects: "${{ needs.aggregate-checksums.outputs.hashes }}"
upload-assets: true
provenance-name: multiple.intoto.jsonl
# ----------------------------------------------------------------------
# build-and-push-docker: push container images to GHCR with native
# SLSA L3 provenance (mode=max) and SBOM attestations emitted by
# docker/build-push-action@v6, plus a keyless Cosign signature on the
# image digest for identity-bound verification. The M-4 proxy-propagation
# build-args block is retained verbatim — M-3 only adds supply-chain
# steps; it never touches M-4 wiring.
# ----------------------------------------------------------------------
build-and-push-docker: build-and-push-docker:
name: Build & Push Docker Images name: Build & Push Docker Images
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions: permissions:
contents: write contents: write
packages: write packages: write
id-token: write # Cosign keyless OIDC identity token
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
@@ -93,20 +222,24 @@ jobs:
- name: Extract version from tag - name: Extract version from tag
id: version id: version
run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> "$GITHUB_OUTPUT"
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
- name: Install Cosign
uses: sigstore/cosign-installer@cad07c2e89fa2edd6e2d7bab4c1aa38e53f76003 # v4.1.1
- name: Build and push server image - name: Build and push server image
id: server-push
uses: docker/build-push-action@v6 uses: docker/build-push-action@v6
with: with:
context: . context: .
file: ./Dockerfile file: ./Dockerfile
push: true push: true
tags: | tags: |
${{ env.REGISTRY }}/shankar0123/certctl-server:${{ steps.version.outputs.VERSION }} ${{ env.REGISTRY }}/${{ env.IMAGE_NAMESPACE }}/certctl-server:${{ steps.version.outputs.VERSION }}
${{ env.REGISTRY }}/shankar0123/certctl-server:latest ${{ env.REGISTRY }}/${{ env.IMAGE_NAMESPACE }}/certctl-server:latest
# Proxy propagation (M-4, Issue #9) — forwards runner-level proxy # Proxy propagation (M-4, Issue #9) — forwards runner-level proxy
# secrets into the Docker build so self-hosted runners behind # secrets into the Docker build so self-hosted runners behind
# corporate proxies can reach public registries. GitHub-hosted # corporate proxies can reach public registries. GitHub-hosted
@@ -117,18 +250,31 @@ jobs:
HTTP_PROXY=${{ secrets.HTTP_PROXY }} HTTP_PROXY=${{ secrets.HTTP_PROXY }}
HTTPS_PROXY=${{ secrets.HTTPS_PROXY }} HTTPS_PROXY=${{ secrets.HTTPS_PROXY }}
NO_PROXY=${{ secrets.NO_PROXY }} NO_PROXY=${{ secrets.NO_PROXY }}
# Supply-chain hardening (M-3): emit native SLSA L3 provenance
# and SBOM attestations bound to the image manifest.
provenance: mode=max
sbom: true
cache-from: type=gha cache-from: type=gha
cache-to: type=gha,mode=max cache-to: type=gha,mode=max
- name: Keyless-sign server image with Cosign
env:
DIGEST: ${{ steps.server-push.outputs.digest }}
IMAGE: ${{ env.REGISTRY }}/${{ env.IMAGE_NAMESPACE }}/certctl-server
run: |
set -euo pipefail
cosign sign --yes "${IMAGE}@${DIGEST}"
- name: Build and push agent image - name: Build and push agent image
id: agent-push
uses: docker/build-push-action@v6 uses: docker/build-push-action@v6
with: with:
context: . context: .
file: ./Dockerfile.agent file: ./Dockerfile.agent
push: true push: true
tags: | tags: |
${{ env.REGISTRY }}/shankar0123/certctl-agent:${{ steps.version.outputs.VERSION }} ${{ env.REGISTRY }}/${{ env.IMAGE_NAMESPACE }}/certctl-agent:${{ steps.version.outputs.VERSION }}
${{ env.REGISTRY }}/shankar0123/certctl-agent:latest ${{ env.REGISTRY }}/${{ env.IMAGE_NAMESPACE }}/certctl-agent:latest
# Proxy propagation (M-4, Issue #9) — see server-image step for # Proxy propagation (M-4, Issue #9) — see server-image step for
# rationale. Empty secrets resolve to empty build args, leaving # rationale. Empty secrets resolve to empty build args, leaving
# the un-proxied code path byte-identical to the pre-fix tree. # the un-proxied code path byte-identical to the pre-fix tree.
@@ -136,14 +282,30 @@ jobs:
HTTP_PROXY=${{ secrets.HTTP_PROXY }} HTTP_PROXY=${{ secrets.HTTP_PROXY }}
HTTPS_PROXY=${{ secrets.HTTPS_PROXY }} HTTPS_PROXY=${{ secrets.HTTPS_PROXY }}
NO_PROXY=${{ secrets.NO_PROXY }} NO_PROXY=${{ secrets.NO_PROXY }}
# Supply-chain hardening (M-3): emit native SLSA L3 provenance
# and SBOM attestations bound to the image manifest.
provenance: mode=max
sbom: true
cache-from: type=gha cache-from: type=gha
cache-to: type=gha,mode=max cache-to: type=gha,mode=max
# Create release notes with all artifacts - name: Keyless-sign agent image with Cosign
env:
DIGEST: ${{ steps.agent-push.outputs.digest }}
IMAGE: ${{ env.REGISTRY }}/${{ env.IMAGE_NAMESPACE }}/certctl-agent
run: |
set -euo pipefail
cosign sign --yes "${IMAGE}@${DIGEST}"
# ----------------------------------------------------------------------
# create-release: stamp the release body. The actual asset uploads are
# handled by aggregate-checksums (binaries, SBOMs, sigs, certs,
# checksums.txt + signature) and the SLSA generator (multiple.intoto.jsonl).
# ----------------------------------------------------------------------
create-release: create-release:
name: Create Release Notes name: Create Release Notes
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [build-binaries, build-and-push-docker] needs: [build-binaries, aggregate-checksums, provenance-binaries, build-and-push-docker]
permissions: permissions:
contents: write contents: write
@@ -152,7 +314,7 @@ jobs:
- name: Extract version from tag - name: Extract version from tag
id: version id: version
run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> "$GITHUB_OUTPUT"
- name: Create release with notes - name: Create release with notes
uses: softprops/action-gh-release@v2 uses: softprops/action-gh-release@v2
@@ -214,6 +376,76 @@ jobs:
- **Linux x86_64**: `certctl-server-linux-amd64` - **Linux x86_64**: `certctl-server-linux-amd64`
- **Linux ARM64**: `certctl-server-linux-arm64` - **Linux ARM64**: `certctl-server-linux-arm64`
- **macOS x86_64**: `certctl-server-darwin-amd64`
- **macOS ARM64 (Apple Silicon)**: `certctl-server-darwin-arm64`
## CLI & MCP Server Binaries
The `certctl-cli` (REST API wrapper) and `certctl-mcp-server` (Model Context
Protocol bridge) binaries ship for all four platforms as well:
- `certctl-cli-{linux,darwin}-{amd64,arm64}`
- `certctl-mcp-server-{linux,darwin}-{amd64,arm64}`
## Verifying this release
Every binary, `checksums.txt`, and container image is signed with Cosign
keyless OIDC. Each binary ships with a SPDX-JSON SBOM. Binaries are covered
by SLSA Level 3 provenance; container images carry native SLSA L3 provenance
and SBOM attestations (docker/build-push-action `provenance: mode=max`,
`sbom: true`) in addition to a Cosign signature on the digest.
**1. Verify SHA-256 checksums:**
```bash
sha256sum -c checksums.txt
```
**2. Verify the Cosign signature on checksums.txt (keyless OIDC):**
```bash
cosign verify-blob \
--bundle checksums.txt.sigstore.json \
--certificate-identity-regexp '^https://github\.com/shankar0123/certctl/\.github/workflows/release\.yml@refs/tags/' \
--certificate-oidc-issuer 'https://token.actions.githubusercontent.com' \
checksums.txt
```
Replace `checksums.txt` with any individual binary name to verify that
artefact directly (each binary ships with its own `.sigstore.json`
bundle, e.g. `cosign verify-blob --bundle certctl-agent-linux-amd64.sigstore.json …`).
**3. Verify SLSA Level 3 provenance (binaries):**
```bash
slsa-verifier verify-artifact \
--provenance-path multiple.intoto.jsonl \
--source-uri github.com/shankar0123/certctl \
--source-tag ${{ steps.version.outputs.VERSION }} \
certctl-agent-linux-amd64
```
**4. Verify container image signature and attestations:**
```bash
IMAGE=ghcr.io/shankar0123/certctl-server:${{ steps.version.outputs.VERSION }}
cosign verify \
--certificate-identity-regexp '^https://github\.com/shankar0123/certctl/\.github/workflows/release\.yml@refs/tags/' \
--certificate-oidc-issuer 'https://token.actions.githubusercontent.com' \
"$IMAGE"
# SBOM attestation (SPDX-JSON) emitted by docker/build-push-action
cosign verify-attestation --type spdxjson \
--certificate-identity-regexp '^https://github\.com/shankar0123/certctl/' \
--certificate-oidc-issuer 'https://token.actions.githubusercontent.com' \
"$IMAGE"
# SLSA provenance attestation (mode=max)
cosign verify-attestation --type slsaprovenance \
--certificate-identity-regexp '^https://github\.com/shankar0123/certctl/' \
--certificate-oidc-issuer 'https://token.actions.githubusercontent.com' \
"$IMAGE"
```
## Helm Chart ## Helm Chart
+1
View File
@@ -6,6 +6,7 @@ run:
linters: linters:
default: none default: none
enable: enable:
- contextcheck
- govet - govet
- staticcheck - staticcheck
- unused - unused
+68
View File
@@ -237,6 +237,74 @@ docker pull shankar0123.docker.scarf.sh/certctl-server
docker pull shankar0123.docker.scarf.sh/certctl-agent docker pull shankar0123.docker.scarf.sh/certctl-agent
``` ```
## Verifying this release
Every `v*` tag publishes signed, attested release artefacts. Binaries
(`certctl-agent`, `certctl-server`, `certctl-cli`, `certctl-mcp-server` for
`linux|darwin × amd64|arm64`) ship alongside a `checksums.txt`, per-binary
SPDX-JSON SBOMs, Cosign signatures, and SLSA Level 3 provenance. Container
images on `ghcr.io/shankar0123/certctl-{server,agent}` are built with
`docker/build-push-action` `provenance: mode=max` + `sbom: true` and are
additionally signed with Cosign at the image digest.
All signatures use Cosign keyless OIDC; the signing identity is the
release workflow running on a signed tag.
**1. Verify SHA-256 checksums:**
```bash
sha256sum -c checksums.txt
```
**2. Verify the Cosign signature on `checksums.txt`:**
```bash
cosign verify-blob \
--bundle checksums.txt.sigstore.json \
--certificate-identity-regexp '^https://github\.com/shankar0123/certctl/\.github/workflows/release\.yml@refs/tags/' \
--certificate-oidc-issuer 'https://token.actions.githubusercontent.com' \
checksums.txt
```
Every individual binary ships with its own `.sigstore.json` bundle
(unified Sigstore bundle containing signature, certificate chain, and
Rekor inclusion proof). Swap `checksums.txt` for any binary name and
point `--bundle` at the matching `<binary>.sigstore.json` to verify it
directly.
**3. Verify SLSA Level 3 provenance on a binary:**
```bash
slsa-verifier verify-artifact \
--provenance-path multiple.intoto.jsonl \
--source-uri github.com/shankar0123/certctl \
--source-tag v2.1.0 \
certctl-agent-linux-amd64
```
**4. Verify a container image signature and its SBOM / provenance attestations:**
```bash
IMAGE=ghcr.io/shankar0123/certctl-server:v2.1.0
cosign verify \
--certificate-identity-regexp '^https://github\.com/shankar0123/certctl/\.github/workflows/release\.yml@refs/tags/' \
--certificate-oidc-issuer 'https://token.actions.githubusercontent.com' \
"$IMAGE"
# SBOM attestation (SPDX-JSON, emitted by docker/build-push-action)
cosign verify-attestation --type spdxjson \
--certificate-identity-regexp '^https://github\.com/shankar0123/certctl/' \
--certificate-oidc-issuer 'https://token.actions.githubusercontent.com' \
"$IMAGE"
# SLSA provenance attestation (docker/build-push-action `provenance: mode=max`)
cosign verify-attestation --type slsaprovenance \
--certificate-identity-regexp '^https://github\.com/shankar0123/certctl/' \
--certificate-oidc-issuer 'https://token.actions.githubusercontent.com' \
"$IMAGE"
```
## Examples ## Examples
Pick the scenario closest to your setup and have it running in 2 minutes. Pick the scenario closest to your setup and have it running in 2 minutes.
+375
View File
@@ -66,6 +66,12 @@ tags:
description: Continuous TLS endpoint health checks with status tracking and probe history description: Continuous TLS endpoint health checks with status tracking and probe history
- name: Digest - name: Digest
description: Scheduled certificate digest email notifications description: Scheduled certificate digest email notifications
- name: Verification
description: Post-deployment TLS endpoint fingerprint verification
- name: EST
description: Enrollment over Secure Transport (RFC 7030)
- name: SCEP
description: Simple Certificate Enrollment Protocol (RFC 8894)
paths: paths:
# ─── Health & Auth ─────────────────────────────────────────────────── # ─── Health & Auth ───────────────────────────────────────────────────
@@ -816,6 +822,28 @@ paths:
"500": "500":
$ref: "#/components/responses/InternalError" $ref: "#/components/responses/InternalError"
/api/v1/targets/{id}/test:
post:
tags: [Targets]
summary: Test target connection
description: |
Checks target connectivity by verifying the assigned agent's heartbeat status
(agent reported within the last 5 minutes). Always returns HTTP 200 — the
connectivity result is reflected in the response body's `status` field
(`success` when the agent is reachable, `failed` otherwise).
operationId: testTargetConnection
parameters:
- $ref: "#/components/parameters/resourceId"
responses:
"200":
description: Connection test result (success or failed in body)
content:
application/json:
schema:
$ref: "#/components/schemas/StatusMessageResponse"
"400":
$ref: "#/components/responses/BadRequest"
# ─── Agents ────────────────────────────────────────────────────────── # ─── Agents ──────────────────────────────────────────────────────────
/api/v1/agents: /api/v1/agents:
get: get:
@@ -1177,6 +1205,66 @@ paths:
"500": "500":
$ref: "#/components/responses/InternalError" $ref: "#/components/responses/InternalError"
/api/v1/jobs/{id}/verify:
post:
tags: [Verification]
summary: Record post-deployment verification result
description: |
Agents submit the result of probing a deployed certificate's live TLS endpoint.
Compares the served certificate's SHA-256 fingerprint against the expected
fingerprint. Best-effort: failures are recorded on the job but do not roll
back the deployment.
operationId: verifyDeployment
parameters:
- $ref: "#/components/parameters/resourceId"
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/VerifyDeploymentRequest"
responses:
"200":
description: Verification result recorded
content:
application/json:
schema:
type: object
properties:
job_id:
type: string
verified:
type: boolean
verified_at:
type: string
format: date-time
"400":
$ref: "#/components/responses/BadRequest"
"500":
$ref: "#/components/responses/InternalError"
/api/v1/jobs/{id}/verification:
get:
tags: [Verification]
summary: Get post-deployment verification status
description: |
Returns the stored verification result for a deployment job — expected
and observed SHA-256 fingerprints, verified flag, and timestamp.
operationId: getJobVerification
parameters:
- $ref: "#/components/parameters/resourceId"
responses:
"200":
description: Verification result for the job
content:
application/json:
schema:
$ref: "#/components/schemas/VerificationResult"
"400":
$ref: "#/components/responses/BadRequest"
"500":
$ref: "#/components/responses/InternalError"
# ─── Policies ──────────────────────────────────────────────────────── # ─── Policies ────────────────────────────────────────────────────────
/api/v1/policies: /api/v1/policies:
get: get:
@@ -2718,6 +2806,238 @@ paths:
"500": "500":
$ref: "#/components/responses/InternalError" $ref: "#/components/responses/InternalError"
# ─── EST (RFC 7030) ────────────────────────────────────────────────
/.well-known/est/cacerts:
get:
tags: [EST]
summary: EST CA certificates distribution
description: |
Returns the CA certificate chain used to verify certctl-issued certificates.
Response is a base64-encoded degenerate PKCS#7 SignedData (certs-only) per
RFC 7030 §4.1.3.
operationId: estCACerts
security: []
responses:
"200":
description: Base64-encoded PKCS#7 certs-only structure
headers:
Content-Transfer-Encoding:
schema:
type: string
example: base64
content:
application/pkcs7-mime:
schema:
type: string
format: byte
description: "Base64-encoded PKCS#7 (smime-type=certs-only)"
"500":
$ref: "#/components/responses/InternalError"
/.well-known/est/simpleenroll:
post:
tags: [EST]
summary: EST simple enrollment
description: |
Enrolls a new certificate from a PKCS#10 CSR per RFC 7030 §4.2.1.
The CSR MAY be supplied as base64-encoded DER (EST standard wire format)
or as PEM for convenience. Returns a base64-encoded PKCS#7 certs-only
structure containing the issued certificate.
operationId: estSimpleEnroll
security: []
requestBody:
required: true
description: "Base64-encoded DER PKCS#10 CSR, or PEM-encoded CSR"
content:
application/pkcs10:
schema:
type: string
format: byte
responses:
"200":
description: Base64-encoded PKCS#7 cert-only response with issued certificate
headers:
Content-Transfer-Encoding:
schema:
type: string
example: base64
content:
application/pkcs7-mime:
schema:
type: string
format: byte
description: "Base64-encoded PKCS#7 (smime-type=certs-only)"
"400":
$ref: "#/components/responses/BadRequest"
"405":
description: Method not allowed (only POST accepted)
"500":
$ref: "#/components/responses/InternalError"
/.well-known/est/simplereenroll:
post:
tags: [EST]
summary: EST simple re-enrollment
description: |
Re-enrolls an existing certificate (same as simpleenroll in certctl's
implementation — re-enrollment is treated as a fresh issuance) per
RFC 7030 §4.2.2.
operationId: estSimpleReEnroll
security: []
requestBody:
required: true
description: "Base64-encoded DER PKCS#10 CSR, or PEM-encoded CSR"
content:
application/pkcs10:
schema:
type: string
format: byte
responses:
"200":
description: Base64-encoded PKCS#7 cert-only response with re-issued certificate
headers:
Content-Transfer-Encoding:
schema:
type: string
example: base64
content:
application/pkcs7-mime:
schema:
type: string
format: byte
description: "Base64-encoded PKCS#7 (smime-type=certs-only)"
"400":
$ref: "#/components/responses/BadRequest"
"405":
description: Method not allowed (only POST accepted)
"500":
$ref: "#/components/responses/InternalError"
/.well-known/est/csrattrs:
get:
tags: [EST]
summary: EST CSR attributes
description: |
Returns attributes the EST client should include in its CSR per
RFC 7030 §4.5. certctl currently returns an empty attribute set
(HTTP 204) — profile-based constraints are enforced server-side
during enrollment rather than advertised here.
operationId: estCSRAttrs
security: []
responses:
"200":
description: Base64-encoded CsrAttrs (when non-empty)
headers:
Content-Transfer-Encoding:
schema:
type: string
example: base64
content:
application/csrattrs:
schema:
type: string
format: byte
"204":
description: No CSR attributes defined (empty response)
"500":
$ref: "#/components/responses/InternalError"
# ─── SCEP (RFC 8894) ──────────────────────────────────────────────
/scep:
get:
tags: [SCEP]
summary: SCEP operation dispatch (GET)
description: |
Single SCEP entry point dispatched by the `operation` query parameter
per RFC 8894. GET is used for capability discovery (`GetCACaps`) and
CA certificate retrieval (`GetCACert`).
operationId: scepGet
security: []
parameters:
- name: operation
in: query
required: true
schema:
type: string
enum: [GetCACaps, GetCACert, PKIOperation]
description: SCEP operation selector
- name: message
in: query
required: false
schema:
type: string
description: Optional SCEP message parameter (base64-encoded for GET PKIOperation)
responses:
"200":
description: |
Success. Content-Type varies by operation:
- `GetCACaps` → `text/plain` capability list
- `GetCACert` (single cert) → `application/x-x509-ca-cert` (raw DER)
- `GetCACert` (chain) → `application/x-x509-ca-ra-cert` (PKCS#7)
- `PKIOperation` → `application/x-pki-message` (PKCS#7 SignedData)
content:
text/plain:
schema:
type: string
description: "SCEP capabilities (GetCACaps only)"
application/x-x509-ca-cert:
schema:
type: string
format: binary
description: "CA certificate DER (GetCACert single)"
application/x-x509-ca-ra-cert:
schema:
type: string
format: binary
description: "CA chain PKCS#7 (GetCACert chain)"
application/x-pki-message:
schema:
type: string
format: binary
description: "PKCS#7 SignedData response (PKIOperation)"
"400":
$ref: "#/components/responses/BadRequest"
"500":
$ref: "#/components/responses/InternalError"
post:
tags: [SCEP]
summary: SCEP PKIOperation (POST)
description: |
SCEP enrollment / renewal / revocation request per RFC 8894.
Request body is a PKCS#7 SignedData envelope wrapping the PKCS#10 CSR
or a degenerate raw CSR (fallback). The challenge password in the CSR
attributes is validated against `CERTCTL_SCEP_CHALLENGE_PASSWORD` when
configured.
operationId: scepPost
security: []
parameters:
- name: operation
in: query
required: true
schema:
type: string
enum: [PKIOperation]
requestBody:
required: true
description: PKCS#7 SignedData envelope wrapping a PKCS#10 CSR (or raw CSR as fallback)
content:
application/x-pki-message:
schema:
type: string
format: binary
responses:
"200":
description: PKCS#7 SignedData PKIMessage response
content:
application/x-pki-message:
schema:
type: string
format: binary
"400":
$ref: "#/components/responses/BadRequest"
"500":
$ref: "#/components/responses/InternalError"
# ═══════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════
components: components:
securitySchemes: securitySchemes:
@@ -3006,6 +3326,7 @@ components:
DeploymentTarget: DeploymentTarget:
type: object type: object
required: [name, type, agent_id]
properties: properties:
id: id:
type: string type: string
@@ -3015,6 +3336,12 @@ components:
$ref: "#/components/schemas/TargetType" $ref: "#/components/schemas/TargetType"
agent_id: agent_id:
type: string type: string
description: |
ID of the agent that manages this target. Required because
deployment_targets.agent_id is a NOT NULL foreign key to agents(id)
(migration 000001). Empty or nonexistent agent IDs are rejected
with HTTP 400 by the service layer (see C-002 in the coverage-gap
audit).
config: config:
type: object type: object
description: Target-specific configuration (varies by type) description: Target-specific configuration (varies by type)
@@ -3141,6 +3468,7 @@ components:
- RequiredMetadata - RequiredMetadata
- AllowedEnvironments - AllowedEnvironments
- RenewalLeadTime - RenewalLeadTime
- CertificateLifetime
PolicySeverity: PolicySeverity:
type: string type: string
@@ -3160,6 +3488,9 @@ components:
description: Policy-specific configuration (varies by type) description: Policy-specific configuration (varies by type)
enabled: enabled:
type: boolean type: boolean
severity:
$ref: "#/components/schemas/PolicySeverity"
description: Severity level applied to violations of this rule. Defaults to Warning on create when omitted.
created_at: created_at:
type: string type: string
format: date-time format: date-time
@@ -3805,3 +4136,47 @@ components:
type: string type: string
format: date-time format: date-time
description: Timestamp of this probe description: Timestamp of this probe
# ─── Verification (M25) ──────────────────────────────────────────
VerifyDeploymentRequest:
type: object
required: [target_id, expected_fingerprint, actual_fingerprint, verified]
properties:
target_id:
type: string
description: Deployment target the agent probed
expected_fingerprint:
type: string
description: SHA-256 fingerprint of the certificate that should be served (hex, lowercase)
actual_fingerprint:
type: string
description: SHA-256 fingerprint observed on the live TLS endpoint (hex, lowercase)
verified:
type: boolean
description: True when expected and actual fingerprints match
error:
type: string
nullable: true
description: Error message when probe failed or fingerprints differ
VerificationResult:
type: object
properties:
job_id:
type: string
target_id:
type: string
expected_fingerprint:
type: string
description: SHA-256 fingerprint (hex) of the certificate deployed by this job
actual_fingerprint:
type: string
description: SHA-256 fingerprint (hex) observed on the live TLS endpoint
verified:
type: boolean
verified_at:
type: string
format: date-time
error:
type: string
description: Error message when verification failed
+60 -17
View File
@@ -16,7 +16,6 @@ import (
"github.com/shankar0123/certctl/internal/api/middleware" "github.com/shankar0123/certctl/internal/api/middleware"
"github.com/shankar0123/certctl/internal/api/router" "github.com/shankar0123/certctl/internal/api/router"
"github.com/shankar0123/certctl/internal/config" "github.com/shankar0123/certctl/internal/config"
"github.com/shankar0123/certctl/internal/crypto"
"github.com/shankar0123/certctl/internal/domain" "github.com/shankar0123/certctl/internal/domain"
discoveryawssm "github.com/shankar0123/certctl/internal/connector/discovery/awssm" discoveryawssm "github.com/shankar0123/certctl/internal/connector/discovery/awssm"
discoveryazurekv "github.com/shankar0123/certctl/internal/connector/discovery/azurekv" discoveryazurekv "github.com/shankar0123/certctl/internal/connector/discovery/azurekv"
@@ -82,12 +81,20 @@ func main() {
logger.Info("initialized all repositories") logger.Info("initialized all repositories")
// Initialize dynamic issuer registry. // Initialize dynamic issuer registry.
// Issuers are loaded from the database (with AES-GCM encrypted config). // Issuers are loaded from the database (with AES-256-GCM encrypted config).
// On first boot with an empty database, env var issuers are seeded automatically. // On first boot with an empty database, env var issuers are seeded automatically.
var encryptionKey []byte //
if cfg.Encryption.ConfigEncryptionKey != "" { // M-8 (CWE-916 / CWE-329): the encryption passphrase is passed as a raw
encryptionKey = crypto.DeriveKey(cfg.Encryption.ConfigEncryptionKey) // string into IssuerService / TargetService / IssuerRegistry. Each call to
logger.Info("config encryption enabled (AES-256-GCM)") // crypto.EncryptIfKeySet generates a fresh 16-byte PBKDF2 salt and emits a
// v2 blob (magic 0x02 || salt || nonce || sealed). Decryption auto-detects
// v1 legacy blobs (no magic) and falls back to the fixed v1 salt for
// backward compatibility; v1 blobs transparently upgrade to v2 on next
// write. DO NOT pre-derive the key here with crypto.DeriveKey — that was
// the v1 fixed-salt behaviour that M-8 removes.
encryptionKey := cfg.Encryption.ConfigEncryptionKey
if encryptionKey != "" {
logger.Info("config encryption enabled (AES-256-GCM, per-ciphertext PBKDF2 salt)")
} else { } else {
// C-2 fix: fail closed at startup when database-sourced issuer or target // C-2 fix: fail closed at startup when database-sourced issuer or target
// rows exist without a configured encryption key. Previously the server // rows exist without a configured encryption key. Previously the server
@@ -138,6 +145,7 @@ func main() {
// Initialize services (following the dependency graph) // Initialize services (following the dependency graph)
auditService := service.NewAuditService(auditRepo) auditService := service.NewAuditService(auditRepo)
policyService := service.NewPolicyService(policyRepo, auditService) policyService := service.NewPolicyService(policyRepo, auditService)
policyService.SetCertRepo(certificateRepo) // D-008: CertificateLifetime arm needs CertificateVersion.NotBefore/NotAfter
certificateService := service.NewCertificateService(certificateRepo, policyService, auditService) certificateService := service.NewCertificateService(certificateRepo, policyService, auditService)
notifierRegistry := make(map[string]service.Notifier) notifierRegistry := make(map[string]service.Notifier)
@@ -246,9 +254,15 @@ func main() {
Name: "Network Scanner (Server-Side)", Name: "Network Scanner (Server-Side)",
Status: domain.AgentStatusOnline, Status: domain.AgentStatusOnline,
} }
if err := agentRepo.Create(context.Background(), sentinelAgent); err != nil { // M-6: use CreateIfNotExists so duplicate rows on restart/upgrade are
// Ignore duplicate key errors (agent already exists) // idempotent without swallowing unrelated DB failures (CWE-662).
logger.Debug("sentinel agent creation", "status", "exists or created", "id", service.SentinelAgentID) created, err := agentRepo.CreateIfNotExists(context.Background(), sentinelAgent)
if err != nil {
logger.Error("sentinel agent creation failed", "id", service.SentinelAgentID, "error", err)
} else if created {
logger.Info("sentinel agent created", "id", service.SentinelAgentID)
} else {
logger.Debug("sentinel agent already exists", "id", service.SentinelAgentID)
} }
} }
@@ -267,8 +281,14 @@ func main() {
Name: "AWS Secrets Manager Discovery", Name: "AWS Secrets Manager Discovery",
Status: domain.AgentStatusOnline, Status: domain.AgentStatusOnline,
} }
if err := agentRepo.Create(context.Background(), sentinelAWS); err != nil { // M-6: idempotent create (CWE-662).
logger.Debug("sentinel agent creation", "status", "exists or created", "id", service.SentinelAWSSecretsMgr) created, err := agentRepo.CreateIfNotExists(context.Background(), sentinelAWS)
if err != nil {
logger.Error("sentinel agent creation failed", "id", service.SentinelAWSSecretsMgr, "error", err)
} else if created {
logger.Info("sentinel agent created", "id", service.SentinelAWSSecretsMgr)
} else {
logger.Debug("sentinel agent already exists", "id", service.SentinelAWSSecretsMgr)
} }
} }
@@ -286,8 +306,14 @@ func main() {
Name: "Azure Key Vault Discovery", Name: "Azure Key Vault Discovery",
Status: domain.AgentStatusOnline, Status: domain.AgentStatusOnline,
} }
if err := agentRepo.Create(context.Background(), sentinelAzure); err != nil { // M-6: idempotent create (CWE-662).
logger.Debug("sentinel agent creation", "status", "exists or created", "id", service.SentinelAzureKeyVault) created, err := agentRepo.CreateIfNotExists(context.Background(), sentinelAzure)
if err != nil {
logger.Error("sentinel agent creation failed", "id", service.SentinelAzureKeyVault, "error", err)
} else if created {
logger.Info("sentinel agent created", "id", service.SentinelAzureKeyVault)
} else {
logger.Debug("sentinel agent already exists", "id", service.SentinelAzureKeyVault)
} }
} }
@@ -300,8 +326,14 @@ func main() {
Name: "GCP Secret Manager Discovery", Name: "GCP Secret Manager Discovery",
Status: domain.AgentStatusOnline, Status: domain.AgentStatusOnline,
} }
if err := agentRepo.Create(context.Background(), sentinelGCP); err != nil { // M-6: idempotent create (CWE-662).
logger.Debug("sentinel agent creation", "status", "exists or created", "id", service.SentinelGCPSecretMgr) created, err := agentRepo.CreateIfNotExists(context.Background(), sentinelGCP)
if err != nil {
logger.Error("sentinel agent creation failed", "id", service.SentinelGCPSecretMgr, "error", err)
} else if created {
logger.Info("sentinel agent created", "id", service.SentinelGCPSecretMgr)
} else {
logger.Debug("sentinel agent already exists", "id", service.SentinelGCPSecretMgr)
} }
} }
@@ -558,7 +590,7 @@ func main() {
bodyLimitMiddleware, bodyLimitMiddleware,
corsMiddleware, corsMiddleware,
authMiddleware, authMiddleware,
auditMiddleware, auditMiddleware.Middleware,
} }
// Add rate limiter if enabled // Add rate limiter if enabled
@@ -575,7 +607,7 @@ func main() {
rateLimiter, rateLimiter,
corsMiddleware, corsMiddleware,
authMiddleware, authMiddleware,
auditMiddleware, auditMiddleware.Middleware,
} }
logger.Info("rate limiting enabled", "rps", cfg.RateLimit.RPS, "burst", cfg.RateLimit.BurstSize) logger.Info("rate limiting enabled", "rps", cfg.RateLimit.RPS, "burst", cfg.RateLimit.BurstSize)
} }
@@ -693,6 +725,17 @@ func main() {
logger.Error("HTTP server shutdown error", "error", err) logger.Error("HTTP server shutdown error", "error", err)
} }
// Drain in-flight audit-recording goroutines before closing the DB pool.
// The audit middleware spawns one goroutine per non-excluded request; those
// goroutines run detached from the request context and write to the
// audit_events table via the same *sql.DB. Without this drain, SIGTERM
// would close the DB pool while recordings were mid-flight, silently
// dropping audit events (M-1, CWE-662 / CWE-400).
logger.Info("flushing audit middleware in-flight recordings")
if err := auditMiddleware.Flush(shutdownCtx); err != nil {
logger.Warn("audit middleware flush did not complete in time", "error", err)
}
// Close database connection // Close database connection
if err := db.Close(); err != nil { if err := db.Close(); err != nil {
logger.Error("error closing database connection", "error", err) logger.Error("error closing database connection", "error", err)
+28
View File
@@ -808,6 +808,34 @@ All shell-facing inputs (connector scripts, domain names, ACME tokens) are valid
All incoming HTTP request bodies are capped by `http.MaxBytesReader` middleware (default 1MB, configurable via `CERTCTL_MAX_BODY_SIZE`). Requests exceeding the limit receive a 413 Request Entity Too Large response. The middleware is positioned before authentication in the chain so oversized payloads are rejected early, before any auth processing or database work occurs. Requests without bodies (GET, HEAD, nil body) skip the limit check. All incoming HTTP request bodies are capped by `http.MaxBytesReader` middleware (default 1MB, configurable via `CERTCTL_MAX_BODY_SIZE`). Requests exceeding the limit receive a 413 Request Entity Too Large response. The middleware is positioned before authentication in the chain so oversized payloads are rejected early, before any auth processing or database work occurs. Requests without bodies (GET, HEAD, nil body) skip the limit check.
### Config Encryption at Rest
Dynamic issuer and target configurations (rows with `source='database'`) contain credentials — ACME EAB HMACs, Vault tokens, DigiCert/Sectigo API keys, SSH private keys, WinRM passwords, F5 BIG-IP passwords, and similar. These are sealed at rest in PostgreSQL via `internal/crypto/encryption.go` using AES-256-GCM with a key derived from the operator passphrase `CERTCTL_CONFIG_ENCRYPTION_KEY` through PBKDF2-SHA256 (100,000 rounds, 32-byte output).
**v2 wire format (current, M-8 remediation, CWE-916 / CWE-329):**
```
magic(0x02) || salt(16) || nonce(12) || ciphertext+tag
```
Every call to `EncryptIfKeySet` draws 16 fresh bytes from `crypto/rand` as the PBKDF2 salt, so the derived AES-256 key is distinct per ciphertext and per re-encryption. The salt is stored alongside the ciphertext; decryption reads the magic byte, splits out the salt, re-derives the key, and verifies the AEAD tag.
**v1 legacy format (read-only):**
```
nonce(12) || ciphertext+tag
```
Pre-M-8 blobs were sealed with a package-level fixed salt `"certctl-config-encryption-v1"`. `DecryptIfKeySet` preserves the v1 read path unchanged — a blob whose first byte is not `0x02`, or whose v2 AEAD verification fails (including the 1/256 case where a v1 nonce happens to begin with `0x02`), falls through to a v1 attempt against the legacy fixed salt. v1 blobs are never written by the post-M-8 code path; they re-seal as v2 naturally on the next UPDATE through the normal service CRUD flow. No operator migration ceremony is required.
**Fail-closed behavior (C-2 sentinel, CWE-311):** both `EncryptIfKeySet` and `DecryptIfKeySet` return `ErrEncryptionKeyRequired` when invoked with an empty passphrase. The server refuses to start if any `source='database'` rows already exist without `CERTCTL_CONFIG_ENCRYPTION_KEY` set.
**Low-level primitives preserved byte-identical.** `Encrypt`, `Decrypt`, and `DeriveKey` are kept bit-stable so v1 fixtures on disk remain decryptable unchanged and so callers outside the config-encryption path (none today, but the symbols are exported) do not see a breaking change. The new per-ciphertext salt path is reached via the helper `deriveKeyWithSalt(passphrase, salt)`.
**Passphrase plumbing.** Services (`IssuerService`, `TargetService`, `IssuerRegistry`) hold the operator passphrase as a raw `string` and delegate PBKDF2 to the crypto package per ciphertext. This replaces the pre-M-8 design that pre-derived a single `[]byte` key at service construction and reused it for every row, which was the direct consequence of the fixed-salt KDF.
**Coverage gate.** CI enforces `internal/crypto/...` coverage ≥ 85% (observed 86.7%) — the encryption primitives are a security-critical gate, and the v2 format plus v1 fallback plus C-2 sentinel paths all need exhaustive coverage to avoid silent regressions.
### CORS ### CORS
CORS uses a **deny-by-default** posture: when `CERTCTL_CORS_ORIGINS` is empty, no CORS headers are set and only same-origin requests can read responses. Operators must explicitly configure allowed origins. This prevents accidental exposure of the API to cross-origin requests in production. CORS uses a **deny-by-default** posture: when `CERTCTL_CORS_ORIGINS` is empty, no CORS headers are set and only same-origin requests can read responses. Operators must explicitly configure allowed origins. This prevents accidental exposure of the API to cross-origin requests in production.
@@ -27,6 +27,7 @@ package handler
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -120,7 +121,7 @@ func TestGetCertificate_PathInjection(t *testing.T) {
handler, mock := newCertHandlerWithMock() handler, mock := newCertHandlerWithMock()
// Force a 404 so we can distinguish "service was called" from // Force a 404 so we can distinguish "service was called" from
// "parser accepted the ID"; a 200 with null body is also fine. // "parser accepted the ID"; a 200 with null body is also fine.
mock.GetCertificateFn = func(id string) (*domain.ManagedCertificate, error) { mock.GetCertificateFn = func(_ context.Context, id string) (*domain.ManagedCertificate, error) {
return nil, ErrMockNotFound return nil, ErrMockNotFound
} }
@@ -156,7 +157,7 @@ func TestUpdateCertificate_PathInjection(t *testing.T) {
}() }()
handler, mock := newCertHandlerWithMock() handler, mock := newCertHandlerWithMock()
mock.UpdateCertificateFn = func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { mock.UpdateCertificateFn = func(_ context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
return nil, ErrMockNotFound return nil, ErrMockNotFound
} }
@@ -184,7 +185,7 @@ func TestArchiveCertificate_PathInjection(t *testing.T) {
}() }()
handler, mock := newCertHandlerWithMock() handler, mock := newCertHandlerWithMock()
mock.ArchiveCertificateFn = func(id string) error { return ErrMockNotFound } mock.ArchiveCertificateFn = func(_ context.Context, id string) error { return ErrMockNotFound }
req := httptest.NewRequest(http.MethodDelete, "/api/v1/certificates/x", nil) req := httptest.NewRequest(http.MethodDelete, "/api/v1/certificates/x", nil)
req.URL.Path = "/api/v1/certificates/" + tc.input req.URL.Path = "/api/v1/certificates/" + tc.input
@@ -227,7 +228,7 @@ func TestGetCertificateVersions_MultiSegment(t *testing.T) {
}() }()
handler, mock := newCertHandlerWithMock() handler, mock := newCertHandlerWithMock()
mock.GetCertificateVersionsFn = func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) { mock.GetCertificateVersionsFn = func(_ context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
return []domain.CertificateVersion{}, 0, nil return []domain.CertificateVersion{}, 0, nil
} }
@@ -277,7 +278,7 @@ func TestHandleOCSP_MultiSegment(t *testing.T) {
}() }()
handler, mock := newCertHandlerWithMock() handler, mock := newCertHandlerWithMock()
mock.GetOCSPResponseFn = func(issuerID, serialHex string) ([]byte, error) { mock.GetOCSPResponseFn = func(_ context.Context, issuerID, serialHex string) ([]byte, error) {
return nil, ErrMockNotFound return nil, ErrMockNotFound
} }
@@ -311,7 +312,7 @@ func TestGetDERCRL_IssuerPathInjection(t *testing.T) {
}() }()
handler, mock := newCertHandlerWithMock() handler, mock := newCertHandlerWithMock()
mock.GenerateDERCRLFn = func(issuerID string) ([]byte, error) { mock.GenerateDERCRLFn = func(_ context.Context, issuerID string) ([]byte, error) {
return nil, ErrMockNotFound return nil, ErrMockNotFound
} }
+12 -11
View File
@@ -19,6 +19,7 @@ package handler
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -76,7 +77,7 @@ func TestListCertificates_PaginationAbuse(t *testing.T) {
}() }()
handler, mock := newCertHandlerWithMock() handler, mock := newCertHandlerWithMock()
mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
// Sanity: page/perPage on the filter must never be negative // Sanity: page/perPage on the filter must never be negative
// and perPage must never exceed 500 after parsing. // and perPage must never exceed 500 after parsing.
if filter.Page < 1 { if filter.Page < 1 {
@@ -133,7 +134,7 @@ func TestListCertificates_SortAbuse(t *testing.T) {
}() }()
handler, mock := newCertHandlerWithMock() handler, mock := newCertHandlerWithMock()
mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
return []domain.ManagedCertificate{}, 0, nil return []domain.ManagedCertificate{}, 0, nil
} }
@@ -175,7 +176,7 @@ func TestListCertificates_FieldsAbuse(t *testing.T) {
}() }()
handler, mock := newCertHandlerWithMock() handler, mock := newCertHandlerWithMock()
mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
return []domain.ManagedCertificate{}, 0, nil return []domain.ManagedCertificate{}, 0, nil
} }
@@ -219,7 +220,7 @@ func TestListCertificates_TimeRangeAbuse(t *testing.T) {
}() }()
handler, mock := newCertHandlerWithMock() handler, mock := newCertHandlerWithMock()
mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
return []domain.ManagedCertificate{}, 0, nil return []domain.ManagedCertificate{}, 0, nil
} }
@@ -263,7 +264,7 @@ func TestListCertificates_CursorAbuse(t *testing.T) {
}() }()
handler, mock := newCertHandlerWithMock() handler, mock := newCertHandlerWithMock()
mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
return []domain.ManagedCertificate{}, 0, nil return []domain.ManagedCertificate{}, 0, nil
} }
@@ -314,7 +315,7 @@ func TestListCertificates_FilterInjection(t *testing.T) {
}() }()
handler, mock := newCertHandlerWithMock() handler, mock := newCertHandlerWithMock()
mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
return []domain.ManagedCertificate{}, 0, nil return []domain.ManagedCertificate{}, 0, nil
} }
@@ -374,7 +375,7 @@ func TestCreateCertificate_BodyAbuse(t *testing.T) {
}() }()
handler, mock := newCertHandlerWithMock() handler, mock := newCertHandlerWithMock()
mock.CreateCertificateFn = func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { mock.CreateCertificateFn = func(_ context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
// If we ever reach this, the handler accepted a malformed // If we ever reach this, the handler accepted a malformed
// body. Return a sentinel that passes but flag it. // body. Return a sentinel that passes but flag it.
c := cert c := cert
@@ -419,7 +420,7 @@ func TestCreateCertificate_HugeBody(t *testing.T) {
sb.WriteString(`]}`) sb.WriteString(`]}`)
handler, mock := newCertHandlerWithMock() handler, mock := newCertHandlerWithMock()
mock.CreateCertificateFn = func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { mock.CreateCertificateFn = func(_ context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
c := cert c := cert
c.ID = "mc-huge" c.ID = "mc-huge"
return &c, nil return &c, nil
@@ -476,7 +477,7 @@ func TestRevokeCertificate_ReasonAbuse(t *testing.T) {
handler, mock := newCertHandlerWithMock() handler, mock := newCertHandlerWithMock()
// The mock always returns "invalid revocation reason" so we // The mock always returns "invalid revocation reason" so we
// verify the handler's errMsg→status mapping turns it into a 400. // verify the handler's errMsg→status mapping turns it into a 400.
mock.RevokeCertificateFn = func(id string, reason string) error { mock.RevokeCertificateFn = func(_ context.Context, id string, reason string, _ string) error {
// The service uses domain.IsValidRevocationReason. If we got // The service uses domain.IsValidRevocationReason. If we got
// through to here with something bogus, simulate a real // through to here with something bogus, simulate a real
// service error. // service error.
@@ -500,7 +501,7 @@ func TestRevokeCertificate_ReasonAbuse(t *testing.T) {
// service error message, which is fragile — this test catches regressions. // service error message, which is fragile — this test catches regressions.
func TestRevokeCertificate_AlreadyRevoked(t *testing.T) { func TestRevokeCertificate_AlreadyRevoked(t *testing.T) {
handler, mock := newCertHandlerWithMock() handler, mock := newCertHandlerWithMock()
mock.RevokeCertificateFn = func(id string, reason string) error { mock.RevokeCertificateFn = func(_ context.Context, id string, reason string, _ string) error {
return fmt.Errorf("cannot revoke: certificate is already revoked") return fmt.Errorf("cannot revoke: certificate is already revoked")
} }
@@ -520,7 +521,7 @@ func TestRevokeCertificate_AlreadyRevoked(t *testing.T) {
// TestRevokeCertificate_NotFound verifies 404 mapping. // TestRevokeCertificate_NotFound verifies 404 mapping.
func TestRevokeCertificate_NotFound(t *testing.T) { func TestRevokeCertificate_NotFound(t *testing.T) {
handler, mock := newCertHandlerWithMock() handler, mock := newCertHandlerWithMock()
mock.RevokeCertificateFn = func(id string, reason string) error { mock.RevokeCertificateFn = func(_ context.Context, id string, reason string, _ string) error {
return fmt.Errorf("certificate not found") return fmt.Errorf("certificate not found")
} }
+5 -4
View File
@@ -1,6 +1,7 @@
package handler package handler
import ( import (
"context"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@@ -11,8 +12,8 @@ import (
// AuditService defines the service interface for audit event operations. // AuditService defines the service interface for audit event operations.
type AuditService interface { type AuditService interface {
ListAuditEvents(page, perPage int) ([]domain.AuditEvent, int64, error) ListAuditEvents(ctx context.Context, page, perPage int) ([]domain.AuditEvent, int64, error)
GetAuditEvent(id string) (*domain.AuditEvent, error) GetAuditEvent(ctx context.Context, id string) (*domain.AuditEvent, error)
} }
// AuditHandler handles HTTP requests for audit event operations. // AuditHandler handles HTTP requests for audit event operations.
@@ -49,7 +50,7 @@ func (h AuditHandler) ListAuditEvents(w http.ResponseWriter, r *http.Request) {
} }
} }
events, total, err := h.svc.ListAuditEvents(page, perPage) events, total, err := h.svc.ListAuditEvents(r.Context(), page, perPage)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list audit events", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list audit events", requestID)
return return
@@ -83,7 +84,7 @@ func (h AuditHandler) GetAuditEvent(w http.ResponseWriter, r *http.Request) {
} }
id = parts[0] id = parts[0]
event, err := h.svc.GetAuditEvent(id) event, err := h.svc.GetAuditEvent(r.Context(), id)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusNotFound, "Audit event not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Audit event not found", requestID)
return return
+2 -2
View File
@@ -19,14 +19,14 @@ type mockAuditService struct {
getFunc func(id string) (*domain.AuditEvent, error) getFunc func(id string) (*domain.AuditEvent, error)
} }
func (m *mockAuditService) ListAuditEvents(page, perPage int) ([]domain.AuditEvent, int64, error) { func (m *mockAuditService) ListAuditEvents(_ context.Context, page, perPage int) ([]domain.AuditEvent, int64, error) {
if m.listFunc != nil { if m.listFunc != nil {
return m.listFunc(page, perPage) return m.listFunc(page, perPage)
} }
return nil, 0, nil return nil, 0, nil
} }
func (m *mockAuditService) GetAuditEvent(id string) (*domain.AuditEvent, error) { func (m *mockAuditService) GetAuditEvent(_ context.Context, id string) (*domain.AuditEvent, error) {
if m.getFunc != nil { if m.getFunc != nil {
return m.getFunc(id) return m.getFunc(id)
} }
+148 -88
View File
@@ -17,116 +17,116 @@ import (
// MockCertificateService is a mock implementation of CertificateService interface. // MockCertificateService is a mock implementation of CertificateService interface.
type MockCertificateService struct { type MockCertificateService struct {
ListCertificatesFn func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) ListCertificatesFn func(ctx context.Context, status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error)
ListCertificatesWithFilterFn func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) ListCertificatesWithFilterFn func(ctx context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error)
GetCertificateFn func(id string) (*domain.ManagedCertificate, error) GetCertificateFn func(ctx context.Context, id string) (*domain.ManagedCertificate, error)
CreateCertificateFn func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) CreateCertificateFn func(ctx context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
UpdateCertificateFn func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) UpdateCertificateFn func(ctx context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
ArchiveCertificateFn func(id string) error ArchiveCertificateFn func(ctx context.Context, id string) error
GetCertificateVersionsFn func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) GetCertificateVersionsFn func(ctx context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error)
TriggerRenewalFn func(certID string) error TriggerRenewalFn func(ctx context.Context, certID string, actor string) error
TriggerDeploymentFn func(certID string, targetID string) error TriggerDeploymentFn func(ctx context.Context, certID string, targetID string, actor string) error
RevokeCertificateFn func(certID string, reason string) error RevokeCertificateFn func(ctx context.Context, certID string, reason string, actor string) error
GetRevokedCertificatesFn func() ([]*domain.CertificateRevocation, error) GetRevokedCertificatesFn func(ctx context.Context) ([]*domain.CertificateRevocation, error)
GenerateDERCRLFn func(issuerID string) ([]byte, error) GenerateDERCRLFn func(ctx context.Context, issuerID string) ([]byte, error)
GetOCSPResponseFn func(issuerID string, serialHex string) ([]byte, error) GetOCSPResponseFn func(ctx context.Context, issuerID string, serialHex string) ([]byte, error)
GetCertificateDeploymentsFn func(certID string) ([]domain.DeploymentTarget, error) GetCertificateDeploymentsFn func(ctx context.Context, certID string) ([]domain.DeploymentTarget, error)
} }
func (m *MockCertificateService) ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) { func (m *MockCertificateService) ListCertificates(ctx context.Context, status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
if m.ListCertificatesFn != nil { if m.ListCertificatesFn != nil {
return m.ListCertificatesFn(status, environment, ownerID, teamID, issuerID, page, perPage) return m.ListCertificatesFn(ctx, status, environment, ownerID, teamID, issuerID, page, perPage)
} }
return nil, 0, nil return nil, 0, nil
} }
func (m *MockCertificateService) GetCertificate(id string) (*domain.ManagedCertificate, error) { func (m *MockCertificateService) GetCertificate(ctx context.Context, id string) (*domain.ManagedCertificate, error) {
if m.GetCertificateFn != nil { if m.GetCertificateFn != nil {
return m.GetCertificateFn(id) return m.GetCertificateFn(ctx, id)
} }
return nil, nil return nil, nil
} }
func (m *MockCertificateService) CreateCertificate(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { func (m *MockCertificateService) CreateCertificate(ctx context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
if m.CreateCertificateFn != nil { if m.CreateCertificateFn != nil {
return m.CreateCertificateFn(cert) return m.CreateCertificateFn(ctx, cert)
} }
return nil, nil return nil, nil
} }
func (m *MockCertificateService) UpdateCertificate(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { func (m *MockCertificateService) UpdateCertificate(ctx context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
if m.UpdateCertificateFn != nil { if m.UpdateCertificateFn != nil {
return m.UpdateCertificateFn(id, cert) return m.UpdateCertificateFn(ctx, id, cert)
} }
return nil, nil return nil, nil
} }
func (m *MockCertificateService) ArchiveCertificate(id string) error { func (m *MockCertificateService) ArchiveCertificate(ctx context.Context, id string) error {
if m.ArchiveCertificateFn != nil { if m.ArchiveCertificateFn != nil {
return m.ArchiveCertificateFn(id) return m.ArchiveCertificateFn(ctx, id)
} }
return nil return nil
} }
func (m *MockCertificateService) GetCertificateVersions(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) { func (m *MockCertificateService) GetCertificateVersions(ctx context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
if m.GetCertificateVersionsFn != nil { if m.GetCertificateVersionsFn != nil {
return m.GetCertificateVersionsFn(certID, page, perPage) return m.GetCertificateVersionsFn(ctx, certID, page, perPage)
} }
return nil, 0, nil return nil, 0, nil
} }
func (m *MockCertificateService) TriggerRenewal(certID string) error { func (m *MockCertificateService) TriggerRenewal(ctx context.Context, certID string, actor string) error {
if m.TriggerRenewalFn != nil { if m.TriggerRenewalFn != nil {
return m.TriggerRenewalFn(certID) return m.TriggerRenewalFn(ctx, certID, actor)
} }
return nil return nil
} }
func (m *MockCertificateService) TriggerDeployment(certID string, targetID string) error { func (m *MockCertificateService) TriggerDeployment(ctx context.Context, certID string, targetID string, actor string) error {
if m.TriggerDeploymentFn != nil { if m.TriggerDeploymentFn != nil {
return m.TriggerDeploymentFn(certID, targetID) return m.TriggerDeploymentFn(ctx, certID, targetID, actor)
} }
return nil return nil
} }
func (m *MockCertificateService) RevokeCertificate(certID string, reason string) error { func (m *MockCertificateService) RevokeCertificate(ctx context.Context, certID string, reason string, actor string) error {
if m.RevokeCertificateFn != nil { if m.RevokeCertificateFn != nil {
return m.RevokeCertificateFn(certID, reason) return m.RevokeCertificateFn(ctx, certID, reason, actor)
} }
return nil return nil
} }
func (m *MockCertificateService) GetRevokedCertificates() ([]*domain.CertificateRevocation, error) { func (m *MockCertificateService) GetRevokedCertificates(ctx context.Context) ([]*domain.CertificateRevocation, error) {
if m.GetRevokedCertificatesFn != nil { if m.GetRevokedCertificatesFn != nil {
return m.GetRevokedCertificatesFn() return m.GetRevokedCertificatesFn(ctx)
} }
return nil, nil return nil, nil
} }
func (m *MockCertificateService) GenerateDERCRL(issuerID string) ([]byte, error) { func (m *MockCertificateService) GenerateDERCRL(ctx context.Context, issuerID string) ([]byte, error) {
if m.GenerateDERCRLFn != nil { if m.GenerateDERCRLFn != nil {
return m.GenerateDERCRLFn(issuerID) return m.GenerateDERCRLFn(ctx, issuerID)
} }
return nil, nil return nil, nil
} }
func (m *MockCertificateService) GetOCSPResponse(issuerID string, serialHex string) ([]byte, error) { func (m *MockCertificateService) GetOCSPResponse(ctx context.Context, issuerID string, serialHex string) ([]byte, error) {
if m.GetOCSPResponseFn != nil { if m.GetOCSPResponseFn != nil {
return m.GetOCSPResponseFn(issuerID, serialHex) return m.GetOCSPResponseFn(ctx, issuerID, serialHex)
} }
return nil, nil return nil, nil
} }
func (m *MockCertificateService) ListCertificatesWithFilter(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { func (m *MockCertificateService) ListCertificatesWithFilter(ctx context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if m.ListCertificatesWithFilterFn != nil { if m.ListCertificatesWithFilterFn != nil {
return m.ListCertificatesWithFilterFn(filter) return m.ListCertificatesWithFilterFn(ctx, filter)
} }
return nil, 0, nil return nil, 0, nil
} }
func (m *MockCertificateService) GetCertificateDeployments(certID string) ([]domain.DeploymentTarget, error) { func (m *MockCertificateService) GetCertificateDeployments(ctx context.Context, certID string) ([]domain.DeploymentTarget, error) {
if m.GetCertificateDeploymentsFn != nil { if m.GetCertificateDeploymentsFn != nil {
return m.GetCertificateDeploymentsFn(certID) return m.GetCertificateDeploymentsFn(ctx, certID)
} }
return nil, nil return nil, nil
} }
@@ -158,7 +158,7 @@ func TestListCertificates_Success(t *testing.T) {
} }
mock := &MockCertificateService{ mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if filter.Page == 1 && filter.PerPage == 50 { if filter.Page == 1 && filter.PerPage == 50 {
return []domain.ManagedCertificate{cert1, cert2}, 2, nil return []domain.ManagedCertificate{cert1, cert2}, 2, nil
} }
@@ -197,7 +197,7 @@ func TestListCertificates_Success(t *testing.T) {
// Test ListCertificates - with filters // Test ListCertificates - with filters
func TestListCertificates_WithFilters(t *testing.T) { func TestListCertificates_WithFilters(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if filter.Status == "Active" && filter.Environment == "prod" { if filter.Status == "Active" && filter.Environment == "prod" {
return []domain.ManagedCertificate{}, 0, nil return []domain.ManagedCertificate{}, 0, nil
} }
@@ -236,7 +236,7 @@ func TestListCertificates_MethodNotAllowed(t *testing.T) {
// Test ListCertificates - service error // Test ListCertificates - service error
func TestListCertificates_ServiceError(t *testing.T) { func TestListCertificates_ServiceError(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
return nil, 0, ErrMockServiceFailed return nil, 0, ErrMockServiceFailed
}, },
} }
@@ -266,7 +266,7 @@ func TestGetCertificate_Success(t *testing.T) {
} }
mock := &MockCertificateService{ mock := &MockCertificateService{
GetCertificateFn: func(id string) (*domain.ManagedCertificate, error) { GetCertificateFn: func(_ context.Context, id string) (*domain.ManagedCertificate, error) {
if id == "mc-prod-001" { if id == "mc-prod-001" {
return cert, nil return cert, nil
} }
@@ -298,7 +298,7 @@ func TestGetCertificate_Success(t *testing.T) {
// Test GetCertificate - not found // Test GetCertificate - not found
func TestGetCertificate_NotFound(t *testing.T) { func TestGetCertificate_NotFound(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
GetCertificateFn: func(id string) (*domain.ManagedCertificate, error) { GetCertificateFn: func(_ context.Context, id string) (*domain.ManagedCertificate, error) {
return nil, ErrMockNotFound return nil, ErrMockNotFound
}, },
} }
@@ -345,7 +345,7 @@ func TestCreateCertificate_Success(t *testing.T) {
} }
mock := &MockCertificateService{ mock := &MockCertificateService{
CreateCertificateFn: func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { CreateCertificateFn: func(_ context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
return created, nil return created, nil
}, },
} }
@@ -403,7 +403,7 @@ func TestCreateCertificate_InvalidBody(t *testing.T) {
// Test CreateCertificate - service error // Test CreateCertificate - service error
func TestCreateCertificate_ServiceError(t *testing.T) { func TestCreateCertificate_ServiceError(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
CreateCertificateFn: func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { CreateCertificateFn: func(_ context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
return nil, ErrMockServiceFailed return nil, ErrMockServiceFailed
}, },
} }
@@ -432,6 +432,66 @@ func TestCreateCertificate_ServiceError(t *testing.T) {
} }
} }
// TestCreateCertificate_MissingRequiredField_Returns400 pins the C-001 handler
// contract: handler MUST reject a create payload that omits any of the five
// required fields (name, common_name, owner_id, team_id, issuer_id,
// renewal_policy_id) with HTTP 400 before the service is invoked. The mock
// service here would succeed if called; every subtest proving 400 therefore
// proves the handler guard fires.
func TestCreateCertificate_MissingRequiredField_Returns400(t *testing.T) {
baseBody := map[string]interface{}{
"name": "API Prod",
"common_name": "api.example.com",
"owner_id": "o-alice",
"team_id": "t-platform",
"issuer_id": "iss-local",
"renewal_policy_id": "rp-standard",
}
cases := []struct {
name string
missingField string
}{
{"missing name", "name"},
{"missing common_name", "common_name"},
{"missing owner_id", "owner_id"},
{"missing team_id", "team_id"},
{"missing issuer_id", "issuer_id"},
{"missing renewal_policy_id", "renewal_policy_id"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
body := make(map[string]interface{}, len(baseBody))
for k, v := range baseBody {
body[k] = v
}
delete(body, tc.missingField)
bodyBytes, _ := json.Marshal(body)
mock := &MockCertificateService{
CreateCertificateFn: func(_ context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
// Would succeed if handler guard did not fire.
cert.ID = "mc-would-be-created"
return &cert, nil
},
}
handler := NewCertificateHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates", bytes.NewReader(bodyBytes))
req = req.WithContext(contextWithRequestID())
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
handler.CreateCertificate(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("%s: expected 400, got %d — body=%s", tc.name, w.Code, w.Body.String())
}
})
}
}
// Test UpdateCertificate - success case // Test UpdateCertificate - success case
func TestUpdateCertificate_Success(t *testing.T) { func TestUpdateCertificate_Success(t *testing.T) {
updated := &domain.ManagedCertificate{ updated := &domain.ManagedCertificate{
@@ -445,7 +505,7 @@ func TestUpdateCertificate_Success(t *testing.T) {
} }
mock := &MockCertificateService{ mock := &MockCertificateService{
UpdateCertificateFn: func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { UpdateCertificateFn: func(_ context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
if id == "mc-prod-001" { if id == "mc-prod-001" {
return updated, nil return updated, nil
} }
@@ -501,7 +561,7 @@ func TestUpdateCertificate_InvalidBody(t *testing.T) {
// Test ArchiveCertificate - success case // Test ArchiveCertificate - success case
func TestArchiveCertificate_Success(t *testing.T) { func TestArchiveCertificate_Success(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
ArchiveCertificateFn: func(id string) error { ArchiveCertificateFn: func(_ context.Context, id string) error {
if id == "mc-prod-001" { if id == "mc-prod-001" {
return nil return nil
} }
@@ -524,7 +584,7 @@ func TestArchiveCertificate_Success(t *testing.T) {
// Test ArchiveCertificate - not found // Test ArchiveCertificate - not found
func TestArchiveCertificate_NotFound(t *testing.T) { func TestArchiveCertificate_NotFound(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
ArchiveCertificateFn: func(id string) error { ArchiveCertificateFn: func(_ context.Context, id string) error {
return ErrMockNotFound return ErrMockNotFound
}, },
} }
@@ -554,7 +614,7 @@ func TestGetCertificateVersions_Success(t *testing.T) {
} }
mock := &MockCertificateService{ mock := &MockCertificateService{
GetCertificateVersionsFn: func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) { GetCertificateVersionsFn: func(_ context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
if certID == "mc-prod-001" { if certID == "mc-prod-001" {
return []domain.CertificateVersion{ver1}, 1, nil return []domain.CertificateVersion{ver1}, 1, nil
} }
@@ -586,7 +646,7 @@ func TestGetCertificateVersions_Success(t *testing.T) {
// Test GetCertificateVersions - not found // Test GetCertificateVersions - not found
func TestGetCertificateVersions_NotFound(t *testing.T) { func TestGetCertificateVersions_NotFound(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
GetCertificateVersionsFn: func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) { GetCertificateVersionsFn: func(_ context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
return nil, 0, ErrMockNotFound return nil, 0, ErrMockNotFound
}, },
} }
@@ -606,7 +666,7 @@ func TestGetCertificateVersions_NotFound(t *testing.T) {
// Test TriggerRenewal - success case // Test TriggerRenewal - success case
func TestTriggerRenewal_Success(t *testing.T) { func TestTriggerRenewal_Success(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
TriggerRenewalFn: func(certID string) error { TriggerRenewalFn: func(_ context.Context, certID string, _ string) error {
if certID == "mc-prod-001" { if certID == "mc-prod-001" {
return nil return nil
} }
@@ -638,7 +698,7 @@ func TestTriggerRenewal_Success(t *testing.T) {
// Test TriggerRenewal - service error // Test TriggerRenewal - service error
func TestTriggerRenewal_ServiceError(t *testing.T) { func TestTriggerRenewal_ServiceError(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
TriggerRenewalFn: func(certID string) error { TriggerRenewalFn: func(_ context.Context, certID string, _ string) error {
return ErrMockServiceFailed return ErrMockServiceFailed
}, },
} }
@@ -658,7 +718,7 @@ func TestTriggerRenewal_ServiceError(t *testing.T) {
// Test TriggerDeployment - success case // Test TriggerDeployment - success case
func TestTriggerDeployment_Success(t *testing.T) { func TestTriggerDeployment_Success(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
TriggerDeploymentFn: func(certID string, targetID string) error { TriggerDeploymentFn: func(_ context.Context, certID string, targetID string, _ string) error {
if certID == "mc-prod-001" { if certID == "mc-prod-001" {
return nil return nil
} }
@@ -695,7 +755,7 @@ func TestTriggerDeployment_Success(t *testing.T) {
// Test TriggerDeployment - without target ID // Test TriggerDeployment - without target ID
func TestTriggerDeployment_NoTargetID(t *testing.T) { func TestTriggerDeployment_NoTargetID(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
TriggerDeploymentFn: func(certID string, targetID string) error { TriggerDeploymentFn: func(_ context.Context, certID string, targetID string, _ string) error {
// Should accept empty targetID (deploy to all) // Should accept empty targetID (deploy to all)
return nil return nil
}, },
@@ -716,7 +776,7 @@ func TestTriggerDeployment_NoTargetID(t *testing.T) {
// Test ListCertificates - invalid page parameter // Test ListCertificates - invalid page parameter
func TestListCertificates_InvalidPageParam(t *testing.T) { func TestListCertificates_InvalidPageParam(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
// Should default to page 1 // Should default to page 1
if filter.Page == 1 { if filter.Page == 1 {
return []domain.ManagedCertificate{}, 0, nil return []domain.ManagedCertificate{}, 0, nil
@@ -740,7 +800,7 @@ func TestListCertificates_InvalidPageParam(t *testing.T) {
// Test ListCertificates - per_page exceeds max // Test ListCertificates - per_page exceeds max
func TestListCertificates_PerPageExceedsMax(t *testing.T) { func TestListCertificates_PerPageExceedsMax(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
// Should cap perPage at 500 // Should cap perPage at 500
if filter.PerPage == 50 { // defaults to 50 if > 500 if filter.PerPage == 50 { // defaults to 50 if > 500
return []domain.ManagedCertificate{}, 0, nil return []domain.ManagedCertificate{}, 0, nil
@@ -765,7 +825,7 @@ func TestListCertificates_PerPageExceedsMax(t *testing.T) {
func TestRevokeCertificate_Handler_Success(t *testing.T) { func TestRevokeCertificate_Handler_Success(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error { RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
if certID != "mc-prod-001" { if certID != "mc-prod-001" {
t.Errorf("expected certID mc-prod-001, got %s", certID) t.Errorf("expected certID mc-prod-001, got %s", certID)
} }
@@ -798,7 +858,7 @@ func TestRevokeCertificate_Handler_Success(t *testing.T) {
func TestRevokeCertificate_Handler_NoBody(t *testing.T) { func TestRevokeCertificate_Handler_NoBody(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error { RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
// Empty reason is OK — service defaults to "unspecified" // Empty reason is OK — service defaults to "unspecified"
return nil return nil
}, },
@@ -818,7 +878,7 @@ func TestRevokeCertificate_Handler_NoBody(t *testing.T) {
func TestRevokeCertificate_Handler_AlreadyRevoked(t *testing.T) { func TestRevokeCertificate_Handler_AlreadyRevoked(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error { RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
return fmt.Errorf("certificate is already revoked") return fmt.Errorf("certificate is already revoked")
}, },
} }
@@ -839,7 +899,7 @@ func TestRevokeCertificate_Handler_AlreadyRevoked(t *testing.T) {
func TestRevokeCertificate_Handler_NotFound(t *testing.T) { func TestRevokeCertificate_Handler_NotFound(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error { RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
return fmt.Errorf("failed to fetch certificate: not found") return fmt.Errorf("failed to fetch certificate: not found")
}, },
} }
@@ -858,7 +918,7 @@ func TestRevokeCertificate_Handler_NotFound(t *testing.T) {
func TestRevokeCertificate_Handler_InvalidReason(t *testing.T) { func TestRevokeCertificate_Handler_InvalidReason(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error { RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
return fmt.Errorf("invalid revocation reason: badReason") return fmt.Errorf("invalid revocation reason: badReason")
}, },
} }
@@ -922,7 +982,7 @@ func TestRevokeCertificate_Handler_EmptyID(t *testing.T) {
func TestRevokeCertificate_Handler_CannotRevokeArchived(t *testing.T) { func TestRevokeCertificate_Handler_CannotRevokeArchived(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error { RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
return fmt.Errorf("cannot revoke archived certificate") return fmt.Errorf("cannot revoke archived certificate")
}, },
} }
@@ -941,7 +1001,7 @@ func TestRevokeCertificate_Handler_CannotRevokeArchived(t *testing.T) {
func TestRevokeCertificate_Handler_ServerError(t *testing.T) { func TestRevokeCertificate_Handler_ServerError(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
RevokeCertificateFn: func(certID string, reason string) error { RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
return fmt.Errorf("database connection lost") return fmt.Errorf("database connection lost")
}, },
} }
@@ -962,7 +1022,7 @@ func TestRevokeCertificate_Handler_ServerError(t *testing.T) {
func TestGetCRL_Success(t *testing.T) { func TestGetCRL_Success(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
GetRevokedCertificatesFn: func() ([]*domain.CertificateRevocation, error) { GetRevokedCertificatesFn: func(_ context.Context) ([]*domain.CertificateRevocation, error) {
return []*domain.CertificateRevocation{ return []*domain.CertificateRevocation{
{ {
ID: "rev-1", ID: "rev-1",
@@ -1022,7 +1082,7 @@ func TestGetCRL_Success(t *testing.T) {
func TestGetCRL_Empty(t *testing.T) { func TestGetCRL_Empty(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
GetRevokedCertificatesFn: func() ([]*domain.CertificateRevocation, error) { GetRevokedCertificatesFn: func(_ context.Context) ([]*domain.CertificateRevocation, error) {
return nil, nil return nil, nil
}, },
} }
@@ -1047,7 +1107,7 @@ func TestGetCRL_Empty(t *testing.T) {
func TestGetCRL_ServiceError(t *testing.T) { func TestGetCRL_ServiceError(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
GetRevokedCertificatesFn: func() ([]*domain.CertificateRevocation, error) { GetRevokedCertificatesFn: func(_ context.Context) ([]*domain.CertificateRevocation, error) {
return nil, fmt.Errorf("revocation repository not configured") return nil, fmt.Errorf("revocation repository not configured")
}, },
} }
@@ -1083,7 +1143,7 @@ func TestGetCRL_MethodNotAllowed(t *testing.T) {
func TestGetDERCRL_Success(t *testing.T) { func TestGetDERCRL_Success(t *testing.T) {
derCRLData := []byte{0x30, 0x82, 0x01, 0x00} // Mock DER CRL bytes derCRLData := []byte{0x30, 0x82, 0x01, 0x00} // Mock DER CRL bytes
mock := &MockCertificateService{ mock := &MockCertificateService{
GenerateDERCRLFn: func(issuerID string) ([]byte, error) { GenerateDERCRLFn: func(_ context.Context, issuerID string) ([]byte, error) {
if issuerID == "iss-local" { if issuerID == "iss-local" {
return derCRLData, nil return derCRLData, nil
} }
@@ -1111,7 +1171,7 @@ func TestGetDERCRL_Success(t *testing.T) {
func TestGetDERCRL_IssuerNotFound(t *testing.T) { func TestGetDERCRL_IssuerNotFound(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
GenerateDERCRLFn: func(issuerID string) ([]byte, error) { GenerateDERCRLFn: func(_ context.Context, issuerID string) ([]byte, error) {
return nil, fmt.Errorf("issuer not found") return nil, fmt.Errorf("issuer not found")
}, },
} }
@@ -1130,7 +1190,7 @@ func TestGetDERCRL_IssuerNotFound(t *testing.T) {
func TestGetDERCRL_NotSupported(t *testing.T) { func TestGetDERCRL_NotSupported(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
GenerateDERCRLFn: func(issuerID string) ([]byte, error) { GenerateDERCRLFn: func(_ context.Context, issuerID string) ([]byte, error) {
return nil, fmt.Errorf("issuer does not support CRL generation") return nil, fmt.Errorf("issuer does not support CRL generation")
}, },
} }
@@ -1165,7 +1225,7 @@ func TestGetDERCRL_MethodNotAllowed(t *testing.T) {
func TestHandleOCSP_Success(t *testing.T) { func TestHandleOCSP_Success(t *testing.T) {
ocspResponseBytes := []byte{0x30, 0x82, 0x02, 0x00} // Mock OCSP response ocspResponseBytes := []byte{0x30, 0x82, 0x02, 0x00} // Mock OCSP response
mock := &MockCertificateService{ mock := &MockCertificateService{
GetOCSPResponseFn: func(issuerID string, serialHex string) ([]byte, error) { GetOCSPResponseFn: func(_ context.Context, issuerID string, serialHex string) ([]byte, error) {
if issuerID == "iss-local" && serialHex == "12345" { if issuerID == "iss-local" && serialHex == "12345" {
return ocspResponseBytes, nil return ocspResponseBytes, nil
} }
@@ -1206,7 +1266,7 @@ func TestHandleOCSP_MissingSerial(t *testing.T) {
func TestHandleOCSP_IssuerNotFound(t *testing.T) { func TestHandleOCSP_IssuerNotFound(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
GetOCSPResponseFn: func(issuerID string, serialHex string) ([]byte, error) { GetOCSPResponseFn: func(_ context.Context, issuerID string, serialHex string) ([]byte, error) {
return nil, fmt.Errorf("issuer not found") return nil, fmt.Errorf("issuer not found")
}, },
} }
@@ -1225,7 +1285,7 @@ func TestHandleOCSP_IssuerNotFound(t *testing.T) {
func TestHandleOCSP_CertNotFound(t *testing.T) { func TestHandleOCSP_CertNotFound(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
GetOCSPResponseFn: func(issuerID string, serialHex string) ([]byte, error) { GetOCSPResponseFn: func(_ context.Context, issuerID string, serialHex string) ([]byte, error) {
return nil, fmt.Errorf("certificate not found") return nil, fmt.Errorf("certificate not found")
}, },
} }
@@ -1261,7 +1321,7 @@ func TestHandleOCSP_MethodNotAllowed(t *testing.T) {
// TestListCertificates_SortParam tests sort parameter parsing and passing to service. // TestListCertificates_SortParam tests sort parameter parsing and passing to service.
func TestListCertificates_SortParam(t *testing.T) { func TestListCertificates_SortParam(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
// Handler strips the '-' prefix and sets SortDesc = true // Handler strips the '-' prefix and sets SortDesc = true
if filter.Sort != "notAfter" || !filter.SortDesc { if filter.Sort != "notAfter" || !filter.SortDesc {
t.Errorf("expected sort=notAfter desc=true, got sort=%s desc=%v", filter.Sort, filter.SortDesc) t.Errorf("expected sort=notAfter desc=true, got sort=%s desc=%v", filter.Sort, filter.SortDesc)
@@ -1284,7 +1344,7 @@ func TestListCertificates_SortParam(t *testing.T) {
// TestListCertificates_SortParam_Ascending tests sort parameter without '-' prefix (ascending). // TestListCertificates_SortParam_Ascending tests sort parameter without '-' prefix (ascending).
func TestListCertificates_SortParam_Ascending(t *testing.T) { func TestListCertificates_SortParam_Ascending(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if filter.Sort != "createdAt" || filter.SortDesc { if filter.Sort != "createdAt" || filter.SortDesc {
t.Errorf("expected sort=createdAt desc=false, got sort=%s desc=%v", filter.Sort, filter.SortDesc) t.Errorf("expected sort=createdAt desc=false, got sort=%s desc=%v", filter.Sort, filter.SortDesc)
} }
@@ -1309,7 +1369,7 @@ func TestListCertificates_TimeRangeFilters(t *testing.T) {
after := time.Now().AddDate(0, 0, -90) after := time.Now().AddDate(0, 0, -90)
mock := &MockCertificateService{ mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if filter.ExpiresBefore == nil { if filter.ExpiresBefore == nil {
t.Error("expected ExpiresBefore to be set") t.Error("expected ExpiresBefore to be set")
} }
@@ -1339,7 +1399,7 @@ func TestListCertificates_CreatedAfterFilter(t *testing.T) {
past := time.Now().AddDate(-1, 0, 0) past := time.Now().AddDate(-1, 0, 0)
mock := &MockCertificateService{ mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if filter.CreatedAfter == nil { if filter.CreatedAfter == nil {
t.Error("expected CreatedAfter to be set") t.Error("expected CreatedAfter to be set")
} }
@@ -1369,7 +1429,7 @@ func TestListCertificates_CursorPagination(t *testing.T) {
} }
mock := &MockCertificateService{ mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
return []domain.ManagedCertificate{cert}, 1, nil return []domain.ManagedCertificate{cert}, 1, nil
}, },
} }
@@ -1409,7 +1469,7 @@ func TestListCertificates_SparseFields(t *testing.T) {
} }
mock := &MockCertificateService{ mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if len(filter.Fields) != 2 { if len(filter.Fields) != 2 {
t.Errorf("expected 2 fields, got %d", len(filter.Fields)) t.Errorf("expected 2 fields, got %d", len(filter.Fields))
} }
@@ -1456,7 +1516,7 @@ func TestListCertificates_SparseFields(t *testing.T) {
// TestListCertificates_ProfileFilter tests profile_id filter. // TestListCertificates_ProfileFilter tests profile_id filter.
func TestListCertificates_ProfileFilter(t *testing.T) { func TestListCertificates_ProfileFilter(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if filter.ProfileID != "prof-standard" { if filter.ProfileID != "prof-standard" {
t.Errorf("expected ProfileID=prof-standard, got %s", filter.ProfileID) t.Errorf("expected ProfileID=prof-standard, got %s", filter.ProfileID)
} }
@@ -1479,7 +1539,7 @@ func TestListCertificates_ProfileFilter(t *testing.T) {
// TestListCertificates_AgentIDFilter tests agent_id filter. // TestListCertificates_AgentIDFilter tests agent_id filter.
func TestListCertificates_AgentIDFilter(t *testing.T) { func TestListCertificates_AgentIDFilter(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if filter.AgentID != "agent-prod-001" { if filter.AgentID != "agent-prod-001" {
t.Errorf("expected AgentID=agent-prod-001, got %s", filter.AgentID) t.Errorf("expected AgentID=agent-prod-001, got %s", filter.AgentID)
} }
@@ -1502,7 +1562,7 @@ func TestListCertificates_AgentIDFilter(t *testing.T) {
// TestListCertificates_CombinedFilters tests multiple filters together. // TestListCertificates_CombinedFilters tests multiple filters together.
func TestListCertificates_CombinedFilters(t *testing.T) { func TestListCertificates_CombinedFilters(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
if filter.Status != "Active" || filter.Environment != "production" || filter.ProfileID != "prof-standard" { if filter.Status != "Active" || filter.Environment != "production" || filter.ProfileID != "prof-standard" {
t.Error("expected all filters to be set") t.Error("expected all filters to be set")
} }
@@ -1540,7 +1600,7 @@ func TestGetCertificateDeployments_Success(t *testing.T) {
} }
mock := &MockCertificateService{ mock := &MockCertificateService{
GetCertificateDeploymentsFn: func(certID string) ([]domain.DeploymentTarget, error) { GetCertificateDeploymentsFn: func(_ context.Context, certID string) ([]domain.DeploymentTarget, error) {
if certID != "mc-prod-001" { if certID != "mc-prod-001" {
return nil, ErrMockNotFound return nil, ErrMockNotFound
} }
@@ -1576,7 +1636,7 @@ func TestGetCertificateDeployments_Success(t *testing.T) {
// TestGetCertificateDeployments_NotFound tests 404 for nonexistent certificate. // TestGetCertificateDeployments_NotFound tests 404 for nonexistent certificate.
func TestGetCertificateDeployments_NotFound(t *testing.T) { func TestGetCertificateDeployments_NotFound(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
GetCertificateDeploymentsFn: func(certID string) ([]domain.DeploymentTarget, error) { GetCertificateDeploymentsFn: func(_ context.Context, certID string) ([]domain.DeploymentTarget, error) {
return nil, fmt.Errorf("certificate not found") return nil, fmt.Errorf("certificate not found")
}, },
} }
@@ -1596,7 +1656,7 @@ func TestGetCertificateDeployments_NotFound(t *testing.T) {
// TestGetCertificateDeployments_Empty tests successful response with no deployments. // TestGetCertificateDeployments_Empty tests successful response with no deployments.
func TestGetCertificateDeployments_Empty(t *testing.T) { func TestGetCertificateDeployments_Empty(t *testing.T) {
mock := &MockCertificateService{ mock := &MockCertificateService{
GetCertificateDeploymentsFn: func(certID string) ([]domain.DeploymentTarget, error) { GetCertificateDeploymentsFn: func(_ context.Context, certID string) ([]domain.DeploymentTarget, error) {
if certID == "mc-no-deployments" { if certID == "mc-no-deployments" {
return []domain.DeploymentTarget{}, nil return []domain.DeploymentTarget{}, nil
} }
+28 -27
View File
@@ -1,6 +1,7 @@
package handler package handler
import ( import (
"context"
"encoding/json" "encoding/json"
"log/slog" "log/slog"
"net/http" "net/http"
@@ -15,20 +16,20 @@ import (
// CertificateService defines the service interface for certificate operations. // CertificateService defines the service interface for certificate operations.
type CertificateService interface { type CertificateService interface {
ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) ListCertificates(ctx context.Context, status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error)
ListCertificatesWithFilter(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) ListCertificatesWithFilter(ctx context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error)
GetCertificate(id string) (*domain.ManagedCertificate, error) GetCertificate(ctx context.Context, id string) (*domain.ManagedCertificate, error)
CreateCertificate(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) CreateCertificate(ctx context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
UpdateCertificate(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) UpdateCertificate(ctx context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
ArchiveCertificate(id string) error ArchiveCertificate(ctx context.Context, id string) error
GetCertificateVersions(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) GetCertificateVersions(ctx context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error)
TriggerRenewal(certID string) error TriggerRenewal(ctx context.Context, certID string, actor string) error
TriggerDeployment(certID string, targetID string) error TriggerDeployment(ctx context.Context, certID string, targetID string, actor string) error
RevokeCertificate(certID string, reason string) error RevokeCertificate(ctx context.Context, certID string, reason string, actor string) error
GetRevokedCertificates() ([]*domain.CertificateRevocation, error) GetRevokedCertificates(ctx context.Context) ([]*domain.CertificateRevocation, error)
GenerateDERCRL(issuerID string) ([]byte, error) GenerateDERCRL(ctx context.Context, issuerID string) ([]byte, error)
GetOCSPResponse(issuerID string, serialHex string) ([]byte, error) GetOCSPResponse(ctx context.Context, issuerID string, serialHex string) ([]byte, error)
GetCertificateDeployments(certID string) ([]domain.DeploymentTarget, error) GetCertificateDeployments(ctx context.Context, certID string) ([]domain.DeploymentTarget, error)
} }
// CertificateHandler handles HTTP requests for certificate operations. // CertificateHandler handles HTTP requests for certificate operations.
@@ -128,7 +129,7 @@ func (h CertificateHandler) ListCertificates(w http.ResponseWriter, r *http.Requ
filter.Fields = strings.Split(fieldsStr, ",") filter.Fields = strings.Split(fieldsStr, ",")
} }
certs, total, err := h.svc.ListCertificatesWithFilter(filter) certs, total, err := h.svc.ListCertificatesWithFilter(r.Context(), filter)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list certificates", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list certificates", requestID)
return return
@@ -186,7 +187,7 @@ func (h CertificateHandler) GetCertificate(w http.ResponseWriter, r *http.Reques
return return
} }
cert, err := h.svc.GetCertificate(id) cert, err := h.svc.GetCertificate(r.Context(), id)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
return return
@@ -241,7 +242,7 @@ func (h CertificateHandler) CreateCertificate(w http.ResponseWriter, r *http.Req
return return
} }
created, err := h.svc.CreateCertificate(cert) created, err := h.svc.CreateCertificate(r.Context(), cert)
if err != nil { if err != nil {
slog.Error("failed to create certificate", "error", err, "request_id", requestID, "common_name", cert.CommonName, "name", cert.Name) slog.Error("failed to create certificate", "error", err, "request_id", requestID, "common_name", cert.CommonName, "name", cert.Name)
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create certificate", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create certificate", requestID)
@@ -295,7 +296,7 @@ func (h CertificateHandler) UpdateCertificate(w http.ResponseWriter, r *http.Req
} }
} }
updated, err := h.svc.UpdateCertificate(id, cert) updated, err := h.svc.UpdateCertificate(r.Context(), id, cert)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
@@ -325,7 +326,7 @@ func (h CertificateHandler) ArchiveCertificate(w http.ResponseWriter, r *http.Re
return return
} }
if err := h.svc.ArchiveCertificate(id); err != nil { if err := h.svc.ArchiveCertificate(r.Context(), id); err != nil {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
return return
@@ -370,7 +371,7 @@ func (h CertificateHandler) GetCertificateVersions(w http.ResponseWriter, r *htt
} }
} }
versions, total, err := h.svc.GetCertificateVersions(certID, page, perPage) versions, total, err := h.svc.GetCertificateVersions(r.Context(), certID, page, perPage)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
@@ -410,7 +411,7 @@ func (h CertificateHandler) TriggerRenewal(w http.ResponseWriter, r *http.Reques
} }
certID := parts[0] certID := parts[0]
if err := h.svc.TriggerRenewal(certID); err != nil { if err := h.svc.TriggerRenewal(r.Context(), certID, "api"); err != nil {
errMsg := err.Error() errMsg := err.Error()
if strings.Contains(errMsg, "not found") { if strings.Contains(errMsg, "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
@@ -466,7 +467,7 @@ func (h CertificateHandler) TriggerDeployment(w http.ResponseWriter, r *http.Req
} }
} }
if err := h.svc.TriggerDeployment(certID, req.TargetID); err != nil { if err := h.svc.TriggerDeployment(r.Context(), certID, req.TargetID, "api"); err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to trigger deployment", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to trigger deployment", requestID)
return return
} }
@@ -508,7 +509,7 @@ func (h CertificateHandler) RevokeCertificate(w http.ResponseWriter, r *http.Req
} }
} }
if err := h.svc.RevokeCertificate(certID, req.Reason); err != nil { if err := h.svc.RevokeCertificate(r.Context(), certID, req.Reason, "api"); err != nil {
// Distinguish between client errors and server errors // Distinguish between client errors and server errors
errMsg := err.Error() errMsg := err.Error()
if strings.Contains(errMsg, "already revoked") || if strings.Contains(errMsg, "already revoked") ||
@@ -540,7 +541,7 @@ func (h CertificateHandler) GetCRL(w http.ResponseWriter, r *http.Request) {
requestID := middleware.GetRequestID(r.Context()) requestID := middleware.GetRequestID(r.Context())
revocations, err := h.svc.GetRevokedCertificates() revocations, err := h.svc.GetRevokedCertificates(r.Context())
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to generate CRL", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to generate CRL", requestID)
return return
@@ -585,7 +586,7 @@ func (h CertificateHandler) GetDERCRL(w http.ResponseWriter, r *http.Request) {
return return
} }
derBytes, err := h.svc.GenerateDERCRL(issuerID) derBytes, err := h.svc.GenerateDERCRL(r.Context(), issuerID)
if err != nil { if err != nil {
errMsg := err.Error() errMsg := err.Error()
if strings.Contains(errMsg, "not found") { if strings.Contains(errMsg, "not found") {
@@ -627,7 +628,7 @@ func (h CertificateHandler) HandleOCSP(w http.ResponseWriter, r *http.Request) {
issuerID := parts[0] issuerID := parts[0]
serialHex := parts[1] serialHex := parts[1]
derBytes, err := h.svc.GetOCSPResponse(issuerID, serialHex) derBytes, err := h.svc.GetOCSPResponse(r.Context(), issuerID, serialHex)
if err != nil { if err != nil {
errMsg := err.Error() errMsg := err.Error()
if strings.Contains(errMsg, "not found") { if strings.Contains(errMsg, "not found") {
@@ -667,7 +668,7 @@ func (h CertificateHandler) GetCertificateDeployments(w http.ResponseWriter, r *
} }
certID := parts[0] certID := parts[0]
deployments, err := h.svc.GetCertificateDeployments(certID) deployments, err := h.svc.GetCertificateDeployments(r.Context(), certID)
if err != nil { if err != nil {
errMsg := err.Error() errMsg := err.Error()
if strings.Contains(errMsg, "not found") { if strings.Contains(errMsg, "not found") {
+33 -32
View File
@@ -2,6 +2,7 @@ package handler
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -15,52 +16,52 @@ import (
// MockIssuerService is a mock implementation of IssuerService interface. // MockIssuerService is a mock implementation of IssuerService interface.
type MockIssuerService struct { type MockIssuerService struct {
ListIssuersFn func(page, perPage int) ([]domain.Issuer, int64, error) ListIssuersFn func(ctx context.Context, page, perPage int) ([]domain.Issuer, int64, error)
GetIssuerFn func(id string) (*domain.Issuer, error) GetIssuerFn func(ctx context.Context, id string) (*domain.Issuer, error)
CreateIssuerFn func(issuer domain.Issuer) (*domain.Issuer, error) CreateIssuerFn func(ctx context.Context, issuer domain.Issuer) (*domain.Issuer, error)
UpdateIssuerFn func(id string, issuer domain.Issuer) (*domain.Issuer, error) UpdateIssuerFn func(ctx context.Context, id string, issuer domain.Issuer) (*domain.Issuer, error)
DeleteIssuerFn func(id string) error DeleteIssuerFn func(ctx context.Context, id string) error
TestConnectionFn func(id string) error TestConnectionFn func(ctx context.Context, id string) error
} }
func (m *MockIssuerService) ListIssuers(page, perPage int) ([]domain.Issuer, int64, error) { func (m *MockIssuerService) ListIssuers(ctx context.Context, page, perPage int) ([]domain.Issuer, int64, error) {
if m.ListIssuersFn != nil { if m.ListIssuersFn != nil {
return m.ListIssuersFn(page, perPage) return m.ListIssuersFn(ctx, page, perPage)
} }
return nil, 0, nil return nil, 0, nil
} }
func (m *MockIssuerService) GetIssuer(id string) (*domain.Issuer, error) { func (m *MockIssuerService) GetIssuer(ctx context.Context, id string) (*domain.Issuer, error) {
if m.GetIssuerFn != nil { if m.GetIssuerFn != nil {
return m.GetIssuerFn(id) return m.GetIssuerFn(ctx, id)
} }
return nil, nil return nil, nil
} }
func (m *MockIssuerService) CreateIssuer(issuer domain.Issuer) (*domain.Issuer, error) { func (m *MockIssuerService) CreateIssuer(ctx context.Context, issuer domain.Issuer) (*domain.Issuer, error) {
if m.CreateIssuerFn != nil { if m.CreateIssuerFn != nil {
return m.CreateIssuerFn(issuer) return m.CreateIssuerFn(ctx, issuer)
} }
return nil, nil return nil, nil
} }
func (m *MockIssuerService) UpdateIssuer(id string, issuer domain.Issuer) (*domain.Issuer, error) { func (m *MockIssuerService) UpdateIssuer(ctx context.Context, id string, issuer domain.Issuer) (*domain.Issuer, error) {
if m.UpdateIssuerFn != nil { if m.UpdateIssuerFn != nil {
return m.UpdateIssuerFn(id, issuer) return m.UpdateIssuerFn(ctx, id, issuer)
} }
return nil, nil return nil, nil
} }
func (m *MockIssuerService) DeleteIssuer(id string) error { func (m *MockIssuerService) DeleteIssuer(ctx context.Context, id string) error {
if m.DeleteIssuerFn != nil { if m.DeleteIssuerFn != nil {
return m.DeleteIssuerFn(id) return m.DeleteIssuerFn(ctx, id)
} }
return nil return nil
} }
func (m *MockIssuerService) TestConnection(id string) error { func (m *MockIssuerService) TestConnection(ctx context.Context, id string) error {
if m.TestConnectionFn != nil { if m.TestConnectionFn != nil {
return m.TestConnectionFn(id) return m.TestConnectionFn(ctx, id)
} }
return nil return nil
} }
@@ -85,7 +86,7 @@ func TestListIssuers_Success(t *testing.T) {
} }
mock := &MockIssuerService{ mock := &MockIssuerService{
ListIssuersFn: func(page, perPage int) ([]domain.Issuer, int64, error) { ListIssuersFn: func(_ context.Context, page, perPage int) ([]domain.Issuer, int64, error) {
return []domain.Issuer{iss1, iss2}, 2, nil return []domain.Issuer{iss1, iss2}, 2, nil
}, },
} }
@@ -113,7 +114,7 @@ func TestListIssuers_Success(t *testing.T) {
func TestListIssuers_Pagination(t *testing.T) { func TestListIssuers_Pagination(t *testing.T) {
var capturedPage, capturedPerPage int var capturedPage, capturedPerPage int
mock := &MockIssuerService{ mock := &MockIssuerService{
ListIssuersFn: func(page, perPage int) ([]domain.Issuer, int64, error) { ListIssuersFn: func(_ context.Context, page, perPage int) ([]domain.Issuer, int64, error) {
capturedPage = page capturedPage = page
capturedPerPage = perPage capturedPerPage = perPage
return []domain.Issuer{}, 0, nil return []domain.Issuer{}, 0, nil
@@ -137,7 +138,7 @@ func TestListIssuers_Pagination(t *testing.T) {
func TestListIssuers_ServiceError(t *testing.T) { func TestListIssuers_ServiceError(t *testing.T) {
mock := &MockIssuerService{ mock := &MockIssuerService{
ListIssuersFn: func(page, perPage int) ([]domain.Issuer, int64, error) { ListIssuersFn: func(_ context.Context, page, perPage int) ([]domain.Issuer, int64, error) {
return nil, 0, ErrMockServiceFailed return nil, 0, ErrMockServiceFailed
}, },
} }
@@ -169,7 +170,7 @@ func TestListIssuers_MethodNotAllowed(t *testing.T) {
func TestGetIssuer_Success(t *testing.T) { func TestGetIssuer_Success(t *testing.T) {
now := time.Now() now := time.Now()
mock := &MockIssuerService{ mock := &MockIssuerService{
GetIssuerFn: func(id string) (*domain.Issuer, error) { GetIssuerFn: func(_ context.Context, id string) (*domain.Issuer, error) {
return &domain.Issuer{ return &domain.Issuer{
ID: id, ID: id,
Name: "Local CA", Name: "Local CA",
@@ -195,7 +196,7 @@ func TestGetIssuer_Success(t *testing.T) {
func TestGetIssuer_NotFound(t *testing.T) { func TestGetIssuer_NotFound(t *testing.T) {
mock := &MockIssuerService{ mock := &MockIssuerService{
GetIssuerFn: func(id string) (*domain.Issuer, error) { GetIssuerFn: func(_ context.Context, id string) (*domain.Issuer, error) {
return nil, ErrMockNotFound return nil, ErrMockNotFound
}, },
} }
@@ -228,7 +229,7 @@ func TestGetIssuer_EmptyID(t *testing.T) {
func TestCreateIssuer_Success(t *testing.T) { func TestCreateIssuer_Success(t *testing.T) {
now := time.Now() now := time.Now()
mock := &MockIssuerService{ mock := &MockIssuerService{
CreateIssuerFn: func(issuer domain.Issuer) (*domain.Issuer, error) { CreateIssuerFn: func(_ context.Context, issuer domain.Issuer) (*domain.Issuer, error) {
issuer.ID = "iss-new" issuer.ID = "iss-new"
issuer.CreatedAt = now issuer.CreatedAt = now
issuer.UpdatedAt = now issuer.UpdatedAt = now
@@ -328,7 +329,7 @@ func TestCreateIssuer_NameTooLong(t *testing.T) {
func TestCreateIssuer_DuplicateName(t *testing.T) { func TestCreateIssuer_DuplicateName(t *testing.T) {
mock := &MockIssuerService{ mock := &MockIssuerService{
CreateIssuerFn: func(issuer domain.Issuer) (*domain.Issuer, error) { CreateIssuerFn: func(_ context.Context, issuer domain.Issuer) (*domain.Issuer, error) {
return nil, fmt.Errorf("failed to create issuer: duplicate key value violates unique constraint \"issuers_name_key\"") return nil, fmt.Errorf("failed to create issuer: duplicate key value violates unique constraint \"issuers_name_key\"")
}, },
} }
@@ -361,7 +362,7 @@ func TestCreateIssuer_DuplicateName(t *testing.T) {
func TestCreateIssuer_UnsupportedType(t *testing.T) { func TestCreateIssuer_UnsupportedType(t *testing.T) {
mock := &MockIssuerService{ mock := &MockIssuerService{
CreateIssuerFn: func(issuer domain.Issuer) (*domain.Issuer, error) { CreateIssuerFn: func(_ context.Context, issuer domain.Issuer) (*domain.Issuer, error) {
return nil, fmt.Errorf("unsupported issuer type: FakeCA") return nil, fmt.Errorf("unsupported issuer type: FakeCA")
}, },
} }
@@ -394,7 +395,7 @@ func TestCreateIssuer_UnsupportedType(t *testing.T) {
func TestCreateIssuer_GenericServiceError(t *testing.T) { func TestCreateIssuer_GenericServiceError(t *testing.T) {
mock := &MockIssuerService{ mock := &MockIssuerService{
CreateIssuerFn: func(issuer domain.Issuer) (*domain.Issuer, error) { CreateIssuerFn: func(_ context.Context, issuer domain.Issuer) (*domain.Issuer, error) {
return nil, fmt.Errorf("failed to encrypt config: cipher error") return nil, fmt.Errorf("failed to encrypt config: cipher error")
}, },
} }
@@ -419,7 +420,7 @@ func TestCreateIssuer_GenericServiceError(t *testing.T) {
func TestUpdateIssuer_DuplicateName(t *testing.T) { func TestUpdateIssuer_DuplicateName(t *testing.T) {
mock := &MockIssuerService{ mock := &MockIssuerService{
UpdateIssuerFn: func(id string, issuer domain.Issuer) (*domain.Issuer, error) { UpdateIssuerFn: func(_ context.Context, id string, issuer domain.Issuer) (*domain.Issuer, error) {
return nil, fmt.Errorf("failed to update issuer: duplicate key value violates unique constraint") return nil, fmt.Errorf("failed to update issuer: duplicate key value violates unique constraint")
}, },
} }
@@ -445,7 +446,7 @@ func TestUpdateIssuer_DuplicateName(t *testing.T) {
func TestDeleteIssuer_Success(t *testing.T) { func TestDeleteIssuer_Success(t *testing.T) {
var deletedID string var deletedID string
mock := &MockIssuerService{ mock := &MockIssuerService{
DeleteIssuerFn: func(id string) error { DeleteIssuerFn: func(_ context.Context, id string) error {
deletedID = id deletedID = id
return nil return nil
}, },
@@ -468,7 +469,7 @@ func TestDeleteIssuer_Success(t *testing.T) {
func TestDeleteIssuer_ServiceError(t *testing.T) { func TestDeleteIssuer_ServiceError(t *testing.T) {
mock := &MockIssuerService{ mock := &MockIssuerService{
DeleteIssuerFn: func(id string) error { DeleteIssuerFn: func(_ context.Context, id string) error {
return ErrMockServiceFailed return ErrMockServiceFailed
}, },
} }
@@ -487,7 +488,7 @@ func TestDeleteIssuer_ServiceError(t *testing.T) {
func TestTestConnection_Success(t *testing.T) { func TestTestConnection_Success(t *testing.T) {
mock := &MockIssuerService{ mock := &MockIssuerService{
TestConnectionFn: func(id string) error { TestConnectionFn: func(_ context.Context, id string) error {
return nil return nil
}, },
} }
@@ -514,7 +515,7 @@ func TestTestConnection_Success(t *testing.T) {
func TestTestConnection_Failure(t *testing.T) { func TestTestConnection_Failure(t *testing.T) {
mock := &MockIssuerService{ mock := &MockIssuerService{
TestConnectionFn: func(id string) error { TestConnectionFn: func(_ context.Context, id string) error {
return ErrMockServiceFailed return ErrMockServiceFailed
}, },
} }
+13 -12
View File
@@ -1,6 +1,7 @@
package handler package handler
import ( import (
"context"
"encoding/json" "encoding/json"
"log/slog" "log/slog"
"net/http" "net/http"
@@ -13,12 +14,12 @@ import (
// IssuerService defines the service interface for issuer operations. // IssuerService defines the service interface for issuer operations.
type IssuerService interface { type IssuerService interface {
ListIssuers(page, perPage int) ([]domain.Issuer, int64, error) ListIssuers(ctx context.Context, page, perPage int) ([]domain.Issuer, int64, error)
GetIssuer(id string) (*domain.Issuer, error) GetIssuer(ctx context.Context, id string) (*domain.Issuer, error)
CreateIssuer(issuer domain.Issuer) (*domain.Issuer, error) CreateIssuer(ctx context.Context, issuer domain.Issuer) (*domain.Issuer, error)
UpdateIssuer(id string, issuer domain.Issuer) (*domain.Issuer, error) UpdateIssuer(ctx context.Context, id string, issuer domain.Issuer) (*domain.Issuer, error)
DeleteIssuer(id string) error DeleteIssuer(ctx context.Context, id string) error
TestConnection(id string) error TestConnection(ctx context.Context, id string) error
} }
// IssuerHandler handles HTTP requests for issuer operations. // IssuerHandler handles HTTP requests for issuer operations.
@@ -61,7 +62,7 @@ func (h IssuerHandler) ListIssuers(w http.ResponseWriter, r *http.Request) {
} }
} }
issuers, total, err := h.svc.ListIssuers(page, perPage) issuers, total, err := h.svc.ListIssuers(r.Context(), page, perPage)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list issuers", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list issuers", requestID)
return return
@@ -93,7 +94,7 @@ func (h IssuerHandler) GetIssuer(w http.ResponseWriter, r *http.Request) {
return return
} }
issuer, err := h.svc.GetIssuer(id) issuer, err := h.svc.GetIssuer(r.Context(), id)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusNotFound, "Issuer not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Issuer not found", requestID)
return return
@@ -132,7 +133,7 @@ func (h IssuerHandler) CreateIssuer(w http.ResponseWriter, r *http.Request) {
return return
} }
created, err := h.svc.CreateIssuer(issuer) created, err := h.svc.CreateIssuer(r.Context(), issuer)
if err != nil { if err != nil {
h.logger.Error("failed to create issuer", "error", err, "name", issuer.Name, "type", issuer.Type) h.logger.Error("failed to create issuer", "error", err, "name", issuer.Name, "type", issuer.Type)
errMsg := err.Error() errMsg := err.Error()
@@ -174,7 +175,7 @@ func (h IssuerHandler) UpdateIssuer(w http.ResponseWriter, r *http.Request) {
return return
} }
updated, err := h.svc.UpdateIssuer(id, issuer) updated, err := h.svc.UpdateIssuer(r.Context(), id, issuer)
if err != nil { if err != nil {
h.logger.Error("failed to update issuer", "error", err, "id", id) h.logger.Error("failed to update issuer", "error", err, "id", id)
errMsg := err.Error() errMsg := err.Error()
@@ -208,7 +209,7 @@ func (h IssuerHandler) DeleteIssuer(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := h.svc.DeleteIssuer(id); err != nil { if err := h.svc.DeleteIssuer(r.Context(), id); err != nil {
if strings.Contains(err.Error(), "violates foreign key") || strings.Contains(err.Error(), "RESTRICT") { if strings.Contains(err.Error(), "violates foreign key") || strings.Contains(err.Error(), "RESTRICT") {
ErrorWithRequestID(w, http.StatusConflict, "Cannot delete issuer: certificates are still using this issuer", requestID) ErrorWithRequestID(w, http.StatusConflict, "Cannot delete issuer: certificates are still using this issuer", requestID)
} else if strings.Contains(err.Error(), "not found") { } else if strings.Contains(err.Error(), "not found") {
@@ -241,7 +242,7 @@ func (h IssuerHandler) TestConnection(w http.ResponseWriter, r *http.Request) {
} }
issuerID := parts[0] issuerID := parts[0]
if err := h.svc.TestConnection(issuerID); err != nil { if err := h.svc.TestConnection(r.Context(), issuerID); err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Connection test failed", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Connection test failed", requestID)
return return
} }
+6 -5
View File
@@ -1,6 +1,7 @@
package handler package handler
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -21,35 +22,35 @@ type MockJobService struct {
RejectJobFn func(id string, reason string) error RejectJobFn func(id string, reason string) error
} }
func (m *MockJobService) ListJobs(status, jobType string, page, perPage int) ([]domain.Job, int64, error) { func (m *MockJobService) ListJobs(_ context.Context, status, jobType string, page, perPage int) ([]domain.Job, int64, error) {
if m.ListJobsFn != nil { if m.ListJobsFn != nil {
return m.ListJobsFn(status, jobType, page, perPage) return m.ListJobsFn(status, jobType, page, perPage)
} }
return nil, 0, nil return nil, 0, nil
} }
func (m *MockJobService) GetJob(id string) (*domain.Job, error) { func (m *MockJobService) GetJob(_ context.Context, id string) (*domain.Job, error) {
if m.GetJobFn != nil { if m.GetJobFn != nil {
return m.GetJobFn(id) return m.GetJobFn(id)
} }
return nil, nil return nil, nil
} }
func (m *MockJobService) CancelJob(id string) error { func (m *MockJobService) CancelJob(_ context.Context, id string) error {
if m.CancelJobFn != nil { if m.CancelJobFn != nil {
return m.CancelJobFn(id) return m.CancelJobFn(id)
} }
return nil return nil
} }
func (m *MockJobService) ApproveJob(id string) error { func (m *MockJobService) ApproveJob(_ context.Context, id string) error {
if m.ApproveJobFn != nil { if m.ApproveJobFn != nil {
return m.ApproveJobFn(id) return m.ApproveJobFn(id)
} }
return nil return nil
} }
func (m *MockJobService) RejectJob(id string, reason string) error { func (m *MockJobService) RejectJob(_ context.Context, id string, reason string) error {
if m.RejectJobFn != nil { if m.RejectJobFn != nil {
return m.RejectJobFn(id, reason) return m.RejectJobFn(id, reason)
} }
+11 -10
View File
@@ -1,6 +1,7 @@
package handler package handler
import ( import (
"context"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@@ -13,11 +14,11 @@ import (
// JobService defines the service interface for job operations. // JobService defines the service interface for job operations.
type JobService interface { type JobService interface {
ListJobs(status, jobType string, page, perPage int) ([]domain.Job, int64, error) ListJobs(ctx context.Context, status, jobType string, page, perPage int) ([]domain.Job, int64, error)
GetJob(id string) (*domain.Job, error) GetJob(ctx context.Context, id string) (*domain.Job, error)
CancelJob(id string) error CancelJob(ctx context.Context, id string) error
ApproveJob(id string) error ApproveJob(ctx context.Context, id string) error
RejectJob(id string, reason string) error RejectJob(ctx context.Context, id string, reason string) error
} }
// JobHandler handles HTTP requests for job operations. // JobHandler handles HTTP requests for job operations.
@@ -57,7 +58,7 @@ func (h JobHandler) ListJobs(w http.ResponseWriter, r *http.Request) {
} }
} }
jobs, total, err := h.svc.ListJobs(status, jobType, page, perPage) jobs, total, err := h.svc.ListJobs(r.Context(), status, jobType, page, perPage)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list jobs", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list jobs", requestID)
return return
@@ -91,7 +92,7 @@ func (h JobHandler) GetJob(w http.ResponseWriter, r *http.Request) {
} }
id = parts[0] id = parts[0]
job, err := h.svc.GetJob(id) job, err := h.svc.GetJob(r.Context(), id)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusNotFound, "Job not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Job not found", requestID)
return return
@@ -119,7 +120,7 @@ func (h JobHandler) CancelJob(w http.ResponseWriter, r *http.Request) {
} }
jobID := parts[0] jobID := parts[0]
if err := h.svc.CancelJob(jobID); err != nil { if err := h.svc.CancelJob(r.Context(), jobID); err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to cancel job", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to cancel job", requestID)
return return
} }
@@ -149,7 +150,7 @@ func (h JobHandler) ApproveJob(w http.ResponseWriter, r *http.Request) {
} }
jobID := parts[0] jobID := parts[0]
if err := h.svc.ApproveJob(jobID); err != nil { if err := h.svc.ApproveJob(r.Context(), jobID); err != nil {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Job not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Job not found", requestID)
return return
@@ -193,7 +194,7 @@ func (h JobHandler) RejectJob(w http.ResponseWriter, r *http.Request) {
} }
} }
if err := h.svc.RejectJob(jobID, body.Reason); err != nil { if err := h.svc.RejectJob(r.Context(), jobID, body.Reason); err != nil {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Job not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Job not found", requestID)
return return
@@ -1,6 +1,7 @@
package handler package handler
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -17,21 +18,21 @@ type MockNotificationService struct {
MarkAsReadFn func(id string) error MarkAsReadFn func(id string) error
} }
func (m *MockNotificationService) ListNotifications(page, perPage int) ([]domain.NotificationEvent, int64, error) { func (m *MockNotificationService) ListNotifications(_ context.Context, page, perPage int) ([]domain.NotificationEvent, int64, error) {
if m.ListNotificationsFn != nil { if m.ListNotificationsFn != nil {
return m.ListNotificationsFn(page, perPage) return m.ListNotificationsFn(page, perPage)
} }
return nil, 0, nil return nil, 0, nil
} }
func (m *MockNotificationService) GetNotification(id string) (*domain.NotificationEvent, error) { func (m *MockNotificationService) GetNotification(_ context.Context, id string) (*domain.NotificationEvent, error) {
if m.GetNotificationFn != nil { if m.GetNotificationFn != nil {
return m.GetNotificationFn(id) return m.GetNotificationFn(id)
} }
return nil, nil return nil, nil
} }
func (m *MockNotificationService) MarkAsRead(id string) error { func (m *MockNotificationService) MarkAsRead(_ context.Context, id string) error {
if m.MarkAsReadFn != nil { if m.MarkAsReadFn != nil {
return m.MarkAsReadFn(id) return m.MarkAsReadFn(id)
} }
+7 -6
View File
@@ -1,6 +1,7 @@
package handler package handler
import ( import (
"context"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@@ -11,9 +12,9 @@ import (
// NotificationService defines the service interface for notification operations. // NotificationService defines the service interface for notification operations.
type NotificationService interface { type NotificationService interface {
ListNotifications(page, perPage int) ([]domain.NotificationEvent, int64, error) ListNotifications(ctx context.Context, page, perPage int) ([]domain.NotificationEvent, int64, error)
GetNotification(id string) (*domain.NotificationEvent, error) GetNotification(ctx context.Context, id string) (*domain.NotificationEvent, error)
MarkAsRead(id string) error MarkAsRead(ctx context.Context, id string) error
} }
// NotificationHandler handles HTTP requests for notification operations. // NotificationHandler handles HTTP requests for notification operations.
@@ -50,7 +51,7 @@ func (h NotificationHandler) ListNotifications(w http.ResponseWriter, r *http.Re
} }
} }
notifications, total, err := h.svc.ListNotifications(page, perPage) notifications, total, err := h.svc.ListNotifications(r.Context(), page, perPage)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list notifications", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list notifications", requestID)
return return
@@ -84,7 +85,7 @@ func (h NotificationHandler) GetNotification(w http.ResponseWriter, r *http.Requ
} }
id = parts[0] id = parts[0]
notification, err := h.svc.GetNotification(id) notification, err := h.svc.GetNotification(r.Context(), id)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusNotFound, "Notification not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Notification not found", requestID)
return return
@@ -112,7 +113,7 @@ func (h NotificationHandler) MarkAsRead(w http.ResponseWriter, r *http.Request)
} }
notificationID := parts[0] notificationID := parts[0]
if err := h.svc.MarkAsRead(notificationID); err != nil { if err := h.svc.MarkAsRead(r.Context(), notificationID); err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to mark notification as read", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to mark notification as read", requestID)
return return
} }
+6 -5
View File
@@ -2,6 +2,7 @@ package handler
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -20,35 +21,35 @@ type MockOwnerService struct {
DeleteOwnerFn func(id string) error DeleteOwnerFn func(id string) error
} }
func (m *MockOwnerService) ListOwners(page, perPage int) ([]domain.Owner, int64, error) { func (m *MockOwnerService) ListOwners(_ context.Context, page, perPage int) ([]domain.Owner, int64, error) {
if m.ListOwnersFn != nil { if m.ListOwnersFn != nil {
return m.ListOwnersFn(page, perPage) return m.ListOwnersFn(page, perPage)
} }
return nil, 0, nil return nil, 0, nil
} }
func (m *MockOwnerService) GetOwner(id string) (*domain.Owner, error) { func (m *MockOwnerService) GetOwner(_ context.Context, id string) (*domain.Owner, error) {
if m.GetOwnerFn != nil { if m.GetOwnerFn != nil {
return m.GetOwnerFn(id) return m.GetOwnerFn(id)
} }
return nil, nil return nil, nil
} }
func (m *MockOwnerService) CreateOwner(owner domain.Owner) (*domain.Owner, error) { func (m *MockOwnerService) CreateOwner(_ context.Context, owner domain.Owner) (*domain.Owner, error) {
if m.CreateOwnerFn != nil { if m.CreateOwnerFn != nil {
return m.CreateOwnerFn(owner) return m.CreateOwnerFn(owner)
} }
return nil, nil return nil, nil
} }
func (m *MockOwnerService) UpdateOwner(id string, owner domain.Owner) (*domain.Owner, error) { func (m *MockOwnerService) UpdateOwner(_ context.Context, id string, owner domain.Owner) (*domain.Owner, error) {
if m.UpdateOwnerFn != nil { if m.UpdateOwnerFn != nil {
return m.UpdateOwnerFn(id, owner) return m.UpdateOwnerFn(id, owner)
} }
return nil, nil return nil, nil
} }
func (m *MockOwnerService) DeleteOwner(id string) error { func (m *MockOwnerService) DeleteOwner(_ context.Context, id string) error {
if m.DeleteOwnerFn != nil { if m.DeleteOwnerFn != nil {
return m.DeleteOwnerFn(id) return m.DeleteOwnerFn(id)
} }
+11 -10
View File
@@ -1,6 +1,7 @@
package handler package handler
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"strconv" "strconv"
@@ -12,11 +13,11 @@ import (
// OwnerService defines the service interface for owner operations. // OwnerService defines the service interface for owner operations.
type OwnerService interface { type OwnerService interface {
ListOwners(page, perPage int) ([]domain.Owner, int64, error) ListOwners(ctx context.Context, page, perPage int) ([]domain.Owner, int64, error)
GetOwner(id string) (*domain.Owner, error) GetOwner(ctx context.Context, id string) (*domain.Owner, error)
CreateOwner(owner domain.Owner) (*domain.Owner, error) CreateOwner(ctx context.Context, owner domain.Owner) (*domain.Owner, error)
UpdateOwner(id string, owner domain.Owner) (*domain.Owner, error) UpdateOwner(ctx context.Context, id string, owner domain.Owner) (*domain.Owner, error)
DeleteOwner(id string) error DeleteOwner(ctx context.Context, id string) error
} }
// OwnerHandler handles HTTP requests for owner operations. // OwnerHandler handles HTTP requests for owner operations.
@@ -53,7 +54,7 @@ func (h OwnerHandler) ListOwners(w http.ResponseWriter, r *http.Request) {
} }
} }
owners, total, err := h.svc.ListOwners(page, perPage) owners, total, err := h.svc.ListOwners(r.Context(), page, perPage)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list owners", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list owners", requestID)
return return
@@ -87,7 +88,7 @@ func (h OwnerHandler) GetOwner(w http.ResponseWriter, r *http.Request) {
} }
id = parts[0] id = parts[0]
owner, err := h.svc.GetOwner(id) owner, err := h.svc.GetOwner(r.Context(), id)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusNotFound, "Owner not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Owner not found", requestID)
return return
@@ -122,7 +123,7 @@ func (h OwnerHandler) CreateOwner(w http.ResponseWriter, r *http.Request) {
return return
} }
created, err := h.svc.CreateOwner(owner) created, err := h.svc.CreateOwner(r.Context(), owner)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create owner", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create owner", requestID)
return return
@@ -155,7 +156,7 @@ func (h OwnerHandler) UpdateOwner(w http.ResponseWriter, r *http.Request) {
return return
} }
updated, err := h.svc.UpdateOwner(id, owner) updated, err := h.svc.UpdateOwner(r.Context(), id, owner)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update owner", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update owner", requestID)
return return
@@ -182,7 +183,7 @@ func (h OwnerHandler) DeleteOwner(w http.ResponseWriter, r *http.Request) {
} }
id = parts[0] id = parts[0]
if err := h.svc.DeleteOwner(id); err != nil { if err := h.svc.DeleteOwner(r.Context(), id); err != nil {
if strings.Contains(err.Error(), "violates foreign key") || strings.Contains(err.Error(), "RESTRICT") { if strings.Contains(err.Error(), "violates foreign key") || strings.Contains(err.Error(), "RESTRICT") {
ErrorWithRequestID(w, http.StatusConflict, "Cannot delete owner: certificates are still assigned to this owner", requestID) ErrorWithRequestID(w, http.StatusConflict, "Cannot delete owner: certificates are still assigned to this owner", requestID)
} else if strings.Contains(err.Error(), "not found") { } else if strings.Contains(err.Error(), "not found") {
+30 -12
View File
@@ -1,6 +1,7 @@
package handler package handler
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"strconv" "strconv"
@@ -12,12 +13,12 @@ import (
// PolicyService defines the service interface for policy rule operations. // PolicyService defines the service interface for policy rule operations.
type PolicyService interface { type PolicyService interface {
ListPolicies(page, perPage int) ([]domain.PolicyRule, int64, error) ListPolicies(ctx context.Context, page, perPage int) ([]domain.PolicyRule, int64, error)
GetPolicy(id string) (*domain.PolicyRule, error) GetPolicy(ctx context.Context, id string) (*domain.PolicyRule, error)
CreatePolicy(policy domain.PolicyRule) (*domain.PolicyRule, error) CreatePolicy(ctx context.Context, policy domain.PolicyRule) (*domain.PolicyRule, error)
UpdatePolicy(id string, policy domain.PolicyRule) (*domain.PolicyRule, error) UpdatePolicy(ctx context.Context, id string, policy domain.PolicyRule) (*domain.PolicyRule, error)
DeletePolicy(id string) error DeletePolicy(ctx context.Context, id string) error
ListViolations(policyID string, page, perPage int) ([]domain.PolicyViolation, int64, error) ListViolations(ctx context.Context, policyID string, page, perPage int) ([]domain.PolicyViolation, int64, error)
} }
// PolicyHandler handles HTTP requests for policy rule operations. // PolicyHandler handles HTTP requests for policy rule operations.
@@ -54,7 +55,7 @@ func (h PolicyHandler) ListPolicies(w http.ResponseWriter, r *http.Request) {
} }
} }
policies, total, err := h.svc.ListPolicies(page, perPage) policies, total, err := h.svc.ListPolicies(r.Context(), page, perPage)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list policies", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list policies", requestID)
return return
@@ -88,7 +89,7 @@ func (h PolicyHandler) GetPolicy(w http.ResponseWriter, r *http.Request) {
} }
id = parts[0] id = parts[0]
policy, err := h.svc.GetPolicy(id) policy, err := h.svc.GetPolicy(r.Context(), id)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusNotFound, "Policy not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Policy not found", requestID)
return return
@@ -126,8 +127,19 @@ func (h PolicyHandler) CreatePolicy(w http.ResponseWriter, r *http.Request) {
ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID) ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID)
return return
} }
// Severity is optional on create; default matches the DB default.
// Any explicit value must pass the TitleCase allowlist; the DB CHECK
// constraint enforces the same set, but catching it here gives a 400
// with a clear message instead of a 500 on constraint violation.
if policy.Severity == "" {
policy.Severity = domain.PolicySeverityWarning
}
if err := ValidatePolicySeverity(policy.Severity); err != nil {
ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID)
return
}
created, err := h.svc.CreatePolicy(policy) created, err := h.svc.CreatePolicy(r.Context(), policy)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create policy", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create policy", requestID)
return return
@@ -173,8 +185,14 @@ func (h PolicyHandler) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
return return
} }
} }
if policy.Severity != "" {
if err := ValidatePolicySeverity(policy.Severity); err != nil {
ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID)
return
}
}
updated, err := h.svc.UpdatePolicy(id, policy) updated, err := h.svc.UpdatePolicy(r.Context(), id, policy)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update policy", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update policy", requestID)
return return
@@ -201,7 +219,7 @@ func (h PolicyHandler) DeletePolicy(w http.ResponseWriter, r *http.Request) {
} }
id = parts[0] id = parts[0]
if err := h.svc.DeletePolicy(id); err != nil { if err := h.svc.DeletePolicy(r.Context(), id); err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to delete policy", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to delete policy", requestID)
return return
} }
@@ -242,7 +260,7 @@ func (h PolicyHandler) ListViolations(w http.ResponseWriter, r *http.Request) {
} }
} }
violations, total, err := h.svc.ListViolations(policyID, page, perPage) violations, total, err := h.svc.ListViolations(r.Context(), policyID, page, perPage)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list violations", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list violations", requestID)
return return
+7 -6
View File
@@ -2,6 +2,7 @@ package handler
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -21,42 +22,42 @@ type MockPolicyService struct {
ListViolationsFn func(policyID string, page, perPage int) ([]domain.PolicyViolation, int64, error) ListViolationsFn func(policyID string, page, perPage int) ([]domain.PolicyViolation, int64, error)
} }
func (m *MockPolicyService) ListPolicies(page, perPage int) ([]domain.PolicyRule, int64, error) { func (m *MockPolicyService) ListPolicies(_ context.Context, page, perPage int) ([]domain.PolicyRule, int64, error) {
if m.ListPoliciesFn != nil { if m.ListPoliciesFn != nil {
return m.ListPoliciesFn(page, perPage) return m.ListPoliciesFn(page, perPage)
} }
return nil, 0, nil return nil, 0, nil
} }
func (m *MockPolicyService) GetPolicy(id string) (*domain.PolicyRule, error) { func (m *MockPolicyService) GetPolicy(_ context.Context, id string) (*domain.PolicyRule, error) {
if m.GetPolicyFn != nil { if m.GetPolicyFn != nil {
return m.GetPolicyFn(id) return m.GetPolicyFn(id)
} }
return nil, nil return nil, nil
} }
func (m *MockPolicyService) CreatePolicy(policy domain.PolicyRule) (*domain.PolicyRule, error) { func (m *MockPolicyService) CreatePolicy(_ context.Context, policy domain.PolicyRule) (*domain.PolicyRule, error) {
if m.CreatePolicyFn != nil { if m.CreatePolicyFn != nil {
return m.CreatePolicyFn(policy) return m.CreatePolicyFn(policy)
} }
return nil, nil return nil, nil
} }
func (m *MockPolicyService) UpdatePolicy(id string, policy domain.PolicyRule) (*domain.PolicyRule, error) { func (m *MockPolicyService) UpdatePolicy(_ context.Context, id string, policy domain.PolicyRule) (*domain.PolicyRule, error) {
if m.UpdatePolicyFn != nil { if m.UpdatePolicyFn != nil {
return m.UpdatePolicyFn(id, policy) return m.UpdatePolicyFn(id, policy)
} }
return nil, nil return nil, nil
} }
func (m *MockPolicyService) DeletePolicy(id string) error { func (m *MockPolicyService) DeletePolicy(_ context.Context, id string) error {
if m.DeletePolicyFn != nil { if m.DeletePolicyFn != nil {
return m.DeletePolicyFn(id) return m.DeletePolicyFn(id)
} }
return nil return nil
} }
func (m *MockPolicyService) ListViolations(policyID string, page, perPage int) ([]domain.PolicyViolation, int64, error) { func (m *MockPolicyService) ListViolations(_ context.Context, policyID string, page, perPage int) ([]domain.PolicyViolation, int64, error) {
if m.ListViolationsFn != nil { if m.ListViolationsFn != nil {
return m.ListViolationsFn(policyID, page, perPage) return m.ListViolationsFn(policyID, page, perPage)
} }
+6 -5
View File
@@ -2,6 +2,7 @@ package handler
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -20,35 +21,35 @@ type MockProfileService struct {
DeleteProfileFn func(id string) error DeleteProfileFn func(id string) error
} }
func (m *MockProfileService) ListProfiles(page, perPage int) ([]domain.CertificateProfile, int64, error) { func (m *MockProfileService) ListProfiles(_ context.Context, page, perPage int) ([]domain.CertificateProfile, int64, error) {
if m.ListProfilesFn != nil { if m.ListProfilesFn != nil {
return m.ListProfilesFn(page, perPage) return m.ListProfilesFn(page, perPage)
} }
return nil, 0, nil return nil, 0, nil
} }
func (m *MockProfileService) GetProfile(id string) (*domain.CertificateProfile, error) { func (m *MockProfileService) GetProfile(_ context.Context, id string) (*domain.CertificateProfile, error) {
if m.GetProfileFn != nil { if m.GetProfileFn != nil {
return m.GetProfileFn(id) return m.GetProfileFn(id)
} }
return nil, nil return nil, nil
} }
func (m *MockProfileService) CreateProfile(profile domain.CertificateProfile) (*domain.CertificateProfile, error) { func (m *MockProfileService) CreateProfile(_ context.Context, profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
if m.CreateProfileFn != nil { if m.CreateProfileFn != nil {
return m.CreateProfileFn(profile) return m.CreateProfileFn(profile)
} }
return nil, nil return nil, nil
} }
func (m *MockProfileService) UpdateProfile(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) { func (m *MockProfileService) UpdateProfile(_ context.Context, id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
if m.UpdateProfileFn != nil { if m.UpdateProfileFn != nil {
return m.UpdateProfileFn(id, profile) return m.UpdateProfileFn(id, profile)
} }
return nil, nil return nil, nil
} }
func (m *MockProfileService) DeleteProfile(id string) error { func (m *MockProfileService) DeleteProfile(_ context.Context, id string) error {
if m.DeleteProfileFn != nil { if m.DeleteProfileFn != nil {
return m.DeleteProfileFn(id) return m.DeleteProfileFn(id)
} }
+11 -10
View File
@@ -1,6 +1,7 @@
package handler package handler
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"strconv" "strconv"
@@ -12,11 +13,11 @@ import (
// ProfileService defines the service interface for certificate profile operations. // ProfileService defines the service interface for certificate profile operations.
type ProfileService interface { type ProfileService interface {
ListProfiles(page, perPage int) ([]domain.CertificateProfile, int64, error) ListProfiles(ctx context.Context, page, perPage int) ([]domain.CertificateProfile, int64, error)
GetProfile(id string) (*domain.CertificateProfile, error) GetProfile(ctx context.Context, id string) (*domain.CertificateProfile, error)
CreateProfile(profile domain.CertificateProfile) (*domain.CertificateProfile, error) CreateProfile(ctx context.Context, profile domain.CertificateProfile) (*domain.CertificateProfile, error)
UpdateProfile(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) UpdateProfile(ctx context.Context, id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error)
DeleteProfile(id string) error DeleteProfile(ctx context.Context, id string) error
} }
// ProfileHandler handles HTTP requests for certificate profile operations. // ProfileHandler handles HTTP requests for certificate profile operations.
@@ -53,7 +54,7 @@ func (h ProfileHandler) ListProfiles(w http.ResponseWriter, r *http.Request) {
} }
} }
profiles, total, err := h.svc.ListProfiles(page, perPage) profiles, total, err := h.svc.ListProfiles(r.Context(), page, perPage)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list profiles", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list profiles", requestID)
return return
@@ -85,7 +86,7 @@ func (h ProfileHandler) GetProfile(w http.ResponseWriter, r *http.Request) {
return return
} }
profile, err := h.svc.GetProfile(id) profile, err := h.svc.GetProfile(r.Context(), id)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusNotFound, "Profile not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Profile not found", requestID)
return return
@@ -120,7 +121,7 @@ func (h ProfileHandler) CreateProfile(w http.ResponseWriter, r *http.Request) {
return return
} }
created, err := h.svc.CreateProfile(profile) created, err := h.svc.CreateProfile(r.Context(), profile)
if err != nil { if err != nil {
// Check if it's a validation error from the service // Check if it's a validation error from the service
if strings.Contains(err.Error(), "invalid") || strings.Contains(err.Error(), "required") || if strings.Contains(err.Error(), "invalid") || strings.Contains(err.Error(), "required") ||
@@ -159,7 +160,7 @@ func (h ProfileHandler) UpdateProfile(w http.ResponseWriter, r *http.Request) {
return return
} }
updated, err := h.svc.UpdateProfile(id, profile) updated, err := h.svc.UpdateProfile(r.Context(), id, profile)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Profile not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Profile not found", requestID)
@@ -193,7 +194,7 @@ func (h ProfileHandler) DeleteProfile(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := h.svc.DeleteProfile(id); err != nil { if err := h.svc.DeleteProfile(r.Context(), id); err != nil {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Profile not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Profile not found", requestID)
return return
+95 -30
View File
@@ -2,6 +2,7 @@ package handler
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -9,56 +10,57 @@ import (
"time" "time"
"github.com/shankar0123/certctl/internal/domain" "github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/service"
) )
// MockTargetService is a mock implementation of TargetService interface. // MockTargetService is a mock implementation of TargetService interface.
type MockTargetService struct { type MockTargetService struct {
ListTargetsFn func(page, perPage int) ([]domain.DeploymentTarget, int64, error) ListTargetsFn func(ctx context.Context, page, perPage int) ([]domain.DeploymentTarget, int64, error)
GetTargetFn func(id string) (*domain.DeploymentTarget, error) GetTargetFn func(ctx context.Context, id string) (*domain.DeploymentTarget, error)
CreateTargetFn func(target domain.DeploymentTarget) (*domain.DeploymentTarget, error) CreateTargetFn func(ctx context.Context, target domain.DeploymentTarget) (*domain.DeploymentTarget, error)
UpdateTargetFn func(id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) UpdateTargetFn func(ctx context.Context, id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error)
DeleteTargetFn func(id string) error DeleteTargetFn func(ctx context.Context, id string) error
TestTargetConnectionFn func(id string) error TestConnectionFn func(ctx context.Context, id string) error
} }
func (m *MockTargetService) ListTargets(page, perPage int) ([]domain.DeploymentTarget, int64, error) { func (m *MockTargetService) ListTargets(ctx context.Context, page, perPage int) ([]domain.DeploymentTarget, int64, error) {
if m.ListTargetsFn != nil { if m.ListTargetsFn != nil {
return m.ListTargetsFn(page, perPage) return m.ListTargetsFn(ctx, page, perPage)
} }
return nil, 0, nil return nil, 0, nil
} }
func (m *MockTargetService) GetTarget(id string) (*domain.DeploymentTarget, error) { func (m *MockTargetService) GetTarget(ctx context.Context, id string) (*domain.DeploymentTarget, error) {
if m.GetTargetFn != nil { if m.GetTargetFn != nil {
return m.GetTargetFn(id) return m.GetTargetFn(ctx, id)
} }
return nil, nil return nil, nil
} }
func (m *MockTargetService) CreateTarget(target domain.DeploymentTarget) (*domain.DeploymentTarget, error) { func (m *MockTargetService) CreateTarget(ctx context.Context, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
if m.CreateTargetFn != nil { if m.CreateTargetFn != nil {
return m.CreateTargetFn(target) return m.CreateTargetFn(ctx, target)
} }
return nil, nil return nil, nil
} }
func (m *MockTargetService) UpdateTarget(id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) { func (m *MockTargetService) UpdateTarget(ctx context.Context, id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
if m.UpdateTargetFn != nil { if m.UpdateTargetFn != nil {
return m.UpdateTargetFn(id, target) return m.UpdateTargetFn(ctx, id, target)
} }
return nil, nil return nil, nil
} }
func (m *MockTargetService) DeleteTarget(id string) error { func (m *MockTargetService) DeleteTarget(ctx context.Context, id string) error {
if m.DeleteTargetFn != nil { if m.DeleteTargetFn != nil {
return m.DeleteTargetFn(id) return m.DeleteTargetFn(ctx, id)
} }
return nil return nil
} }
func (m *MockTargetService) TestTargetConnection(id string) error { func (m *MockTargetService) TestConnection(ctx context.Context, id string) error {
if m.TestTargetConnectionFn != nil { if m.TestConnectionFn != nil {
return m.TestTargetConnectionFn(id) return m.TestConnectionFn(ctx, id)
} }
return nil return nil
} }
@@ -85,7 +87,7 @@ func TestListTargets_Success(t *testing.T) {
} }
mock := &MockTargetService{ mock := &MockTargetService{
ListTargetsFn: func(page, perPage int) ([]domain.DeploymentTarget, int64, error) { ListTargetsFn: func(_ context.Context, page, perPage int) ([]domain.DeploymentTarget, int64, error) {
return []domain.DeploymentTarget{t1, t2}, 2, nil return []domain.DeploymentTarget{t1, t2}, 2, nil
}, },
} }
@@ -113,7 +115,7 @@ func TestListTargets_Success(t *testing.T) {
func TestListTargets_Pagination(t *testing.T) { func TestListTargets_Pagination(t *testing.T) {
var capturedPage, capturedPerPage int var capturedPage, capturedPerPage int
mock := &MockTargetService{ mock := &MockTargetService{
ListTargetsFn: func(page, perPage int) ([]domain.DeploymentTarget, int64, error) { ListTargetsFn: func(_ context.Context, page, perPage int) ([]domain.DeploymentTarget, int64, error) {
capturedPage = page capturedPage = page
capturedPerPage = perPage capturedPerPage = perPage
return []domain.DeploymentTarget{}, 0, nil return []domain.DeploymentTarget{}, 0, nil
@@ -137,7 +139,7 @@ func TestListTargets_Pagination(t *testing.T) {
func TestListTargets_ServiceError(t *testing.T) { func TestListTargets_ServiceError(t *testing.T) {
mock := &MockTargetService{ mock := &MockTargetService{
ListTargetsFn: func(page, perPage int) ([]domain.DeploymentTarget, int64, error) { ListTargetsFn: func(_ context.Context, page, perPage int) ([]domain.DeploymentTarget, int64, error) {
return nil, 0, ErrMockServiceFailed return nil, 0, ErrMockServiceFailed
}, },
} }
@@ -169,7 +171,7 @@ func TestListTargets_MethodNotAllowed(t *testing.T) {
func TestGetTarget_Success(t *testing.T) { func TestGetTarget_Success(t *testing.T) {
now := time.Now() now := time.Now()
mock := &MockTargetService{ mock := &MockTargetService{
GetTargetFn: func(id string) (*domain.DeploymentTarget, error) { GetTargetFn: func(_ context.Context, id string) (*domain.DeploymentTarget, error) {
return &domain.DeploymentTarget{ return &domain.DeploymentTarget{
ID: id, ID: id,
Name: "NGINX Proxy", Name: "NGINX Proxy",
@@ -196,7 +198,7 @@ func TestGetTarget_Success(t *testing.T) {
func TestGetTarget_NotFound(t *testing.T) { func TestGetTarget_NotFound(t *testing.T) {
mock := &MockTargetService{ mock := &MockTargetService{
GetTargetFn: func(id string) (*domain.DeploymentTarget, error) { GetTargetFn: func(_ context.Context, id string) (*domain.DeploymentTarget, error) {
return nil, ErrMockNotFound return nil, ErrMockNotFound
}, },
} }
@@ -229,7 +231,7 @@ func TestGetTarget_EmptyID(t *testing.T) {
func TestCreateTarget_Success(t *testing.T) { func TestCreateTarget_Success(t *testing.T) {
now := time.Now() now := time.Now()
mock := &MockTargetService{ mock := &MockTargetService{
CreateTargetFn: func(target domain.DeploymentTarget) (*domain.DeploymentTarget, error) { CreateTargetFn: func(_ context.Context, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
target.ID = "t-new" target.ID = "t-new"
target.CreatedAt = now target.CreatedAt = now
target.UpdatedAt = now target.UpdatedAt = now
@@ -240,6 +242,7 @@ func TestCreateTarget_Success(t *testing.T) {
body := map[string]interface{}{ body := map[string]interface{}{
"name": "New Target", "name": "New Target",
"type": "nginx", "type": "nginx",
"agent_id": "agent-001",
} }
bodyBytes, _ := json.Marshal(body) bodyBytes, _ := json.Marshal(body)
@@ -258,6 +261,7 @@ func TestCreateTarget_Success(t *testing.T) {
func TestCreateTarget_MissingName(t *testing.T) { func TestCreateTarget_MissingName(t *testing.T) {
body := map[string]interface{}{ body := map[string]interface{}{
"type": "nginx", "type": "nginx",
"agent_id": "agent-001",
} }
bodyBytes, _ := json.Marshal(body) bodyBytes, _ := json.Marshal(body)
@@ -276,6 +280,7 @@ func TestCreateTarget_MissingName(t *testing.T) {
func TestCreateTarget_MissingType(t *testing.T) { func TestCreateTarget_MissingType(t *testing.T) {
body := map[string]interface{}{ body := map[string]interface{}{
"name": "New Target", "name": "New Target",
"agent_id": "agent-001",
} }
bodyBytes, _ := json.Marshal(body) bodyBytes, _ := json.Marshal(body)
@@ -312,6 +317,7 @@ func TestCreateTarget_NameTooLong(t *testing.T) {
body := map[string]interface{}{ body := map[string]interface{}{
"name": longName, "name": longName,
"type": "nginx", "type": "nginx",
"agent_id": "agent-001",
} }
bodyBytes, _ := json.Marshal(body) bodyBytes, _ := json.Marshal(body)
@@ -339,10 +345,69 @@ func TestCreateTarget_MethodNotAllowed(t *testing.T) {
} }
} }
// TestCreateTarget_MissingAgentID_Returns400 pins the C-002 handler contract:
// handler MUST reject a create payload that omits agent_id with HTTP 400
// before the service is invoked. Using a mock that would return 201-worthy
// success proves the guard fires.
func TestCreateTarget_MissingAgentID_Returns400(t *testing.T) {
body := map[string]interface{}{
"name": "New Target",
"type": "nginx",
// agent_id intentionally omitted
}
bodyBytes, _ := json.Marshal(body)
mock := &MockTargetService{
CreateTargetFn: func(_ context.Context, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
// Would succeed if handler guard did not fire.
target.ID = "t-would-be-created"
return &target, nil
},
}
handler := NewTargetHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/targets", bytes.NewReader(bodyBytes))
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.CreateTarget(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d — body=%s", w.Code, w.Body.String())
}
}
// TestCreateTarget_NonexistentAgent_Returns400 pins the C-002 handler↔service
// translation: when the service returns service.ErrAgentNotFound, the handler
// MUST map it to HTTP 400, not the generic 500 used for other service errors.
func TestCreateTarget_NonexistentAgent_Returns400(t *testing.T) {
mock := &MockTargetService{
CreateTargetFn: func(_ context.Context, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
return nil, service.ErrAgentNotFound
},
}
body := map[string]interface{}{
"name": "New Target",
"type": "nginx",
"agent_id": "agent-does-not-exist",
}
bodyBytes, _ := json.Marshal(body)
handler := NewTargetHandler(mock)
req := httptest.NewRequest(http.MethodPost, "/api/v1/targets", bytes.NewReader(bodyBytes))
req = req.WithContext(contextWithRequestID())
w := httptest.NewRecorder()
handler.CreateTarget(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400 for nonexistent agent, got %d — body=%s", w.Code, w.Body.String())
}
}
func TestUpdateTarget_Success(t *testing.T) { func TestUpdateTarget_Success(t *testing.T) {
now := time.Now() now := time.Now()
mock := &MockTargetService{ mock := &MockTargetService{
UpdateTargetFn: func(id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) { UpdateTargetFn: func(_ context.Context, id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
return &domain.DeploymentTarget{ return &domain.DeploymentTarget{
ID: id, ID: id,
Name: target.Name, Name: target.Name,
@@ -375,7 +440,7 @@ func TestUpdateTarget_Success(t *testing.T) {
func TestDeleteTarget_Success(t *testing.T) { func TestDeleteTarget_Success(t *testing.T) {
var deletedID string var deletedID string
mock := &MockTargetService{ mock := &MockTargetService{
DeleteTargetFn: func(id string) error { DeleteTargetFn: func(_ context.Context, id string) error {
deletedID = id deletedID = id
return nil return nil
}, },
@@ -398,7 +463,7 @@ func TestDeleteTarget_Success(t *testing.T) {
func TestDeleteTarget_ServiceError(t *testing.T) { func TestDeleteTarget_ServiceError(t *testing.T) {
mock := &MockTargetService{ mock := &MockTargetService{
DeleteTargetFn: func(id string) error { DeleteTargetFn: func(_ context.Context, id string) error {
return ErrMockServiceFailed return ErrMockServiceFailed
}, },
} }
@@ -430,7 +495,7 @@ func TestDeleteTarget_EmptyID(t *testing.T) {
func TestTestTargetConnection_Success(t *testing.T) { func TestTestTargetConnection_Success(t *testing.T) {
mock := &MockTargetService{ mock := &MockTargetService{
TestTargetConnectionFn: func(id string) error { TestConnectionFn: func(_ context.Context, id string) error {
return nil return nil
}, },
} }
@@ -457,7 +522,7 @@ func TestTestTargetConnection_Success(t *testing.T) {
func TestTestTargetConnection_Failed(t *testing.T) { func TestTestTargetConnection_Failed(t *testing.T) {
mock := &MockTargetService{ mock := &MockTargetService{
TestTargetConnectionFn: func(id string) error { TestConnectionFn: func(_ context.Context, id string) error {
return ErrMockServiceFailed return ErrMockServiceFailed
}, },
} }
+29 -12
View File
@@ -1,23 +1,26 @@
package handler package handler
import ( import (
"context"
"encoding/json" "encoding/json"
"errors"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"github.com/shankar0123/certctl/internal/api/middleware" "github.com/shankar0123/certctl/internal/api/middleware"
"github.com/shankar0123/certctl/internal/domain" "github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/service"
) )
// TargetService defines the service interface for deployment target operations. // TargetService defines the service interface for deployment target operations.
type TargetService interface { type TargetService interface {
ListTargets(page, perPage int) ([]domain.DeploymentTarget, int64, error) ListTargets(ctx context.Context, page, perPage int) ([]domain.DeploymentTarget, int64, error)
GetTarget(id string) (*domain.DeploymentTarget, error) GetTarget(ctx context.Context, id string) (*domain.DeploymentTarget, error)
CreateTarget(target domain.DeploymentTarget) (*domain.DeploymentTarget, error) CreateTarget(ctx context.Context, target domain.DeploymentTarget) (*domain.DeploymentTarget, error)
UpdateTarget(id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) UpdateTarget(ctx context.Context, id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error)
DeleteTarget(id string) error DeleteTarget(ctx context.Context, id string) error
TestTargetConnection(id string) error TestConnection(ctx context.Context, id string) error
} }
// TargetHandler handles HTTP requests for deployment target operations. // TargetHandler handles HTTP requests for deployment target operations.
@@ -54,7 +57,7 @@ func (h TargetHandler) ListTargets(w http.ResponseWriter, r *http.Request) {
} }
} }
targets, total, err := h.svc.ListTargets(page, perPage) targets, total, err := h.svc.ListTargets(r.Context(), page, perPage)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list targets", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list targets", requestID)
return return
@@ -86,7 +89,7 @@ func (h TargetHandler) GetTarget(w http.ResponseWriter, r *http.Request) {
return return
} }
target, err := h.svc.GetTarget(id) target, err := h.svc.GetTarget(r.Context(), id)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusNotFound, "Target not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Target not found", requestID)
return return
@@ -124,9 +127,23 @@ func (h TargetHandler) CreateTarget(w http.ResponseWriter, r *http.Request) {
ErrorWithRequestID(w, http.StatusBadRequest, "type is required", requestID) ErrorWithRequestID(w, http.StatusBadRequest, "type is required", requestID)
return return
} }
// C-002: agent_id is a NOT NULL FK in deployment_targets (migration 000001
// line 104). Reject empty values at the boundary so callers get a clean 400
// with the field name rather than a generic "Failed to create target" 500.
if err := ValidateRequired("agent_id", target.AgentID); err != nil {
ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID)
return
}
created, err := h.svc.CreateTarget(target) created, err := h.svc.CreateTarget(r.Context(), target)
if err != nil { if err != nil {
// C-002: a nonexistent agent_id is a client error, not a server error.
// The service returns ErrAgentNotFound (wrapped via fmt.Errorf %w) when
// agentRepo.Get fails; we translate that to 400 via errors.Is.
if errors.Is(err, service.ErrAgentNotFound) {
ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID)
return
}
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create target", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create target", requestID)
return return
} }
@@ -158,7 +175,7 @@ func (h TargetHandler) UpdateTarget(w http.ResponseWriter, r *http.Request) {
return return
} }
updated, err := h.svc.UpdateTarget(id, target) updated, err := h.svc.UpdateTarget(r.Context(), id, target)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update target", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update target", requestID)
return return
@@ -183,7 +200,7 @@ func (h TargetHandler) DeleteTarget(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := h.svc.DeleteTarget(id); err != nil { if err := h.svc.DeleteTarget(r.Context(), id); err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to delete target", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to delete target", requestID)
return return
} }
@@ -210,7 +227,7 @@ func (h TargetHandler) TestTargetConnection(w http.ResponseWriter, r *http.Reque
} }
id := parts[0] id := parts[0]
if err := h.svc.TestTargetConnection(id); err != nil { if err := h.svc.TestConnection(r.Context(), id); err != nil {
JSON(w, http.StatusOK, map[string]interface{}{ JSON(w, http.StatusOK, map[string]interface{}{
"status": "failed", "status": "failed",
"message": err.Error(), "message": err.Error(),
+6 -5
View File
@@ -2,6 +2,7 @@ package handler
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -20,35 +21,35 @@ type MockTeamService struct {
DeleteTeamFn func(id string) error DeleteTeamFn func(id string) error
} }
func (m *MockTeamService) ListTeams(page, perPage int) ([]domain.Team, int64, error) { func (m *MockTeamService) ListTeams(_ context.Context, page, perPage int) ([]domain.Team, int64, error) {
if m.ListTeamsFn != nil { if m.ListTeamsFn != nil {
return m.ListTeamsFn(page, perPage) return m.ListTeamsFn(page, perPage)
} }
return nil, 0, nil return nil, 0, nil
} }
func (m *MockTeamService) GetTeam(id string) (*domain.Team, error) { func (m *MockTeamService) GetTeam(_ context.Context, id string) (*domain.Team, error) {
if m.GetTeamFn != nil { if m.GetTeamFn != nil {
return m.GetTeamFn(id) return m.GetTeamFn(id)
} }
return nil, nil return nil, nil
} }
func (m *MockTeamService) CreateTeam(team domain.Team) (*domain.Team, error) { func (m *MockTeamService) CreateTeam(_ context.Context, team domain.Team) (*domain.Team, error) {
if m.CreateTeamFn != nil { if m.CreateTeamFn != nil {
return m.CreateTeamFn(team) return m.CreateTeamFn(team)
} }
return nil, nil return nil, nil
} }
func (m *MockTeamService) UpdateTeam(id string, team domain.Team) (*domain.Team, error) { func (m *MockTeamService) UpdateTeam(_ context.Context, id string, team domain.Team) (*domain.Team, error) {
if m.UpdateTeamFn != nil { if m.UpdateTeamFn != nil {
return m.UpdateTeamFn(id, team) return m.UpdateTeamFn(id, team)
} }
return nil, nil return nil, nil
} }
func (m *MockTeamService) DeleteTeam(id string) error { func (m *MockTeamService) DeleteTeam(_ context.Context, id string) error {
if m.DeleteTeamFn != nil { if m.DeleteTeamFn != nil {
return m.DeleteTeamFn(id) return m.DeleteTeamFn(id)
} }
+11 -10
View File
@@ -1,6 +1,7 @@
package handler package handler
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"strconv" "strconv"
@@ -12,11 +13,11 @@ import (
// TeamService defines the service interface for team operations. // TeamService defines the service interface for team operations.
type TeamService interface { type TeamService interface {
ListTeams(page, perPage int) ([]domain.Team, int64, error) ListTeams(ctx context.Context, page, perPage int) ([]domain.Team, int64, error)
GetTeam(id string) (*domain.Team, error) GetTeam(ctx context.Context, id string) (*domain.Team, error)
CreateTeam(team domain.Team) (*domain.Team, error) CreateTeam(ctx context.Context, team domain.Team) (*domain.Team, error)
UpdateTeam(id string, team domain.Team) (*domain.Team, error) UpdateTeam(ctx context.Context, id string, team domain.Team) (*domain.Team, error)
DeleteTeam(id string) error DeleteTeam(ctx context.Context, id string) error
} }
// TeamHandler handles HTTP requests for team operations. // TeamHandler handles HTTP requests for team operations.
@@ -53,7 +54,7 @@ func (h TeamHandler) ListTeams(w http.ResponseWriter, r *http.Request) {
} }
} }
teams, total, err := h.svc.ListTeams(page, perPage) teams, total, err := h.svc.ListTeams(r.Context(), page, perPage)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list teams", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list teams", requestID)
return return
@@ -87,7 +88,7 @@ func (h TeamHandler) GetTeam(w http.ResponseWriter, r *http.Request) {
} }
id = parts[0] id = parts[0]
team, err := h.svc.GetTeam(id) team, err := h.svc.GetTeam(r.Context(), id)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusNotFound, "Team not found", requestID) ErrorWithRequestID(w, http.StatusNotFound, "Team not found", requestID)
return return
@@ -122,7 +123,7 @@ func (h TeamHandler) CreateTeam(w http.ResponseWriter, r *http.Request) {
return return
} }
created, err := h.svc.CreateTeam(team) created, err := h.svc.CreateTeam(r.Context(), team)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create team", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create team", requestID)
return return
@@ -155,7 +156,7 @@ func (h TeamHandler) UpdateTeam(w http.ResponseWriter, r *http.Request) {
return return
} }
updated, err := h.svc.UpdateTeam(id, team) updated, err := h.svc.UpdateTeam(r.Context(), id, team)
if err != nil { if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update team", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update team", requestID)
return return
@@ -182,7 +183,7 @@ func (h TeamHandler) DeleteTeam(w http.ResponseWriter, r *http.Request) {
} }
id = parts[0] id = parts[0]
if err := h.svc.DeleteTeam(id); err != nil { if err := h.svc.DeleteTeam(r.Context(), id); err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to delete team", requestID) ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to delete team", requestID)
return return
} }
+2 -1
View File
@@ -71,10 +71,11 @@ func ValidatePolicyType(policyType interface{}) error {
"RequiredMetadata": true, "RequiredMetadata": true,
"AllowedEnvironments": true, "AllowedEnvironments": true,
"RenewalLeadTime": true, "RenewalLeadTime": true,
"CertificateLifetime": true,
} }
typeStr := fmt.Sprintf("%v", policyType) typeStr := fmt.Sprintf("%v", policyType)
if !validTypes[typeStr] { if !validTypes[typeStr] {
return ValidationError{Field: "type", Message: "type must be one of: AllowedIssuers, AllowedDomains, RequiredMetadata, AllowedEnvironments, RenewalLeadTime"} return ValidationError{Field: "type", Message: "type must be one of: AllowedIssuers, AllowedDomains, RequiredMetadata, AllowedEnvironments, RenewalLeadTime, CertificateLifetime"}
} }
return nil return nil
} }
+115 -14
View File
@@ -4,16 +4,22 @@ import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"net/http" "net/http"
"strings" "strings"
"sync"
"time" "time"
) )
// AuditRecorder is the interface that the audit middleware uses to record API calls. // AuditRecorder is the interface that the audit middleware uses to record API calls.
// This avoids importing the service package directly, maintaining dependency inversion. // This avoids importing the service package directly, maintaining dependency inversion.
//
// Implementations may perform I/O (e.g., database writes). The middleware invokes
// RecordAPICall from a tracked goroutine so that callers can drain in-flight
// recordings during graceful shutdown via AuditMiddleware.Flush.
type AuditRecorder interface { type AuditRecorder interface {
RecordAPICall(ctx context.Context, method, path, actor string, bodyHash string, status int, latencyMs int64) error RecordAPICall(ctx context.Context, method, path, actor string, bodyHash string, status int, latencyMs int64) error
} }
@@ -26,10 +32,42 @@ type AuditConfig struct {
Logger *slog.Logger Logger *slog.Logger
} }
// NewAuditLog creates a middleware that records every API call to the audit trail. // ErrAuditFlushTimeout is returned by AuditMiddleware.Flush when in-flight audit
// It captures method, path, authenticated actor, request body hash, response status, and latency. // recordings do not complete before the provided context is cancelled or its
// Audit recording is best-effort — failures are logged but don't affect the HTTP response. // deadline elapses. It mirrors scheduler.ErrSchedulerShutdownTimeout so callers
func NewAuditLog(recorder AuditRecorder, cfg AuditConfig) func(http.Handler) http.Handler { // can branch on graceful-shutdown timeouts consistently across subsystems.
var ErrAuditFlushTimeout = errors.New("audit middleware flush timeout")
// AuditMiddleware is the handle returned by NewAuditLog. It wraps the audit
// logging HTTP middleware and tracks the goroutines spawned to record each API
// call, so that callers can drain them during graceful shutdown (M-1, CWE-662
// / CWE-400). The goroutines themselves still run detached from the request
// context — the shutdown-drain signal flows through this struct's WaitGroup
// instead of the per-request context.
type AuditMiddleware struct {
recorder AuditRecorder
logger *slog.Logger
excludeSet map[string]bool
// wg tracks every audit-recording goroutine spawned by Middleware so Flush
// can block until they complete before the DB pool is torn down.
wg sync.WaitGroup
}
// NewAuditLog constructs the API audit logging middleware. The returned
// *AuditMiddleware exposes the HTTP middleware via the Middleware method value
// (same func(http.Handler) http.Handler shape) and a Flush method that the
// process shutdown path must call after the HTTP server has stopped accepting
// new requests but before the audit recorder's backing store (e.g., the
// database connection pool) is closed.
//
// The middleware records method, path, authenticated actor, request body hash,
// response status, and latency. Recording is best-effort — individual failures
// are logged and do not affect the HTTP response. Shutdown is NOT best-effort:
// Flush must succeed (or time out, returning ErrAuditFlushTimeout) so that
// in-flight events are not lost when the audit recorder's connection pool is
// closed out from under the goroutines.
func NewAuditLog(recorder AuditRecorder, cfg AuditConfig) *AuditMiddleware {
excludeSet := make(map[string]bool, len(cfg.ExcludePaths)) excludeSet := make(map[string]bool, len(cfg.ExcludePaths))
for _, p := range cfg.ExcludePaths { for _, p := range cfg.ExcludePaths {
excludeSet[p] = true excludeSet[p] = true
@@ -40,10 +78,20 @@ func NewAuditLog(recorder AuditRecorder, cfg AuditConfig) func(http.Handler) htt
logger = slog.Default() logger = slog.Default()
} }
return func(next http.Handler) http.Handler { return &AuditMiddleware{
recorder: recorder,
logger: logger,
excludeSet: excludeSet,
}
}
// Middleware is the http.Handler wrapper. It has the standard
// func(http.Handler) http.Handler middleware signature so it can be composed
// into an existing middleware chain via a method value (auditMiddleware.Middleware).
func (a *AuditMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip excluded paths (health, readiness probes) // Skip excluded paths (health, readiness probes)
for prefix := range excludeSet { for prefix := range a.excludeSet {
if strings.HasPrefix(r.URL.Path, prefix) { if strings.HasPrefix(r.URL.Path, prefix) {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
@@ -78,31 +126,84 @@ func NewAuditLog(recorder AuditRecorder, cfg AuditConfig) func(http.Handler) htt
latency := time.Since(start).Milliseconds() latency := time.Since(start).Milliseconds()
// Snapshot request-derived inputs so the goroutine does not race with
// the http.Server reusing r after this handler returns.
method := r.Method
path := r.URL.Path
status := wrapped.statusCode
// Derive a detached context that preserves request-scoped values
// (trace IDs, auth info carried via context keys) but is not cancelled
// when the HTTP server finalizes the request. Using r.Context()
// directly would cause the async audit write to observe ctx.Done()
// as soon as the response completes; using context.Background() would
// discard useful observability metadata. WithoutCancel gives us both
// (M-2 / D-3).
auditCtx := context.WithoutCancel(r.Context())
// Record audit event asynchronously (best-effort, don't block response). // Record audit event asynchronously (best-effort, don't block response).
// SECURITY: We intentionally use r.URL.Path (not r.URL.String() or r.RequestURI) // SECURITY: We intentionally use r.URL.Path (not r.URL.String() or r.RequestURI)
// to prevent query parameters from being recorded in the immutable audit trail. // to prevent query parameters from being recorded in the immutable audit trail.
// Query strings may contain cursor tokens, API keys passed as params, or other // Query strings may contain cursor tokens, API keys passed as params, or other
// sensitive filter values. Since the audit trail is append-only with no deletion // sensitive filter values. Since the audit trail is append-only with no deletion
// capability, any sensitive data recorded would persist permanently. // capability, any sensitive data recorded would persist permanently.
//
// The goroutine is tracked in a.wg so AuditMiddleware.Flush can drain
// in-flight recordings during graceful shutdown. Without this (M-1,
// CWE-662 / CWE-400), SIGTERM would close the DB pool while recordings
// were still mid-flight, silently dropping audit events.
a.wg.Add(1)
go func() { go func() {
if err := recorder.RecordAPICall( defer a.wg.Done()
context.Background(), if err := a.recorder.RecordAPICall(
r.Method, auditCtx,
r.URL.Path, method,
path,
actor, actor,
bodyHash, bodyHash,
wrapped.statusCode, status,
latency, latency,
); err != nil { ); err != nil {
logger.Error("failed to record API audit event", a.logger.Error("failed to record API audit event",
"error", err, "error", err,
"method", r.Method, "method", method,
"path", r.URL.Path, "path", path,
) )
} }
}() }()
}) })
} }
// Flush blocks until every audit-recording goroutine spawned by Middleware has
// completed, or until ctx is cancelled / its deadline elapses. It must be
// called from the process shutdown path after http.Server.Shutdown has
// returned (so no new requests are being accepted) but before the backing
// audit recorder's resources (DB pool, etc.) are torn down.
//
// On timeout or cancellation Flush returns ErrAuditFlushTimeout wrapped with
// any context error; in-flight goroutines continue to run and may still write
// to the recorder once they unblock — the caller is responsible for deciding
// whether to proceed with teardown anyway or surface the error.
//
// Flush mirrors the idiom used by scheduler.Scheduler.WaitForCompletion so
// that the two subsystems drain identically at shutdown.
func (a *AuditMiddleware) Flush(ctx context.Context) error {
done := make(chan struct{})
go func() {
a.wg.Wait()
close(done)
}()
select {
case <-done:
a.logger.Info("audit middleware flush complete")
return nil
case <-ctx.Done():
a.logger.Warn("audit middleware flush did not complete before context cancellation",
"error", ctx.Err(),
)
return fmt.Errorf("%w: %w", ErrAuditFlushTimeout, ctx.Err())
}
} }
// AuditServiceAdapter adapts the AuditService to the AuditRecorder interface. // AuditServiceAdapter adapts the AuditService to the AuditRecorder interface.
+127 -9
View File
@@ -2,6 +2,7 @@ package middleware
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@@ -17,6 +18,7 @@ type mockAuditRecorder struct {
mu sync.Mutex mu sync.Mutex
calls []auditCall calls []auditCall
err error // if non-nil, RecordAPICall returns this err error // if non-nil, RecordAPICall returns this
block chan struct{} // if non-nil, RecordAPICall blocks on receive before returning
} }
type auditCall struct { type auditCall struct {
@@ -29,6 +31,13 @@ type auditCall struct {
} }
func (m *mockAuditRecorder) RecordAPICall(ctx context.Context, method, path, actor, bodyHash string, status int, latencyMs int64) error { func (m *mockAuditRecorder) RecordAPICall(ctx context.Context, method, path, actor, bodyHash string, status int, latencyMs int64) error {
// Optional: block the recorder until a signal is received so tests can
// exercise the shutdown-drain path deterministically. The block happens
// before any state mutation so Flush-timeout tests see the call
// "in-flight" (wg counter > 0) with no recorded entries yet.
if m.block != nil {
<-m.block
}
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
m.calls = append(m.calls, auditCall{ m.calls = append(m.calls, auditCall{
@@ -90,7 +99,7 @@ func (w *waitableAuditRecorder) Wait(timeout time.Duration) bool {
func TestAuditLog_RecordsAPICall(t *testing.T) { func TestAuditLog_RecordsAPICall(t *testing.T) {
recorder := newWaitableAuditRecorder() recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{}) mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -130,7 +139,7 @@ func TestAuditLog_RecordsAPICall(t *testing.T) {
func TestAuditLog_CapturesStatusCode(t *testing.T) { func TestAuditLog_CapturesStatusCode(t *testing.T) {
recorder := newWaitableAuditRecorder() recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{}) mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
@@ -157,7 +166,7 @@ func TestAuditLog_ExcludesHealth(t *testing.T) {
recorder := newWaitableAuditRecorder() recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{ mw := NewAuditLog(recorder, AuditConfig{
ExcludePaths: []string{"/health", "/ready"}, ExcludePaths: []string{"/health", "/ready"},
}) }).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -193,7 +202,7 @@ func TestAuditLog_ExcludesHealth(t *testing.T) {
func TestAuditLog_HashesRequestBody(t *testing.T) { func TestAuditLog_HashesRequestBody(t *testing.T) {
recorder := newWaitableAuditRecorder() recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{}) mw := NewAuditLog(recorder, AuditConfig{}).Middleware
// Handler verifies body was restored // Handler verifies body was restored
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -228,7 +237,7 @@ func TestAuditLog_HashesRequestBody(t *testing.T) {
func TestAuditLog_EmptyBodyNoHash(t *testing.T) { func TestAuditLog_EmptyBodyNoHash(t *testing.T) {
recorder := newWaitableAuditRecorder() recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{}) mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -253,7 +262,7 @@ func TestAuditLog_EmptyBodyNoHash(t *testing.T) {
func TestAuditLog_ExtractsAuthenticatedActor(t *testing.T) { func TestAuditLog_ExtractsAuthenticatedActor(t *testing.T) {
recorder := newWaitableAuditRecorder() recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{}) mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -285,7 +294,7 @@ func TestAuditLog_ExtractsAuthenticatedActor(t *testing.T) {
func TestAuditLog_RecorderErrorDoesNotBreakResponse(t *testing.T) { func TestAuditLog_RecorderErrorDoesNotBreakResponse(t *testing.T) {
recorder := &mockAuditRecorder{err: fmt.Errorf("db connection lost")} recorder := &mockAuditRecorder{err: fmt.Errorf("db connection lost")}
mw := NewAuditLog(recorder, AuditConfig{}) mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -304,7 +313,7 @@ func TestAuditLog_RecorderErrorDoesNotBreakResponse(t *testing.T) {
func TestAuditLog_CapturesLatency(t *testing.T) { func TestAuditLog_CapturesLatency(t *testing.T) {
recorder := newWaitableAuditRecorder() recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{}) mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
@@ -330,7 +339,7 @@ func TestAuditLog_CapturesLatency(t *testing.T) {
func TestAuditLog_ExcludesQueryParamsFromPath(t *testing.T) { func TestAuditLog_ExcludesQueryParamsFromPath(t *testing.T) {
recorder := newWaitableAuditRecorder() recorder := newWaitableAuditRecorder()
mw := NewAuditLog(recorder, AuditConfig{}) mw := NewAuditLog(recorder, AuditConfig{}).Middleware
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -429,3 +438,112 @@ func TestAuditServiceAdapter_PropagatesError(t *testing.T) {
t.Errorf("expected database error, got %v", err) t.Errorf("expected database error, got %v", err)
} }
} }
// TestAuditLog_FlushDrainsInFlightGoroutines verifies the M-1 shutdown-drain
// contract: Flush blocks until every audit-recording goroutine spawned by the
// middleware completes, then returns nil. Without the drain (pre-M-1 code),
// the DB pool would be closed while in-flight goroutines were still calling
// RecordAPICall, silently dropping audit events (CWE-662 / CWE-400).
func TestAuditLog_FlushDrainsInFlightGoroutines(t *testing.T) {
// Recorder blocks on `unblock` until the test releases it. This simulates
// a slow DB write still in flight when shutdown begins.
unblock := make(chan struct{})
recorder := &mockAuditRecorder{block: unblock}
auditMW := NewAuditLog(recorder, AuditConfig{})
handler := auditMW.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Fire a request. Handler returns immediately; recorder goroutine is
// parked on the `unblock` channel inside RecordAPICall.
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Start Flush in a goroutine — it must block on the WaitGroup until we
// release the recorder.
flushDone := make(chan error, 1)
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
flushDone <- auditMW.Flush(ctx)
}()
// Confirm Flush is actually blocked (not returning immediately).
select {
case err := <-flushDone:
t.Fatalf("Flush returned before recorder unblocked: err=%v", err)
case <-time.After(50 * time.Millisecond):
// expected: Flush is blocked on wg.Wait
}
// Release the recorder. Flush should now observe wg counter drop to 0
// and return nil.
close(unblock)
select {
case err := <-flushDone:
if err != nil {
t.Fatalf("expected nil from Flush after drain, got %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("Flush did not return after recorder unblocked")
}
// Verify the audit event was actually recorded (i.e., the goroutine
// completed its write — not just that Flush unblocked).
calls := recorder.getCalls()
if len(calls) != 1 {
t.Fatalf("expected 1 recorded audit call, got %d", len(calls))
}
if calls[0].Path != "/api/v1/certificates" {
t.Errorf("expected path /api/v1/certificates, got %s", calls[0].Path)
}
}
// TestAuditLog_FlushTimeoutReturnsErrAuditFlushTimeout verifies that Flush
// respects its context: when in-flight goroutines exceed the shutdown budget,
// Flush returns an error wrapping ErrAuditFlushTimeout plus ctx.Err(). The
// caller can then decide whether to proceed with teardown anyway.
func TestAuditLog_FlushTimeoutReturnsErrAuditFlushTimeout(t *testing.T) {
// Recorder will never unblock on its own — we unblock at end of test for
// a clean race-safe teardown.
unblock := make(chan struct{})
recorder := &mockAuditRecorder{block: unblock}
auditMW := NewAuditLog(recorder, AuditConfig{})
handler := auditMW.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Flush with a tiny deadline — must time out.
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
err := auditMW.Flush(ctx)
if err == nil {
// Release the blocked goroutine before failing so the race detector
// doesn't trip on teardown.
close(unblock)
t.Fatal("expected Flush to return an error on timeout, got nil")
}
if !errors.Is(err, ErrAuditFlushTimeout) {
close(unblock)
t.Fatalf("expected error to wrap ErrAuditFlushTimeout, got %v", err)
}
if !errors.Is(err, context.DeadlineExceeded) {
close(unblock)
t.Fatalf("expected error to wrap context.DeadlineExceeded, got %v", err)
}
// Race-safe teardown: unblock the recorder goroutine so it exits cleanly
// before the test returns. The goroutine itself is still detached and
// will record to the mock even after Flush timed out — that's the
// documented behavior (Flush surfaces the timeout; caller decides).
close(unblock)
}
+10 -2
View File
@@ -5,6 +5,7 @@ import (
"crypto/sha256" "crypto/sha256"
"crypto/subtle" "crypto/subtle"
"encoding/hex" "encoding/hex"
"fmt"
"log" "log"
"log/slog" "log/slog"
"net/http" "net/http"
@@ -78,10 +79,17 @@ func NewLogging(logger *slog.Logger) func(http.Handler) http.Handler {
// Recovery middleware recovers from panics and returns a 500 error. // Recovery middleware recovers from panics and returns a 500 error.
func Recovery(next http.Handler) http.Handler { func Recovery(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
requestID := getRequestID(r.Context()) requestID := getRequestID(ctx)
log.Printf("[%s] PANIC: %v", requestID, err) // Use slog.ErrorContext so the panic log carries the same
// request-scoped trace/auth metadata as normal request logs
// (M-2 / D-3 — preserve ctx propagation on the panic path).
slog.ErrorContext(ctx, "panic recovered in HTTP handler",
"request_id", requestID,
"panic", fmt.Sprintf("%v", err),
)
http.Error(w, `{"error":"Internal Server Error"}`, http.StatusInternalServerError) http.Error(w, `{"error":"Internal Server Error"}`, http.StatusInternalServerError)
} }
}() }()
+5 -1
View File
@@ -547,7 +547,11 @@ func (c *Connector) solveAuthorizationsHTTP01(ctx context.Context, authzURLs []s
return fmt.Errorf("failed to start challenge server: %w", err) return fmt.Errorf("failed to start challenge server: %w", err)
} }
defer func() { defer func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) // Derive the challenge-server shutdown context from the parent ctx so
// values (trace IDs, deadlines) propagate, but detach from its
// cancellation so Shutdown always gets its full budget even when the
// parent was cancelled (M-2 / D-3).
shutdownCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
defer cancel() defer cancel()
_ = srv.Shutdown(shutdownCtx) _ = srv.Shutdown(shutdownCtx)
c.logger.Debug("challenge server stopped") c.logger.Debug("challenge server stopped")
+19
View File
@@ -359,6 +359,25 @@ func (c *Connector) loadCAFromDisk() error {
return fmt.Errorf("loaded CA certificate does not have KeyUsageCertSign") return fmt.Errorf("loaded CA certificate does not have KeyUsageCertSign")
} }
// Validate CA certificate validity window (M-5, CWE-672).
// An expired or not-yet-valid sub-CA produces child certificates that any
// RFC 5280 path-validator will reject. Fail closed at load time so operators
// learn about it at startup, not at 3am when a renewal cycle silently
// starts minting broken certs. See audit finding M-5.
now := time.Now()
if now.After(caCert.NotAfter) {
return fmt.Errorf("CA certificate %q has expired (not_after=%s, now=%s)",
caCert.Subject.CommonName,
caCert.NotAfter.UTC().Format(time.RFC3339),
now.UTC().Format(time.RFC3339))
}
if now.Before(caCert.NotBefore) {
return fmt.Errorf("CA certificate %q is not yet valid (not_before=%s, now=%s)",
caCert.Subject.CommonName,
caCert.NotBefore.UTC().Format(time.RFC3339),
now.UTC().Format(time.RFC3339))
}
// Load CA private key (supports RSA and ECDSA) // Load CA private key (supports RSA and ECDSA)
keyPEM, err := os.ReadFile(c.config.CAKeyPath) keyPEM, err := os.ReadFile(c.config.CAKeyPath)
if err != nil { if err != nil {
+120 -3
View File
@@ -14,6 +14,7 @@ import (
"math/big" "math/big"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"time" "time"
@@ -360,6 +361,114 @@ func TestSubCAMode(t *testing.T) {
t.Logf("Correctly rejected non-CA cert: %v", err) t.Logf("Correctly rejected non-CA cert: %v", err)
}) })
t.Run("SubCA_ExpiredCert_IsRejected", func(t *testing.T) {
// Sub-CA expired 1 hour ago. M-5: loadCAFromDisk must fail closed
// instead of minting child certs that immediately fail path validation
// at every relying party (CWE-672).
notBefore := time.Now().AddDate(-1, 0, 0)
notAfter := time.Now().Add(-1 * time.Hour)
certPath, keyPath := generateTestSubCAWithValidity(t, "rsa", notBefore, notAfter)
config := &local.Config{
ValidityDays: 30,
CACertPath: certPath,
CAKeyPath: keyPath,
}
connector := local.New(config, logger)
_, csrPEM, err := generateTestCSR("app.internal.corp")
if err != nil {
t.Fatalf("Failed to generate CSR: %v", err)
}
req := issuer.IssuanceRequest{
CommonName: "app.internal.corp",
CSRPEM: csrPEM,
}
_, err = connector.IssueCertificate(ctx, req)
if err == nil {
t.Fatal("Expected error when loading expired sub-CA; got nil")
}
if !strings.Contains(err.Error(), "expired") {
t.Errorf("Expected error to mention 'expired'; got: %v", err)
}
if !strings.Contains(err.Error(), "Test Sub-CA") {
t.Errorf("Expected error to include CA subject CN 'Test Sub-CA'; got: %v", err)
}
t.Logf("Correctly rejected expired sub-CA: %v", err)
})
t.Run("SubCA_NotYetValid_IsRejected", func(t *testing.T) {
// Sub-CA is not valid for another hour (clock skew or operator error
// pushing a pre-production CA into prod). M-5: loadCAFromDisk must
// fail closed.
notBefore := time.Now().Add(1 * time.Hour)
notAfter := time.Now().AddDate(5, 0, 0)
certPath, keyPath := generateTestSubCAWithValidity(t, "rsa", notBefore, notAfter)
config := &local.Config{
ValidityDays: 30,
CACertPath: certPath,
CAKeyPath: keyPath,
}
connector := local.New(config, logger)
_, csrPEM, err := generateTestCSR("app.internal.corp")
if err != nil {
t.Fatalf("Failed to generate CSR: %v", err)
}
req := issuer.IssuanceRequest{
CommonName: "app.internal.corp",
CSRPEM: csrPEM,
}
_, err = connector.IssueCertificate(ctx, req)
if err == nil {
t.Fatal("Expected error when loading not-yet-valid sub-CA; got nil")
}
if !strings.Contains(err.Error(), "not yet valid") {
t.Errorf("Expected error to mention 'not yet valid'; got: %v", err)
}
if !strings.Contains(err.Error(), "Test Sub-CA") {
t.Errorf("Expected error to include CA subject CN 'Test Sub-CA'; got: %v", err)
}
t.Logf("Correctly rejected not-yet-valid sub-CA: %v", err)
})
t.Run("SubCA_BarelyValid_IsAccepted", func(t *testing.T) {
// Sub-CA valid from 1 minute ago to 1 hour from now. Edge case:
// proves the M-5 window check doesn't over-reject CAs that are
// legitimately live but close to the boundaries.
notBefore := time.Now().Add(-1 * time.Minute)
notAfter := time.Now().Add(1 * time.Hour)
certPath, keyPath := generateTestSubCAWithValidity(t, "rsa", notBefore, notAfter)
config := &local.Config{
ValidityDays: 30,
CACertPath: certPath,
CAKeyPath: keyPath,
}
connector := local.New(config, logger)
_, csrPEM, err := generateTestCSR("app.internal.corp")
if err != nil {
t.Fatalf("Failed to generate CSR: %v", err)
}
req := issuer.IssuanceRequest{
CommonName: "app.internal.corp",
CSRPEM: csrPEM,
}
result, err := connector.IssueCertificate(ctx, req)
if err != nil {
t.Fatalf("Barely-valid sub-CA was wrongly rejected: %v", err)
}
if result.CertPEM == "" {
t.Error("CertPEM is empty")
}
t.Logf("Correctly accepted barely-valid sub-CA: serial=%s", result.Serial)
})
t.Run("SubCA_RenewCertificate", func(t *testing.T) { t.Run("SubCA_RenewCertificate", func(t *testing.T) {
certPath, keyPath := generateTestSubCA(t, "rsa") certPath, keyPath := generateTestSubCA(t, "rsa")
defer os.Remove(certPath) defer os.Remove(certPath)
@@ -396,8 +505,16 @@ func TestSubCAMode(t *testing.T) {
} }
// generateTestSubCA creates a self-signed CA cert+key pair and writes them to temp files. // generateTestSubCA creates a self-signed CA cert+key pair and writes them to temp files.
// keyType can be "rsa" or "ecdsa". // keyType can be "rsa" or "ecdsa". Validity window is [now, now+5y].
func generateTestSubCA(t *testing.T, keyType string) (certPath, keyPath string) { func generateTestSubCA(t *testing.T, keyType string) (certPath, keyPath string) {
t.Helper()
return generateTestSubCAWithValidity(t, keyType, time.Now(), time.Now().AddDate(5, 0, 0))
}
// generateTestSubCAWithValidity creates a self-signed CA cert+key pair with an
// explicit NotBefore/NotAfter window. Used by M-5 tests that exercise expired
// and not-yet-valid CA rejection in loadCAFromDisk.
func generateTestSubCAWithValidity(t *testing.T, keyType string, notBefore, notAfter time.Time) (certPath, keyPath string) {
t.Helper() t.Helper()
tmpDir := t.TempDir() tmpDir := t.TempDir()
certPath = filepath.Join(tmpDir, "ca.pem") certPath = filepath.Join(tmpDir, "ca.pem")
@@ -445,8 +562,8 @@ func generateTestSubCA(t *testing.T, keyType string) (certPath, keyPath string)
CommonName: "Test Sub-CA", CommonName: "Test Sub-CA",
Organization: []string{"CertCtl Test"}, Organization: []string{"CertCtl Test"},
}, },
NotBefore: time.Now(), NotBefore: notBefore,
NotAfter: time.Now().AddDate(5, 0, 0), NotAfter: notAfter,
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true, BasicConstraintsValid: true,
IsCA: true, IsCA: true,
+197 -37
View File
@@ -1,4 +1,31 @@
// Package crypto provides AES-256-GCM encryption for sensitive configuration data. // Package crypto provides AES-256-GCM encryption for sensitive configuration data.
//
// The on-disk format for blobs produced by [EncryptIfKeySet] is versioned. Two
// versions coexist and both can be read by [DecryptIfKeySet]:
//
// v2 (current, M-8)
// magic(0x02) || salt(16) || nonce(12) || ciphertext+tag
// — 32-byte AES-256 key derived via PBKDF2-SHA256 from the operator
// passphrase and the per-ciphertext random salt.
//
// v1 (legacy, pre-M-8)
// nonce(12) || ciphertext+tag
// — 32-byte AES-256 key derived via PBKDF2-SHA256 from the operator
// passphrase and the package-level fixed salt
// "certctl-config-encryption-v1".
//
// v1 blobs are accepted by the read path for backward compatibility with rows
// persisted before the M-8 remediation. They are never produced by the write
// path. Any row that is updated after M-8 is re-sealed as v2 in-place via the
// normal UPDATE flow.
//
// Rationale for the per-ciphertext salt (see M-8 / CWE-916 / CWE-329): the
// pre-M-8 design reused a single 28-byte fixed salt for every ciphertext, which
// (a) removes one defense-in-depth layer against passphrase-space brute force
// and (b) makes every encrypted column across every row share the exact same
// derived key. v2 replaces the fixed salt with 16 fresh random bytes per write
// and stores the salt alongside the ciphertext. Derived keys now differ per
// row and per re-encryption.
package crypto package crypto
import ( import (
@@ -14,7 +41,8 @@ import (
) )
// ErrEncryptionKeyRequired is returned by EncryptIfKeySet and DecryptIfKeySet when // ErrEncryptionKeyRequired is returned by EncryptIfKeySet and DecryptIfKeySet when
// the caller provides an empty key but the data on the wire requires protection. // the caller provides an empty passphrase but the data on the wire requires
// protection.
// //
// Historically these helpers silently returned plaintext when no key was configured, // Historically these helpers silently returned plaintext when no key was configured,
// which produced a data-at-rest confidentiality bypass (CWE-311): sensitive fields // which produced a data-at-rest confidentiality bypass (CWE-311): sensitive fields
@@ -24,16 +52,58 @@ import (
// and plaintext branches at runtime, so the only visible signal was a warning // and plaintext branches at runtime, so the only visible signal was a warning
// line emitted once at startup. // line emitted once at startup.
// //
// The fix is to fail closed: EncryptIfKeySet/DecryptIfKeySet now require a key // The fix (C-2, commit fb4ce1a) is to fail closed: EncryptIfKeySet/DecryptIfKeySet
// whenever they are invoked on sensitive material, and the server refuses to // now require a passphrase whenever they are invoked on sensitive material, and
// start if any source='database' rows already exist without a configured key. // the server refuses to start if any source='database' rows already exist without
// a configured passphrase.
var ErrEncryptionKeyRequired = errors.New("crypto: CERTCTL_CONFIG_ENCRYPTION_KEY is required to encrypt or decrypt sensitive config") var ErrEncryptionKeyRequired = errors.New("crypto: CERTCTL_CONFIG_ENCRYPTION_KEY is required to encrypt or decrypt sensitive config")
// v2Magic is the first byte of every v2-format ciphertext blob. It distinguishes
// v2 blobs (per-ciphertext random salt, embedded in the blob) from v1 legacy
// blobs (no magic byte, fixed package-level salt).
//
// The choice of 0x02 is deliberate: v1 blobs begin with a random 12-byte AES-GCM
// nonce. A v1 nonce can coincidentally start with 0x02 with probability 1/256,
// which makes a pure magic-byte dispatch ambiguous. [DecryptIfKeySet] resolves
// the ambiguity by falling back to the v1 path when v2 AEAD verification fails.
const v2Magic byte = 0x02
// v2SaltSize is the length in bytes of the per-ciphertext salt embedded in a
// v2 blob. 16 bytes (128 bits) matches the lower bound recommended in NIST
// SP 800-132 §5.1 for PBKDF2 salts and is sufficient given the one-shot-per-row
// nature of the derivation.
const v2SaltSize = 16
// pbkdf2Iterations is the PBKDF2-SHA256 work factor applied uniformly to both
// v1 and v2 key derivations. The value is preserved from the pre-M-8 design so
// that v1 fallback reads stay bit-identical.
const pbkdf2Iterations = 100000
// aes256KeySize is the output length in bytes of both [DeriveKey] and
// [deriveKeyWithSalt]. It is also the only AES key length accepted by [Encrypt]
// and [Decrypt].
const aes256KeySize = 32
// legacyV1Salt is the fixed salt used by pre-M-8 config encryption. It is
// retained exclusively to preserve the v1 read path — any v1 blob that pre-dates
// M-8 remediation must be decryptable with a key derived from (passphrase,
// legacyV1Salt). The write path never uses this salt.
//
// Exposed as a package-level var rather than a local so that tests can reason
// about v1 fixture bytes symbolically.
var legacyV1Salt = []byte("certctl-config-encryption-v1")
// Encrypt encrypts plaintext using AES-256-GCM with a random 12-byte nonce prepended to the output. // Encrypt encrypts plaintext using AES-256-GCM with a random 12-byte nonce prepended to the output.
// The key must be exactly 32 bytes (AES-256). Returns [12-byte nonce][ciphertext+tag]. // The key must be exactly 32 bytes (AES-256). Returns [12-byte nonce][ciphertext+tag].
//
// Encrypt is a low-level primitive. It is intentionally kept byte-identical to
// the pre-M-8 implementation so that existing v1 blobs on disk remain
// decryptable via [Decrypt] when paired with a [DeriveKey]-derived key. New
// callers should prefer [EncryptIfKeySet], which handles key derivation and
// emits the v2 wire format.
func Encrypt(plaintext []byte, key []byte) ([]byte, error) { func Encrypt(plaintext []byte, key []byte) ([]byte, error) {
if len(key) != 32 { if len(key) != aes256KeySize {
return nil, fmt.Errorf("encryption key must be exactly 32 bytes, got %d", len(key)) return nil, fmt.Errorf("encryption key must be exactly %d bytes, got %d", aes256KeySize, len(key))
} }
block, err := aes.NewCipher(key) block, err := aes.NewCipher(key)
@@ -57,9 +127,14 @@ func Encrypt(plaintext []byte, key []byte) ([]byte, error) {
// Decrypt decrypts ciphertext that was encrypted with Encrypt. // Decrypt decrypts ciphertext that was encrypted with Encrypt.
// Expects format: [12-byte nonce][ciphertext+tag]. Key must be exactly 32 bytes. // Expects format: [12-byte nonce][ciphertext+tag]. Key must be exactly 32 bytes.
//
// Decrypt is a low-level primitive. It is intentionally kept byte-identical to
// the pre-M-8 implementation so that [DecryptIfKeySet] can delegate to it for
// both the v2 inner blob (after stripping the magic byte + embedded salt) and
// the v1 legacy blob (unmodified).
func Decrypt(ciphertext []byte, key []byte) ([]byte, error) { func Decrypt(ciphertext []byte, key []byte) ([]byte, error) {
if len(key) != 32 { if len(key) != aes256KeySize {
return nil, fmt.Errorf("encryption key must be exactly 32 bytes, got %d", len(key)) return nil, fmt.Errorf("encryption key must be exactly %d bytes, got %d", aes256KeySize, len(key))
} }
block, err := aes.NewCipher(key) block, err := aes.NewCipher(key)
@@ -86,48 +161,133 @@ func Decrypt(ciphertext []byte, key []byte) ([]byte, error) {
return plaintext, nil return plaintext, nil
} }
// DeriveKey derives a 32-byte AES-256 key from a passphrase using PBKDF2-SHA256. // DeriveKey derives a 32-byte AES-256 key from a passphrase using PBKDF2-SHA256
// Uses a fixed application-specific salt and 100,000 iterations for resistance // with the legacy v1 fixed salt.
// to brute-force attacks on weak passphrases. //
// This helper is preserved byte-identical to the pre-M-8 implementation so that
// v1 ciphertexts persisted before the M-8 remediation remain decryptable
// unchanged. New code paths should prefer [EncryptIfKeySet] and
// [DecryptIfKeySet], which use a per-ciphertext random salt.
func DeriveKey(passphrase string) []byte { func DeriveKey(passphrase string) []byte {
// Fixed salt is acceptable here because: return deriveKeyWithSalt(passphrase, legacyV1Salt)
// 1. Each certctl instance has its own passphrase
// 2. The salt prevents generic rainbow table attacks
// 3. Per-user salts are unnecessary (single server key, not user passwords)
salt := []byte("certctl-config-encryption-v1")
return pbkdf2.Key([]byte(passphrase), salt, 100000, 32, sha256.New)
} }
// EncryptIfKeySet encrypts plaintext with the supplied 32-byte AES-256 key. // deriveKeyWithSalt derives a 32-byte AES-256 key from a passphrase and an
// explicit salt using PBKDF2-SHA256 with [pbkdf2Iterations] rounds.
//
// The per-ciphertext random salt path (v2) calls this directly with a fresh
// 16-byte random salt embedded in the ciphertext blob. The legacy path
// ([DeriveKey]) calls it with the package-level fixed salt [legacyV1Salt].
func deriveKeyWithSalt(passphrase string, salt []byte) []byte {
return pbkdf2.Key([]byte(passphrase), salt, pbkdf2Iterations, aes256KeySize, sha256.New)
}
// IsLegacyFormat reports whether blob is in the v1 legacy wire format (no magic
// byte, fixed-salt derivation) as opposed to the v2 wire format
// (magic(0x02) || salt(16) || nonce(12) || ciphertext+tag).
//
// A return value of false is a necessary but not sufficient condition for a
// blob to be a valid v2 ciphertext: the shortest possible v2 blob is
// 1 + v2SaltSize + 12 = 29 bytes, and even a 29+ byte blob that starts with
// 0x02 may turn out to be a v1 ciphertext whose random nonce happens to begin
// with 0x02 (probability 1/256). [DecryptIfKeySet] resolves this ambiguity at
// decrypt time by falling back to v1 when v2 AEAD verification fails; callers
// of IsLegacyFormat should use it only as a heuristic (e.g. migration
// tooling, log annotation).
func IsLegacyFormat(blob []byte) bool {
if len(blob) == 0 {
return false
}
return blob[0] != v2Magic
}
// EncryptIfKeySet encrypts plaintext with the supplied passphrase and emits a
// v2 wire-format blob: magic(0x02) || salt(16) || nonce(12) || ciphertext+tag.
//
// Key derivation is performed internally per invocation with a fresh 16-byte
// random salt, producing a distinct AES-256 key for every ciphertext. The
// operator-supplied passphrase is the only cross-ciphertext shared secret.
// //
// The second return value is always true when err == nil — the "wasEncrypted" // The second return value is always true when err == nil — the "wasEncrypted"
// flag is retained for source-compatibility with callers that previously used it // flag is retained for source-compatibility with callers that previously used
// to log provenance. Callers MUST handle err: passing an empty key now returns // it to log provenance. Callers MUST handle err: passing an empty passphrase
// ErrEncryptionKeyRequired rather than silently emitting plaintext. See the // returns [ErrEncryptionKeyRequired] rather than silently emitting plaintext.
// package-level ErrEncryptionKeyRequired documentation for the history behind // See the package-level [ErrEncryptionKeyRequired] documentation for the
// this behavior change. // history behind this behavior change (C-2).
func EncryptIfKeySet(plaintext []byte, key []byte) ([]byte, bool, error) { //
if len(key) == 0 { // The write path never produces a v1 blob. v1 blobs are read-only legacy
// state — see [DecryptIfKeySet] for the compatibility fallback.
func EncryptIfKeySet(plaintext []byte, passphrase string) ([]byte, bool, error) {
if passphrase == "" {
return nil, false, ErrEncryptionKeyRequired return nil, false, ErrEncryptionKeyRequired
} }
encrypted, err := Encrypt(plaintext, key)
salt := make([]byte, v2SaltSize)
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
return nil, false, fmt.Errorf("failed to generate v2 salt: %w", err)
}
key := deriveKeyWithSalt(passphrase, salt)
inner, err := Encrypt(plaintext, key)
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
return encrypted, true, nil
// v2 blob layout: magic(1) || salt(v2SaltSize) || inner
blob := make([]byte, 0, 1+v2SaltSize+len(inner))
blob = append(blob, v2Magic)
blob = append(blob, salt...)
blob = append(blob, inner...)
return blob, true, nil
} }
// DecryptIfKeySet decrypts ciphertext with the supplied 32-byte AES-256 key. // DecryptIfKeySet decrypts blob with the supplied passphrase, supporting both
// v2 (M-8 and later) and v1 (legacy) on-disk formats.
// //
// Passing an empty key now returns ErrEncryptionKeyRequired. Callers that // Dispatch is first-byte magic + AEAD fallback. If blob starts with
// legitimately store plaintext (e.g. env-seeded source='env' rows that keep // [v2Magic] and is long enough to contain a v2 header plus an AEAD-authenticated
// the raw JSON in the unencrypted `config` column) must branch on the presence // inner ciphertext, a v2 decrypt is attempted using a key derived from the
// of the ciphertext themselves rather than relying on this helper to silently // embedded salt. If that succeeds, its plaintext is returned. If v2 AEAD
// pass bytes through. See the package-level ErrEncryptionKeyRequired // verification fails — which covers both the "wrong passphrase" case and the
// documentation for the history behind this behavior change. // 1/256 case where a v1 blob's first byte happens to be 0x02 — the function
func DecryptIfKeySet(ciphertext []byte, key []byte) ([]byte, error) { // falls through to the v1 path and attempts decryption using a key derived
if len(key) == 0 { // from the package-level fixed salt [legacyV1Salt].
//
// Passing an empty passphrase returns [ErrEncryptionKeyRequired]. Callers that
// legitimately store plaintext (e.g. env-seeded source='env' rows that keep the
// raw JSON in the unencrypted `config` column) must branch on the presence of
// the ciphertext themselves rather than relying on this helper to silently
// pass bytes through. See the package-level [ErrEncryptionKeyRequired]
// documentation for the history behind this behavior change (C-2).
//
// The function never re-encrypts in place. A v1 blob that is successfully
// decrypted is returned to the caller as plaintext; re-sealing as v2 happens
// naturally on the next UPDATE via [EncryptIfKeySet].
func DecryptIfKeySet(blob []byte, passphrase string) ([]byte, error) {
if passphrase == "" {
return nil, ErrEncryptionKeyRequired return nil, ErrEncryptionKeyRequired
} }
return Decrypt(ciphertext, key) if len(blob) == 0 {
return nil, fmt.Errorf("ciphertext is empty")
}
// v2 path: magic || salt(16) || nonce(12) || ciphertext+tag (min 29 bytes
// ignoring the GCM tag; the AEAD verify inside Decrypt enforces the tag).
if blob[0] == v2Magic && len(blob) >= 1+v2SaltSize+12 {
salt := blob[1 : 1+v2SaltSize]
sealed := blob[1+v2SaltSize:]
key := deriveKeyWithSalt(passphrase, salt)
if plaintext, err := Decrypt(sealed, key); err == nil {
return plaintext, nil
}
// v2 AEAD verification failed. Fall through to v1 so that a v1 blob
// whose first byte happens to be 0x02 (1/256 probability) is still
// decryptable. If this is truly a v2 blob with the wrong passphrase,
// the v1 attempt below will also fail and the v1 error is returned.
}
// v1 legacy path: blob is the full ciphertext with no header and was
// sealed with a key derived from (passphrase, legacyV1Salt).
key := DeriveKey(passphrase)
return Decrypt(blob, key)
} }
+238 -46
View File
@@ -2,6 +2,8 @@ package crypto
import ( import (
"bytes" "bytes"
"crypto/aes"
"crypto/cipher"
"errors" "errors"
"testing" "testing"
) )
@@ -126,21 +128,20 @@ func TestDeriveKeyDifferentPassphrases(t *testing.T) {
} }
func TestEncryptIfKeySet_WithKey(t *testing.T) { func TestEncryptIfKeySet_WithKey(t *testing.T) {
key := DeriveKey("test-key")
plaintext := []byte("config data") plaintext := []byte("config data")
result, wasEncrypted, err := EncryptIfKeySet(plaintext, key) result, wasEncrypted, err := EncryptIfKeySet(plaintext, "test-passphrase")
if err != nil { if err != nil {
t.Fatalf("EncryptIfKeySet failed: %v", err) t.Fatalf("EncryptIfKeySet failed: %v", err)
} }
if !wasEncrypted { if !wasEncrypted {
t.Fatal("expected wasEncrypted=true when key provided") t.Fatal("expected wasEncrypted=true when passphrase provided")
} }
if bytes.Equal(result, plaintext) { if bytes.Equal(result, plaintext) {
t.Fatal("result should be encrypted") t.Fatal("result should be encrypted")
} }
decrypted, err := DecryptIfKeySet(result, key) decrypted, err := DecryptIfKeySet(result, "test-passphrase")
if err != nil { if err != nil {
t.Fatalf("DecryptIfKeySet failed: %v", err) t.Fatalf("DecryptIfKeySet failed: %v", err)
} }
@@ -150,24 +151,14 @@ func TestEncryptIfKeySet_WithKey(t *testing.T) {
} }
// TestEncryptIfKeySet_EmptyKeyFailsClosed asserts the C-2 regression guard: // TestEncryptIfKeySet_EmptyKeyFailsClosed asserts the C-2 regression guard:
// EncryptIfKeySet must refuse to silently emit plaintext when no key is configured. // EncryptIfKeySet must refuse to silently emit plaintext when no passphrase is
// The pre-fix behavior was to return plaintext with wasEncrypted=false, which // configured. The pre-fix behavior was to return plaintext with
// produced a data-at-rest confidentiality bypass (CWE-311) for GUI-created // wasEncrypted=false, which produced a data-at-rest confidentiality bypass
// issuer and target configs. // (CWE-311) for GUI-created issuer and target configs.
func TestEncryptIfKeySet_EmptyKeyFailsClosed(t *testing.T) { func TestEncryptIfKeySet_EmptyKeyFailsClosed(t *testing.T) {
plaintext := []byte("config data") plaintext := []byte("config data")
cases := []struct { result, wasEncrypted, err := EncryptIfKeySet(plaintext, "")
name string
key []byte
}{
{"nil_key", nil},
{"empty_key", []byte{}},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
result, wasEncrypted, err := EncryptIfKeySet(plaintext, tc.key)
if err == nil { if err == nil {
t.Fatal("expected ErrEncryptionKeyRequired, got nil") t.Fatal("expected ErrEncryptionKeyRequired, got nil")
} }
@@ -180,27 +171,15 @@ func TestEncryptIfKeySet_EmptyKeyFailsClosed(t *testing.T) {
if result != nil { if result != nil {
t.Fatalf("expected nil result on error, got %q", result) t.Fatalf("expected nil result on error, got %q", result)
} }
})
}
} }
// TestDecryptIfKeySet_EmptyKeyFailsClosed asserts the matching C-2 regression // TestDecryptIfKeySet_EmptyKeyFailsClosed asserts the matching C-2 regression
// guard on the read path: DecryptIfKeySet must refuse to pass ciphertext // guard on the read path: DecryptIfKeySet must refuse to pass ciphertext
// through as plaintext when no key is configured. // through as plaintext when no passphrase is configured.
func TestDecryptIfKeySet_EmptyKeyFailsClosed(t *testing.T) { func TestDecryptIfKeySet_EmptyKeyFailsClosed(t *testing.T) {
data := []byte("plaintext config data") data := []byte("plaintext config data")
cases := []struct { result, err := DecryptIfKeySet(data, "")
name string
key []byte
}{
{"nil_key", nil},
{"empty_key", []byte{}},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
result, err := DecryptIfKeySet(data, tc.key)
if err == nil { if err == nil {
t.Fatal("expected ErrEncryptionKeyRequired, got nil") t.Fatal("expected ErrEncryptionKeyRequired, got nil")
} }
@@ -210,29 +189,26 @@ func TestDecryptIfKeySet_EmptyKeyFailsClosed(t *testing.T) {
if result != nil { if result != nil {
t.Fatalf("expected nil result on error, got %q", result) t.Fatalf("expected nil result on error, got %q", result)
} }
})
}
} }
// TestEncryptDecryptIfKeySet_RoundTripProducesDifferentCiphertext proves the // TestEncryptDecryptIfKeySet_RoundTripProducesDifferentCiphertext proves the
// "if set" helpers produce real AES-GCM output (not plaintext) and that a full // "if set" helpers produce real AES-GCM output (not plaintext) and that a full
// round-trip through both helpers recovers the original bytes. // round-trip through both helpers recovers the original bytes.
func TestEncryptDecryptIfKeySet_RoundTripProducesDifferentCiphertext(t *testing.T) { func TestEncryptDecryptIfKeySet_RoundTripProducesDifferentCiphertext(t *testing.T) {
key := DeriveKey("round-trip-key")
plaintext := []byte(`{"api_key":"s3cr3t","token":"abc"}`) plaintext := []byte(`{"api_key":"s3cr3t","token":"abc"}`)
encrypted, wasEncrypted, err := EncryptIfKeySet(plaintext, key) encrypted, wasEncrypted, err := EncryptIfKeySet(plaintext, "round-trip-key")
if err != nil { if err != nil {
t.Fatalf("EncryptIfKeySet failed: %v", err) t.Fatalf("EncryptIfKeySet failed: %v", err)
} }
if !wasEncrypted { if !wasEncrypted {
t.Fatal("wasEncrypted must be true when key is present") t.Fatal("wasEncrypted must be true when passphrase is present")
} }
if bytes.Equal(encrypted, plaintext) { if bytes.Equal(encrypted, plaintext) {
t.Fatal("EncryptIfKeySet returned plaintext — would regress C-2") t.Fatal("EncryptIfKeySet returned plaintext — would regress C-2")
} }
decrypted, err := DecryptIfKeySet(encrypted, key) decrypted, err := DecryptIfKeySet(encrypted, "round-trip-key")
if err != nil { if err != nil {
t.Fatalf("DecryptIfKeySet failed: %v", err) t.Fatalf("DecryptIfKeySet failed: %v", err)
} }
@@ -242,22 +218,24 @@ func TestEncryptDecryptIfKeySet_RoundTripProducesDifferentCiphertext(t *testing.
} }
// TestDecryptIfKeySet_RejectsTamperedCiphertext confirms the AEAD auth tag // TestDecryptIfKeySet_RejectsTamperedCiphertext confirms the AEAD auth tag
// still rejects modified ciphertext when routed through the helper. // still rejects modified ciphertext when routed through the helper. The v2
// wire format is magic(1) || salt(16) || nonce(12) || ciphertext+tag, so
// flipping a byte anywhere past offset 29 lands squarely inside the AEAD body.
func TestDecryptIfKeySet_RejectsTamperedCiphertext(t *testing.T) { func TestDecryptIfKeySet_RejectsTamperedCiphertext(t *testing.T) {
key := DeriveKey("tamper-test-key")
plaintext := []byte("authenticated data") plaintext := []byte("authenticated data")
encrypted, _, err := EncryptIfKeySet(plaintext, key) encrypted, _, err := EncryptIfKeySet(plaintext, "tamper-test-key")
if err != nil { if err != nil {
t.Fatalf("EncryptIfKeySet failed: %v", err) t.Fatalf("EncryptIfKeySet failed: %v", err)
} }
// Flip a byte inside the GCM body (past the 12-byte nonce) to invalidate the tag. // Flip a byte past the v2 header (1 + 16 + 12 = 29) to invalidate the tag.
if len(encrypted) <= 13 { const minV2HeaderLen = 1 + v2SaltSize + 12
if len(encrypted) <= minV2HeaderLen {
t.Fatalf("ciphertext too short to tamper: %d bytes", len(encrypted)) t.Fatalf("ciphertext too short to tamper: %d bytes", len(encrypted))
} }
encrypted[13] ^= 0xFF encrypted[minV2HeaderLen] ^= 0xFF
if _, err := DecryptIfKeySet(encrypted, key); err == nil { if _, err := DecryptIfKeySet(encrypted, "tamper-test-key"); err == nil {
t.Fatal("DecryptIfKeySet accepted tampered ciphertext — AEAD tag check bypassed") t.Fatal("DecryptIfKeySet accepted tampered ciphertext — AEAD tag check bypassed")
} }
} }
@@ -296,3 +274,217 @@ func TestEncryptProducesDifferentCiphertexts(t *testing.T) {
t.Fatal("encrypting same plaintext twice should produce different ciphertexts (random nonce)") t.Fatal("encrypting same plaintext twice should produce different ciphertexts (random nonce)")
} }
} }
// ---------------------------------------------------------------------------
// M-8 additions: per-ciphertext salt + v2 wire format + v1 backward compat.
// ---------------------------------------------------------------------------
// TestDeriveKey_DifferentSaltsProduceDifferentKeys asserts that
// deriveKeyWithSalt fans out distinct 32-byte keys for the same passphrase
// across different salts. This is the core M-8 defense-in-depth property: even
// if an attacker obtains two v2 ciphertexts encrypted with the same master
// passphrase, the derived AES keys differ, and a brute-force attempt on one
// blob cannot be amortized across the other.
func TestDeriveKey_DifferentSaltsProduceDifferentKeys(t *testing.T) {
passphrase := "master-passphrase"
saltA := bytes.Repeat([]byte{0xAA}, v2SaltSize)
saltB := bytes.Repeat([]byte{0xBB}, v2SaltSize)
keyA := deriveKeyWithSalt(passphrase, saltA)
keyB := deriveKeyWithSalt(passphrase, saltB)
if len(keyA) != aes256KeySize || len(keyB) != aes256KeySize {
t.Fatalf("derived key length wrong: %d / %d", len(keyA), len(keyB))
}
if bytes.Equal(keyA, keyB) {
t.Fatal("deriveKeyWithSalt must produce different keys for different salts")
}
// Sanity-check that deterministic behaviour is preserved under a fixed salt.
keyA2 := deriveKeyWithSalt(passphrase, saltA)
if !bytes.Equal(keyA, keyA2) {
t.Fatal("deriveKeyWithSalt must be deterministic for a fixed (passphrase, salt)")
}
}
// TestEncryptIfKeySet_ProducesV2Format asserts the exact v2 wire-format bytes:
// magic(0x02) || salt(16) || nonce(12) || ciphertext+tag.
func TestEncryptIfKeySet_ProducesV2Format(t *testing.T) {
blob, _, err := EncryptIfKeySet([]byte("hello"), "any-passphrase")
if err != nil {
t.Fatalf("EncryptIfKeySet failed: %v", err)
}
const minLen = 1 + v2SaltSize + 12 + 16 // magic + salt + nonce + GCM tag (16)
if len(blob) < minLen {
t.Fatalf("v2 blob too short: got %d, want >= %d", len(blob), minLen)
}
if blob[0] != v2Magic {
t.Fatalf("v2 blob must start with magic byte 0x%02x, got 0x%02x", v2Magic, blob[0])
}
if IsLegacyFormat(blob) {
t.Fatal("IsLegacyFormat must return false for a freshly produced v2 blob")
}
}
// TestEncryptIfKeySet_SaltIsRandom asserts that two calls with the same
// passphrase and plaintext produce distinct embedded salts.
func TestEncryptIfKeySet_SaltIsRandom(t *testing.T) {
plaintext := []byte("same plaintext")
passphrase := "same-passphrase"
blob1, _, err := EncryptIfKeySet(plaintext, passphrase)
if err != nil {
t.Fatalf("EncryptIfKeySet #1 failed: %v", err)
}
blob2, _, err := EncryptIfKeySet(plaintext, passphrase)
if err != nil {
t.Fatalf("EncryptIfKeySet #2 failed: %v", err)
}
salt1 := blob1[1 : 1+v2SaltSize]
salt2 := blob2[1 : 1+v2SaltSize]
if bytes.Equal(salt1, salt2) {
t.Fatal("two EncryptIfKeySet invocations must produce distinct per-ciphertext salts")
}
if bytes.Equal(blob1, blob2) {
t.Fatal("two v2 blobs with same (passphrase, plaintext) must differ end-to-end")
}
}
// TestDecryptIfKeySet_V1BackwardCompat builds a deterministic v1-format
// ciphertext using the pre-M-8 recipe (DeriveKey with the fixed salt, then
// Encrypt with an all-zero nonce for reproducibility) and asserts that
// DecryptIfKeySet still decrypts it correctly. This is the migration guarantee:
// v1 blobs persisted before M-8 must remain decryptable.
func TestDecryptIfKeySet_V1BackwardCompat(t *testing.T) {
passphrase := "legacy-passphrase"
plaintext := []byte(`{"api_key":"legacy","org_id":"789"}`)
// Build a deterministic v1 blob directly: nonce(12 zero bytes) || ct+tag.
// This matches the exact wire shape that Encrypt produces, minus the random
// nonce, so the test is stable rather than 1/256 flaky.
key := DeriveKey(passphrase) // fixed-salt derivation (pre-M-8 behavior)
block, err := aes.NewCipher(key)
if err != nil {
t.Fatalf("aes.NewCipher: %v", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
t.Fatalf("cipher.NewGCM: %v", err)
}
nonce := make([]byte, gcm.NonceSize()) // all zeros → first byte != v2Magic
v1Blob := gcm.Seal(nonce, nonce, plaintext, nil)
if v1Blob[0] == v2Magic {
t.Fatalf("fixture nonce collided with v2 magic byte — test design error")
}
decrypted, err := DecryptIfKeySet(v1Blob, passphrase)
if err != nil {
t.Fatalf("DecryptIfKeySet(v1) failed: %v", err)
}
if !bytes.Equal(decrypted, plaintext) {
t.Fatalf("v1 decrypt mismatch: got %q, want %q", decrypted, plaintext)
}
// Cross-check: IsLegacyFormat should flag this as legacy.
if !IsLegacyFormat(v1Blob) {
t.Fatal("IsLegacyFormat must return true for a v1 blob whose first byte != v2Magic")
}
}
// TestDecryptIfKeySet_V1MagicByteCollisionFallsThrough covers the 1/256 edge
// case where a v1 ciphertext's random 12-byte nonce happens to begin with
// 0x02. The dispatch must attempt v2, see AEAD failure, and fall through to
// v1 — never return a decrypt error when the passphrase is correct.
func TestDecryptIfKeySet_V1MagicByteCollisionFallsThrough(t *testing.T) {
passphrase := "collision-passphrase"
plaintext := []byte("colliding v1 blob")
// Craft a v1 blob whose first byte equals v2Magic by choosing a nonce
// starting with 0x02 and sealing manually.
key := DeriveKey(passphrase)
block, err := aes.NewCipher(key)
if err != nil {
t.Fatalf("aes.NewCipher: %v", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
t.Fatalf("cipher.NewGCM: %v", err)
}
nonce := make([]byte, gcm.NonceSize())
nonce[0] = v2Magic // force collision
v1Blob := gcm.Seal(nonce, nonce, plaintext, nil)
if v1Blob[0] != v2Magic {
t.Fatal("fixture construction bug: first byte must equal v2Magic")
}
decrypted, err := DecryptIfKeySet(v1Blob, passphrase)
if err != nil {
t.Fatalf("DecryptIfKeySet must fall through to v1 on AEAD failure, got err: %v", err)
}
if !bytes.Equal(decrypted, plaintext) {
t.Fatalf("v1-via-fallback decrypt mismatch: got %q, want %q", decrypted, plaintext)
}
}
// TestDecryptIfKeySet_V2WithWrongPassphraseFails asserts that a v2 blob
// sealed under passphrase A cannot be decrypted under passphrase B. Both the
// v2 AEAD verify (with salt from the blob + passphrase B) and the v1 fallback
// (with fixed salt + passphrase B) must fail, and an error must be returned
// rather than silently-corrupt plaintext.
func TestDecryptIfKeySet_V2WithWrongPassphraseFails(t *testing.T) {
blob, _, err := EncryptIfKeySet([]byte("secret"), "passphrase-A")
if err != nil {
t.Fatalf("EncryptIfKeySet failed: %v", err)
}
got, err := DecryptIfKeySet(blob, "passphrase-B")
if err == nil {
t.Fatalf("DecryptIfKeySet must return error for wrong passphrase, got plaintext %q", got)
}
if got != nil {
t.Fatalf("result must be nil on decrypt error, got %q", got)
}
}
// TestDecryptIfKeySet_TruncatedV2Blob asserts that a blob starting with the v2
// magic byte but too short to contain a full v2 header does not trip an
// out-of-bounds slice and does not succeed. It either returns an error (v1
// fallback on the short bytes fails with "ciphertext too short") or at minimum
// never returns plaintext.
func TestDecryptIfKeySet_TruncatedV2Blob(t *testing.T) {
truncated := []byte{v2Magic, 0x00, 0x01, 0x02, 0x03} // 5 bytes — well below the 29-byte v2 minimum
got, err := DecryptIfKeySet(truncated, "any-passphrase")
if err == nil {
t.Fatalf("DecryptIfKeySet must reject a truncated v2 blob, got plaintext %q", got)
}
if got != nil {
t.Fatalf("result must be nil on decrypt error, got %q", got)
}
}
// TestIsLegacyFormat covers the three branches of the public magic-byte
// heuristic: v2 blob → false, v1 blob → true, empty blob → false.
func TestIsLegacyFormat(t *testing.T) {
v2Blob, _, err := EncryptIfKeySet([]byte("data"), "p")
if err != nil {
t.Fatalf("EncryptIfKeySet failed: %v", err)
}
if IsLegacyFormat(v2Blob) {
t.Fatal("v2 blob must not be flagged as legacy")
}
// Any blob whose first byte isn't v2Magic should be reported as legacy.
v1Shape := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0xFF}
if !IsLegacyFormat(v1Shape) {
t.Fatal("non-v2-magic blob must be flagged as legacy")
}
if IsLegacyFormat(nil) {
t.Fatal("nil blob must not be flagged as legacy (undefined)")
}
if IsLegacyFormat([]byte{}) {
t.Fatal("empty blob must not be flagged as legacy (undefined)")
}
}
+2
View File
@@ -12,6 +12,7 @@ type PolicyRule struct {
Type PolicyType `json:"type"` Type PolicyType `json:"type"`
Config json.RawMessage `json:"config"` Config json.RawMessage `json:"config"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
Severity PolicySeverity `json:"severity"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
} }
@@ -25,6 +26,7 @@ const (
PolicyTypeRequiredMetadata PolicyType = "RequiredMetadata" PolicyTypeRequiredMetadata PolicyType = "RequiredMetadata"
PolicyTypeAllowedEnvironments PolicyType = "AllowedEnvironments" PolicyTypeAllowedEnvironments PolicyType = "AllowedEnvironments"
PolicyTypeRenewalLeadTime PolicyType = "RenewalLeadTime" PolicyTypeRenewalLeadTime PolicyType = "RenewalLeadTime"
PolicyTypeCertificateLifetime PolicyType = "CertificateLifetime"
) )
// PolicyViolation records an instance of a certificate violating a policy rule. // PolicyViolation records an instance of a certificate violating a policy rule.
+1 -1
View File
@@ -158,7 +158,7 @@ func TestCrossResourceWorkflow(t *testing.T) {
payload := map[string]interface{}{ payload := map[string]interface{}{
"name": "Allowed Domains Policy", "name": "Allowed Domains Policy",
"type": "AllowedDomains", "type": "AllowedDomains",
"severity": "High", "severity": "Error",
"config": json.RawMessage(`{"domains": ["example.com", "*.example.com"]}`), "config": json.RawMessage(`{"domains": ["example.com", "*.example.com"]}`),
"description": "Restrict issuance to example.com domains", "description": "Restrict issuance to example.com domains",
} }
+35 -27
View File
@@ -70,7 +70,7 @@ func TestCertificateLifecycle(t *testing.T) {
// without a configured CERTCTL_CONFIG_ENCRYPTION_KEY. Happy-path CRUD tests // without a configured CERTCTL_CONFIG_ENCRYPTION_KEY. Happy-path CRUD tests
// must supply a real key so the encrypt path runs instead of returning // must supply a real key so the encrypt path runs instead of returning
// ErrEncryptionKeyRequired. // ErrEncryptionKeyRequired.
testEncryptionKey := []byte("0123456789abcdef0123456789abcdef") testEncryptionKey := "0123456789abcdef0123456789abcdef"
issuerService := service.NewIssuerService(issuerRepo, auditService, issuerRegistry, testEncryptionKey, slog.Default()) issuerService := service.NewIssuerService(issuerRepo, auditService, issuerRegistry, testEncryptionKey, slog.Default())
// Initialize handlers // Initialize handlers
@@ -772,6 +772,14 @@ func (m *mockAgentRepository) Create(ctx context.Context, agent *domain.Agent) e
return nil return nil
} }
func (m *mockAgentRepository) CreateIfNotExists(ctx context.Context, agent *domain.Agent) (bool, error) {
if _, exists := m.agents[agent.ID]; exists {
return false, nil
}
m.agents[agent.ID] = agent
return true, nil
}
func (m *mockAgentRepository) Update(ctx context.Context, agent *domain.Agent) error { func (m *mockAgentRepository) Update(ctx context.Context, agent *domain.Agent) error {
m.agents[agent.ID] = agent m.agents[agent.ID] = agent
return nil return nil
@@ -1028,8 +1036,8 @@ type mockTargetService struct {
auditService *service.AuditService auditService *service.AuditService
} }
func (m *mockTargetService) ListTargets(page, perPage int) ([]domain.DeploymentTarget, int64, error) { func (m *mockTargetService) ListTargets(ctx context.Context, page, perPage int) ([]domain.DeploymentTarget, int64, error) {
targets, err := m.targetRepo.List(context.Background()) targets, err := m.targetRepo.List(ctx)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@@ -1040,99 +1048,99 @@ func (m *mockTargetService) ListTargets(page, perPage int) ([]domain.DeploymentT
return result, int64(len(result)), nil return result, int64(len(result)), nil
} }
func (m *mockTargetService) GetTarget(id string) (*domain.DeploymentTarget, error) { func (m *mockTargetService) GetTarget(ctx context.Context, id string) (*domain.DeploymentTarget, error) {
return m.targetRepo.Get(context.Background(), id) return m.targetRepo.Get(ctx, id)
} }
func (m *mockTargetService) CreateTarget(target domain.DeploymentTarget) (*domain.DeploymentTarget, error) { func (m *mockTargetService) CreateTarget(ctx context.Context, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
if err := m.targetRepo.Create(context.Background(), &target); err != nil { if err := m.targetRepo.Create(ctx, &target); err != nil {
return nil, err return nil, err
} }
return &target, nil return &target, nil
} }
func (m *mockTargetService) UpdateTarget(id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) { func (m *mockTargetService) UpdateTarget(ctx context.Context, id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
target.ID = id target.ID = id
if err := m.targetRepo.Update(context.Background(), &target); err != nil { if err := m.targetRepo.Update(ctx, &target); err != nil {
return nil, err return nil, err
} }
return &target, nil return &target, nil
} }
func (m *mockTargetService) DeleteTarget(id string) error { func (m *mockTargetService) DeleteTarget(ctx context.Context, id string) error {
return m.targetRepo.Delete(context.Background(), id) return m.targetRepo.Delete(ctx, id)
} }
func (m *mockTargetService) TestTargetConnection(id string) error { func (m *mockTargetService) TestConnection(ctx context.Context, id string) error {
return nil // No-op for integration tests return nil // No-op for integration tests
} }
type mockTeamService struct{} type mockTeamService struct{}
func (m *mockTeamService) ListTeams(page, perPage int) ([]domain.Team, int64, error) { func (m *mockTeamService) ListTeams(_ context.Context, page, perPage int) ([]domain.Team, int64, error) {
return []domain.Team{}, 0, nil return []domain.Team{}, 0, nil
} }
func (m *mockTeamService) GetTeam(id string) (*domain.Team, error) { func (m *mockTeamService) GetTeam(_ context.Context, id string) (*domain.Team, error) {
return nil, fmt.Errorf("team not found") return nil, fmt.Errorf("team not found")
} }
func (m *mockTeamService) CreateTeam(team domain.Team) (*domain.Team, error) { func (m *mockTeamService) CreateTeam(_ context.Context, team domain.Team) (*domain.Team, error) {
return &team, nil return &team, nil
} }
func (m *mockTeamService) UpdateTeam(id string, team domain.Team) (*domain.Team, error) { func (m *mockTeamService) UpdateTeam(_ context.Context, id string, team domain.Team) (*domain.Team, error) {
team.ID = id team.ID = id
return &team, nil return &team, nil
} }
func (m *mockTeamService) DeleteTeam(id string) error { func (m *mockTeamService) DeleteTeam(_ context.Context, id string) error {
return nil return nil
} }
type mockOwnerService struct{} type mockOwnerService struct{}
func (m *mockOwnerService) ListOwners(page, perPage int) ([]domain.Owner, int64, error) { func (m *mockOwnerService) ListOwners(_ context.Context, page, perPage int) ([]domain.Owner, int64, error) {
return []domain.Owner{}, 0, nil return []domain.Owner{}, 0, nil
} }
func (m *mockOwnerService) GetOwner(id string) (*domain.Owner, error) { func (m *mockOwnerService) GetOwner(_ context.Context, id string) (*domain.Owner, error) {
return nil, fmt.Errorf("owner not found") return nil, fmt.Errorf("owner not found")
} }
func (m *mockOwnerService) CreateOwner(owner domain.Owner) (*domain.Owner, error) { func (m *mockOwnerService) CreateOwner(_ context.Context, owner domain.Owner) (*domain.Owner, error) {
return &owner, nil return &owner, nil
} }
func (m *mockOwnerService) UpdateOwner(id string, owner domain.Owner) (*domain.Owner, error) { func (m *mockOwnerService) UpdateOwner(_ context.Context, id string, owner domain.Owner) (*domain.Owner, error) {
owner.ID = id owner.ID = id
return &owner, nil return &owner, nil
} }
func (m *mockOwnerService) DeleteOwner(id string) error { func (m *mockOwnerService) DeleteOwner(_ context.Context, id string) error {
return nil return nil
} }
type mockProfileService struct{} type mockProfileService struct{}
func (m *mockProfileService) ListProfiles(page, perPage int) ([]domain.CertificateProfile, int64, error) { func (m *mockProfileService) ListProfiles(_ context.Context, page, perPage int) ([]domain.CertificateProfile, int64, error) {
return []domain.CertificateProfile{}, 0, nil return []domain.CertificateProfile{}, 0, nil
} }
func (m *mockProfileService) GetProfile(id string) (*domain.CertificateProfile, error) { func (m *mockProfileService) GetProfile(_ context.Context, id string) (*domain.CertificateProfile, error) {
return nil, fmt.Errorf("profile not found") return nil, fmt.Errorf("profile not found")
} }
func (m *mockProfileService) CreateProfile(profile domain.CertificateProfile) (*domain.CertificateProfile, error) { func (m *mockProfileService) CreateProfile(_ context.Context, profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
return &profile, nil return &profile, nil
} }
func (m *mockProfileService) UpdateProfile(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) { func (m *mockProfileService) UpdateProfile(_ context.Context, id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
profile.ID = id profile.ID = id
return &profile, nil return &profile, nil
} }
func (m *mockProfileService) DeleteProfile(id string) error { func (m *mockProfileService) DeleteProfile(_ context.Context, id string) error {
return nil return nil
} }
+1 -1
View File
@@ -62,7 +62,7 @@ func setupTestServer(t *testing.T) (*httptest.Server, *mockCertificateRepository
// without a configured CERTCTL_CONFIG_ENCRYPTION_KEY. Happy-path CRUD tests // without a configured CERTCTL_CONFIG_ENCRYPTION_KEY. Happy-path CRUD tests
// must supply a real key so the encrypt path runs instead of returning // must supply a real key so the encrypt path runs instead of returning
// ErrEncryptionKeyRequired. // ErrEncryptionKeyRequired.
testEncryptionKey := []byte("0123456789abcdef0123456789abcdef") testEncryptionKey := "0123456789abcdef0123456789abcdef"
issuerService := service.NewIssuerService(issuerRepo, auditService, issuerRegistry, testEncryptionKey, logger) issuerService := service.NewIssuerService(issuerRepo, auditService, issuerRegistry, testEncryptionKey, logger)
certificateHandler := handler.NewCertificateHandler(certificateService) certificateHandler := handler.NewCertificateHandler(certificateService)
+2 -2
View File
@@ -610,7 +610,7 @@ func registerPolicyTools(s *gomcp.Server, c *Client) {
gomcp.AddTool(s, &gomcp.Tool{ gomcp.AddTool(s, &gomcp.Tool{
Name: "certctl_create_policy", Name: "certctl_create_policy",
Description: "Create a new policy rule. Requires name and type.", Description: "Create a new policy rule. Requires name and type. Optional severity (Warning, Error, Critical) defaults to Warning.",
}, func(ctx context.Context, req *gomcp.CallToolRequest, input CreatePolicyInput) (*gomcp.CallToolResult, any, error) { }, func(ctx context.Context, req *gomcp.CallToolRequest, input CreatePolicyInput) (*gomcp.CallToolResult, any, error) {
data, err := c.Post("/api/v1/policies", input) data, err := c.Post("/api/v1/policies", input)
if err != nil { if err != nil {
@@ -621,7 +621,7 @@ func registerPolicyTools(s *gomcp.Server, c *Client) {
gomcp.AddTool(s, &gomcp.Tool{ gomcp.AddTool(s, &gomcp.Tool{
Name: "certctl_update_policy", Name: "certctl_update_policy",
Description: "Update a policy rule's name, type, configuration, or enabled status.", Description: "Update a policy rule's name, type, configuration, enabled status, or severity.",
}, func(ctx context.Context, req *gomcp.CallToolRequest, input UpdatePolicyInput) (*gomcp.CallToolResult, any, error) { }, func(ctx context.Context, req *gomcp.CallToolRequest, input UpdatePolicyInput) (*gomcp.CallToolResult, any, error) {
data, err := c.Put("/api/v1/policies/"+input.ID, input) data, err := c.Put("/api/v1/policies/"+input.ID, input)
if err != nil { if err != nil {
+4 -2
View File
@@ -35,7 +35,7 @@ type CreateCertificateInput struct {
TeamID string `json:"team_id" jsonschema:"Team ID (required)"` TeamID string `json:"team_id" jsonschema:"Team ID (required)"`
IssuerID string `json:"issuer_id" jsonschema:"Issuer connector ID"` IssuerID string `json:"issuer_id" jsonschema:"Issuer connector ID"`
TargetIDs []string `json:"target_ids,omitempty" jsonschema:"Deployment target IDs"` TargetIDs []string `json:"target_ids,omitempty" jsonschema:"Deployment target IDs"`
RenewalPolicyID string `json:"renewal_policy_id,omitempty" jsonschema:"Renewal policy ID"` RenewalPolicyID string `json:"renewal_policy_id" jsonschema:"Renewal policy ID (required)"`
ProfileID string `json:"certificate_profile_id,omitempty" jsonschema:"Certificate profile ID"` ProfileID string `json:"certificate_profile_id,omitempty" jsonschema:"Certificate profile ID"`
Tags map[string]string `json:"tags,omitempty" jsonschema:"Key-value tags"` Tags map[string]string `json:"tags,omitempty" jsonschema:"Key-value tags"`
} }
@@ -112,7 +112,7 @@ type CreateTargetInput struct {
ID string `json:"id,omitempty" jsonschema:"Target ID"` ID string `json:"id,omitempty" jsonschema:"Target ID"`
Name string `json:"name" jsonschema:"Target display name"` Name string `json:"name" jsonschema:"Target display name"`
Type string `json:"type" jsonschema:"Target type: NGINX, Apache, HAProxy, F5, IIS"` Type string `json:"type" jsonschema:"Target type: NGINX, Apache, HAProxy, F5, IIS"`
AgentID string `json:"agent_id,omitempty" jsonschema:"Agent ID that manages this target"` AgentID string `json:"agent_id" jsonschema:"Agent ID that manages this target (required)"`
Config interface{} `json:"config,omitempty" jsonschema:"Target-specific configuration"` Config interface{} `json:"config,omitempty" jsonschema:"Target-specific configuration"`
Enabled bool `json:"enabled,omitempty" jsonschema:"Whether the target is enabled"` Enabled bool `json:"enabled,omitempty" jsonschema:"Whether the target is enabled"`
} }
@@ -173,6 +173,7 @@ type CreatePolicyInput struct {
Type string `json:"type" jsonschema:"Policy type: AllowedIssuers, AllowedDomains, RequiredMetadata, AllowedEnvironments, RenewalLeadTime"` Type string `json:"type" jsonschema:"Policy type: AllowedIssuers, AllowedDomains, RequiredMetadata, AllowedEnvironments, RenewalLeadTime"`
Config interface{} `json:"config,omitempty" jsonschema:"Policy-specific configuration"` Config interface{} `json:"config,omitempty" jsonschema:"Policy-specific configuration"`
Enabled bool `json:"enabled,omitempty" jsonschema:"Whether the policy is enabled"` Enabled bool `json:"enabled,omitempty" jsonschema:"Whether the policy is enabled"`
Severity string `json:"severity,omitempty" jsonschema:"Violation severity: Warning, Error, or Critical (default: Warning)"`
} }
type UpdatePolicyInput struct { type UpdatePolicyInput struct {
@@ -181,6 +182,7 @@ type UpdatePolicyInput struct {
Type string `json:"type,omitempty" jsonschema:"Policy type"` Type string `json:"type,omitempty" jsonschema:"Policy type"`
Config interface{} `json:"config,omitempty" jsonschema:"Policy-specific configuration"` Config interface{} `json:"config,omitempty" jsonschema:"Policy-specific configuration"`
Enabled *bool `json:"enabled,omitempty" jsonschema:"Whether the policy is enabled"` Enabled *bool `json:"enabled,omitempty" jsonschema:"Whether the policy is enabled"`
Severity string `json:"severity,omitempty" jsonschema:"Violation severity: Warning, Error, or Critical"`
} }
type ListViolationsInput struct { type ListViolationsInput struct {
+11 -1
View File
@@ -90,8 +90,18 @@ type AgentRepository interface {
List(ctx context.Context) ([]*domain.Agent, error) List(ctx context.Context) ([]*domain.Agent, error)
// Get retrieves an agent by ID. // Get retrieves an agent by ID.
Get(ctx context.Context, id string) (*domain.Agent, error) Get(ctx context.Context, id string) (*domain.Agent, error)
// Create stores a new agent. // Create stores a new agent. Callers that want duplicate-key errors surfaced
// (e.g. real-agent registration) must use this method; sentinel/bootstrap
// paths that expect the row to already exist on restart should call
// CreateIfNotExists instead (M-6, CWE-662).
Create(ctx context.Context, agent *domain.Agent) error Create(ctx context.Context, agent *domain.Agent) error
// CreateIfNotExists creates an agent only if the ID doesn't already exist
// (INSERT ... ON CONFLICT (id) DO NOTHING). Returns true if the row was
// newly inserted, false if a row with the same ID already existed. Used
// by the sentinel-agent bootstrap path in cmd/server/main.go so restarts
// and upgrades are idempotent without swallowing unrelated database
// failures (M-6, CWE-662).
CreateIfNotExists(ctx context.Context, agent *domain.Agent) (bool, error)
// Update modifies an existing agent. // Update modifies an existing agent.
Update(ctx context.Context, agent *domain.Agent) error Update(ctx context.Context, agent *domain.Agent) error
// Delete removes an agent. // Delete removes an agent.
+41 -1
View File
@@ -70,7 +70,9 @@ func (r *AgentRepository) Get(ctx context.Context, id string) (*domain.Agent, er
return agent, nil return agent, nil
} }
// Create stores a new agent // Create stores a new agent. Duplicate-key errors surface to the caller —
// real-agent registration paths rely on this to detect collisions. Use
// CreateIfNotExists for sentinel/bootstrap paths where re-inserts are expected.
func (r *AgentRepository) Create(ctx context.Context, agent *domain.Agent) error { func (r *AgentRepository) Create(ctx context.Context, agent *domain.Agent) error {
if agent.ID == "" { if agent.ID == "" {
agent.ID = uuid.New().String() agent.ID = uuid.New().String()
@@ -92,6 +94,44 @@ func (r *AgentRepository) Create(ctx context.Context, agent *domain.Agent) error
return nil return nil
} }
// CreateIfNotExists creates an agent only if the ID doesn't already exist.
// Used for sentinel agents (server-scanner, cloud-aws-sm, cloud-azure-kv,
// cloud-gcp-sm) on first boot AND on every subsequent restart/upgrade — the
// pre-M-6 code used plain INSERT, swallowed the duplicate-key error, and so
// silently swallowed every other database failure too (CWE-662 /
// CWE-209-adjacent). ON CONFLICT (id) DO NOTHING + RETURNING id +
// sql.ErrNoRows distinguishes "row already existed" (created=false, err=nil)
// from genuine errors (connectivity, permission, constraint violations
// other than the id primary key) which still surface. Returns true if the
// row was newly inserted, false if a row with the same ID already existed.
func (r *AgentRepository) CreateIfNotExists(ctx context.Context, agent *domain.Agent) (bool, error) {
if agent.ID == "" {
agent.ID = uuid.New().String()
}
var id string
err := r.db.QueryRowContext(ctx, `
INSERT INTO agents (id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash,
os, architecture, ip_address, version)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
ON CONFLICT (id) DO NOTHING
RETURNING id
`, agent.ID, agent.Name, agent.Hostname, agent.Status, agent.LastHeartbeatAt,
agent.RegisteredAt, agent.APIKeyHash,
agent.OS, agent.Architecture, agent.IPAddress, agent.Version).Scan(&id)
if err != nil {
if err == sql.ErrNoRows {
// ON CONFLICT DO NOTHING — a row with this ID already existed.
return false, nil
}
return false, fmt.Errorf("failed to create agent: %w", err)
}
agent.ID = id
return true, nil
}
// Update modifies an existing agent // Update modifies an existing agent
func (r *AgentRepository) Update(ctx context.Context, agent *domain.Agent) error { func (r *AgentRepository) Update(ctx context.Context, agent *domain.Agent) error {
result, err := r.db.ExecContext(ctx, ` result, err := r.db.ExecContext(ctx, `
+178 -9
View File
@@ -190,18 +190,65 @@ func (r *CertificateRepository) List(ctx context.Context, filter *repository.Cer
defer rows.Close() defer rows.Close()
var certs []*domain.ManagedCertificate var certs []*domain.ManagedCertificate
var certIDs []string
for rows.Next() { for rows.Next() {
cert, err := scanCertificate(rows) var cert domain.ManagedCertificate
var tagsJSON []byte
var sans pq.StringArray
var profileID sql.NullString
var revocationReason sql.NullString
err := rows.Scan(
&cert.ID, &cert.Name, &cert.CommonName, &sans, &cert.Environment, &cert.OwnerID,
&cert.TeamID, &cert.IssuerID, &cert.RenewalPolicyID, &profileID,
&cert.Status, &cert.ExpiresAt, &tagsJSON,
&cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.RevokedAt, &revocationReason,
&cert.CreatedAt, &cert.UpdatedAt)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, fmt.Errorf("failed to scan certificate: %w", err)
} }
certs = append(certs, cert)
cert.SANs = []string(sans)
if profileID.Valid {
cert.CertificateProfileID = profileID.String
}
if revocationReason.Valid {
cert.RevocationReason = revocationReason.String
}
// Unmarshal tags
if len(tagsJSON) > 0 {
if err := json.Unmarshal(tagsJSON, &cert.Tags); err != nil {
return nil, 0, fmt.Errorf("failed to unmarshal tags: %w", err)
}
} else {
cert.Tags = make(map[string]string)
}
certs = append(certs, &cert)
certIDs = append(certIDs, cert.ID)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, 0, fmt.Errorf("error iterating certificate rows: %w", err) return nil, 0, fmt.Errorf("error iterating certificate rows: %w", err)
} }
// Fetch target IDs for all certificates in a single query (avoid N+1)
if len(certIDs) > 0 {
targetIDsMap, err := r.getTargetIDsForCertificates(ctx, certIDs)
if err != nil {
return nil, 0, err
}
for _, cert := range certs {
if targetIDs, ok := targetIDsMap[cert.ID]; ok {
cert.TargetIDs = targetIDs
} else {
cert.TargetIDs = []string{}
}
}
}
return certs, total, nil return certs, total, nil
} }
@@ -214,7 +261,7 @@ func (r *CertificateRepository) Get(ctx context.Context, id string) (*domain.Man
WHERE id = $1 WHERE id = $1
`, id) `, id)
cert, err := scanCertificate(row) cert, err := r.scanCertificate(ctx, row)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, fmt.Errorf("certificate not found") return nil, fmt.Errorf("certificate not found")
@@ -421,18 +468,65 @@ func (r *CertificateRepository) GetExpiringCertificates(ctx context.Context, bef
defer rows.Close() defer rows.Close()
var certs []*domain.ManagedCertificate var certs []*domain.ManagedCertificate
var certIDs []string
for rows.Next() { for rows.Next() {
cert, err := scanCertificate(rows) var cert domain.ManagedCertificate
var tagsJSON []byte
var sans pq.StringArray
var profileID sql.NullString
var revocationReason sql.NullString
err := rows.Scan(
&cert.ID, &cert.Name, &cert.CommonName, &sans, &cert.Environment, &cert.OwnerID,
&cert.TeamID, &cert.IssuerID, &cert.RenewalPolicyID, &profileID,
&cert.Status, &cert.ExpiresAt, &tagsJSON,
&cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.RevokedAt, &revocationReason,
&cert.CreatedAt, &cert.UpdatedAt)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to scan certificate: %w", err)
} }
certs = append(certs, cert)
cert.SANs = []string(sans)
if profileID.Valid {
cert.CertificateProfileID = profileID.String
}
if revocationReason.Valid {
cert.RevocationReason = revocationReason.String
}
// Unmarshal tags
if len(tagsJSON) > 0 {
if err := json.Unmarshal(tagsJSON, &cert.Tags); err != nil {
return nil, fmt.Errorf("failed to unmarshal tags: %w", err)
}
} else {
cert.Tags = make(map[string]string)
}
certs = append(certs, &cert)
certIDs = append(certIDs, cert.ID)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating expiring certificate rows: %w", err) return nil, fmt.Errorf("error iterating expiring certificate rows: %w", err)
} }
// Fetch target IDs for all certificates in a single query (avoid N+1)
if len(certIDs) > 0 {
targetIDsMap, err := r.getTargetIDsForCertificates(ctx, certIDs)
if err != nil {
return nil, err
}
for _, cert := range certs {
if targetIDs, ok := targetIDsMap[cert.ID]; ok {
cert.TargetIDs = targetIDs
} else {
cert.TargetIDs = []string{}
}
}
}
return certs, nil return certs, nil
} }
@@ -462,8 +556,76 @@ func (r *CertificateRepository) GetLatestVersion(ctx context.Context, certID str
return &v, nil return &v, nil
} }
// scanCertificate scans a certificate from a row or rows // getTargetIDs retrieves all target IDs for a given certificate from the junction table.
func scanCertificate(scanner interface { // Returns an empty slice (not nil) if no targets are found.
func (r *CertificateRepository) getTargetIDs(ctx context.Context, certID string) ([]string, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT target_id FROM certificate_target_mappings
WHERE certificate_id = $1
ORDER BY target_id ASC
`, certID)
if err != nil {
return nil, fmt.Errorf("failed to query target mappings: %w", err)
}
defer rows.Close()
var targetIDs []string
for rows.Next() {
var targetID string
if err := rows.Scan(&targetID); err != nil {
return nil, fmt.Errorf("failed to scan target ID: %w", err)
}
targetIDs = append(targetIDs, targetID)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating target ID rows: %w", err)
}
// Return empty slice instead of nil for consistency with JSON marshaling
if targetIDs == nil {
targetIDs = []string{}
}
return targetIDs, nil
}
// getTargetIDsForCertificates retrieves target IDs for multiple certificates in a single query.
// Returns a map of certificate_id -> []target_id.
func (r *CertificateRepository) getTargetIDsForCertificates(ctx context.Context, certIDs []string) (map[string][]string, error) {
if len(certIDs) == 0 {
return make(map[string][]string), nil
}
rows, err := r.db.QueryContext(ctx, `
SELECT certificate_id, target_id FROM certificate_target_mappings
WHERE certificate_id = ANY($1)
ORDER BY certificate_id, target_id ASC
`, pq.Array(certIDs))
if err != nil {
return nil, fmt.Errorf("failed to query target mappings: %w", err)
}
defer rows.Close()
targetIDsMap := make(map[string][]string)
for rows.Next() {
var certID, targetID string
if err := rows.Scan(&certID, &targetID); err != nil {
return nil, fmt.Errorf("failed to scan target mapping: %w", err)
}
targetIDsMap[certID] = append(targetIDsMap[certID], targetID)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating target mapping rows: %w", err)
}
return targetIDsMap, nil
}
// scanCertificate scans a certificate from a row or rows and populates its TargetIDs
// by querying the certificate_target_mappings junction table.
func (r *CertificateRepository) scanCertificate(ctx context.Context, scanner interface {
Scan(...interface{}) error Scan(...interface{}) error
}) (*domain.ManagedCertificate, error) { }) (*domain.ManagedCertificate, error) {
var cert domain.ManagedCertificate var cert domain.ManagedCertificate
@@ -500,6 +662,13 @@ func scanCertificate(scanner interface {
cert.Tags = make(map[string]string) cert.Tags = make(map[string]string)
} }
// Populate TargetIDs from junction table
targetIDs, err := r.getTargetIDs(ctx, cert.ID)
if err != nil {
return nil, err
}
cert.TargetIDs = targetIDs
return &cert, nil return &cert, nil
} }
@@ -0,0 +1,322 @@
// Package postgres_test — integration tests for M-7: Certificate.TargetIDs
// must be populated from certificate_target_mappings on read.
//
// Before M-7 the repository scan helper never consulted the junction table, so
// Get / List / GetExpiringCertificates always returned empty TargetIDs even when
// rows existed in certificate_target_mappings. These tests exercise all three
// read paths end-to-end against a real PostgreSQL 16 container.
//
// Runs against the shared testcontainer from testutil_test.go. Skipped when
// `-short` is set (CI uses short mode; local runs pick it up by default).
package postgres_test
import (
"context"
"database/sql"
"testing"
"time"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/repository/postgres"
)
// insertAgentAndTargetsRaw creates one agent and N deployment_targets, returns
// the agent ID and the list of target IDs (in insertion order).
func insertAgentAndTargetsRaw(t *testing.T, db *sql.DB, ctx context.Context, suffix string, n int) (agentID string, targetIDs []string) {
t.Helper()
now := time.Now().Truncate(time.Microsecond)
agentID = "agent-" + suffix
_, err := db.ExecContext(ctx, `
INSERT INTO agents (id, name, hostname, status, registered_at, api_key_hash)
VALUES ($1, $2, $3, $4, $5, $6)
`, agentID, "agent-"+suffix, "host-"+suffix, "online", now, "hash-"+suffix)
if err != nil {
t.Fatalf("insertAgent failed: %v", err)
}
for i := 0; i < n; i++ {
tid := "t-" + suffix + "-" + intToStr(i)
_, err := db.ExecContext(ctx, `
INSERT INTO deployment_targets (id, name, type, agent_id, config, enabled, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
`, tid, tid, "NGINX", agentID, []byte(`{}`), true, now, now)
if err != nil {
t.Fatalf("insertTarget %d failed: %v", i, err)
}
targetIDs = append(targetIDs, tid)
}
return agentID, targetIDs
}
// intToStr converts a non-negative int to its decimal string.
// Local helper to avoid importing strconv for a single use.
func intToStr(n int) string {
if n == 0 {
return "0"
}
var buf [20]byte
i := len(buf)
for n > 0 {
i--
buf[i] = byte('0' + n%10)
n /= 10
}
return string(buf[i:])
}
// insertCertificateRow writes a minimal managed_certificates row via raw SQL.
// Bypasses the repository Create so we can isolate read-path tests from any
// write-path behavior. managed_certificates.sans is TEXT[], written here as an
// empty array literal.
func insertCertificateRow(t *testing.T, db *sql.DB, ctx context.Context, certID, ownerID, teamID, issuerID, policyID string, expiresAt time.Time) {
t.Helper()
now := time.Now().Truncate(time.Microsecond)
_, err := db.ExecContext(ctx, `
INSERT INTO managed_certificates (
id, name, common_name, sans, environment,
owner_id, team_id, issuer_id, renewal_policy_id,
status, expires_at, tags,
created_at, updated_at
) VALUES (
$1, $2, $3, ARRAY[]::TEXT[], $4,
$5, $6, $7, $8,
$9, $10, $11,
$12, $13
)
`,
certID, certID, certID+".example.com", "production",
ownerID, teamID, issuerID, policyID,
string(domain.CertificateStatusActive), expiresAt, []byte(`{}`),
now, now,
)
if err != nil {
t.Fatalf("insertCertificateRow failed: %v", err)
}
}
// insertMapping writes a single row into certificate_target_mappings via raw SQL.
func insertMapping(t *testing.T, db *sql.DB, ctx context.Context, certID, targetID string) {
t.Helper()
_, err := db.ExecContext(ctx,
`INSERT INTO certificate_target_mappings (certificate_id, target_id) VALUES ($1, $2)`,
certID, targetID)
if err != nil {
t.Fatalf("insertMapping(%s, %s) failed: %v", certID, targetID, err)
}
}
// --------------------------------------------------------------------
// Get() — single-cert read path
// --------------------------------------------------------------------
// TestGet_PopulatesTargetIDs_NoMappings: no mapping rows → TargetIDs must be
// an empty slice, not nil, so JSON serialisation emits "[]".
func TestGet_PopulatesTargetIDs_NoMappings(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewCertificateRepository(db)
ctx := context.Background()
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "getnone")
certID := "mc-getnone"
insertCertificateRow(t, db, ctx, certID, ownerID, teamID, issuerID, policyID, time.Now().Add(30*24*time.Hour))
got, err := repo.Get(ctx, certID)
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if got.TargetIDs == nil {
t.Fatalf("TargetIDs = nil, want empty slice (JSON serialises nil as null and [] as [])")
}
if len(got.TargetIDs) != 0 {
t.Errorf("len(TargetIDs) = %d, want 0; got %v", len(got.TargetIDs), got.TargetIDs)
}
}
// TestGet_PopulatesTargetIDs_SingleTarget: one mapping → one entry.
func TestGet_PopulatesTargetIDs_SingleTarget(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewCertificateRepository(db)
ctx := context.Background()
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "getone")
_, targets := insertAgentAndTargetsRaw(t, db, ctx, "getone", 1)
certID := "mc-getone"
insertCertificateRow(t, db, ctx, certID, ownerID, teamID, issuerID, policyID, time.Now().Add(30*24*time.Hour))
insertMapping(t, db, ctx, certID, targets[0])
got, err := repo.Get(ctx, certID)
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if len(got.TargetIDs) != 1 {
t.Fatalf("len(TargetIDs) = %d, want 1; got %v", len(got.TargetIDs), got.TargetIDs)
}
if got.TargetIDs[0] != targets[0] {
t.Errorf("TargetIDs[0] = %q, want %q", got.TargetIDs[0], targets[0])
}
}
// TestGet_PopulatesTargetIDs_MultipleTargets: many mappings → sorted by target_id ASC.
func TestGet_PopulatesTargetIDs_MultipleTargets(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewCertificateRepository(db)
ctx := context.Background()
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "getmany")
_, targets := insertAgentAndTargetsRaw(t, db, ctx, "getmany", 3)
certID := "mc-getmany"
insertCertificateRow(t, db, ctx, certID, ownerID, teamID, issuerID, policyID, time.Now().Add(30*24*time.Hour))
// Insert mappings in reverse order to confirm ORDER BY target_id ASC in the query.
insertMapping(t, db, ctx, certID, targets[2])
insertMapping(t, db, ctx, certID, targets[0])
insertMapping(t, db, ctx, certID, targets[1])
got, err := repo.Get(ctx, certID)
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if len(got.TargetIDs) != 3 {
t.Fatalf("len(TargetIDs) = %d, want 3; got %v", len(got.TargetIDs), got.TargetIDs)
}
// Ascending order: t-getmany-0, t-getmany-1, t-getmany-2
want := []string{targets[0], targets[1], targets[2]}
for i, tid := range want {
if got.TargetIDs[i] != tid {
t.Errorf("TargetIDs[%d] = %q, want %q (full: %v)", i, got.TargetIDs[i], tid, got.TargetIDs)
}
}
}
// --------------------------------------------------------------------
// List() — batch read path, must avoid N+1
// --------------------------------------------------------------------
// TestList_PopulatesTargetIDs_BatchFetch: three certs with different mapping counts;
// all must have their TargetIDs populated correctly, and the cert with no mapping
// must get an empty (non-nil) slice.
func TestList_PopulatesTargetIDs_BatchFetch(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewCertificateRepository(db)
ctx := context.Background()
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "listbatch")
_, targets := insertAgentAndTargetsRaw(t, db, ctx, "listbatch", 3)
certA := "mc-list-a"
certB := "mc-list-b"
certC := "mc-list-c"
insertCertificateRow(t, db, ctx, certA, ownerID, teamID, issuerID, policyID, time.Now().Add(30*24*time.Hour))
insertCertificateRow(t, db, ctx, certB, ownerID, teamID, issuerID, policyID, time.Now().Add(30*24*time.Hour))
insertCertificateRow(t, db, ctx, certC, ownerID, teamID, issuerID, policyID, time.Now().Add(30*24*time.Hour))
// certA → 2 targets (t-0, t-1)
insertMapping(t, db, ctx, certA, targets[0])
insertMapping(t, db, ctx, certA, targets[1])
// certB → 1 target (t-2)
insertMapping(t, db, ctx, certB, targets[2])
// certC → 0 targets
got, total, err := repo.List(ctx, nil)
if err != nil {
t.Fatalf("List failed: %v", err)
}
if total < 3 {
t.Fatalf("total = %d, want >= 3", total)
}
want := map[string][]string{
certA: {targets[0], targets[1]},
certB: {targets[2]},
certC: {},
}
seen := map[string]bool{}
for _, c := range got {
exp, ok := want[c.ID]
if !ok {
continue
}
seen[c.ID] = true
if c.TargetIDs == nil {
t.Errorf("cert %s: TargetIDs = nil, want %v", c.ID, exp)
continue
}
if len(c.TargetIDs) != len(exp) {
t.Errorf("cert %s: len(TargetIDs) = %d, want %d (got %v, want %v)", c.ID, len(c.TargetIDs), len(exp), c.TargetIDs, exp)
continue
}
for i, tid := range exp {
if c.TargetIDs[i] != tid {
t.Errorf("cert %s: TargetIDs[%d] = %q, want %q", c.ID, i, c.TargetIDs[i], tid)
}
}
}
for id := range want {
if !seen[id] {
t.Errorf("cert %s missing from List() result", id)
}
}
}
// --------------------------------------------------------------------
// GetExpiringCertificates() — scheduler read path
// --------------------------------------------------------------------
// TestGetExpiringCertificates_PopulatesTargetIDs: expiring certs must also carry
// their mapping information so renewal-triggered deployments can route work.
func TestGetExpiringCertificates_PopulatesTargetIDs(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewCertificateRepository(db)
ctx := context.Background()
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "expiring")
_, targets := insertAgentAndTargetsRaw(t, db, ctx, "expiring", 2)
// Two expiring certs (expires in 3 days). Threshold = 7 days → both selected.
certA := "mc-exp-a"
certB := "mc-exp-b"
expiresSoon := time.Now().Add(3 * 24 * time.Hour)
insertCertificateRow(t, db, ctx, certA, ownerID, teamID, issuerID, policyID, expiresSoon)
insertCertificateRow(t, db, ctx, certB, ownerID, teamID, issuerID, policyID, expiresSoon)
insertMapping(t, db, ctx, certA, targets[0])
insertMapping(t, db, ctx, certA, targets[1])
// certB has no mappings.
threshold := time.Now().Add(7 * 24 * time.Hour)
got, err := repo.GetExpiringCertificates(ctx, threshold)
if err != nil {
t.Fatalf("GetExpiringCertificates failed: %v", err)
}
found := map[string]*domain.ManagedCertificate{}
for _, c := range got {
found[c.ID] = c
}
a, ok := found[certA]
if !ok {
t.Fatalf("cert %s not in expiring list", certA)
}
if len(a.TargetIDs) != 2 || a.TargetIDs[0] != targets[0] || a.TargetIDs[1] != targets[1] {
t.Errorf("cert %s: TargetIDs = %v, want %v", certA, a.TargetIDs, []string{targets[0], targets[1]})
}
b, ok := found[certB]
if !ok {
t.Fatalf("cert %s not in expiring list", certB)
}
if b.TargetIDs == nil {
t.Errorf("cert %s: TargetIDs = nil, want empty slice", certB)
}
if len(b.TargetIDs) != 0 {
t.Errorf("cert %s: len(TargetIDs) = %d, want 0", certB, len(b.TargetIDs))
}
}
+11 -10
View File
@@ -24,7 +24,7 @@ func NewPolicyRepository(db *sql.DB) *PolicyRepository {
// ListRules returns all policy rules // ListRules returns all policy rules
func (r *PolicyRepository) ListRules(ctx context.Context) ([]*domain.PolicyRule, error) { func (r *PolicyRepository) ListRules(ctx context.Context) ([]*domain.PolicyRule, error) {
rows, err := r.db.QueryContext(ctx, ` rows, err := r.db.QueryContext(ctx, `
SELECT id, name, type, config, enabled, created_at, updated_at SELECT id, name, type, config, enabled, severity, created_at, updated_at
FROM policy_rules FROM policy_rules
ORDER BY created_at DESC ORDER BY created_at DESC
`) `)
@@ -38,7 +38,7 @@ func (r *PolicyRepository) ListRules(ctx context.Context) ([]*domain.PolicyRule,
for rows.Next() { for rows.Next() {
var rule domain.PolicyRule var rule domain.PolicyRule
if err := rows.Scan(&rule.ID, &rule.Name, &rule.Type, &rule.Config, if err := rows.Scan(&rule.ID, &rule.Name, &rule.Type, &rule.Config,
&rule.Enabled, &rule.CreatedAt, &rule.UpdatedAt); err != nil { &rule.Enabled, &rule.Severity, &rule.CreatedAt, &rule.UpdatedAt); err != nil {
return nil, fmt.Errorf("failed to scan policy rule: %w", err) return nil, fmt.Errorf("failed to scan policy rule: %w", err)
} }
rules = append(rules, &rule) rules = append(rules, &rule)
@@ -55,11 +55,11 @@ func (r *PolicyRepository) ListRules(ctx context.Context) ([]*domain.PolicyRule,
func (r *PolicyRepository) GetRule(ctx context.Context, id string) (*domain.PolicyRule, error) { func (r *PolicyRepository) GetRule(ctx context.Context, id string) (*domain.PolicyRule, error) {
var rule domain.PolicyRule var rule domain.PolicyRule
err := r.db.QueryRowContext(ctx, ` err := r.db.QueryRowContext(ctx, `
SELECT id, name, type, config, enabled, created_at, updated_at SELECT id, name, type, config, enabled, severity, created_at, updated_at
FROM policy_rules FROM policy_rules
WHERE id = $1 WHERE id = $1
`, id).Scan(&rule.ID, &rule.Name, &rule.Type, &rule.Config, `, id).Scan(&rule.ID, &rule.Name, &rule.Type, &rule.Config,
&rule.Enabled, &rule.CreatedAt, &rule.UpdatedAt) &rule.Enabled, &rule.Severity, &rule.CreatedAt, &rule.UpdatedAt)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@@ -78,11 +78,11 @@ func (r *PolicyRepository) CreateRule(ctx context.Context, rule *domain.PolicyRu
} }
err := r.db.QueryRowContext(ctx, ` err := r.db.QueryRowContext(ctx, `
INSERT INTO policy_rules (id, name, type, config, enabled, created_at, updated_at) INSERT INTO policy_rules (id, name, type, config, enabled, severity, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id RETURNING id
`, rule.ID, rule.Name, rule.Type, rule.Config, rule.Enabled, `, rule.ID, rule.Name, rule.Type, rule.Config, rule.Enabled,
rule.CreatedAt, rule.UpdatedAt).Scan(&rule.ID) rule.Severity, rule.CreatedAt, rule.UpdatedAt).Scan(&rule.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to create policy rule: %w", err) return fmt.Errorf("failed to create policy rule: %w", err)
@@ -99,9 +99,10 @@ func (r *PolicyRepository) UpdateRule(ctx context.Context, rule *domain.PolicyRu
type = $2, type = $2,
config = $3, config = $3,
enabled = $4, enabled = $4,
updated_at = $5 severity = $5,
WHERE id = $6 updated_at = $6
`, rule.Name, rule.Type, rule.Config, rule.Enabled, rule.UpdatedAt, rule.ID) WHERE id = $7
`, rule.Name, rule.Type, rule.Config, rule.Enabled, rule.Severity, rule.UpdatedAt, rule.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to update policy rule: %w", err) return fmt.Errorf("failed to update policy rule: %w", err)
+187
View File
@@ -457,6 +457,193 @@ func TestAgentRepository_Delete_NotFound(t *testing.T) {
} }
} }
// TestAgentRepository_CreateIfNotExists_FirstInsert verifies that a brand-new
// sentinel agent row is inserted and the helper reports created=true (M-6).
func TestAgentRepository_CreateIfNotExists_FirstInsert(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewAgentRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
agent := &domain.Agent{
ID: "server-scanner",
Name: "Network Scanner (Server-Side)",
Status: domain.AgentStatusOnline,
RegisteredAt: now,
}
created, err := repo.CreateIfNotExists(ctx, agent)
if err != nil {
t.Fatalf("CreateIfNotExists failed: %v", err)
}
if !created {
t.Error("created = false on first insert, want true")
}
got, err := repo.Get(ctx, "server-scanner")
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if got.Name != "Network Scanner (Server-Side)" {
t.Errorf("Name = %q, want %q", got.Name, "Network Scanner (Server-Side)")
}
}
// TestAgentRepository_CreateIfNotExists_Idempotent verifies that a second
// call with the same ID returns created=false and err=nil without mutating
// the existing row — the core M-6 upgrade/restart scenario (CWE-662).
func TestAgentRepository_CreateIfNotExists_Idempotent(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewAgentRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
first := &domain.Agent{
ID: "cloud-aws-sm",
Name: "AWS Secrets Manager Discovery",
Status: domain.AgentStatusOnline,
RegisteredAt: now,
}
created, err := repo.CreateIfNotExists(ctx, first)
if err != nil {
t.Fatalf("first CreateIfNotExists failed: %v", err)
}
if !created {
t.Fatal("first created = false, want true")
}
// Second call with the same ID but a different name must be a no-op.
second := &domain.Agent{
ID: "cloud-aws-sm",
Name: "Overwritten Name Should Not Persist",
Status: domain.AgentStatusOffline,
RegisteredAt: now.Add(time.Hour),
}
created, err = repo.CreateIfNotExists(ctx, second)
if err != nil {
t.Fatalf("second CreateIfNotExists failed: %v", err)
}
if created {
t.Error("second created = true, want false (row already existed)")
}
// Row must still reflect the original insert.
got, err := repo.Get(ctx, "cloud-aws-sm")
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if got.Name != "AWS Secrets Manager Discovery" {
t.Errorf("Name = %q, want %q (ON CONFLICT DO NOTHING must preserve original row)", got.Name, "AWS Secrets Manager Discovery")
}
if got.Status != domain.AgentStatusOnline {
t.Errorf("Status = %q, want %q", got.Status, domain.AgentStatusOnline)
}
}
// TestAgentRepository_CreateIfNotExists_ConcurrentRace fires N concurrent
// inserts for the same sentinel ID. Exactly one goroutine must see
// created=true; every other must see created=false and err=nil. No panics,
// no duplicate rows, no swallowed errors. This is the scenario that the
// pre-M-6 plain-INSERT path masked with a blanket error log.
func TestAgentRepository_CreateIfNotExists_ConcurrentRace(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewAgentRepository(db)
ctx := context.Background()
const N = 16
now := time.Now().Truncate(time.Microsecond)
var (
wg sync.WaitGroup
createdCount int64
errorCount int64
)
wg.Add(N)
for i := 0; i < N; i++ {
go func() {
defer wg.Done()
agent := &domain.Agent{
ID: "cloud-gcp-sm",
Name: "GCP Secret Manager Discovery",
Status: domain.AgentStatusOnline,
RegisteredAt: now,
}
created, err := repo.CreateIfNotExists(ctx, agent)
if err != nil {
atomic.AddInt64(&errorCount, 1)
t.Errorf("CreateIfNotExists returned error: %v", err)
return
}
if created {
atomic.AddInt64(&createdCount, 1)
}
}()
}
wg.Wait()
if errorCount != 0 {
t.Fatalf("errorCount = %d, want 0", errorCount)
}
if createdCount != 1 {
t.Errorf("createdCount = %d, want exactly 1 (only one goroutine may win the insert)", createdCount)
}
// Exactly one row must exist.
agents, err := repo.List(ctx)
if err != nil {
t.Fatalf("List failed: %v", err)
}
count := 0
for _, a := range agents {
if a.ID == "cloud-gcp-sm" {
count++
}
}
if count != 1 {
t.Errorf("row count for cloud-gcp-sm = %d, want 1", count)
}
}
// TestAgentRepository_CreateIfNotExists_GenericErrorSurfaces verifies that
// failures other than the primary-key duplicate (the only collision
// ON CONFLICT (id) absorbs) propagate to the caller instead of being
// swallowed. This is the security property that M-6 restores: the
// pre-fix plain-INSERT path logged every error at Debug level, so a
// connectivity or permission failure would vanish into the log without
// the server surfacing a problem on startup (CWE-662 / CWE-209-adjacent).
//
// Uses a pre-cancelled context to force QueryRowContext to fail with
// context.Canceled — a non-duplicate error class that must surface.
// Does NOT close the shared sql.DB (that would break sibling tests).
func TestAgentRepository_CreateIfNotExists_GenericErrorSurfaces(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewAgentRepository(db)
ctx, cancel := context.WithCancel(context.Background())
cancel() // pre-cancel so the driver round-trip fails immediately.
agent := &domain.Agent{
ID: "server-scanner",
Name: "Network Scanner (Server-Side)",
Status: domain.AgentStatusOnline,
RegisteredAt: time.Now(),
}
created, err := repo.CreateIfNotExists(ctx, agent)
if err == nil {
t.Fatal("expected error on cancelled context, got nil (error would have been swallowed pre-M-6)")
}
if created {
t.Error("created = true on failure, want false")
}
if err == sql.ErrNoRows {
t.Error("got sql.ErrNoRows, want a real connection/context error (ErrNoRows is the duplicate-row sentinel)")
}
}
// ============================================================ // ============================================================
// Issuer Repository Tests // Issuer Repository Tests
// ============================================================ // ============================================================
+2 -8
View File
@@ -91,8 +91,8 @@ func (s *AgentService) Register(ctx context.Context, name string, hostname strin
return agent, apiKey, nil return agent, apiKey, nil
} }
// HeartbeatWithContext updates an agent's last seen time, status, and metadata. // Heartbeat updates an agent's last seen time, status, and metadata.
func (s *AgentService) HeartbeatWithContext(ctx context.Context, agentID string, metadata *domain.AgentMetadata) error { func (s *AgentService) Heartbeat(ctx context.Context, agentID string, metadata *domain.AgentMetadata) error {
agent, err := s.agentRepo.Get(ctx, agentID) agent, err := s.agentRepo.Get(ctx, agentID)
if err != nil { if err != nil {
return fmt.Errorf("failed to fetch agent: %w", err) return fmt.Errorf("failed to fetch agent: %w", err)
@@ -114,12 +114,6 @@ func (s *AgentService) HeartbeatWithContext(ctx context.Context, agentID string,
return nil return nil
} }
// Heartbeat updates agent heartbeat (handler interface method).
// Note: This method is called from handlers which have a context; callers should prefer HeartbeatWithContext.
func (s *AgentService) Heartbeat(ctx context.Context, agentID string, metadata *domain.AgentMetadata) error {
return s.HeartbeatWithContext(ctx, agentID, metadata)
}
// SubmitCSR validates and processes a Certificate Signing Request from an agent. // SubmitCSR validates and processes a Certificate Signing Request from an agent.
// In agent keygen mode, this completes an AwaitingCSR renewal job by signing the CSR // In agent keygen mode, this completes an AwaitingCSR renewal job by signing the CSR
// and storing the cert version. The private key stays on the agent — only the CSR // and storing the cert version. The private key stays on the agent — only the CSR
+2 -2
View File
@@ -92,7 +92,7 @@ func TestHeartbeat(t *testing.T) {
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry, nil) agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry, nil)
err := agentService.HeartbeatWithContext(ctx, "agent-001", nil) err := agentService.Heartbeat(ctx, "agent-001", nil)
if err != nil { if err != nil {
t.Fatalf("Heartbeat failed: %v", err) t.Fatalf("Heartbeat failed: %v", err)
} }
@@ -125,7 +125,7 @@ func TestHeartbeat_NotFound(t *testing.T) {
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry, nil) agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry, nil)
err := agentService.HeartbeatWithContext(ctx, "nonexistent", nil) err := agentService.Heartbeat(ctx, "nonexistent", nil)
if err == nil { if err == nil {
t.Fatal("expected error for nonexistent agent") t.Fatal("expected error for nonexistent agent")
} }
+4 -4
View File
@@ -110,7 +110,7 @@ func (s *AuditService) ListByAction(ctx context.Context, action string, from, to
} }
// ListAuditEvents returns paginated audit events (handler interface method). // ListAuditEvents returns paginated audit events (handler interface method).
func (s *AuditService) ListAuditEvents(page, perPage int) ([]domain.AuditEvent, int64, error) { func (s *AuditService) ListAuditEvents(ctx context.Context, page, perPage int) ([]domain.AuditEvent, int64, error) {
if page < 1 { if page < 1 {
page = 1 page = 1
} }
@@ -123,7 +123,7 @@ func (s *AuditService) ListAuditEvents(page, perPage int) ([]domain.AuditEvent,
PerPage: perPage, PerPage: perPage,
} }
events, err := s.auditRepo.List(context.Background(), filter) events, err := s.auditRepo.List(ctx, filter)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to list audit events: %w", err) return nil, 0, fmt.Errorf("failed to list audit events: %w", err)
} }
@@ -143,13 +143,13 @@ func (s *AuditService) ListAuditEvents(page, perPage int) ([]domain.AuditEvent,
} }
// GetAuditEvent returns a single audit event (handler interface method). // GetAuditEvent returns a single audit event (handler interface method).
func (s *AuditService) GetAuditEvent(id string) (*domain.AuditEvent, error) { func (s *AuditService) GetAuditEvent(ctx context.Context, id string) (*domain.AuditEvent, error) {
filter := &repository.AuditFilter{ filter := &repository.AuditFilter{
ResourceID: id, ResourceID: id,
PerPage: 1, PerPage: 1,
} }
events, err := s.auditRepo.List(context.Background(), filter) events, err := s.auditRepo.List(ctx, filter)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get audit event: %w", err) return nil, fmt.Errorf("failed to get audit event: %w", err)
} }
+13 -13
View File
@@ -41,7 +41,7 @@ func (s *CAOperationsSvc) SetIssuerRegistry(registry *IssuerRegistry) {
// GenerateDERCRL generates a DER-encoded X.509 CRL for the given issuer. // GenerateDERCRL generates a DER-encoded X.509 CRL for the given issuer.
// Short-lived certificates (profile TTL < 1 hour) are excluded from the CRL. // Short-lived certificates (profile TTL < 1 hour) are excluded from the CRL.
func (s *CAOperationsSvc) GenerateDERCRL(issuerID string) ([]byte, error) { func (s *CAOperationsSvc) GenerateDERCRL(ctx context.Context, issuerID string) ([]byte, error) {
if s.revocationRepo == nil { if s.revocationRepo == nil {
return nil, fmt.Errorf("revocation repository not configured") return nil, fmt.Errorf("revocation repository not configured")
} }
@@ -54,7 +54,7 @@ func (s *CAOperationsSvc) GenerateDERCRL(issuerID string) ([]byte, error) {
return nil, fmt.Errorf("issuer not found: %s", issuerID) return nil, fmt.Errorf("issuer not found: %s", issuerID)
} }
revocations, err := s.revocationRepo.ListAll(context.Background()) revocations, err := s.revocationRepo.ListAll(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to list revocations: %w", err) return nil, fmt.Errorf("failed to list revocations: %w", err)
} }
@@ -69,9 +69,9 @@ func (s *CAOperationsSvc) GenerateDERCRL(issuerID string) ([]byte, error) {
// Check short-lived exemption: look up the cert's profile // Check short-lived exemption: look up the cert's profile
if s.profileRepo != nil && s.certRepo != nil { if s.profileRepo != nil && s.certRepo != nil {
cert, err := s.certRepo.Get(context.Background(), rev.CertificateID) cert, err := s.certRepo.Get(ctx, rev.CertificateID)
if err == nil && cert.CertificateProfileID != "" { if err == nil && cert.CertificateProfileID != "" {
profile, err := s.profileRepo.Get(context.Background(), cert.CertificateProfileID) profile, err := s.profileRepo.Get(ctx, cert.CertificateProfileID)
if err == nil && profile.IsShortLived() { if err == nil && profile.IsShortLived() {
slog.Debug("skipping short-lived cert from CRL", slog.Debug("skipping short-lived cert from CRL",
"certificate_id", rev.CertificateID, "certificate_id", rev.CertificateID,
@@ -92,11 +92,11 @@ func (s *CAOperationsSvc) GenerateDERCRL(issuerID string) ([]byte, error) {
}) })
} }
return issuerConn.GenerateCRL(context.Background(), entries) return issuerConn.GenerateCRL(ctx, entries)
} }
// GetOCSPResponse generates a signed OCSP response for the given certificate serial. // GetOCSPResponse generates a signed OCSP response for the given certificate serial.
func (s *CAOperationsSvc) GetOCSPResponse(issuerID string, serialHex string) ([]byte, error) { func (s *CAOperationsSvc) GetOCSPResponse(ctx context.Context, issuerID string, serialHex string) ([]byte, error) {
if s.revocationRepo == nil { if s.revocationRepo == nil {
return nil, fmt.Errorf("revocation repository not configured") return nil, fmt.Errorf("revocation repository not configured")
} }
@@ -120,13 +120,13 @@ func (s *CAOperationsSvc) GetOCSPResponse(issuerID string, serialHex string) ([]
// Look up cert by (issuer_id, serial) — per RFC 5280 §5.2.3, serial numbers // Look up cert by (issuer_id, serial) — per RFC 5280 §5.2.3, serial numbers
// are unique only within a single issuer. The OCSP URL path carries issuer_id, // are unique only within a single issuer. The OCSP URL path carries issuer_id,
// so we scope the lookup to avoid cross-issuer collisions. // so we scope the lookup to avoid cross-issuer collisions.
rev, _ := s.revocationRepo.GetByIssuerAndSerial(context.Background(), issuerID, serialHex) rev, _ := s.revocationRepo.GetByIssuerAndSerial(ctx, issuerID, serialHex)
if rev != nil { if rev != nil {
cert, err := s.certRepo.Get(context.Background(), rev.CertificateID) cert, err := s.certRepo.Get(ctx, rev.CertificateID)
if err == nil && cert.CertificateProfileID != "" { if err == nil && cert.CertificateProfileID != "" {
profile, err := s.profileRepo.Get(context.Background(), cert.CertificateProfileID) profile, err := s.profileRepo.Get(ctx, cert.CertificateProfileID)
if err == nil && profile.IsShortLived() { if err == nil && profile.IsShortLived() {
return issuerConn.SignOCSPResponse(context.Background(), OCSPSignRequest{ return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{
CertSerial: serial, CertSerial: serial,
CertStatus: 0, // good — short-lived exemption CertStatus: 0, // good — short-lived exemption
ThisUpdate: now, ThisUpdate: now,
@@ -138,10 +138,10 @@ func (s *CAOperationsSvc) GetOCSPResponse(issuerID string, serialHex string) ([]
} }
// Check if this (issuer_id, serial) is revoked — RFC 5280 §5.2.3 scoping. // Check if this (issuer_id, serial) is revoked — RFC 5280 §5.2.3 scoping.
rev, err := s.revocationRepo.GetByIssuerAndSerial(context.Background(), issuerID, serialHex) rev, err := s.revocationRepo.GetByIssuerAndSerial(ctx, issuerID, serialHex)
if err != nil { if err != nil {
// Not revoked — return "good" status // Not revoked — return "good" status
return issuerConn.SignOCSPResponse(context.Background(), OCSPSignRequest{ return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{
CertSerial: serial, CertSerial: serial,
CertStatus: 0, // good CertStatus: 0, // good
ThisUpdate: now, ThisUpdate: now,
@@ -150,7 +150,7 @@ func (s *CAOperationsSvc) GetOCSPResponse(issuerID string, serialHex string) ([]
} }
// Revoked // Revoked
return issuerConn.SignOCSPResponse(context.Background(), OCSPSignRequest{ return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{
CertSerial: serial, CertSerial: serial,
CertStatus: 1, // revoked CertStatus: 1, // revoked
RevokedAt: rev.RevokedAt, RevokedAt: rev.RevokedAt,
+5 -4
View File
@@ -3,6 +3,7 @@
package service package service
import ( import (
"context"
"log/slog" "log/slog"
"testing" "testing"
"time" "time"
@@ -48,7 +49,7 @@ func TestCAOperationsSvc_GenerateDERCRL_Success(t *testing.T) {
}, },
} }
crl, err := caSvc.GenerateDERCRL("iss-local") crl, err := caSvc.GenerateDERCRL(context.Background(), "iss-local")
if err != nil { if err != nil {
t.Fatalf("expected no error, got: %v", err) t.Fatalf("expected no error, got: %v", err)
@@ -71,7 +72,7 @@ func TestCAOperationsSvc_GenerateDERCRL_EmptyCRL(t *testing.T) {
// No revoked certs for this issuer // No revoked certs for this issuer
revocationRepo.Revocations = []*domain.CertificateRevocation{} revocationRepo.Revocations = []*domain.CertificateRevocation{}
crl, err := caSvc.GenerateDERCRL("iss-local") crl, err := caSvc.GenerateDERCRL(context.Background(), "iss-local")
if err != nil { if err != nil {
t.Fatalf("expected no error, got: %v", err) t.Fatalf("expected no error, got: %v", err)
@@ -112,7 +113,7 @@ func TestCAOperationsSvc_GetOCSPResponse_Good(t *testing.T) {
certRepo.Versions["cert-ocsp-good"] = []*domain.CertificateVersion{version} certRepo.Versions["cert-ocsp-good"] = []*domain.CertificateVersion{version}
// Request OCSP response for good cert // Request OCSP response for good cert
resp, err := caSvc.GetOCSPResponse("iss-local", "OCSP-GOOD-001") resp, err := caSvc.GetOCSPResponse(context.Background(), "iss-local", "OCSP-GOOD-001")
if err != nil { if err != nil {
t.Fatalf("expected no error, got: %v", err) t.Fatalf("expected no error, got: %v", err)
@@ -165,7 +166,7 @@ func TestCAOperationsSvc_GetOCSPResponse_Revoked(t *testing.T) {
} }
// Request OCSP response for revoked cert // Request OCSP response for revoked cert
resp, err := caSvc.GetOCSPResponse("iss-local", "OCSP-REVOKED-001") resp, err := caSvc.GetOCSPResponse(context.Background(), "iss-local", "OCSP-REVOKED-001")
if err != nil { if err != nil {
t.Fatalf("expected no error, got: %v", err) t.Fatalf("expected no error, got: %v", err)
+31 -46
View File
@@ -71,8 +71,8 @@ func (s *CertificateService) List(ctx context.Context, filter *repository.Certif
// ListCertificatesWithFilter returns a list of certificates with advanced filtering (M20). // ListCertificatesWithFilter returns a list of certificates with advanced filtering (M20).
// This method supports the new M20 filters and returns domain.ManagedCertificate (not pointers). // This method supports the new M20 filters and returns domain.ManagedCertificate (not pointers).
func (s *CertificateService) ListCertificatesWithFilter(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { func (s *CertificateService) ListCertificatesWithFilter(ctx context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
certs, total, err := s.certRepo.List(context.Background(), filter) certs, total, err := s.certRepo.List(ctx, filter)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to list certificates with filter: %w", err) return nil, 0, fmt.Errorf("failed to list certificates with filter: %w", err)
} }
@@ -206,10 +206,10 @@ func (s *CertificateService) GetVersions(ctx context.Context, certID string) ([]
return versions, nil return versions, nil
} }
// TriggerRenewalWithActor initiates a renewal job if the certificate is eligible. // TriggerRenewal initiates a renewal job if the certificate is eligible.
// Creates a Renewal job (or Issuance for new certs) so the scheduler's job processor // Creates a Renewal job (or Issuance for new certs) so the scheduler's job processor
// can pick it up and route it through the issuer connector. // can pick it up and route it through the issuer connector.
func (s *CertificateService) TriggerRenewalWithActor(ctx context.Context, certID string, actor string) error { func (s *CertificateService) TriggerRenewal(ctx context.Context, certID string, actor string) error {
cert, err := s.certRepo.Get(ctx, certID) cert, err := s.certRepo.Get(ctx, certID)
if err != nil { if err != nil {
return fmt.Errorf("failed to fetch certificate: %w", err) return fmt.Errorf("failed to fetch certificate: %w", err)
@@ -283,8 +283,11 @@ func (s *CertificateService) TriggerRenewalWithActor(ctx context.Context, certID
return nil return nil
} }
// TriggerDeploymentWithActor creates deployment jobs for all targets of a certificate. // TriggerDeployment creates deployment jobs for all targets of a certificate.
func (s *CertificateService) TriggerDeploymentWithActor(ctx context.Context, certID string, actor string) error { // The targetID parameter is accepted from the handler interface but currently unused;
// deployment coordination happens per-certificate across all of its targets.
func (s *CertificateService) TriggerDeployment(ctx context.Context, certID string, targetID string, actor string) error {
_ = targetID
cert, err := s.certRepo.Get(ctx, certID) cert, err := s.certRepo.Get(ctx, certID)
if err != nil { if err != nil {
return fmt.Errorf("failed to fetch certificate: %w", err) return fmt.Errorf("failed to fetch certificate: %w", err)
@@ -306,7 +309,7 @@ func (s *CertificateService) TriggerDeploymentWithActor(ctx context.Context, cer
} }
// ListCertificates returns paginated certificates with optional filtering (handler interface method). // ListCertificates returns paginated certificates with optional filtering (handler interface method).
func (s *CertificateService) ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) { func (s *CertificateService) ListCertificates(ctx context.Context, status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
if page < 1 { if page < 1 {
page = 1 page = 1
} }
@@ -325,7 +328,7 @@ func (s *CertificateService) ListCertificates(status, environment, ownerID, team
PerPage: perPage, PerPage: perPage,
} }
certs, total, err := s.certRepo.List(context.Background(), filter) certs, total, err := s.certRepo.List(ctx, filter)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to list certificates: %w", err) return nil, 0, fmt.Errorf("failed to list certificates: %w", err)
} }
@@ -341,12 +344,12 @@ func (s *CertificateService) ListCertificates(status, environment, ownerID, team
} }
// GetCertificate returns a single certificate (handler interface method). // GetCertificate returns a single certificate (handler interface method).
func (s *CertificateService) GetCertificate(id string) (*domain.ManagedCertificate, error) { func (s *CertificateService) GetCertificate(ctx context.Context, id string) (*domain.ManagedCertificate, error) {
return s.certRepo.Get(context.Background(), id) return s.certRepo.Get(ctx, id)
} }
// CreateCertificate creates a new certificate (handler interface method). // CreateCertificate creates a new certificate (handler interface method).
func (s *CertificateService) CreateCertificate(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) { func (s *CertificateService) CreateCertificate(ctx context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
if cert.ID == "" { if cert.ID == "" {
cert.ID = generateID("cert") cert.ID = generateID("cert")
} }
@@ -365,16 +368,14 @@ func (s *CertificateService) CreateCertificate(cert domain.ManagedCertificate) (
if cert.Tags == nil { if cert.Tags == nil {
cert.Tags = make(map[string]string) cert.Tags = make(map[string]string)
} }
if err := s.certRepo.Create(context.Background(), &cert); err != nil { if err := s.certRepo.Create(ctx, &cert); err != nil {
return nil, fmt.Errorf("failed to create certificate: %w", err) return nil, fmt.Errorf("failed to create certificate: %w", err)
} }
return &cert, nil return &cert, nil
} }
// UpdateCertificate modifies a certificate (handler interface method). // UpdateCertificate modifies a certificate (handler interface method).
func (s *CertificateService) UpdateCertificate(id string, patch domain.ManagedCertificate) (*domain.ManagedCertificate, error) { func (s *CertificateService) UpdateCertificate(ctx context.Context, id string, patch domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
ctx := context.Background()
// Fetch existing certificate so partial updates don't zero out fields // Fetch existing certificate so partial updates don't zero out fields
existing, err := s.certRepo.Get(ctx, id) existing, err := s.certRepo.Get(ctx, id)
if err != nil { if err != nil {
@@ -425,12 +426,12 @@ func (s *CertificateService) UpdateCertificate(id string, patch domain.ManagedCe
} }
// ArchiveCertificate marks a certificate as archived (handler interface method). // ArchiveCertificate marks a certificate as archived (handler interface method).
func (s *CertificateService) ArchiveCertificate(id string) error { func (s *CertificateService) ArchiveCertificate(ctx context.Context, id string) error {
return s.certRepo.Archive(context.Background(), id) return s.certRepo.Archive(ctx, id)
} }
// GetCertificateVersions returns certificate versions (handler interface method). // GetCertificateVersions returns certificate versions (handler interface method).
func (s *CertificateService) GetCertificateVersions(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) { func (s *CertificateService) GetCertificateVersions(ctx context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
if page < 1 { if page < 1 {
page = 1 page = 1
} }
@@ -438,7 +439,7 @@ func (s *CertificateService) GetCertificateVersions(certID string, page, perPage
perPage = 50 perPage = 50
} }
versions, err := s.certRepo.ListVersions(context.Background(), certID) versions, err := s.certRepo.ListVersions(ctx, certID)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to list certificate versions: %w", err) return nil, 0, fmt.Errorf("failed to list certificate versions: %w", err)
} }
@@ -463,24 +464,8 @@ func (s *CertificateService) GetCertificateVersions(certID string, page, perPage
return result, total, nil return result, total, nil
} }
// TriggerRenewal initiates renewal (handler interface method). // RevokeCertificate performs revocation with actor tracking. Delegates to RevocationSvc.
func (s *CertificateService) TriggerRenewal(certID string) error { func (s *CertificateService) RevokeCertificate(ctx context.Context, certID string, reason string, actor string) error {
return s.TriggerRenewalWithActor(context.Background(), certID, "api")
}
// TriggerDeployment triggers deployment (handler interface method).
func (s *CertificateService) TriggerDeployment(certID string, targetID string) error {
return s.TriggerDeploymentWithActor(context.Background(), certID, "api")
}
// RevokeCertificate revokes a certificate with the given reason (handler interface method).
func (s *CertificateService) RevokeCertificate(certID string, reason string) error {
return s.RevokeCertificateWithActor(context.Background(), certID, reason, "api")
}
// RevokeCertificateWithActor performs revocation with actor tracking.
// Delegates to RevocationSvc.
func (s *CertificateService) RevokeCertificateWithActor(ctx context.Context, certID string, reason string, actor string) error {
if s.revSvc == nil { if s.revSvc == nil {
return fmt.Errorf("revocation service not configured") return fmt.Errorf("revocation service not configured")
} }
@@ -489,35 +474,35 @@ func (s *CertificateService) RevokeCertificateWithActor(ctx context.Context, cer
// GetRevokedCertificates returns all revoked certificate records (for CRL generation). // GetRevokedCertificates returns all revoked certificate records (for CRL generation).
// Delegates to RevocationSvc. // Delegates to RevocationSvc.
func (s *CertificateService) GetRevokedCertificates() ([]*domain.CertificateRevocation, error) { func (s *CertificateService) GetRevokedCertificates(ctx context.Context) ([]*domain.CertificateRevocation, error) {
if s.revSvc == nil { if s.revSvc == nil {
return nil, fmt.Errorf("revocation service not configured") return nil, fmt.Errorf("revocation service not configured")
} }
return s.revSvc.GetRevokedCertificates() return s.revSvc.GetRevokedCertificates(ctx)
} }
// GenerateDERCRL generates a DER-encoded X.509 CRL for the given issuer. // GenerateDERCRL generates a DER-encoded X.509 CRL for the given issuer.
// Delegates to CAOperationsSvc. // Delegates to CAOperationsSvc.
func (s *CertificateService) GenerateDERCRL(issuerID string) ([]byte, error) { func (s *CertificateService) GenerateDERCRL(ctx context.Context, issuerID string) ([]byte, error) {
if s.caSvc == nil { if s.caSvc == nil {
return nil, fmt.Errorf("CA operations service not configured") return nil, fmt.Errorf("CA operations service not configured")
} }
return s.caSvc.GenerateDERCRL(issuerID) return s.caSvc.GenerateDERCRL(ctx, issuerID)
} }
// GetOCSPResponse generates a signed OCSP response for the given certificate serial. // GetOCSPResponse generates a signed OCSP response for the given certificate serial.
// Delegates to CAOperationsSvc. // Delegates to CAOperationsSvc.
func (s *CertificateService) GetOCSPResponse(issuerID string, serialHex string) ([]byte, error) { func (s *CertificateService) GetOCSPResponse(ctx context.Context, issuerID string, serialHex string) ([]byte, error) {
if s.caSvc == nil { if s.caSvc == nil {
return nil, fmt.Errorf("CA operations service not configured") return nil, fmt.Errorf("CA operations service not configured")
} }
return s.caSvc.GetOCSPResponse(issuerID, serialHex) return s.caSvc.GetOCSPResponse(ctx, issuerID, serialHex)
} }
// GetCertificateDeployments returns all deployment targets for a certificate (M20). // GetCertificateDeployments returns all deployment targets for a certificate (M20).
func (s *CertificateService) GetCertificateDeployments(certID string) ([]domain.DeploymentTarget, error) { func (s *CertificateService) GetCertificateDeployments(ctx context.Context, certID string) ([]domain.DeploymentTarget, error) {
// Verify certificate exists // Verify certificate exists
_, err := s.certRepo.Get(context.Background(), certID) _, err := s.certRepo.Get(ctx, certID)
if err != nil { if err != nil {
return nil, fmt.Errorf("certificate not found: %w", err) return nil, fmt.Errorf("certificate not found: %w", err)
} }
@@ -527,7 +512,7 @@ func (s *CertificateService) GetCertificateDeployments(certID string) ([]domain.
} }
// Get targets from repository // Get targets from repository
targets, err := s.targetRepo.ListByCertificate(context.Background(), certID) targets, err := s.targetRepo.ListByCertificate(ctx, certID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to list deployment targets: %w", err) return nil, fmt.Errorf("failed to list deployment targets: %w", err)
} }
+11 -11
View File
@@ -34,7 +34,7 @@ func TestCertificateService_RevokeCertificate_RevocationSvcNil(t *testing.T) {
certRepo.AddCert(cert) certRepo.AddCert(cert)
// Call RevokeCertificateWithActor with nil RevocationSvc // Call RevokeCertificateWithActor with nil RevocationSvc
err := certService.RevokeCertificateWithActor(context.Background(), "cert-1", "keyCompromise", "admin") err := certService.RevokeCertificate(context.Background(), "cert-1", "keyCompromise", "admin")
// Assert: Should return error, NOT panic // Assert: Should return error, NOT panic
if err == nil { if err == nil {
@@ -64,7 +64,7 @@ func TestCertificateService_GenerateDERCRL_CAOpsSvcNil(t *testing.T) {
// Note: NOT calling certService.SetCAOperationsSvc(...) // Note: NOT calling certService.SetCAOperationsSvc(...)
// Call GenerateDERCRL with nil CAOperationsSvc // Call GenerateDERCRL with nil CAOperationsSvc
_, err := certService.GenerateDERCRL("iss-local") _, err := certService.GenerateDERCRL(context.Background(), "iss-local")
// Assert: Should return error, NOT panic // Assert: Should return error, NOT panic
if err == nil { if err == nil {
@@ -94,7 +94,7 @@ func TestCertificateService_GetOCSPResponse_CAOpsSvcNil(t *testing.T) {
// Note: NOT calling certService.SetCAOperationsSvc(...) // Note: NOT calling certService.SetCAOperationsSvc(...)
// Call GetOCSPResponse with nil CAOperationsSvc // Call GetOCSPResponse with nil CAOperationsSvc
_, err := certService.GetOCSPResponse("iss-local", "serial123") _, err := certService.GetOCSPResponse(context.Background(), "iss-local", "serial123")
// Assert: Should return error, NOT panic // Assert: Should return error, NOT panic
if err == nil { if err == nil {
@@ -124,7 +124,7 @@ func TestCertificateService_GetRevokedCertificates_RevocationSvcNil(t *testing.T
// Note: NOT calling certService.SetRevocationSvc(...) // Note: NOT calling certService.SetRevocationSvc(...)
// Call GetRevokedCertificates with nil RevocationSvc // Call GetRevokedCertificates with nil RevocationSvc
_, err := certService.GetRevokedCertificates() _, err := certService.GetRevokedCertificates(context.Background())
// Assert: Should return error, NOT panic // Assert: Should return error, NOT panic
if err == nil { if err == nil {
@@ -177,7 +177,7 @@ func TestCertificateService_GetCertificateDeployments_Success(t *testing.T) {
targetRepo.AddTarget(target2) targetRepo.AddTarget(target2)
// Call GetCertificateDeployments // Call GetCertificateDeployments
deployments, err := certService.GetCertificateDeployments("cert-1") deployments, err := certService.GetCertificateDeployments(context.Background(), "cert-1")
// Assert: Should return deployment list successfully // Assert: Should return deployment list successfully
if err != nil { if err != nil {
@@ -218,7 +218,7 @@ func TestCertificateService_GetCertificateDeployments_RepositoryError(t *testing
certRepo.AddCert(cert) certRepo.AddCert(cert)
// Call GetCertificateDeployments with repo error // Call GetCertificateDeployments with repo error
_, err := certService.GetCertificateDeployments("cert-1") _, err := certService.GetCertificateDeployments(context.Background(), "cert-1")
// Assert: Should return error, NOT panic // Assert: Should return error, NOT panic
if err == nil { if err == nil {
@@ -247,7 +247,7 @@ func TestCertificateService_GetCertificateDeployments_CertNotFound(t *testing.T)
certService.SetTargetRepo(targetRepo) certService.SetTargetRepo(targetRepo)
// Call GetCertificateDeployments with nonexistent certificate // Call GetCertificateDeployments with nonexistent certificate
_, err := certService.GetCertificateDeployments("nonexistent-cert") _, err := certService.GetCertificateDeployments(context.Background(), "nonexistent-cert")
// Assert: Should return error // Assert: Should return error
if err == nil { if err == nil {
@@ -283,7 +283,7 @@ func TestCertificateService_GetCertificateDeployments_NilTargetRepo(t *testing.T
certRepo.AddCert(cert) certRepo.AddCert(cert)
// Call GetCertificateDeployments with nil TargetRepo // Call GetCertificateDeployments with nil TargetRepo
deployments, err := certService.GetCertificateDeployments("cert-1") deployments, err := certService.GetCertificateDeployments(context.Background(), "cert-1")
// Assert: Should return empty list gracefully (not panic) // Assert: Should return empty list gracefully (not panic)
if err != nil { if err != nil {
@@ -337,19 +337,19 @@ func TestCertificateService_Multiple_NilSafetyChecks(t *testing.T) {
revSvc.SetIssuerRegistry(registry) revSvc.SetIssuerRegistry(registry)
// Test 1: RevokeCertificateWithActor should succeed (RevocationSvc is set) // Test 1: RevokeCertificateWithActor should succeed (RevocationSvc is set)
errRevoke := certService.RevokeCertificateWithActor(context.Background(), "cert-1", "keyCompromise", "admin") errRevoke := certService.RevokeCertificate(context.Background(), "cert-1", "keyCompromise", "admin")
if errRevoke != nil { if errRevoke != nil {
t.Fatalf("RevokeCertificateWithActor failed unexpectedly: %v", errRevoke) t.Fatalf("RevokeCertificateWithActor failed unexpectedly: %v", errRevoke)
} }
// Test 2: GenerateDERCRL should fail gracefully (CAOperationsSvc is nil) // Test 2: GenerateDERCRL should fail gracefully (CAOperationsSvc is nil)
_, errCRL := certService.GenerateDERCRL("iss-local") _, errCRL := certService.GenerateDERCRL(context.Background(), "iss-local")
if errCRL == nil { if errCRL == nil {
t.Fatal("GenerateDERCRL expected error, got nil") t.Fatal("GenerateDERCRL expected error, got nil")
} }
// Test 3: GetOCSPResponse should fail gracefully (CAOperationsSvc is nil) // Test 3: GetOCSPResponse should fail gracefully (CAOperationsSvc is nil)
_, errOCSP := certService.GetOCSPResponse("iss-local", "ABC123") _, errOCSP := certService.GetOCSPResponse(context.Background(), "iss-local", "ABC123")
if errOCSP == nil { if errOCSP == nil {
t.Fatal("GetOCSPResponse expected error, got nil") t.Fatal("GetOCSPResponse expected error, got nil")
} }
+4 -3
View File
@@ -294,7 +294,7 @@ func TestTriggerRenewal(t *testing.T) {
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
certService := NewCertificateService(certRepo, policyService, auditService) certService := NewCertificateService(certRepo, policyService, auditService)
err := certService.TriggerRenewalWithActor(ctx, "cert-001", "user-1") err := certService.TriggerRenewal(ctx, "cert-001", "user-1")
if err != nil { if err != nil {
t.Fatalf("TriggerRenewal failed: %v", err) t.Fatalf("TriggerRenewal failed: %v", err)
} }
@@ -333,13 +333,14 @@ func TestTriggerRenewal_Archived(t *testing.T) {
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
certService := NewCertificateService(certRepo, policyService, auditService) certService := NewCertificateService(certRepo, policyService, auditService)
err := certService.TriggerRenewalWithActor(ctx, "cert-001", "user-1") err := certService.TriggerRenewal(ctx, "cert-001", "user-1")
if err == nil { if err == nil {
t.Fatal("expected error for archived certificate") t.Fatal("expected error for archived certificate")
} }
} }
func TestListCertificates(t *testing.T) { func TestListCertificates(t *testing.T) {
ctx := context.Background()
now := time.Now() now := time.Now()
cert1 := &domain.ManagedCertificate{ cert1 := &domain.ManagedCertificate{
ID: "cert-001", ID: "cert-001",
@@ -369,7 +370,7 @@ func TestListCertificates(t *testing.T) {
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
certService := NewCertificateService(certRepo, policyService, auditService) certService := NewCertificateService(certRepo, policyService, auditService)
certs, total, err := certService.ListCertificates("", "", "", "", "", 1, 50) certs, total, err := certService.ListCertificates(ctx, "", "", "", "", "", 1, 50)
if err != nil { if err != nil {
t.Fatalf("ListCertificates failed: %v", err) t.Fatalf("ListCertificates failed: %v", err)
} }
+3 -3
View File
@@ -159,7 +159,7 @@ func TestConcurrentAgentHeartbeats(t *testing.T) {
Architecture: "x86_64", Architecture: "x86_64",
} }
err := agentSvc.HeartbeatWithContext(ctx, agentID, metadata) err := agentSvc.Heartbeat(ctx, agentID, metadata)
if err != nil { if err != nil {
errChan <- fmt.Errorf("goroutine %d: failed heartbeat for agent %s: %w", idx, agentID, err) errChan <- fmt.Errorf("goroutine %d: failed heartbeat for agent %s: %w", idx, agentID, err)
return return
@@ -194,7 +194,7 @@ func TestConcurrentTargetCRUD(t *testing.T) {
Targets: make(map[string]*domain.DeploymentTarget), Targets: make(map[string]*domain.DeploymentTarget),
} }
targetSvc := NewTargetService(mockTargetRepo, nil, nil, nil, slog.New(slog.NewTextHandler(os.Stderr, nil))) targetSvc := NewTargetService(mockTargetRepo, nil, nil, "", slog.New(slog.NewTextHandler(os.Stderr, nil)))
var mu sync.Mutex var mu sync.Mutex
createdTargets := make([]string, 0) createdTargets := make([]string, 0)
@@ -403,7 +403,7 @@ func TestConcurrentMixedOperations(t *testing.T) {
// Setup services // Setup services
auditSvc := &AuditService{auditRepo: mockAuditRepo} auditSvc := &AuditService{auditRepo: mockAuditRepo}
certSvc := NewCertificateService(mockCertRepo, nil, auditSvc) certSvc := NewCertificateService(mockCertRepo, nil, auditSvc)
targetSvc := NewTargetService(mockTargetRepo, auditSvc, nil, nil, slog.New(slog.NewTextHandler(os.Stderr, nil))) targetSvc := NewTargetService(mockTargetRepo, auditSvc, nil, "", slog.New(slog.NewTextHandler(os.Stderr, nil)))
var wg sync.WaitGroup var wg sync.WaitGroup
errChan := make(chan error, 30) errChan := make(chan error, 30)
+5 -5
View File
@@ -142,7 +142,7 @@ func TestTargetService_ListWithCancelledContext(t *testing.T) {
mockTargetRepo := &mockTargetRepo{ mockTargetRepo := &mockTargetRepo{
Targets: make(map[string]*domain.DeploymentTarget), Targets: make(map[string]*domain.DeploymentTarget),
} }
targetSvc := NewTargetService(mockTargetRepo, nil, nil, nil, slog.New(slog.NewTextHandler(os.Stderr, nil))) targetSvc := NewTargetService(mockTargetRepo, nil, nil, "", slog.New(slog.NewTextHandler(os.Stderr, nil)))
_, _, err := targetSvc.List(ctx, 1, 50) _, _, err := targetSvc.List(ctx, 1, 50)
@@ -176,13 +176,13 @@ func TestAgentService_HeartbeatWithCancelledContext(t *testing.T) {
nil, // renewalService nil, // renewalService
) )
err := agentSvc.HeartbeatWithContext(ctx, "agent-1", &domain.AgentMetadata{}) err := agentSvc.Heartbeat(ctx, "agent-1", &domain.AgentMetadata{})
// Service should handle cancelled context // Service should handle cancelled context
if err == nil || ctx.Err() == context.Canceled { if err == nil || ctx.Err() == context.Canceled {
return return
} }
t.Logf("HeartbeatWithContext with cancelled context returned: %v", err) t.Logf("Heartbeat with cancelled context returned: %v", err)
} }
// Test with timeout context (should trigger deadline exceeded) // Test with timeout context (should trigger deadline exceeded)
@@ -229,11 +229,11 @@ func TestAgentService_HeartbeatWithDeadlineExceeded(t *testing.T) {
time.Sleep(10 * time.Millisecond) // Ensure deadline is exceeded time.Sleep(10 * time.Millisecond) // Ensure deadline is exceeded
err := agentSvc.HeartbeatWithContext(ctx, "agent-1", &domain.AgentMetadata{}) err := agentSvc.Heartbeat(ctx, "agent-1", &domain.AgentMetadata{})
// Service should handle deadline exceeded // Service should handle deadline exceeded
if err == nil || ctx.Err() == context.DeadlineExceeded { if err == nil || ctx.Err() == context.DeadlineExceeded {
return return
} }
t.Logf("HeartbeatWithContext with deadline exceeded returned: %v", err) t.Logf("Heartbeat with deadline exceeded returned: %v", err)
} }
+25 -23
View File
@@ -17,20 +17,27 @@ import (
) )
// IssuerService provides business logic for certificate issuer management. // IssuerService provides business logic for certificate issuer management.
//
// The encryptionKey field holds the raw passphrase (not a pre-derived 32-byte
// key). Per-ciphertext salt derivation is performed inside
// [crypto.EncryptIfKeySet] / [crypto.DecryptIfKeySet] on each call. See M-8
// in certctl-audit-report.md.
type IssuerService struct { type IssuerService struct {
issuerRepo repository.IssuerRepository issuerRepo repository.IssuerRepository
auditService *AuditService auditService *AuditService
registry *IssuerRegistry registry *IssuerRegistry
encryptionKey []byte encryptionKey string
logger *slog.Logger logger *slog.Logger
} }
// NewIssuerService creates a new issuer service. // NewIssuerService creates a new issuer service. The encryptionKey is the raw
// passphrase; it MUST NOT be pre-derived via crypto.DeriveKey (that was the
// v1 behavior, replaced in M-8 with per-ciphertext random salt).
func NewIssuerService( func NewIssuerService(
issuerRepo repository.IssuerRepository, issuerRepo repository.IssuerRepository,
auditService *AuditService, auditService *AuditService,
registry *IssuerRegistry, registry *IssuerRegistry,
encryptionKey []byte, encryptionKey string,
logger *slog.Logger, logger *slog.Logger,
) *IssuerService { ) *IssuerService {
return &IssuerService{ return &IssuerService{
@@ -253,9 +260,9 @@ func (s *IssuerService) Delete(ctx context.Context, id string, actor string) err
return nil return nil
} }
// TestConnectionWithContext tests the connection to an issuer by instantiating a throwaway // TestConnection tests the connection to an issuer by instantiating a throwaway
// connector and calling ValidateConfig. Records the result in the database. // connector and calling ValidateConfig. Records the result in the database.
func (s *IssuerService) TestConnectionWithContext(ctx context.Context, id string) error { func (s *IssuerService) TestConnection(ctx context.Context, id string) error {
iss, err := s.issuerRepo.Get(ctx, id) iss, err := s.issuerRepo.Get(ctx, id)
if err != nil { if err != nil {
return fmt.Errorf("issuer not found: %w", err) return fmt.Errorf("issuer not found: %w", err)
@@ -284,11 +291,6 @@ func (s *IssuerService) TestConnectionWithContext(ctx context.Context, id string
return nil return nil
} }
// TestConnection verifies the issuer connection (handler interface method).
func (s *IssuerService) TestConnection(id string) error {
return s.TestConnectionWithContext(context.Background(), id)
}
// BuildRegistry loads all enabled issuers from the database and rebuilds the dynamic registry. // BuildRegistry loads all enabled issuers from the database and rebuilds the dynamic registry.
// Called at server startup. Partial failures (individual issuers failing to load) are logged // Called at server startup. Partial failures (individual issuers failing to load) are logged
// as warnings but don't prevent the server from starting. // as warnings but don't prevent the server from starting.
@@ -626,7 +628,7 @@ func (s *IssuerService) buildEnvVarSeeds(cfg *config.Config) []*domain.Issuer {
} }
// ListIssuers returns paginated issuers (handler interface method). // ListIssuers returns paginated issuers (handler interface method).
func (s *IssuerService) ListIssuers(page, perPage int) ([]domain.Issuer, int64, error) { func (s *IssuerService) ListIssuers(ctx context.Context, page, perPage int) ([]domain.Issuer, int64, error) {
if page < 1 { if page < 1 {
page = 1 page = 1
} }
@@ -634,7 +636,7 @@ func (s *IssuerService) ListIssuers(page, perPage int) ([]domain.Issuer, int64,
perPage = 50 perPage = 50
} }
issuers, err := s.issuerRepo.List(context.Background()) issuers, err := s.issuerRepo.List(ctx)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to list issuers: %w", err) return nil, 0, fmt.Errorf("failed to list issuers: %w", err)
} }
@@ -651,12 +653,12 @@ func (s *IssuerService) ListIssuers(page, perPage int) ([]domain.Issuer, int64,
} }
// GetIssuer returns a single issuer (handler interface method). // GetIssuer returns a single issuer (handler interface method).
func (s *IssuerService) GetIssuer(id string) (*domain.Issuer, error) { func (s *IssuerService) GetIssuer(ctx context.Context, id string) (*domain.Issuer, error) {
return s.issuerRepo.Get(context.Background(), id) return s.issuerRepo.Get(ctx, id)
} }
// CreateIssuer creates a new issuer (handler interface method). // CreateIssuer creates a new issuer (handler interface method).
func (s *IssuerService) CreateIssuer(iss domain.Issuer) (*domain.Issuer, error) { func (s *IssuerService) CreateIssuer(ctx context.Context, iss domain.Issuer) (*domain.Issuer, error) {
iss.Type = normalizeIssuerType(iss.Type) iss.Type = normalizeIssuerType(iss.Type)
if !isValidIssuerType(iss.Type) { if !isValidIssuerType(iss.Type) {
return nil, fmt.Errorf("unsupported issuer type: %s", iss.Type) return nil, fmt.Errorf("unsupported issuer type: %s", iss.Type)
@@ -693,26 +695,26 @@ func (s *IssuerService) CreateIssuer(iss domain.Issuer) (*domain.Issuer, error)
iss.Config = redactConfigJSON(iss.Config) iss.Config = redactConfigJSON(iss.Config)
} }
if err := s.issuerRepo.Create(context.Background(), &iss); err != nil { if err := s.issuerRepo.Create(ctx, &iss); err != nil {
return nil, fmt.Errorf("failed to create issuer: %w", err) return nil, fmt.Errorf("failed to create issuer: %w", err)
} }
// Rebuild registry // Rebuild registry
if iss.Enabled { if iss.Enabled {
s.rebuildRegistryQuiet(context.Background()) s.rebuildRegistryQuiet(ctx)
} }
return &iss, nil return &iss, nil
} }
// UpdateIssuer modifies an issuer (handler interface method). // UpdateIssuer modifies an issuer (handler interface method).
func (s *IssuerService) UpdateIssuer(id string, iss domain.Issuer) (*domain.Issuer, error) { func (s *IssuerService) UpdateIssuer(ctx context.Context, id string, iss domain.Issuer) (*domain.Issuer, error) {
iss.ID = id iss.ID = id
iss.UpdatedAt = time.Now() iss.UpdatedAt = time.Now()
// Merge redacted fields with existing config // Merge redacted fields with existing config
if len(iss.Config) > 0 { if len(iss.Config) > 0 {
mergedConfig, err := s.mergeRedactedConfig(context.Background(), id, iss.Config) mergedConfig, err := s.mergeRedactedConfig(ctx, id, iss.Config)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to merge config: %w", err) return nil, fmt.Errorf("failed to merge config: %w", err)
} }
@@ -725,18 +727,18 @@ func (s *IssuerService) UpdateIssuer(id string, iss domain.Issuer) (*domain.Issu
iss.Config = redactConfigJSON(json.RawMessage(mergedConfig)) iss.Config = redactConfigJSON(json.RawMessage(mergedConfig))
} }
if err := s.issuerRepo.Update(context.Background(), &iss); err != nil { if err := s.issuerRepo.Update(ctx, &iss); err != nil {
return nil, fmt.Errorf("failed to update issuer: %w", err) return nil, fmt.Errorf("failed to update issuer: %w", err)
} }
s.rebuildRegistryQuiet(context.Background()) s.rebuildRegistryQuiet(ctx)
return &iss, nil return &iss, nil
} }
// DeleteIssuer removes an issuer (handler interface method). // DeleteIssuer removes an issuer (handler interface method).
func (s *IssuerService) DeleteIssuer(id string) error { func (s *IssuerService) DeleteIssuer(ctx context.Context, id string) error {
if err := s.issuerRepo.Delete(context.Background(), id); err != nil { if err := s.issuerRepo.Delete(ctx, id); err != nil {
return err return err
} }
if s.registry != nil { if s.registry != nil {
+8 -8
View File
@@ -26,7 +26,7 @@ func TestBuildEnvVarSeeds_ACMEConfig(t *testing.T) {
auditRepo := newMockAuditRepository() auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default()) service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
// Call buildEnvVarSeeds (unexported method, but testable from same package) // Call buildEnvVarSeeds (unexported method, but testable from same package)
seeds := service.buildEnvVarSeeds(cfg) seeds := service.buildEnvVarSeeds(cfg)
@@ -82,7 +82,7 @@ func TestBuildEnvVarSeeds_VaultConfig(t *testing.T) {
auditRepo := newMockAuditRepository() auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default()) service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
seeds := service.buildEnvVarSeeds(cfg) seeds := service.buildEnvVarSeeds(cfg)
@@ -136,7 +136,7 @@ func TestBuildEnvVarSeeds_NoConfig(t *testing.T) {
auditRepo := newMockAuditRepository() auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default()) service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
seeds := service.buildEnvVarSeeds(cfg) seeds := service.buildEnvVarSeeds(cfg)
@@ -186,7 +186,7 @@ func TestBuildEnvVarSeeds_MultipleConfigs(t *testing.T) {
auditRepo := newMockAuditRepository() auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default()) service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
seeds := service.buildEnvVarSeeds(cfg) seeds := service.buildEnvVarSeeds(cfg)
@@ -232,7 +232,7 @@ func TestSeedFromEnvVars_Empty(t *testing.T) {
auditRepo := newMockAuditRepository() auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default()) service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
// Call SeedFromEnvVars on empty repo // Call SeedFromEnvVars on empty repo
service.SeedFromEnvVars(ctx, cfg) service.SeedFromEnvVars(ctx, cfg)
@@ -280,7 +280,7 @@ func TestSeedFromEnvVars_AlreadyExists(t *testing.T) {
auditRepo := newMockAuditRepository() auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default()) service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
// Get count before seeding // Get count before seeding
beforeSeeding, _ := repo.List(ctx) beforeSeeding, _ := repo.List(ctx)
@@ -328,7 +328,7 @@ func TestBuildRegistry_Success(t *testing.T) {
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
registry := NewIssuerRegistry(slog.Default()) registry := NewIssuerRegistry(slog.Default())
service := NewIssuerService(repo, auditService, registry, nil, slog.Default()) service := NewIssuerService(repo, auditService, registry, "", slog.Default())
// Call BuildRegistry // Call BuildRegistry
err := service.BuildRegistry(ctx) err := service.BuildRegistry(ctx)
@@ -351,7 +351,7 @@ func TestBuildRegistry_EmptyDatabase(t *testing.T) {
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
registry := NewIssuerRegistry(slog.Default()) registry := NewIssuerRegistry(slog.Default())
service := NewIssuerService(repo, auditService, registry, nil, slog.Default()) service := NewIssuerService(repo, auditService, registry, "", slog.Default())
// Call BuildRegistry on empty database // Call BuildRegistry on empty database
err := service.BuildRegistry(ctx) err := service.BuildRegistry(ctx)
+6 -1
View File
@@ -72,7 +72,12 @@ func (r *IssuerRegistry) Len() int {
// For each enabled issuer, it decrypts the config (if encryption key is set), // For each enabled issuer, it decrypts the config (if encryption key is set),
// instantiates a connector via the factory, wraps it in an adapter, and // instantiates a connector via the factory, wraps it in an adapter, and
// atomically swaps the entire map. // atomically swaps the entire map.
func (r *IssuerRegistry) Rebuild(configs []*domain.Issuer, encryptionKey []byte) error { //
// The encryption passphrase is passed as a string; per-ciphertext salt derivation
// for v2 blobs is performed inside [crypto.DecryptIfKeySet]. Empty passphrase
// fails closed via [crypto.ErrEncryptionKeyRequired] when encrypted configs
// are encountered. See M-8 in certctl-audit-report.md.
func (r *IssuerRegistry) Rebuild(configs []*domain.Issuer, encryptionKey string) error {
newIssuers := make(map[string]IssuerConnector) newIssuers := make(map[string]IssuerConnector)
var errors []string var errors []string
+13 -11
View File
@@ -101,7 +101,7 @@ func TestIssuerRegistry_Rebuild_Enabled(t *testing.T) {
}, },
} }
err := reg.Rebuild(configs, nil) err := reg.Rebuild(configs, "")
if err != nil { if err != nil {
t.Fatalf("Rebuild failed: %v", err) t.Fatalf("Rebuild failed: %v", err)
} }
@@ -124,11 +124,12 @@ func TestIssuerRegistry_Rebuild_Enabled(t *testing.T) {
func TestIssuerRegistry_Rebuild_WithEncryption(t *testing.T) { func TestIssuerRegistry_Rebuild_WithEncryption(t *testing.T) {
reg := NewIssuerRegistry(registryTestLogger()) reg := NewIssuerRegistry(registryTestLogger())
key := crypto.DeriveKey("test-key")
configJSON := []byte(`{"ca_common_name":"Encrypted CA"}`) configJSON := []byte(`{"ca_common_name":"Encrypted CA"}`)
encrypted, err := crypto.Encrypt(configJSON, key) // M-8: EncryptIfKeySet now emits v2 (magic 0x02 || per-ciphertext salt || sealed).
// IssuerRegistry.Rebuild accepts the raw passphrase and delegates PBKDF2 to crypto.DecryptIfKeySet.
encrypted, _, err := crypto.EncryptIfKeySet(configJSON, "test-key")
if err != nil { if err != nil {
t.Fatalf("encrypt failed: %v", err) t.Fatalf("EncryptIfKeySet failed: %v", err)
} }
configs := []*domain.Issuer{ configs := []*domain.Issuer{
@@ -141,7 +142,7 @@ func TestIssuerRegistry_Rebuild_WithEncryption(t *testing.T) {
}, },
} }
err = reg.Rebuild(configs, key) err = reg.Rebuild(configs, "test-key")
if err != nil { if err != nil {
t.Fatalf("Rebuild with encryption failed: %v", err) t.Fatalf("Rebuild with encryption failed: %v", err)
} }
@@ -165,10 +166,11 @@ func TestIssuerRegistry_Rebuild_NilKeyFallback(t *testing.T) {
}, },
} }
// nil key should work — falls back to config column // Empty passphrase is safe when no EncryptedConfig is present — falls back to config column.
err := reg.Rebuild(configs, nil) // The C-2 fail-closed sentinel only fires when EncryptedConfig is non-empty.
err := reg.Rebuild(configs, "")
if err != nil { if err != nil {
t.Fatalf("Rebuild with nil key failed: %v", err) t.Fatalf("Rebuild with empty key failed: %v", err)
} }
_, ok := reg.Get("iss-plain") _, ok := reg.Get("iss-plain")
@@ -198,7 +200,7 @@ func TestIssuerRegistry_Rebuild_InvalidConfig(t *testing.T) {
} }
// Should return an error indicating partial failure, but still load valid issuers // Should return an error indicating partial failure, but still load valid issuers
err := reg.Rebuild(configs, nil) err := reg.Rebuild(configs, "")
if err == nil { if err == nil {
t.Fatal("Rebuild should return error when some issuers fail to load") t.Fatal("Rebuild should return error when some issuers fail to load")
} }
@@ -230,7 +232,7 @@ func TestIssuerRegistry_Rebuild_ReplacesExisting(t *testing.T) {
}, },
} }
err := reg.Rebuild(configs, nil) err := reg.Rebuild(configs, "")
if err != nil { if err != nil {
t.Fatalf("Rebuild failed: %v", err) t.Fatalf("Rebuild failed: %v", err)
} }
@@ -275,7 +277,7 @@ func TestIssuerRegistry_Rebuild_Empty(t *testing.T) {
reg.Set("iss-existing", &mockIssuerConnector{}) reg.Set("iss-existing", &mockIssuerConnector{})
err := reg.Rebuild([]*domain.Issuer{}, nil) err := reg.Rebuild([]*domain.Issuer{}, "")
if err != nil { if err != nil {
t.Fatalf("Rebuild with empty configs failed: %v", err) t.Fatalf("Rebuild with empty configs failed: %v", err)
} }
+26 -22
View File
@@ -50,7 +50,7 @@ func TestIssuerService_List(t *testing.T) {
auditRepo := newMockAuditRepository() auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default()) service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
issuers, total, err := service.List(ctx, 1, 2) issuers, total, err := service.List(ctx, 1, 2)
@@ -87,7 +87,7 @@ func TestIssuerService_List_DefaultPagination(t *testing.T) {
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
registry := NewIssuerRegistry(slog.Default()) registry := NewIssuerRegistry(slog.Default())
service := NewIssuerService(repo, auditService, registry, nil, slog.Default()) service := NewIssuerService(repo, auditService, registry, "", slog.Default())
// Call with invalid page and perPage // Call with invalid page and perPage
issuers, total, err := service.List(ctx, 0, 0) issuers, total, err := service.List(ctx, 0, 0)
@@ -115,7 +115,7 @@ func TestIssuerService_List_RepositoryError(t *testing.T) {
auditRepo := newMockAuditRepository() auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default()) service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
_, _, err := service.List(ctx, 1, 50) _, _, err := service.List(ctx, 1, 50)
@@ -137,7 +137,7 @@ func TestIssuerService_List_EmptyResult(t *testing.T) {
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
registry := NewIssuerRegistry(slog.Default()) registry := NewIssuerRegistry(slog.Default())
service := NewIssuerService(repo, auditService, registry, nil, slog.Default()) service := NewIssuerService(repo, auditService, registry, "", slog.Default())
issuers, total, err := service.List(ctx, 1, 50) issuers, total, err := service.List(ctx, 1, 50)
@@ -173,7 +173,7 @@ func TestIssuerService_Get(t *testing.T) {
auditRepo := newMockAuditRepository() auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default()) service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
retrieved, err := service.Get(ctx, "iss-acme-prod") retrieved, err := service.Get(ctx, "iss-acme-prod")
@@ -199,7 +199,7 @@ func TestIssuerService_Get_NotFound(t *testing.T) {
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
registry := NewIssuerRegistry(slog.Default()) registry := NewIssuerRegistry(slog.Default())
service := NewIssuerService(repo, auditService, registry, nil, slog.Default()) service := NewIssuerService(repo, auditService, registry, "", slog.Default())
_, err := service.Get(ctx, "nonexistent-issuer") _, err := service.Get(ctx, "nonexistent-issuer")
@@ -280,7 +280,7 @@ func TestIssuerService_Create_EmptyName(t *testing.T) {
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
registry := NewIssuerRegistry(slog.Default()) registry := NewIssuerRegistry(slog.Default())
service := NewIssuerService(repo, auditService, registry, nil, slog.Default()) service := NewIssuerService(repo, auditService, registry, "", slog.Default())
issuer := &domain.Issuer{ issuer := &domain.Issuer{
Name: "", Name: "",
@@ -314,7 +314,7 @@ func TestIssuerService_Create_RepositoryError(t *testing.T) {
auditRepo := newMockAuditRepository() auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default()) service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
issuer := &domain.Issuer{ issuer := &domain.Issuer{
Name: "Test Issuer", Name: "Test Issuer",
@@ -387,7 +387,7 @@ func TestIssuerService_Update_EmptyName(t *testing.T) {
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
registry := NewIssuerRegistry(slog.Default()) registry := NewIssuerRegistry(slog.Default())
service := NewIssuerService(repo, auditService, registry, nil, slog.Default()) service := NewIssuerService(repo, auditService, registry, "", slog.Default())
issuer := &domain.Issuer{ issuer := &domain.Issuer{
Name: "", Name: "",
@@ -415,7 +415,7 @@ func TestIssuerService_Delete(t *testing.T) {
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
registry := NewIssuerRegistry(slog.Default()) registry := NewIssuerRegistry(slog.Default())
service := NewIssuerService(repo, auditService, registry, nil, slog.Default()) service := NewIssuerService(repo, auditService, registry, "", slog.Default())
err := service.Delete(ctx, "iss-to-delete", "user-frank") err := service.Delete(ctx, "iss-to-delete", "user-frank")
@@ -447,7 +447,7 @@ func TestIssuerService_Delete_RepositoryError(t *testing.T) {
auditRepo := newMockAuditRepository() auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default()) service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
err := service.Delete(ctx, "iss-bad-id", "user-grace") err := service.Delete(ctx, "iss-bad-id", "user-grace")
@@ -482,12 +482,12 @@ func TestIssuerService_TestConnection_Success(t *testing.T) {
auditRepo := newMockAuditRepository() auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
svc := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default()) svc := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
err := svc.TestConnectionWithContext(ctx, "iss-test-conn") err := svc.TestConnection(ctx, "iss-test-conn")
if err != nil { if err != nil {
t.Fatalf("TestConnectionWithContext failed: %v", err) t.Fatalf("TestConnection failed: %v", err)
} }
} }
@@ -500,9 +500,9 @@ func TestIssuerService_TestConnection_NotFound(t *testing.T) {
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
registry := NewIssuerRegistry(slog.Default()) registry := NewIssuerRegistry(slog.Default())
service := NewIssuerService(repo, auditService, registry, nil, slog.Default()) service := NewIssuerService(repo, auditService, registry, "", slog.Default())
err := service.TestConnectionWithContext(ctx, "nonexistent-issuer") err := service.TestConnection(ctx, "nonexistent-issuer")
if err == nil { if err == nil {
t.Fatal("expected error for nonexistent issuer") t.Fatal("expected error for nonexistent issuer")
@@ -540,9 +540,10 @@ func TestIssuerService_ListIssuers_HandlerInterface(t *testing.T) {
auditRepo := newMockAuditRepository() auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default()) service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
issuers, total, err := service.ListIssuers(1, 50) ctx := context.Background()
issuers, total, err := service.ListIssuers(ctx, 1, 50)
if err != nil { if err != nil {
t.Fatalf("ListIssuers failed: %v", err) t.Fatalf("ListIssuers failed: %v", err)
@@ -580,7 +581,8 @@ func TestIssuerService_CreateIssuer_HandlerInterface(t *testing.T) {
Enabled: true, Enabled: true,
} }
result, err := service.CreateIssuer(issuer) ctx := context.Background()
result, err := service.CreateIssuer(ctx, issuer)
if err != nil { if err != nil {
t.Fatalf("CreateIssuer failed: %v", err) t.Fatalf("CreateIssuer failed: %v", err)
@@ -606,9 +608,10 @@ func TestIssuerService_DeleteIssuer_HandlerInterface(t *testing.T) {
auditService := NewAuditService(auditRepo) auditService := NewAuditService(auditRepo)
registry := NewIssuerRegistry(slog.Default()) registry := NewIssuerRegistry(slog.Default())
service := NewIssuerService(repo, auditService, registry, nil, slog.Default()) service := NewIssuerService(repo, auditService, registry, "", slog.Default())
err := service.DeleteIssuer("iss-handler-delete") ctx := context.Background()
err := service.DeleteIssuer(ctx, "iss-handler-delete")
if err != nil { if err != nil {
t.Fatalf("DeleteIssuer failed: %v", err) t.Fatalf("DeleteIssuer failed: %v", err)
@@ -722,7 +725,8 @@ func TestIssuerService_CreateIssuer_LowercaseType(t *testing.T) {
Enabled: true, Enabled: true,
} }
result, err := service.CreateIssuer(issuer) ctx := context.Background()
result, err := service.CreateIssuer(ctx, issuer)
if err != nil { if err != nil {
t.Fatalf("CreateIssuer with lowercase 'stepca' should succeed, got: %v", err) t.Fatalf("CreateIssuer with lowercase 'stepca' should succeed, got: %v", err)
} }
+8 -15
View File
@@ -189,8 +189,8 @@ func (s *JobService) GetJobStatus(ctx context.Context, jobID string) (*domain.Jo
return job, nil return job, nil
} }
// CancelJobWithContext cancels a pending or running job. // CancelJob cancels a pending or running job (handler interface method).
func (s *JobService) CancelJobWithContext(ctx context.Context, jobID string) error { func (s *JobService) CancelJob(ctx context.Context, jobID string) error {
job, err := s.jobRepo.Get(ctx, jobID) job, err := s.jobRepo.Get(ctx, jobID)
if err != nil { if err != nil {
return fmt.Errorf("failed to fetch job: %w", err) return fmt.Errorf("failed to fetch job: %w", err)
@@ -208,13 +208,8 @@ func (s *JobService) CancelJobWithContext(ctx context.Context, jobID string) err
return nil return nil
} }
// CancelJob cancels a job (handler interface method).
func (s *JobService) CancelJob(id string) error {
return s.CancelJobWithContext(context.Background(), id)
}
// ListJobs returns paginated jobs with optional filtering (handler interface method). // ListJobs returns paginated jobs with optional filtering (handler interface method).
func (s *JobService) ListJobs(status, jobType string, page, perPage int) ([]domain.Job, int64, error) { func (s *JobService) ListJobs(ctx context.Context, status, jobType string, page, perPage int) ([]domain.Job, int64, error) {
if page < 1 { if page < 1 {
page = 1 page = 1
} }
@@ -222,7 +217,7 @@ func (s *JobService) ListJobs(status, jobType string, page, perPage int) ([]doma
perPage = 50 perPage = 50
} }
allJobs, err := s.jobRepo.List(context.Background()) allJobs, err := s.jobRepo.List(ctx)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to list jobs: %w", err) return nil, 0, fmt.Errorf("failed to list jobs: %w", err)
} }
@@ -263,14 +258,13 @@ func (s *JobService) ListJobs(status, jobType string, page, perPage int) ([]doma
} }
// GetJob returns a single job (handler interface method). // GetJob returns a single job (handler interface method).
func (s *JobService) GetJob(id string) (*domain.Job, error) { func (s *JobService) GetJob(ctx context.Context, id string) (*domain.Job, error) {
return s.jobRepo.Get(context.Background(), id) return s.jobRepo.Get(ctx, id)
} }
// ApproveJob approves a renewal job that is awaiting approval. // ApproveJob approves a renewal job that is awaiting approval.
// Transitions the job from AwaitingApproval to Pending so the scheduler picks it up. // Transitions the job from AwaitingApproval to Pending so the scheduler picks it up.
func (s *JobService) ApproveJob(id string) error { func (s *JobService) ApproveJob(ctx context.Context, id string) error {
ctx := context.Background()
job, err := s.jobRepo.Get(ctx, id) job, err := s.jobRepo.Get(ctx, id)
if err != nil { if err != nil {
return fmt.Errorf("job not found: %w", err) return fmt.Errorf("job not found: %w", err)
@@ -290,8 +284,7 @@ func (s *JobService) ApproveJob(id string) error {
// RejectJob rejects a renewal job that is awaiting approval. // RejectJob rejects a renewal job that is awaiting approval.
// Transitions the job to Cancelled with a rejection reason. // Transitions the job to Cancelled with a rejection reason.
func (s *JobService) RejectJob(id string, reason string) error { func (s *JobService) RejectJob(ctx context.Context, id string, reason string) error {
ctx := context.Background()
job, err := s.jobRepo.Get(ctx, id) job, err := s.jobRepo.Get(ctx, id)
if err != nil { if err != nil {
return fmt.Errorf("job not found: %w", err) return fmt.Errorf("job not found: %w", err)
+11 -5
View File
@@ -99,7 +99,7 @@ func TestCancelJob(t *testing.T) {
jobService := newTestJobService(jobRepo) jobService := newTestJobService(jobRepo)
err := jobService.CancelJobWithContext(ctx, "job-001") err := jobService.CancelJob(ctx, "job-001")
if err != nil { if err != nil {
t.Fatalf("CancelJob failed: %v", err) t.Fatalf("CancelJob failed: %v", err)
} }
@@ -129,13 +129,15 @@ func TestCancelJob_AlreadyCompleted(t *testing.T) {
jobService := newTestJobService(jobRepo) jobService := newTestJobService(jobRepo)
err := jobService.CancelJobWithContext(ctx, "job-001") err := jobService.CancelJob(ctx, "job-001")
if err == nil { if err == nil {
t.Fatal("expected error for completed job") t.Fatal("expected error for completed job")
} }
} }
func TestGetJob(t *testing.T) { func TestGetJob(t *testing.T) {
ctx := context.Background()
now := time.Now() now := time.Now()
job := &domain.Job{ job := &domain.Job{
ID: "job-001", ID: "job-001",
@@ -153,7 +155,7 @@ func TestGetJob(t *testing.T) {
jobService := newTestJobService(jobRepo) jobService := newTestJobService(jobRepo)
retrieved, err := jobService.GetJob("job-001") retrieved, err := jobService.GetJob(ctx, "job-001")
if err != nil { if err != nil {
t.Fatalf("GetJob failed: %v", err) t.Fatalf("GetJob failed: %v", err)
} }
@@ -167,6 +169,8 @@ func TestGetJob(t *testing.T) {
} }
func TestListJobs(t *testing.T) { func TestListJobs(t *testing.T) {
ctx := context.Background()
now := time.Now() now := time.Now()
job1 := &domain.Job{ job1 := &domain.Job{
ID: "job-001", ID: "job-001",
@@ -192,7 +196,7 @@ func TestListJobs(t *testing.T) {
jobService := newTestJobService(jobRepo) jobService := newTestJobService(jobRepo)
jobs, total, err := jobService.ListJobs("", "", 1, 50) jobs, total, err := jobService.ListJobs(ctx, "", "", 1, 50)
if err != nil { if err != nil {
t.Fatalf("ListJobs failed: %v", err) t.Fatalf("ListJobs failed: %v", err)
} }
@@ -206,6 +210,8 @@ func TestListJobs(t *testing.T) {
} }
func TestListJobs_FilterByStatus(t *testing.T) { func TestListJobs_FilterByStatus(t *testing.T) {
ctx := context.Background()
now := time.Now() now := time.Now()
job1 := &domain.Job{ job1 := &domain.Job{
ID: "job-001", ID: "job-001",
@@ -231,7 +237,7 @@ func TestListJobs_FilterByStatus(t *testing.T) {
jobService := newTestJobService(jobRepo) jobService := newTestJobService(jobRepo)
jobs, total, err := jobService.ListJobs(string(domain.JobStatusPending), "", 1, 50) jobs, total, err := jobService.ListJobs(ctx, string(domain.JobStatusPending), "", 1, 50)
if err != nil { if err != nil {
t.Fatalf("ListJobs failed: %v", err) t.Fatalf("ListJobs failed: %v", err)
} }
+51 -15
View File
@@ -235,21 +235,19 @@ func (s *NetworkScanService) scanTarget(ctx context.Context, target *domain.Netw
timeout := time.Duration(target.TimeoutMs) * time.Millisecond timeout := time.Duration(target.TimeoutMs) * time.Millisecond
results := s.scanEndpoints(ctx, endpoints, timeout) results := s.scanEndpoints(ctx, endpoints, timeout)
// Collect discovered cert entries // Collect discovered cert entries and per-endpoint errors.
var entries []domain.DiscoveredCertEntry //
var scanErrors []string // M-9 (operator-observability): before this fix, scanErrors was declared
for _, result := range results { // but never appended to, so the "errors" count in the summary Info log
if result.Error != "" { // and the Errors field on the DiscoveryReport were always zero/nil —
// Only log connection errors at debug level (many hosts won't have TLS) // silently hiding per-endpoint failures from operators and from the
if s.logger != nil { // downstream scan history record. Per-endpoint failures are still logged
s.logger.Debug("scan endpoint error", // at Debug (sweep scans generate high connection-refused noise by design
"address", result.Address, // — most hosts in a CIDR won't have TLS on the probed port), but the
"error", result.Error) // aggregate count and the report's Errors field now reflect reality so
} // operators can see, via the scan summary and the stored scan record,
continue // how many endpoints failed without having to enable Debug logging.
} entries, scanErrors := s.collectScanResults(results)
entries = append(entries, result.Certs...)
}
scanDuration := time.Since(startTime) scanDuration := time.Since(startTime)
if s.logger != nil { if s.logger != nil {
@@ -385,6 +383,44 @@ func incrementIP(ip net.IP) {
} }
} }
// collectScanResults partitions per-endpoint scan results into discovered
// certificate entries and a list of per-endpoint error strings.
//
// M-9 (operator-observability): the summary Info log and the DiscoveryReport
// both report the count of endpoints that failed to probe. Before this helper
// existed, the caller accumulated entries but never populated the errors
// slice, so the aggregate error count was always zero and the scan record's
// Errors field was always nil — silently hiding per-endpoint failures.
//
// Per-endpoint errors remain logged at Debug (sweep scans generate high
// connection-refused noise by design — most hosts in a CIDR won't have TLS
// on the probed port). Aggregation surfaces the count at Info, preserving
// Debug-level detail for operators who want it without creating log spam
// at default verbosity.
func (s *NetworkScanService) collectScanResults(results []domain.NetworkScanResult) ([]domain.DiscoveredCertEntry, []string) {
var entries []domain.DiscoveredCertEntry
var scanErrors []string
for _, result := range results {
if result.Error != "" {
// Debug-level is intentional: a sweep scan of a /24 typically
// produces 200+ connection-refused results, and logging each
// at Warn would create log spam at default verbosity. The
// aggregate count in the Info-level scan-completed log surfaces
// the failure volume to operators; Debug provides the detail
// when diagnosing a specific endpoint.
if s.logger != nil {
s.logger.Debug("scan endpoint error",
"address", result.Address,
"error", result.Error)
}
scanErrors = append(scanErrors, fmt.Sprintf("%s: %s", result.Address, result.Error))
continue
}
entries = append(entries, result.Certs...)
}
return entries, scanErrors
}
// scanEndpoints probes TLS endpoints concurrently and returns results. // scanEndpoints probes TLS endpoints concurrently and returns results.
func (s *NetworkScanService) scanEndpoints(ctx context.Context, endpoints []string, timeout time.Duration) []domain.NetworkScanResult { func (s *NetworkScanService) scanEndpoints(ctx context.Context, endpoints []string, timeout time.Duration) []domain.NetworkScanResult {
results := make([]domain.NetworkScanResult, len(endpoints)) results := make([]domain.NetworkScanResult, len(endpoints))
+110
View File
@@ -491,3 +491,113 @@ func TestExpandCIDR_SingleLinkLocalIP(t *testing.T) {
t.Errorf("expected empty for cloud metadata IP, got %v", ips) t.Errorf("expected empty for cloud metadata IP, got %v", ips)
} }
} }
// TestCollectScanResults_AggregatesErrors is the M-9 regression guard:
// per-endpoint probe failures must accumulate into the errors slice so the
// summary Info log and the DiscoveryReport reflect the true failure count.
// Before the M-9 fix, scanErrors was declared but never appended to, so the
// aggregate count was always zero and the scan record's Errors field was
// always nil — silently hiding per-endpoint failures from operators.
func TestCollectScanResults_AggregatesErrors(t *testing.T) {
svc := &NetworkScanService{}
results := []domain.NetworkScanResult{
{Address: "203.0.113.1:443", Error: "connection refused"},
{Address: "203.0.113.2:443", Certs: []domain.DiscoveredCertEntry{
{CommonName: "example.com"},
}},
{Address: "203.0.113.3:443", Error: "tls handshake failure"},
{Address: "203.0.113.4:443", Certs: []domain.DiscoveredCertEntry{
{CommonName: "internal.example.com"},
}},
{Address: "203.0.113.5:443", Error: "i/o timeout"},
}
entries, errs := svc.collectScanResults(results)
if len(entries) != 2 {
t.Errorf("expected 2 entries (one per successful probe), got %d", len(entries))
}
if len(errs) != 3 {
t.Fatalf("expected 3 error strings (one per failed probe), got %d: %v", len(errs), errs)
}
// Each error string must be non-empty and include the endpoint address so
// the scan record lets operators correlate failures back to endpoints
// without needing Debug logging enabled.
for i, e := range errs {
if e == "" {
t.Errorf("error[%d]: expected non-empty error string", i)
}
}
// Spot-check that address is threaded through the error strings.
if want := "203.0.113.1:443"; errs[0] == "" || errs[0][:len(want)] != want {
t.Errorf("errs[0] should start with %q, got %q", want, errs[0])
}
if want := "203.0.113.3:443"; errs[1] == "" || errs[1][:len(want)] != want {
t.Errorf("errs[1] should start with %q, got %q", want, errs[1])
}
if want := "203.0.113.5:443"; errs[2] == "" || errs[2][:len(want)] != want {
t.Errorf("errs[2] should start with %q, got %q", want, errs[2])
}
}
// TestCollectScanResults_AllSuccess exercises the happy path: a scan where
// every endpoint returned certificates. The errors slice must be nil (not an
// empty non-nil slice) so the downstream DiscoveryReport.Errors field stays
// nil as well, preserving the JSON-omitempty behavior that callers rely on.
func TestCollectScanResults_AllSuccess(t *testing.T) {
svc := &NetworkScanService{}
results := []domain.NetworkScanResult{
{Address: "203.0.113.10:443", Certs: []domain.DiscoveredCertEntry{
{CommonName: "a.example.com"},
}},
{Address: "203.0.113.11:443", Certs: []domain.DiscoveredCertEntry{
{CommonName: "b.example.com"},
}},
}
entries, errs := svc.collectScanResults(results)
if len(entries) != 2 {
t.Errorf("expected 2 entries, got %d", len(entries))
}
if errs != nil {
t.Errorf("expected nil errors slice on all-success, got %v", errs)
}
}
// TestCollectScanResults_AllFailed exercises the worst-case sweep: every
// endpoint failed to probe. Entries must be nil, and every failure must be
// recorded in the errors slice so the scan record is complete.
func TestCollectScanResults_AllFailed(t *testing.T) {
svc := &NetworkScanService{}
results := []domain.NetworkScanResult{
{Address: "203.0.113.20:443", Error: "connection refused"},
{Address: "203.0.113.21:443", Error: "connection refused"},
{Address: "203.0.113.22:443", Error: "connection refused"},
}
entries, errs := svc.collectScanResults(results)
if entries != nil {
t.Errorf("expected nil entries on all-failed, got %v", entries)
}
if len(errs) != 3 {
t.Errorf("expected 3 error strings, got %d: %v", len(errs), errs)
}
}
// TestCollectScanResults_Empty guards against a degenerate empty-input case
// (scanEndpoints returns no results, e.g. if ctx was cancelled before the
// first probe ran). Both return slices must be nil.
func TestCollectScanResults_Empty(t *testing.T) {
svc := &NetworkScanService{}
entries, errs := svc.collectScanResults(nil)
if entries != nil {
t.Errorf("expected nil entries for empty input, got %v", entries)
}
if errs != nil {
t.Errorf("expected nil errors for empty input, got %v", errs)
}
}
+6 -6
View File
@@ -319,7 +319,7 @@ func (s *NotificationService) GetNotificationHistory(ctx context.Context, certID
} }
// ListNotifications returns paginated notifications (handler interface method). // ListNotifications returns paginated notifications (handler interface method).
func (s *NotificationService) ListNotifications(page, perPage int) ([]domain.NotificationEvent, int64, error) { func (s *NotificationService) ListNotifications(ctx context.Context, page, perPage int) ([]domain.NotificationEvent, int64, error) {
if page < 1 { if page < 1 {
page = 1 page = 1
} }
@@ -332,7 +332,7 @@ func (s *NotificationService) ListNotifications(page, perPage int) ([]domain.Not
PerPage: perPage, PerPage: perPage,
} }
notifications, err := s.notifRepo.List(context.Background(), filter) notifications, err := s.notifRepo.List(ctx, filter)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to list notifications: %w", err) return nil, 0, fmt.Errorf("failed to list notifications: %w", err)
} }
@@ -349,12 +349,12 @@ func (s *NotificationService) ListNotifications(page, perPage int) ([]domain.Not
} }
// GetNotification returns a single notification (handler interface method). // GetNotification returns a single notification (handler interface method).
func (s *NotificationService) GetNotification(id string) (*domain.NotificationEvent, error) { func (s *NotificationService) GetNotification(ctx context.Context, id string) (*domain.NotificationEvent, error) {
filter := &repository.NotificationFilter{ filter := &repository.NotificationFilter{
PerPage: 1, PerPage: 1,
} }
notifications, err := s.notifRepo.List(context.Background(), filter) notifications, err := s.notifRepo.List(ctx, filter)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get notification: %w", err) return nil, fmt.Errorf("failed to get notification: %w", err)
} }
@@ -370,6 +370,6 @@ func (s *NotificationService) GetNotification(id string) (*domain.NotificationEv
} }
// MarkAsRead marks a notification as read (handler interface method). // MarkAsRead marks a notification as read (handler interface method).
func (s *NotificationService) MarkAsRead(id string) error { func (s *NotificationService) MarkAsRead(ctx context.Context, id string) error {
return s.notifRepo.UpdateStatus(context.Background(), id, "read", time.Now()) return s.notifRepo.UpdateStatus(ctx, id, "read", time.Now())
} }
+3 -3
View File
@@ -370,7 +370,7 @@ func TestListNotifications(t *testing.T) {
} }
// List with pagination // List with pagination
notifs, total, err := svc.ListNotifications(1, 3) notifs, total, err := svc.ListNotifications(context.Background(), 1, 3)
if err != nil { if err != nil {
t.Fatalf("ListNotifications failed: %v", err) t.Fatalf("ListNotifications failed: %v", err)
} }
@@ -404,7 +404,7 @@ func TestMarkAsRead(t *testing.T) {
notifRepo.AddNotification(notif) notifRepo.AddNotification(notif)
// Mark as read // Mark as read
err := svc.MarkAsRead(notif.ID) err := svc.MarkAsRead(context.Background(), notif.ID)
if err != nil { if err != nil {
t.Fatalf("MarkAsRead failed: %v", err) t.Fatalf("MarkAsRead failed: %v", err)
} }
@@ -434,7 +434,7 @@ func TestGetNotification(t *testing.T) {
notifRepo.AddNotification(notif) notifRepo.AddNotification(notif)
// Get the notification // Get the notification
retrieved, err := svc.GetNotification(notif.ID) retrieved, err := svc.GetNotification(context.Background(), notif.ID)
if err != nil { if err != nil {
t.Fatalf("GetNotification failed: %v", err) t.Fatalf("GetNotification failed: %v", err)
} }
+10 -10
View File
@@ -126,7 +126,7 @@ func (s *OwnerService) Delete(ctx context.Context, id string, actor string) erro
} }
// ListOwners returns paginated owners (handler interface method). // ListOwners returns paginated owners (handler interface method).
func (s *OwnerService) ListOwners(page, perPage int) ([]domain.Owner, int64, error) { func (s *OwnerService) ListOwners(ctx context.Context, page, perPage int) ([]domain.Owner, int64, error) {
if page < 1 { if page < 1 {
page = 1 page = 1
} }
@@ -134,7 +134,7 @@ func (s *OwnerService) ListOwners(page, perPage int) ([]domain.Owner, int64, err
perPage = 50 perPage = 50
} }
owners, err := s.ownerRepo.List(context.Background()) owners, err := s.ownerRepo.List(ctx)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to list owners: %w", err) return nil, 0, fmt.Errorf("failed to list owners: %w", err)
} }
@@ -151,12 +151,12 @@ func (s *OwnerService) ListOwners(page, perPage int) ([]domain.Owner, int64, err
} }
// GetOwner returns a single owner (handler interface method). // GetOwner returns a single owner (handler interface method).
func (s *OwnerService) GetOwner(id string) (*domain.Owner, error) { func (s *OwnerService) GetOwner(ctx context.Context, id string) (*domain.Owner, error) {
return s.ownerRepo.Get(context.Background(), id) return s.ownerRepo.Get(ctx, id)
} }
// CreateOwner creates a new owner (handler interface method). // CreateOwner creates a new owner (handler interface method).
func (s *OwnerService) CreateOwner(owner domain.Owner) (*domain.Owner, error) { func (s *OwnerService) CreateOwner(ctx context.Context, owner domain.Owner) (*domain.Owner, error) {
if owner.ID == "" { if owner.ID == "" {
owner.ID = generateID("owner") owner.ID = generateID("owner")
} }
@@ -167,22 +167,22 @@ func (s *OwnerService) CreateOwner(owner domain.Owner) (*domain.Owner, error) {
if owner.UpdatedAt.IsZero() { if owner.UpdatedAt.IsZero() {
owner.UpdatedAt = now owner.UpdatedAt = now
} }
if err := s.ownerRepo.Create(context.Background(), &owner); err != nil { if err := s.ownerRepo.Create(ctx, &owner); err != nil {
return nil, fmt.Errorf("failed to create owner: %w", err) return nil, fmt.Errorf("failed to create owner: %w", err)
} }
return &owner, nil return &owner, nil
} }
// UpdateOwner modifies an owner (handler interface method). // UpdateOwner modifies an owner (handler interface method).
func (s *OwnerService) UpdateOwner(id string, owner domain.Owner) (*domain.Owner, error) { func (s *OwnerService) UpdateOwner(ctx context.Context, id string, owner domain.Owner) (*domain.Owner, error) {
owner.ID = id owner.ID = id
if err := s.ownerRepo.Update(context.Background(), &owner); err != nil { if err := s.ownerRepo.Update(ctx, &owner); err != nil {
return nil, fmt.Errorf("failed to update owner: %w", err) return nil, fmt.Errorf("failed to update owner: %w", err)
} }
return &owner, nil return &owner, nil
} }
// DeleteOwner removes an owner (handler interface method). // DeleteOwner removes an owner (handler interface method).
func (s *OwnerService) DeleteOwner(id string) error { func (s *OwnerService) DeleteOwner(ctx context.Context, id string) error {
return s.ownerRepo.Delete(context.Background(), id) return s.ownerRepo.Delete(ctx, id)
} }
+5 -5
View File
@@ -638,7 +638,7 @@ func TestOwnerService_ListOwners_HandlerInterface(t *testing.T) {
ownerService := NewOwnerService(ownerRepo, auditService) ownerService := NewOwnerService(ownerRepo, auditService)
owners, total, err := ownerService.ListOwners(1, 50) owners, total, err := ownerService.ListOwners(context.Background(), 1, 50)
if err != nil { if err != nil {
t.Fatalf("ListOwners failed: %v", err) t.Fatalf("ListOwners failed: %v", err)
} }
@@ -678,7 +678,7 @@ func TestOwnerService_GetOwner_HandlerInterface(t *testing.T) {
ownerService := NewOwnerService(ownerRepo, auditService) ownerService := NewOwnerService(ownerRepo, auditService)
retrieved, err := ownerService.GetOwner("owner-001") retrieved, err := ownerService.GetOwner(context.Background(), "owner-001")
if err != nil { if err != nil {
t.Fatalf("GetOwner failed: %v", err) t.Fatalf("GetOwner failed: %v", err)
} }
@@ -702,7 +702,7 @@ func TestOwnerService_CreateOwner_HandlerInterface(t *testing.T) {
TeamID: "team-001", TeamID: "team-001",
} }
created, err := ownerService.CreateOwner(owner) created, err := ownerService.CreateOwner(context.Background(), owner)
if err != nil { if err != nil {
t.Fatalf("CreateOwner failed: %v", err) t.Fatalf("CreateOwner failed: %v", err)
} }
@@ -752,7 +752,7 @@ func TestOwnerService_UpdateOwner_HandlerInterface(t *testing.T) {
TeamID: "team-002", TeamID: "team-002",
} }
updated, err := ownerService.UpdateOwner("owner-001", updatedOwner) updated, err := ownerService.UpdateOwner(context.Background(), "owner-001", updatedOwner)
if err != nil { if err != nil {
t.Fatalf("UpdateOwner failed: %v", err) t.Fatalf("UpdateOwner failed: %v", err)
} }
@@ -798,7 +798,7 @@ func TestOwnerService_DeleteOwner_HandlerInterface(t *testing.T) {
ownerService := NewOwnerService(ownerRepo, auditService) ownerService := NewOwnerService(ownerRepo, auditService)
err := ownerService.DeleteOwner("owner-001") err := ownerService.DeleteOwner(context.Background(), "owner-001")
if err != nil { if err != nil {
t.Fatalf("DeleteOwner failed: %v", err) t.Fatalf("DeleteOwner failed: %v", err)
} }
+232 -61
View File
@@ -2,8 +2,10 @@ package service
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"log/slog" "log/slog"
"strings"
"time" "time"
"github.com/shankar0123/certctl/internal/domain" "github.com/shankar0123/certctl/internal/domain"
@@ -14,6 +16,11 @@ import (
type PolicyService struct { type PolicyService struct {
policyRepo repository.PolicyRepository policyRepo repository.PolicyRepository
auditService *AuditService auditService *AuditService
// certRepo is optional and only required by the CertificateLifetime rule
// arm, which must read NotBefore/NotAfter from the latest CertificateVersion.
// Wire via SetCertRepo after construction; rules other than
// CertificateLifetime operate without it.
certRepo repository.CertificateRepository
} }
// NewPolicyService creates a new policy service. // NewPolicyService creates a new policy service.
@@ -27,6 +34,16 @@ func NewPolicyService(
} }
} }
// SetCertRepo wires the certificate repository needed for the CertificateLifetime
// rule arm. Kept as a setter (not a constructor parameter) so the ~36 existing
// NewPolicyService call sites don't churn for a single new arm's dependency.
// Safe to call before or after construction; evaluateRule checks for nil and
// returns an error if a CertificateLifetime rule fires without a wired repo
// (the caller at ValidateCertificate logs and continues).
func (s *PolicyService) SetCertRepo(r repository.CertificateRepository) {
s.certRepo = r
}
// ValidateCertificate runs all enabled policy rules against a certificate. // ValidateCertificate runs all enabled policy rules against a certificate.
func (s *PolicyService) ValidateCertificate(ctx context.Context, cert *domain.ManagedCertificate) ([]*domain.PolicyViolation, error) { func (s *PolicyService) ValidateCertificate(ctx context.Context, cert *domain.ManagedCertificate) ([]*domain.PolicyViolation, error) {
rules, err := s.policyRepo.ListRules(ctx) rules, err := s.policyRepo.ListRules(ctx)
@@ -43,7 +60,7 @@ func (s *PolicyService) ValidateCertificate(ctx context.Context, cert *domain.Ma
} }
// Evaluate rule against certificate // Evaluate rule against certificate
v, err := s.evaluateRule(rule, cert) v, err := s.evaluateRule(ctx, rule, cert)
if err != nil { if err != nil {
slog.Error("failed to evaluate rule", "rule_id", rule.ID, "error", err) slog.Error("failed to evaluate rule", "rule_id", rule.ID, "error", err)
continue continue
@@ -58,73 +75,163 @@ func (s *PolicyService) ValidateCertificate(ctx context.Context, cert *domain.Ma
} }
// evaluateRule checks if a certificate violates a single policy rule. // evaluateRule checks if a certificate violates a single policy rule.
func (s *PolicyService) evaluateRule(rule *domain.PolicyRule, cert *domain.ManagedCertificate) (*domain.PolicyViolation, error) { //
// D-008 closes the engine loop by:
// 1. Consuming rule.Severity on every violation (the pre-D-008 engine
// hardcoded PolicySeverityWarning, which silently defeated the D-006
// per-rule severity column).
// 2. Parsing rule.Config per-arm so rules carry real thresholds / allowlists
// instead of the pre-D-008 "metadata absent" placeholders. Empty/null
// Config preserves the pre-D-008 missing-field behavior as a
// backward-compat invariant — a rule without config still fires on the
// absent-field shape but using its configured severity.
// 3. Adding the CertificateLifetime arm, which reads NotBefore/NotAfter from
// the latest CertificateVersion (injected via SetCertRepo). Required
// because ManagedCertificate tracks ExpiresAt but not issuance date.
//
// Bad-config failure mode: json.Unmarshal error returns (nil, error) shaped
// as `invalid config for rule <id> (type=<type>): <err>`; the caller at
// ValidateCertificate logs and continues so one malformed rule doesn't fail
// the entire pass.
func (s *PolicyService) evaluateRule(ctx context.Context, rule *domain.PolicyRule, cert *domain.ManagedCertificate) (*domain.PolicyViolation, error) {
switch rule.Type { switch rule.Type {
case domain.PolicyTypeAllowedIssuers: case domain.PolicyTypeAllowedIssuers:
// Restrict to specific issuers // Config: {"allowed_issuer_ids": ["iss-a", "iss-b"]}
// Note: In a production implementation, we would parse rule.Config to extract parameters // Empty config = fire only on absent IssuerID (backward-compat).
var cfg struct {
AllowedIssuerIDs []string `json:"allowed_issuer_ids"`
}
if len(rule.Config) > 0 {
if err := json.Unmarshal(rule.Config, &cfg); err != nil {
return nil, fmt.Errorf("invalid config for rule %s (type=%s): %w", rule.ID, rule.Type, err)
}
}
if cert.IssuerID == "" { if cert.IssuerID == "" {
return &domain.PolicyViolation{ return s.violation(rule, cert, "certificate has no issuer assigned"), nil
ID: generateID("violation"), }
RuleID: rule.ID, if len(cfg.AllowedIssuerIDs) > 0 && !containsString(cfg.AllowedIssuerIDs, cert.IssuerID) {
CertificateID: cert.ID, return s.violation(rule, cert, fmt.Sprintf("issuer %q is not in the allowed list", cert.IssuerID)), nil
Severity: domain.PolicySeverityWarning,
Message: "certificate has no issuer assigned",
CreatedAt: time.Now(),
}, nil
} }
case domain.PolicyTypeAllowedDomains: case domain.PolicyTypeAllowedDomains:
// Ensure certificate domains are in allowed list // Config: {"allowed_domains": ["example.com", "*.internal.example.com"]}
// Wildcards are literal prefix matches (*.foo matches anything ending
// in .foo). Empty config = fire only on zero SANs (backward-compat).
var cfg struct {
AllowedDomains []string `json:"allowed_domains"`
}
if len(rule.Config) > 0 {
if err := json.Unmarshal(rule.Config, &cfg); err != nil {
return nil, fmt.Errorf("invalid config for rule %s (type=%s): %w", rule.ID, rule.Type, err)
}
}
if len(cert.SANs) == 0 { if len(cert.SANs) == 0 {
return &domain.PolicyViolation{ return s.violation(rule, cert, "certificate has no subject alternative names"), nil
ID: generateID("violation"), }
RuleID: rule.ID, if len(cfg.AllowedDomains) > 0 {
CertificateID: cert.ID, for _, san := range cert.SANs {
Severity: domain.PolicySeverityWarning, if !domainAllowed(san, cfg.AllowedDomains) {
Message: "certificate has no subject alternative names", return s.violation(rule, cert, fmt.Sprintf("SAN %q is not in the allowed domain list", san)), nil
CreatedAt: time.Now(), }
}, nil }
} }
case domain.PolicyTypeRequiredMetadata: case domain.PolicyTypeRequiredMetadata:
// Ensure certificate has required metadata/tags // Config: {"required_keys": ["owner", "cost-center"]}
// Empty config = fire only on zero tags (backward-compat).
var cfg struct {
RequiredKeys []string `json:"required_keys"`
}
if len(rule.Config) > 0 {
if err := json.Unmarshal(rule.Config, &cfg); err != nil {
return nil, fmt.Errorf("invalid config for rule %s (type=%s): %w", rule.ID, rule.Type, err)
}
}
if len(cert.Tags) == 0 { if len(cert.Tags) == 0 {
return &domain.PolicyViolation{ return s.violation(rule, cert, "certificate has no tags or metadata"), nil
ID: generateID("violation"), }
RuleID: rule.ID, for _, key := range cfg.RequiredKeys {
CertificateID: cert.ID, if _, ok := cert.Tags[key]; !ok {
Severity: domain.PolicySeverityWarning, return s.violation(rule, cert, fmt.Sprintf("certificate is missing required metadata key %q", key)), nil
Message: "certificate has no tags or metadata", }
CreatedAt: time.Now(),
}, nil
} }
case domain.PolicyTypeAllowedEnvironments: case domain.PolicyTypeAllowedEnvironments:
// Restrict to specific environments // Config: {"allowed": ["prod", "staging"]}
// Empty config = fire only on empty Environment (backward-compat).
var cfg struct {
Allowed []string `json:"allowed"`
}
if len(rule.Config) > 0 {
if err := json.Unmarshal(rule.Config, &cfg); err != nil {
return nil, fmt.Errorf("invalid config for rule %s (type=%s): %w", rule.ID, rule.Type, err)
}
}
if cert.Environment == "" { if cert.Environment == "" {
return &domain.PolicyViolation{ return s.violation(rule, cert, "certificate has no environment assigned"), nil
ID: generateID("violation"), }
RuleID: rule.ID, if len(cfg.Allowed) > 0 && !containsString(cfg.Allowed, cert.Environment) {
CertificateID: cert.ID, return s.violation(rule, cert, fmt.Sprintf("environment %q is not in the allowed list", cert.Environment)), nil
Severity: domain.PolicySeverityWarning,
Message: "certificate has no environment assigned",
CreatedAt: time.Now(),
}, nil
} }
case domain.PolicyTypeRenewalLeadTime: case domain.PolicyTypeRenewalLeadTime:
// Ensure renewal begins before certificate expires // Config: {"lead_time_days": 30}
// Fires when remaining validity drops below lead_time_days and the
// cert is not already expired. Empty/zero config falls back to the
// pre-D-008 hardcoded 30-day threshold for backward compatibility.
var cfg struct {
LeadTimeDays int `json:"lead_time_days"`
}
if len(rule.Config) > 0 {
if err := json.Unmarshal(rule.Config, &cfg); err != nil {
return nil, fmt.Errorf("invalid config for rule %s (type=%s): %w", rule.ID, rule.Type, err)
}
}
leadDays := cfg.LeadTimeDays
if leadDays <= 0 {
leadDays = 30
}
daysUntilExpiry := time.Until(cert.ExpiresAt).Hours() / 24 daysUntilExpiry := time.Until(cert.ExpiresAt).Hours() / 24
if daysUntilExpiry < 30 && daysUntilExpiry > 0 { if daysUntilExpiry < float64(leadDays) && daysUntilExpiry > 0 {
return &domain.PolicyViolation{ return s.violation(rule, cert, fmt.Sprintf("certificate expires in %.1f days, plan renewal soon (policy lead time: %d days)", daysUntilExpiry, leadDays)), nil
ID: generateID("violation"), }
RuleID: rule.ID,
CertificateID: cert.ID, case domain.PolicyTypeCertificateLifetime:
Severity: domain.PolicySeverityWarning, // Config: {"max_days": 397}
Message: fmt.Sprintf("certificate expires in %.1f days, plan renewal soon", daysUntilExpiry), // Reads NotBefore/NotAfter from the latest CertificateVersion via the
CreatedAt: time.Now(), // injected certRepo. ManagedCertificate exposes ExpiresAt but not the
}, nil // issuance date, so lifetime math requires the version record.
//
// If certRepo wasn't wired (test misconfiguration / early boot),
// returns an error so the caller logs it — better a loud failure
// than silently ignoring the rule. If GetLatestVersion errors (e.g.,
// the cert hasn't been issued yet), we skip the check — a cert with
// no version has no lifetime to measure, matching the missing-field
// backward-compat pattern used by the other arms.
if s.certRepo == nil {
return nil, fmt.Errorf("CertificateLifetime rule %s requires cert repository (not wired via SetCertRepo)", rule.ID)
}
var cfg struct {
MaxDays int `json:"max_days"`
}
if len(rule.Config) > 0 {
if err := json.Unmarshal(rule.Config, &cfg); err != nil {
return nil, fmt.Errorf("invalid config for rule %s (type=%s): %w", rule.ID, rule.Type, err)
}
}
if cfg.MaxDays <= 0 {
// No threshold configured — nothing meaningful to enforce.
return nil, nil
}
version, err := s.certRepo.GetLatestVersion(ctx, cert.ID)
if err != nil {
// No version yet — nothing to measure. Not an engine error;
// the cert simply hasn't been issued.
return nil, nil
}
lifetimeDays := version.NotAfter.Sub(version.NotBefore).Hours() / 24
if lifetimeDays > float64(cfg.MaxDays) {
return s.violation(rule, cert, fmt.Sprintf("certificate lifetime is %.1f days, exceeds policy max of %d days", lifetimeDays, cfg.MaxDays)), nil
} }
default: default:
@@ -134,6 +241,56 @@ func (s *PolicyService) evaluateRule(rule *domain.PolicyRule, cert *domain.Manag
return nil, nil return nil, nil
} }
// violation constructs a PolicyViolation carrying the rule's configured
// severity. Centralizing the build eliminates the pre-D-008 bug where each
// arm independently stamped PolicySeverityWarning on its violation.
func (s *PolicyService) violation(rule *domain.PolicyRule, cert *domain.ManagedCertificate, message string) *domain.PolicyViolation {
return &domain.PolicyViolation{
ID: generateID("violation"),
RuleID: rule.ID,
CertificateID: cert.ID,
Severity: rule.Severity,
Message: message,
CreatedAt: time.Now(),
}
}
// containsString reports whether needle is present in haystack.
func containsString(haystack []string, needle string) bool {
for _, s := range haystack {
if s == needle {
return true
}
}
return false
}
// domainAllowed reports whether a SAN (hostname) matches any of the allowed
// domain patterns. Patterns may be exact matches or `*.example.com` wildcards
// (the wildcard consumes a single label: `*.foo.com` matches `bar.foo.com`
// but not `baz.bar.foo.com`, mirroring X.509 SAN wildcard semantics).
func domainAllowed(san string, allowed []string) bool {
san = strings.ToLower(strings.TrimSpace(san))
for _, pattern := range allowed {
pattern = strings.ToLower(strings.TrimSpace(pattern))
if pattern == san {
return true
}
if strings.HasPrefix(pattern, "*.") {
suffix := pattern[1:] // ".foo.com"
if strings.HasSuffix(san, suffix) {
// Ensure wildcard consumes exactly one label — reject
// sub-subdomains.
head := strings.TrimSuffix(san, suffix)
if head != "" && !strings.Contains(head, ".") {
return true
}
}
}
}
return false
}
// CreateRule stores a new policy rule. // CreateRule stores a new policy rule.
func (s *PolicyService) CreateRule(ctx context.Context, rule *domain.PolicyRule, actor string) error { func (s *PolicyService) CreateRule(ctx context.Context, rule *domain.PolicyRule, actor string) error {
if rule.ID == "" { if rule.ID == "" {
@@ -230,7 +387,7 @@ func (s *PolicyService) ListViolationsWithContext(ctx context.Context, filter *r
} }
// ListPolicies returns paginated policies (handler interface method). // ListPolicies returns paginated policies (handler interface method).
func (s *PolicyService) ListPolicies(page, perPage int) ([]domain.PolicyRule, int64, error) { func (s *PolicyService) ListPolicies(ctx context.Context, page, perPage int) ([]domain.PolicyRule, int64, error) {
if page < 1 { if page < 1 {
page = 1 page = 1
} }
@@ -238,7 +395,7 @@ func (s *PolicyService) ListPolicies(page, perPage int) ([]domain.PolicyRule, in
perPage = 50 perPage = 50
} }
rules, err := s.policyRepo.ListRules(context.Background()) rules, err := s.policyRepo.ListRules(ctx)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to list policies: %w", err) return nil, 0, fmt.Errorf("failed to list policies: %w", err)
} }
@@ -264,12 +421,12 @@ func (s *PolicyService) ListPolicies(page, perPage int) ([]domain.PolicyRule, in
} }
// GetPolicy returns a single policy (handler interface method). // GetPolicy returns a single policy (handler interface method).
func (s *PolicyService) GetPolicy(id string) (*domain.PolicyRule, error) { func (s *PolicyService) GetPolicy(ctx context.Context, id string) (*domain.PolicyRule, error) {
return s.policyRepo.GetRule(context.Background(), id) return s.policyRepo.GetRule(ctx, id)
} }
// CreatePolicy creates a new policy (handler interface method). // CreatePolicy creates a new policy (handler interface method).
func (s *PolicyService) CreatePolicy(policy domain.PolicyRule) (*domain.PolicyRule, error) { func (s *PolicyService) CreatePolicy(ctx context.Context, policy domain.PolicyRule) (*domain.PolicyRule, error) {
if policy.ID == "" { if policy.ID == "" {
policy.ID = generateID("rule") policy.ID = generateID("rule")
} }
@@ -277,30 +434,44 @@ func (s *PolicyService) CreatePolicy(policy domain.PolicyRule) (*domain.PolicyRu
policy.CreatedAt = time.Now() policy.CreatedAt = time.Now()
} }
if err := s.policyRepo.CreateRule(context.Background(), &policy); err != nil { if err := s.policyRepo.CreateRule(ctx, &policy); err != nil {
return nil, fmt.Errorf("failed to create policy: %w", err) return nil, fmt.Errorf("failed to create policy: %w", err)
} }
return &policy, nil return &policy, nil
} }
// UpdatePolicy modifies a policy (handler interface method). // UpdatePolicy modifies a policy (handler interface method).
func (s *PolicyService) UpdatePolicy(id string, policy domain.PolicyRule) (*domain.PolicyRule, error) { func (s *PolicyService) UpdatePolicy(ctx context.Context, id string, policy domain.PolicyRule) (*domain.PolicyRule, error) {
policy.ID = id policy.ID = id
policy.UpdatedAt = time.Now() policy.UpdatedAt = time.Now()
if err := s.policyRepo.UpdateRule(context.Background(), &policy); err != nil { // Severity is NOT NULL with a CHECK constraint at the DB level
// (migration 000013). If the client omits severity on a PUT (zero-value
// empty string after json.Decode), preserve the existing severity rather
// than letting the CHECK reject the write. Preserves partial-update
// semantics for the new column without changing the pre-existing behavior
// for Name/Type, which is out of scope for D-005/D-006.
if policy.Severity == "" {
existing, err := s.policyRepo.GetRule(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to fetch existing rule for severity preservation: %w", err)
}
policy.Severity = existing.Severity
}
if err := s.policyRepo.UpdateRule(ctx, &policy); err != nil {
return nil, fmt.Errorf("failed to update policy: %w", err) return nil, fmt.Errorf("failed to update policy: %w", err)
} }
return &policy, nil return &policy, nil
} }
// DeletePolicy removes a policy (handler interface method). // DeletePolicy removes a policy (handler interface method).
func (s *PolicyService) DeletePolicy(id string) error { func (s *PolicyService) DeletePolicy(ctx context.Context, id string) error {
return s.policyRepo.DeleteRule(context.Background(), id) return s.policyRepo.DeleteRule(ctx, id)
} }
// ListViolations returns policy violations with pagination (handler interface method). // ListViolations returns policy violations with pagination (handler interface method).
func (s *PolicyService) ListViolations(policyID string, page, perPage int) ([]domain.PolicyViolation, int64, error) { func (s *PolicyService) ListViolations(ctx context.Context, policyID string, page, perPage int) ([]domain.PolicyViolation, int64, error) {
if page < 1 { if page < 1 {
page = 1 page = 1
} }
@@ -313,7 +484,7 @@ func (s *PolicyService) ListViolations(policyID string, page, perPage int) ([]do
PerPage: 1000, // Get all violations for the policy PerPage: 1000, // Get all violations for the policy
} }
violations, err := s.policyRepo.ListViolations(context.Background(), filter) violations, err := s.policyRepo.ListViolations(ctx, filter)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to list violations: %w", err) return nil, 0, fmt.Errorf("failed to list violations: %w", err)
} }
+536 -2
View File
@@ -3,6 +3,7 @@ package service
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"strings"
"testing" "testing"
"time" "time"
@@ -376,7 +377,7 @@ func TestListPolicies(t *testing.T) {
policyService := NewPolicyService(policyRepo, auditService) policyService := NewPolicyService(policyRepo, auditService)
policies, total, err := policyService.ListPolicies(1, 50) policies, total, err := policyService.ListPolicies(context.Background(), 1, 50)
if err != nil { if err != nil {
t.Fatalf("ListPolicies failed: %v", err) t.Fatalf("ListPolicies failed: %v", err)
} }
@@ -407,7 +408,7 @@ func TestCreatePolicy(t *testing.T) {
CreatedAt: now, CreatedAt: now,
} }
created, err := policyService.CreatePolicy(policy) created, err := policyService.CreatePolicy(context.Background(), policy)
if err != nil { if err != nil {
t.Fatalf("CreatePolicy failed: %v", err) t.Fatalf("CreatePolicy failed: %v", err)
} }
@@ -420,3 +421,536 @@ func TestCreatePolicy(t *testing.T) {
t.Errorf("expected 1 rule in repo, got %d", len(policyRepo.Rules)) t.Errorf("expected 1 rule in repo, got %d", len(policyRepo.Rules))
} }
} }
// ============================================================================
// D-008 regression tests
//
// These pin the behavior that closes the D-006 loop:
// 1. evaluateRule copies rule.Severity onto every violation (pre-D-008 the
// engine hardcoded Warning regardless of the rule's configured severity).
// 2. evaluateRule parses rule.Config per-arm so rules enforce real thresholds
// and allowlists (pre-D-008 the configs were ignored; rules fired only on
// the missing-field shape).
// 3. An empty/zero Config preserves the pre-D-008 missing-field violation
// (backward-compat invariant).
// 4. Malformed Config returns an error; the caller logs and skips the rule
// instead of producing a zero-value violation.
// 5. CertificateLifetime (new 6th arm) reads NotBefore/NotAfter from the
// latest CertificateVersion via the cert repo wired with SetCertRepo.
// ============================================================================
// mkRule is a tiny constructor used by the D-008 tests to keep the table rows
// readable. Every rule is enabled; test-specific fields layer on top.
func mkRule(id string, t domain.PolicyType, sev domain.PolicySeverity, cfg string) *domain.PolicyRule {
return &domain.PolicyRule{
ID: id,
Name: id,
Type: t,
Config: json.RawMessage(cfg),
Enabled: true,
Severity: sev,
}
}
// evalCert is a minimal cert used by the arms that don't look at much beyond
// the shape of the field they're testing. Tests shadow fields as needed.
func evalCert() *domain.ManagedCertificate {
return &domain.ManagedCertificate{
ID: "cert-001",
CommonName: "example.com",
Status: domain.CertificateStatusActive,
ExpiresAt: time.Now().AddDate(1, 0, 0),
}
}
// TestEvaluateRule_SeverityPassThrough pins invariant #1 — every arm stamps
// rule.Severity onto the violation. The pre-D-008 bug was that arms
// independently hardcoded PolicySeverityWarning. We test each arm with a
// severity that isn't the legacy default so a regression would be visible.
func TestEvaluateRule_SeverityPassThrough(t *testing.T) {
ctx := context.Background()
// Cert shaped to fail every non-empty-config check via the backward-compat
// missing-field path. Each row picks a severity intentionally ≠ Warning to
// make a stray hardcoded default obvious.
cases := []struct {
name string
rule *domain.PolicyRule
cert *domain.ManagedCertificate
setupFn func(svc *PolicyService)
expected domain.PolicySeverity
}{
{
name: "AllowedIssuers Critical via missing IssuerID",
rule: mkRule("r-ai", domain.PolicyTypeAllowedIssuers, domain.PolicySeverityCritical, ""),
cert: func() *domain.ManagedCertificate {
c := evalCert()
c.IssuerID = ""
return c
}(),
expected: domain.PolicySeverityCritical,
},
{
name: "AllowedDomains Error via empty SANs",
rule: mkRule("r-ad", domain.PolicyTypeAllowedDomains, domain.PolicySeverityError, ""),
cert: func() *domain.ManagedCertificate {
c := evalCert()
c.SANs = nil
return c
}(),
expected: domain.PolicySeverityError,
},
{
name: "RequiredMetadata Critical via empty Tags",
rule: mkRule("r-rm", domain.PolicyTypeRequiredMetadata, domain.PolicySeverityCritical, ""),
cert: func() *domain.ManagedCertificate {
c := evalCert()
c.Tags = nil
return c
}(),
expected: domain.PolicySeverityCritical,
},
{
name: "AllowedEnvironments Warning via empty Environment",
rule: mkRule("r-ae", domain.PolicyTypeAllowedEnvironments, domain.PolicySeverityWarning, ""),
cert: func() *domain.ManagedCertificate {
c := evalCert()
c.Environment = ""
return c
}(),
expected: domain.PolicySeverityWarning,
},
{
name: "RenewalLeadTime Critical via short remaining validity",
rule: mkRule("r-rl", domain.PolicyTypeRenewalLeadTime, domain.PolicySeverityCritical, `{"lead_time_days": 60}`),
cert: func() *domain.ManagedCertificate {
c := evalCert()
c.ExpiresAt = time.Now().AddDate(0, 0, 30) // 30d remaining < 60d lead
return c
}(),
expected: domain.PolicySeverityCritical,
},
{
name: "CertificateLifetime Error via 365d span vs 90d max",
rule: mkRule("r-cl", domain.PolicyTypeCertificateLifetime, domain.PolicySeverityError, `{"max_days": 90}`),
cert: evalCert(),
setupFn: func(svc *PolicyService) {
// Seed a version with 365d lifetime on the same cert ID used
// by evalCert().
cr := &mockCertRepo{
Certs: map[string]*domain.ManagedCertificate{},
Versions: map[string][]*domain.CertificateVersion{},
}
now := time.Now()
cr.Versions["cert-001"] = []*domain.CertificateVersion{{
ID: "ver-001",
CertificateID: "cert-001",
NotBefore: now.AddDate(0, 0, -10),
NotAfter: now.AddDate(1, 0, -10), // ~365d lifetime
}}
svc.SetCertRepo(cr)
},
expected: domain.PolicySeverityError,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
policyRepo := &mockPolicyRepo{
Rules: map[string]*domain.PolicyRule{tc.rule.ID: tc.rule},
Violations: []*domain.PolicyViolation{},
}
auditService := NewAuditService(&mockAuditRepo{})
svc := NewPolicyService(policyRepo, auditService)
if tc.setupFn != nil {
tc.setupFn(svc)
}
violations, err := svc.ValidateCertificate(ctx, tc.cert)
if err != nil {
t.Fatalf("ValidateCertificate failed: %v", err)
}
if len(violations) != 1 {
t.Fatalf("expected 1 violation, got %d", len(violations))
}
if violations[0].Severity != tc.expected {
t.Errorf("expected severity %q, got %q", tc.expected, violations[0].Severity)
}
if violations[0].RuleID != tc.rule.ID {
t.Errorf("expected rule ID %q, got %q", tc.rule.ID, violations[0].RuleID)
}
})
}
}
// TestEvaluateRule_ConfigConsumed pins invariant #2 — non-empty Config drives
// arm behavior (allowlists, thresholds, keys). Each subtest supplies a config
// that the cert would satisfy under the backward-compat missing-field path
// but violates under the config-aware path. A regression to the pre-D-008
// "config silently dropped" behavior would make these pass with 0 violations.
func TestEvaluateRule_ConfigConsumed(t *testing.T) {
ctx := context.Background()
t.Run("AllowedIssuers rejects issuer not in allowlist", func(t *testing.T) {
rule := mkRule("r-ai", domain.PolicyTypeAllowedIssuers, domain.PolicySeverityWarning,
`{"allowed_issuer_ids": ["iss-acme"]}`)
cert := evalCert()
cert.IssuerID = "iss-wrong"
violations := runEval(ctx, t, rule, cert, nil)
if len(violations) != 1 {
t.Fatalf("expected 1 violation for disallowed issuer, got %d", len(violations))
}
if !strings.Contains(violations[0].Message, "iss-wrong") {
t.Errorf("expected message to mention issuer ID, got %q", violations[0].Message)
}
})
t.Run("AllowedIssuers accepts issuer in allowlist", func(t *testing.T) {
rule := mkRule("r-ai", domain.PolicyTypeAllowedIssuers, domain.PolicySeverityWarning,
`{"allowed_issuer_ids": ["iss-acme"]}`)
cert := evalCert()
cert.IssuerID = "iss-acme"
violations := runEval(ctx, t, rule, cert, nil)
if len(violations) != 0 {
t.Errorf("expected 0 violations for allowed issuer, got %d", len(violations))
}
})
t.Run("AllowedDomains rejects SAN outside allowlist", func(t *testing.T) {
rule := mkRule("r-ad", domain.PolicyTypeAllowedDomains, domain.PolicySeverityWarning,
`{"allowed_domains": ["*.foo.com"]}`)
cert := evalCert()
cert.SANs = []string{"bar.elsewhere.com"}
violations := runEval(ctx, t, rule, cert, nil)
if len(violations) != 1 {
t.Fatalf("expected 1 violation for disallowed SAN, got %d", len(violations))
}
})
t.Run("AllowedDomains wildcard matches single-label subdomain", func(t *testing.T) {
rule := mkRule("r-ad", domain.PolicyTypeAllowedDomains, domain.PolicySeverityWarning,
`{"allowed_domains": ["*.foo.com"]}`)
cert := evalCert()
cert.SANs = []string{"bar.foo.com"}
violations := runEval(ctx, t, rule, cert, nil)
if len(violations) != 0 {
t.Errorf("expected 0 violations for single-label wildcard match, got %d", len(violations))
}
})
t.Run("AllowedDomains wildcard rejects multi-label subdomain", func(t *testing.T) {
// X.509 wildcard semantics: *.foo consumes exactly one label.
rule := mkRule("r-ad", domain.PolicyTypeAllowedDomains, domain.PolicySeverityWarning,
`{"allowed_domains": ["*.foo.com"]}`)
cert := evalCert()
cert.SANs = []string{"baz.bar.foo.com"}
violations := runEval(ctx, t, rule, cert, nil)
if len(violations) != 1 {
t.Errorf("expected 1 violation for multi-label wildcard (X.509 semantics), got %d", len(violations))
}
})
t.Run("RequiredMetadata rejects missing key", func(t *testing.T) {
rule := mkRule("r-rm", domain.PolicyTypeRequiredMetadata, domain.PolicySeverityWarning,
`{"required_keys": ["owner"]}`)
cert := evalCert()
cert.Tags = map[string]string{"team": "platform"}
violations := runEval(ctx, t, rule, cert, nil)
if len(violations) != 1 {
t.Fatalf("expected 1 violation for missing owner key, got %d", len(violations))
}
if !strings.Contains(violations[0].Message, "owner") {
t.Errorf("expected message to mention the missing key, got %q", violations[0].Message)
}
})
t.Run("RequiredMetadata accepts all required keys present", func(t *testing.T) {
rule := mkRule("r-rm", domain.PolicyTypeRequiredMetadata, domain.PolicySeverityWarning,
`{"required_keys": ["owner"]}`)
cert := evalCert()
cert.Tags = map[string]string{"owner": "alice"}
violations := runEval(ctx, t, rule, cert, nil)
if len(violations) != 0 {
t.Errorf("expected 0 violations when all required keys present, got %d", len(violations))
}
})
t.Run("AllowedEnvironments rejects env outside allowlist", func(t *testing.T) {
rule := mkRule("r-ae", domain.PolicyTypeAllowedEnvironments, domain.PolicySeverityWarning,
`{"allowed": ["production", "staging"]}`)
cert := evalCert()
cert.Environment = "wild-west"
violations := runEval(ctx, t, rule, cert, nil)
if len(violations) != 1 {
t.Fatalf("expected 1 violation for disallowed env, got %d", len(violations))
}
})
t.Run("RenewalLeadTime fires when remaining < configured lead", func(t *testing.T) {
rule := mkRule("r-rl", domain.PolicyTypeRenewalLeadTime, domain.PolicySeverityWarning,
`{"lead_time_days": 60}`)
cert := evalCert()
cert.ExpiresAt = time.Now().AddDate(0, 0, 30) // 30d < 60d lead
violations := runEval(ctx, t, rule, cert, nil)
if len(violations) != 1 {
t.Fatalf("expected 1 violation for 30d remaining vs 60d lead, got %d", len(violations))
}
})
t.Run("RenewalLeadTime quiet when remaining > configured lead", func(t *testing.T) {
rule := mkRule("r-rl", domain.PolicyTypeRenewalLeadTime, domain.PolicySeverityWarning,
`{"lead_time_days": 14}`)
cert := evalCert()
cert.ExpiresAt = time.Now().AddDate(0, 0, 60) // 60d > 14d lead
violations := runEval(ctx, t, rule, cert, nil)
if len(violations) != 0 {
t.Errorf("expected 0 violations when plenty of runway remains, got %d", len(violations))
}
})
t.Run("CertificateLifetime fires when lifetime exceeds max", func(t *testing.T) {
rule := mkRule("r-cl", domain.PolicyTypeCertificateLifetime, domain.PolicySeverityWarning,
`{"max_days": 90}`)
cert := evalCert()
now := time.Now()
certRepo := &mockCertRepo{
Certs: map[string]*domain.ManagedCertificate{},
Versions: map[string][]*domain.CertificateVersion{},
}
certRepo.Versions["cert-001"] = []*domain.CertificateVersion{{
ID: "ver-001",
CertificateID: "cert-001",
NotBefore: now.AddDate(0, 0, -1),
NotAfter: now.AddDate(1, 0, -1), // ~365d > 90d
}}
violations := runEval(ctx, t, rule, cert, certRepo)
if len(violations) != 1 {
t.Fatalf("expected 1 violation for 365d lifetime vs 90d max, got %d", len(violations))
}
if !strings.Contains(violations[0].Message, "90 days") {
t.Errorf("expected message to mention max_days threshold, got %q", violations[0].Message)
}
})
t.Run("CertificateLifetime quiet when lifetime within max", func(t *testing.T) {
rule := mkRule("r-cl", domain.PolicyTypeCertificateLifetime, domain.PolicySeverityWarning,
`{"max_days": 90}`)
cert := evalCert()
now := time.Now()
certRepo := &mockCertRepo{
Certs: map[string]*domain.ManagedCertificate{},
Versions: map[string][]*domain.CertificateVersion{},
}
certRepo.Versions["cert-001"] = []*domain.CertificateVersion{{
ID: "ver-001",
CertificateID: "cert-001",
NotBefore: now.AddDate(0, 0, -10),
NotAfter: now.AddDate(0, 0, 60), // 70d lifetime < 90d
}}
violations := runEval(ctx, t, rule, cert, certRepo)
if len(violations) != 0 {
t.Errorf("expected 0 violations for 70d lifetime under 90d max, got %d", len(violations))
}
})
}
// TestEvaluateRule_EmptyConfig_BackCompat pins invariant #3 — a rule with no
// Config (e.g., a legacy row from a pre-D-008 migration) still fires on the
// pre-D-008 missing-field shape using its configured severity. This is how
// we let existing deployments migrate without a schema rewrite.
func TestEvaluateRule_EmptyConfig_BackCompat(t *testing.T) {
ctx := context.Background()
t.Run("RequiredMetadata fires on zero tags", func(t *testing.T) {
rule := mkRule("r-rm", domain.PolicyTypeRequiredMetadata, domain.PolicySeverityError, "")
cert := evalCert()
cert.Tags = nil
violations := runEval(ctx, t, rule, cert, nil)
if len(violations) != 1 {
t.Fatalf("expected 1 backcompat violation, got %d", len(violations))
}
if violations[0].Severity != domain.PolicySeverityError {
t.Errorf("expected severity Error (passed through from rule), got %q", violations[0].Severity)
}
})
t.Run("RequiredMetadata quiet when any tags present under empty config", func(t *testing.T) {
// Empty config means "only fire on missing-field shape" — so a cert
// with any tags (even not what a human would call meaningful) passes.
rule := mkRule("r-rm", domain.PolicyTypeRequiredMetadata, domain.PolicySeverityError, "")
cert := evalCert()
cert.Tags = map[string]string{"arbitrary": "value"}
violations := runEval(ctx, t, rule, cert, nil)
if len(violations) != 0 {
t.Errorf("expected 0 violations under backcompat shape w/ tags set, got %d", len(violations))
}
})
t.Run("RenewalLeadTime uses 30d default under empty/zero config", func(t *testing.T) {
rule := mkRule("r-rl", domain.PolicyTypeRenewalLeadTime, domain.PolicySeverityWarning, "")
cert := evalCert()
cert.ExpiresAt = time.Now().AddDate(0, 0, 15) // 15d < 30d default
violations := runEval(ctx, t, rule, cert, nil)
if len(violations) != 1 {
t.Errorf("expected 1 violation under 30d backcompat default, got %d", len(violations))
}
})
}
// TestEvaluateRule_BadConfig_SkipsRule pins invariant #4 — malformed JSON in
// Config returns an error from evaluateRule, which ValidateCertificate logs
// and swallows. The pass continues; no zero-value violation is emitted.
// Co-located rules still fire normally.
func TestEvaluateRule_BadConfig_SkipsRule(t *testing.T) {
ctx := context.Background()
// Rule 1 has malformed JSON — should log+skip.
// Rule 2 is a healthy AllowedIssuers rule that should still emit its
// violation on the missing-IssuerID cert. If the bad rule poisoned the
// loop, we'd see 0 or 2 violations instead of exactly 1.
badRule := mkRule("r-bad", domain.PolicyTypeAllowedIssuers, domain.PolicySeverityError,
`{"allowed_issuer_ids": [`) // unterminated JSON
goodRule := mkRule("r-good", domain.PolicyTypeAllowedEnvironments, domain.PolicySeverityWarning, "")
policyRepo := &mockPolicyRepo{
Rules: map[string]*domain.PolicyRule{
badRule.ID: badRule,
goodRule.ID: goodRule,
},
Violations: []*domain.PolicyViolation{},
}
auditService := NewAuditService(&mockAuditRepo{})
svc := NewPolicyService(policyRepo, auditService)
cert := evalCert()
cert.IssuerID = "" // would trigger the bad rule if it wasn't skipped
cert.Environment = "" // triggers goodRule via missing-field backcompat
violations, err := svc.ValidateCertificate(ctx, cert)
if err != nil {
t.Fatalf("ValidateCertificate should swallow rule-eval errors, got %v", err)
}
if len(violations) != 1 {
t.Fatalf("expected exactly 1 violation (bad rule skipped, good rule fires), got %d", len(violations))
}
if violations[0].RuleID != goodRule.ID {
t.Errorf("expected violation from r-good, got %q", violations[0].RuleID)
}
}
// TestEvaluateRule_CertificateLifetime_RepoScenarios pins the setter-injection
// pattern for the 6th arm. SetCertRepo wires the dependency; without it the
// arm errors (logged+skipped by the caller). With it but no version present,
// the arm silently returns nil (matching the missing-field backcompat shape).
func TestEvaluateRule_CertificateLifetime_RepoScenarios(t *testing.T) {
ctx := context.Background()
t.Run("repo not wired logs and skips", func(t *testing.T) {
rule := mkRule("r-cl", domain.PolicyTypeCertificateLifetime, domain.PolicySeverityError,
`{"max_days": 90}`)
policyRepo := &mockPolicyRepo{
Rules: map[string]*domain.PolicyRule{rule.ID: rule},
Violations: []*domain.PolicyViolation{},
}
svc := NewPolicyService(policyRepo, NewAuditService(&mockAuditRepo{}))
// deliberately do NOT call SetCertRepo
violations, err := svc.ValidateCertificate(ctx, evalCert())
if err != nil {
t.Fatalf("ValidateCertificate should swallow the nil-repo error, got %v", err)
}
if len(violations) != 0 {
t.Errorf("expected 0 violations when repo unwired (rule skipped), got %d", len(violations))
}
})
t.Run("version missing silently skips", func(t *testing.T) {
rule := mkRule("r-cl", domain.PolicyTypeCertificateLifetime, domain.PolicySeverityError,
`{"max_days": 90}`)
policyRepo := &mockPolicyRepo{
Rules: map[string]*domain.PolicyRule{rule.ID: rule},
Violations: []*domain.PolicyViolation{},
}
svc := NewPolicyService(policyRepo, NewAuditService(&mockAuditRepo{}))
// Empty Versions map — GetLatestVersion returns errNotFound, arm skips.
svc.SetCertRepo(&mockCertRepo{
Certs: map[string]*domain.ManagedCertificate{},
Versions: map[string][]*domain.CertificateVersion{},
})
violations, err := svc.ValidateCertificate(ctx, evalCert())
if err != nil {
t.Fatalf("ValidateCertificate failed: %v", err)
}
if len(violations) != 0 {
t.Errorf("expected 0 violations when no version exists (nothing to measure), got %d", len(violations))
}
})
t.Run("max_days zero/absent means no enforcement", func(t *testing.T) {
// Even with a version, max_days=0 is a no-op (matches the
// no-threshold-configured guard in the arm).
rule := mkRule("r-cl", domain.PolicyTypeCertificateLifetime, domain.PolicySeverityError, "")
policyRepo := &mockPolicyRepo{
Rules: map[string]*domain.PolicyRule{rule.ID: rule},
Violations: []*domain.PolicyViolation{},
}
svc := NewPolicyService(policyRepo, NewAuditService(&mockAuditRepo{}))
now := time.Now()
svc.SetCertRepo(&mockCertRepo{
Certs: map[string]*domain.ManagedCertificate{},
Versions: map[string][]*domain.CertificateVersion{
"cert-001": {{
CertificateID: "cert-001",
NotBefore: now.AddDate(0, 0, -1),
NotAfter: now.AddDate(10, 0, 0), // 10 years — huge but unchecked
}},
},
})
violations, err := svc.ValidateCertificate(ctx, evalCert())
if err != nil {
t.Fatalf("ValidateCertificate failed: %v", err)
}
if len(violations) != 0 {
t.Errorf("expected 0 violations when max_days absent (no enforcement), got %d", len(violations))
}
})
}
// runEval is a test helper that exercises ValidateCertificate against a
// single-rule configuration and returns the violation slice. Optionally
// wires a cert repo for the CertificateLifetime arm.
func runEval(ctx context.Context, t *testing.T, rule *domain.PolicyRule, cert *domain.ManagedCertificate, certRepo *mockCertRepo) []*domain.PolicyViolation {
t.Helper()
policyRepo := &mockPolicyRepo{
Rules: map[string]*domain.PolicyRule{rule.ID: rule},
Violations: []*domain.PolicyViolation{},
}
svc := NewPolicyService(policyRepo, NewAuditService(&mockAuditRepo{}))
if certRepo != nil {
svc.SetCertRepo(certRepo)
}
violations, err := svc.ValidateCertificate(ctx, cert)
if err != nil {
t.Fatalf("ValidateCertificate failed: %v", err)
}
return violations
}
+13 -13
View File
@@ -28,7 +28,7 @@ func NewProfileService(
} }
// ListProfiles returns all profiles (handler interface method). // ListProfiles returns all profiles (handler interface method).
func (s *ProfileService) ListProfiles(page, perPage int) ([]domain.CertificateProfile, int64, error) { func (s *ProfileService) ListProfiles(ctx context.Context, page, perPage int) ([]domain.CertificateProfile, int64, error) {
if page < 1 { if page < 1 {
page = 1 page = 1
} }
@@ -36,7 +36,7 @@ func (s *ProfileService) ListProfiles(page, perPage int) ([]domain.CertificatePr
perPage = 50 perPage = 50
} }
profiles, err := s.profileRepo.List(context.Background()) profiles, err := s.profileRepo.List(ctx)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to list profiles: %w", err) return nil, 0, fmt.Errorf("failed to list profiles: %w", err)
} }
@@ -53,12 +53,12 @@ func (s *ProfileService) ListProfiles(page, perPage int) ([]domain.CertificatePr
} }
// GetProfile returns a single profile (handler interface method). // GetProfile returns a single profile (handler interface method).
func (s *ProfileService) GetProfile(id string) (*domain.CertificateProfile, error) { func (s *ProfileService) GetProfile(ctx context.Context, id string) (*domain.CertificateProfile, error) {
return s.profileRepo.Get(context.Background(), id) return s.profileRepo.Get(ctx, id)
} }
// CreateProfile creates a new profile with validation (handler interface method). // CreateProfile creates a new profile with validation (handler interface method).
func (s *ProfileService) CreateProfile(profile domain.CertificateProfile) (*domain.CertificateProfile, error) { func (s *ProfileService) CreateProfile(ctx context.Context, profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
if err := validateProfile(&profile); err != nil { if err := validateProfile(&profile); err != nil {
return nil, err return nil, err
} }
@@ -82,12 +82,12 @@ func (s *ProfileService) CreateProfile(profile domain.CertificateProfile) (*doma
profile.AllowedEKUs = domain.DefaultEKUs() profile.AllowedEKUs = domain.DefaultEKUs()
} }
if err := s.profileRepo.Create(context.Background(), &profile); err != nil { if err := s.profileRepo.Create(ctx, &profile); err != nil {
return nil, fmt.Errorf("failed to create profile: %w", err) return nil, fmt.Errorf("failed to create profile: %w", err)
} }
if s.auditService != nil { if s.auditService != nil {
if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser, if auditErr := s.auditService.RecordEvent(context.WithoutCancel(ctx), "api", domain.ActorTypeUser,
"create_profile", "certificate_profile", profile.ID, nil); auditErr != nil { "create_profile", "certificate_profile", profile.ID, nil); auditErr != nil {
slog.Error("failed to record audit event", "error", auditErr) slog.Error("failed to record audit event", "error", auditErr)
} }
@@ -97,18 +97,18 @@ func (s *ProfileService) CreateProfile(profile domain.CertificateProfile) (*doma
} }
// UpdateProfile modifies an existing profile (handler interface method). // UpdateProfile modifies an existing profile (handler interface method).
func (s *ProfileService) UpdateProfile(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) { func (s *ProfileService) UpdateProfile(ctx context.Context, id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
if err := validateProfile(&profile); err != nil { if err := validateProfile(&profile); err != nil {
return nil, err return nil, err
} }
profile.ID = id profile.ID = id
if err := s.profileRepo.Update(context.Background(), &profile); err != nil { if err := s.profileRepo.Update(ctx, &profile); err != nil {
return nil, fmt.Errorf("failed to update profile: %w", err) return nil, fmt.Errorf("failed to update profile: %w", err)
} }
if s.auditService != nil { if s.auditService != nil {
if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser, if auditErr := s.auditService.RecordEvent(context.WithoutCancel(ctx), "api", domain.ActorTypeUser,
"update_profile", "certificate_profile", id, nil); auditErr != nil { "update_profile", "certificate_profile", id, nil); auditErr != nil {
slog.Error("failed to record audit event", "error", auditErr) slog.Error("failed to record audit event", "error", auditErr)
} }
@@ -118,13 +118,13 @@ func (s *ProfileService) UpdateProfile(id string, profile domain.CertificateProf
} }
// DeleteProfile removes a profile (handler interface method). // DeleteProfile removes a profile (handler interface method).
func (s *ProfileService) DeleteProfile(id string) error { func (s *ProfileService) DeleteProfile(ctx context.Context, id string) error {
if err := s.profileRepo.Delete(context.Background(), id); err != nil { if err := s.profileRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("failed to delete profile: %w", err) return fmt.Errorf("failed to delete profile: %w", err)
} }
if s.auditService != nil { if s.auditService != nil {
if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser, if auditErr := s.auditService.RecordEvent(context.WithoutCancel(ctx), "api", domain.ActorTypeUser,
"delete_profile", "certificate_profile", id, nil); auditErr != nil { "delete_profile", "certificate_profile", id, nil); auditErr != nil {
slog.Error("failed to record audit event", "error", auditErr) slog.Error("failed to record audit event", "error", auditErr)
} }
+13 -13
View File
@@ -82,7 +82,7 @@ func TestProfileService_ListProfiles(t *testing.T) {
repo.AddProfile(&domain.CertificateProfile{ID: "prof-2", Name: "Internal mTLS", Enabled: true}) repo.AddProfile(&domain.CertificateProfile{ID: "prof-2", Name: "Internal mTLS", Enabled: true})
svc := NewProfileService(repo, nil) svc := NewProfileService(repo, nil)
profiles, total, err := svc.ListProfiles(1, 50) profiles, total, err := svc.ListProfiles(context.Background(), 1, 50)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -98,7 +98,7 @@ func TestProfileService_ListProfiles_Empty(t *testing.T) {
repo := newMockProfileRepository() repo := newMockProfileRepository()
svc := NewProfileService(repo, nil) svc := NewProfileService(repo, nil)
profiles, total, err := svc.ListProfiles(1, 50) profiles, total, err := svc.ListProfiles(context.Background(), 1, 50)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -115,7 +115,7 @@ func TestProfileService_ListProfiles_RepoError(t *testing.T) {
repo.ListErr = errors.New("db error") repo.ListErr = errors.New("db error")
svc := NewProfileService(repo, nil) svc := NewProfileService(repo, nil)
_, _, err := svc.ListProfiles(1, 50) _, _, err := svc.ListProfiles(context.Background(), 1, 50)
if err == nil { if err == nil {
t.Fatal("expected error, got nil") t.Fatal("expected error, got nil")
} }
@@ -126,7 +126,7 @@ func TestProfileService_GetProfile(t *testing.T) {
repo.AddProfile(&domain.CertificateProfile{ID: "prof-1", Name: "Standard TLS"}) repo.AddProfile(&domain.CertificateProfile{ID: "prof-1", Name: "Standard TLS"})
svc := NewProfileService(repo, nil) svc := NewProfileService(repo, nil)
profile, err := svc.GetProfile("prof-1") profile, err := svc.GetProfile(context.Background(), "prof-1")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -139,7 +139,7 @@ func TestProfileService_GetProfile_NotFound(t *testing.T) {
repo := newMockProfileRepository() repo := newMockProfileRepository()
svc := NewProfileService(repo, nil) svc := NewProfileService(repo, nil)
_, err := svc.GetProfile("nonexistent") _, err := svc.GetProfile(context.Background(), "nonexistent")
if err == nil { if err == nil {
t.Fatal("expected error, got nil") t.Fatal("expected error, got nil")
} }
@@ -156,7 +156,7 @@ func TestProfileService_CreateProfile_Defaults(t *testing.T) {
MaxTTLSeconds: 86400, MaxTTLSeconds: 86400,
} }
created, err := svc.CreateProfile(profile) created, err := svc.CreateProfile(context.Background(), profile)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -258,7 +258,7 @@ func TestProfileService_CreateProfile_ValidationErrors(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, err := svc.CreateProfile(tt.profile) _, err := svc.CreateProfile(context.Background(), tt.profile)
if err == nil { if err == nil {
t.Fatalf("expected error containing %q, got nil", tt.errMsg) t.Fatalf("expected error containing %q, got nil", tt.errMsg)
} }
@@ -274,7 +274,7 @@ func TestProfileService_CreateProfile_RepoError(t *testing.T) {
repo.CreateErr = errors.New("db create failed") repo.CreateErr = errors.New("db create failed")
svc := NewProfileService(repo, nil) svc := NewProfileService(repo, nil)
_, err := svc.CreateProfile(domain.CertificateProfile{Name: "Valid"}) _, err := svc.CreateProfile(context.Background(), domain.CertificateProfile{Name: "Valid"})
if err == nil { if err == nil {
t.Fatal("expected error, got nil") t.Fatal("expected error, got nil")
} }
@@ -287,7 +287,7 @@ func TestProfileService_UpdateProfile(t *testing.T) {
auditSvc := NewAuditService(auditRepo) auditSvc := NewAuditService(auditRepo)
svc := NewProfileService(repo, auditSvc) svc := NewProfileService(repo, auditSvc)
updated, err := svc.UpdateProfile("prof-1", domain.CertificateProfile{ updated, err := svc.UpdateProfile(context.Background(), "prof-1", domain.CertificateProfile{
Name: "Updated", Name: "Updated",
MaxTTLSeconds: 43200, MaxTTLSeconds: 43200,
}) })
@@ -306,7 +306,7 @@ func TestProfileService_UpdateProfile_ValidationError(t *testing.T) {
repo := newMockProfileRepository() repo := newMockProfileRepository()
svc := NewProfileService(repo, nil) svc := NewProfileService(repo, nil)
_, err := svc.UpdateProfile("prof-1", domain.CertificateProfile{Name: ""}) _, err := svc.UpdateProfile(context.Background(), "prof-1", domain.CertificateProfile{Name: ""})
if err == nil { if err == nil {
t.Fatal("expected validation error, got nil") t.Fatal("expected validation error, got nil")
} }
@@ -319,7 +319,7 @@ func TestProfileService_DeleteProfile(t *testing.T) {
auditSvc := NewAuditService(auditRepo) auditSvc := NewAuditService(auditRepo)
svc := NewProfileService(repo, auditSvc) svc := NewProfileService(repo, auditSvc)
err := svc.DeleteProfile("prof-1") err := svc.DeleteProfile(context.Background(), "prof-1")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -333,7 +333,7 @@ func TestProfileService_DeleteProfile_RepoError(t *testing.T) {
repo.DeleteErr = errors.New("db delete failed") repo.DeleteErr = errors.New("db delete failed")
svc := NewProfileService(repo, nil) svc := NewProfileService(repo, nil)
err := svc.DeleteProfile("prof-1") err := svc.DeleteProfile(context.Background(), "prof-1")
if err == nil { if err == nil {
t.Fatal("expected error, got nil") t.Fatal("expected error, got nil")
} }
@@ -344,7 +344,7 @@ func TestProfileService_CreateProfile_ValidShortLived(t *testing.T) {
svc := NewProfileService(repo, nil) svc := NewProfileService(repo, nil)
// Short-lived with TTL under 1 hour should succeed // Short-lived with TTL under 1 hour should succeed
created, err := svc.CreateProfile(domain.CertificateProfile{ created, err := svc.CreateProfile(context.Background(), domain.CertificateProfile{
Name: "CI Ephemeral", Name: "CI Ephemeral",
AllowShortLived: true, AllowShortLived: true,
MaxTTLSeconds: 300, // 5 minutes MaxTTLSeconds: 300, // 5 minutes
+2 -2
View File
@@ -151,9 +151,9 @@ func (s *RevocationSvc) RevokeCertificateWithActor(ctx context.Context, certID s
} }
// GetRevokedCertificates returns all revoked certificate records (for CRL generation). // GetRevokedCertificates returns all revoked certificate records (for CRL generation).
func (s *RevocationSvc) GetRevokedCertificates() ([]*domain.CertificateRevocation, error) { func (s *RevocationSvc) GetRevokedCertificates(ctx context.Context) ([]*domain.CertificateRevocation, error) {
if s.revocationRepo == nil { if s.revocationRepo == nil {
return nil, fmt.Errorf("revocation repository not configured") return nil, fmt.Errorf("revocation repository not configured")
} }
return s.revocationRepo.ListAll(context.Background()) return s.revocationRepo.ListAll(ctx)
} }
+1 -1
View File
@@ -122,7 +122,7 @@ func TestRevocationSvc_GetRevokedCertificates_Success(t *testing.T) {
{ID: "rev-2", CertificateID: "cert-2", SerialNumber: "SER-2", Reason: "superseded", RevokedAt: time.Now()}, {ID: "rev-2", CertificateID: "cert-2", SerialNumber: "SER-2", Reason: "superseded", RevokedAt: time.Now()},
} }
revocations, err := revSvc.GetRevokedCertificates() revocations, err := revSvc.GetRevokedCertificates(context.Background())
if err != nil { if err != nil {
t.Fatalf("expected no error, got: %v", err) t.Fatalf("expected no error, got: %v", err)
} }
+23 -23
View File
@@ -62,7 +62,7 @@ func TestRevokeCertificate_Success(t *testing.T) {
certRepo.Versions["cert-1"] = []*domain.CertificateVersion{version} certRepo.Versions["cert-1"] = []*domain.CertificateVersion{version}
// Revoke // Revoke
err := svc.RevokeCertificateWithActor(context.Background(), "cert-1", "keyCompromise", "admin") err := svc.RevokeCertificate(context.Background(), "cert-1", "keyCompromise", "admin")
if err != nil { if err != nil {
t.Fatalf("expected no error, got: %v", err) t.Fatalf("expected no error, got: %v", err)
} }
@@ -125,7 +125,7 @@ func TestRevokeCertificate_DefaultReason(t *testing.T) {
} }
// Revoke with empty reason — should default to "unspecified" // Revoke with empty reason — should default to "unspecified"
err := svc.RevokeCertificateWithActor(context.Background(), "cert-2", "", "api") err := svc.RevokeCertificate(context.Background(), "cert-2", "", "api")
if err != nil { if err != nil {
t.Fatalf("expected no error, got: %v", err) t.Fatalf("expected no error, got: %v", err)
} }
@@ -158,7 +158,7 @@ func TestRevokeCertificate_AlreadyRevoked(t *testing.T) {
} }
certRepo.AddCert(cert) certRepo.AddCert(cert)
err := svc.RevokeCertificateWithActor(context.Background(), "cert-3", "superseded", "admin") err := svc.RevokeCertificate(context.Background(), "cert-3", "superseded", "admin")
if err == nil { if err == nil {
t.Fatal("expected error for already revoked certificate") t.Fatal("expected error for already revoked certificate")
} }
@@ -179,7 +179,7 @@ func TestRevokeCertificate_ArchivedCert(t *testing.T) {
} }
certRepo.AddCert(cert) certRepo.AddCert(cert)
err := svc.RevokeCertificateWithActor(context.Background(), "cert-4", "keyCompromise", "admin") err := svc.RevokeCertificate(context.Background(), "cert-4", "keyCompromise", "admin")
if err == nil { if err == nil {
t.Fatal("expected error for archived certificate") t.Fatal("expected error for archived certificate")
} }
@@ -200,7 +200,7 @@ func TestRevokeCertificate_InvalidReason(t *testing.T) {
} }
certRepo.AddCert(cert) certRepo.AddCert(cert)
err := svc.RevokeCertificateWithActor(context.Background(), "cert-5", "notAValidReason", "admin") err := svc.RevokeCertificate(context.Background(), "cert-5", "notAValidReason", "admin")
if err == nil { if err == nil {
t.Fatal("expected error for invalid reason") t.Fatal("expected error for invalid reason")
} }
@@ -212,7 +212,7 @@ func TestRevokeCertificate_InvalidReason(t *testing.T) {
func TestRevokeCertificate_NotFound(t *testing.T) { func TestRevokeCertificate_NotFound(t *testing.T) {
svc, _, _, _ := newRevocationTestService() svc, _, _, _ := newRevocationTestService()
err := svc.RevokeCertificateWithActor(context.Background(), "nonexistent-cert", "keyCompromise", "admin") err := svc.RevokeCertificate(context.Background(), "nonexistent-cert", "keyCompromise", "admin")
if err == nil { if err == nil {
t.Fatal("expected error for nonexistent certificate") t.Fatal("expected error for nonexistent certificate")
} }
@@ -231,7 +231,7 @@ func TestRevokeCertificate_NoVersion(t *testing.T) {
certRepo.AddCert(cert) certRepo.AddCert(cert)
// No versions added — should fail // No versions added — should fail
err := svc.RevokeCertificateWithActor(context.Background(), "cert-6", "keyCompromise", "admin") err := svc.RevokeCertificate(context.Background(), "cert-6", "keyCompromise", "admin")
if err == nil { if err == nil {
t.Fatal("expected error when no certificate version exists") t.Fatal("expected error when no certificate version exists")
} }
@@ -258,7 +258,7 @@ func TestRevokeCertificate_WithIssuerNotification(t *testing.T) {
{ID: "ver-7", CertificateID: "cert-7", SerialNumber: "GHI789", CreatedAt: time.Now()}, {ID: "ver-7", CertificateID: "cert-7", SerialNumber: "GHI789", CreatedAt: time.Now()},
} }
err := svc.RevokeCertificateWithActor(context.Background(), "cert-7", "cessationOfOperation", "admin") err := svc.RevokeCertificate(context.Background(), "cert-7", "cessationOfOperation", "admin")
if err != nil { if err != nil {
t.Fatalf("expected no error, got: %v", err) t.Fatalf("expected no error, got: %v", err)
} }
@@ -293,7 +293,7 @@ func TestRevokeCertificate_WithNotificationService(t *testing.T) {
{ID: "ver-8", CertificateID: "cert-8", SerialNumber: "JKL012", CreatedAt: time.Now()}, {ID: "ver-8", CertificateID: "cert-8", SerialNumber: "JKL012", CreatedAt: time.Now()},
} }
err := svc.RevokeCertificateWithActor(context.Background(), "cert-8", "keyCompromise", "admin") err := svc.RevokeCertificate(context.Background(), "cert-8", "keyCompromise", "admin")
if err != nil { if err != nil {
t.Fatalf("expected no error, got: %v", err) t.Fatalf("expected no error, got: %v", err)
} }
@@ -336,7 +336,7 @@ func TestRevokeCertificate_AllValidReasons(t *testing.T) {
{ID: "ver-" + reason, CertificateID: "cert-" + reason, SerialNumber: "SER-" + reason, CreatedAt: time.Now()}, {ID: "ver-" + reason, CertificateID: "cert-" + reason, SerialNumber: "SER-" + reason, CreatedAt: time.Now()},
} }
err := svc.RevokeCertificateWithActor(context.Background(), "cert-"+reason, reason, "admin") err := svc.RevokeCertificate(context.Background(), "cert-"+reason, reason, "admin")
if err != nil { if err != nil {
t.Fatalf("expected no error for reason %s, got: %v", reason, err) t.Fatalf("expected no error for reason %s, got: %v", reason, err)
} }
@@ -358,7 +358,7 @@ func TestGetRevokedCertificates_Success(t *testing.T) {
{ID: "rev-2", CertificateID: "cert-2", SerialNumber: "SER-2", Reason: "superseded", RevokedAt: time.Now()}, {ID: "rev-2", CertificateID: "cert-2", SerialNumber: "SER-2", Reason: "superseded", RevokedAt: time.Now()},
} }
revocations, err := svc.GetRevokedCertificates() revocations, err := svc.GetRevokedCertificates(context.Background())
if err != nil { if err != nil {
t.Fatalf("expected no error, got: %v", err) t.Fatalf("expected no error, got: %v", err)
} }
@@ -370,7 +370,7 @@ func TestGetRevokedCertificates_Success(t *testing.T) {
func TestGetRevokedCertificates_Empty(t *testing.T) { func TestGetRevokedCertificates_Empty(t *testing.T) {
svc, _, _, _ := newRevocationTestService() svc, _, _, _ := newRevocationTestService()
revocations, err := svc.GetRevokedCertificates() revocations, err := svc.GetRevokedCertificates(context.Background())
if err != nil { if err != nil {
t.Fatalf("expected no error, got: %v", err) t.Fatalf("expected no error, got: %v", err)
} }
@@ -390,7 +390,7 @@ func TestGetRevokedCertificates_NoRepo(t *testing.T) {
svc := NewCertificateService(certRepo, policyService, auditService) svc := NewCertificateService(certRepo, policyService, auditService)
// Do NOT set revocation repo // Do NOT set revocation repo
_, err := svc.GetRevokedCertificates() _, err := svc.GetRevokedCertificates(context.Background())
if err == nil { if err == nil {
t.Fatal("expected error when revocation repo not configured") t.Fatal("expected error when revocation repo not configured")
} }
@@ -411,8 +411,8 @@ func TestRevokeCertificate_HandlerInterfaceMethod(t *testing.T) {
{ID: "ver-handler", CertificateID: "cert-handler", SerialNumber: "SER-HANDLER", CreatedAt: time.Now()}, {ID: "ver-handler", CertificateID: "cert-handler", SerialNumber: "SER-HANDLER", CreatedAt: time.Now()},
} }
// Test the handler interface method (no actor param) // Test the handler interface method (actor collapsed to required positional arg per D-2)
err := svc.RevokeCertificate("cert-handler", "superseded") err := svc.RevokeCertificate(context.Background(), "cert-handler", "superseded", "api")
if err != nil { if err != nil {
t.Fatalf("expected no error, got: %v", err) t.Fatalf("expected no error, got: %v", err)
} }
@@ -449,7 +449,7 @@ func TestGenerateDERCRL_Success(t *testing.T) {
}, },
} }
crl, err := svc.GenerateDERCRL("iss-local") crl, err := svc.GenerateDERCRL(context.Background(), "iss-local")
if err != nil { if err != nil {
t.Fatalf("expected no error, got: %v", err) t.Fatalf("expected no error, got: %v", err)
@@ -472,7 +472,7 @@ func TestGenerateDERCRL_EmptyCRL(t *testing.T) {
// No revoked certs for this issuer // No revoked certs for this issuer
revocationRepo.Revocations = []*domain.CertificateRevocation{} revocationRepo.Revocations = []*domain.CertificateRevocation{}
crl, err := svc.GenerateDERCRL("iss-local") crl, err := svc.GenerateDERCRL(context.Background(), "iss-local")
if err != nil { if err != nil {
t.Fatalf("expected no error, got: %v", err) t.Fatalf("expected no error, got: %v", err)
@@ -493,7 +493,7 @@ func TestGenerateDERCRL_IssuerNotFound(t *testing.T) {
svc, _, _, _ := newRevocationTestService() svc, _, _, _ := newRevocationTestService()
// Try to generate CRL for unknown issuer // Try to generate CRL for unknown issuer
crl, err := svc.GenerateDERCRL("iss-unknown") crl, err := svc.GenerateDERCRL(context.Background(), "iss-unknown")
// Should return error or nil CRL depending on implementation // Should return error or nil CRL depending on implementation
if crl != nil && err == nil { if crl != nil && err == nil {
@@ -527,7 +527,7 @@ func TestGetOCSPResponse_Good(t *testing.T) {
certRepo.Versions["cert-ocsp-good"] = []*domain.CertificateVersion{version} certRepo.Versions["cert-ocsp-good"] = []*domain.CertificateVersion{version}
// Request OCSP response for good cert // Request OCSP response for good cert
resp, err := svc.GetOCSPResponse("iss-local", "OCSP-GOOD-001") resp, err := svc.GetOCSPResponse(context.Background(), "iss-local", "OCSP-GOOD-001")
if err != nil { if err != nil {
t.Fatalf("expected no error, got: %v", err) t.Fatalf("expected no error, got: %v", err)
@@ -580,7 +580,7 @@ func TestGetOCSPResponse_Revoked(t *testing.T) {
} }
// Request OCSP response for revoked cert // Request OCSP response for revoked cert
resp, err := svc.GetOCSPResponse("iss-local", "OCSP-REVOKED-001") resp, err := svc.GetOCSPResponse(context.Background(), "iss-local", "OCSP-REVOKED-001")
if err != nil { if err != nil {
t.Fatalf("expected no error, got: %v", err) t.Fatalf("expected no error, got: %v", err)
@@ -597,7 +597,7 @@ func TestGetOCSPResponse_Unknown(t *testing.T) {
svc, _, _, _ := newRevocationTestService() svc, _, _, _ := newRevocationTestService()
// Request OCSP response for unknown cert // Request OCSP response for unknown cert
resp, err := svc.GetOCSPResponse("iss-local", "UNKNOWN-SERIAL") resp, err := svc.GetOCSPResponse(context.Background(), "iss-local", "UNKNOWN-SERIAL")
if err != nil { if err != nil {
t.Fatalf("expected no error (should return unknown status), got: %v", err) t.Fatalf("expected no error (should return unknown status), got: %v", err)
@@ -615,7 +615,7 @@ func TestGetOCSPResponse_IssuerNotFound(t *testing.T) {
svc, _, _, _ := newRevocationTestService() svc, _, _, _ := newRevocationTestService()
// Request OCSP response for unknown issuer // Request OCSP response for unknown issuer
resp, err := svc.GetOCSPResponse("iss-unknown", "SOME-SERIAL") resp, err := svc.GetOCSPResponse(context.Background(), "iss-unknown", "SOME-SERIAL")
// Should return error since issuer doesn't exist // Should return error since issuer doesn't exist
if err == nil && resp != nil { if err == nil && resp != nil {
@@ -629,7 +629,7 @@ func TestGetOCSPResponse_InvalidSerial(t *testing.T) {
svc, _, _, _ := newRevocationTestService() svc, _, _, _ := newRevocationTestService()
// Request OCSP response with invalid serial format // Request OCSP response with invalid serial format
resp, err := svc.GetOCSPResponse("iss-local", "") resp, err := svc.GetOCSPResponse(context.Background(), "iss-local", "")
if err == nil && resp != nil { if err == nil && resp != nil {
// Empty serial might return unknown status; that's ok // Empty serial might return unknown status; that's ok
+42 -19
View File
@@ -3,6 +3,7 @@ package service
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"time" "time"
@@ -12,6 +13,13 @@ import (
"github.com/shankar0123/certctl/internal/repository" "github.com/shankar0123/certctl/internal/repository"
) )
// ErrAgentNotFound is returned by [TargetService.CreateTarget] when the caller
// references an agent_id that is empty or does not correspond to a registered
// agent. The handler layer maps this to HTTP 400 via [errors.Is]. See C-002 in
// cowork/certctl-coverage-gap-audit.md — this sentinel replaces a silent
// Postgres FK violation (23503 → HTTP 500) with a deterministic 400.
var ErrAgentNotFound = errors.New("referenced agent does not exist")
// validTargetTypes is the set of allowed target types for validation. // validTargetTypes is the set of allowed target types for validation.
var validTargetTypes = map[domain.TargetType]bool{ var validTargetTypes = map[domain.TargetType]bool{
domain.TargetTypeNGINX: true, domain.TargetTypeNGINX: true,
@@ -36,20 +44,27 @@ func isValidTargetType(t domain.TargetType) bool {
} }
// TargetService provides business logic for deployment target management. // TargetService provides business logic for deployment target management.
//
// The encryptionKey field holds the raw passphrase (not a pre-derived 32-byte
// key). Per-ciphertext salt derivation is performed inside
// [crypto.EncryptIfKeySet] / [crypto.DecryptIfKeySet] on each call. See M-8
// in certctl-audit-report.md.
type TargetService struct { type TargetService struct {
targetRepo repository.TargetRepository targetRepo repository.TargetRepository
agentRepo repository.AgentRepository agentRepo repository.AgentRepository
auditService *AuditService auditService *AuditService
encryptionKey []byte encryptionKey string
logger *slog.Logger logger *slog.Logger
} }
// NewTargetService creates a new target service. // NewTargetService creates a new target service. The encryptionKey is the raw
// passphrase; it MUST NOT be pre-derived via crypto.DeriveKey (that was the
// v1 behavior, replaced in M-8 with per-ciphertext random salt).
func NewTargetService( func NewTargetService(
targetRepo repository.TargetRepository, targetRepo repository.TargetRepository,
auditService *AuditService, auditService *AuditService,
agentRepo repository.AgentRepository, agentRepo repository.AgentRepository,
encryptionKey []byte, encryptionKey string,
logger *slog.Logger, logger *slog.Logger,
) *TargetService { ) *TargetService {
return &TargetService{ return &TargetService{
@@ -235,7 +250,7 @@ func (s *TargetService) TestConnection(ctx context.Context, id string) error {
} }
// ListTargets returns paginated targets (handler interface method). // ListTargets returns paginated targets (handler interface method).
func (s *TargetService) ListTargets(page, perPage int) ([]domain.DeploymentTarget, int64, error) { func (s *TargetService) ListTargets(ctx context.Context, page, perPage int) ([]domain.DeploymentTarget, int64, error) {
if page < 1 { if page < 1 {
page = 1 page = 1
} }
@@ -243,7 +258,7 @@ func (s *TargetService) ListTargets(page, perPage int) ([]domain.DeploymentTarge
perPage = 50 perPage = 50
} }
targets, err := s.targetRepo.List(context.Background()) targets, err := s.targetRepo.List(ctx)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to list targets: %w", err) return nil, 0, fmt.Errorf("failed to list targets: %w", err)
} }
@@ -260,15 +275,28 @@ func (s *TargetService) ListTargets(page, perPage int) ([]domain.DeploymentTarge
} }
// GetTarget returns a single target (handler interface method). // GetTarget returns a single target (handler interface method).
func (s *TargetService) GetTarget(id string) (*domain.DeploymentTarget, error) { func (s *TargetService) GetTarget(ctx context.Context, id string) (*domain.DeploymentTarget, error) {
return s.targetRepo.Get(context.Background(), id) return s.targetRepo.Get(ctx, id)
} }
// CreateTarget creates a new target (handler interface method). // CreateTarget creates a new target (handler interface method).
func (s *TargetService) CreateTarget(target domain.DeploymentTarget) (*domain.DeploymentTarget, error) { func (s *TargetService) CreateTarget(ctx context.Context, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
if !isValidTargetType(target.Type) { if !isValidTargetType(target.Type) {
return nil, fmt.Errorf("unsupported target type: %s", target.Type) return nil, fmt.Errorf("unsupported target type: %s", target.Type)
} }
// C-002: enforce agent_id FK at service layer so we return a clean 400
// instead of bubbling a Postgres 23503 foreign-key violation out as 500.
// The schema (migrations/000001 line 104) declares agent_id TEXT NOT NULL
// with a FK to agents(id); we mirror that contract here for deterministic
// error mapping.
if target.AgentID == "" {
return nil, fmt.Errorf("%w: agent_id is required", ErrAgentNotFound)
}
if _, err := s.agentRepo.Get(ctx, target.AgentID); err != nil {
return nil, fmt.Errorf("%w: %s", ErrAgentNotFound, target.AgentID)
}
if target.ID == "" { if target.ID == "" {
target.ID = generateID("target") target.ID = generateID("target")
} }
@@ -301,20 +329,20 @@ func (s *TargetService) CreateTarget(target domain.DeploymentTarget) (*domain.De
target.Config = redactConfigJSON(target.Config) target.Config = redactConfigJSON(target.Config)
} }
if err := s.targetRepo.Create(context.Background(), &target); err != nil { if err := s.targetRepo.Create(ctx, &target); err != nil {
return nil, fmt.Errorf("failed to create target: %w", err) return nil, fmt.Errorf("failed to create target: %w", err)
} }
return &target, nil return &target, nil
} }
// UpdateTarget modifies a target (handler interface method). // UpdateTarget modifies a target (handler interface method).
func (s *TargetService) UpdateTarget(id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) { func (s *TargetService) UpdateTarget(ctx context.Context, id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
target.ID = id target.ID = id
target.UpdatedAt = time.Now() target.UpdatedAt = time.Now()
// Merge redacted fields with existing config // Merge redacted fields with existing config
if len(target.Config) > 0 { if len(target.Config) > 0 {
mergedConfig, err := s.mergeRedactedConfig(context.Background(), id, target.Config) mergedConfig, err := s.mergeRedactedConfig(ctx, id, target.Config)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to merge config: %w", err) return nil, fmt.Errorf("failed to merge config: %w", err)
} }
@@ -327,20 +355,15 @@ func (s *TargetService) UpdateTarget(id string, target domain.DeploymentTarget)
target.Config = redactConfigJSON(json.RawMessage(mergedConfig)) target.Config = redactConfigJSON(json.RawMessage(mergedConfig))
} }
if err := s.targetRepo.Update(context.Background(), &target); err != nil { if err := s.targetRepo.Update(ctx, &target); err != nil {
return nil, fmt.Errorf("failed to update target: %w", err) return nil, fmt.Errorf("failed to update target: %w", err)
} }
return &target, nil return &target, nil
} }
// DeleteTarget removes a target (handler interface method). // DeleteTarget removes a target (handler interface method).
func (s *TargetService) DeleteTarget(id string) error { func (s *TargetService) DeleteTarget(ctx context.Context, id string) error {
return s.targetRepo.Delete(context.Background(), id) return s.targetRepo.Delete(ctx, id)
}
// TestTargetConnection tests target connectivity (handler interface method).
func (s *TargetService) TestTargetConnection(id string) error {
return s.TestConnection(context.Background(), id)
} }
// --- Internal helpers --- // --- Internal helpers ---
+67 -7
View File
@@ -3,6 +3,7 @@ package service
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"log/slog" "log/slog"
"os" "os"
"testing" "testing"
@@ -344,7 +345,8 @@ func TestTargetService_ListTargets_Success(t *testing.T) {
targetRepo.AddTarget(target2) targetRepo.AddTarget(target2)
// Call handler-interface method // Call handler-interface method
targets, total, err := svc.ListTargets(1, 50) ctx := context.Background()
targets, total, err := svc.ListTargets(ctx, 1, 50)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -364,7 +366,8 @@ func TestTargetService_GetTarget_Success(t *testing.T) {
target := &domain.DeploymentTarget{ID: "t-1", Name: "Target 1", Type: domain.TargetTypeNGINX} target := &domain.DeploymentTarget{ID: "t-1", Name: "Target 1", Type: domain.TargetTypeNGINX}
targetRepo.AddTarget(target) targetRepo.AddTarget(target)
result, err := svc.GetTarget("t-1") ctx := context.Background()
result, err := svc.GetTarget(ctx, "t-1")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -375,14 +378,21 @@ func TestTargetService_GetTarget_Success(t *testing.T) {
} }
func TestTargetService_CreateTarget_Success(t *testing.T) { func TestTargetService_CreateTarget_Success(t *testing.T) {
svc, targetRepo, _, _ := newTestTargetService() svc, targetRepo, _, agentRepo := newTestTargetService()
// C-002: CreateTarget now pre-validates agent_id against agentRepo. Seed a
// real agent so the happy path still exercises the normal creation flow
// without tripping the new ErrAgentNotFound guard.
agentRepo.AddAgent(&domain.Agent{ID: "a-1", Name: "test-agent"})
target := domain.DeploymentTarget{ target := domain.DeploymentTarget{
Name: "New Target", Name: "New Target",
Type: domain.TargetTypeNGINX, Type: domain.TargetTypeNGINX,
AgentID: "a-1",
} }
result, err := svc.CreateTarget(target) ctx := context.Background()
result, err := svc.CreateTarget(ctx, target)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -405,12 +415,60 @@ func TestTargetService_CreateTarget_InvalidType(t *testing.T) {
Type: domain.TargetType("Unknown"), Type: domain.TargetType("Unknown"),
} }
_, err := svc.CreateTarget(target) ctx := context.Background()
_, err := svc.CreateTarget(ctx, target)
if err == nil { if err == nil {
t.Fatalf("expected error for invalid type, got nil") t.Fatalf("expected error for invalid type, got nil")
} }
} }
// TestTargetService_CreateTarget_MissingAgentID verifies the C-002 service-layer
// guard: an empty agent_id must be rejected with ErrAgentNotFound before the
// repository layer is ever consulted. The handler maps this sentinel to HTTP
// 400, so a 500 from a Postgres 23503 FK violation is never surfaced.
func TestTargetService_CreateTarget_MissingAgentID(t *testing.T) {
svc, _, _, _ := newTestTargetService()
target := domain.DeploymentTarget{
Name: "No Agent",
Type: domain.TargetTypeNGINX,
// AgentID intentionally empty
}
ctx := context.Background()
_, err := svc.CreateTarget(ctx, target)
if err == nil {
t.Fatalf("expected error for missing agent_id, got nil")
}
if !errors.Is(err, ErrAgentNotFound) {
t.Errorf("expected errors.Is(err, ErrAgentNotFound) to be true, got err=%v", err)
}
}
// TestTargetService_CreateTarget_NonexistentAgentID verifies the second half of
// the C-002 guard: a non-empty agent_id that does not resolve in agentRepo
// still returns ErrAgentNotFound rather than letting the FK violation escape to
// Postgres. This is the realistic failure mode for a GUI sending a stale
// agent_id or a CLI caller with a typo.
func TestTargetService_CreateTarget_NonexistentAgentID(t *testing.T) {
svc, _, _, _ := newTestTargetService()
target := domain.DeploymentTarget{
Name: "Bad Agent Ref",
Type: domain.TargetTypeNGINX,
AgentID: "a-does-not-exist",
}
ctx := context.Background()
_, err := svc.CreateTarget(ctx, target)
if err == nil {
t.Fatalf("expected error for nonexistent agent_id, got nil")
}
if !errors.Is(err, ErrAgentNotFound) {
t.Errorf("expected errors.Is(err, ErrAgentNotFound) to be true, got err=%v", err)
}
}
func TestTargetService_UpdateTarget_Success(t *testing.T) { func TestTargetService_UpdateTarget_Success(t *testing.T) {
svc, targetRepo, _, _ := newTestTargetService() svc, targetRepo, _, _ := newTestTargetService()
@@ -424,7 +482,8 @@ func TestTargetService_UpdateTarget_Success(t *testing.T) {
Type: domain.TargetTypeApache, Type: domain.TargetTypeApache,
} }
result, err := svc.UpdateTarget("t-1", updated) ctx := context.Background()
result, err := svc.UpdateTarget(ctx, "t-1", updated)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -442,7 +501,8 @@ func TestTargetService_DeleteTarget_Success(t *testing.T) {
targetRepo.AddTarget(target) targetRepo.AddTarget(target)
// Delete it // Delete it
err := svc.DeleteTarget("t-1") ctx := context.Background()
err := svc.DeleteTarget(ctx, "t-1")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
+10 -10
View File
@@ -126,7 +126,7 @@ func (s *TeamService) Delete(ctx context.Context, id string, actor string) error
} }
// ListTeams returns paginated teams (handler interface method). // ListTeams returns paginated teams (handler interface method).
func (s *TeamService) ListTeams(page, perPage int) ([]domain.Team, int64, error) { func (s *TeamService) ListTeams(ctx context.Context, page, perPage int) ([]domain.Team, int64, error) {
if page < 1 { if page < 1 {
page = 1 page = 1
} }
@@ -134,7 +134,7 @@ func (s *TeamService) ListTeams(page, perPage int) ([]domain.Team, int64, error)
perPage = 50 perPage = 50
} }
teams, err := s.teamRepo.List(context.Background()) teams, err := s.teamRepo.List(ctx)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("failed to list teams: %w", err) return nil, 0, fmt.Errorf("failed to list teams: %w", err)
} }
@@ -151,12 +151,12 @@ func (s *TeamService) ListTeams(page, perPage int) ([]domain.Team, int64, error)
} }
// GetTeam returns a single team (handler interface method). // GetTeam returns a single team (handler interface method).
func (s *TeamService) GetTeam(id string) (*domain.Team, error) { func (s *TeamService) GetTeam(ctx context.Context, id string) (*domain.Team, error) {
return s.teamRepo.Get(context.Background(), id) return s.teamRepo.Get(ctx, id)
} }
// CreateTeam creates a new team (handler interface method). // CreateTeam creates a new team (handler interface method).
func (s *TeamService) CreateTeam(team domain.Team) (*domain.Team, error) { func (s *TeamService) CreateTeam(ctx context.Context, team domain.Team) (*domain.Team, error) {
if team.ID == "" { if team.ID == "" {
team.ID = generateID("team") team.ID = generateID("team")
} }
@@ -167,22 +167,22 @@ func (s *TeamService) CreateTeam(team domain.Team) (*domain.Team, error) {
if team.UpdatedAt.IsZero() { if team.UpdatedAt.IsZero() {
team.UpdatedAt = now team.UpdatedAt = now
} }
if err := s.teamRepo.Create(context.Background(), &team); err != nil { if err := s.teamRepo.Create(ctx, &team); err != nil {
return nil, fmt.Errorf("failed to create team: %w", err) return nil, fmt.Errorf("failed to create team: %w", err)
} }
return &team, nil return &team, nil
} }
// UpdateTeam modifies a team (handler interface method). // UpdateTeam modifies a team (handler interface method).
func (s *TeamService) UpdateTeam(id string, team domain.Team) (*domain.Team, error) { func (s *TeamService) UpdateTeam(ctx context.Context, id string, team domain.Team) (*domain.Team, error) {
team.ID = id team.ID = id
if err := s.teamRepo.Update(context.Background(), &team); err != nil { if err := s.teamRepo.Update(ctx, &team); err != nil {
return nil, fmt.Errorf("failed to update team: %w", err) return nil, fmt.Errorf("failed to update team: %w", err)
} }
return &team, nil return &team, nil
} }
// DeleteTeam removes a team (handler interface method). // DeleteTeam removes a team (handler interface method).
func (s *TeamService) DeleteTeam(id string) error { func (s *TeamService) DeleteTeam(ctx context.Context, id string) error {
return s.teamRepo.Delete(context.Background(), id) return s.teamRepo.Delete(ctx, id)
} }
+5 -5
View File
@@ -544,7 +544,7 @@ func TestTeamService_ListTeams_HandlerInterface(t *testing.T) {
}) })
} }
teams, total, err := teamService.ListTeams(1, 2) teams, total, err := teamService.ListTeams(context.Background(), 1, 2)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -571,7 +571,7 @@ func TestTeamService_GetTeam_HandlerInterface(t *testing.T) {
} }
mockTeamRepo.AddTeam(testTeam) mockTeamRepo.AddTeam(testTeam)
team, err := teamService.GetTeam("handler-team") team, err := teamService.GetTeam(context.Background(), "handler-team")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -593,7 +593,7 @@ func TestTeamService_CreateTeam_HandlerInterface(t *testing.T) {
Description: "Created via handler", Description: "Created via handler",
} }
result, err := teamService.CreateTeam(team) result, err := teamService.CreateTeam(context.Background(), team)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -629,7 +629,7 @@ func TestTeamService_UpdateTeam_HandlerInterface(t *testing.T) {
Description: "Handler update", Description: "Handler update",
} }
result, err := teamService.UpdateTeam("handler-update-team", updateTeam) result, err := teamService.UpdateTeam(context.Background(), "handler-update-team", updateTeam)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -656,7 +656,7 @@ func TestTeamService_DeleteTeam_HandlerInterface(t *testing.T) {
Name: "To Delete", Name: "To Delete",
}) })
err := teamService.DeleteTeam("handler-delete-team") err := teamService.DeleteTeam(context.Background(), "handler-delete-team")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
+20 -4
View File
@@ -12,12 +12,15 @@ import (
var errNotFound = errors.New("not found") var errNotFound = errors.New("not found")
// testEncryptionKey is a deterministic 32-byte AES-256 key for unit tests that // testEncryptionKey is a deterministic passphrase for unit tests that
// exercise IssuerService/TargetService write paths. After the C-2 remediation // exercise IssuerService/TargetService write paths. After the C-2 remediation
// these services fail closed when no key is configured, so happy-path tests // these services fail closed when no key is configured, so happy-path tests
// must supply a real key. Using a constant keeps wire-format assertions stable // must supply a real passphrase. M-8 reshaped the type from []byte to string
// across runs and avoids flaky PBKDF2 timing. // because services now hold the raw passphrase and delegate PBKDF2 to
var testEncryptionKey = []byte("0123456789abcdef0123456789abcdef") // 32 bytes // crypto.EncryptIfKeySet / crypto.DecryptIfKeySet (which apply a fresh random
// salt per ciphertext). Using a constant keeps wire-format assertions stable
// across runs.
var testEncryptionKey = "0123456789abcdef0123456789abcdef"
// mockCertRepo is a test implementation of CertificateRepository // mockCertRepo is a test implementation of CertificateRepository
type mockCertRepo struct { type mockCertRepo struct {
@@ -599,6 +602,19 @@ func (m *mockAgentRepo) Create(ctx context.Context, agent *domain.Agent) error {
return nil return nil
} }
func (m *mockAgentRepo) CreateIfNotExists(ctx context.Context, agent *domain.Agent) (bool, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.CreateErr != nil {
return false, m.CreateErr
}
if _, exists := m.Agents[agent.ID]; exists {
return false, nil
}
m.Agents[agent.ID] = agent
return true, nil
}
func (m *mockAgentRepo) Update(ctx context.Context, agent *domain.Agent) error { func (m *mockAgentRepo) Update(ctx context.Context, agent *domain.Agent) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
@@ -0,0 +1,8 @@
-- Rollback migration 000013: remove per-rule severity.
--
-- DROP COLUMN removes the column, its CHECK constraint, and the default in
-- one statement. Any downstream code still referencing severity after
-- rollback will fail at query time — that's intentional, since running this
-- rollback implies severity as a concept is being abandoned.
ALTER TABLE policy_rules DROP COLUMN IF EXISTS severity;
@@ -0,0 +1,24 @@
-- Migration 000013: Per-Rule Severity on policy_rules
--
-- Prior to this migration, PolicyRule had no severity column. The TypeScript
-- frontend (PoliciesPage.tsx) sent a `severity` field on create/update, but
-- Go's json.Decoder silently dropped it (no matching struct field) and the
-- value never reached PostgreSQL. Reloading the page always showed severity
-- reverting to a default — the classic "silent drop" bug.
--
-- This migration adds severity as a first-class column on policy_rules.
-- Default `'Warning'` covers pre-existing rows; the CHECK constraint gives
-- defense-in-depth against casing drift (the application-layer validator in
-- internal/api/handler/validation.go already enforces the TitleCase allowlist,
-- but the DB should reject a bypassed write too).
--
-- No index: three-value column on a table that stays in the low thousands of
-- rows. The planner will seq-scan regardless; write cost without read benefit.
-- If measurements later justify it, add the index then.
--
-- PG 11+ makes ADD COLUMN with a literal DEFAULT a metadata-only operation
-- (no table rewrite), so this is safe to run on a live server.
ALTER TABLE policy_rules
ADD COLUMN IF NOT EXISTS severity VARCHAR(50) NOT NULL DEFAULT 'Warning'
CHECK (severity IN ('Warning', 'Error', 'Critical'));
@@ -0,0 +1,9 @@
-- Rollback migration 000014: drop the policy_violations severity CHECK.
--
-- Drops the named CHECK constraint added by the up migration. The severity
-- column itself stays (it predates this migration — see 000001 line 183),
-- so any application code that reads/writes the column continues to work.
-- Only the DB-level enforcement of the TitleCase allowlist is removed.
ALTER TABLE policy_violations
DROP CONSTRAINT IF EXISTS policy_violations_severity_check;
@@ -0,0 +1,29 @@
-- Migration 000014: CHECK constraint on policy_violations.severity
--
-- Sibling to migration 000013, which added severity + CHECK to policy_rules.
-- policy_violations has carried a severity column since the initial schema
-- (000001, line 183) but without any CHECK. The engine used to hardcode
-- `Warning` on every violation regardless of the triggering rule's severity
-- (see pre-D-008 internal/service/policy.go:evaluateRule), so the column
-- value was uniform by accident of implementation, not by constraint.
--
-- D-008 rewrites evaluateRule to copy rule.Severity into the violation. The
-- engine now writes values drawn from the application-layer PolicySeverity
-- allowlist, but nothing at the DB level prevents a future caller — or a
-- bypassed write from a migration or psql session — from inserting casing
-- drift ('warning', 'ERROR', etc.) and re-opening the same class of bug
-- that D-005 and D-006 closed. This constraint is the defense-in-depth
-- complement to the handler validator.
--
-- Pre-existing seed_demo.sql rows use lowercase severity values. D-008
-- updates those in the same commit so this migration can apply cleanly
-- against both a fresh install and an upgraded install that has already
-- seeded the demo data.
--
-- Named constraint (policy_violations_severity_check) so the down migration
-- can DROP it by name without ambiguity; un-named CHECK constraints use
-- a synthesized PostgreSQL name that varies by environment.
ALTER TABLE policy_violations
ADD CONSTRAINT policy_violations_severity_check
CHECK (severity IN ('Warning', 'Error', 'Critical'));
+32 -16
View File
@@ -12,42 +12,58 @@ VALUES (
'[30, 14, 7, 0]'::jsonb '[30, 14, 7, 0]'::jsonb
) ON CONFLICT (id) DO NOTHING; ) ON CONFLICT (id) DO NOTHING;
-- Policy rules: Require owner assignment -- Policy rules: Require owner assignment, bound environments, cap lifetime,
INSERT INTO policy_rules (id, name, type, config, enabled) -- and enforce a renewal lead-time.
--
-- Severity is differentiated per rule (D-006) and the types are now the
-- TitleCase canonicals the engine actually recognizes (D-008). Pre-D-008 the
-- types were lowercase strings (`ownership`, `environment`, `lifetime`,
-- `renewal_window`) that the engine silently dropped through to its
-- default-case error path — the rules looked alive in the GUI but did not
-- enforce anything. The backend CHECK constraint (migration 000013) enforces
-- the TitleCase severity allowlist Warning/Error/Critical. Configs are also
-- reshaped to match the D-008 per-arm schemas so the rules actually exercise
-- the config-consuming paths instead of falling back to the missing-field
-- placeholders.
INSERT INTO policy_rules (id, name, type, config, enabled, severity)
VALUES ( VALUES (
'pr-require-owner', 'pr-require-owner',
'require-owner', 'require-owner',
'ownership', 'RequiredMetadata',
'{"requirement": "owner_id must be set"}'::jsonb, '{"required_keys": ["owner"]}'::jsonb,
true true,
'Warning'
) ON CONFLICT (id) DO NOTHING; ) ON CONFLICT (id) DO NOTHING;
-- Policy rules: Allowed environments -- Policy rules: Allowed environments
INSERT INTO policy_rules (id, name, type, config, enabled) INSERT INTO policy_rules (id, name, type, config, enabled, severity)
VALUES ( VALUES (
'pr-allowed-environments', 'pr-allowed-environments',
'allowed-environments', 'allowed-environments',
'environment', 'AllowedEnvironments',
'{"allowed": ["production", "staging", "development"]}'::jsonb, '{"allowed": ["production", "staging", "development"]}'::jsonb,
true true,
'Error'
) ON CONFLICT (id) DO NOTHING; ) ON CONFLICT (id) DO NOTHING;
-- Policy rules: Maximum certificate lifetime -- Policy rules: Maximum certificate lifetime
INSERT INTO policy_rules (id, name, type, config, enabled) INSERT INTO policy_rules (id, name, type, config, enabled, severity)
VALUES ( VALUES (
'pr-max-certificate-lifetime', 'pr-max-certificate-lifetime',
'max-certificate-lifetime', 'max-certificate-lifetime',
'lifetime', 'CertificateLifetime',
'{"max_days": 90}'::jsonb, '{"max_days": 90}'::jsonb,
true true,
'Critical'
) ON CONFLICT (id) DO NOTHING; ) ON CONFLICT (id) DO NOTHING;
-- Policy rules: Minimum renewal window -- Policy rules: Minimum renewal window (renew at least 14 days before expiry)
INSERT INTO policy_rules (id, name, type, config, enabled) INSERT INTO policy_rules (id, name, type, config, enabled, severity)
VALUES ( VALUES (
'pr-min-renewal-window', 'pr-min-renewal-window',
'min-renewal-window', 'min-renewal-window',
'renewal_window', 'RenewalLeadTime',
'{"min_days": 14}'::jsonb, '{"lead_time_days": 14}'::jsonb,
true true,
'Warning'
) ON CONFLICT (id) DO NOTHING; ) ON CONFLICT (id) DO NOTHING;
+13 -6
View File
@@ -478,13 +478,20 @@ ON CONFLICT (id) DO NOTHING;
-- ============================================================ -- ============================================================
-- 13. Policy Violations -- 13. Policy Violations
-- ============================================================ -- ============================================================
-- D-008: severity values rewritten to TitleCase canonicals (Warning/Error/Critical).
-- Pre-D-008 these rows used lowercase strings ('critical', 'error', 'warning'). Those
-- values were silently tolerated by the pre-D-008 engine, which hardcoded 'Warning'
-- on every new violation regardless of the triggering rule's severity. D-008 rewires
-- evaluateRule to copy rule.Severity into the violation AND migration 000014 adds a
-- CHECK constraint enforcing the TitleCase allowlist at the DB level. Both paths now
-- round-trip correctly against these demo rows.
INSERT INTO policy_violations (id, certificate_id, rule_id, message, severity, created_at) VALUES INSERT INTO policy_violations (id, certificate_id, rule_id, message, severity, created_at) VALUES
('pv-001', 'mc-legacy-prod', 'pr-max-certificate-lifetime', 'Certificate has expired and exceeds maximum lifetime policy', 'critical', NOW() - INTERVAL '3 days'), ('pv-001', 'mc-legacy-prod', 'pr-max-certificate-lifetime', 'Certificate has expired and exceeds maximum lifetime policy', 'Critical', NOW() - INTERVAL '3 days'),
('pv-002', 'mc-old-api', 'pr-max-certificate-lifetime', 'Certificate expired 15 days ago', 'critical', NOW() - INTERVAL '15 days'), ('pv-002', 'mc-old-api', 'pr-max-certificate-lifetime', 'Certificate expired 15 days ago', 'Critical', NOW() - INTERVAL '15 days'),
('pv-003', 'mc-vpn-prod', 'pr-min-renewal-window', 'Renewal failed within minimum renewal window', 'error', NOW() - INTERVAL '3 days'), ('pv-003', 'mc-vpn-prod', 'pr-min-renewal-window', 'Renewal failed within minimum renewal window', 'Error', NOW() - INTERVAL '3 days'),
('pv-004', 'mc-mail-prod', 'pr-min-renewal-window', 'Certificate expiring in 5 days, below 14-day minimum window','warning', NOW() - INTERVAL '20 minutes'), ('pv-004', 'mc-mail-prod', 'pr-min-renewal-window', 'Certificate expiring in 5 days, below 14-day minimum window','Warning', NOW() - INTERVAL '20 minutes'),
('pv-005', 'mc-wiki-prod', 'pr-max-certificate-lifetime', 'Certificate expired 7 days ago', 'critical', NOW() - INTERVAL '7 days'), ('pv-005', 'mc-wiki-prod', 'pr-max-certificate-lifetime', 'Certificate expired 7 days ago', 'Critical', NOW() - INTERVAL '7 days'),
('pv-006', 'mc-compromised', 'pr-min-renewal-window', 'Certificate revoked due to key compromise', 'critical', NOW() - INTERVAL '14 days') ('pv-006', 'mc-compromised', 'pr-min-renewal-window', 'Certificate revoked due to key compromise', 'Critical', NOW() - INTERVAL '14 days')
ON CONFLICT (id) DO NOTHING; ON CONFLICT (id) DO NOTHING;
-- ============================================================ -- ============================================================
+60
View File
@@ -0,0 +1,60 @@
import { describe, it, expect } from 'vitest';
import { POLICY_TYPES, POLICY_SEVERITIES } from './types';
/**
* Regression tests for the policy enum tuples.
*
* These tuples are the GUI's source of truth for the policy type and severity
* dropdowns. They MUST stay in lockstep with the backend enum values:
* - internal/domain/policy.go defines the PolicyType / PolicySeverity consts
* - internal/api/handler/validators.go rejects anything outside the allowlist
* - migration 000013 enforces the severity allowlist at the DB level via CHECK
*
* Audit history (D-005, D-006):
* - The GUI previously sent lowercase values (e.g. 'key_algorithm',
* 'ownership'), which the backend validator rejected with a 400. Every
* attempt to create a policy from the "+ New Policy" button silently
* failed until the modal was closed.
* - The severity dropdown carried a four-value `low/medium/high/critical`
* tuple that shared zero values with the backend's
* `Warning/Error/Critical` the `medium` option has no backend analog
* and is removed.
*
* If these tests fail because a backend enum changed, DO NOT update the
* expected arrays without also updating the backend consts and the migration.
* Frontend/backend drift on these tuples is precisely what this regression
* guards against.
*/
describe('POLICY_TYPES', () => {
it('matches the backend PolicyType TitleCase allowlist exactly', () => {
expect(POLICY_TYPES).toEqual([
'AllowedIssuers',
'AllowedDomains',
'RequiredMetadata',
'AllowedEnvironments',
'RenewalLeadTime',
'CertificateLifetime',
]);
});
it('has no duplicate entries', () => {
expect(new Set(POLICY_TYPES).size).toBe(POLICY_TYPES.length);
});
});
describe('POLICY_SEVERITIES', () => {
it('matches the backend PolicySeverity TitleCase allowlist exactly', () => {
expect(POLICY_SEVERITIES).toEqual(['Warning', 'Error', 'Critical']);
});
it('has no duplicate entries', () => {
expect(new Set(POLICY_SEVERITIES).size).toBe(POLICY_SEVERITIES.length);
});
it('does not include the removed pre-fix `medium` value', () => {
// Explicit negative assertion. Pre-fix the GUI offered four severities
// (low/medium/high/critical); `medium` never had a backend analog.
expect(POLICY_SEVERITIES as readonly string[]).not.toContain('medium');
});
});
+30 -3
View File
@@ -112,11 +112,38 @@ export interface AuditEvent {
timestamp: string; timestamp: string;
} }
/**
* Policy rule type enum pinned to the backend's TitleCase constants in
* internal/domain/policy.go. Historical note (D-005): the GUI previously sent
* lowercase values (`ownership`, `environment`, etc.) that the handler's
* ValidatePolicyType rejected with a 400. These tuples are the canonical
* source of truth for the dropdown options; the regression test in
* types.test.ts pins them so future drift is caught at CI time.
*/
export const POLICY_TYPES = [
'AllowedIssuers',
'AllowedDomains',
'RequiredMetadata',
'AllowedEnvironments',
'RenewalLeadTime',
'CertificateLifetime',
] as const;
export type PolicyType = (typeof POLICY_TYPES)[number];
/**
* Policy severity enum pinned to the backend's PolicySeverity constants.
* The backend CHECK constraint on policy_rules.severity enforces the same
* allowlist (migration 000013). The 4-value `medium` option that used to
* appear in the GUI was never a valid backend value and has been removed.
*/
export const POLICY_SEVERITIES = ['Warning', 'Error', 'Critical'] as const;
export type PolicySeverity = (typeof POLICY_SEVERITIES)[number];
export interface PolicyRule { export interface PolicyRule {
id: string; id: string;
name: string; name: string;
type: string; type: PolicyType;
severity: string; severity: PolicySeverity;
config: Record<string, unknown>; config: Record<string, unknown>;
enabled: boolean; enabled: boolean;
created_at: string; created_at: string;
@@ -127,7 +154,7 @@ export interface PolicyViolation {
id: string; id: string;
rule_id: string; rule_id: string;
certificate_id: string; certificate_id: string;
severity: string; severity: PolicySeverity;
message: string; message: string;
created_at: string; created_at: string;
} }
+53 -14
View File
@@ -1,7 +1,7 @@
import { useState } from 'react'; import { useState } from 'react';
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'; import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query';
import { useNavigate } from 'react-router-dom'; import { useNavigate } from 'react-router-dom';
import { getCertificates, createCertificate, triggerRenewal, revokeCertificate, updateCertificate, getOwners, getProfiles, getIssuers, bulkRevokeCertificates } from '../api/client'; import { getCertificates, createCertificate, triggerRenewal, revokeCertificate, updateCertificate, getOwners, getTeams, getPolicies, getProfiles, getIssuers, bulkRevokeCertificates } from '../api/client';
import { REVOCATION_REASONS } from '../api/types'; import { REVOCATION_REASONS } from '../api/types';
import PageHeader from '../components/PageHeader'; import PageHeader from '../components/PageHeader';
import DataTable from '../components/DataTable'; import DataTable from '../components/DataTable';
@@ -35,8 +35,27 @@ function CreateCertificateModal({ onClose, onSuccess }: { onClose: () => void; o
queryKey: ['issuers'], queryKey: ['issuers'],
queryFn: () => getIssuers(), queryFn: () => getIssuers(),
}); });
// C-001: owner_id, team_id, and renewal_policy_id are required by the
// server (handler in internal/api/handler/certificates.go) and by OpenAPI.
// Load the catalog so the user selects valid FKs instead of typing free-text
// IDs that would 400 at the server.
const { data: ownersResp } = useQuery({
queryKey: ['owners', 'form'],
queryFn: () => getOwners({ per_page: '500' }),
});
const { data: teamsResp } = useQuery({
queryKey: ['teams', 'form'],
queryFn: () => getTeams({ per_page: '500' }),
});
const { data: policiesResp } = useQuery({
queryKey: ['renewal-policies', 'form'],
queryFn: () => getPolicies({ per_page: '500' }),
});
const profiles = profilesResp?.data || []; const profiles = profilesResp?.data || [];
const issuers = issuersResp?.data || []; const issuers = issuersResp?.data || [];
const owners = ownersResp?.data || [];
const teams = teamsResp?.data || [];
const policies = policiesResp?.data || [];
const selectedProfile = profiles.find(p => p.id === form.certificate_profile_id); const selectedProfile = profiles.find(p => p.id === form.certificate_profile_id);
const ttlLabel = selectedProfile const ttlLabel = selectedProfile
@@ -143,24 +162,36 @@ function CreateCertificateModal({ onClose, onSuccess }: { onClose: () => void; o
</select> </select>
</div> </div>
<div> <div>
<label className="text-xs text-ink-muted block mb-1">Policy</label> <label className="text-xs text-ink-muted block mb-1">Policy *</label>
<input value={form.renewal_policy_id} onChange={e => setForm(f => ({ ...f, renewal_policy_id: e.target.value }))} <select value={form.renewal_policy_id} onChange={e => setForm(f => ({ ...f, renewal_policy_id: e.target.value }))}
className={inputClass} className={selectClass}>
placeholder="rp-standard" /> <option value="">Select policy...</option>
{policies.map(p => (
<option key={p.id} value={p.id}>{p.name}</option>
))}
</select>
</div> </div>
</div> </div>
<div className="grid grid-cols-2 gap-3"> <div className="grid grid-cols-2 gap-3">
<div> <div>
<label className="text-xs text-ink-muted block mb-1">Owner</label> <label className="text-xs text-ink-muted block mb-1">Owner *</label>
<input value={form.owner_id} onChange={e => setForm(f => ({ ...f, owner_id: e.target.value }))} <select value={form.owner_id} onChange={e => setForm(f => ({ ...f, owner_id: e.target.value }))}
className={inputClass} className={selectClass}>
placeholder="o-alice" /> <option value="">Select owner...</option>
{owners.map(o => (
<option key={o.id} value={o.id}>{o.name} ({o.email})</option>
))}
</select>
</div> </div>
<div> <div>
<label className="text-xs text-ink-muted block mb-1">Team</label> <label className="text-xs text-ink-muted block mb-1">Team *</label>
<input value={form.team_id} onChange={e => setForm(f => ({ ...f, team_id: e.target.value }))} <select value={form.team_id} onChange={e => setForm(f => ({ ...f, team_id: e.target.value }))}
className={inputClass} className={selectClass}>
placeholder="t-platform" /> <option value="">Select team...</option>
{teams.map(t => (
<option key={t.id} value={t.id}>{t.name}</option>
))}
</select>
</div> </div>
</div> </div>
<div> <div>
@@ -175,7 +206,15 @@ function CreateCertificateModal({ onClose, onSuccess }: { onClose: () => void; o
<button onClick={onClose} className="btn btn-ghost text-sm">Cancel</button> <button onClick={onClose} className="btn btn-ghost text-sm">Cancel</button>
<button <button
onClick={() => mutation.mutate()} onClick={() => mutation.mutate()}
disabled={!form.name || !form.common_name || !form.issuer_id || mutation.isPending} disabled={
!form.name ||
!form.common_name ||
!form.issuer_id ||
!form.owner_id ||
!form.team_id ||
!form.renewal_policy_id ||
mutation.isPending
}
className="btn btn-primary text-sm disabled:opacity-50" className="btn btn-primary text-sm disabled:opacity-50"
> >
{mutation.isPending ? 'Creating...' : 'Create Certificate'} {mutation.isPending ? 'Creating...' : 'Create Certificate'}
+28 -2
View File
@@ -2,7 +2,7 @@ import { useState } from 'react';
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'; import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query';
import { useNavigate, Link } from 'react-router-dom'; import { useNavigate, Link } from 'react-router-dom';
import { import {
getIssuers, getAgents, getProfiles, getIssuers, getAgents, getProfiles, getOwners,
createIssuer, testIssuerConnection, createIssuer, testIssuerConnection,
createCertificate, triggerRenewal, createCertificate, triggerRenewal,
getApiKey, getApiKey,
@@ -404,12 +404,14 @@ function CertificateStep({ onNext, onSkip, createdIssuerId }: {
const [sans, setSans] = useState(''); const [sans, setSans] = useState('');
const [issuerId, setIssuerId] = useState(createdIssuerId || ''); const [issuerId, setIssuerId] = useState(createdIssuerId || '');
const [profileId, setProfileId] = useState(''); const [profileId, setProfileId] = useState('');
const [ownerId, setOwnerId] = useState('');
const [error, setError] = useState(''); const [error, setError] = useState('');
const [created, setCreated] = useState(false); const [created, setCreated] = useState(false);
const { data: issuers } = useQuery({ queryKey: ['issuers'], queryFn: () => getIssuers() }); const { data: issuers } = useQuery({ queryKey: ['issuers'], queryFn: () => getIssuers() });
const { data: profiles } = useQuery({ queryKey: ['profiles'], queryFn: () => getProfiles() }); const { data: profiles } = useQuery({ queryKey: ['profiles'], queryFn: () => getProfiles() });
const { data: agents } = useQuery({ queryKey: ['agents'], queryFn: () => getAgents() }); const { data: agents } = useQuery({ queryKey: ['agents'], queryFn: () => getAgents() });
const { data: owners } = useQuery({ queryKey: ['owners'], queryFn: () => getOwners() });
const hasAgents = (agents?.data?.length ?? 0) > 0; const hasAgents = (agents?.data?.length ?? 0) > 0;
@@ -421,6 +423,7 @@ function CertificateStep({ onNext, onSkip, createdIssuerId }: {
sans: sanList, sans: sanList,
issuer_id: issuerId, issuer_id: issuerId,
certificate_profile_id: profileId || undefined, certificate_profile_id: profileId || undefined,
owner_id: ownerId,
environment: 'production', environment: 'production',
}); });
// Trigger issuance // Trigger issuance
@@ -521,6 +524,29 @@ function CertificateStep({ onNext, onSkip, createdIssuerId }: {
</select> </select>
</div> </div>
</div> </div>
<div>
<label className="block text-sm font-medium text-ink mb-2">
Owner <span className="text-red-600">*</span>
</label>
<select
value={ownerId}
onChange={e => setOwnerId(e.target.value)}
className="w-full px-3 py-2 bg-surface border border-surface-border rounded text-ink focus:outline-none focus:border-brand-500 transition-colors"
>
<option value="">Select owner...</option>
{owners?.data?.map(o => (
<option key={o.id} value={o.id}>
{o.name}{o.email ? ` (${o.email})` : ''}
</option>
))}
</select>
{(owners?.data?.length ?? 0) === 0 && (
<p className="mt-1 text-xs text-ink-muted">
No owners yet create one from the <Link to="/owners" className="underline hover:text-ink">Owners page</Link> first, then return here.
</p>
)}
</div>
</div> </div>
{/* Discovery hint */} {/* Discovery hint */}
@@ -547,7 +573,7 @@ function CertificateStep({ onNext, onSkip, createdIssuerId }: {
onSkip={onSkip} onSkip={onSkip}
onNext={() => createMutation.mutate()} onNext={() => createMutation.mutate()}
nextLabel={createMutation.isPending ? 'Creating...' : 'Issue Certificate'} nextLabel={createMutation.isPending ? 'Creating...' : 'Issue Certificate'}
nextDisabled={!commonName || !issuerId || createMutation.isPending} nextDisabled={!commonName || !issuerId || !ownerId || createMutation.isPending}
/> />
</div> </div>
); );
+44 -29
View File
@@ -6,22 +6,40 @@ import DataTable from '../components/DataTable';
import type { Column } from '../components/DataTable'; import type { Column } from '../components/DataTable';
import ErrorState from '../components/ErrorState'; import ErrorState from '../components/ErrorState';
import { formatDateTime } from '../api/utils'; import { formatDateTime } from '../api/utils';
import type { PolicyRule } from '../api/types'; import {
POLICY_TYPES,
POLICY_SEVERITIES,
type PolicyRule,
type PolicyType,
type PolicySeverity,
} from '../api/types';
const severityStyles: Record<string, string> = { /**
low: 'badge-info', * Severity badge style. Keyed on the backend's TitleCase PolicySeverity
medium: 'badge-warning', * enum values (D-006). The pre-fix map keyed on `low`/`medium`/`high`/`critical`
high: 'badge-danger', * which never matched the backend's `Warning`/`Error`/`Critical`, so every
critical: 'badge-danger', * existing rule fell through to the `badge-neutral` default.
*/
const severityStyles: Record<PolicySeverity, string> = {
Warning: 'badge-warning',
Error: 'badge-danger',
Critical: 'badge-danger',
}; };
const severityDots: Record<string, string> = { const severityDots: Record<PolicySeverity, string> = {
low: 'bg-emerald-500', Warning: 'bg-amber-500',
medium: 'bg-amber-500', Error: 'bg-orange-500',
high: 'bg-orange-500', Critical: 'bg-red-500',
critical: 'bg-red-500',
}; };
/**
* Convert TitleCase enum value to a human-readable label for display.
* "AllowedIssuers" "Allowed Issuers"
*/
function humanize(s: string): string {
return s.replace(/([A-Z])/g, ' $1').trim();
}
interface CreatePolicyModalProps { interface CreatePolicyModalProps {
isOpen: boolean; isOpen: boolean;
onClose: () => void; onClose: () => void;
@@ -32,8 +50,8 @@ interface CreatePolicyModalProps {
function CreatePolicyModal({ isOpen, onClose, onSuccess, isLoading, error }: CreatePolicyModalProps) { function CreatePolicyModal({ isOpen, onClose, onSuccess, isLoading, error }: CreatePolicyModalProps) {
const [name, setName] = useState(''); const [name, setName] = useState('');
const [type, setType] = useState('key_algorithm'); const [type, setType] = useState<PolicyType>(POLICY_TYPES[0]);
const [severity, setSeverity] = useState('medium'); const [severity, setSeverity] = useState<PolicySeverity>('Warning');
const [configStr, setConfigStr] = useState('{}'); const [configStr, setConfigStr] = useState('{}');
const [enabled, setEnabled] = useState(true); const [enabled, setEnabled] = useState(true);
@@ -43,8 +61,8 @@ function CreatePolicyModal({ isOpen, onClose, onSuccess, isLoading, error }: Cre
const config = JSON.parse(configStr); const config = JSON.parse(configStr);
await createPolicy({ name: name.trim(), type, severity, config, enabled }); await createPolicy({ name: name.trim(), type, severity, config, enabled });
setName(''); setName('');
setType('key_algorithm'); setType(POLICY_TYPES[0]);
setSeverity('medium'); setSeverity('Warning');
setConfigStr('{}'); setConfigStr('{}');
setEnabled(true); setEnabled(true);
onSuccess(); onSuccess();
@@ -72,27 +90,24 @@ function CreatePolicyModal({ isOpen, onClose, onSuccess, isLoading, error }: Cre
<label className="block text-sm font-medium text-ink mb-1">Type *</label> <label className="block text-sm font-medium text-ink mb-1">Type *</label>
<select <select
value={type} value={type}
onChange={e => setType(e.target.value)} onChange={e => setType(e.target.value as PolicyType)}
className="w-full bg-white border border-surface-border rounded px-3 py-2 text-sm text-ink focus:outline-none focus:border-brand-400" className="w-full bg-white border border-surface-border rounded px-3 py-2 text-sm text-ink focus:outline-none focus:border-brand-400"
> >
<option value="key_algorithm">Key Algorithm</option> {POLICY_TYPES.map(t => (
<option value="cert_lifetime">Certificate Lifetime</option> <option key={t} value={t}>{humanize(t)}</option>
<option value="san_pattern">SAN Pattern</option> ))}
<option value="key_usage">Key Usage</option>
<option value="revocation_check">Revocation Check</option>
</select> </select>
</div> </div>
<div> <div>
<label className="block text-sm font-medium text-ink mb-1">Severity *</label> <label className="block text-sm font-medium text-ink mb-1">Severity *</label>
<select <select
value={severity} value={severity}
onChange={e => setSeverity(e.target.value)} onChange={e => setSeverity(e.target.value as PolicySeverity)}
className="w-full bg-white border border-surface-border rounded px-3 py-2 text-sm text-ink focus:outline-none focus:border-brand-400" className="w-full bg-white border border-surface-border rounded px-3 py-2 text-sm text-ink focus:outline-none focus:border-brand-400"
> >
<option value="low">Low</option> {POLICY_SEVERITIES.map(s => (
<option value="medium">Medium</option> <option key={s} value={s}>{s}</option>
<option value="high">High</option> ))}
<option value="critical">Critical</option>
</select> </select>
</div> </div>
<div> <div>
@@ -182,7 +197,7 @@ export default function PoliciesPage() {
</div> </div>
), ),
}, },
{ key: 'type', label: 'Type', render: (p) => <span className="text-sm text-ink">{p.type.replace(/_/g, ' ')}</span> }, { key: 'type', label: 'Type', render: (p) => <span className="text-sm text-ink">{humanize(p.type)}</span> },
{ {
key: 'severity', key: 'severity',
label: 'Severity', label: 'Severity',
@@ -248,8 +263,8 @@ export default function PoliciesPage() {
</div> </div>
{Object.entries(bySeverity).map(([sev, count]) => ( {Object.entries(bySeverity).map(([sev, count]) => (
<div key={sev} className="flex items-center gap-1.5"> <div key={sev} className="flex items-center gap-1.5">
<div className={`w-2 h-2 rounded-full ${severityDots[sev] || 'bg-slate-400'}`} /> <div className={`w-2 h-2 rounded-full ${severityDots[sev as PolicySeverity] || 'bg-slate-400'}`} />
<span className="text-xs text-ink capitalize">{sev}</span> <span className="text-xs text-ink">{sev}</span>
<span className="text-xs text-ink-faint">{count}</span> <span className="text-xs text-ink-faint">{count}</span>
</div> </div>
))} ))}
+22 -6
View File
@@ -1,7 +1,7 @@
import { useState } from 'react'; import { useState } from 'react';
import { Link } from 'react-router-dom'; import { Link } from 'react-router-dom';
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
import { getTargets, createTarget, deleteTarget } from '../api/client'; import { getTargets, createTarget, deleteTarget, getAgents } from '../api/client';
import PageHeader from '../components/PageHeader'; import PageHeader from '../components/PageHeader';
import DataTable from '../components/DataTable'; import DataTable from '../components/DataTable';
import type { Column } from '../components/DataTable'; import type { Column } from '../components/DataTable';
@@ -180,6 +180,16 @@ function CreateTargetWizard({ onClose, onSuccess }: { onClose: () => void; onSuc
const [config, setConfig] = useState<Record<string, string>>({}); const [config, setConfig] = useState<Record<string, string>>({});
const [error, setError] = useState(''); const [error, setError] = useState('');
// C-002: agent_id is a NOT NULL FK in deployment_targets (migration 000001
// line 104). Load registered agents so the user picks a valid FK instead of
// typing a free-text ID that would 400 at the service layer (or, pre-fix,
// bubble up as a Postgres 23503 foreign-key violation → 500).
const { data: agentsResp } = useQuery({
queryKey: ['agents', 'form'],
queryFn: () => getAgents({ per_page: '500' }),
});
const agents = agentsResp?.data || [];
// Fields that backends expect as boolean (Go bool) // Fields that backends expect as boolean (Go bool)
const BOOL_FIELDS = new Set([ const BOOL_FIELDS = new Set([
'sni', 'insecure', 'sds_config', 'remove_expired', 'create_keystore', 'sni', 'insecure', 'sds_config', 'remove_expired', 'create_keystore',
@@ -244,7 +254,7 @@ function CreateTargetWizard({ onClose, onSuccess }: { onClose: () => void; onSuc
}); });
const fields = CONFIG_FIELDS[targetType] || []; const fields = CONFIG_FIELDS[targetType] || [];
const canProceedToReview = name && targetType && fields.filter(f => f.required).every(f => config[f.key]); const canProceedToReview = name && targetType && agentId && fields.filter(f => f.required).every(f => config[f.key]);
return ( return (
<div className="fixed inset-0 bg-black/40 flex items-center justify-center z-50" onClick={onClose}> <div className="fixed inset-0 bg-black/40 flex items-center justify-center z-50" onClick={onClose}>
@@ -314,10 +324,16 @@ function CreateTargetWizard({ onClose, onSuccess }: { onClose: () => void; onSuc
placeholder="web-server-1" /> placeholder="web-server-1" />
</div> </div>
<div> <div>
<label className="text-xs text-ink-muted block mb-1">Agent ID</label> <label className="text-xs text-ink-muted block mb-1">Agent *</label>
<input value={agentId} onChange={e => setAgentId(e.target.value)} <select value={agentId} onChange={e => setAgentId(e.target.value)}
className="w-full bg-white border border-surface-border rounded px-3 py-2 text-sm text-ink focus:outline-none focus:border-brand-400" className="w-full bg-white border border-surface-border rounded px-3 py-2 text-sm text-ink focus:outline-none focus:border-brand-400">
placeholder="agent-web1" /> <option value="">Select an agent...</option>
{agents.map(a => (
<option key={a.id} value={a.id}>
{a.hostname || a.id} ({a.id})
</option>
))}
</select>
</div> </div>
{fields.map(f => ( {fields.map(f => (
<div key={f.key}> <div key={f.key}>