mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-09 18:00:05 +00:00
Compare commits
25 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ad93e99158 | |||
| 9d0c3dfa15 | |||
| 2c9602db71 | |||
| ef670fa6da | |||
| 5a6ec39cfd | |||
| e3196e7b50 | |||
| bea69efd12 | |||
| 283ec27ca4 | |||
| a67a6b6c30 | |||
| ccd89c348f | |||
| 478a141498 | |||
| 2497be496d | |||
| 25dd6c07f3 | |||
| eb14236166 | |||
| bbb628243f | |||
| cdc9d03d5b | |||
| e951d319d0 | |||
| d14a45401b | |||
| 655e2879e6 | |||
| e757ef1471 | |||
| 27afa4463d | |||
| 80450c7180 | |||
| c655e0f8c5 | |||
| 5abeeb882b | |||
| b1df6dab27 |
@@ -45,11 +45,11 @@ jobs:
|
|||||||
run: govulncheck ./...
|
run: govulncheck ./...
|
||||||
|
|
||||||
- name: Race Detection
|
- name: Race Detection
|
||||||
run: go test -race ./internal/service/... ./internal/api/handler/... ./internal/api/middleware/... ./internal/scheduler/... ./internal/connector/... ./internal/domain/... ./internal/validation/... ./internal/tlsprobe/... -count=1 -timeout 300s
|
run: go test -race ./internal/service/... ./internal/api/handler/... ./internal/api/middleware/... ./internal/scheduler/... ./internal/connector/... ./internal/crypto/... ./internal/domain/... ./internal/validation/... ./internal/tlsprobe/... -count=1 -timeout 300s
|
||||||
|
|
||||||
- name: Go Test with Coverage
|
- name: Go Test with Coverage
|
||||||
run: |
|
run: |
|
||||||
go test ./internal/service/... ./internal/api/handler/... ./internal/api/middleware/... ./internal/integration/... ./internal/connector/issuer/... ./internal/connector/target/... ./internal/connector/notifier/... ./internal/connector/discovery/... ./internal/mcp/... ./internal/cli/... ./internal/domain/... ./internal/validation/... ./internal/tlsprobe/... -count=1 -cover -coverprofile=coverage.out
|
go test ./internal/service/... ./internal/api/handler/... ./internal/api/middleware/... ./internal/integration/... ./internal/connector/issuer/... ./internal/connector/target/... ./internal/connector/notifier/... ./internal/connector/discovery/... ./internal/crypto/... ./internal/mcp/... ./internal/cli/... ./internal/domain/... ./internal/validation/... ./internal/tlsprobe/... -count=1 -cover -coverprofile=coverage.out
|
||||||
|
|
||||||
- name: Check Coverage Thresholds
|
- name: Check Coverage Thresholds
|
||||||
run: |
|
run: |
|
||||||
@@ -73,6 +73,13 @@ jobs:
|
|||||||
MIDDLEWARE_COV=$(go tool cover -func=coverage.out | grep 'internal/api/middleware' | awk '{print $NF}' | sed 's/%//' | awk '{sum+=$1; n++} END {if(n>0) printf "%.1f", sum/n; else print "0"}')
|
MIDDLEWARE_COV=$(go tool cover -func=coverage.out | grep 'internal/api/middleware' | awk '{print $NF}' | sed 's/%//' | awk '{sum+=$1; n++} END {if(n>0) printf "%.1f", sum/n; else print "0"}')
|
||||||
echo "Middleware layer coverage: ${MIDDLEWARE_COV}%"
|
echo "Middleware layer coverage: ${MIDDLEWARE_COV}%"
|
||||||
|
|
||||||
|
# Check crypto package coverage (target: 85%+)
|
||||||
|
# M-8 rationale: encryption primitives are a security-critical gate.
|
||||||
|
# v2 format, key-derivation, fallback, and fail-closed sentinel paths
|
||||||
|
# all need exhaustive coverage to avoid silent regressions (CWE-916 / CWE-329).
|
||||||
|
CRYPTO_COV=$(go tool cover -func=coverage.out | grep 'internal/crypto' | awk '{print $NF}' | sed 's/%//' | awk '{sum+=$1; n++} END {if(n>0) printf "%.1f", sum/n; else print "0"}')
|
||||||
|
echo "Crypto package coverage: ${CRYPTO_COV}%"
|
||||||
|
|
||||||
# Fail if thresholds not met
|
# Fail if thresholds not met
|
||||||
if [ "$(echo "$SERVICE_COV < 55" | bc -l)" -eq 1 ]; then
|
if [ "$(echo "$SERVICE_COV < 55" | bc -l)" -eq 1 ]; then
|
||||||
echo "::error::Service layer coverage ${SERVICE_COV}% is below 55% threshold"
|
echo "::error::Service layer coverage ${SERVICE_COV}% is below 55% threshold"
|
||||||
@@ -90,6 +97,10 @@ jobs:
|
|||||||
echo "::error::Middleware layer coverage ${MIDDLEWARE_COV}% is below 30% threshold"
|
echo "::error::Middleware layer coverage ${MIDDLEWARE_COV}% is below 30% threshold"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
if [ "$(echo "$CRYPTO_COV < 85" | bc -l)" -eq 1 ]; then
|
||||||
|
echo "::error::Crypto package coverage ${CRYPTO_COV}% is below 85% threshold"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
echo "Coverage thresholds passed!"
|
echo "Coverage thresholds passed!"
|
||||||
|
|
||||||
- name: Upload Coverage Report
|
- name: Upload Coverage Report
|
||||||
|
|||||||
+272
-43
@@ -7,40 +7,30 @@ on:
|
|||||||
|
|
||||||
env:
|
env:
|
||||||
REGISTRY: ghcr.io
|
REGISTRY: ghcr.io
|
||||||
GO_VERSION: '1.22'
|
# Keep in lock-step with .github/workflows/ci.yml (M-3).
|
||||||
|
GO_VERSION: '1.25.9'
|
||||||
|
IMAGE_NAMESPACE: shankar0123
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
# Cross-compile agent and server binaries for multiple platforms
|
# ----------------------------------------------------------------------
|
||||||
|
# build-binaries (M-3): matrix build every (binary × OS × arch) tuple.
|
||||||
|
# For each tuple we produce: the binary, a SPDX-JSON SBOM, a keyless
|
||||||
|
# Cosign signature + certificate bundle, and a single-line sha256sum
|
||||||
|
# file. All artefacts are uploaded to a workflow-scoped artifact; the
|
||||||
|
# aggregate-checksums job fans them back in for release upload.
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
build-binaries:
|
build-binaries:
|
||||||
name: Build Cross-Platform Binaries
|
name: Build ${{ matrix.binary }} (${{ matrix.os }}/${{ matrix.arch }})
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: read
|
||||||
|
id-token: write # Cosign keyless OIDC identity token
|
||||||
strategy:
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
binary: [agent, server, cli, mcp-server]
|
||||||
# Agent binaries (4 platforms)
|
os: [linux, darwin]
|
||||||
- os: linux
|
arch: [amd64, arm64]
|
||||||
arch: amd64
|
|
||||||
binary: agent
|
|
||||||
- os: linux
|
|
||||||
arch: arm64
|
|
||||||
binary: agent
|
|
||||||
- os: darwin
|
|
||||||
arch: amd64
|
|
||||||
binary: agent
|
|
||||||
- os: darwin
|
|
||||||
arch: arm64
|
|
||||||
binary: agent
|
|
||||||
# Server binaries (2 platforms)
|
|
||||||
- os: linux
|
|
||||||
arch: amd64
|
|
||||||
binary: server
|
|
||||||
- os: linux
|
|
||||||
arch: arm64
|
|
||||||
binary: server
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
@@ -51,35 +41,171 @@ jobs:
|
|||||||
|
|
||||||
- name: Extract version from tag
|
- name: Extract version from tag
|
||||||
id: version
|
id: version
|
||||||
run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
|
run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
- name: Build ${{ matrix.binary }} binary (${{ matrix.os }}-${{ matrix.arch }})
|
- name: Build binary
|
||||||
|
id: build
|
||||||
env:
|
env:
|
||||||
GOOS: ${{ matrix.os }}
|
GOOS: ${{ matrix.os }}
|
||||||
GOARCH: ${{ matrix.arch }}
|
GOARCH: ${{ matrix.arch }}
|
||||||
CGO_ENABLED: 0
|
CGO_ENABLED: '0'
|
||||||
|
VERSION: ${{ steps.version.outputs.VERSION }}
|
||||||
run: |
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
OUTPUT_NAME="certctl-${{ matrix.binary }}-${{ matrix.os }}-${{ matrix.arch }}"
|
OUTPUT_NAME="certctl-${{ matrix.binary }}-${{ matrix.os }}-${{ matrix.arch }}"
|
||||||
go build -ldflags="-w -s -X main.Version=${{ steps.version.outputs.VERSION }}" \
|
mkdir -p dist
|
||||||
|
go build \
|
||||||
|
-trimpath \
|
||||||
|
-ldflags="-w -s -X main.Version=${VERSION}" \
|
||||||
-o "dist/${OUTPUT_NAME}" \
|
-o "dist/${OUTPUT_NAME}" \
|
||||||
"./cmd/${{ matrix.binary }}"
|
"./cmd/${{ matrix.binary }}"
|
||||||
ls -lh "dist/${OUTPUT_NAME}"
|
ls -lh "dist/${OUTPUT_NAME}"
|
||||||
|
echo "output_name=${OUTPUT_NAME}" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
- name: Upload binaries to release
|
- name: Generate SBOM (SPDX-JSON)
|
||||||
|
uses: anchore/sbom-action@e22c389904149dbc22b58101806040fa8d37a610 # v0.24.0
|
||||||
|
with:
|
||||||
|
file: dist/${{ steps.build.outputs.output_name }}
|
||||||
|
format: spdx-json
|
||||||
|
output-file: dist/${{ steps.build.outputs.output_name }}.sbom.spdx.json
|
||||||
|
upload-artifact: false
|
||||||
|
upload-release-assets: false
|
||||||
|
|
||||||
|
- name: Install Cosign
|
||||||
|
uses: sigstore/cosign-installer@cad07c2e89fa2edd6e2d7bab4c1aa38e53f76003 # v4.1.1
|
||||||
|
|
||||||
|
- name: Keyless-sign binary with Cosign
|
||||||
|
env:
|
||||||
|
OUTPUT_NAME: ${{ steps.build.outputs.output_name }}
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
cosign sign-blob \
|
||||||
|
--yes \
|
||||||
|
--output-signature "dist/${OUTPUT_NAME}.sig" \
|
||||||
|
--output-certificate "dist/${OUTPUT_NAME}.pem" \
|
||||||
|
"dist/${OUTPUT_NAME}"
|
||||||
|
|
||||||
|
- name: Compute SHA-256 sidecar
|
||||||
|
env:
|
||||||
|
OUTPUT_NAME: ${{ steps.build.outputs.output_name }}
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
cd dist
|
||||||
|
sha256sum "${OUTPUT_NAME}" > "${OUTPUT_NAME}.sha256"
|
||||||
|
cat "${OUTPUT_NAME}.sha256"
|
||||||
|
|
||||||
|
- name: Upload build artefacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: binary-${{ steps.build.outputs.output_name }}
|
||||||
|
path: |
|
||||||
|
dist/${{ steps.build.outputs.output_name }}
|
||||||
|
dist/${{ steps.build.outputs.output_name }}.sig
|
||||||
|
dist/${{ steps.build.outputs.output_name }}.pem
|
||||||
|
dist/${{ steps.build.outputs.output_name }}.sbom.spdx.json
|
||||||
|
dist/${{ steps.build.outputs.output_name }}.sha256
|
||||||
|
if-no-files-found: error
|
||||||
|
retention-days: 7
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# aggregate-checksums (M-3): fan in every matrix artefact, produce a
|
||||||
|
# single checksums.txt (sha256sum format, compatible with `sha256sum
|
||||||
|
# -c`), sign it with Cosign, upload everything to the GitHub Release,
|
||||||
|
# and emit a base64-encoded hash manifest for the SLSA generator.
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
aggregate-checksums:
|
||||||
|
name: Aggregate checksums & sign
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [build-binaries]
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
id-token: write # Cosign keyless OIDC identity token
|
||||||
|
outputs:
|
||||||
|
hashes: ${{ steps.hashes.outputs.hashes }}
|
||||||
|
steps:
|
||||||
|
- name: Download binary artefacts
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
pattern: binary-*
|
||||||
|
path: artifacts
|
||||||
|
merge-multiple: true
|
||||||
|
|
||||||
|
- name: Aggregate SHA-256 sums
|
||||||
|
id: hashes
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
cd artifacts
|
||||||
|
: > checksums.txt
|
||||||
|
for f in certctl-*; do
|
||||||
|
case "$f" in
|
||||||
|
*.sig|*.pem|*.sbom.spdx.json|*.sha256|checksums.txt)
|
||||||
|
continue ;;
|
||||||
|
esac
|
||||||
|
sha256sum "$f" >> checksums.txt
|
||||||
|
done
|
||||||
|
echo "=== checksums.txt ==="
|
||||||
|
cat checksums.txt
|
||||||
|
# base64 hashes (single line, no wrapping) for SLSA generator.
|
||||||
|
HASHES=$(base64 -w0 < checksums.txt)
|
||||||
|
echo "hashes=${HASHES}" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
- name: Install Cosign
|
||||||
|
uses: sigstore/cosign-installer@cad07c2e89fa2edd6e2d7bab4c1aa38e53f76003 # v4.1.1
|
||||||
|
|
||||||
|
- name: Keyless-sign checksums.txt
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
cd artifacts
|
||||||
|
cosign sign-blob \
|
||||||
|
--yes \
|
||||||
|
--output-signature checksums.txt.sig \
|
||||||
|
--output-certificate checksums.txt.pem \
|
||||||
|
checksums.txt
|
||||||
|
|
||||||
|
- name: Upload artefacts to GitHub Release
|
||||||
uses: softprops/action-gh-release@v2
|
uses: softprops/action-gh-release@v2
|
||||||
if: startsWith(github.ref, 'refs/tags/')
|
if: startsWith(github.ref, 'refs/tags/')
|
||||||
with:
|
with:
|
||||||
files: |
|
files: |
|
||||||
dist/certctl-agent-*
|
artifacts/certctl-*
|
||||||
dist/certctl-server-*
|
artifacts/checksums.txt
|
||||||
|
artifacts/checksums.txt.sig
|
||||||
|
artifacts/checksums.txt.pem
|
||||||
|
|
||||||
# Build and push Docker images
|
# ----------------------------------------------------------------------
|
||||||
|
# provenance-binaries (M-3): SLSA Level 3 provenance for every binary.
|
||||||
|
# The SLSA generic generator reusable workflow runs in a hermetic
|
||||||
|
# workflow run, producing multiple.intoto.jsonl from the base64 hash
|
||||||
|
# manifest and uploading it as a release asset.
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
provenance-binaries:
|
||||||
|
name: SLSA provenance (binaries)
|
||||||
|
needs: [aggregate-checksums]
|
||||||
|
permissions:
|
||||||
|
actions: read
|
||||||
|
id-token: write
|
||||||
|
contents: write
|
||||||
|
uses: slsa-framework/slsa-github-generator/.github/workflows/generator_generic_slsa3.yml@v2.1.0
|
||||||
|
with:
|
||||||
|
base64-subjects: "${{ needs.aggregate-checksums.outputs.hashes }}"
|
||||||
|
upload-assets: true
|
||||||
|
provenance-name: multiple.intoto.jsonl
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# build-and-push-docker: push container images to GHCR with native
|
||||||
|
# SLSA L3 provenance (mode=max) and SBOM attestations emitted by
|
||||||
|
# docker/build-push-action@v6, plus a keyless Cosign signature on the
|
||||||
|
# image digest for identity-bound verification. The M-4 proxy-propagation
|
||||||
|
# build-args block is retained verbatim — M-3 only adds supply-chain
|
||||||
|
# steps; it never touches M-4 wiring.
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
build-and-push-docker:
|
build-and-push-docker:
|
||||||
name: Build & Push Docker Images
|
name: Build & Push Docker Images
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
packages: write
|
packages: write
|
||||||
|
id-token: write # Cosign keyless OIDC identity token
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
@@ -93,20 +219,24 @@ jobs:
|
|||||||
|
|
||||||
- name: Extract version from tag
|
- name: Extract version from tag
|
||||||
id: version
|
id: version
|
||||||
run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
|
run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Install Cosign
|
||||||
|
uses: sigstore/cosign-installer@cad07c2e89fa2edd6e2d7bab4c1aa38e53f76003 # v4.1.1
|
||||||
|
|
||||||
- name: Build and push server image
|
- name: Build and push server image
|
||||||
|
id: server-push
|
||||||
uses: docker/build-push-action@v6
|
uses: docker/build-push-action@v6
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ./Dockerfile
|
file: ./Dockerfile
|
||||||
push: true
|
push: true
|
||||||
tags: |
|
tags: |
|
||||||
${{ env.REGISTRY }}/shankar0123/certctl-server:${{ steps.version.outputs.VERSION }}
|
${{ env.REGISTRY }}/${{ env.IMAGE_NAMESPACE }}/certctl-server:${{ steps.version.outputs.VERSION }}
|
||||||
${{ env.REGISTRY }}/shankar0123/certctl-server:latest
|
${{ env.REGISTRY }}/${{ env.IMAGE_NAMESPACE }}/certctl-server:latest
|
||||||
# Proxy propagation (M-4, Issue #9) — forwards runner-level proxy
|
# Proxy propagation (M-4, Issue #9) — forwards runner-level proxy
|
||||||
# secrets into the Docker build so self-hosted runners behind
|
# secrets into the Docker build so self-hosted runners behind
|
||||||
# corporate proxies can reach public registries. GitHub-hosted
|
# corporate proxies can reach public registries. GitHub-hosted
|
||||||
@@ -117,18 +247,31 @@ jobs:
|
|||||||
HTTP_PROXY=${{ secrets.HTTP_PROXY }}
|
HTTP_PROXY=${{ secrets.HTTP_PROXY }}
|
||||||
HTTPS_PROXY=${{ secrets.HTTPS_PROXY }}
|
HTTPS_PROXY=${{ secrets.HTTPS_PROXY }}
|
||||||
NO_PROXY=${{ secrets.NO_PROXY }}
|
NO_PROXY=${{ secrets.NO_PROXY }}
|
||||||
|
# Supply-chain hardening (M-3): emit native SLSA L3 provenance
|
||||||
|
# and SBOM attestations bound to the image manifest.
|
||||||
|
provenance: mode=max
|
||||||
|
sbom: true
|
||||||
cache-from: type=gha
|
cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max
|
cache-to: type=gha,mode=max
|
||||||
|
|
||||||
|
- name: Keyless-sign server image with Cosign
|
||||||
|
env:
|
||||||
|
DIGEST: ${{ steps.server-push.outputs.digest }}
|
||||||
|
IMAGE: ${{ env.REGISTRY }}/${{ env.IMAGE_NAMESPACE }}/certctl-server
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
cosign sign --yes "${IMAGE}@${DIGEST}"
|
||||||
|
|
||||||
- name: Build and push agent image
|
- name: Build and push agent image
|
||||||
|
id: agent-push
|
||||||
uses: docker/build-push-action@v6
|
uses: docker/build-push-action@v6
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ./Dockerfile.agent
|
file: ./Dockerfile.agent
|
||||||
push: true
|
push: true
|
||||||
tags: |
|
tags: |
|
||||||
${{ env.REGISTRY }}/shankar0123/certctl-agent:${{ steps.version.outputs.VERSION }}
|
${{ env.REGISTRY }}/${{ env.IMAGE_NAMESPACE }}/certctl-agent:${{ steps.version.outputs.VERSION }}
|
||||||
${{ env.REGISTRY }}/shankar0123/certctl-agent:latest
|
${{ env.REGISTRY }}/${{ env.IMAGE_NAMESPACE }}/certctl-agent:latest
|
||||||
# Proxy propagation (M-4, Issue #9) — see server-image step for
|
# Proxy propagation (M-4, Issue #9) — see server-image step for
|
||||||
# rationale. Empty secrets resolve to empty build args, leaving
|
# rationale. Empty secrets resolve to empty build args, leaving
|
||||||
# the un-proxied code path byte-identical to the pre-fix tree.
|
# the un-proxied code path byte-identical to the pre-fix tree.
|
||||||
@@ -136,14 +279,30 @@ jobs:
|
|||||||
HTTP_PROXY=${{ secrets.HTTP_PROXY }}
|
HTTP_PROXY=${{ secrets.HTTP_PROXY }}
|
||||||
HTTPS_PROXY=${{ secrets.HTTPS_PROXY }}
|
HTTPS_PROXY=${{ secrets.HTTPS_PROXY }}
|
||||||
NO_PROXY=${{ secrets.NO_PROXY }}
|
NO_PROXY=${{ secrets.NO_PROXY }}
|
||||||
|
# Supply-chain hardening (M-3): emit native SLSA L3 provenance
|
||||||
|
# and SBOM attestations bound to the image manifest.
|
||||||
|
provenance: mode=max
|
||||||
|
sbom: true
|
||||||
cache-from: type=gha
|
cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max
|
cache-to: type=gha,mode=max
|
||||||
|
|
||||||
# Create release notes with all artifacts
|
- name: Keyless-sign agent image with Cosign
|
||||||
|
env:
|
||||||
|
DIGEST: ${{ steps.agent-push.outputs.digest }}
|
||||||
|
IMAGE: ${{ env.REGISTRY }}/${{ env.IMAGE_NAMESPACE }}/certctl-agent
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
cosign sign --yes "${IMAGE}@${DIGEST}"
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# create-release: stamp the release body. The actual asset uploads are
|
||||||
|
# handled by aggregate-checksums (binaries, SBOMs, sigs, certs,
|
||||||
|
# checksums.txt + signature) and the SLSA generator (multiple.intoto.jsonl).
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
create-release:
|
create-release:
|
||||||
name: Create Release Notes
|
name: Create Release Notes
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: [build-binaries, build-and-push-docker]
|
needs: [build-binaries, aggregate-checksums, provenance-binaries, build-and-push-docker]
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
|
|
||||||
@@ -152,7 +311,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Extract version from tag
|
- name: Extract version from tag
|
||||||
id: version
|
id: version
|
||||||
run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
|
run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
- name: Create release with notes
|
- name: Create release with notes
|
||||||
uses: softprops/action-gh-release@v2
|
uses: softprops/action-gh-release@v2
|
||||||
@@ -214,6 +373,76 @@ jobs:
|
|||||||
|
|
||||||
- **Linux x86_64**: `certctl-server-linux-amd64`
|
- **Linux x86_64**: `certctl-server-linux-amd64`
|
||||||
- **Linux ARM64**: `certctl-server-linux-arm64`
|
- **Linux ARM64**: `certctl-server-linux-arm64`
|
||||||
|
- **macOS x86_64**: `certctl-server-darwin-amd64`
|
||||||
|
- **macOS ARM64 (Apple Silicon)**: `certctl-server-darwin-arm64`
|
||||||
|
|
||||||
|
## CLI & MCP Server Binaries
|
||||||
|
|
||||||
|
The `certctl-cli` (REST API wrapper) and `certctl-mcp-server` (Model Context
|
||||||
|
Protocol bridge) binaries ship for all four platforms as well:
|
||||||
|
|
||||||
|
- `certctl-cli-{linux,darwin}-{amd64,arm64}`
|
||||||
|
- `certctl-mcp-server-{linux,darwin}-{amd64,arm64}`
|
||||||
|
|
||||||
|
## Verifying this release
|
||||||
|
|
||||||
|
Every binary, `checksums.txt`, and container image is signed with Cosign
|
||||||
|
keyless OIDC. Each binary ships with a SPDX-JSON SBOM. Binaries are covered
|
||||||
|
by SLSA Level 3 provenance; container images carry native SLSA L3 provenance
|
||||||
|
and SBOM attestations (docker/build-push-action `provenance: mode=max`,
|
||||||
|
`sbom: true`) in addition to a Cosign signature on the digest.
|
||||||
|
|
||||||
|
**1. Verify SHA-256 checksums:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sha256sum -c checksums.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Verify the Cosign signature on checksums.txt (keyless OIDC):**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cosign verify-blob \
|
||||||
|
--certificate checksums.txt.pem \
|
||||||
|
--signature checksums.txt.sig \
|
||||||
|
--certificate-identity-regexp '^https://github\.com/shankar0123/certctl/\.github/workflows/release\.yml@refs/tags/' \
|
||||||
|
--certificate-oidc-issuer 'https://token.actions.githubusercontent.com' \
|
||||||
|
checksums.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Replace `checksums.txt` with any individual binary name to verify that
|
||||||
|
artefact directly (each binary ships with its own `.sig` + `.pem` sidecar).
|
||||||
|
|
||||||
|
**3. Verify SLSA Level 3 provenance (binaries):**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
slsa-verifier verify-artifact \
|
||||||
|
--provenance-path multiple.intoto.jsonl \
|
||||||
|
--source-uri github.com/shankar0123/certctl \
|
||||||
|
--source-tag ${{ steps.version.outputs.VERSION }} \
|
||||||
|
certctl-agent-linux-amd64
|
||||||
|
```
|
||||||
|
|
||||||
|
**4. Verify container image signature and attestations:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
IMAGE=ghcr.io/shankar0123/certctl-server:${{ steps.version.outputs.VERSION }}
|
||||||
|
cosign verify \
|
||||||
|
--certificate-identity-regexp '^https://github\.com/shankar0123/certctl/\.github/workflows/release\.yml@refs/tags/' \
|
||||||
|
--certificate-oidc-issuer 'https://token.actions.githubusercontent.com' \
|
||||||
|
"$IMAGE"
|
||||||
|
|
||||||
|
# SBOM attestation (SPDX-JSON) emitted by docker/build-push-action
|
||||||
|
cosign verify-attestation --type spdxjson \
|
||||||
|
--certificate-identity-regexp '^https://github\.com/shankar0123/certctl/' \
|
||||||
|
--certificate-oidc-issuer 'https://token.actions.githubusercontent.com' \
|
||||||
|
"$IMAGE"
|
||||||
|
|
||||||
|
# SLSA provenance attestation (mode=max)
|
||||||
|
cosign verify-attestation --type slsaprovenance \
|
||||||
|
--certificate-identity-regexp '^https://github\.com/shankar0123/certctl/' \
|
||||||
|
--certificate-oidc-issuer 'https://token.actions.githubusercontent.com' \
|
||||||
|
"$IMAGE"
|
||||||
|
```
|
||||||
|
|
||||||
## Helm Chart
|
## Helm Chart
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ run:
|
|||||||
linters:
|
linters:
|
||||||
default: none
|
default: none
|
||||||
enable:
|
enable:
|
||||||
|
- contextcheck
|
||||||
- govet
|
- govet
|
||||||
- staticcheck
|
- staticcheck
|
||||||
- unused
|
- unused
|
||||||
|
|||||||
@@ -237,6 +237,72 @@ docker pull shankar0123.docker.scarf.sh/certctl-server
|
|||||||
docker pull shankar0123.docker.scarf.sh/certctl-agent
|
docker pull shankar0123.docker.scarf.sh/certctl-agent
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Verifying this release
|
||||||
|
|
||||||
|
Every `v*` tag publishes signed, attested release artefacts. Binaries
|
||||||
|
(`certctl-agent`, `certctl-server`, `certctl-cli`, `certctl-mcp-server` for
|
||||||
|
`linux|darwin × amd64|arm64`) ship alongside a `checksums.txt`, per-binary
|
||||||
|
SPDX-JSON SBOMs, Cosign signatures, and SLSA Level 3 provenance. Container
|
||||||
|
images on `ghcr.io/shankar0123/certctl-{server,agent}` are built with
|
||||||
|
`docker/build-push-action` `provenance: mode=max` + `sbom: true` and are
|
||||||
|
additionally signed with Cosign at the image digest.
|
||||||
|
|
||||||
|
All signatures use Cosign keyless OIDC; the signing identity is the
|
||||||
|
release workflow running on a signed tag.
|
||||||
|
|
||||||
|
**1. Verify SHA-256 checksums:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sha256sum -c checksums.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Verify the Cosign signature on `checksums.txt`:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cosign verify-blob \
|
||||||
|
--certificate checksums.txt.pem \
|
||||||
|
--signature checksums.txt.sig \
|
||||||
|
--certificate-identity-regexp '^https://github\.com/shankar0123/certctl/\.github/workflows/release\.yml@refs/tags/' \
|
||||||
|
--certificate-oidc-issuer 'https://token.actions.githubusercontent.com' \
|
||||||
|
checksums.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Every individual binary has its own `.sig` + `.pem` sidecar; swap
|
||||||
|
`checksums.txt` for any binary name to verify it directly.
|
||||||
|
|
||||||
|
**3. Verify SLSA Level 3 provenance on a binary:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
slsa-verifier verify-artifact \
|
||||||
|
--provenance-path multiple.intoto.jsonl \
|
||||||
|
--source-uri github.com/shankar0123/certctl \
|
||||||
|
--source-tag v2.1.0 \
|
||||||
|
certctl-agent-linux-amd64
|
||||||
|
```
|
||||||
|
|
||||||
|
**4. Verify a container image signature and its SBOM / provenance attestations:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
IMAGE=ghcr.io/shankar0123/certctl-server:v2.1.0
|
||||||
|
|
||||||
|
cosign verify \
|
||||||
|
--certificate-identity-regexp '^https://github\.com/shankar0123/certctl/\.github/workflows/release\.yml@refs/tags/' \
|
||||||
|
--certificate-oidc-issuer 'https://token.actions.githubusercontent.com' \
|
||||||
|
"$IMAGE"
|
||||||
|
|
||||||
|
# SBOM attestation (SPDX-JSON, emitted by docker/build-push-action)
|
||||||
|
cosign verify-attestation --type spdxjson \
|
||||||
|
--certificate-identity-regexp '^https://github\.com/shankar0123/certctl/' \
|
||||||
|
--certificate-oidc-issuer 'https://token.actions.githubusercontent.com' \
|
||||||
|
"$IMAGE"
|
||||||
|
|
||||||
|
# SLSA provenance attestation (docker/build-push-action `provenance: mode=max`)
|
||||||
|
cosign verify-attestation --type slsaprovenance \
|
||||||
|
--certificate-identity-regexp '^https://github\.com/shankar0123/certctl/' \
|
||||||
|
--certificate-oidc-issuer 'https://token.actions.githubusercontent.com' \
|
||||||
|
"$IMAGE"
|
||||||
|
```
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
Pick the scenario closest to your setup and have it running in 2 minutes.
|
Pick the scenario closest to your setup and have it running in 2 minutes.
|
||||||
|
|||||||
@@ -66,6 +66,12 @@ tags:
|
|||||||
description: Continuous TLS endpoint health checks with status tracking and probe history
|
description: Continuous TLS endpoint health checks with status tracking and probe history
|
||||||
- name: Digest
|
- name: Digest
|
||||||
description: Scheduled certificate digest email notifications
|
description: Scheduled certificate digest email notifications
|
||||||
|
- name: Verification
|
||||||
|
description: Post-deployment TLS endpoint fingerprint verification
|
||||||
|
- name: EST
|
||||||
|
description: Enrollment over Secure Transport (RFC 7030)
|
||||||
|
- name: SCEP
|
||||||
|
description: Simple Certificate Enrollment Protocol (RFC 8894)
|
||||||
|
|
||||||
paths:
|
paths:
|
||||||
# ─── Health & Auth ───────────────────────────────────────────────────
|
# ─── Health & Auth ───────────────────────────────────────────────────
|
||||||
@@ -816,6 +822,28 @@ paths:
|
|||||||
"500":
|
"500":
|
||||||
$ref: "#/components/responses/InternalError"
|
$ref: "#/components/responses/InternalError"
|
||||||
|
|
||||||
|
/api/v1/targets/{id}/test:
|
||||||
|
post:
|
||||||
|
tags: [Targets]
|
||||||
|
summary: Test target connection
|
||||||
|
description: |
|
||||||
|
Checks target connectivity by verifying the assigned agent's heartbeat status
|
||||||
|
(agent reported within the last 5 minutes). Always returns HTTP 200 — the
|
||||||
|
connectivity result is reflected in the response body's `status` field
|
||||||
|
(`success` when the agent is reachable, `failed` otherwise).
|
||||||
|
operationId: testTargetConnection
|
||||||
|
parameters:
|
||||||
|
- $ref: "#/components/parameters/resourceId"
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Connection test result (success or failed in body)
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/StatusMessageResponse"
|
||||||
|
"400":
|
||||||
|
$ref: "#/components/responses/BadRequest"
|
||||||
|
|
||||||
# ─── Agents ──────────────────────────────────────────────────────────
|
# ─── Agents ──────────────────────────────────────────────────────────
|
||||||
/api/v1/agents:
|
/api/v1/agents:
|
||||||
get:
|
get:
|
||||||
@@ -1177,6 +1205,66 @@ paths:
|
|||||||
"500":
|
"500":
|
||||||
$ref: "#/components/responses/InternalError"
|
$ref: "#/components/responses/InternalError"
|
||||||
|
|
||||||
|
/api/v1/jobs/{id}/verify:
|
||||||
|
post:
|
||||||
|
tags: [Verification]
|
||||||
|
summary: Record post-deployment verification result
|
||||||
|
description: |
|
||||||
|
Agents submit the result of probing a deployed certificate's live TLS endpoint.
|
||||||
|
Compares the served certificate's SHA-256 fingerprint against the expected
|
||||||
|
fingerprint. Best-effort: failures are recorded on the job but do not roll
|
||||||
|
back the deployment.
|
||||||
|
operationId: verifyDeployment
|
||||||
|
parameters:
|
||||||
|
- $ref: "#/components/parameters/resourceId"
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/VerifyDeploymentRequest"
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Verification result recorded
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
job_id:
|
||||||
|
type: string
|
||||||
|
verified:
|
||||||
|
type: boolean
|
||||||
|
verified_at:
|
||||||
|
type: string
|
||||||
|
format: date-time
|
||||||
|
"400":
|
||||||
|
$ref: "#/components/responses/BadRequest"
|
||||||
|
"500":
|
||||||
|
$ref: "#/components/responses/InternalError"
|
||||||
|
|
||||||
|
/api/v1/jobs/{id}/verification:
|
||||||
|
get:
|
||||||
|
tags: [Verification]
|
||||||
|
summary: Get post-deployment verification status
|
||||||
|
description: |
|
||||||
|
Returns the stored verification result for a deployment job — expected
|
||||||
|
and observed SHA-256 fingerprints, verified flag, and timestamp.
|
||||||
|
operationId: getJobVerification
|
||||||
|
parameters:
|
||||||
|
- $ref: "#/components/parameters/resourceId"
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Verification result for the job
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/VerificationResult"
|
||||||
|
"400":
|
||||||
|
$ref: "#/components/responses/BadRequest"
|
||||||
|
"500":
|
||||||
|
$ref: "#/components/responses/InternalError"
|
||||||
|
|
||||||
# ─── Policies ────────────────────────────────────────────────────────
|
# ─── Policies ────────────────────────────────────────────────────────
|
||||||
/api/v1/policies:
|
/api/v1/policies:
|
||||||
get:
|
get:
|
||||||
@@ -2718,6 +2806,238 @@ paths:
|
|||||||
"500":
|
"500":
|
||||||
$ref: "#/components/responses/InternalError"
|
$ref: "#/components/responses/InternalError"
|
||||||
|
|
||||||
|
# ─── EST (RFC 7030) ────────────────────────────────────────────────
|
||||||
|
/.well-known/est/cacerts:
|
||||||
|
get:
|
||||||
|
tags: [EST]
|
||||||
|
summary: EST CA certificates distribution
|
||||||
|
description: |
|
||||||
|
Returns the CA certificate chain used to verify certctl-issued certificates.
|
||||||
|
Response is a base64-encoded degenerate PKCS#7 SignedData (certs-only) per
|
||||||
|
RFC 7030 §4.1.3.
|
||||||
|
operationId: estCACerts
|
||||||
|
security: []
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Base64-encoded PKCS#7 certs-only structure
|
||||||
|
headers:
|
||||||
|
Content-Transfer-Encoding:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
example: base64
|
||||||
|
content:
|
||||||
|
application/pkcs7-mime:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
format: byte
|
||||||
|
description: "Base64-encoded PKCS#7 (smime-type=certs-only)"
|
||||||
|
"500":
|
||||||
|
$ref: "#/components/responses/InternalError"
|
||||||
|
|
||||||
|
/.well-known/est/simpleenroll:
|
||||||
|
post:
|
||||||
|
tags: [EST]
|
||||||
|
summary: EST simple enrollment
|
||||||
|
description: |
|
||||||
|
Enrolls a new certificate from a PKCS#10 CSR per RFC 7030 §4.2.1.
|
||||||
|
The CSR MAY be supplied as base64-encoded DER (EST standard wire format)
|
||||||
|
or as PEM for convenience. Returns a base64-encoded PKCS#7 certs-only
|
||||||
|
structure containing the issued certificate.
|
||||||
|
operationId: estSimpleEnroll
|
||||||
|
security: []
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
description: "Base64-encoded DER PKCS#10 CSR, or PEM-encoded CSR"
|
||||||
|
content:
|
||||||
|
application/pkcs10:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
format: byte
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Base64-encoded PKCS#7 cert-only response with issued certificate
|
||||||
|
headers:
|
||||||
|
Content-Transfer-Encoding:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
example: base64
|
||||||
|
content:
|
||||||
|
application/pkcs7-mime:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
format: byte
|
||||||
|
description: "Base64-encoded PKCS#7 (smime-type=certs-only)"
|
||||||
|
"400":
|
||||||
|
$ref: "#/components/responses/BadRequest"
|
||||||
|
"405":
|
||||||
|
description: Method not allowed (only POST accepted)
|
||||||
|
"500":
|
||||||
|
$ref: "#/components/responses/InternalError"
|
||||||
|
|
||||||
|
/.well-known/est/simplereenroll:
|
||||||
|
post:
|
||||||
|
tags: [EST]
|
||||||
|
summary: EST simple re-enrollment
|
||||||
|
description: |
|
||||||
|
Re-enrolls an existing certificate (same as simpleenroll in certctl's
|
||||||
|
implementation — re-enrollment is treated as a fresh issuance) per
|
||||||
|
RFC 7030 §4.2.2.
|
||||||
|
operationId: estSimpleReEnroll
|
||||||
|
security: []
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
description: "Base64-encoded DER PKCS#10 CSR, or PEM-encoded CSR"
|
||||||
|
content:
|
||||||
|
application/pkcs10:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
format: byte
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Base64-encoded PKCS#7 cert-only response with re-issued certificate
|
||||||
|
headers:
|
||||||
|
Content-Transfer-Encoding:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
example: base64
|
||||||
|
content:
|
||||||
|
application/pkcs7-mime:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
format: byte
|
||||||
|
description: "Base64-encoded PKCS#7 (smime-type=certs-only)"
|
||||||
|
"400":
|
||||||
|
$ref: "#/components/responses/BadRequest"
|
||||||
|
"405":
|
||||||
|
description: Method not allowed (only POST accepted)
|
||||||
|
"500":
|
||||||
|
$ref: "#/components/responses/InternalError"
|
||||||
|
|
||||||
|
/.well-known/est/csrattrs:
|
||||||
|
get:
|
||||||
|
tags: [EST]
|
||||||
|
summary: EST CSR attributes
|
||||||
|
description: |
|
||||||
|
Returns attributes the EST client should include in its CSR per
|
||||||
|
RFC 7030 §4.5. certctl currently returns an empty attribute set
|
||||||
|
(HTTP 204) — profile-based constraints are enforced server-side
|
||||||
|
during enrollment rather than advertised here.
|
||||||
|
operationId: estCSRAttrs
|
||||||
|
security: []
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Base64-encoded CsrAttrs (when non-empty)
|
||||||
|
headers:
|
||||||
|
Content-Transfer-Encoding:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
example: base64
|
||||||
|
content:
|
||||||
|
application/csrattrs:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
format: byte
|
||||||
|
"204":
|
||||||
|
description: No CSR attributes defined (empty response)
|
||||||
|
"500":
|
||||||
|
$ref: "#/components/responses/InternalError"
|
||||||
|
|
||||||
|
# ─── SCEP (RFC 8894) ──────────────────────────────────────────────
|
||||||
|
/scep:
|
||||||
|
get:
|
||||||
|
tags: [SCEP]
|
||||||
|
summary: SCEP operation dispatch (GET)
|
||||||
|
description: |
|
||||||
|
Single SCEP entry point dispatched by the `operation` query parameter
|
||||||
|
per RFC 8894. GET is used for capability discovery (`GetCACaps`) and
|
||||||
|
CA certificate retrieval (`GetCACert`).
|
||||||
|
operationId: scepGet
|
||||||
|
security: []
|
||||||
|
parameters:
|
||||||
|
- name: operation
|
||||||
|
in: query
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
enum: [GetCACaps, GetCACert, PKIOperation]
|
||||||
|
description: SCEP operation selector
|
||||||
|
- name: message
|
||||||
|
in: query
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
description: Optional SCEP message parameter (base64-encoded for GET PKIOperation)
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: |
|
||||||
|
Success. Content-Type varies by operation:
|
||||||
|
- `GetCACaps` → `text/plain` capability list
|
||||||
|
- `GetCACert` (single cert) → `application/x-x509-ca-cert` (raw DER)
|
||||||
|
- `GetCACert` (chain) → `application/x-x509-ca-ra-cert` (PKCS#7)
|
||||||
|
- `PKIOperation` → `application/x-pki-message` (PKCS#7 SignedData)
|
||||||
|
content:
|
||||||
|
text/plain:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
description: "SCEP capabilities (GetCACaps only)"
|
||||||
|
application/x-x509-ca-cert:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
format: binary
|
||||||
|
description: "CA certificate DER (GetCACert single)"
|
||||||
|
application/x-x509-ca-ra-cert:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
format: binary
|
||||||
|
description: "CA chain PKCS#7 (GetCACert chain)"
|
||||||
|
application/x-pki-message:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
format: binary
|
||||||
|
description: "PKCS#7 SignedData response (PKIOperation)"
|
||||||
|
"400":
|
||||||
|
$ref: "#/components/responses/BadRequest"
|
||||||
|
"500":
|
||||||
|
$ref: "#/components/responses/InternalError"
|
||||||
|
post:
|
||||||
|
tags: [SCEP]
|
||||||
|
summary: SCEP PKIOperation (POST)
|
||||||
|
description: |
|
||||||
|
SCEP enrollment / renewal / revocation request per RFC 8894.
|
||||||
|
Request body is a PKCS#7 SignedData envelope wrapping the PKCS#10 CSR
|
||||||
|
or a degenerate raw CSR (fallback). The challenge password in the CSR
|
||||||
|
attributes is validated against `CERTCTL_SCEP_CHALLENGE_PASSWORD` when
|
||||||
|
configured.
|
||||||
|
operationId: scepPost
|
||||||
|
security: []
|
||||||
|
parameters:
|
||||||
|
- name: operation
|
||||||
|
in: query
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
enum: [PKIOperation]
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
description: PKCS#7 SignedData envelope wrapping a PKCS#10 CSR (or raw CSR as fallback)
|
||||||
|
content:
|
||||||
|
application/x-pki-message:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
format: binary
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: PKCS#7 SignedData PKIMessage response
|
||||||
|
content:
|
||||||
|
application/x-pki-message:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
format: binary
|
||||||
|
"400":
|
||||||
|
$ref: "#/components/responses/BadRequest"
|
||||||
|
"500":
|
||||||
|
$ref: "#/components/responses/InternalError"
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
components:
|
components:
|
||||||
securitySchemes:
|
securitySchemes:
|
||||||
@@ -3805,3 +4125,47 @@ components:
|
|||||||
type: string
|
type: string
|
||||||
format: date-time
|
format: date-time
|
||||||
description: Timestamp of this probe
|
description: Timestamp of this probe
|
||||||
|
|
||||||
|
# ─── Verification (M25) ──────────────────────────────────────────
|
||||||
|
VerifyDeploymentRequest:
|
||||||
|
type: object
|
||||||
|
required: [target_id, expected_fingerprint, actual_fingerprint, verified]
|
||||||
|
properties:
|
||||||
|
target_id:
|
||||||
|
type: string
|
||||||
|
description: Deployment target the agent probed
|
||||||
|
expected_fingerprint:
|
||||||
|
type: string
|
||||||
|
description: SHA-256 fingerprint of the certificate that should be served (hex, lowercase)
|
||||||
|
actual_fingerprint:
|
||||||
|
type: string
|
||||||
|
description: SHA-256 fingerprint observed on the live TLS endpoint (hex, lowercase)
|
||||||
|
verified:
|
||||||
|
type: boolean
|
||||||
|
description: True when expected and actual fingerprints match
|
||||||
|
error:
|
||||||
|
type: string
|
||||||
|
nullable: true
|
||||||
|
description: Error message when probe failed or fingerprints differ
|
||||||
|
|
||||||
|
VerificationResult:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
job_id:
|
||||||
|
type: string
|
||||||
|
target_id:
|
||||||
|
type: string
|
||||||
|
expected_fingerprint:
|
||||||
|
type: string
|
||||||
|
description: SHA-256 fingerprint (hex) of the certificate deployed by this job
|
||||||
|
actual_fingerprint:
|
||||||
|
type: string
|
||||||
|
description: SHA-256 fingerprint (hex) observed on the live TLS endpoint
|
||||||
|
verified:
|
||||||
|
type: boolean
|
||||||
|
verified_at:
|
||||||
|
type: string
|
||||||
|
format: date-time
|
||||||
|
error:
|
||||||
|
type: string
|
||||||
|
description: Error message when verification failed
|
||||||
|
|||||||
+59
-17
@@ -16,7 +16,6 @@ import (
|
|||||||
"github.com/shankar0123/certctl/internal/api/middleware"
|
"github.com/shankar0123/certctl/internal/api/middleware"
|
||||||
"github.com/shankar0123/certctl/internal/api/router"
|
"github.com/shankar0123/certctl/internal/api/router"
|
||||||
"github.com/shankar0123/certctl/internal/config"
|
"github.com/shankar0123/certctl/internal/config"
|
||||||
"github.com/shankar0123/certctl/internal/crypto"
|
|
||||||
"github.com/shankar0123/certctl/internal/domain"
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
discoveryawssm "github.com/shankar0123/certctl/internal/connector/discovery/awssm"
|
discoveryawssm "github.com/shankar0123/certctl/internal/connector/discovery/awssm"
|
||||||
discoveryazurekv "github.com/shankar0123/certctl/internal/connector/discovery/azurekv"
|
discoveryazurekv "github.com/shankar0123/certctl/internal/connector/discovery/azurekv"
|
||||||
@@ -82,12 +81,20 @@ func main() {
|
|||||||
logger.Info("initialized all repositories")
|
logger.Info("initialized all repositories")
|
||||||
|
|
||||||
// Initialize dynamic issuer registry.
|
// Initialize dynamic issuer registry.
|
||||||
// Issuers are loaded from the database (with AES-GCM encrypted config).
|
// Issuers are loaded from the database (with AES-256-GCM encrypted config).
|
||||||
// On first boot with an empty database, env var issuers are seeded automatically.
|
// On first boot with an empty database, env var issuers are seeded automatically.
|
||||||
var encryptionKey []byte
|
//
|
||||||
if cfg.Encryption.ConfigEncryptionKey != "" {
|
// M-8 (CWE-916 / CWE-329): the encryption passphrase is passed as a raw
|
||||||
encryptionKey = crypto.DeriveKey(cfg.Encryption.ConfigEncryptionKey)
|
// string into IssuerService / TargetService / IssuerRegistry. Each call to
|
||||||
logger.Info("config encryption enabled (AES-256-GCM)")
|
// crypto.EncryptIfKeySet generates a fresh 16-byte PBKDF2 salt and emits a
|
||||||
|
// v2 blob (magic 0x02 || salt || nonce || sealed). Decryption auto-detects
|
||||||
|
// v1 legacy blobs (no magic) and falls back to the fixed v1 salt for
|
||||||
|
// backward compatibility; v1 blobs transparently upgrade to v2 on next
|
||||||
|
// write. DO NOT pre-derive the key here with crypto.DeriveKey — that was
|
||||||
|
// the v1 fixed-salt behaviour that M-8 removes.
|
||||||
|
encryptionKey := cfg.Encryption.ConfigEncryptionKey
|
||||||
|
if encryptionKey != "" {
|
||||||
|
logger.Info("config encryption enabled (AES-256-GCM, per-ciphertext PBKDF2 salt)")
|
||||||
} else {
|
} else {
|
||||||
// C-2 fix: fail closed at startup when database-sourced issuer or target
|
// C-2 fix: fail closed at startup when database-sourced issuer or target
|
||||||
// rows exist without a configured encryption key. Previously the server
|
// rows exist without a configured encryption key. Previously the server
|
||||||
@@ -246,9 +253,15 @@ func main() {
|
|||||||
Name: "Network Scanner (Server-Side)",
|
Name: "Network Scanner (Server-Side)",
|
||||||
Status: domain.AgentStatusOnline,
|
Status: domain.AgentStatusOnline,
|
||||||
}
|
}
|
||||||
if err := agentRepo.Create(context.Background(), sentinelAgent); err != nil {
|
// M-6: use CreateIfNotExists so duplicate rows on restart/upgrade are
|
||||||
// Ignore duplicate key errors (agent already exists)
|
// idempotent without swallowing unrelated DB failures (CWE-662).
|
||||||
logger.Debug("sentinel agent creation", "status", "exists or created", "id", service.SentinelAgentID)
|
created, err := agentRepo.CreateIfNotExists(context.Background(), sentinelAgent)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("sentinel agent creation failed", "id", service.SentinelAgentID, "error", err)
|
||||||
|
} else if created {
|
||||||
|
logger.Info("sentinel agent created", "id", service.SentinelAgentID)
|
||||||
|
} else {
|
||||||
|
logger.Debug("sentinel agent already exists", "id", service.SentinelAgentID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -267,8 +280,14 @@ func main() {
|
|||||||
Name: "AWS Secrets Manager Discovery",
|
Name: "AWS Secrets Manager Discovery",
|
||||||
Status: domain.AgentStatusOnline,
|
Status: domain.AgentStatusOnline,
|
||||||
}
|
}
|
||||||
if err := agentRepo.Create(context.Background(), sentinelAWS); err != nil {
|
// M-6: idempotent create (CWE-662).
|
||||||
logger.Debug("sentinel agent creation", "status", "exists or created", "id", service.SentinelAWSSecretsMgr)
|
created, err := agentRepo.CreateIfNotExists(context.Background(), sentinelAWS)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("sentinel agent creation failed", "id", service.SentinelAWSSecretsMgr, "error", err)
|
||||||
|
} else if created {
|
||||||
|
logger.Info("sentinel agent created", "id", service.SentinelAWSSecretsMgr)
|
||||||
|
} else {
|
||||||
|
logger.Debug("sentinel agent already exists", "id", service.SentinelAWSSecretsMgr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -286,8 +305,14 @@ func main() {
|
|||||||
Name: "Azure Key Vault Discovery",
|
Name: "Azure Key Vault Discovery",
|
||||||
Status: domain.AgentStatusOnline,
|
Status: domain.AgentStatusOnline,
|
||||||
}
|
}
|
||||||
if err := agentRepo.Create(context.Background(), sentinelAzure); err != nil {
|
// M-6: idempotent create (CWE-662).
|
||||||
logger.Debug("sentinel agent creation", "status", "exists or created", "id", service.SentinelAzureKeyVault)
|
created, err := agentRepo.CreateIfNotExists(context.Background(), sentinelAzure)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("sentinel agent creation failed", "id", service.SentinelAzureKeyVault, "error", err)
|
||||||
|
} else if created {
|
||||||
|
logger.Info("sentinel agent created", "id", service.SentinelAzureKeyVault)
|
||||||
|
} else {
|
||||||
|
logger.Debug("sentinel agent already exists", "id", service.SentinelAzureKeyVault)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -300,8 +325,14 @@ func main() {
|
|||||||
Name: "GCP Secret Manager Discovery",
|
Name: "GCP Secret Manager Discovery",
|
||||||
Status: domain.AgentStatusOnline,
|
Status: domain.AgentStatusOnline,
|
||||||
}
|
}
|
||||||
if err := agentRepo.Create(context.Background(), sentinelGCP); err != nil {
|
// M-6: idempotent create (CWE-662).
|
||||||
logger.Debug("sentinel agent creation", "status", "exists or created", "id", service.SentinelGCPSecretMgr)
|
created, err := agentRepo.CreateIfNotExists(context.Background(), sentinelGCP)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("sentinel agent creation failed", "id", service.SentinelGCPSecretMgr, "error", err)
|
||||||
|
} else if created {
|
||||||
|
logger.Info("sentinel agent created", "id", service.SentinelGCPSecretMgr)
|
||||||
|
} else {
|
||||||
|
logger.Debug("sentinel agent already exists", "id", service.SentinelGCPSecretMgr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -558,7 +589,7 @@ func main() {
|
|||||||
bodyLimitMiddleware,
|
bodyLimitMiddleware,
|
||||||
corsMiddleware,
|
corsMiddleware,
|
||||||
authMiddleware,
|
authMiddleware,
|
||||||
auditMiddleware,
|
auditMiddleware.Middleware,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add rate limiter if enabled
|
// Add rate limiter if enabled
|
||||||
@@ -575,7 +606,7 @@ func main() {
|
|||||||
rateLimiter,
|
rateLimiter,
|
||||||
corsMiddleware,
|
corsMiddleware,
|
||||||
authMiddleware,
|
authMiddleware,
|
||||||
auditMiddleware,
|
auditMiddleware.Middleware,
|
||||||
}
|
}
|
||||||
logger.Info("rate limiting enabled", "rps", cfg.RateLimit.RPS, "burst", cfg.RateLimit.BurstSize)
|
logger.Info("rate limiting enabled", "rps", cfg.RateLimit.RPS, "burst", cfg.RateLimit.BurstSize)
|
||||||
}
|
}
|
||||||
@@ -693,6 +724,17 @@ func main() {
|
|||||||
logger.Error("HTTP server shutdown error", "error", err)
|
logger.Error("HTTP server shutdown error", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Drain in-flight audit-recording goroutines before closing the DB pool.
|
||||||
|
// The audit middleware spawns one goroutine per non-excluded request; those
|
||||||
|
// goroutines run detached from the request context and write to the
|
||||||
|
// audit_events table via the same *sql.DB. Without this drain, SIGTERM
|
||||||
|
// would close the DB pool while recordings were mid-flight, silently
|
||||||
|
// dropping audit events (M-1, CWE-662 / CWE-400).
|
||||||
|
logger.Info("flushing audit middleware in-flight recordings")
|
||||||
|
if err := auditMiddleware.Flush(shutdownCtx); err != nil {
|
||||||
|
logger.Warn("audit middleware flush did not complete in time", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Close database connection
|
// Close database connection
|
||||||
if err := db.Close(); err != nil {
|
if err := db.Close(); err != nil {
|
||||||
logger.Error("error closing database connection", "error", err)
|
logger.Error("error closing database connection", "error", err)
|
||||||
|
|||||||
@@ -808,6 +808,34 @@ All shell-facing inputs (connector scripts, domain names, ACME tokens) are valid
|
|||||||
|
|
||||||
All incoming HTTP request bodies are capped by `http.MaxBytesReader` middleware (default 1MB, configurable via `CERTCTL_MAX_BODY_SIZE`). Requests exceeding the limit receive a 413 Request Entity Too Large response. The middleware is positioned before authentication in the chain so oversized payloads are rejected early, before any auth processing or database work occurs. Requests without bodies (GET, HEAD, nil body) skip the limit check.
|
All incoming HTTP request bodies are capped by `http.MaxBytesReader` middleware (default 1MB, configurable via `CERTCTL_MAX_BODY_SIZE`). Requests exceeding the limit receive a 413 Request Entity Too Large response. The middleware is positioned before authentication in the chain so oversized payloads are rejected early, before any auth processing or database work occurs. Requests without bodies (GET, HEAD, nil body) skip the limit check.
|
||||||
|
|
||||||
|
### Config Encryption at Rest
|
||||||
|
|
||||||
|
Dynamic issuer and target configurations (rows with `source='database'`) contain credentials — ACME EAB HMACs, Vault tokens, DigiCert/Sectigo API keys, SSH private keys, WinRM passwords, F5 BIG-IP passwords, and similar. These are sealed at rest in PostgreSQL via `internal/crypto/encryption.go` using AES-256-GCM with a key derived from the operator passphrase `CERTCTL_CONFIG_ENCRYPTION_KEY` through PBKDF2-SHA256 (100,000 rounds, 32-byte output).
|
||||||
|
|
||||||
|
**v2 wire format (current, M-8 remediation, CWE-916 / CWE-329):**
|
||||||
|
|
||||||
|
```
|
||||||
|
magic(0x02) || salt(16) || nonce(12) || ciphertext+tag
|
||||||
|
```
|
||||||
|
|
||||||
|
Every call to `EncryptIfKeySet` draws 16 fresh bytes from `crypto/rand` as the PBKDF2 salt, so the derived AES-256 key is distinct per ciphertext and per re-encryption. The salt is stored alongside the ciphertext; decryption reads the magic byte, splits out the salt, re-derives the key, and verifies the AEAD tag.
|
||||||
|
|
||||||
|
**v1 legacy format (read-only):**
|
||||||
|
|
||||||
|
```
|
||||||
|
nonce(12) || ciphertext+tag
|
||||||
|
```
|
||||||
|
|
||||||
|
Pre-M-8 blobs were sealed with a package-level fixed salt `"certctl-config-encryption-v1"`. `DecryptIfKeySet` preserves the v1 read path unchanged — a blob whose first byte is not `0x02`, or whose v2 AEAD verification fails (including the 1/256 case where a v1 nonce happens to begin with `0x02`), falls through to a v1 attempt against the legacy fixed salt. v1 blobs are never written by the post-M-8 code path; they re-seal as v2 naturally on the next UPDATE through the normal service CRUD flow. No operator migration ceremony is required.
|
||||||
|
|
||||||
|
**Fail-closed behavior (C-2 sentinel, CWE-311):** both `EncryptIfKeySet` and `DecryptIfKeySet` return `ErrEncryptionKeyRequired` when invoked with an empty passphrase. The server refuses to start if any `source='database'` rows already exist without `CERTCTL_CONFIG_ENCRYPTION_KEY` set.
|
||||||
|
|
||||||
|
**Low-level primitives preserved byte-identical.** `Encrypt`, `Decrypt`, and `DeriveKey` are kept bit-stable so v1 fixtures on disk remain decryptable unchanged and so callers outside the config-encryption path (none today, but the symbols are exported) do not see a breaking change. The new per-ciphertext salt path is reached via the helper `deriveKeyWithSalt(passphrase, salt)`.
|
||||||
|
|
||||||
|
**Passphrase plumbing.** Services (`IssuerService`, `TargetService`, `IssuerRegistry`) hold the operator passphrase as a raw `string` and delegate PBKDF2 to the crypto package per ciphertext. This replaces the pre-M-8 design that pre-derived a single `[]byte` key at service construction and reused it for every row, which was the direct consequence of the fixed-salt KDF.
|
||||||
|
|
||||||
|
**Coverage gate.** CI enforces `internal/crypto/...` coverage ≥ 85% (observed 86.7%) — the encryption primitives are a security-critical gate, and the v2 format plus v1 fallback plus C-2 sentinel paths all need exhaustive coverage to avoid silent regressions.
|
||||||
|
|
||||||
### CORS
|
### CORS
|
||||||
|
|
||||||
CORS uses a **deny-by-default** posture: when `CERTCTL_CORS_ORIGINS` is empty, no CORS headers are set and only same-origin requests can read responses. Operators must explicitly configure allowed origins. This prevents accidental exposure of the API to cross-origin requests in production.
|
CORS uses a **deny-by-default** posture: when `CERTCTL_CORS_ORIGINS` is empty, no CORS headers are set and only same-origin requests can read responses. Operators must explicitly configure allowed origins. This prevents accidental exposure of the API to cross-origin requests in production.
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -120,7 +121,7 @@ func TestGetCertificate_PathInjection(t *testing.T) {
|
|||||||
handler, mock := newCertHandlerWithMock()
|
handler, mock := newCertHandlerWithMock()
|
||||||
// Force a 404 so we can distinguish "service was called" from
|
// Force a 404 so we can distinguish "service was called" from
|
||||||
// "parser accepted the ID"; a 200 with null body is also fine.
|
// "parser accepted the ID"; a 200 with null body is also fine.
|
||||||
mock.GetCertificateFn = func(id string) (*domain.ManagedCertificate, error) {
|
mock.GetCertificateFn = func(_ context.Context, id string) (*domain.ManagedCertificate, error) {
|
||||||
return nil, ErrMockNotFound
|
return nil, ErrMockNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,7 +157,7 @@ func TestUpdateCertificate_PathInjection(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
handler, mock := newCertHandlerWithMock()
|
handler, mock := newCertHandlerWithMock()
|
||||||
mock.UpdateCertificateFn = func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
mock.UpdateCertificateFn = func(_ context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
||||||
return nil, ErrMockNotFound
|
return nil, ErrMockNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -184,7 +185,7 @@ func TestArchiveCertificate_PathInjection(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
handler, mock := newCertHandlerWithMock()
|
handler, mock := newCertHandlerWithMock()
|
||||||
mock.ArchiveCertificateFn = func(id string) error { return ErrMockNotFound }
|
mock.ArchiveCertificateFn = func(_ context.Context, id string) error { return ErrMockNotFound }
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/certificates/x", nil)
|
req := httptest.NewRequest(http.MethodDelete, "/api/v1/certificates/x", nil)
|
||||||
req.URL.Path = "/api/v1/certificates/" + tc.input
|
req.URL.Path = "/api/v1/certificates/" + tc.input
|
||||||
@@ -227,7 +228,7 @@ func TestGetCertificateVersions_MultiSegment(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
handler, mock := newCertHandlerWithMock()
|
handler, mock := newCertHandlerWithMock()
|
||||||
mock.GetCertificateVersionsFn = func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
|
mock.GetCertificateVersionsFn = func(_ context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
|
||||||
return []domain.CertificateVersion{}, 0, nil
|
return []domain.CertificateVersion{}, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -277,7 +278,7 @@ func TestHandleOCSP_MultiSegment(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
handler, mock := newCertHandlerWithMock()
|
handler, mock := newCertHandlerWithMock()
|
||||||
mock.GetOCSPResponseFn = func(issuerID, serialHex string) ([]byte, error) {
|
mock.GetOCSPResponseFn = func(_ context.Context, issuerID, serialHex string) ([]byte, error) {
|
||||||
return nil, ErrMockNotFound
|
return nil, ErrMockNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -311,7 +312,7 @@ func TestGetDERCRL_IssuerPathInjection(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
handler, mock := newCertHandlerWithMock()
|
handler, mock := newCertHandlerWithMock()
|
||||||
mock.GenerateDERCRLFn = func(issuerID string) ([]byte, error) {
|
mock.GenerateDERCRLFn = func(_ context.Context, issuerID string) ([]byte, error) {
|
||||||
return nil, ErrMockNotFound
|
return nil, ErrMockNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -76,7 +77,7 @@ func TestListCertificates_PaginationAbuse(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
handler, mock := newCertHandlerWithMock()
|
handler, mock := newCertHandlerWithMock()
|
||||||
mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
// Sanity: page/perPage on the filter must never be negative
|
// Sanity: page/perPage on the filter must never be negative
|
||||||
// and perPage must never exceed 500 after parsing.
|
// and perPage must never exceed 500 after parsing.
|
||||||
if filter.Page < 1 {
|
if filter.Page < 1 {
|
||||||
@@ -133,7 +134,7 @@ func TestListCertificates_SortAbuse(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
handler, mock := newCertHandlerWithMock()
|
handler, mock := newCertHandlerWithMock()
|
||||||
mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
return []domain.ManagedCertificate{}, 0, nil
|
return []domain.ManagedCertificate{}, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,7 +176,7 @@ func TestListCertificates_FieldsAbuse(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
handler, mock := newCertHandlerWithMock()
|
handler, mock := newCertHandlerWithMock()
|
||||||
mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
return []domain.ManagedCertificate{}, 0, nil
|
return []domain.ManagedCertificate{}, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -219,7 +220,7 @@ func TestListCertificates_TimeRangeAbuse(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
handler, mock := newCertHandlerWithMock()
|
handler, mock := newCertHandlerWithMock()
|
||||||
mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
return []domain.ManagedCertificate{}, 0, nil
|
return []domain.ManagedCertificate{}, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -263,7 +264,7 @@ func TestListCertificates_CursorAbuse(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
handler, mock := newCertHandlerWithMock()
|
handler, mock := newCertHandlerWithMock()
|
||||||
mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
return []domain.ManagedCertificate{}, 0, nil
|
return []domain.ManagedCertificate{}, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -314,7 +315,7 @@ func TestListCertificates_FilterInjection(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
handler, mock := newCertHandlerWithMock()
|
handler, mock := newCertHandlerWithMock()
|
||||||
mock.ListCertificatesWithFilterFn = func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
mock.ListCertificatesWithFilterFn = func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
return []domain.ManagedCertificate{}, 0, nil
|
return []domain.ManagedCertificate{}, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -374,7 +375,7 @@ func TestCreateCertificate_BodyAbuse(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
handler, mock := newCertHandlerWithMock()
|
handler, mock := newCertHandlerWithMock()
|
||||||
mock.CreateCertificateFn = func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
mock.CreateCertificateFn = func(_ context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
||||||
// If we ever reach this, the handler accepted a malformed
|
// If we ever reach this, the handler accepted a malformed
|
||||||
// body. Return a sentinel that passes but flag it.
|
// body. Return a sentinel that passes but flag it.
|
||||||
c := cert
|
c := cert
|
||||||
@@ -419,7 +420,7 @@ func TestCreateCertificate_HugeBody(t *testing.T) {
|
|||||||
sb.WriteString(`]}`)
|
sb.WriteString(`]}`)
|
||||||
|
|
||||||
handler, mock := newCertHandlerWithMock()
|
handler, mock := newCertHandlerWithMock()
|
||||||
mock.CreateCertificateFn = func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
mock.CreateCertificateFn = func(_ context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
||||||
c := cert
|
c := cert
|
||||||
c.ID = "mc-huge"
|
c.ID = "mc-huge"
|
||||||
return &c, nil
|
return &c, nil
|
||||||
@@ -476,7 +477,7 @@ func TestRevokeCertificate_ReasonAbuse(t *testing.T) {
|
|||||||
handler, mock := newCertHandlerWithMock()
|
handler, mock := newCertHandlerWithMock()
|
||||||
// The mock always returns "invalid revocation reason" so we
|
// The mock always returns "invalid revocation reason" so we
|
||||||
// verify the handler's errMsg→status mapping turns it into a 400.
|
// verify the handler's errMsg→status mapping turns it into a 400.
|
||||||
mock.RevokeCertificateFn = func(id string, reason string) error {
|
mock.RevokeCertificateFn = func(_ context.Context, id string, reason string, _ string) error {
|
||||||
// The service uses domain.IsValidRevocationReason. If we got
|
// The service uses domain.IsValidRevocationReason. If we got
|
||||||
// through to here with something bogus, simulate a real
|
// through to here with something bogus, simulate a real
|
||||||
// service error.
|
// service error.
|
||||||
@@ -500,7 +501,7 @@ func TestRevokeCertificate_ReasonAbuse(t *testing.T) {
|
|||||||
// service error message, which is fragile — this test catches regressions.
|
// service error message, which is fragile — this test catches regressions.
|
||||||
func TestRevokeCertificate_AlreadyRevoked(t *testing.T) {
|
func TestRevokeCertificate_AlreadyRevoked(t *testing.T) {
|
||||||
handler, mock := newCertHandlerWithMock()
|
handler, mock := newCertHandlerWithMock()
|
||||||
mock.RevokeCertificateFn = func(id string, reason string) error {
|
mock.RevokeCertificateFn = func(_ context.Context, id string, reason string, _ string) error {
|
||||||
return fmt.Errorf("cannot revoke: certificate is already revoked")
|
return fmt.Errorf("cannot revoke: certificate is already revoked")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -520,7 +521,7 @@ func TestRevokeCertificate_AlreadyRevoked(t *testing.T) {
|
|||||||
// TestRevokeCertificate_NotFound verifies 404 mapping.
|
// TestRevokeCertificate_NotFound verifies 404 mapping.
|
||||||
func TestRevokeCertificate_NotFound(t *testing.T) {
|
func TestRevokeCertificate_NotFound(t *testing.T) {
|
||||||
handler, mock := newCertHandlerWithMock()
|
handler, mock := newCertHandlerWithMock()
|
||||||
mock.RevokeCertificateFn = func(id string, reason string) error {
|
mock.RevokeCertificateFn = func(_ context.Context, id string, reason string, _ string) error {
|
||||||
return fmt.Errorf("certificate not found")
|
return fmt.Errorf("certificate not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -11,8 +12,8 @@ import (
|
|||||||
|
|
||||||
// AuditService defines the service interface for audit event operations.
|
// AuditService defines the service interface for audit event operations.
|
||||||
type AuditService interface {
|
type AuditService interface {
|
||||||
ListAuditEvents(page, perPage int) ([]domain.AuditEvent, int64, error)
|
ListAuditEvents(ctx context.Context, page, perPage int) ([]domain.AuditEvent, int64, error)
|
||||||
GetAuditEvent(id string) (*domain.AuditEvent, error)
|
GetAuditEvent(ctx context.Context, id string) (*domain.AuditEvent, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuditHandler handles HTTP requests for audit event operations.
|
// AuditHandler handles HTTP requests for audit event operations.
|
||||||
@@ -49,7 +50,7 @@ func (h AuditHandler) ListAuditEvents(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
events, total, err := h.svc.ListAuditEvents(page, perPage)
|
events, total, err := h.svc.ListAuditEvents(r.Context(), page, perPage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list audit events", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list audit events", requestID)
|
||||||
return
|
return
|
||||||
@@ -83,7 +84,7 @@ func (h AuditHandler) GetAuditEvent(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
id = parts[0]
|
id = parts[0]
|
||||||
|
|
||||||
event, err := h.svc.GetAuditEvent(id)
|
event, err := h.svc.GetAuditEvent(r.Context(), id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusNotFound, "Audit event not found", requestID)
|
ErrorWithRequestID(w, http.StatusNotFound, "Audit event not found", requestID)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -19,14 +19,14 @@ type mockAuditService struct {
|
|||||||
getFunc func(id string) (*domain.AuditEvent, error)
|
getFunc func(id string) (*domain.AuditEvent, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuditService) ListAuditEvents(page, perPage int) ([]domain.AuditEvent, int64, error) {
|
func (m *mockAuditService) ListAuditEvents(_ context.Context, page, perPage int) ([]domain.AuditEvent, int64, error) {
|
||||||
if m.listFunc != nil {
|
if m.listFunc != nil {
|
||||||
return m.listFunc(page, perPage)
|
return m.listFunc(page, perPage)
|
||||||
}
|
}
|
||||||
return nil, 0, nil
|
return nil, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuditService) GetAuditEvent(id string) (*domain.AuditEvent, error) {
|
func (m *mockAuditService) GetAuditEvent(_ context.Context, id string) (*domain.AuditEvent, error) {
|
||||||
if m.getFunc != nil {
|
if m.getFunc != nil {
|
||||||
return m.getFunc(id)
|
return m.getFunc(id)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,116 +17,116 @@ import (
|
|||||||
|
|
||||||
// MockCertificateService is a mock implementation of CertificateService interface.
|
// MockCertificateService is a mock implementation of CertificateService interface.
|
||||||
type MockCertificateService struct {
|
type MockCertificateService struct {
|
||||||
ListCertificatesFn func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error)
|
ListCertificatesFn func(ctx context.Context, status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error)
|
||||||
ListCertificatesWithFilterFn func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error)
|
ListCertificatesWithFilterFn func(ctx context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error)
|
||||||
GetCertificateFn func(id string) (*domain.ManagedCertificate, error)
|
GetCertificateFn func(ctx context.Context, id string) (*domain.ManagedCertificate, error)
|
||||||
CreateCertificateFn func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
|
CreateCertificateFn func(ctx context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
|
||||||
UpdateCertificateFn func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
|
UpdateCertificateFn func(ctx context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
|
||||||
ArchiveCertificateFn func(id string) error
|
ArchiveCertificateFn func(ctx context.Context, id string) error
|
||||||
GetCertificateVersionsFn func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error)
|
GetCertificateVersionsFn func(ctx context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error)
|
||||||
TriggerRenewalFn func(certID string) error
|
TriggerRenewalFn func(ctx context.Context, certID string, actor string) error
|
||||||
TriggerDeploymentFn func(certID string, targetID string) error
|
TriggerDeploymentFn func(ctx context.Context, certID string, targetID string, actor string) error
|
||||||
RevokeCertificateFn func(certID string, reason string) error
|
RevokeCertificateFn func(ctx context.Context, certID string, reason string, actor string) error
|
||||||
GetRevokedCertificatesFn func() ([]*domain.CertificateRevocation, error)
|
GetRevokedCertificatesFn func(ctx context.Context) ([]*domain.CertificateRevocation, error)
|
||||||
GenerateDERCRLFn func(issuerID string) ([]byte, error)
|
GenerateDERCRLFn func(ctx context.Context, issuerID string) ([]byte, error)
|
||||||
GetOCSPResponseFn func(issuerID string, serialHex string) ([]byte, error)
|
GetOCSPResponseFn func(ctx context.Context, issuerID string, serialHex string) ([]byte, error)
|
||||||
GetCertificateDeploymentsFn func(certID string) ([]domain.DeploymentTarget, error)
|
GetCertificateDeploymentsFn func(ctx context.Context, certID string) ([]domain.DeploymentTarget, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockCertificateService) ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
|
func (m *MockCertificateService) ListCertificates(ctx context.Context, status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
|
||||||
if m.ListCertificatesFn != nil {
|
if m.ListCertificatesFn != nil {
|
||||||
return m.ListCertificatesFn(status, environment, ownerID, teamID, issuerID, page, perPage)
|
return m.ListCertificatesFn(ctx, status, environment, ownerID, teamID, issuerID, page, perPage)
|
||||||
}
|
}
|
||||||
return nil, 0, nil
|
return nil, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockCertificateService) GetCertificate(id string) (*domain.ManagedCertificate, error) {
|
func (m *MockCertificateService) GetCertificate(ctx context.Context, id string) (*domain.ManagedCertificate, error) {
|
||||||
if m.GetCertificateFn != nil {
|
if m.GetCertificateFn != nil {
|
||||||
return m.GetCertificateFn(id)
|
return m.GetCertificateFn(ctx, id)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockCertificateService) CreateCertificate(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
func (m *MockCertificateService) CreateCertificate(ctx context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
||||||
if m.CreateCertificateFn != nil {
|
if m.CreateCertificateFn != nil {
|
||||||
return m.CreateCertificateFn(cert)
|
return m.CreateCertificateFn(ctx, cert)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockCertificateService) UpdateCertificate(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
func (m *MockCertificateService) UpdateCertificate(ctx context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
||||||
if m.UpdateCertificateFn != nil {
|
if m.UpdateCertificateFn != nil {
|
||||||
return m.UpdateCertificateFn(id, cert)
|
return m.UpdateCertificateFn(ctx, id, cert)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockCertificateService) ArchiveCertificate(id string) error {
|
func (m *MockCertificateService) ArchiveCertificate(ctx context.Context, id string) error {
|
||||||
if m.ArchiveCertificateFn != nil {
|
if m.ArchiveCertificateFn != nil {
|
||||||
return m.ArchiveCertificateFn(id)
|
return m.ArchiveCertificateFn(ctx, id)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockCertificateService) GetCertificateVersions(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
|
func (m *MockCertificateService) GetCertificateVersions(ctx context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
|
||||||
if m.GetCertificateVersionsFn != nil {
|
if m.GetCertificateVersionsFn != nil {
|
||||||
return m.GetCertificateVersionsFn(certID, page, perPage)
|
return m.GetCertificateVersionsFn(ctx, certID, page, perPage)
|
||||||
}
|
}
|
||||||
return nil, 0, nil
|
return nil, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockCertificateService) TriggerRenewal(certID string) error {
|
func (m *MockCertificateService) TriggerRenewal(ctx context.Context, certID string, actor string) error {
|
||||||
if m.TriggerRenewalFn != nil {
|
if m.TriggerRenewalFn != nil {
|
||||||
return m.TriggerRenewalFn(certID)
|
return m.TriggerRenewalFn(ctx, certID, actor)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockCertificateService) TriggerDeployment(certID string, targetID string) error {
|
func (m *MockCertificateService) TriggerDeployment(ctx context.Context, certID string, targetID string, actor string) error {
|
||||||
if m.TriggerDeploymentFn != nil {
|
if m.TriggerDeploymentFn != nil {
|
||||||
return m.TriggerDeploymentFn(certID, targetID)
|
return m.TriggerDeploymentFn(ctx, certID, targetID, actor)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockCertificateService) RevokeCertificate(certID string, reason string) error {
|
func (m *MockCertificateService) RevokeCertificate(ctx context.Context, certID string, reason string, actor string) error {
|
||||||
if m.RevokeCertificateFn != nil {
|
if m.RevokeCertificateFn != nil {
|
||||||
return m.RevokeCertificateFn(certID, reason)
|
return m.RevokeCertificateFn(ctx, certID, reason, actor)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockCertificateService) GetRevokedCertificates() ([]*domain.CertificateRevocation, error) {
|
func (m *MockCertificateService) GetRevokedCertificates(ctx context.Context) ([]*domain.CertificateRevocation, error) {
|
||||||
if m.GetRevokedCertificatesFn != nil {
|
if m.GetRevokedCertificatesFn != nil {
|
||||||
return m.GetRevokedCertificatesFn()
|
return m.GetRevokedCertificatesFn(ctx)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockCertificateService) GenerateDERCRL(issuerID string) ([]byte, error) {
|
func (m *MockCertificateService) GenerateDERCRL(ctx context.Context, issuerID string) ([]byte, error) {
|
||||||
if m.GenerateDERCRLFn != nil {
|
if m.GenerateDERCRLFn != nil {
|
||||||
return m.GenerateDERCRLFn(issuerID)
|
return m.GenerateDERCRLFn(ctx, issuerID)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockCertificateService) GetOCSPResponse(issuerID string, serialHex string) ([]byte, error) {
|
func (m *MockCertificateService) GetOCSPResponse(ctx context.Context, issuerID string, serialHex string) ([]byte, error) {
|
||||||
if m.GetOCSPResponseFn != nil {
|
if m.GetOCSPResponseFn != nil {
|
||||||
return m.GetOCSPResponseFn(issuerID, serialHex)
|
return m.GetOCSPResponseFn(ctx, issuerID, serialHex)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockCertificateService) ListCertificatesWithFilter(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
func (m *MockCertificateService) ListCertificatesWithFilter(ctx context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
if m.ListCertificatesWithFilterFn != nil {
|
if m.ListCertificatesWithFilterFn != nil {
|
||||||
return m.ListCertificatesWithFilterFn(filter)
|
return m.ListCertificatesWithFilterFn(ctx, filter)
|
||||||
}
|
}
|
||||||
return nil, 0, nil
|
return nil, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockCertificateService) GetCertificateDeployments(certID string) ([]domain.DeploymentTarget, error) {
|
func (m *MockCertificateService) GetCertificateDeployments(ctx context.Context, certID string) ([]domain.DeploymentTarget, error) {
|
||||||
if m.GetCertificateDeploymentsFn != nil {
|
if m.GetCertificateDeploymentsFn != nil {
|
||||||
return m.GetCertificateDeploymentsFn(certID)
|
return m.GetCertificateDeploymentsFn(ctx, certID)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -158,7 +158,7 @@ func TestListCertificates_Success(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
if filter.Page == 1 && filter.PerPage == 50 {
|
if filter.Page == 1 && filter.PerPage == 50 {
|
||||||
return []domain.ManagedCertificate{cert1, cert2}, 2, nil
|
return []domain.ManagedCertificate{cert1, cert2}, 2, nil
|
||||||
}
|
}
|
||||||
@@ -197,7 +197,7 @@ func TestListCertificates_Success(t *testing.T) {
|
|||||||
// Test ListCertificates - with filters
|
// Test ListCertificates - with filters
|
||||||
func TestListCertificates_WithFilters(t *testing.T) {
|
func TestListCertificates_WithFilters(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
if filter.Status == "Active" && filter.Environment == "prod" {
|
if filter.Status == "Active" && filter.Environment == "prod" {
|
||||||
return []domain.ManagedCertificate{}, 0, nil
|
return []domain.ManagedCertificate{}, 0, nil
|
||||||
}
|
}
|
||||||
@@ -236,7 +236,7 @@ func TestListCertificates_MethodNotAllowed(t *testing.T) {
|
|||||||
// Test ListCertificates - service error
|
// Test ListCertificates - service error
|
||||||
func TestListCertificates_ServiceError(t *testing.T) {
|
func TestListCertificates_ServiceError(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
return nil, 0, ErrMockServiceFailed
|
return nil, 0, ErrMockServiceFailed
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -266,7 +266,7 @@ func TestGetCertificate_Success(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
GetCertificateFn: func(id string) (*domain.ManagedCertificate, error) {
|
GetCertificateFn: func(_ context.Context, id string) (*domain.ManagedCertificate, error) {
|
||||||
if id == "mc-prod-001" {
|
if id == "mc-prod-001" {
|
||||||
return cert, nil
|
return cert, nil
|
||||||
}
|
}
|
||||||
@@ -298,7 +298,7 @@ func TestGetCertificate_Success(t *testing.T) {
|
|||||||
// Test GetCertificate - not found
|
// Test GetCertificate - not found
|
||||||
func TestGetCertificate_NotFound(t *testing.T) {
|
func TestGetCertificate_NotFound(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
GetCertificateFn: func(id string) (*domain.ManagedCertificate, error) {
|
GetCertificateFn: func(_ context.Context, id string) (*domain.ManagedCertificate, error) {
|
||||||
return nil, ErrMockNotFound
|
return nil, ErrMockNotFound
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -345,7 +345,7 @@ func TestCreateCertificate_Success(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
CreateCertificateFn: func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
CreateCertificateFn: func(_ context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
||||||
return created, nil
|
return created, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -403,7 +403,7 @@ func TestCreateCertificate_InvalidBody(t *testing.T) {
|
|||||||
// Test CreateCertificate - service error
|
// Test CreateCertificate - service error
|
||||||
func TestCreateCertificate_ServiceError(t *testing.T) {
|
func TestCreateCertificate_ServiceError(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
CreateCertificateFn: func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
CreateCertificateFn: func(_ context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
||||||
return nil, ErrMockServiceFailed
|
return nil, ErrMockServiceFailed
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -445,7 +445,7 @@ func TestUpdateCertificate_Success(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
UpdateCertificateFn: func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
UpdateCertificateFn: func(_ context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
||||||
if id == "mc-prod-001" {
|
if id == "mc-prod-001" {
|
||||||
return updated, nil
|
return updated, nil
|
||||||
}
|
}
|
||||||
@@ -501,7 +501,7 @@ func TestUpdateCertificate_InvalidBody(t *testing.T) {
|
|||||||
// Test ArchiveCertificate - success case
|
// Test ArchiveCertificate - success case
|
||||||
func TestArchiveCertificate_Success(t *testing.T) {
|
func TestArchiveCertificate_Success(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ArchiveCertificateFn: func(id string) error {
|
ArchiveCertificateFn: func(_ context.Context, id string) error {
|
||||||
if id == "mc-prod-001" {
|
if id == "mc-prod-001" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -524,7 +524,7 @@ func TestArchiveCertificate_Success(t *testing.T) {
|
|||||||
// Test ArchiveCertificate - not found
|
// Test ArchiveCertificate - not found
|
||||||
func TestArchiveCertificate_NotFound(t *testing.T) {
|
func TestArchiveCertificate_NotFound(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ArchiveCertificateFn: func(id string) error {
|
ArchiveCertificateFn: func(_ context.Context, id string) error {
|
||||||
return ErrMockNotFound
|
return ErrMockNotFound
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -554,7 +554,7 @@ func TestGetCertificateVersions_Success(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
GetCertificateVersionsFn: func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
|
GetCertificateVersionsFn: func(_ context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
|
||||||
if certID == "mc-prod-001" {
|
if certID == "mc-prod-001" {
|
||||||
return []domain.CertificateVersion{ver1}, 1, nil
|
return []domain.CertificateVersion{ver1}, 1, nil
|
||||||
}
|
}
|
||||||
@@ -586,7 +586,7 @@ func TestGetCertificateVersions_Success(t *testing.T) {
|
|||||||
// Test GetCertificateVersions - not found
|
// Test GetCertificateVersions - not found
|
||||||
func TestGetCertificateVersions_NotFound(t *testing.T) {
|
func TestGetCertificateVersions_NotFound(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
GetCertificateVersionsFn: func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
|
GetCertificateVersionsFn: func(_ context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
|
||||||
return nil, 0, ErrMockNotFound
|
return nil, 0, ErrMockNotFound
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -606,7 +606,7 @@ func TestGetCertificateVersions_NotFound(t *testing.T) {
|
|||||||
// Test TriggerRenewal - success case
|
// Test TriggerRenewal - success case
|
||||||
func TestTriggerRenewal_Success(t *testing.T) {
|
func TestTriggerRenewal_Success(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
TriggerRenewalFn: func(certID string) error {
|
TriggerRenewalFn: func(_ context.Context, certID string, _ string) error {
|
||||||
if certID == "mc-prod-001" {
|
if certID == "mc-prod-001" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -638,7 +638,7 @@ func TestTriggerRenewal_Success(t *testing.T) {
|
|||||||
// Test TriggerRenewal - service error
|
// Test TriggerRenewal - service error
|
||||||
func TestTriggerRenewal_ServiceError(t *testing.T) {
|
func TestTriggerRenewal_ServiceError(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
TriggerRenewalFn: func(certID string) error {
|
TriggerRenewalFn: func(_ context.Context, certID string, _ string) error {
|
||||||
return ErrMockServiceFailed
|
return ErrMockServiceFailed
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -658,7 +658,7 @@ func TestTriggerRenewal_ServiceError(t *testing.T) {
|
|||||||
// Test TriggerDeployment - success case
|
// Test TriggerDeployment - success case
|
||||||
func TestTriggerDeployment_Success(t *testing.T) {
|
func TestTriggerDeployment_Success(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
TriggerDeploymentFn: func(certID string, targetID string) error {
|
TriggerDeploymentFn: func(_ context.Context, certID string, targetID string, _ string) error {
|
||||||
if certID == "mc-prod-001" {
|
if certID == "mc-prod-001" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -695,7 +695,7 @@ func TestTriggerDeployment_Success(t *testing.T) {
|
|||||||
// Test TriggerDeployment - without target ID
|
// Test TriggerDeployment - without target ID
|
||||||
func TestTriggerDeployment_NoTargetID(t *testing.T) {
|
func TestTriggerDeployment_NoTargetID(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
TriggerDeploymentFn: func(certID string, targetID string) error {
|
TriggerDeploymentFn: func(_ context.Context, certID string, targetID string, _ string) error {
|
||||||
// Should accept empty targetID (deploy to all)
|
// Should accept empty targetID (deploy to all)
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
@@ -716,7 +716,7 @@ func TestTriggerDeployment_NoTargetID(t *testing.T) {
|
|||||||
// Test ListCertificates - invalid page parameter
|
// Test ListCertificates - invalid page parameter
|
||||||
func TestListCertificates_InvalidPageParam(t *testing.T) {
|
func TestListCertificates_InvalidPageParam(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
// Should default to page 1
|
// Should default to page 1
|
||||||
if filter.Page == 1 {
|
if filter.Page == 1 {
|
||||||
return []domain.ManagedCertificate{}, 0, nil
|
return []domain.ManagedCertificate{}, 0, nil
|
||||||
@@ -740,7 +740,7 @@ func TestListCertificates_InvalidPageParam(t *testing.T) {
|
|||||||
// Test ListCertificates - per_page exceeds max
|
// Test ListCertificates - per_page exceeds max
|
||||||
func TestListCertificates_PerPageExceedsMax(t *testing.T) {
|
func TestListCertificates_PerPageExceedsMax(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
// Should cap perPage at 500
|
// Should cap perPage at 500
|
||||||
if filter.PerPage == 50 { // defaults to 50 if > 500
|
if filter.PerPage == 50 { // defaults to 50 if > 500
|
||||||
return []domain.ManagedCertificate{}, 0, nil
|
return []domain.ManagedCertificate{}, 0, nil
|
||||||
@@ -765,7 +765,7 @@ func TestListCertificates_PerPageExceedsMax(t *testing.T) {
|
|||||||
|
|
||||||
func TestRevokeCertificate_Handler_Success(t *testing.T) {
|
func TestRevokeCertificate_Handler_Success(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
RevokeCertificateFn: func(certID string, reason string) error {
|
RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
|
||||||
if certID != "mc-prod-001" {
|
if certID != "mc-prod-001" {
|
||||||
t.Errorf("expected certID mc-prod-001, got %s", certID)
|
t.Errorf("expected certID mc-prod-001, got %s", certID)
|
||||||
}
|
}
|
||||||
@@ -798,7 +798,7 @@ func TestRevokeCertificate_Handler_Success(t *testing.T) {
|
|||||||
|
|
||||||
func TestRevokeCertificate_Handler_NoBody(t *testing.T) {
|
func TestRevokeCertificate_Handler_NoBody(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
RevokeCertificateFn: func(certID string, reason string) error {
|
RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
|
||||||
// Empty reason is OK — service defaults to "unspecified"
|
// Empty reason is OK — service defaults to "unspecified"
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
@@ -818,7 +818,7 @@ func TestRevokeCertificate_Handler_NoBody(t *testing.T) {
|
|||||||
|
|
||||||
func TestRevokeCertificate_Handler_AlreadyRevoked(t *testing.T) {
|
func TestRevokeCertificate_Handler_AlreadyRevoked(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
RevokeCertificateFn: func(certID string, reason string) error {
|
RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
|
||||||
return fmt.Errorf("certificate is already revoked")
|
return fmt.Errorf("certificate is already revoked")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -839,7 +839,7 @@ func TestRevokeCertificate_Handler_AlreadyRevoked(t *testing.T) {
|
|||||||
|
|
||||||
func TestRevokeCertificate_Handler_NotFound(t *testing.T) {
|
func TestRevokeCertificate_Handler_NotFound(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
RevokeCertificateFn: func(certID string, reason string) error {
|
RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
|
||||||
return fmt.Errorf("failed to fetch certificate: not found")
|
return fmt.Errorf("failed to fetch certificate: not found")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -858,7 +858,7 @@ func TestRevokeCertificate_Handler_NotFound(t *testing.T) {
|
|||||||
|
|
||||||
func TestRevokeCertificate_Handler_InvalidReason(t *testing.T) {
|
func TestRevokeCertificate_Handler_InvalidReason(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
RevokeCertificateFn: func(certID string, reason string) error {
|
RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
|
||||||
return fmt.Errorf("invalid revocation reason: badReason")
|
return fmt.Errorf("invalid revocation reason: badReason")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -922,7 +922,7 @@ func TestRevokeCertificate_Handler_EmptyID(t *testing.T) {
|
|||||||
|
|
||||||
func TestRevokeCertificate_Handler_CannotRevokeArchived(t *testing.T) {
|
func TestRevokeCertificate_Handler_CannotRevokeArchived(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
RevokeCertificateFn: func(certID string, reason string) error {
|
RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
|
||||||
return fmt.Errorf("cannot revoke archived certificate")
|
return fmt.Errorf("cannot revoke archived certificate")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -941,7 +941,7 @@ func TestRevokeCertificate_Handler_CannotRevokeArchived(t *testing.T) {
|
|||||||
|
|
||||||
func TestRevokeCertificate_Handler_ServerError(t *testing.T) {
|
func TestRevokeCertificate_Handler_ServerError(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
RevokeCertificateFn: func(certID string, reason string) error {
|
RevokeCertificateFn: func(_ context.Context, certID string, reason string, _ string) error {
|
||||||
return fmt.Errorf("database connection lost")
|
return fmt.Errorf("database connection lost")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -962,7 +962,7 @@ func TestRevokeCertificate_Handler_ServerError(t *testing.T) {
|
|||||||
|
|
||||||
func TestGetCRL_Success(t *testing.T) {
|
func TestGetCRL_Success(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
GetRevokedCertificatesFn: func() ([]*domain.CertificateRevocation, error) {
|
GetRevokedCertificatesFn: func(_ context.Context) ([]*domain.CertificateRevocation, error) {
|
||||||
return []*domain.CertificateRevocation{
|
return []*domain.CertificateRevocation{
|
||||||
{
|
{
|
||||||
ID: "rev-1",
|
ID: "rev-1",
|
||||||
@@ -1022,7 +1022,7 @@ func TestGetCRL_Success(t *testing.T) {
|
|||||||
|
|
||||||
func TestGetCRL_Empty(t *testing.T) {
|
func TestGetCRL_Empty(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
GetRevokedCertificatesFn: func() ([]*domain.CertificateRevocation, error) {
|
GetRevokedCertificatesFn: func(_ context.Context) ([]*domain.CertificateRevocation, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1047,7 +1047,7 @@ func TestGetCRL_Empty(t *testing.T) {
|
|||||||
|
|
||||||
func TestGetCRL_ServiceError(t *testing.T) {
|
func TestGetCRL_ServiceError(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
GetRevokedCertificatesFn: func() ([]*domain.CertificateRevocation, error) {
|
GetRevokedCertificatesFn: func(_ context.Context) ([]*domain.CertificateRevocation, error) {
|
||||||
return nil, fmt.Errorf("revocation repository not configured")
|
return nil, fmt.Errorf("revocation repository not configured")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1083,7 +1083,7 @@ func TestGetCRL_MethodNotAllowed(t *testing.T) {
|
|||||||
func TestGetDERCRL_Success(t *testing.T) {
|
func TestGetDERCRL_Success(t *testing.T) {
|
||||||
derCRLData := []byte{0x30, 0x82, 0x01, 0x00} // Mock DER CRL bytes
|
derCRLData := []byte{0x30, 0x82, 0x01, 0x00} // Mock DER CRL bytes
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
GenerateDERCRLFn: func(issuerID string) ([]byte, error) {
|
GenerateDERCRLFn: func(_ context.Context, issuerID string) ([]byte, error) {
|
||||||
if issuerID == "iss-local" {
|
if issuerID == "iss-local" {
|
||||||
return derCRLData, nil
|
return derCRLData, nil
|
||||||
}
|
}
|
||||||
@@ -1111,7 +1111,7 @@ func TestGetDERCRL_Success(t *testing.T) {
|
|||||||
|
|
||||||
func TestGetDERCRL_IssuerNotFound(t *testing.T) {
|
func TestGetDERCRL_IssuerNotFound(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
GenerateDERCRLFn: func(issuerID string) ([]byte, error) {
|
GenerateDERCRLFn: func(_ context.Context, issuerID string) ([]byte, error) {
|
||||||
return nil, fmt.Errorf("issuer not found")
|
return nil, fmt.Errorf("issuer not found")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1130,7 +1130,7 @@ func TestGetDERCRL_IssuerNotFound(t *testing.T) {
|
|||||||
|
|
||||||
func TestGetDERCRL_NotSupported(t *testing.T) {
|
func TestGetDERCRL_NotSupported(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
GenerateDERCRLFn: func(issuerID string) ([]byte, error) {
|
GenerateDERCRLFn: func(_ context.Context, issuerID string) ([]byte, error) {
|
||||||
return nil, fmt.Errorf("issuer does not support CRL generation")
|
return nil, fmt.Errorf("issuer does not support CRL generation")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1165,7 +1165,7 @@ func TestGetDERCRL_MethodNotAllowed(t *testing.T) {
|
|||||||
func TestHandleOCSP_Success(t *testing.T) {
|
func TestHandleOCSP_Success(t *testing.T) {
|
||||||
ocspResponseBytes := []byte{0x30, 0x82, 0x02, 0x00} // Mock OCSP response
|
ocspResponseBytes := []byte{0x30, 0x82, 0x02, 0x00} // Mock OCSP response
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
GetOCSPResponseFn: func(issuerID string, serialHex string) ([]byte, error) {
|
GetOCSPResponseFn: func(_ context.Context, issuerID string, serialHex string) ([]byte, error) {
|
||||||
if issuerID == "iss-local" && serialHex == "12345" {
|
if issuerID == "iss-local" && serialHex == "12345" {
|
||||||
return ocspResponseBytes, nil
|
return ocspResponseBytes, nil
|
||||||
}
|
}
|
||||||
@@ -1206,7 +1206,7 @@ func TestHandleOCSP_MissingSerial(t *testing.T) {
|
|||||||
|
|
||||||
func TestHandleOCSP_IssuerNotFound(t *testing.T) {
|
func TestHandleOCSP_IssuerNotFound(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
GetOCSPResponseFn: func(issuerID string, serialHex string) ([]byte, error) {
|
GetOCSPResponseFn: func(_ context.Context, issuerID string, serialHex string) ([]byte, error) {
|
||||||
return nil, fmt.Errorf("issuer not found")
|
return nil, fmt.Errorf("issuer not found")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1225,7 +1225,7 @@ func TestHandleOCSP_IssuerNotFound(t *testing.T) {
|
|||||||
|
|
||||||
func TestHandleOCSP_CertNotFound(t *testing.T) {
|
func TestHandleOCSP_CertNotFound(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
GetOCSPResponseFn: func(issuerID string, serialHex string) ([]byte, error) {
|
GetOCSPResponseFn: func(_ context.Context, issuerID string, serialHex string) ([]byte, error) {
|
||||||
return nil, fmt.Errorf("certificate not found")
|
return nil, fmt.Errorf("certificate not found")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1261,7 +1261,7 @@ func TestHandleOCSP_MethodNotAllowed(t *testing.T) {
|
|||||||
// TestListCertificates_SortParam tests sort parameter parsing and passing to service.
|
// TestListCertificates_SortParam tests sort parameter parsing and passing to service.
|
||||||
func TestListCertificates_SortParam(t *testing.T) {
|
func TestListCertificates_SortParam(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
// Handler strips the '-' prefix and sets SortDesc = true
|
// Handler strips the '-' prefix and sets SortDesc = true
|
||||||
if filter.Sort != "notAfter" || !filter.SortDesc {
|
if filter.Sort != "notAfter" || !filter.SortDesc {
|
||||||
t.Errorf("expected sort=notAfter desc=true, got sort=%s desc=%v", filter.Sort, filter.SortDesc)
|
t.Errorf("expected sort=notAfter desc=true, got sort=%s desc=%v", filter.Sort, filter.SortDesc)
|
||||||
@@ -1284,7 +1284,7 @@ func TestListCertificates_SortParam(t *testing.T) {
|
|||||||
// TestListCertificates_SortParam_Ascending tests sort parameter without '-' prefix (ascending).
|
// TestListCertificates_SortParam_Ascending tests sort parameter without '-' prefix (ascending).
|
||||||
func TestListCertificates_SortParam_Ascending(t *testing.T) {
|
func TestListCertificates_SortParam_Ascending(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
if filter.Sort != "createdAt" || filter.SortDesc {
|
if filter.Sort != "createdAt" || filter.SortDesc {
|
||||||
t.Errorf("expected sort=createdAt desc=false, got sort=%s desc=%v", filter.Sort, filter.SortDesc)
|
t.Errorf("expected sort=createdAt desc=false, got sort=%s desc=%v", filter.Sort, filter.SortDesc)
|
||||||
}
|
}
|
||||||
@@ -1309,7 +1309,7 @@ func TestListCertificates_TimeRangeFilters(t *testing.T) {
|
|||||||
after := time.Now().AddDate(0, 0, -90)
|
after := time.Now().AddDate(0, 0, -90)
|
||||||
|
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
if filter.ExpiresBefore == nil {
|
if filter.ExpiresBefore == nil {
|
||||||
t.Error("expected ExpiresBefore to be set")
|
t.Error("expected ExpiresBefore to be set")
|
||||||
}
|
}
|
||||||
@@ -1339,7 +1339,7 @@ func TestListCertificates_CreatedAfterFilter(t *testing.T) {
|
|||||||
past := time.Now().AddDate(-1, 0, 0)
|
past := time.Now().AddDate(-1, 0, 0)
|
||||||
|
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
if filter.CreatedAfter == nil {
|
if filter.CreatedAfter == nil {
|
||||||
t.Error("expected CreatedAfter to be set")
|
t.Error("expected CreatedAfter to be set")
|
||||||
}
|
}
|
||||||
@@ -1369,7 +1369,7 @@ func TestListCertificates_CursorPagination(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
return []domain.ManagedCertificate{cert}, 1, nil
|
return []domain.ManagedCertificate{cert}, 1, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1409,7 +1409,7 @@ func TestListCertificates_SparseFields(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
if len(filter.Fields) != 2 {
|
if len(filter.Fields) != 2 {
|
||||||
t.Errorf("expected 2 fields, got %d", len(filter.Fields))
|
t.Errorf("expected 2 fields, got %d", len(filter.Fields))
|
||||||
}
|
}
|
||||||
@@ -1456,7 +1456,7 @@ func TestListCertificates_SparseFields(t *testing.T) {
|
|||||||
// TestListCertificates_ProfileFilter tests profile_id filter.
|
// TestListCertificates_ProfileFilter tests profile_id filter.
|
||||||
func TestListCertificates_ProfileFilter(t *testing.T) {
|
func TestListCertificates_ProfileFilter(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
if filter.ProfileID != "prof-standard" {
|
if filter.ProfileID != "prof-standard" {
|
||||||
t.Errorf("expected ProfileID=prof-standard, got %s", filter.ProfileID)
|
t.Errorf("expected ProfileID=prof-standard, got %s", filter.ProfileID)
|
||||||
}
|
}
|
||||||
@@ -1479,7 +1479,7 @@ func TestListCertificates_ProfileFilter(t *testing.T) {
|
|||||||
// TestListCertificates_AgentIDFilter tests agent_id filter.
|
// TestListCertificates_AgentIDFilter tests agent_id filter.
|
||||||
func TestListCertificates_AgentIDFilter(t *testing.T) {
|
func TestListCertificates_AgentIDFilter(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
if filter.AgentID != "agent-prod-001" {
|
if filter.AgentID != "agent-prod-001" {
|
||||||
t.Errorf("expected AgentID=agent-prod-001, got %s", filter.AgentID)
|
t.Errorf("expected AgentID=agent-prod-001, got %s", filter.AgentID)
|
||||||
}
|
}
|
||||||
@@ -1502,7 +1502,7 @@ func TestListCertificates_AgentIDFilter(t *testing.T) {
|
|||||||
// TestListCertificates_CombinedFilters tests multiple filters together.
|
// TestListCertificates_CombinedFilters tests multiple filters together.
|
||||||
func TestListCertificates_CombinedFilters(t *testing.T) {
|
func TestListCertificates_CombinedFilters(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
ListCertificatesWithFilterFn: func(_ context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
if filter.Status != "Active" || filter.Environment != "production" || filter.ProfileID != "prof-standard" {
|
if filter.Status != "Active" || filter.Environment != "production" || filter.ProfileID != "prof-standard" {
|
||||||
t.Error("expected all filters to be set")
|
t.Error("expected all filters to be set")
|
||||||
}
|
}
|
||||||
@@ -1540,7 +1540,7 @@ func TestGetCertificateDeployments_Success(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
GetCertificateDeploymentsFn: func(certID string) ([]domain.DeploymentTarget, error) {
|
GetCertificateDeploymentsFn: func(_ context.Context, certID string) ([]domain.DeploymentTarget, error) {
|
||||||
if certID != "mc-prod-001" {
|
if certID != "mc-prod-001" {
|
||||||
return nil, ErrMockNotFound
|
return nil, ErrMockNotFound
|
||||||
}
|
}
|
||||||
@@ -1576,7 +1576,7 @@ func TestGetCertificateDeployments_Success(t *testing.T) {
|
|||||||
// TestGetCertificateDeployments_NotFound tests 404 for nonexistent certificate.
|
// TestGetCertificateDeployments_NotFound tests 404 for nonexistent certificate.
|
||||||
func TestGetCertificateDeployments_NotFound(t *testing.T) {
|
func TestGetCertificateDeployments_NotFound(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
GetCertificateDeploymentsFn: func(certID string) ([]domain.DeploymentTarget, error) {
|
GetCertificateDeploymentsFn: func(_ context.Context, certID string) ([]domain.DeploymentTarget, error) {
|
||||||
return nil, fmt.Errorf("certificate not found")
|
return nil, fmt.Errorf("certificate not found")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1596,7 +1596,7 @@ func TestGetCertificateDeployments_NotFound(t *testing.T) {
|
|||||||
// TestGetCertificateDeployments_Empty tests successful response with no deployments.
|
// TestGetCertificateDeployments_Empty tests successful response with no deployments.
|
||||||
func TestGetCertificateDeployments_Empty(t *testing.T) {
|
func TestGetCertificateDeployments_Empty(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
GetCertificateDeploymentsFn: func(certID string) ([]domain.DeploymentTarget, error) {
|
GetCertificateDeploymentsFn: func(_ context.Context, certID string) ([]domain.DeploymentTarget, error) {
|
||||||
if certID == "mc-no-deployments" {
|
if certID == "mc-no-deployments" {
|
||||||
return []domain.DeploymentTarget{}, nil
|
return []domain.DeploymentTarget{}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -15,20 +16,20 @@ import (
|
|||||||
|
|
||||||
// CertificateService defines the service interface for certificate operations.
|
// CertificateService defines the service interface for certificate operations.
|
||||||
type CertificateService interface {
|
type CertificateService interface {
|
||||||
ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error)
|
ListCertificates(ctx context.Context, status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error)
|
||||||
ListCertificatesWithFilter(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error)
|
ListCertificatesWithFilter(ctx context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error)
|
||||||
GetCertificate(id string) (*domain.ManagedCertificate, error)
|
GetCertificate(ctx context.Context, id string) (*domain.ManagedCertificate, error)
|
||||||
CreateCertificate(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
|
CreateCertificate(ctx context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
|
||||||
UpdateCertificate(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
|
UpdateCertificate(ctx context.Context, id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
|
||||||
ArchiveCertificate(id string) error
|
ArchiveCertificate(ctx context.Context, id string) error
|
||||||
GetCertificateVersions(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error)
|
GetCertificateVersions(ctx context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error)
|
||||||
TriggerRenewal(certID string) error
|
TriggerRenewal(ctx context.Context, certID string, actor string) error
|
||||||
TriggerDeployment(certID string, targetID string) error
|
TriggerDeployment(ctx context.Context, certID string, targetID string, actor string) error
|
||||||
RevokeCertificate(certID string, reason string) error
|
RevokeCertificate(ctx context.Context, certID string, reason string, actor string) error
|
||||||
GetRevokedCertificates() ([]*domain.CertificateRevocation, error)
|
GetRevokedCertificates(ctx context.Context) ([]*domain.CertificateRevocation, error)
|
||||||
GenerateDERCRL(issuerID string) ([]byte, error)
|
GenerateDERCRL(ctx context.Context, issuerID string) ([]byte, error)
|
||||||
GetOCSPResponse(issuerID string, serialHex string) ([]byte, error)
|
GetOCSPResponse(ctx context.Context, issuerID string, serialHex string) ([]byte, error)
|
||||||
GetCertificateDeployments(certID string) ([]domain.DeploymentTarget, error)
|
GetCertificateDeployments(ctx context.Context, certID string) ([]domain.DeploymentTarget, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CertificateHandler handles HTTP requests for certificate operations.
|
// CertificateHandler handles HTTP requests for certificate operations.
|
||||||
@@ -128,7 +129,7 @@ func (h CertificateHandler) ListCertificates(w http.ResponseWriter, r *http.Requ
|
|||||||
filter.Fields = strings.Split(fieldsStr, ",")
|
filter.Fields = strings.Split(fieldsStr, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
certs, total, err := h.svc.ListCertificatesWithFilter(filter)
|
certs, total, err := h.svc.ListCertificatesWithFilter(r.Context(), filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list certificates", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list certificates", requestID)
|
||||||
return
|
return
|
||||||
@@ -186,7 +187,7 @@ func (h CertificateHandler) GetCertificate(w http.ResponseWriter, r *http.Reques
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cert, err := h.svc.GetCertificate(id)
|
cert, err := h.svc.GetCertificate(r.Context(), id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
|
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
|
||||||
return
|
return
|
||||||
@@ -241,7 +242,7 @@ func (h CertificateHandler) CreateCertificate(w http.ResponseWriter, r *http.Req
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
created, err := h.svc.CreateCertificate(cert)
|
created, err := h.svc.CreateCertificate(r.Context(), cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to create certificate", "error", err, "request_id", requestID, "common_name", cert.CommonName, "name", cert.Name)
|
slog.Error("failed to create certificate", "error", err, "request_id", requestID, "common_name", cert.CommonName, "name", cert.Name)
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create certificate", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create certificate", requestID)
|
||||||
@@ -295,7 +296,7 @@ func (h CertificateHandler) UpdateCertificate(w http.ResponseWriter, r *http.Req
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, err := h.svc.UpdateCertificate(id, cert)
|
updated, err := h.svc.UpdateCertificate(r.Context(), id, cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if strings.Contains(err.Error(), "not found") {
|
||||||
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
|
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
|
||||||
@@ -325,7 +326,7 @@ func (h CertificateHandler) ArchiveCertificate(w http.ResponseWriter, r *http.Re
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.svc.ArchiveCertificate(id); err != nil {
|
if err := h.svc.ArchiveCertificate(r.Context(), id); err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if strings.Contains(err.Error(), "not found") {
|
||||||
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
|
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
|
||||||
return
|
return
|
||||||
@@ -370,7 +371,7 @@ func (h CertificateHandler) GetCertificateVersions(w http.ResponseWriter, r *htt
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
versions, total, err := h.svc.GetCertificateVersions(certID, page, perPage)
|
versions, total, err := h.svc.GetCertificateVersions(r.Context(), certID, page, perPage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if strings.Contains(err.Error(), "not found") {
|
||||||
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
|
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
|
||||||
@@ -410,7 +411,7 @@ func (h CertificateHandler) TriggerRenewal(w http.ResponseWriter, r *http.Reques
|
|||||||
}
|
}
|
||||||
certID := parts[0]
|
certID := parts[0]
|
||||||
|
|
||||||
if err := h.svc.TriggerRenewal(certID); err != nil {
|
if err := h.svc.TriggerRenewal(r.Context(), certID, "api"); err != nil {
|
||||||
errMsg := err.Error()
|
errMsg := err.Error()
|
||||||
if strings.Contains(errMsg, "not found") {
|
if strings.Contains(errMsg, "not found") {
|
||||||
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
|
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
|
||||||
@@ -466,7 +467,7 @@ func (h CertificateHandler) TriggerDeployment(w http.ResponseWriter, r *http.Req
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.svc.TriggerDeployment(certID, req.TargetID); err != nil {
|
if err := h.svc.TriggerDeployment(r.Context(), certID, req.TargetID, "api"); err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to trigger deployment", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to trigger deployment", requestID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -508,7 +509,7 @@ func (h CertificateHandler) RevokeCertificate(w http.ResponseWriter, r *http.Req
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.svc.RevokeCertificate(certID, req.Reason); err != nil {
|
if err := h.svc.RevokeCertificate(r.Context(), certID, req.Reason, "api"); err != nil {
|
||||||
// Distinguish between client errors and server errors
|
// Distinguish between client errors and server errors
|
||||||
errMsg := err.Error()
|
errMsg := err.Error()
|
||||||
if strings.Contains(errMsg, "already revoked") ||
|
if strings.Contains(errMsg, "already revoked") ||
|
||||||
@@ -540,7 +541,7 @@ func (h CertificateHandler) GetCRL(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
requestID := middleware.GetRequestID(r.Context())
|
requestID := middleware.GetRequestID(r.Context())
|
||||||
|
|
||||||
revocations, err := h.svc.GetRevokedCertificates()
|
revocations, err := h.svc.GetRevokedCertificates(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to generate CRL", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to generate CRL", requestID)
|
||||||
return
|
return
|
||||||
@@ -585,7 +586,7 @@ func (h CertificateHandler) GetDERCRL(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
derBytes, err := h.svc.GenerateDERCRL(issuerID)
|
derBytes, err := h.svc.GenerateDERCRL(r.Context(), issuerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := err.Error()
|
errMsg := err.Error()
|
||||||
if strings.Contains(errMsg, "not found") {
|
if strings.Contains(errMsg, "not found") {
|
||||||
@@ -627,7 +628,7 @@ func (h CertificateHandler) HandleOCSP(w http.ResponseWriter, r *http.Request) {
|
|||||||
issuerID := parts[0]
|
issuerID := parts[0]
|
||||||
serialHex := parts[1]
|
serialHex := parts[1]
|
||||||
|
|
||||||
derBytes, err := h.svc.GetOCSPResponse(issuerID, serialHex)
|
derBytes, err := h.svc.GetOCSPResponse(r.Context(), issuerID, serialHex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := err.Error()
|
errMsg := err.Error()
|
||||||
if strings.Contains(errMsg, "not found") {
|
if strings.Contains(errMsg, "not found") {
|
||||||
@@ -667,7 +668,7 @@ func (h CertificateHandler) GetCertificateDeployments(w http.ResponseWriter, r *
|
|||||||
}
|
}
|
||||||
certID := parts[0]
|
certID := parts[0]
|
||||||
|
|
||||||
deployments, err := h.svc.GetCertificateDeployments(certID)
|
deployments, err := h.svc.GetCertificateDeployments(r.Context(), certID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := err.Error()
|
errMsg := err.Error()
|
||||||
if strings.Contains(errMsg, "not found") {
|
if strings.Contains(errMsg, "not found") {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -15,52 +16,52 @@ import (
|
|||||||
|
|
||||||
// MockIssuerService is a mock implementation of IssuerService interface.
|
// MockIssuerService is a mock implementation of IssuerService interface.
|
||||||
type MockIssuerService struct {
|
type MockIssuerService struct {
|
||||||
ListIssuersFn func(page, perPage int) ([]domain.Issuer, int64, error)
|
ListIssuersFn func(ctx context.Context, page, perPage int) ([]domain.Issuer, int64, error)
|
||||||
GetIssuerFn func(id string) (*domain.Issuer, error)
|
GetIssuerFn func(ctx context.Context, id string) (*domain.Issuer, error)
|
||||||
CreateIssuerFn func(issuer domain.Issuer) (*domain.Issuer, error)
|
CreateIssuerFn func(ctx context.Context, issuer domain.Issuer) (*domain.Issuer, error)
|
||||||
UpdateIssuerFn func(id string, issuer domain.Issuer) (*domain.Issuer, error)
|
UpdateIssuerFn func(ctx context.Context, id string, issuer domain.Issuer) (*domain.Issuer, error)
|
||||||
DeleteIssuerFn func(id string) error
|
DeleteIssuerFn func(ctx context.Context, id string) error
|
||||||
TestConnectionFn func(id string) error
|
TestConnectionFn func(ctx context.Context, id string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockIssuerService) ListIssuers(page, perPage int) ([]domain.Issuer, int64, error) {
|
func (m *MockIssuerService) ListIssuers(ctx context.Context, page, perPage int) ([]domain.Issuer, int64, error) {
|
||||||
if m.ListIssuersFn != nil {
|
if m.ListIssuersFn != nil {
|
||||||
return m.ListIssuersFn(page, perPage)
|
return m.ListIssuersFn(ctx, page, perPage)
|
||||||
}
|
}
|
||||||
return nil, 0, nil
|
return nil, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockIssuerService) GetIssuer(id string) (*domain.Issuer, error) {
|
func (m *MockIssuerService) GetIssuer(ctx context.Context, id string) (*domain.Issuer, error) {
|
||||||
if m.GetIssuerFn != nil {
|
if m.GetIssuerFn != nil {
|
||||||
return m.GetIssuerFn(id)
|
return m.GetIssuerFn(ctx, id)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockIssuerService) CreateIssuer(issuer domain.Issuer) (*domain.Issuer, error) {
|
func (m *MockIssuerService) CreateIssuer(ctx context.Context, issuer domain.Issuer) (*domain.Issuer, error) {
|
||||||
if m.CreateIssuerFn != nil {
|
if m.CreateIssuerFn != nil {
|
||||||
return m.CreateIssuerFn(issuer)
|
return m.CreateIssuerFn(ctx, issuer)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockIssuerService) UpdateIssuer(id string, issuer domain.Issuer) (*domain.Issuer, error) {
|
func (m *MockIssuerService) UpdateIssuer(ctx context.Context, id string, issuer domain.Issuer) (*domain.Issuer, error) {
|
||||||
if m.UpdateIssuerFn != nil {
|
if m.UpdateIssuerFn != nil {
|
||||||
return m.UpdateIssuerFn(id, issuer)
|
return m.UpdateIssuerFn(ctx, id, issuer)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockIssuerService) DeleteIssuer(id string) error {
|
func (m *MockIssuerService) DeleteIssuer(ctx context.Context, id string) error {
|
||||||
if m.DeleteIssuerFn != nil {
|
if m.DeleteIssuerFn != nil {
|
||||||
return m.DeleteIssuerFn(id)
|
return m.DeleteIssuerFn(ctx, id)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockIssuerService) TestConnection(id string) error {
|
func (m *MockIssuerService) TestConnection(ctx context.Context, id string) error {
|
||||||
if m.TestConnectionFn != nil {
|
if m.TestConnectionFn != nil {
|
||||||
return m.TestConnectionFn(id)
|
return m.TestConnectionFn(ctx, id)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -85,7 +86,7 @@ func TestListIssuers_Success(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mock := &MockIssuerService{
|
mock := &MockIssuerService{
|
||||||
ListIssuersFn: func(page, perPage int) ([]domain.Issuer, int64, error) {
|
ListIssuersFn: func(_ context.Context, page, perPage int) ([]domain.Issuer, int64, error) {
|
||||||
return []domain.Issuer{iss1, iss2}, 2, nil
|
return []domain.Issuer{iss1, iss2}, 2, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -113,7 +114,7 @@ func TestListIssuers_Success(t *testing.T) {
|
|||||||
func TestListIssuers_Pagination(t *testing.T) {
|
func TestListIssuers_Pagination(t *testing.T) {
|
||||||
var capturedPage, capturedPerPage int
|
var capturedPage, capturedPerPage int
|
||||||
mock := &MockIssuerService{
|
mock := &MockIssuerService{
|
||||||
ListIssuersFn: func(page, perPage int) ([]domain.Issuer, int64, error) {
|
ListIssuersFn: func(_ context.Context, page, perPage int) ([]domain.Issuer, int64, error) {
|
||||||
capturedPage = page
|
capturedPage = page
|
||||||
capturedPerPage = perPage
|
capturedPerPage = perPage
|
||||||
return []domain.Issuer{}, 0, nil
|
return []domain.Issuer{}, 0, nil
|
||||||
@@ -137,7 +138,7 @@ func TestListIssuers_Pagination(t *testing.T) {
|
|||||||
|
|
||||||
func TestListIssuers_ServiceError(t *testing.T) {
|
func TestListIssuers_ServiceError(t *testing.T) {
|
||||||
mock := &MockIssuerService{
|
mock := &MockIssuerService{
|
||||||
ListIssuersFn: func(page, perPage int) ([]domain.Issuer, int64, error) {
|
ListIssuersFn: func(_ context.Context, page, perPage int) ([]domain.Issuer, int64, error) {
|
||||||
return nil, 0, ErrMockServiceFailed
|
return nil, 0, ErrMockServiceFailed
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -169,7 +170,7 @@ func TestListIssuers_MethodNotAllowed(t *testing.T) {
|
|||||||
func TestGetIssuer_Success(t *testing.T) {
|
func TestGetIssuer_Success(t *testing.T) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
mock := &MockIssuerService{
|
mock := &MockIssuerService{
|
||||||
GetIssuerFn: func(id string) (*domain.Issuer, error) {
|
GetIssuerFn: func(_ context.Context, id string) (*domain.Issuer, error) {
|
||||||
return &domain.Issuer{
|
return &domain.Issuer{
|
||||||
ID: id,
|
ID: id,
|
||||||
Name: "Local CA",
|
Name: "Local CA",
|
||||||
@@ -195,7 +196,7 @@ func TestGetIssuer_Success(t *testing.T) {
|
|||||||
|
|
||||||
func TestGetIssuer_NotFound(t *testing.T) {
|
func TestGetIssuer_NotFound(t *testing.T) {
|
||||||
mock := &MockIssuerService{
|
mock := &MockIssuerService{
|
||||||
GetIssuerFn: func(id string) (*domain.Issuer, error) {
|
GetIssuerFn: func(_ context.Context, id string) (*domain.Issuer, error) {
|
||||||
return nil, ErrMockNotFound
|
return nil, ErrMockNotFound
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -228,7 +229,7 @@ func TestGetIssuer_EmptyID(t *testing.T) {
|
|||||||
func TestCreateIssuer_Success(t *testing.T) {
|
func TestCreateIssuer_Success(t *testing.T) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
mock := &MockIssuerService{
|
mock := &MockIssuerService{
|
||||||
CreateIssuerFn: func(issuer domain.Issuer) (*domain.Issuer, error) {
|
CreateIssuerFn: func(_ context.Context, issuer domain.Issuer) (*domain.Issuer, error) {
|
||||||
issuer.ID = "iss-new"
|
issuer.ID = "iss-new"
|
||||||
issuer.CreatedAt = now
|
issuer.CreatedAt = now
|
||||||
issuer.UpdatedAt = now
|
issuer.UpdatedAt = now
|
||||||
@@ -328,7 +329,7 @@ func TestCreateIssuer_NameTooLong(t *testing.T) {
|
|||||||
|
|
||||||
func TestCreateIssuer_DuplicateName(t *testing.T) {
|
func TestCreateIssuer_DuplicateName(t *testing.T) {
|
||||||
mock := &MockIssuerService{
|
mock := &MockIssuerService{
|
||||||
CreateIssuerFn: func(issuer domain.Issuer) (*domain.Issuer, error) {
|
CreateIssuerFn: func(_ context.Context, issuer domain.Issuer) (*domain.Issuer, error) {
|
||||||
return nil, fmt.Errorf("failed to create issuer: duplicate key value violates unique constraint \"issuers_name_key\"")
|
return nil, fmt.Errorf("failed to create issuer: duplicate key value violates unique constraint \"issuers_name_key\"")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -361,7 +362,7 @@ func TestCreateIssuer_DuplicateName(t *testing.T) {
|
|||||||
|
|
||||||
func TestCreateIssuer_UnsupportedType(t *testing.T) {
|
func TestCreateIssuer_UnsupportedType(t *testing.T) {
|
||||||
mock := &MockIssuerService{
|
mock := &MockIssuerService{
|
||||||
CreateIssuerFn: func(issuer domain.Issuer) (*domain.Issuer, error) {
|
CreateIssuerFn: func(_ context.Context, issuer domain.Issuer) (*domain.Issuer, error) {
|
||||||
return nil, fmt.Errorf("unsupported issuer type: FakeCA")
|
return nil, fmt.Errorf("unsupported issuer type: FakeCA")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -394,7 +395,7 @@ func TestCreateIssuer_UnsupportedType(t *testing.T) {
|
|||||||
|
|
||||||
func TestCreateIssuer_GenericServiceError(t *testing.T) {
|
func TestCreateIssuer_GenericServiceError(t *testing.T) {
|
||||||
mock := &MockIssuerService{
|
mock := &MockIssuerService{
|
||||||
CreateIssuerFn: func(issuer domain.Issuer) (*domain.Issuer, error) {
|
CreateIssuerFn: func(_ context.Context, issuer domain.Issuer) (*domain.Issuer, error) {
|
||||||
return nil, fmt.Errorf("failed to encrypt config: cipher error")
|
return nil, fmt.Errorf("failed to encrypt config: cipher error")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -419,7 +420,7 @@ func TestCreateIssuer_GenericServiceError(t *testing.T) {
|
|||||||
|
|
||||||
func TestUpdateIssuer_DuplicateName(t *testing.T) {
|
func TestUpdateIssuer_DuplicateName(t *testing.T) {
|
||||||
mock := &MockIssuerService{
|
mock := &MockIssuerService{
|
||||||
UpdateIssuerFn: func(id string, issuer domain.Issuer) (*domain.Issuer, error) {
|
UpdateIssuerFn: func(_ context.Context, id string, issuer domain.Issuer) (*domain.Issuer, error) {
|
||||||
return nil, fmt.Errorf("failed to update issuer: duplicate key value violates unique constraint")
|
return nil, fmt.Errorf("failed to update issuer: duplicate key value violates unique constraint")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -445,7 +446,7 @@ func TestUpdateIssuer_DuplicateName(t *testing.T) {
|
|||||||
func TestDeleteIssuer_Success(t *testing.T) {
|
func TestDeleteIssuer_Success(t *testing.T) {
|
||||||
var deletedID string
|
var deletedID string
|
||||||
mock := &MockIssuerService{
|
mock := &MockIssuerService{
|
||||||
DeleteIssuerFn: func(id string) error {
|
DeleteIssuerFn: func(_ context.Context, id string) error {
|
||||||
deletedID = id
|
deletedID = id
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
@@ -468,7 +469,7 @@ func TestDeleteIssuer_Success(t *testing.T) {
|
|||||||
|
|
||||||
func TestDeleteIssuer_ServiceError(t *testing.T) {
|
func TestDeleteIssuer_ServiceError(t *testing.T) {
|
||||||
mock := &MockIssuerService{
|
mock := &MockIssuerService{
|
||||||
DeleteIssuerFn: func(id string) error {
|
DeleteIssuerFn: func(_ context.Context, id string) error {
|
||||||
return ErrMockServiceFailed
|
return ErrMockServiceFailed
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -487,7 +488,7 @@ func TestDeleteIssuer_ServiceError(t *testing.T) {
|
|||||||
|
|
||||||
func TestTestConnection_Success(t *testing.T) {
|
func TestTestConnection_Success(t *testing.T) {
|
||||||
mock := &MockIssuerService{
|
mock := &MockIssuerService{
|
||||||
TestConnectionFn: func(id string) error {
|
TestConnectionFn: func(_ context.Context, id string) error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -514,7 +515,7 @@ func TestTestConnection_Success(t *testing.T) {
|
|||||||
|
|
||||||
func TestTestConnection_Failure(t *testing.T) {
|
func TestTestConnection_Failure(t *testing.T) {
|
||||||
mock := &MockIssuerService{
|
mock := &MockIssuerService{
|
||||||
TestConnectionFn: func(id string) error {
|
TestConnectionFn: func(_ context.Context, id string) error {
|
||||||
return ErrMockServiceFailed
|
return ErrMockServiceFailed
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -13,12 +14,12 @@ import (
|
|||||||
|
|
||||||
// IssuerService defines the service interface for issuer operations.
|
// IssuerService defines the service interface for issuer operations.
|
||||||
type IssuerService interface {
|
type IssuerService interface {
|
||||||
ListIssuers(page, perPage int) ([]domain.Issuer, int64, error)
|
ListIssuers(ctx context.Context, page, perPage int) ([]domain.Issuer, int64, error)
|
||||||
GetIssuer(id string) (*domain.Issuer, error)
|
GetIssuer(ctx context.Context, id string) (*domain.Issuer, error)
|
||||||
CreateIssuer(issuer domain.Issuer) (*domain.Issuer, error)
|
CreateIssuer(ctx context.Context, issuer domain.Issuer) (*domain.Issuer, error)
|
||||||
UpdateIssuer(id string, issuer domain.Issuer) (*domain.Issuer, error)
|
UpdateIssuer(ctx context.Context, id string, issuer domain.Issuer) (*domain.Issuer, error)
|
||||||
DeleteIssuer(id string) error
|
DeleteIssuer(ctx context.Context, id string) error
|
||||||
TestConnection(id string) error
|
TestConnection(ctx context.Context, id string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// IssuerHandler handles HTTP requests for issuer operations.
|
// IssuerHandler handles HTTP requests for issuer operations.
|
||||||
@@ -61,7 +62,7 @@ func (h IssuerHandler) ListIssuers(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
issuers, total, err := h.svc.ListIssuers(page, perPage)
|
issuers, total, err := h.svc.ListIssuers(r.Context(), page, perPage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list issuers", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list issuers", requestID)
|
||||||
return
|
return
|
||||||
@@ -93,7 +94,7 @@ func (h IssuerHandler) GetIssuer(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
issuer, err := h.svc.GetIssuer(id)
|
issuer, err := h.svc.GetIssuer(r.Context(), id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusNotFound, "Issuer not found", requestID)
|
ErrorWithRequestID(w, http.StatusNotFound, "Issuer not found", requestID)
|
||||||
return
|
return
|
||||||
@@ -132,7 +133,7 @@ func (h IssuerHandler) CreateIssuer(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
created, err := h.svc.CreateIssuer(issuer)
|
created, err := h.svc.CreateIssuer(r.Context(), issuer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("failed to create issuer", "error", err, "name", issuer.Name, "type", issuer.Type)
|
h.logger.Error("failed to create issuer", "error", err, "name", issuer.Name, "type", issuer.Type)
|
||||||
errMsg := err.Error()
|
errMsg := err.Error()
|
||||||
@@ -174,7 +175,7 @@ func (h IssuerHandler) UpdateIssuer(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, err := h.svc.UpdateIssuer(id, issuer)
|
updated, err := h.svc.UpdateIssuer(r.Context(), id, issuer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("failed to update issuer", "error", err, "id", id)
|
h.logger.Error("failed to update issuer", "error", err, "id", id)
|
||||||
errMsg := err.Error()
|
errMsg := err.Error()
|
||||||
@@ -208,7 +209,7 @@ func (h IssuerHandler) DeleteIssuer(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.svc.DeleteIssuer(id); err != nil {
|
if err := h.svc.DeleteIssuer(r.Context(), id); err != nil {
|
||||||
if strings.Contains(err.Error(), "violates foreign key") || strings.Contains(err.Error(), "RESTRICT") {
|
if strings.Contains(err.Error(), "violates foreign key") || strings.Contains(err.Error(), "RESTRICT") {
|
||||||
ErrorWithRequestID(w, http.StatusConflict, "Cannot delete issuer: certificates are still using this issuer", requestID)
|
ErrorWithRequestID(w, http.StatusConflict, "Cannot delete issuer: certificates are still using this issuer", requestID)
|
||||||
} else if strings.Contains(err.Error(), "not found") {
|
} else if strings.Contains(err.Error(), "not found") {
|
||||||
@@ -241,7 +242,7 @@ func (h IssuerHandler) TestConnection(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
issuerID := parts[0]
|
issuerID := parts[0]
|
||||||
|
|
||||||
if err := h.svc.TestConnection(issuerID); err != nil {
|
if err := h.svc.TestConnection(r.Context(), issuerID); err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Connection test failed", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Connection test failed", requestID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -21,35 +22,35 @@ type MockJobService struct {
|
|||||||
RejectJobFn func(id string, reason string) error
|
RejectJobFn func(id string, reason string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockJobService) ListJobs(status, jobType string, page, perPage int) ([]domain.Job, int64, error) {
|
func (m *MockJobService) ListJobs(_ context.Context, status, jobType string, page, perPage int) ([]domain.Job, int64, error) {
|
||||||
if m.ListJobsFn != nil {
|
if m.ListJobsFn != nil {
|
||||||
return m.ListJobsFn(status, jobType, page, perPage)
|
return m.ListJobsFn(status, jobType, page, perPage)
|
||||||
}
|
}
|
||||||
return nil, 0, nil
|
return nil, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockJobService) GetJob(id string) (*domain.Job, error) {
|
func (m *MockJobService) GetJob(_ context.Context, id string) (*domain.Job, error) {
|
||||||
if m.GetJobFn != nil {
|
if m.GetJobFn != nil {
|
||||||
return m.GetJobFn(id)
|
return m.GetJobFn(id)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockJobService) CancelJob(id string) error {
|
func (m *MockJobService) CancelJob(_ context.Context, id string) error {
|
||||||
if m.CancelJobFn != nil {
|
if m.CancelJobFn != nil {
|
||||||
return m.CancelJobFn(id)
|
return m.CancelJobFn(id)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockJobService) ApproveJob(id string) error {
|
func (m *MockJobService) ApproveJob(_ context.Context, id string) error {
|
||||||
if m.ApproveJobFn != nil {
|
if m.ApproveJobFn != nil {
|
||||||
return m.ApproveJobFn(id)
|
return m.ApproveJobFn(id)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockJobService) RejectJob(id string, reason string) error {
|
func (m *MockJobService) RejectJob(_ context.Context, id string, reason string) error {
|
||||||
if m.RejectJobFn != nil {
|
if m.RejectJobFn != nil {
|
||||||
return m.RejectJobFn(id, reason)
|
return m.RejectJobFn(id, reason)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -13,11 +14,11 @@ import (
|
|||||||
|
|
||||||
// JobService defines the service interface for job operations.
|
// JobService defines the service interface for job operations.
|
||||||
type JobService interface {
|
type JobService interface {
|
||||||
ListJobs(status, jobType string, page, perPage int) ([]domain.Job, int64, error)
|
ListJobs(ctx context.Context, status, jobType string, page, perPage int) ([]domain.Job, int64, error)
|
||||||
GetJob(id string) (*domain.Job, error)
|
GetJob(ctx context.Context, id string) (*domain.Job, error)
|
||||||
CancelJob(id string) error
|
CancelJob(ctx context.Context, id string) error
|
||||||
ApproveJob(id string) error
|
ApproveJob(ctx context.Context, id string) error
|
||||||
RejectJob(id string, reason string) error
|
RejectJob(ctx context.Context, id string, reason string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// JobHandler handles HTTP requests for job operations.
|
// JobHandler handles HTTP requests for job operations.
|
||||||
@@ -57,7 +58,7 @@ func (h JobHandler) ListJobs(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
jobs, total, err := h.svc.ListJobs(status, jobType, page, perPage)
|
jobs, total, err := h.svc.ListJobs(r.Context(), status, jobType, page, perPage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list jobs", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list jobs", requestID)
|
||||||
return
|
return
|
||||||
@@ -91,7 +92,7 @@ func (h JobHandler) GetJob(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
id = parts[0]
|
id = parts[0]
|
||||||
|
|
||||||
job, err := h.svc.GetJob(id)
|
job, err := h.svc.GetJob(r.Context(), id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusNotFound, "Job not found", requestID)
|
ErrorWithRequestID(w, http.StatusNotFound, "Job not found", requestID)
|
||||||
return
|
return
|
||||||
@@ -119,7 +120,7 @@ func (h JobHandler) CancelJob(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
jobID := parts[0]
|
jobID := parts[0]
|
||||||
|
|
||||||
if err := h.svc.CancelJob(jobID); err != nil {
|
if err := h.svc.CancelJob(r.Context(), jobID); err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to cancel job", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to cancel job", requestID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -149,7 +150,7 @@ func (h JobHandler) ApproveJob(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
jobID := parts[0]
|
jobID := parts[0]
|
||||||
|
|
||||||
if err := h.svc.ApproveJob(jobID); err != nil {
|
if err := h.svc.ApproveJob(r.Context(), jobID); err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if strings.Contains(err.Error(), "not found") {
|
||||||
ErrorWithRequestID(w, http.StatusNotFound, "Job not found", requestID)
|
ErrorWithRequestID(w, http.StatusNotFound, "Job not found", requestID)
|
||||||
return
|
return
|
||||||
@@ -193,7 +194,7 @@ func (h JobHandler) RejectJob(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.svc.RejectJob(jobID, body.Reason); err != nil {
|
if err := h.svc.RejectJob(r.Context(), jobID, body.Reason); err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if strings.Contains(err.Error(), "not found") {
|
||||||
ErrorWithRequestID(w, http.StatusNotFound, "Job not found", requestID)
|
ErrorWithRequestID(w, http.StatusNotFound, "Job not found", requestID)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -17,21 +18,21 @@ type MockNotificationService struct {
|
|||||||
MarkAsReadFn func(id string) error
|
MarkAsReadFn func(id string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockNotificationService) ListNotifications(page, perPage int) ([]domain.NotificationEvent, int64, error) {
|
func (m *MockNotificationService) ListNotifications(_ context.Context, page, perPage int) ([]domain.NotificationEvent, int64, error) {
|
||||||
if m.ListNotificationsFn != nil {
|
if m.ListNotificationsFn != nil {
|
||||||
return m.ListNotificationsFn(page, perPage)
|
return m.ListNotificationsFn(page, perPage)
|
||||||
}
|
}
|
||||||
return nil, 0, nil
|
return nil, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockNotificationService) GetNotification(id string) (*domain.NotificationEvent, error) {
|
func (m *MockNotificationService) GetNotification(_ context.Context, id string) (*domain.NotificationEvent, error) {
|
||||||
if m.GetNotificationFn != nil {
|
if m.GetNotificationFn != nil {
|
||||||
return m.GetNotificationFn(id)
|
return m.GetNotificationFn(id)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockNotificationService) MarkAsRead(id string) error {
|
func (m *MockNotificationService) MarkAsRead(_ context.Context, id string) error {
|
||||||
if m.MarkAsReadFn != nil {
|
if m.MarkAsReadFn != nil {
|
||||||
return m.MarkAsReadFn(id)
|
return m.MarkAsReadFn(id)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -11,9 +12,9 @@ import (
|
|||||||
|
|
||||||
// NotificationService defines the service interface for notification operations.
|
// NotificationService defines the service interface for notification operations.
|
||||||
type NotificationService interface {
|
type NotificationService interface {
|
||||||
ListNotifications(page, perPage int) ([]domain.NotificationEvent, int64, error)
|
ListNotifications(ctx context.Context, page, perPage int) ([]domain.NotificationEvent, int64, error)
|
||||||
GetNotification(id string) (*domain.NotificationEvent, error)
|
GetNotification(ctx context.Context, id string) (*domain.NotificationEvent, error)
|
||||||
MarkAsRead(id string) error
|
MarkAsRead(ctx context.Context, id string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// NotificationHandler handles HTTP requests for notification operations.
|
// NotificationHandler handles HTTP requests for notification operations.
|
||||||
@@ -50,7 +51,7 @@ func (h NotificationHandler) ListNotifications(w http.ResponseWriter, r *http.Re
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
notifications, total, err := h.svc.ListNotifications(page, perPage)
|
notifications, total, err := h.svc.ListNotifications(r.Context(), page, perPage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list notifications", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list notifications", requestID)
|
||||||
return
|
return
|
||||||
@@ -84,7 +85,7 @@ func (h NotificationHandler) GetNotification(w http.ResponseWriter, r *http.Requ
|
|||||||
}
|
}
|
||||||
id = parts[0]
|
id = parts[0]
|
||||||
|
|
||||||
notification, err := h.svc.GetNotification(id)
|
notification, err := h.svc.GetNotification(r.Context(), id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusNotFound, "Notification not found", requestID)
|
ErrorWithRequestID(w, http.StatusNotFound, "Notification not found", requestID)
|
||||||
return
|
return
|
||||||
@@ -112,7 +113,7 @@ func (h NotificationHandler) MarkAsRead(w http.ResponseWriter, r *http.Request)
|
|||||||
}
|
}
|
||||||
notificationID := parts[0]
|
notificationID := parts[0]
|
||||||
|
|
||||||
if err := h.svc.MarkAsRead(notificationID); err != nil {
|
if err := h.svc.MarkAsRead(r.Context(), notificationID); err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to mark notification as read", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to mark notification as read", requestID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -20,35 +21,35 @@ type MockOwnerService struct {
|
|||||||
DeleteOwnerFn func(id string) error
|
DeleteOwnerFn func(id string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockOwnerService) ListOwners(page, perPage int) ([]domain.Owner, int64, error) {
|
func (m *MockOwnerService) ListOwners(_ context.Context, page, perPage int) ([]domain.Owner, int64, error) {
|
||||||
if m.ListOwnersFn != nil {
|
if m.ListOwnersFn != nil {
|
||||||
return m.ListOwnersFn(page, perPage)
|
return m.ListOwnersFn(page, perPage)
|
||||||
}
|
}
|
||||||
return nil, 0, nil
|
return nil, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockOwnerService) GetOwner(id string) (*domain.Owner, error) {
|
func (m *MockOwnerService) GetOwner(_ context.Context, id string) (*domain.Owner, error) {
|
||||||
if m.GetOwnerFn != nil {
|
if m.GetOwnerFn != nil {
|
||||||
return m.GetOwnerFn(id)
|
return m.GetOwnerFn(id)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockOwnerService) CreateOwner(owner domain.Owner) (*domain.Owner, error) {
|
func (m *MockOwnerService) CreateOwner(_ context.Context, owner domain.Owner) (*domain.Owner, error) {
|
||||||
if m.CreateOwnerFn != nil {
|
if m.CreateOwnerFn != nil {
|
||||||
return m.CreateOwnerFn(owner)
|
return m.CreateOwnerFn(owner)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockOwnerService) UpdateOwner(id string, owner domain.Owner) (*domain.Owner, error) {
|
func (m *MockOwnerService) UpdateOwner(_ context.Context, id string, owner domain.Owner) (*domain.Owner, error) {
|
||||||
if m.UpdateOwnerFn != nil {
|
if m.UpdateOwnerFn != nil {
|
||||||
return m.UpdateOwnerFn(id, owner)
|
return m.UpdateOwnerFn(id, owner)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockOwnerService) DeleteOwner(id string) error {
|
func (m *MockOwnerService) DeleteOwner(_ context.Context, id string) error {
|
||||||
if m.DeleteOwnerFn != nil {
|
if m.DeleteOwnerFn != nil {
|
||||||
return m.DeleteOwnerFn(id)
|
return m.DeleteOwnerFn(id)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -12,11 +13,11 @@ import (
|
|||||||
|
|
||||||
// OwnerService defines the service interface for owner operations.
|
// OwnerService defines the service interface for owner operations.
|
||||||
type OwnerService interface {
|
type OwnerService interface {
|
||||||
ListOwners(page, perPage int) ([]domain.Owner, int64, error)
|
ListOwners(ctx context.Context, page, perPage int) ([]domain.Owner, int64, error)
|
||||||
GetOwner(id string) (*domain.Owner, error)
|
GetOwner(ctx context.Context, id string) (*domain.Owner, error)
|
||||||
CreateOwner(owner domain.Owner) (*domain.Owner, error)
|
CreateOwner(ctx context.Context, owner domain.Owner) (*domain.Owner, error)
|
||||||
UpdateOwner(id string, owner domain.Owner) (*domain.Owner, error)
|
UpdateOwner(ctx context.Context, id string, owner domain.Owner) (*domain.Owner, error)
|
||||||
DeleteOwner(id string) error
|
DeleteOwner(ctx context.Context, id string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// OwnerHandler handles HTTP requests for owner operations.
|
// OwnerHandler handles HTTP requests for owner operations.
|
||||||
@@ -53,7 +54,7 @@ func (h OwnerHandler) ListOwners(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
owners, total, err := h.svc.ListOwners(page, perPage)
|
owners, total, err := h.svc.ListOwners(r.Context(), page, perPage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list owners", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list owners", requestID)
|
||||||
return
|
return
|
||||||
@@ -87,7 +88,7 @@ func (h OwnerHandler) GetOwner(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
id = parts[0]
|
id = parts[0]
|
||||||
|
|
||||||
owner, err := h.svc.GetOwner(id)
|
owner, err := h.svc.GetOwner(r.Context(), id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusNotFound, "Owner not found", requestID)
|
ErrorWithRequestID(w, http.StatusNotFound, "Owner not found", requestID)
|
||||||
return
|
return
|
||||||
@@ -122,7 +123,7 @@ func (h OwnerHandler) CreateOwner(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
created, err := h.svc.CreateOwner(owner)
|
created, err := h.svc.CreateOwner(r.Context(), owner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create owner", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create owner", requestID)
|
||||||
return
|
return
|
||||||
@@ -155,7 +156,7 @@ func (h OwnerHandler) UpdateOwner(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, err := h.svc.UpdateOwner(id, owner)
|
updated, err := h.svc.UpdateOwner(r.Context(), id, owner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update owner", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update owner", requestID)
|
||||||
return
|
return
|
||||||
@@ -182,7 +183,7 @@ func (h OwnerHandler) DeleteOwner(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
id = parts[0]
|
id = parts[0]
|
||||||
|
|
||||||
if err := h.svc.DeleteOwner(id); err != nil {
|
if err := h.svc.DeleteOwner(r.Context(), id); err != nil {
|
||||||
if strings.Contains(err.Error(), "violates foreign key") || strings.Contains(err.Error(), "RESTRICT") {
|
if strings.Contains(err.Error(), "violates foreign key") || strings.Contains(err.Error(), "RESTRICT") {
|
||||||
ErrorWithRequestID(w, http.StatusConflict, "Cannot delete owner: certificates are still assigned to this owner", requestID)
|
ErrorWithRequestID(w, http.StatusConflict, "Cannot delete owner: certificates are still assigned to this owner", requestID)
|
||||||
} else if strings.Contains(err.Error(), "not found") {
|
} else if strings.Contains(err.Error(), "not found") {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -12,12 +13,12 @@ import (
|
|||||||
|
|
||||||
// PolicyService defines the service interface for policy rule operations.
|
// PolicyService defines the service interface for policy rule operations.
|
||||||
type PolicyService interface {
|
type PolicyService interface {
|
||||||
ListPolicies(page, perPage int) ([]domain.PolicyRule, int64, error)
|
ListPolicies(ctx context.Context, page, perPage int) ([]domain.PolicyRule, int64, error)
|
||||||
GetPolicy(id string) (*domain.PolicyRule, error)
|
GetPolicy(ctx context.Context, id string) (*domain.PolicyRule, error)
|
||||||
CreatePolicy(policy domain.PolicyRule) (*domain.PolicyRule, error)
|
CreatePolicy(ctx context.Context, policy domain.PolicyRule) (*domain.PolicyRule, error)
|
||||||
UpdatePolicy(id string, policy domain.PolicyRule) (*domain.PolicyRule, error)
|
UpdatePolicy(ctx context.Context, id string, policy domain.PolicyRule) (*domain.PolicyRule, error)
|
||||||
DeletePolicy(id string) error
|
DeletePolicy(ctx context.Context, id string) error
|
||||||
ListViolations(policyID string, page, perPage int) ([]domain.PolicyViolation, int64, error)
|
ListViolations(ctx context.Context, policyID string, page, perPage int) ([]domain.PolicyViolation, int64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PolicyHandler handles HTTP requests for policy rule operations.
|
// PolicyHandler handles HTTP requests for policy rule operations.
|
||||||
@@ -54,7 +55,7 @@ func (h PolicyHandler) ListPolicies(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
policies, total, err := h.svc.ListPolicies(page, perPage)
|
policies, total, err := h.svc.ListPolicies(r.Context(), page, perPage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list policies", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list policies", requestID)
|
||||||
return
|
return
|
||||||
@@ -88,7 +89,7 @@ func (h PolicyHandler) GetPolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
id = parts[0]
|
id = parts[0]
|
||||||
|
|
||||||
policy, err := h.svc.GetPolicy(id)
|
policy, err := h.svc.GetPolicy(r.Context(), id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusNotFound, "Policy not found", requestID)
|
ErrorWithRequestID(w, http.StatusNotFound, "Policy not found", requestID)
|
||||||
return
|
return
|
||||||
@@ -127,7 +128,7 @@ func (h PolicyHandler) CreatePolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
created, err := h.svc.CreatePolicy(policy)
|
created, err := h.svc.CreatePolicy(r.Context(), policy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create policy", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create policy", requestID)
|
||||||
return
|
return
|
||||||
@@ -174,7 +175,7 @@ func (h PolicyHandler) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, err := h.svc.UpdatePolicy(id, policy)
|
updated, err := h.svc.UpdatePolicy(r.Context(), id, policy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update policy", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update policy", requestID)
|
||||||
return
|
return
|
||||||
@@ -201,7 +202,7 @@ func (h PolicyHandler) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
id = parts[0]
|
id = parts[0]
|
||||||
|
|
||||||
if err := h.svc.DeletePolicy(id); err != nil {
|
if err := h.svc.DeletePolicy(r.Context(), id); err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to delete policy", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to delete policy", requestID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -242,7 +243,7 @@ func (h PolicyHandler) ListViolations(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
violations, total, err := h.svc.ListViolations(policyID, page, perPage)
|
violations, total, err := h.svc.ListViolations(r.Context(), policyID, page, perPage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list violations", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list violations", requestID)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -21,42 +22,42 @@ type MockPolicyService struct {
|
|||||||
ListViolationsFn func(policyID string, page, perPage int) ([]domain.PolicyViolation, int64, error)
|
ListViolationsFn func(policyID string, page, perPage int) ([]domain.PolicyViolation, int64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockPolicyService) ListPolicies(page, perPage int) ([]domain.PolicyRule, int64, error) {
|
func (m *MockPolicyService) ListPolicies(_ context.Context, page, perPage int) ([]domain.PolicyRule, int64, error) {
|
||||||
if m.ListPoliciesFn != nil {
|
if m.ListPoliciesFn != nil {
|
||||||
return m.ListPoliciesFn(page, perPage)
|
return m.ListPoliciesFn(page, perPage)
|
||||||
}
|
}
|
||||||
return nil, 0, nil
|
return nil, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockPolicyService) GetPolicy(id string) (*domain.PolicyRule, error) {
|
func (m *MockPolicyService) GetPolicy(_ context.Context, id string) (*domain.PolicyRule, error) {
|
||||||
if m.GetPolicyFn != nil {
|
if m.GetPolicyFn != nil {
|
||||||
return m.GetPolicyFn(id)
|
return m.GetPolicyFn(id)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockPolicyService) CreatePolicy(policy domain.PolicyRule) (*domain.PolicyRule, error) {
|
func (m *MockPolicyService) CreatePolicy(_ context.Context, policy domain.PolicyRule) (*domain.PolicyRule, error) {
|
||||||
if m.CreatePolicyFn != nil {
|
if m.CreatePolicyFn != nil {
|
||||||
return m.CreatePolicyFn(policy)
|
return m.CreatePolicyFn(policy)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockPolicyService) UpdatePolicy(id string, policy domain.PolicyRule) (*domain.PolicyRule, error) {
|
func (m *MockPolicyService) UpdatePolicy(_ context.Context, id string, policy domain.PolicyRule) (*domain.PolicyRule, error) {
|
||||||
if m.UpdatePolicyFn != nil {
|
if m.UpdatePolicyFn != nil {
|
||||||
return m.UpdatePolicyFn(id, policy)
|
return m.UpdatePolicyFn(id, policy)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockPolicyService) DeletePolicy(id string) error {
|
func (m *MockPolicyService) DeletePolicy(_ context.Context, id string) error {
|
||||||
if m.DeletePolicyFn != nil {
|
if m.DeletePolicyFn != nil {
|
||||||
return m.DeletePolicyFn(id)
|
return m.DeletePolicyFn(id)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockPolicyService) ListViolations(policyID string, page, perPage int) ([]domain.PolicyViolation, int64, error) {
|
func (m *MockPolicyService) ListViolations(_ context.Context, policyID string, page, perPage int) ([]domain.PolicyViolation, int64, error) {
|
||||||
if m.ListViolationsFn != nil {
|
if m.ListViolationsFn != nil {
|
||||||
return m.ListViolationsFn(policyID, page, perPage)
|
return m.ListViolationsFn(policyID, page, perPage)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -20,35 +21,35 @@ type MockProfileService struct {
|
|||||||
DeleteProfileFn func(id string) error
|
DeleteProfileFn func(id string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockProfileService) ListProfiles(page, perPage int) ([]domain.CertificateProfile, int64, error) {
|
func (m *MockProfileService) ListProfiles(_ context.Context, page, perPage int) ([]domain.CertificateProfile, int64, error) {
|
||||||
if m.ListProfilesFn != nil {
|
if m.ListProfilesFn != nil {
|
||||||
return m.ListProfilesFn(page, perPage)
|
return m.ListProfilesFn(page, perPage)
|
||||||
}
|
}
|
||||||
return nil, 0, nil
|
return nil, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockProfileService) GetProfile(id string) (*domain.CertificateProfile, error) {
|
func (m *MockProfileService) GetProfile(_ context.Context, id string) (*domain.CertificateProfile, error) {
|
||||||
if m.GetProfileFn != nil {
|
if m.GetProfileFn != nil {
|
||||||
return m.GetProfileFn(id)
|
return m.GetProfileFn(id)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockProfileService) CreateProfile(profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
|
func (m *MockProfileService) CreateProfile(_ context.Context, profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
|
||||||
if m.CreateProfileFn != nil {
|
if m.CreateProfileFn != nil {
|
||||||
return m.CreateProfileFn(profile)
|
return m.CreateProfileFn(profile)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockProfileService) UpdateProfile(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
|
func (m *MockProfileService) UpdateProfile(_ context.Context, id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
|
||||||
if m.UpdateProfileFn != nil {
|
if m.UpdateProfileFn != nil {
|
||||||
return m.UpdateProfileFn(id, profile)
|
return m.UpdateProfileFn(id, profile)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockProfileService) DeleteProfile(id string) error {
|
func (m *MockProfileService) DeleteProfile(_ context.Context, id string) error {
|
||||||
if m.DeleteProfileFn != nil {
|
if m.DeleteProfileFn != nil {
|
||||||
return m.DeleteProfileFn(id)
|
return m.DeleteProfileFn(id)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -12,11 +13,11 @@ import (
|
|||||||
|
|
||||||
// ProfileService defines the service interface for certificate profile operations.
|
// ProfileService defines the service interface for certificate profile operations.
|
||||||
type ProfileService interface {
|
type ProfileService interface {
|
||||||
ListProfiles(page, perPage int) ([]domain.CertificateProfile, int64, error)
|
ListProfiles(ctx context.Context, page, perPage int) ([]domain.CertificateProfile, int64, error)
|
||||||
GetProfile(id string) (*domain.CertificateProfile, error)
|
GetProfile(ctx context.Context, id string) (*domain.CertificateProfile, error)
|
||||||
CreateProfile(profile domain.CertificateProfile) (*domain.CertificateProfile, error)
|
CreateProfile(ctx context.Context, profile domain.CertificateProfile) (*domain.CertificateProfile, error)
|
||||||
UpdateProfile(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error)
|
UpdateProfile(ctx context.Context, id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error)
|
||||||
DeleteProfile(id string) error
|
DeleteProfile(ctx context.Context, id string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProfileHandler handles HTTP requests for certificate profile operations.
|
// ProfileHandler handles HTTP requests for certificate profile operations.
|
||||||
@@ -53,7 +54,7 @@ func (h ProfileHandler) ListProfiles(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
profiles, total, err := h.svc.ListProfiles(page, perPage)
|
profiles, total, err := h.svc.ListProfiles(r.Context(), page, perPage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list profiles", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list profiles", requestID)
|
||||||
return
|
return
|
||||||
@@ -85,7 +86,7 @@ func (h ProfileHandler) GetProfile(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
profile, err := h.svc.GetProfile(id)
|
profile, err := h.svc.GetProfile(r.Context(), id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusNotFound, "Profile not found", requestID)
|
ErrorWithRequestID(w, http.StatusNotFound, "Profile not found", requestID)
|
||||||
return
|
return
|
||||||
@@ -120,7 +121,7 @@ func (h ProfileHandler) CreateProfile(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
created, err := h.svc.CreateProfile(profile)
|
created, err := h.svc.CreateProfile(r.Context(), profile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Check if it's a validation error from the service
|
// Check if it's a validation error from the service
|
||||||
if strings.Contains(err.Error(), "invalid") || strings.Contains(err.Error(), "required") ||
|
if strings.Contains(err.Error(), "invalid") || strings.Contains(err.Error(), "required") ||
|
||||||
@@ -159,7 +160,7 @@ func (h ProfileHandler) UpdateProfile(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, err := h.svc.UpdateProfile(id, profile)
|
updated, err := h.svc.UpdateProfile(r.Context(), id, profile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if strings.Contains(err.Error(), "not found") {
|
||||||
ErrorWithRequestID(w, http.StatusNotFound, "Profile not found", requestID)
|
ErrorWithRequestID(w, http.StatusNotFound, "Profile not found", requestID)
|
||||||
@@ -193,7 +194,7 @@ func (h ProfileHandler) DeleteProfile(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.svc.DeleteProfile(id); err != nil {
|
if err := h.svc.DeleteProfile(r.Context(), id); err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if strings.Contains(err.Error(), "not found") {
|
||||||
ErrorWithRequestID(w, http.StatusNotFound, "Profile not found", requestID)
|
ErrorWithRequestID(w, http.StatusNotFound, "Profile not found", requestID)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -13,52 +14,52 @@ import (
|
|||||||
|
|
||||||
// MockTargetService is a mock implementation of TargetService interface.
|
// MockTargetService is a mock implementation of TargetService interface.
|
||||||
type MockTargetService struct {
|
type MockTargetService struct {
|
||||||
ListTargetsFn func(page, perPage int) ([]domain.DeploymentTarget, int64, error)
|
ListTargetsFn func(ctx context.Context, page, perPage int) ([]domain.DeploymentTarget, int64, error)
|
||||||
GetTargetFn func(id string) (*domain.DeploymentTarget, error)
|
GetTargetFn func(ctx context.Context, id string) (*domain.DeploymentTarget, error)
|
||||||
CreateTargetFn func(target domain.DeploymentTarget) (*domain.DeploymentTarget, error)
|
CreateTargetFn func(ctx context.Context, target domain.DeploymentTarget) (*domain.DeploymentTarget, error)
|
||||||
UpdateTargetFn func(id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error)
|
UpdateTargetFn func(ctx context.Context, id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error)
|
||||||
DeleteTargetFn func(id string) error
|
DeleteTargetFn func(ctx context.Context, id string) error
|
||||||
TestTargetConnectionFn func(id string) error
|
TestConnectionFn func(ctx context.Context, id string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockTargetService) ListTargets(page, perPage int) ([]domain.DeploymentTarget, int64, error) {
|
func (m *MockTargetService) ListTargets(ctx context.Context, page, perPage int) ([]domain.DeploymentTarget, int64, error) {
|
||||||
if m.ListTargetsFn != nil {
|
if m.ListTargetsFn != nil {
|
||||||
return m.ListTargetsFn(page, perPage)
|
return m.ListTargetsFn(ctx, page, perPage)
|
||||||
}
|
}
|
||||||
return nil, 0, nil
|
return nil, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockTargetService) GetTarget(id string) (*domain.DeploymentTarget, error) {
|
func (m *MockTargetService) GetTarget(ctx context.Context, id string) (*domain.DeploymentTarget, error) {
|
||||||
if m.GetTargetFn != nil {
|
if m.GetTargetFn != nil {
|
||||||
return m.GetTargetFn(id)
|
return m.GetTargetFn(ctx, id)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockTargetService) CreateTarget(target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
func (m *MockTargetService) CreateTarget(ctx context.Context, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
||||||
if m.CreateTargetFn != nil {
|
if m.CreateTargetFn != nil {
|
||||||
return m.CreateTargetFn(target)
|
return m.CreateTargetFn(ctx, target)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockTargetService) UpdateTarget(id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
func (m *MockTargetService) UpdateTarget(ctx context.Context, id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
||||||
if m.UpdateTargetFn != nil {
|
if m.UpdateTargetFn != nil {
|
||||||
return m.UpdateTargetFn(id, target)
|
return m.UpdateTargetFn(ctx, id, target)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockTargetService) DeleteTarget(id string) error {
|
func (m *MockTargetService) DeleteTarget(ctx context.Context, id string) error {
|
||||||
if m.DeleteTargetFn != nil {
|
if m.DeleteTargetFn != nil {
|
||||||
return m.DeleteTargetFn(id)
|
return m.DeleteTargetFn(ctx, id)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockTargetService) TestTargetConnection(id string) error {
|
func (m *MockTargetService) TestConnection(ctx context.Context, id string) error {
|
||||||
if m.TestTargetConnectionFn != nil {
|
if m.TestConnectionFn != nil {
|
||||||
return m.TestTargetConnectionFn(id)
|
return m.TestConnectionFn(ctx, id)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -85,7 +86,7 @@ func TestListTargets_Success(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mock := &MockTargetService{
|
mock := &MockTargetService{
|
||||||
ListTargetsFn: func(page, perPage int) ([]domain.DeploymentTarget, int64, error) {
|
ListTargetsFn: func(_ context.Context, page, perPage int) ([]domain.DeploymentTarget, int64, error) {
|
||||||
return []domain.DeploymentTarget{t1, t2}, 2, nil
|
return []domain.DeploymentTarget{t1, t2}, 2, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -113,7 +114,7 @@ func TestListTargets_Success(t *testing.T) {
|
|||||||
func TestListTargets_Pagination(t *testing.T) {
|
func TestListTargets_Pagination(t *testing.T) {
|
||||||
var capturedPage, capturedPerPage int
|
var capturedPage, capturedPerPage int
|
||||||
mock := &MockTargetService{
|
mock := &MockTargetService{
|
||||||
ListTargetsFn: func(page, perPage int) ([]domain.DeploymentTarget, int64, error) {
|
ListTargetsFn: func(_ context.Context, page, perPage int) ([]domain.DeploymentTarget, int64, error) {
|
||||||
capturedPage = page
|
capturedPage = page
|
||||||
capturedPerPage = perPage
|
capturedPerPage = perPage
|
||||||
return []domain.DeploymentTarget{}, 0, nil
|
return []domain.DeploymentTarget{}, 0, nil
|
||||||
@@ -137,7 +138,7 @@ func TestListTargets_Pagination(t *testing.T) {
|
|||||||
|
|
||||||
func TestListTargets_ServiceError(t *testing.T) {
|
func TestListTargets_ServiceError(t *testing.T) {
|
||||||
mock := &MockTargetService{
|
mock := &MockTargetService{
|
||||||
ListTargetsFn: func(page, perPage int) ([]domain.DeploymentTarget, int64, error) {
|
ListTargetsFn: func(_ context.Context, page, perPage int) ([]domain.DeploymentTarget, int64, error) {
|
||||||
return nil, 0, ErrMockServiceFailed
|
return nil, 0, ErrMockServiceFailed
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -169,7 +170,7 @@ func TestListTargets_MethodNotAllowed(t *testing.T) {
|
|||||||
func TestGetTarget_Success(t *testing.T) {
|
func TestGetTarget_Success(t *testing.T) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
mock := &MockTargetService{
|
mock := &MockTargetService{
|
||||||
GetTargetFn: func(id string) (*domain.DeploymentTarget, error) {
|
GetTargetFn: func(_ context.Context, id string) (*domain.DeploymentTarget, error) {
|
||||||
return &domain.DeploymentTarget{
|
return &domain.DeploymentTarget{
|
||||||
ID: id,
|
ID: id,
|
||||||
Name: "NGINX Proxy",
|
Name: "NGINX Proxy",
|
||||||
@@ -196,7 +197,7 @@ func TestGetTarget_Success(t *testing.T) {
|
|||||||
|
|
||||||
func TestGetTarget_NotFound(t *testing.T) {
|
func TestGetTarget_NotFound(t *testing.T) {
|
||||||
mock := &MockTargetService{
|
mock := &MockTargetService{
|
||||||
GetTargetFn: func(id string) (*domain.DeploymentTarget, error) {
|
GetTargetFn: func(_ context.Context, id string) (*domain.DeploymentTarget, error) {
|
||||||
return nil, ErrMockNotFound
|
return nil, ErrMockNotFound
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -229,7 +230,7 @@ func TestGetTarget_EmptyID(t *testing.T) {
|
|||||||
func TestCreateTarget_Success(t *testing.T) {
|
func TestCreateTarget_Success(t *testing.T) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
mock := &MockTargetService{
|
mock := &MockTargetService{
|
||||||
CreateTargetFn: func(target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
CreateTargetFn: func(_ context.Context, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
||||||
target.ID = "t-new"
|
target.ID = "t-new"
|
||||||
target.CreatedAt = now
|
target.CreatedAt = now
|
||||||
target.UpdatedAt = now
|
target.UpdatedAt = now
|
||||||
@@ -342,7 +343,7 @@ func TestCreateTarget_MethodNotAllowed(t *testing.T) {
|
|||||||
func TestUpdateTarget_Success(t *testing.T) {
|
func TestUpdateTarget_Success(t *testing.T) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
mock := &MockTargetService{
|
mock := &MockTargetService{
|
||||||
UpdateTargetFn: func(id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
UpdateTargetFn: func(_ context.Context, id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
||||||
return &domain.DeploymentTarget{
|
return &domain.DeploymentTarget{
|
||||||
ID: id,
|
ID: id,
|
||||||
Name: target.Name,
|
Name: target.Name,
|
||||||
@@ -375,7 +376,7 @@ func TestUpdateTarget_Success(t *testing.T) {
|
|||||||
func TestDeleteTarget_Success(t *testing.T) {
|
func TestDeleteTarget_Success(t *testing.T) {
|
||||||
var deletedID string
|
var deletedID string
|
||||||
mock := &MockTargetService{
|
mock := &MockTargetService{
|
||||||
DeleteTargetFn: func(id string) error {
|
DeleteTargetFn: func(_ context.Context, id string) error {
|
||||||
deletedID = id
|
deletedID = id
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
@@ -398,7 +399,7 @@ func TestDeleteTarget_Success(t *testing.T) {
|
|||||||
|
|
||||||
func TestDeleteTarget_ServiceError(t *testing.T) {
|
func TestDeleteTarget_ServiceError(t *testing.T) {
|
||||||
mock := &MockTargetService{
|
mock := &MockTargetService{
|
||||||
DeleteTargetFn: func(id string) error {
|
DeleteTargetFn: func(_ context.Context, id string) error {
|
||||||
return ErrMockServiceFailed
|
return ErrMockServiceFailed
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -430,7 +431,7 @@ func TestDeleteTarget_EmptyID(t *testing.T) {
|
|||||||
|
|
||||||
func TestTestTargetConnection_Success(t *testing.T) {
|
func TestTestTargetConnection_Success(t *testing.T) {
|
||||||
mock := &MockTargetService{
|
mock := &MockTargetService{
|
||||||
TestTargetConnectionFn: func(id string) error {
|
TestConnectionFn: func(_ context.Context, id string) error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -457,7 +458,7 @@ func TestTestTargetConnection_Success(t *testing.T) {
|
|||||||
|
|
||||||
func TestTestTargetConnection_Failed(t *testing.T) {
|
func TestTestTargetConnection_Failed(t *testing.T) {
|
||||||
mock := &MockTargetService{
|
mock := &MockTargetService{
|
||||||
TestTargetConnectionFn: func(id string) error {
|
TestConnectionFn: func(_ context.Context, id string) error {
|
||||||
return ErrMockServiceFailed
|
return ErrMockServiceFailed
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -12,12 +13,12 @@ import (
|
|||||||
|
|
||||||
// TargetService defines the service interface for deployment target operations.
|
// TargetService defines the service interface for deployment target operations.
|
||||||
type TargetService interface {
|
type TargetService interface {
|
||||||
ListTargets(page, perPage int) ([]domain.DeploymentTarget, int64, error)
|
ListTargets(ctx context.Context, page, perPage int) ([]domain.DeploymentTarget, int64, error)
|
||||||
GetTarget(id string) (*domain.DeploymentTarget, error)
|
GetTarget(ctx context.Context, id string) (*domain.DeploymentTarget, error)
|
||||||
CreateTarget(target domain.DeploymentTarget) (*domain.DeploymentTarget, error)
|
CreateTarget(ctx context.Context, target domain.DeploymentTarget) (*domain.DeploymentTarget, error)
|
||||||
UpdateTarget(id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error)
|
UpdateTarget(ctx context.Context, id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error)
|
||||||
DeleteTarget(id string) error
|
DeleteTarget(ctx context.Context, id string) error
|
||||||
TestTargetConnection(id string) error
|
TestConnection(ctx context.Context, id string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// TargetHandler handles HTTP requests for deployment target operations.
|
// TargetHandler handles HTTP requests for deployment target operations.
|
||||||
@@ -54,7 +55,7 @@ func (h TargetHandler) ListTargets(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
targets, total, err := h.svc.ListTargets(page, perPage)
|
targets, total, err := h.svc.ListTargets(r.Context(), page, perPage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list targets", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list targets", requestID)
|
||||||
return
|
return
|
||||||
@@ -86,7 +87,7 @@ func (h TargetHandler) GetTarget(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
target, err := h.svc.GetTarget(id)
|
target, err := h.svc.GetTarget(r.Context(), id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusNotFound, "Target not found", requestID)
|
ErrorWithRequestID(w, http.StatusNotFound, "Target not found", requestID)
|
||||||
return
|
return
|
||||||
@@ -125,7 +126,7 @@ func (h TargetHandler) CreateTarget(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
created, err := h.svc.CreateTarget(target)
|
created, err := h.svc.CreateTarget(r.Context(), target)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create target", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create target", requestID)
|
||||||
return
|
return
|
||||||
@@ -158,7 +159,7 @@ func (h TargetHandler) UpdateTarget(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, err := h.svc.UpdateTarget(id, target)
|
updated, err := h.svc.UpdateTarget(r.Context(), id, target)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update target", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update target", requestID)
|
||||||
return
|
return
|
||||||
@@ -183,7 +184,7 @@ func (h TargetHandler) DeleteTarget(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.svc.DeleteTarget(id); err != nil {
|
if err := h.svc.DeleteTarget(r.Context(), id); err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to delete target", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to delete target", requestID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -210,7 +211,7 @@ func (h TargetHandler) TestTargetConnection(w http.ResponseWriter, r *http.Reque
|
|||||||
}
|
}
|
||||||
id := parts[0]
|
id := parts[0]
|
||||||
|
|
||||||
if err := h.svc.TestTargetConnection(id); err != nil {
|
if err := h.svc.TestConnection(r.Context(), id); err != nil {
|
||||||
JSON(w, http.StatusOK, map[string]interface{}{
|
JSON(w, http.StatusOK, map[string]interface{}{
|
||||||
"status": "failed",
|
"status": "failed",
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -20,35 +21,35 @@ type MockTeamService struct {
|
|||||||
DeleteTeamFn func(id string) error
|
DeleteTeamFn func(id string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockTeamService) ListTeams(page, perPage int) ([]domain.Team, int64, error) {
|
func (m *MockTeamService) ListTeams(_ context.Context, page, perPage int) ([]domain.Team, int64, error) {
|
||||||
if m.ListTeamsFn != nil {
|
if m.ListTeamsFn != nil {
|
||||||
return m.ListTeamsFn(page, perPage)
|
return m.ListTeamsFn(page, perPage)
|
||||||
}
|
}
|
||||||
return nil, 0, nil
|
return nil, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockTeamService) GetTeam(id string) (*domain.Team, error) {
|
func (m *MockTeamService) GetTeam(_ context.Context, id string) (*domain.Team, error) {
|
||||||
if m.GetTeamFn != nil {
|
if m.GetTeamFn != nil {
|
||||||
return m.GetTeamFn(id)
|
return m.GetTeamFn(id)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockTeamService) CreateTeam(team domain.Team) (*domain.Team, error) {
|
func (m *MockTeamService) CreateTeam(_ context.Context, team domain.Team) (*domain.Team, error) {
|
||||||
if m.CreateTeamFn != nil {
|
if m.CreateTeamFn != nil {
|
||||||
return m.CreateTeamFn(team)
|
return m.CreateTeamFn(team)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockTeamService) UpdateTeam(id string, team domain.Team) (*domain.Team, error) {
|
func (m *MockTeamService) UpdateTeam(_ context.Context, id string, team domain.Team) (*domain.Team, error) {
|
||||||
if m.UpdateTeamFn != nil {
|
if m.UpdateTeamFn != nil {
|
||||||
return m.UpdateTeamFn(id, team)
|
return m.UpdateTeamFn(id, team)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockTeamService) DeleteTeam(id string) error {
|
func (m *MockTeamService) DeleteTeam(_ context.Context, id string) error {
|
||||||
if m.DeleteTeamFn != nil {
|
if m.DeleteTeamFn != nil {
|
||||||
return m.DeleteTeamFn(id)
|
return m.DeleteTeamFn(id)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -12,11 +13,11 @@ import (
|
|||||||
|
|
||||||
// TeamService defines the service interface for team operations.
|
// TeamService defines the service interface for team operations.
|
||||||
type TeamService interface {
|
type TeamService interface {
|
||||||
ListTeams(page, perPage int) ([]domain.Team, int64, error)
|
ListTeams(ctx context.Context, page, perPage int) ([]domain.Team, int64, error)
|
||||||
GetTeam(id string) (*domain.Team, error)
|
GetTeam(ctx context.Context, id string) (*domain.Team, error)
|
||||||
CreateTeam(team domain.Team) (*domain.Team, error)
|
CreateTeam(ctx context.Context, team domain.Team) (*domain.Team, error)
|
||||||
UpdateTeam(id string, team domain.Team) (*domain.Team, error)
|
UpdateTeam(ctx context.Context, id string, team domain.Team) (*domain.Team, error)
|
||||||
DeleteTeam(id string) error
|
DeleteTeam(ctx context.Context, id string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// TeamHandler handles HTTP requests for team operations.
|
// TeamHandler handles HTTP requests for team operations.
|
||||||
@@ -53,7 +54,7 @@ func (h TeamHandler) ListTeams(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
teams, total, err := h.svc.ListTeams(page, perPage)
|
teams, total, err := h.svc.ListTeams(r.Context(), page, perPage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list teams", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list teams", requestID)
|
||||||
return
|
return
|
||||||
@@ -87,7 +88,7 @@ func (h TeamHandler) GetTeam(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
id = parts[0]
|
id = parts[0]
|
||||||
|
|
||||||
team, err := h.svc.GetTeam(id)
|
team, err := h.svc.GetTeam(r.Context(), id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusNotFound, "Team not found", requestID)
|
ErrorWithRequestID(w, http.StatusNotFound, "Team not found", requestID)
|
||||||
return
|
return
|
||||||
@@ -122,7 +123,7 @@ func (h TeamHandler) CreateTeam(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
created, err := h.svc.CreateTeam(team)
|
created, err := h.svc.CreateTeam(r.Context(), team)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create team", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to create team", requestID)
|
||||||
return
|
return
|
||||||
@@ -155,7 +156,7 @@ func (h TeamHandler) UpdateTeam(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, err := h.svc.UpdateTeam(id, team)
|
updated, err := h.svc.UpdateTeam(r.Context(), id, team)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update team", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update team", requestID)
|
||||||
return
|
return
|
||||||
@@ -182,7 +183,7 @@ func (h TeamHandler) DeleteTeam(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
id = parts[0]
|
id = parts[0]
|
||||||
|
|
||||||
if err := h.svc.DeleteTeam(id); err != nil {
|
if err := h.svc.DeleteTeam(r.Context(), id); err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to delete team", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to delete team", requestID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,16 +4,22 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AuditRecorder is the interface that the audit middleware uses to record API calls.
|
// AuditRecorder is the interface that the audit middleware uses to record API calls.
|
||||||
// This avoids importing the service package directly, maintaining dependency inversion.
|
// This avoids importing the service package directly, maintaining dependency inversion.
|
||||||
|
//
|
||||||
|
// Implementations may perform I/O (e.g., database writes). The middleware invokes
|
||||||
|
// RecordAPICall from a tracked goroutine so that callers can drain in-flight
|
||||||
|
// recordings during graceful shutdown via AuditMiddleware.Flush.
|
||||||
type AuditRecorder interface {
|
type AuditRecorder interface {
|
||||||
RecordAPICall(ctx context.Context, method, path, actor string, bodyHash string, status int, latencyMs int64) error
|
RecordAPICall(ctx context.Context, method, path, actor string, bodyHash string, status int, latencyMs int64) error
|
||||||
}
|
}
|
||||||
@@ -26,10 +32,42 @@ type AuditConfig struct {
|
|||||||
Logger *slog.Logger
|
Logger *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuditLog creates a middleware that records every API call to the audit trail.
|
// ErrAuditFlushTimeout is returned by AuditMiddleware.Flush when in-flight audit
|
||||||
// It captures method, path, authenticated actor, request body hash, response status, and latency.
|
// recordings do not complete before the provided context is cancelled or its
|
||||||
// Audit recording is best-effort — failures are logged but don't affect the HTTP response.
|
// deadline elapses. It mirrors scheduler.ErrSchedulerShutdownTimeout so callers
|
||||||
func NewAuditLog(recorder AuditRecorder, cfg AuditConfig) func(http.Handler) http.Handler {
|
// can branch on graceful-shutdown timeouts consistently across subsystems.
|
||||||
|
var ErrAuditFlushTimeout = errors.New("audit middleware flush timeout")
|
||||||
|
|
||||||
|
// AuditMiddleware is the handle returned by NewAuditLog. It wraps the audit
|
||||||
|
// logging HTTP middleware and tracks the goroutines spawned to record each API
|
||||||
|
// call, so that callers can drain them during graceful shutdown (M-1, CWE-662
|
||||||
|
// / CWE-400). The goroutines themselves still run detached from the request
|
||||||
|
// context — the shutdown-drain signal flows through this struct's WaitGroup
|
||||||
|
// instead of the per-request context.
|
||||||
|
type AuditMiddleware struct {
|
||||||
|
recorder AuditRecorder
|
||||||
|
logger *slog.Logger
|
||||||
|
excludeSet map[string]bool
|
||||||
|
|
||||||
|
// wg tracks every audit-recording goroutine spawned by Middleware so Flush
|
||||||
|
// can block until they complete before the DB pool is torn down.
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuditLog constructs the API audit logging middleware. The returned
|
||||||
|
// *AuditMiddleware exposes the HTTP middleware via the Middleware method value
|
||||||
|
// (same func(http.Handler) http.Handler shape) and a Flush method that the
|
||||||
|
// process shutdown path must call after the HTTP server has stopped accepting
|
||||||
|
// new requests but before the audit recorder's backing store (e.g., the
|
||||||
|
// database connection pool) is closed.
|
||||||
|
//
|
||||||
|
// The middleware records method, path, authenticated actor, request body hash,
|
||||||
|
// response status, and latency. Recording is best-effort — individual failures
|
||||||
|
// are logged and do not affect the HTTP response. Shutdown is NOT best-effort:
|
||||||
|
// Flush must succeed (or time out, returning ErrAuditFlushTimeout) so that
|
||||||
|
// in-flight events are not lost when the audit recorder's connection pool is
|
||||||
|
// closed out from under the goroutines.
|
||||||
|
func NewAuditLog(recorder AuditRecorder, cfg AuditConfig) *AuditMiddleware {
|
||||||
excludeSet := make(map[string]bool, len(cfg.ExcludePaths))
|
excludeSet := make(map[string]bool, len(cfg.ExcludePaths))
|
||||||
for _, p := range cfg.ExcludePaths {
|
for _, p := range cfg.ExcludePaths {
|
||||||
excludeSet[p] = true
|
excludeSet[p] = true
|
||||||
@@ -40,10 +78,20 @@ func NewAuditLog(recorder AuditRecorder, cfg AuditConfig) func(http.Handler) htt
|
|||||||
logger = slog.Default()
|
logger = slog.Default()
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(next http.Handler) http.Handler {
|
return &AuditMiddleware{
|
||||||
|
recorder: recorder,
|
||||||
|
logger: logger,
|
||||||
|
excludeSet: excludeSet,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware is the http.Handler wrapper. It has the standard
|
||||||
|
// func(http.Handler) http.Handler middleware signature so it can be composed
|
||||||
|
// into an existing middleware chain via a method value (auditMiddleware.Middleware).
|
||||||
|
func (a *AuditMiddleware) Middleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Skip excluded paths (health, readiness probes)
|
// Skip excluded paths (health, readiness probes)
|
||||||
for prefix := range excludeSet {
|
for prefix := range a.excludeSet {
|
||||||
if strings.HasPrefix(r.URL.Path, prefix) {
|
if strings.HasPrefix(r.URL.Path, prefix) {
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
@@ -78,31 +126,84 @@ func NewAuditLog(recorder AuditRecorder, cfg AuditConfig) func(http.Handler) htt
|
|||||||
|
|
||||||
latency := time.Since(start).Milliseconds()
|
latency := time.Since(start).Milliseconds()
|
||||||
|
|
||||||
|
// Snapshot request-derived inputs so the goroutine does not race with
|
||||||
|
// the http.Server reusing r after this handler returns.
|
||||||
|
method := r.Method
|
||||||
|
path := r.URL.Path
|
||||||
|
status := wrapped.statusCode
|
||||||
|
|
||||||
|
// Derive a detached context that preserves request-scoped values
|
||||||
|
// (trace IDs, auth info carried via context keys) but is not cancelled
|
||||||
|
// when the HTTP server finalizes the request. Using r.Context()
|
||||||
|
// directly would cause the async audit write to observe ctx.Done()
|
||||||
|
// as soon as the response completes; using context.Background() would
|
||||||
|
// discard useful observability metadata. WithoutCancel gives us both
|
||||||
|
// (M-2 / D-3).
|
||||||
|
auditCtx := context.WithoutCancel(r.Context())
|
||||||
|
|
||||||
// Record audit event asynchronously (best-effort, don't block response).
|
// Record audit event asynchronously (best-effort, don't block response).
|
||||||
// SECURITY: We intentionally use r.URL.Path (not r.URL.String() or r.RequestURI)
|
// SECURITY: We intentionally use r.URL.Path (not r.URL.String() or r.RequestURI)
|
||||||
// to prevent query parameters from being recorded in the immutable audit trail.
|
// to prevent query parameters from being recorded in the immutable audit trail.
|
||||||
// Query strings may contain cursor tokens, API keys passed as params, or other
|
// Query strings may contain cursor tokens, API keys passed as params, or other
|
||||||
// sensitive filter values. Since the audit trail is append-only with no deletion
|
// sensitive filter values. Since the audit trail is append-only with no deletion
|
||||||
// capability, any sensitive data recorded would persist permanently.
|
// capability, any sensitive data recorded would persist permanently.
|
||||||
|
//
|
||||||
|
// The goroutine is tracked in a.wg so AuditMiddleware.Flush can drain
|
||||||
|
// in-flight recordings during graceful shutdown. Without this (M-1,
|
||||||
|
// CWE-662 / CWE-400), SIGTERM would close the DB pool while recordings
|
||||||
|
// were still mid-flight, silently dropping audit events.
|
||||||
|
a.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
if err := recorder.RecordAPICall(
|
defer a.wg.Done()
|
||||||
context.Background(),
|
if err := a.recorder.RecordAPICall(
|
||||||
r.Method,
|
auditCtx,
|
||||||
r.URL.Path,
|
method,
|
||||||
|
path,
|
||||||
actor,
|
actor,
|
||||||
bodyHash,
|
bodyHash,
|
||||||
wrapped.statusCode,
|
status,
|
||||||
latency,
|
latency,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
logger.Error("failed to record API audit event",
|
a.logger.Error("failed to record API audit event",
|
||||||
"error", err,
|
"error", err,
|
||||||
"method", r.Method,
|
"method", method,
|
||||||
"path", r.URL.Path,
|
"path", path,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Flush blocks until every audit-recording goroutine spawned by Middleware has
|
||||||
|
// completed, or until ctx is cancelled / its deadline elapses. It must be
|
||||||
|
// called from the process shutdown path after http.Server.Shutdown has
|
||||||
|
// returned (so no new requests are being accepted) but before the backing
|
||||||
|
// audit recorder's resources (DB pool, etc.) are torn down.
|
||||||
|
//
|
||||||
|
// On timeout or cancellation Flush returns ErrAuditFlushTimeout wrapped with
|
||||||
|
// any context error; in-flight goroutines continue to run and may still write
|
||||||
|
// to the recorder once they unblock — the caller is responsible for deciding
|
||||||
|
// whether to proceed with teardown anyway or surface the error.
|
||||||
|
//
|
||||||
|
// Flush mirrors the idiom used by scheduler.Scheduler.WaitForCompletion so
|
||||||
|
// that the two subsystems drain identically at shutdown.
|
||||||
|
func (a *AuditMiddleware) Flush(ctx context.Context) error {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
a.wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
a.logger.Info("audit middleware flush complete")
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
a.logger.Warn("audit middleware flush did not complete before context cancellation",
|
||||||
|
"error", ctx.Err(),
|
||||||
|
)
|
||||||
|
return fmt.Errorf("%w: %w", ErrAuditFlushTimeout, ctx.Err())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuditServiceAdapter adapts the AuditService to the AuditRecorder interface.
|
// AuditServiceAdapter adapts the AuditService to the AuditRecorder interface.
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -17,6 +18,7 @@ type mockAuditRecorder struct {
|
|||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
calls []auditCall
|
calls []auditCall
|
||||||
err error // if non-nil, RecordAPICall returns this
|
err error // if non-nil, RecordAPICall returns this
|
||||||
|
block chan struct{} // if non-nil, RecordAPICall blocks on receive before returning
|
||||||
}
|
}
|
||||||
|
|
||||||
type auditCall struct {
|
type auditCall struct {
|
||||||
@@ -29,6 +31,13 @@ type auditCall struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuditRecorder) RecordAPICall(ctx context.Context, method, path, actor, bodyHash string, status int, latencyMs int64) error {
|
func (m *mockAuditRecorder) RecordAPICall(ctx context.Context, method, path, actor, bodyHash string, status int, latencyMs int64) error {
|
||||||
|
// Optional: block the recorder until a signal is received so tests can
|
||||||
|
// exercise the shutdown-drain path deterministically. The block happens
|
||||||
|
// before any state mutation so Flush-timeout tests see the call
|
||||||
|
// "in-flight" (wg counter > 0) with no recorded entries yet.
|
||||||
|
if m.block != nil {
|
||||||
|
<-m.block
|
||||||
|
}
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
m.calls = append(m.calls, auditCall{
|
m.calls = append(m.calls, auditCall{
|
||||||
@@ -90,7 +99,7 @@ func (w *waitableAuditRecorder) Wait(timeout time.Duration) bool {
|
|||||||
|
|
||||||
func TestAuditLog_RecordsAPICall(t *testing.T) {
|
func TestAuditLog_RecordsAPICall(t *testing.T) {
|
||||||
recorder := newWaitableAuditRecorder()
|
recorder := newWaitableAuditRecorder()
|
||||||
mw := NewAuditLog(recorder, AuditConfig{})
|
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
|
||||||
|
|
||||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
@@ -130,7 +139,7 @@ func TestAuditLog_RecordsAPICall(t *testing.T) {
|
|||||||
|
|
||||||
func TestAuditLog_CapturesStatusCode(t *testing.T) {
|
func TestAuditLog_CapturesStatusCode(t *testing.T) {
|
||||||
recorder := newWaitableAuditRecorder()
|
recorder := newWaitableAuditRecorder()
|
||||||
mw := NewAuditLog(recorder, AuditConfig{})
|
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
|
||||||
|
|
||||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
@@ -157,7 +166,7 @@ func TestAuditLog_ExcludesHealth(t *testing.T) {
|
|||||||
recorder := newWaitableAuditRecorder()
|
recorder := newWaitableAuditRecorder()
|
||||||
mw := NewAuditLog(recorder, AuditConfig{
|
mw := NewAuditLog(recorder, AuditConfig{
|
||||||
ExcludePaths: []string{"/health", "/ready"},
|
ExcludePaths: []string{"/health", "/ready"},
|
||||||
})
|
}).Middleware
|
||||||
|
|
||||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
@@ -193,7 +202,7 @@ func TestAuditLog_ExcludesHealth(t *testing.T) {
|
|||||||
|
|
||||||
func TestAuditLog_HashesRequestBody(t *testing.T) {
|
func TestAuditLog_HashesRequestBody(t *testing.T) {
|
||||||
recorder := newWaitableAuditRecorder()
|
recorder := newWaitableAuditRecorder()
|
||||||
mw := NewAuditLog(recorder, AuditConfig{})
|
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
|
||||||
|
|
||||||
// Handler verifies body was restored
|
// Handler verifies body was restored
|
||||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -228,7 +237,7 @@ func TestAuditLog_HashesRequestBody(t *testing.T) {
|
|||||||
|
|
||||||
func TestAuditLog_EmptyBodyNoHash(t *testing.T) {
|
func TestAuditLog_EmptyBodyNoHash(t *testing.T) {
|
||||||
recorder := newWaitableAuditRecorder()
|
recorder := newWaitableAuditRecorder()
|
||||||
mw := NewAuditLog(recorder, AuditConfig{})
|
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
|
||||||
|
|
||||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
@@ -253,7 +262,7 @@ func TestAuditLog_EmptyBodyNoHash(t *testing.T) {
|
|||||||
|
|
||||||
func TestAuditLog_ExtractsAuthenticatedActor(t *testing.T) {
|
func TestAuditLog_ExtractsAuthenticatedActor(t *testing.T) {
|
||||||
recorder := newWaitableAuditRecorder()
|
recorder := newWaitableAuditRecorder()
|
||||||
mw := NewAuditLog(recorder, AuditConfig{})
|
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
|
||||||
|
|
||||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
@@ -285,7 +294,7 @@ func TestAuditLog_ExtractsAuthenticatedActor(t *testing.T) {
|
|||||||
|
|
||||||
func TestAuditLog_RecorderErrorDoesNotBreakResponse(t *testing.T) {
|
func TestAuditLog_RecorderErrorDoesNotBreakResponse(t *testing.T) {
|
||||||
recorder := &mockAuditRecorder{err: fmt.Errorf("db connection lost")}
|
recorder := &mockAuditRecorder{err: fmt.Errorf("db connection lost")}
|
||||||
mw := NewAuditLog(recorder, AuditConfig{})
|
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
|
||||||
|
|
||||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
@@ -304,7 +313,7 @@ func TestAuditLog_RecorderErrorDoesNotBreakResponse(t *testing.T) {
|
|||||||
|
|
||||||
func TestAuditLog_CapturesLatency(t *testing.T) {
|
func TestAuditLog_CapturesLatency(t *testing.T) {
|
||||||
recorder := newWaitableAuditRecorder()
|
recorder := newWaitableAuditRecorder()
|
||||||
mw := NewAuditLog(recorder, AuditConfig{})
|
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
|
||||||
|
|
||||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
@@ -330,7 +339,7 @@ func TestAuditLog_CapturesLatency(t *testing.T) {
|
|||||||
|
|
||||||
func TestAuditLog_ExcludesQueryParamsFromPath(t *testing.T) {
|
func TestAuditLog_ExcludesQueryParamsFromPath(t *testing.T) {
|
||||||
recorder := newWaitableAuditRecorder()
|
recorder := newWaitableAuditRecorder()
|
||||||
mw := NewAuditLog(recorder, AuditConfig{})
|
mw := NewAuditLog(recorder, AuditConfig{}).Middleware
|
||||||
|
|
||||||
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
@@ -429,3 +438,112 @@ func TestAuditServiceAdapter_PropagatesError(t *testing.T) {
|
|||||||
t.Errorf("expected database error, got %v", err)
|
t.Errorf("expected database error, got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestAuditLog_FlushDrainsInFlightGoroutines verifies the M-1 shutdown-drain
|
||||||
|
// contract: Flush blocks until every audit-recording goroutine spawned by the
|
||||||
|
// middleware completes, then returns nil. Without the drain (pre-M-1 code),
|
||||||
|
// the DB pool would be closed while in-flight goroutines were still calling
|
||||||
|
// RecordAPICall, silently dropping audit events (CWE-662 / CWE-400).
|
||||||
|
func TestAuditLog_FlushDrainsInFlightGoroutines(t *testing.T) {
|
||||||
|
// Recorder blocks on `unblock` until the test releases it. This simulates
|
||||||
|
// a slow DB write still in flight when shutdown begins.
|
||||||
|
unblock := make(chan struct{})
|
||||||
|
recorder := &mockAuditRecorder{block: unblock}
|
||||||
|
auditMW := NewAuditLog(recorder, AuditConfig{})
|
||||||
|
|
||||||
|
handler := auditMW.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Fire a request. Handler returns immediately; recorder goroutine is
|
||||||
|
// parked on the `unblock` channel inside RecordAPICall.
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
// Start Flush in a goroutine — it must block on the WaitGroup until we
|
||||||
|
// release the recorder.
|
||||||
|
flushDone := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
flushDone <- auditMW.Flush(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Confirm Flush is actually blocked (not returning immediately).
|
||||||
|
select {
|
||||||
|
case err := <-flushDone:
|
||||||
|
t.Fatalf("Flush returned before recorder unblocked: err=%v", err)
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
// expected: Flush is blocked on wg.Wait
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release the recorder. Flush should now observe wg counter drop to 0
|
||||||
|
// and return nil.
|
||||||
|
close(unblock)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-flushDone:
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil from Flush after drain, got %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("Flush did not return after recorder unblocked")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the audit event was actually recorded (i.e., the goroutine
|
||||||
|
// completed its write — not just that Flush unblocked).
|
||||||
|
calls := recorder.getCalls()
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 recorded audit call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if calls[0].Path != "/api/v1/certificates" {
|
||||||
|
t.Errorf("expected path /api/v1/certificates, got %s", calls[0].Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAuditLog_FlushTimeoutReturnsErrAuditFlushTimeout verifies that Flush
|
||||||
|
// respects its context: when in-flight goroutines exceed the shutdown budget,
|
||||||
|
// Flush returns an error wrapping ErrAuditFlushTimeout plus ctx.Err(). The
|
||||||
|
// caller can then decide whether to proceed with teardown anyway.
|
||||||
|
func TestAuditLog_FlushTimeoutReturnsErrAuditFlushTimeout(t *testing.T) {
|
||||||
|
// Recorder will never unblock on its own — we unblock at end of test for
|
||||||
|
// a clean race-safe teardown.
|
||||||
|
unblock := make(chan struct{})
|
||||||
|
recorder := &mockAuditRecorder{block: unblock}
|
||||||
|
auditMW := NewAuditLog(recorder, AuditConfig{})
|
||||||
|
|
||||||
|
handler := auditMW.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
// Flush with a tiny deadline — must time out.
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
err := auditMW.Flush(ctx)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
// Release the blocked goroutine before failing so the race detector
|
||||||
|
// doesn't trip on teardown.
|
||||||
|
close(unblock)
|
||||||
|
t.Fatal("expected Flush to return an error on timeout, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrAuditFlushTimeout) {
|
||||||
|
close(unblock)
|
||||||
|
t.Fatalf("expected error to wrap ErrAuditFlushTimeout, got %v", err)
|
||||||
|
}
|
||||||
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
close(unblock)
|
||||||
|
t.Fatalf("expected error to wrap context.DeadlineExceeded, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Race-safe teardown: unblock the recorder goroutine so it exits cleanly
|
||||||
|
// before the test returns. The goroutine itself is still detached and
|
||||||
|
// will record to the mock even after Flush timed out — that's the
|
||||||
|
// documented behavior (Flush surfaces the timeout; caller decides).
|
||||||
|
close(unblock)
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -78,10 +79,17 @@ func NewLogging(logger *slog.Logger) func(http.Handler) http.Handler {
|
|||||||
// Recovery middleware recovers from panics and returns a 500 error.
|
// Recovery middleware recovers from panics and returns a 500 error.
|
||||||
func Recovery(next http.Handler) http.Handler {
|
func Recovery(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
requestID := getRequestID(r.Context())
|
requestID := getRequestID(ctx)
|
||||||
log.Printf("[%s] PANIC: %v", requestID, err)
|
// Use slog.ErrorContext so the panic log carries the same
|
||||||
|
// request-scoped trace/auth metadata as normal request logs
|
||||||
|
// (M-2 / D-3 — preserve ctx propagation on the panic path).
|
||||||
|
slog.ErrorContext(ctx, "panic recovered in HTTP handler",
|
||||||
|
"request_id", requestID,
|
||||||
|
"panic", fmt.Sprintf("%v", err),
|
||||||
|
)
|
||||||
http.Error(w, `{"error":"Internal Server Error"}`, http.StatusInternalServerError)
|
http.Error(w, `{"error":"Internal Server Error"}`, http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -547,7 +547,11 @@ func (c *Connector) solveAuthorizationsHTTP01(ctx context.Context, authzURLs []s
|
|||||||
return fmt.Errorf("failed to start challenge server: %w", err)
|
return fmt.Errorf("failed to start challenge server: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
// Derive the challenge-server shutdown context from the parent ctx so
|
||||||
|
// values (trace IDs, deadlines) propagate, but detach from its
|
||||||
|
// cancellation so Shutdown always gets its full budget even when the
|
||||||
|
// parent was cancelled (M-2 / D-3).
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
_ = srv.Shutdown(shutdownCtx)
|
_ = srv.Shutdown(shutdownCtx)
|
||||||
c.logger.Debug("challenge server stopped")
|
c.logger.Debug("challenge server stopped")
|
||||||
|
|||||||
@@ -359,6 +359,25 @@ func (c *Connector) loadCAFromDisk() error {
|
|||||||
return fmt.Errorf("loaded CA certificate does not have KeyUsageCertSign")
|
return fmt.Errorf("loaded CA certificate does not have KeyUsageCertSign")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate CA certificate validity window (M-5, CWE-672).
|
||||||
|
// An expired or not-yet-valid sub-CA produces child certificates that any
|
||||||
|
// RFC 5280 path-validator will reject. Fail closed at load time so operators
|
||||||
|
// learn about it at startup, not at 3am when a renewal cycle silently
|
||||||
|
// starts minting broken certs. See audit finding M-5.
|
||||||
|
now := time.Now()
|
||||||
|
if now.After(caCert.NotAfter) {
|
||||||
|
return fmt.Errorf("CA certificate %q has expired (not_after=%s, now=%s)",
|
||||||
|
caCert.Subject.CommonName,
|
||||||
|
caCert.NotAfter.UTC().Format(time.RFC3339),
|
||||||
|
now.UTC().Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
if now.Before(caCert.NotBefore) {
|
||||||
|
return fmt.Errorf("CA certificate %q is not yet valid (not_before=%s, now=%s)",
|
||||||
|
caCert.Subject.CommonName,
|
||||||
|
caCert.NotBefore.UTC().Format(time.RFC3339),
|
||||||
|
now.UTC().Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
|
||||||
// Load CA private key (supports RSA and ECDSA)
|
// Load CA private key (supports RSA and ECDSA)
|
||||||
keyPEM, err := os.ReadFile(c.config.CAKeyPath)
|
keyPEM, err := os.ReadFile(c.config.CAKeyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"math/big"
|
"math/big"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -360,6 +361,114 @@ func TestSubCAMode(t *testing.T) {
|
|||||||
t.Logf("Correctly rejected non-CA cert: %v", err)
|
t.Logf("Correctly rejected non-CA cert: %v", err)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("SubCA_ExpiredCert_IsRejected", func(t *testing.T) {
|
||||||
|
// Sub-CA expired 1 hour ago. M-5: loadCAFromDisk must fail closed
|
||||||
|
// instead of minting child certs that immediately fail path validation
|
||||||
|
// at every relying party (CWE-672).
|
||||||
|
notBefore := time.Now().AddDate(-1, 0, 0)
|
||||||
|
notAfter := time.Now().Add(-1 * time.Hour)
|
||||||
|
certPath, keyPath := generateTestSubCAWithValidity(t, "rsa", notBefore, notAfter)
|
||||||
|
|
||||||
|
config := &local.Config{
|
||||||
|
ValidityDays: 30,
|
||||||
|
CACertPath: certPath,
|
||||||
|
CAKeyPath: keyPath,
|
||||||
|
}
|
||||||
|
connector := local.New(config, logger)
|
||||||
|
|
||||||
|
_, csrPEM, err := generateTestCSR("app.internal.corp")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate CSR: %v", err)
|
||||||
|
}
|
||||||
|
req := issuer.IssuanceRequest{
|
||||||
|
CommonName: "app.internal.corp",
|
||||||
|
CSRPEM: csrPEM,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = connector.IssueCertificate(ctx, req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error when loading expired sub-CA; got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "expired") {
|
||||||
|
t.Errorf("Expected error to mention 'expired'; got: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "Test Sub-CA") {
|
||||||
|
t.Errorf("Expected error to include CA subject CN 'Test Sub-CA'; got: %v", err)
|
||||||
|
}
|
||||||
|
t.Logf("Correctly rejected expired sub-CA: %v", err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("SubCA_NotYetValid_IsRejected", func(t *testing.T) {
|
||||||
|
// Sub-CA is not valid for another hour (clock skew or operator error
|
||||||
|
// pushing a pre-production CA into prod). M-5: loadCAFromDisk must
|
||||||
|
// fail closed.
|
||||||
|
notBefore := time.Now().Add(1 * time.Hour)
|
||||||
|
notAfter := time.Now().AddDate(5, 0, 0)
|
||||||
|
certPath, keyPath := generateTestSubCAWithValidity(t, "rsa", notBefore, notAfter)
|
||||||
|
|
||||||
|
config := &local.Config{
|
||||||
|
ValidityDays: 30,
|
||||||
|
CACertPath: certPath,
|
||||||
|
CAKeyPath: keyPath,
|
||||||
|
}
|
||||||
|
connector := local.New(config, logger)
|
||||||
|
|
||||||
|
_, csrPEM, err := generateTestCSR("app.internal.corp")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate CSR: %v", err)
|
||||||
|
}
|
||||||
|
req := issuer.IssuanceRequest{
|
||||||
|
CommonName: "app.internal.corp",
|
||||||
|
CSRPEM: csrPEM,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = connector.IssueCertificate(ctx, req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error when loading not-yet-valid sub-CA; got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "not yet valid") {
|
||||||
|
t.Errorf("Expected error to mention 'not yet valid'; got: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "Test Sub-CA") {
|
||||||
|
t.Errorf("Expected error to include CA subject CN 'Test Sub-CA'; got: %v", err)
|
||||||
|
}
|
||||||
|
t.Logf("Correctly rejected not-yet-valid sub-CA: %v", err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("SubCA_BarelyValid_IsAccepted", func(t *testing.T) {
|
||||||
|
// Sub-CA valid from 1 minute ago to 1 hour from now. Edge case:
|
||||||
|
// proves the M-5 window check doesn't over-reject CAs that are
|
||||||
|
// legitimately live but close to the boundaries.
|
||||||
|
notBefore := time.Now().Add(-1 * time.Minute)
|
||||||
|
notAfter := time.Now().Add(1 * time.Hour)
|
||||||
|
certPath, keyPath := generateTestSubCAWithValidity(t, "rsa", notBefore, notAfter)
|
||||||
|
|
||||||
|
config := &local.Config{
|
||||||
|
ValidityDays: 30,
|
||||||
|
CACertPath: certPath,
|
||||||
|
CAKeyPath: keyPath,
|
||||||
|
}
|
||||||
|
connector := local.New(config, logger)
|
||||||
|
|
||||||
|
_, csrPEM, err := generateTestCSR("app.internal.corp")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate CSR: %v", err)
|
||||||
|
}
|
||||||
|
req := issuer.IssuanceRequest{
|
||||||
|
CommonName: "app.internal.corp",
|
||||||
|
CSRPEM: csrPEM,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := connector.IssueCertificate(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Barely-valid sub-CA was wrongly rejected: %v", err)
|
||||||
|
}
|
||||||
|
if result.CertPEM == "" {
|
||||||
|
t.Error("CertPEM is empty")
|
||||||
|
}
|
||||||
|
t.Logf("Correctly accepted barely-valid sub-CA: serial=%s", result.Serial)
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("SubCA_RenewCertificate", func(t *testing.T) {
|
t.Run("SubCA_RenewCertificate", func(t *testing.T) {
|
||||||
certPath, keyPath := generateTestSubCA(t, "rsa")
|
certPath, keyPath := generateTestSubCA(t, "rsa")
|
||||||
defer os.Remove(certPath)
|
defer os.Remove(certPath)
|
||||||
@@ -396,8 +505,16 @@ func TestSubCAMode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// generateTestSubCA creates a self-signed CA cert+key pair and writes them to temp files.
|
// generateTestSubCA creates a self-signed CA cert+key pair and writes them to temp files.
|
||||||
// keyType can be "rsa" or "ecdsa".
|
// keyType can be "rsa" or "ecdsa". Validity window is [now, now+5y].
|
||||||
func generateTestSubCA(t *testing.T, keyType string) (certPath, keyPath string) {
|
func generateTestSubCA(t *testing.T, keyType string) (certPath, keyPath string) {
|
||||||
|
t.Helper()
|
||||||
|
return generateTestSubCAWithValidity(t, keyType, time.Now(), time.Now().AddDate(5, 0, 0))
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateTestSubCAWithValidity creates a self-signed CA cert+key pair with an
|
||||||
|
// explicit NotBefore/NotAfter window. Used by M-5 tests that exercise expired
|
||||||
|
// and not-yet-valid CA rejection in loadCAFromDisk.
|
||||||
|
func generateTestSubCAWithValidity(t *testing.T, keyType string, notBefore, notAfter time.Time) (certPath, keyPath string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
certPath = filepath.Join(tmpDir, "ca.pem")
|
certPath = filepath.Join(tmpDir, "ca.pem")
|
||||||
@@ -445,8 +562,8 @@ func generateTestSubCA(t *testing.T, keyType string) (certPath, keyPath string)
|
|||||||
CommonName: "Test Sub-CA",
|
CommonName: "Test Sub-CA",
|
||||||
Organization: []string{"CertCtl Test"},
|
Organization: []string{"CertCtl Test"},
|
||||||
},
|
},
|
||||||
NotBefore: time.Now(),
|
NotBefore: notBefore,
|
||||||
NotAfter: time.Now().AddDate(5, 0, 0),
|
NotAfter: notAfter,
|
||||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||||
BasicConstraintsValid: true,
|
BasicConstraintsValid: true,
|
||||||
IsCA: true,
|
IsCA: true,
|
||||||
|
|||||||
+197
-37
@@ -1,4 +1,31 @@
|
|||||||
// Package crypto provides AES-256-GCM encryption for sensitive configuration data.
|
// Package crypto provides AES-256-GCM encryption for sensitive configuration data.
|
||||||
|
//
|
||||||
|
// The on-disk format for blobs produced by [EncryptIfKeySet] is versioned. Two
|
||||||
|
// versions coexist and both can be read by [DecryptIfKeySet]:
|
||||||
|
//
|
||||||
|
// v2 (current, M-8)
|
||||||
|
// magic(0x02) || salt(16) || nonce(12) || ciphertext+tag
|
||||||
|
// — 32-byte AES-256 key derived via PBKDF2-SHA256 from the operator
|
||||||
|
// passphrase and the per-ciphertext random salt.
|
||||||
|
//
|
||||||
|
// v1 (legacy, pre-M-8)
|
||||||
|
// nonce(12) || ciphertext+tag
|
||||||
|
// — 32-byte AES-256 key derived via PBKDF2-SHA256 from the operator
|
||||||
|
// passphrase and the package-level fixed salt
|
||||||
|
// "certctl-config-encryption-v1".
|
||||||
|
//
|
||||||
|
// v1 blobs are accepted by the read path for backward compatibility with rows
|
||||||
|
// persisted before the M-8 remediation. They are never produced by the write
|
||||||
|
// path. Any row that is updated after M-8 is re-sealed as v2 in-place via the
|
||||||
|
// normal UPDATE flow.
|
||||||
|
//
|
||||||
|
// Rationale for the per-ciphertext salt (see M-8 / CWE-916 / CWE-329): the
|
||||||
|
// pre-M-8 design reused a single 28-byte fixed salt for every ciphertext, which
|
||||||
|
// (a) removes one defense-in-depth layer against passphrase-space brute force
|
||||||
|
// and (b) makes every encrypted column across every row share the exact same
|
||||||
|
// derived key. v2 replaces the fixed salt with 16 fresh random bytes per write
|
||||||
|
// and stores the salt alongside the ciphertext. Derived keys now differ per
|
||||||
|
// row and per re-encryption.
|
||||||
package crypto
|
package crypto
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -14,7 +41,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ErrEncryptionKeyRequired is returned by EncryptIfKeySet and DecryptIfKeySet when
|
// ErrEncryptionKeyRequired is returned by EncryptIfKeySet and DecryptIfKeySet when
|
||||||
// the caller provides an empty key but the data on the wire requires protection.
|
// the caller provides an empty passphrase but the data on the wire requires
|
||||||
|
// protection.
|
||||||
//
|
//
|
||||||
// Historically these helpers silently returned plaintext when no key was configured,
|
// Historically these helpers silently returned plaintext when no key was configured,
|
||||||
// which produced a data-at-rest confidentiality bypass (CWE-311): sensitive fields
|
// which produced a data-at-rest confidentiality bypass (CWE-311): sensitive fields
|
||||||
@@ -24,16 +52,58 @@ import (
|
|||||||
// and plaintext branches at runtime, so the only visible signal was a warning
|
// and plaintext branches at runtime, so the only visible signal was a warning
|
||||||
// line emitted once at startup.
|
// line emitted once at startup.
|
||||||
//
|
//
|
||||||
// The fix is to fail closed: EncryptIfKeySet/DecryptIfKeySet now require a key
|
// The fix (C-2, commit fb4ce1a) is to fail closed: EncryptIfKeySet/DecryptIfKeySet
|
||||||
// whenever they are invoked on sensitive material, and the server refuses to
|
// now require a passphrase whenever they are invoked on sensitive material, and
|
||||||
// start if any source='database' rows already exist without a configured key.
|
// the server refuses to start if any source='database' rows already exist without
|
||||||
|
// a configured passphrase.
|
||||||
var ErrEncryptionKeyRequired = errors.New("crypto: CERTCTL_CONFIG_ENCRYPTION_KEY is required to encrypt or decrypt sensitive config")
|
var ErrEncryptionKeyRequired = errors.New("crypto: CERTCTL_CONFIG_ENCRYPTION_KEY is required to encrypt or decrypt sensitive config")
|
||||||
|
|
||||||
|
// v2Magic is the first byte of every v2-format ciphertext blob. It distinguishes
|
||||||
|
// v2 blobs (per-ciphertext random salt, embedded in the blob) from v1 legacy
|
||||||
|
// blobs (no magic byte, fixed package-level salt).
|
||||||
|
//
|
||||||
|
// The choice of 0x02 is deliberate: v1 blobs begin with a random 12-byte AES-GCM
|
||||||
|
// nonce. A v1 nonce can coincidentally start with 0x02 with probability 1/256,
|
||||||
|
// which makes a pure magic-byte dispatch ambiguous. [DecryptIfKeySet] resolves
|
||||||
|
// the ambiguity by falling back to the v1 path when v2 AEAD verification fails.
|
||||||
|
const v2Magic byte = 0x02
|
||||||
|
|
||||||
|
// v2SaltSize is the length in bytes of the per-ciphertext salt embedded in a
|
||||||
|
// v2 blob. 16 bytes (128 bits) matches the lower bound recommended in NIST
|
||||||
|
// SP 800-132 §5.1 for PBKDF2 salts and is sufficient given the one-shot-per-row
|
||||||
|
// nature of the derivation.
|
||||||
|
const v2SaltSize = 16
|
||||||
|
|
||||||
|
// pbkdf2Iterations is the PBKDF2-SHA256 work factor applied uniformly to both
|
||||||
|
// v1 and v2 key derivations. The value is preserved from the pre-M-8 design so
|
||||||
|
// that v1 fallback reads stay bit-identical.
|
||||||
|
const pbkdf2Iterations = 100000
|
||||||
|
|
||||||
|
// aes256KeySize is the output length in bytes of both [DeriveKey] and
|
||||||
|
// [deriveKeyWithSalt]. It is also the only AES key length accepted by [Encrypt]
|
||||||
|
// and [Decrypt].
|
||||||
|
const aes256KeySize = 32
|
||||||
|
|
||||||
|
// legacyV1Salt is the fixed salt used by pre-M-8 config encryption. It is
|
||||||
|
// retained exclusively to preserve the v1 read path — any v1 blob that pre-dates
|
||||||
|
// M-8 remediation must be decryptable with a key derived from (passphrase,
|
||||||
|
// legacyV1Salt). The write path never uses this salt.
|
||||||
|
//
|
||||||
|
// Exposed as a package-level var rather than a local so that tests can reason
|
||||||
|
// about v1 fixture bytes symbolically.
|
||||||
|
var legacyV1Salt = []byte("certctl-config-encryption-v1")
|
||||||
|
|
||||||
// Encrypt encrypts plaintext using AES-256-GCM with a random 12-byte nonce prepended to the output.
|
// Encrypt encrypts plaintext using AES-256-GCM with a random 12-byte nonce prepended to the output.
|
||||||
// The key must be exactly 32 bytes (AES-256). Returns [12-byte nonce][ciphertext+tag].
|
// The key must be exactly 32 bytes (AES-256). Returns [12-byte nonce][ciphertext+tag].
|
||||||
|
//
|
||||||
|
// Encrypt is a low-level primitive. It is intentionally kept byte-identical to
|
||||||
|
// the pre-M-8 implementation so that existing v1 blobs on disk remain
|
||||||
|
// decryptable via [Decrypt] when paired with a [DeriveKey]-derived key. New
|
||||||
|
// callers should prefer [EncryptIfKeySet], which handles key derivation and
|
||||||
|
// emits the v2 wire format.
|
||||||
func Encrypt(plaintext []byte, key []byte) ([]byte, error) {
|
func Encrypt(plaintext []byte, key []byte) ([]byte, error) {
|
||||||
if len(key) != 32 {
|
if len(key) != aes256KeySize {
|
||||||
return nil, fmt.Errorf("encryption key must be exactly 32 bytes, got %d", len(key))
|
return nil, fmt.Errorf("encryption key must be exactly %d bytes, got %d", aes256KeySize, len(key))
|
||||||
}
|
}
|
||||||
|
|
||||||
block, err := aes.NewCipher(key)
|
block, err := aes.NewCipher(key)
|
||||||
@@ -57,9 +127,14 @@ func Encrypt(plaintext []byte, key []byte) ([]byte, error) {
|
|||||||
|
|
||||||
// Decrypt decrypts ciphertext that was encrypted with Encrypt.
|
// Decrypt decrypts ciphertext that was encrypted with Encrypt.
|
||||||
// Expects format: [12-byte nonce][ciphertext+tag]. Key must be exactly 32 bytes.
|
// Expects format: [12-byte nonce][ciphertext+tag]. Key must be exactly 32 bytes.
|
||||||
|
//
|
||||||
|
// Decrypt is a low-level primitive. It is intentionally kept byte-identical to
|
||||||
|
// the pre-M-8 implementation so that [DecryptIfKeySet] can delegate to it for
|
||||||
|
// both the v2 inner blob (after stripping the magic byte + embedded salt) and
|
||||||
|
// the v1 legacy blob (unmodified).
|
||||||
func Decrypt(ciphertext []byte, key []byte) ([]byte, error) {
|
func Decrypt(ciphertext []byte, key []byte) ([]byte, error) {
|
||||||
if len(key) != 32 {
|
if len(key) != aes256KeySize {
|
||||||
return nil, fmt.Errorf("encryption key must be exactly 32 bytes, got %d", len(key))
|
return nil, fmt.Errorf("encryption key must be exactly %d bytes, got %d", aes256KeySize, len(key))
|
||||||
}
|
}
|
||||||
|
|
||||||
block, err := aes.NewCipher(key)
|
block, err := aes.NewCipher(key)
|
||||||
@@ -86,48 +161,133 @@ func Decrypt(ciphertext []byte, key []byte) ([]byte, error) {
|
|||||||
return plaintext, nil
|
return plaintext, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeriveKey derives a 32-byte AES-256 key from a passphrase using PBKDF2-SHA256.
|
// DeriveKey derives a 32-byte AES-256 key from a passphrase using PBKDF2-SHA256
|
||||||
// Uses a fixed application-specific salt and 100,000 iterations for resistance
|
// with the legacy v1 fixed salt.
|
||||||
// to brute-force attacks on weak passphrases.
|
//
|
||||||
|
// This helper is preserved byte-identical to the pre-M-8 implementation so that
|
||||||
|
// v1 ciphertexts persisted before the M-8 remediation remain decryptable
|
||||||
|
// unchanged. New code paths should prefer [EncryptIfKeySet] and
|
||||||
|
// [DecryptIfKeySet], which use a per-ciphertext random salt.
|
||||||
func DeriveKey(passphrase string) []byte {
|
func DeriveKey(passphrase string) []byte {
|
||||||
// Fixed salt is acceptable here because:
|
return deriveKeyWithSalt(passphrase, legacyV1Salt)
|
||||||
// 1. Each certctl instance has its own passphrase
|
|
||||||
// 2. The salt prevents generic rainbow table attacks
|
|
||||||
// 3. Per-user salts are unnecessary (single server key, not user passwords)
|
|
||||||
salt := []byte("certctl-config-encryption-v1")
|
|
||||||
return pbkdf2.Key([]byte(passphrase), salt, 100000, 32, sha256.New)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// EncryptIfKeySet encrypts plaintext with the supplied 32-byte AES-256 key.
|
// deriveKeyWithSalt derives a 32-byte AES-256 key from a passphrase and an
|
||||||
|
// explicit salt using PBKDF2-SHA256 with [pbkdf2Iterations] rounds.
|
||||||
|
//
|
||||||
|
// The per-ciphertext random salt path (v2) calls this directly with a fresh
|
||||||
|
// 16-byte random salt embedded in the ciphertext blob. The legacy path
|
||||||
|
// ([DeriveKey]) calls it with the package-level fixed salt [legacyV1Salt].
|
||||||
|
func deriveKeyWithSalt(passphrase string, salt []byte) []byte {
|
||||||
|
return pbkdf2.Key([]byte(passphrase), salt, pbkdf2Iterations, aes256KeySize, sha256.New)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLegacyFormat reports whether blob is in the v1 legacy wire format (no magic
|
||||||
|
// byte, fixed-salt derivation) as opposed to the v2 wire format
|
||||||
|
// (magic(0x02) || salt(16) || nonce(12) || ciphertext+tag).
|
||||||
|
//
|
||||||
|
// A return value of false is a necessary but not sufficient condition for a
|
||||||
|
// blob to be a valid v2 ciphertext: the shortest possible v2 blob is
|
||||||
|
// 1 + v2SaltSize + 12 = 29 bytes, and even a 29+ byte blob that starts with
|
||||||
|
// 0x02 may turn out to be a v1 ciphertext whose random nonce happens to begin
|
||||||
|
// with 0x02 (probability 1/256). [DecryptIfKeySet] resolves this ambiguity at
|
||||||
|
// decrypt time by falling back to v1 when v2 AEAD verification fails; callers
|
||||||
|
// of IsLegacyFormat should use it only as a heuristic (e.g. migration
|
||||||
|
// tooling, log annotation).
|
||||||
|
func IsLegacyFormat(blob []byte) bool {
|
||||||
|
if len(blob) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return blob[0] != v2Magic
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncryptIfKeySet encrypts plaintext with the supplied passphrase and emits a
|
||||||
|
// v2 wire-format blob: magic(0x02) || salt(16) || nonce(12) || ciphertext+tag.
|
||||||
|
//
|
||||||
|
// Key derivation is performed internally per invocation with a fresh 16-byte
|
||||||
|
// random salt, producing a distinct AES-256 key for every ciphertext. The
|
||||||
|
// operator-supplied passphrase is the only cross-ciphertext shared secret.
|
||||||
//
|
//
|
||||||
// The second return value is always true when err == nil — the "wasEncrypted"
|
// The second return value is always true when err == nil — the "wasEncrypted"
|
||||||
// flag is retained for source-compatibility with callers that previously used it
|
// flag is retained for source-compatibility with callers that previously used
|
||||||
// to log provenance. Callers MUST handle err: passing an empty key now returns
|
// it to log provenance. Callers MUST handle err: passing an empty passphrase
|
||||||
// ErrEncryptionKeyRequired rather than silently emitting plaintext. See the
|
// returns [ErrEncryptionKeyRequired] rather than silently emitting plaintext.
|
||||||
// package-level ErrEncryptionKeyRequired documentation for the history behind
|
// See the package-level [ErrEncryptionKeyRequired] documentation for the
|
||||||
// this behavior change.
|
// history behind this behavior change (C-2).
|
||||||
func EncryptIfKeySet(plaintext []byte, key []byte) ([]byte, bool, error) {
|
//
|
||||||
if len(key) == 0 {
|
// The write path never produces a v1 blob. v1 blobs are read-only legacy
|
||||||
|
// state — see [DecryptIfKeySet] for the compatibility fallback.
|
||||||
|
func EncryptIfKeySet(plaintext []byte, passphrase string) ([]byte, bool, error) {
|
||||||
|
if passphrase == "" {
|
||||||
return nil, false, ErrEncryptionKeyRequired
|
return nil, false, ErrEncryptionKeyRequired
|
||||||
}
|
}
|
||||||
encrypted, err := Encrypt(plaintext, key)
|
|
||||||
|
salt := make([]byte, v2SaltSize)
|
||||||
|
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
|
||||||
|
return nil, false, fmt.Errorf("failed to generate v2 salt: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
key := deriveKeyWithSalt(passphrase, salt)
|
||||||
|
inner, err := Encrypt(plaintext, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
return encrypted, true, nil
|
|
||||||
|
// v2 blob layout: magic(1) || salt(v2SaltSize) || inner
|
||||||
|
blob := make([]byte, 0, 1+v2SaltSize+len(inner))
|
||||||
|
blob = append(blob, v2Magic)
|
||||||
|
blob = append(blob, salt...)
|
||||||
|
blob = append(blob, inner...)
|
||||||
|
return blob, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecryptIfKeySet decrypts ciphertext with the supplied 32-byte AES-256 key.
|
// DecryptIfKeySet decrypts blob with the supplied passphrase, supporting both
|
||||||
|
// v2 (M-8 and later) and v1 (legacy) on-disk formats.
|
||||||
//
|
//
|
||||||
// Passing an empty key now returns ErrEncryptionKeyRequired. Callers that
|
// Dispatch is first-byte magic + AEAD fallback. If blob starts with
|
||||||
// legitimately store plaintext (e.g. env-seeded source='env' rows that keep
|
// [v2Magic] and is long enough to contain a v2 header plus an AEAD-authenticated
|
||||||
// the raw JSON in the unencrypted `config` column) must branch on the presence
|
// inner ciphertext, a v2 decrypt is attempted using a key derived from the
|
||||||
// of the ciphertext themselves rather than relying on this helper to silently
|
// embedded salt. If that succeeds, its plaintext is returned. If v2 AEAD
|
||||||
// pass bytes through. See the package-level ErrEncryptionKeyRequired
|
// verification fails — which covers both the "wrong passphrase" case and the
|
||||||
// documentation for the history behind this behavior change.
|
// 1/256 case where a v1 blob's first byte happens to be 0x02 — the function
|
||||||
func DecryptIfKeySet(ciphertext []byte, key []byte) ([]byte, error) {
|
// falls through to the v1 path and attempts decryption using a key derived
|
||||||
if len(key) == 0 {
|
// from the package-level fixed salt [legacyV1Salt].
|
||||||
|
//
|
||||||
|
// Passing an empty passphrase returns [ErrEncryptionKeyRequired]. Callers that
|
||||||
|
// legitimately store plaintext (e.g. env-seeded source='env' rows that keep the
|
||||||
|
// raw JSON in the unencrypted `config` column) must branch on the presence of
|
||||||
|
// the ciphertext themselves rather than relying on this helper to silently
|
||||||
|
// pass bytes through. See the package-level [ErrEncryptionKeyRequired]
|
||||||
|
// documentation for the history behind this behavior change (C-2).
|
||||||
|
//
|
||||||
|
// The function never re-encrypts in place. A v1 blob that is successfully
|
||||||
|
// decrypted is returned to the caller as plaintext; re-sealing as v2 happens
|
||||||
|
// naturally on the next UPDATE via [EncryptIfKeySet].
|
||||||
|
func DecryptIfKeySet(blob []byte, passphrase string) ([]byte, error) {
|
||||||
|
if passphrase == "" {
|
||||||
return nil, ErrEncryptionKeyRequired
|
return nil, ErrEncryptionKeyRequired
|
||||||
}
|
}
|
||||||
return Decrypt(ciphertext, key)
|
if len(blob) == 0 {
|
||||||
|
return nil, fmt.Errorf("ciphertext is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// v2 path: magic || salt(16) || nonce(12) || ciphertext+tag (min 29 bytes
|
||||||
|
// ignoring the GCM tag; the AEAD verify inside Decrypt enforces the tag).
|
||||||
|
if blob[0] == v2Magic && len(blob) >= 1+v2SaltSize+12 {
|
||||||
|
salt := blob[1 : 1+v2SaltSize]
|
||||||
|
sealed := blob[1+v2SaltSize:]
|
||||||
|
key := deriveKeyWithSalt(passphrase, salt)
|
||||||
|
if plaintext, err := Decrypt(sealed, key); err == nil {
|
||||||
|
return plaintext, nil
|
||||||
|
}
|
||||||
|
// v2 AEAD verification failed. Fall through to v1 so that a v1 blob
|
||||||
|
// whose first byte happens to be 0x02 (1/256 probability) is still
|
||||||
|
// decryptable. If this is truly a v2 blob with the wrong passphrase,
|
||||||
|
// the v1 attempt below will also fail and the v1 error is returned.
|
||||||
|
}
|
||||||
|
|
||||||
|
// v1 legacy path: blob is the full ciphertext with no header and was
|
||||||
|
// sealed with a key derived from (passphrase, legacyV1Salt).
|
||||||
|
key := DeriveKey(passphrase)
|
||||||
|
return Decrypt(blob, key)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package crypto
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@@ -126,21 +128,20 @@ func TestDeriveKeyDifferentPassphrases(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEncryptIfKeySet_WithKey(t *testing.T) {
|
func TestEncryptIfKeySet_WithKey(t *testing.T) {
|
||||||
key := DeriveKey("test-key")
|
|
||||||
plaintext := []byte("config data")
|
plaintext := []byte("config data")
|
||||||
|
|
||||||
result, wasEncrypted, err := EncryptIfKeySet(plaintext, key)
|
result, wasEncrypted, err := EncryptIfKeySet(plaintext, "test-passphrase")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("EncryptIfKeySet failed: %v", err)
|
t.Fatalf("EncryptIfKeySet failed: %v", err)
|
||||||
}
|
}
|
||||||
if !wasEncrypted {
|
if !wasEncrypted {
|
||||||
t.Fatal("expected wasEncrypted=true when key provided")
|
t.Fatal("expected wasEncrypted=true when passphrase provided")
|
||||||
}
|
}
|
||||||
if bytes.Equal(result, plaintext) {
|
if bytes.Equal(result, plaintext) {
|
||||||
t.Fatal("result should be encrypted")
|
t.Fatal("result should be encrypted")
|
||||||
}
|
}
|
||||||
|
|
||||||
decrypted, err := DecryptIfKeySet(result, key)
|
decrypted, err := DecryptIfKeySet(result, "test-passphrase")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("DecryptIfKeySet failed: %v", err)
|
t.Fatalf("DecryptIfKeySet failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -150,24 +151,14 @@ func TestEncryptIfKeySet_WithKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TestEncryptIfKeySet_EmptyKeyFailsClosed asserts the C-2 regression guard:
|
// TestEncryptIfKeySet_EmptyKeyFailsClosed asserts the C-2 regression guard:
|
||||||
// EncryptIfKeySet must refuse to silently emit plaintext when no key is configured.
|
// EncryptIfKeySet must refuse to silently emit plaintext when no passphrase is
|
||||||
// The pre-fix behavior was to return plaintext with wasEncrypted=false, which
|
// configured. The pre-fix behavior was to return plaintext with
|
||||||
// produced a data-at-rest confidentiality bypass (CWE-311) for GUI-created
|
// wasEncrypted=false, which produced a data-at-rest confidentiality bypass
|
||||||
// issuer and target configs.
|
// (CWE-311) for GUI-created issuer and target configs.
|
||||||
func TestEncryptIfKeySet_EmptyKeyFailsClosed(t *testing.T) {
|
func TestEncryptIfKeySet_EmptyKeyFailsClosed(t *testing.T) {
|
||||||
plaintext := []byte("config data")
|
plaintext := []byte("config data")
|
||||||
|
|
||||||
cases := []struct {
|
result, wasEncrypted, err := EncryptIfKeySet(plaintext, "")
|
||||||
name string
|
|
||||||
key []byte
|
|
||||||
}{
|
|
||||||
{"nil_key", nil},
|
|
||||||
{"empty_key", []byte{}},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range cases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
result, wasEncrypted, err := EncryptIfKeySet(plaintext, tc.key)
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected ErrEncryptionKeyRequired, got nil")
|
t.Fatal("expected ErrEncryptionKeyRequired, got nil")
|
||||||
}
|
}
|
||||||
@@ -180,27 +171,15 @@ func TestEncryptIfKeySet_EmptyKeyFailsClosed(t *testing.T) {
|
|||||||
if result != nil {
|
if result != nil {
|
||||||
t.Fatalf("expected nil result on error, got %q", result)
|
t.Fatalf("expected nil result on error, got %q", result)
|
||||||
}
|
}
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestDecryptIfKeySet_EmptyKeyFailsClosed asserts the matching C-2 regression
|
// TestDecryptIfKeySet_EmptyKeyFailsClosed asserts the matching C-2 regression
|
||||||
// guard on the read path: DecryptIfKeySet must refuse to pass ciphertext
|
// guard on the read path: DecryptIfKeySet must refuse to pass ciphertext
|
||||||
// through as plaintext when no key is configured.
|
// through as plaintext when no passphrase is configured.
|
||||||
func TestDecryptIfKeySet_EmptyKeyFailsClosed(t *testing.T) {
|
func TestDecryptIfKeySet_EmptyKeyFailsClosed(t *testing.T) {
|
||||||
data := []byte("plaintext config data")
|
data := []byte("plaintext config data")
|
||||||
|
|
||||||
cases := []struct {
|
result, err := DecryptIfKeySet(data, "")
|
||||||
name string
|
|
||||||
key []byte
|
|
||||||
}{
|
|
||||||
{"nil_key", nil},
|
|
||||||
{"empty_key", []byte{}},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range cases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
result, err := DecryptIfKeySet(data, tc.key)
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected ErrEncryptionKeyRequired, got nil")
|
t.Fatal("expected ErrEncryptionKeyRequired, got nil")
|
||||||
}
|
}
|
||||||
@@ -210,29 +189,26 @@ func TestDecryptIfKeySet_EmptyKeyFailsClosed(t *testing.T) {
|
|||||||
if result != nil {
|
if result != nil {
|
||||||
t.Fatalf("expected nil result on error, got %q", result)
|
t.Fatalf("expected nil result on error, got %q", result)
|
||||||
}
|
}
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestEncryptDecryptIfKeySet_RoundTripProducesDifferentCiphertext proves the
|
// TestEncryptDecryptIfKeySet_RoundTripProducesDifferentCiphertext proves the
|
||||||
// "if set" helpers produce real AES-GCM output (not plaintext) and that a full
|
// "if set" helpers produce real AES-GCM output (not plaintext) and that a full
|
||||||
// round-trip through both helpers recovers the original bytes.
|
// round-trip through both helpers recovers the original bytes.
|
||||||
func TestEncryptDecryptIfKeySet_RoundTripProducesDifferentCiphertext(t *testing.T) {
|
func TestEncryptDecryptIfKeySet_RoundTripProducesDifferentCiphertext(t *testing.T) {
|
||||||
key := DeriveKey("round-trip-key")
|
|
||||||
plaintext := []byte(`{"api_key":"s3cr3t","token":"abc"}`)
|
plaintext := []byte(`{"api_key":"s3cr3t","token":"abc"}`)
|
||||||
|
|
||||||
encrypted, wasEncrypted, err := EncryptIfKeySet(plaintext, key)
|
encrypted, wasEncrypted, err := EncryptIfKeySet(plaintext, "round-trip-key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("EncryptIfKeySet failed: %v", err)
|
t.Fatalf("EncryptIfKeySet failed: %v", err)
|
||||||
}
|
}
|
||||||
if !wasEncrypted {
|
if !wasEncrypted {
|
||||||
t.Fatal("wasEncrypted must be true when key is present")
|
t.Fatal("wasEncrypted must be true when passphrase is present")
|
||||||
}
|
}
|
||||||
if bytes.Equal(encrypted, plaintext) {
|
if bytes.Equal(encrypted, plaintext) {
|
||||||
t.Fatal("EncryptIfKeySet returned plaintext — would regress C-2")
|
t.Fatal("EncryptIfKeySet returned plaintext — would regress C-2")
|
||||||
}
|
}
|
||||||
|
|
||||||
decrypted, err := DecryptIfKeySet(encrypted, key)
|
decrypted, err := DecryptIfKeySet(encrypted, "round-trip-key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("DecryptIfKeySet failed: %v", err)
|
t.Fatalf("DecryptIfKeySet failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -242,22 +218,24 @@ func TestEncryptDecryptIfKeySet_RoundTripProducesDifferentCiphertext(t *testing.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TestDecryptIfKeySet_RejectsTamperedCiphertext confirms the AEAD auth tag
|
// TestDecryptIfKeySet_RejectsTamperedCiphertext confirms the AEAD auth tag
|
||||||
// still rejects modified ciphertext when routed through the helper.
|
// still rejects modified ciphertext when routed through the helper. The v2
|
||||||
|
// wire format is magic(1) || salt(16) || nonce(12) || ciphertext+tag, so
|
||||||
|
// flipping a byte anywhere past offset 29 lands squarely inside the AEAD body.
|
||||||
func TestDecryptIfKeySet_RejectsTamperedCiphertext(t *testing.T) {
|
func TestDecryptIfKeySet_RejectsTamperedCiphertext(t *testing.T) {
|
||||||
key := DeriveKey("tamper-test-key")
|
|
||||||
plaintext := []byte("authenticated data")
|
plaintext := []byte("authenticated data")
|
||||||
|
|
||||||
encrypted, _, err := EncryptIfKeySet(plaintext, key)
|
encrypted, _, err := EncryptIfKeySet(plaintext, "tamper-test-key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("EncryptIfKeySet failed: %v", err)
|
t.Fatalf("EncryptIfKeySet failed: %v", err)
|
||||||
}
|
}
|
||||||
// Flip a byte inside the GCM body (past the 12-byte nonce) to invalidate the tag.
|
// Flip a byte past the v2 header (1 + 16 + 12 = 29) to invalidate the tag.
|
||||||
if len(encrypted) <= 13 {
|
const minV2HeaderLen = 1 + v2SaltSize + 12
|
||||||
|
if len(encrypted) <= minV2HeaderLen {
|
||||||
t.Fatalf("ciphertext too short to tamper: %d bytes", len(encrypted))
|
t.Fatalf("ciphertext too short to tamper: %d bytes", len(encrypted))
|
||||||
}
|
}
|
||||||
encrypted[13] ^= 0xFF
|
encrypted[minV2HeaderLen] ^= 0xFF
|
||||||
|
|
||||||
if _, err := DecryptIfKeySet(encrypted, key); err == nil {
|
if _, err := DecryptIfKeySet(encrypted, "tamper-test-key"); err == nil {
|
||||||
t.Fatal("DecryptIfKeySet accepted tampered ciphertext — AEAD tag check bypassed")
|
t.Fatal("DecryptIfKeySet accepted tampered ciphertext — AEAD tag check bypassed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -296,3 +274,217 @@ func TestEncryptProducesDifferentCiphertexts(t *testing.T) {
|
|||||||
t.Fatal("encrypting same plaintext twice should produce different ciphertexts (random nonce)")
|
t.Fatal("encrypting same plaintext twice should produce different ciphertexts (random nonce)")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// M-8 additions: per-ciphertext salt + v2 wire format + v1 backward compat.
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// TestDeriveKey_DifferentSaltsProduceDifferentKeys asserts that
|
||||||
|
// deriveKeyWithSalt fans out distinct 32-byte keys for the same passphrase
|
||||||
|
// across different salts. This is the core M-8 defense-in-depth property: even
|
||||||
|
// if an attacker obtains two v2 ciphertexts encrypted with the same master
|
||||||
|
// passphrase, the derived AES keys differ, and a brute-force attempt on one
|
||||||
|
// blob cannot be amortized across the other.
|
||||||
|
func TestDeriveKey_DifferentSaltsProduceDifferentKeys(t *testing.T) {
|
||||||
|
passphrase := "master-passphrase"
|
||||||
|
saltA := bytes.Repeat([]byte{0xAA}, v2SaltSize)
|
||||||
|
saltB := bytes.Repeat([]byte{0xBB}, v2SaltSize)
|
||||||
|
|
||||||
|
keyA := deriveKeyWithSalt(passphrase, saltA)
|
||||||
|
keyB := deriveKeyWithSalt(passphrase, saltB)
|
||||||
|
|
||||||
|
if len(keyA) != aes256KeySize || len(keyB) != aes256KeySize {
|
||||||
|
t.Fatalf("derived key length wrong: %d / %d", len(keyA), len(keyB))
|
||||||
|
}
|
||||||
|
if bytes.Equal(keyA, keyB) {
|
||||||
|
t.Fatal("deriveKeyWithSalt must produce different keys for different salts")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanity-check that deterministic behaviour is preserved under a fixed salt.
|
||||||
|
keyA2 := deriveKeyWithSalt(passphrase, saltA)
|
||||||
|
if !bytes.Equal(keyA, keyA2) {
|
||||||
|
t.Fatal("deriveKeyWithSalt must be deterministic for a fixed (passphrase, salt)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEncryptIfKeySet_ProducesV2Format asserts the exact v2 wire-format bytes:
|
||||||
|
// magic(0x02) || salt(16) || nonce(12) || ciphertext+tag.
|
||||||
|
func TestEncryptIfKeySet_ProducesV2Format(t *testing.T) {
|
||||||
|
blob, _, err := EncryptIfKeySet([]byte("hello"), "any-passphrase")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("EncryptIfKeySet failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const minLen = 1 + v2SaltSize + 12 + 16 // magic + salt + nonce + GCM tag (16)
|
||||||
|
if len(blob) < minLen {
|
||||||
|
t.Fatalf("v2 blob too short: got %d, want >= %d", len(blob), minLen)
|
||||||
|
}
|
||||||
|
if blob[0] != v2Magic {
|
||||||
|
t.Fatalf("v2 blob must start with magic byte 0x%02x, got 0x%02x", v2Magic, blob[0])
|
||||||
|
}
|
||||||
|
if IsLegacyFormat(blob) {
|
||||||
|
t.Fatal("IsLegacyFormat must return false for a freshly produced v2 blob")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEncryptIfKeySet_SaltIsRandom asserts that two calls with the same
|
||||||
|
// passphrase and plaintext produce distinct embedded salts.
|
||||||
|
func TestEncryptIfKeySet_SaltIsRandom(t *testing.T) {
|
||||||
|
plaintext := []byte("same plaintext")
|
||||||
|
passphrase := "same-passphrase"
|
||||||
|
|
||||||
|
blob1, _, err := EncryptIfKeySet(plaintext, passphrase)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("EncryptIfKeySet #1 failed: %v", err)
|
||||||
|
}
|
||||||
|
blob2, _, err := EncryptIfKeySet(plaintext, passphrase)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("EncryptIfKeySet #2 failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
salt1 := blob1[1 : 1+v2SaltSize]
|
||||||
|
salt2 := blob2[1 : 1+v2SaltSize]
|
||||||
|
if bytes.Equal(salt1, salt2) {
|
||||||
|
t.Fatal("two EncryptIfKeySet invocations must produce distinct per-ciphertext salts")
|
||||||
|
}
|
||||||
|
if bytes.Equal(blob1, blob2) {
|
||||||
|
t.Fatal("two v2 blobs with same (passphrase, plaintext) must differ end-to-end")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDecryptIfKeySet_V1BackwardCompat builds a deterministic v1-format
|
||||||
|
// ciphertext using the pre-M-8 recipe (DeriveKey with the fixed salt, then
|
||||||
|
// Encrypt with an all-zero nonce for reproducibility) and asserts that
|
||||||
|
// DecryptIfKeySet still decrypts it correctly. This is the migration guarantee:
|
||||||
|
// v1 blobs persisted before M-8 must remain decryptable.
|
||||||
|
func TestDecryptIfKeySet_V1BackwardCompat(t *testing.T) {
|
||||||
|
passphrase := "legacy-passphrase"
|
||||||
|
plaintext := []byte(`{"api_key":"legacy","org_id":"789"}`)
|
||||||
|
|
||||||
|
// Build a deterministic v1 blob directly: nonce(12 zero bytes) || ct+tag.
|
||||||
|
// This matches the exact wire shape that Encrypt produces, minus the random
|
||||||
|
// nonce, so the test is stable rather than 1/256 flaky.
|
||||||
|
key := DeriveKey(passphrase) // fixed-salt derivation (pre-M-8 behavior)
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("aes.NewCipher: %v", err)
|
||||||
|
}
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cipher.NewGCM: %v", err)
|
||||||
|
}
|
||||||
|
nonce := make([]byte, gcm.NonceSize()) // all zeros → first byte != v2Magic
|
||||||
|
v1Blob := gcm.Seal(nonce, nonce, plaintext, nil)
|
||||||
|
if v1Blob[0] == v2Magic {
|
||||||
|
t.Fatalf("fixture nonce collided with v2 magic byte — test design error")
|
||||||
|
}
|
||||||
|
|
||||||
|
decrypted, err := DecryptIfKeySet(v1Blob, passphrase)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DecryptIfKeySet(v1) failed: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(decrypted, plaintext) {
|
||||||
|
t.Fatalf("v1 decrypt mismatch: got %q, want %q", decrypted, plaintext)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cross-check: IsLegacyFormat should flag this as legacy.
|
||||||
|
if !IsLegacyFormat(v1Blob) {
|
||||||
|
t.Fatal("IsLegacyFormat must return true for a v1 blob whose first byte != v2Magic")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDecryptIfKeySet_V1MagicByteCollisionFallsThrough covers the 1/256 edge
|
||||||
|
// case where a v1 ciphertext's random 12-byte nonce happens to begin with
|
||||||
|
// 0x02. The dispatch must attempt v2, see AEAD failure, and fall through to
|
||||||
|
// v1 — never return a decrypt error when the passphrase is correct.
|
||||||
|
func TestDecryptIfKeySet_V1MagicByteCollisionFallsThrough(t *testing.T) {
|
||||||
|
passphrase := "collision-passphrase"
|
||||||
|
plaintext := []byte("colliding v1 blob")
|
||||||
|
|
||||||
|
// Craft a v1 blob whose first byte equals v2Magic by choosing a nonce
|
||||||
|
// starting with 0x02 and sealing manually.
|
||||||
|
key := DeriveKey(passphrase)
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("aes.NewCipher: %v", err)
|
||||||
|
}
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cipher.NewGCM: %v", err)
|
||||||
|
}
|
||||||
|
nonce := make([]byte, gcm.NonceSize())
|
||||||
|
nonce[0] = v2Magic // force collision
|
||||||
|
v1Blob := gcm.Seal(nonce, nonce, plaintext, nil)
|
||||||
|
if v1Blob[0] != v2Magic {
|
||||||
|
t.Fatal("fixture construction bug: first byte must equal v2Magic")
|
||||||
|
}
|
||||||
|
|
||||||
|
decrypted, err := DecryptIfKeySet(v1Blob, passphrase)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DecryptIfKeySet must fall through to v1 on AEAD failure, got err: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(decrypted, plaintext) {
|
||||||
|
t.Fatalf("v1-via-fallback decrypt mismatch: got %q, want %q", decrypted, plaintext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDecryptIfKeySet_V2WithWrongPassphraseFails asserts that a v2 blob
|
||||||
|
// sealed under passphrase A cannot be decrypted under passphrase B. Both the
|
||||||
|
// v2 AEAD verify (with salt from the blob + passphrase B) and the v1 fallback
|
||||||
|
// (with fixed salt + passphrase B) must fail, and an error must be returned
|
||||||
|
// rather than silently-corrupt plaintext.
|
||||||
|
func TestDecryptIfKeySet_V2WithWrongPassphraseFails(t *testing.T) {
|
||||||
|
blob, _, err := EncryptIfKeySet([]byte("secret"), "passphrase-A")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("EncryptIfKeySet failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := DecryptIfKeySet(blob, "passphrase-B")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("DecryptIfKeySet must return error for wrong passphrase, got plaintext %q", got)
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
t.Fatalf("result must be nil on decrypt error, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDecryptIfKeySet_TruncatedV2Blob asserts that a blob starting with the v2
|
||||||
|
// magic byte but too short to contain a full v2 header does not trip an
|
||||||
|
// out-of-bounds slice and does not succeed. It either returns an error (v1
|
||||||
|
// fallback on the short bytes fails with "ciphertext too short") or at minimum
|
||||||
|
// never returns plaintext.
|
||||||
|
func TestDecryptIfKeySet_TruncatedV2Blob(t *testing.T) {
|
||||||
|
truncated := []byte{v2Magic, 0x00, 0x01, 0x02, 0x03} // 5 bytes — well below the 29-byte v2 minimum
|
||||||
|
got, err := DecryptIfKeySet(truncated, "any-passphrase")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("DecryptIfKeySet must reject a truncated v2 blob, got plaintext %q", got)
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
t.Fatalf("result must be nil on decrypt error, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIsLegacyFormat covers the three branches of the public magic-byte
|
||||||
|
// heuristic: v2 blob → false, v1 blob → true, empty blob → false.
|
||||||
|
func TestIsLegacyFormat(t *testing.T) {
|
||||||
|
v2Blob, _, err := EncryptIfKeySet([]byte("data"), "p")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("EncryptIfKeySet failed: %v", err)
|
||||||
|
}
|
||||||
|
if IsLegacyFormat(v2Blob) {
|
||||||
|
t.Fatal("v2 blob must not be flagged as legacy")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Any blob whose first byte isn't v2Magic should be reported as legacy.
|
||||||
|
v1Shape := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0xFF}
|
||||||
|
if !IsLegacyFormat(v1Shape) {
|
||||||
|
t.Fatal("non-v2-magic blob must be flagged as legacy")
|
||||||
|
}
|
||||||
|
|
||||||
|
if IsLegacyFormat(nil) {
|
||||||
|
t.Fatal("nil blob must not be flagged as legacy (undefined)")
|
||||||
|
}
|
||||||
|
if IsLegacyFormat([]byte{}) {
|
||||||
|
t.Fatal("empty blob must not be flagged as legacy (undefined)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ func TestCertificateLifecycle(t *testing.T) {
|
|||||||
// without a configured CERTCTL_CONFIG_ENCRYPTION_KEY. Happy-path CRUD tests
|
// without a configured CERTCTL_CONFIG_ENCRYPTION_KEY. Happy-path CRUD tests
|
||||||
// must supply a real key so the encrypt path runs instead of returning
|
// must supply a real key so the encrypt path runs instead of returning
|
||||||
// ErrEncryptionKeyRequired.
|
// ErrEncryptionKeyRequired.
|
||||||
testEncryptionKey := []byte("0123456789abcdef0123456789abcdef")
|
testEncryptionKey := "0123456789abcdef0123456789abcdef"
|
||||||
issuerService := service.NewIssuerService(issuerRepo, auditService, issuerRegistry, testEncryptionKey, slog.Default())
|
issuerService := service.NewIssuerService(issuerRepo, auditService, issuerRegistry, testEncryptionKey, slog.Default())
|
||||||
|
|
||||||
// Initialize handlers
|
// Initialize handlers
|
||||||
@@ -772,6 +772,14 @@ func (m *mockAgentRepository) Create(ctx context.Context, agent *domain.Agent) e
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockAgentRepository) CreateIfNotExists(ctx context.Context, agent *domain.Agent) (bool, error) {
|
||||||
|
if _, exists := m.agents[agent.ID]; exists {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
m.agents[agent.ID] = agent
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockAgentRepository) Update(ctx context.Context, agent *domain.Agent) error {
|
func (m *mockAgentRepository) Update(ctx context.Context, agent *domain.Agent) error {
|
||||||
m.agents[agent.ID] = agent
|
m.agents[agent.ID] = agent
|
||||||
return nil
|
return nil
|
||||||
@@ -1028,8 +1036,8 @@ type mockTargetService struct {
|
|||||||
auditService *service.AuditService
|
auditService *service.AuditService
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockTargetService) ListTargets(page, perPage int) ([]domain.DeploymentTarget, int64, error) {
|
func (m *mockTargetService) ListTargets(ctx context.Context, page, perPage int) ([]domain.DeploymentTarget, int64, error) {
|
||||||
targets, err := m.targetRepo.List(context.Background())
|
targets, err := m.targetRepo.List(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
@@ -1040,99 +1048,99 @@ func (m *mockTargetService) ListTargets(page, perPage int) ([]domain.DeploymentT
|
|||||||
return result, int64(len(result)), nil
|
return result, int64(len(result)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockTargetService) GetTarget(id string) (*domain.DeploymentTarget, error) {
|
func (m *mockTargetService) GetTarget(ctx context.Context, id string) (*domain.DeploymentTarget, error) {
|
||||||
return m.targetRepo.Get(context.Background(), id)
|
return m.targetRepo.Get(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockTargetService) CreateTarget(target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
func (m *mockTargetService) CreateTarget(ctx context.Context, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
||||||
if err := m.targetRepo.Create(context.Background(), &target); err != nil {
|
if err := m.targetRepo.Create(ctx, &target); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &target, nil
|
return &target, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockTargetService) UpdateTarget(id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
func (m *mockTargetService) UpdateTarget(ctx context.Context, id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
||||||
target.ID = id
|
target.ID = id
|
||||||
if err := m.targetRepo.Update(context.Background(), &target); err != nil {
|
if err := m.targetRepo.Update(ctx, &target); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &target, nil
|
return &target, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockTargetService) DeleteTarget(id string) error {
|
func (m *mockTargetService) DeleteTarget(ctx context.Context, id string) error {
|
||||||
return m.targetRepo.Delete(context.Background(), id)
|
return m.targetRepo.Delete(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockTargetService) TestTargetConnection(id string) error {
|
func (m *mockTargetService) TestConnection(ctx context.Context, id string) error {
|
||||||
return nil // No-op for integration tests
|
return nil // No-op for integration tests
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockTeamService struct{}
|
type mockTeamService struct{}
|
||||||
|
|
||||||
func (m *mockTeamService) ListTeams(page, perPage int) ([]domain.Team, int64, error) {
|
func (m *mockTeamService) ListTeams(_ context.Context, page, perPage int) ([]domain.Team, int64, error) {
|
||||||
return []domain.Team{}, 0, nil
|
return []domain.Team{}, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockTeamService) GetTeam(id string) (*domain.Team, error) {
|
func (m *mockTeamService) GetTeam(_ context.Context, id string) (*domain.Team, error) {
|
||||||
return nil, fmt.Errorf("team not found")
|
return nil, fmt.Errorf("team not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockTeamService) CreateTeam(team domain.Team) (*domain.Team, error) {
|
func (m *mockTeamService) CreateTeam(_ context.Context, team domain.Team) (*domain.Team, error) {
|
||||||
return &team, nil
|
return &team, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockTeamService) UpdateTeam(id string, team domain.Team) (*domain.Team, error) {
|
func (m *mockTeamService) UpdateTeam(_ context.Context, id string, team domain.Team) (*domain.Team, error) {
|
||||||
team.ID = id
|
team.ID = id
|
||||||
return &team, nil
|
return &team, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockTeamService) DeleteTeam(id string) error {
|
func (m *mockTeamService) DeleteTeam(_ context.Context, id string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockOwnerService struct{}
|
type mockOwnerService struct{}
|
||||||
|
|
||||||
func (m *mockOwnerService) ListOwners(page, perPage int) ([]domain.Owner, int64, error) {
|
func (m *mockOwnerService) ListOwners(_ context.Context, page, perPage int) ([]domain.Owner, int64, error) {
|
||||||
return []domain.Owner{}, 0, nil
|
return []domain.Owner{}, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockOwnerService) GetOwner(id string) (*domain.Owner, error) {
|
func (m *mockOwnerService) GetOwner(_ context.Context, id string) (*domain.Owner, error) {
|
||||||
return nil, fmt.Errorf("owner not found")
|
return nil, fmt.Errorf("owner not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockOwnerService) CreateOwner(owner domain.Owner) (*domain.Owner, error) {
|
func (m *mockOwnerService) CreateOwner(_ context.Context, owner domain.Owner) (*domain.Owner, error) {
|
||||||
return &owner, nil
|
return &owner, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockOwnerService) UpdateOwner(id string, owner domain.Owner) (*domain.Owner, error) {
|
func (m *mockOwnerService) UpdateOwner(_ context.Context, id string, owner domain.Owner) (*domain.Owner, error) {
|
||||||
owner.ID = id
|
owner.ID = id
|
||||||
return &owner, nil
|
return &owner, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockOwnerService) DeleteOwner(id string) error {
|
func (m *mockOwnerService) DeleteOwner(_ context.Context, id string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockProfileService struct{}
|
type mockProfileService struct{}
|
||||||
|
|
||||||
func (m *mockProfileService) ListProfiles(page, perPage int) ([]domain.CertificateProfile, int64, error) {
|
func (m *mockProfileService) ListProfiles(_ context.Context, page, perPage int) ([]domain.CertificateProfile, int64, error) {
|
||||||
return []domain.CertificateProfile{}, 0, nil
|
return []domain.CertificateProfile{}, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockProfileService) GetProfile(id string) (*domain.CertificateProfile, error) {
|
func (m *mockProfileService) GetProfile(_ context.Context, id string) (*domain.CertificateProfile, error) {
|
||||||
return nil, fmt.Errorf("profile not found")
|
return nil, fmt.Errorf("profile not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockProfileService) CreateProfile(profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
|
func (m *mockProfileService) CreateProfile(_ context.Context, profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
|
||||||
return &profile, nil
|
return &profile, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockProfileService) UpdateProfile(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
|
func (m *mockProfileService) UpdateProfile(_ context.Context, id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
|
||||||
profile.ID = id
|
profile.ID = id
|
||||||
return &profile, nil
|
return &profile, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockProfileService) DeleteProfile(id string) error {
|
func (m *mockProfileService) DeleteProfile(_ context.Context, id string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ func setupTestServer(t *testing.T) (*httptest.Server, *mockCertificateRepository
|
|||||||
// without a configured CERTCTL_CONFIG_ENCRYPTION_KEY. Happy-path CRUD tests
|
// without a configured CERTCTL_CONFIG_ENCRYPTION_KEY. Happy-path CRUD tests
|
||||||
// must supply a real key so the encrypt path runs instead of returning
|
// must supply a real key so the encrypt path runs instead of returning
|
||||||
// ErrEncryptionKeyRequired.
|
// ErrEncryptionKeyRequired.
|
||||||
testEncryptionKey := []byte("0123456789abcdef0123456789abcdef")
|
testEncryptionKey := "0123456789abcdef0123456789abcdef"
|
||||||
issuerService := service.NewIssuerService(issuerRepo, auditService, issuerRegistry, testEncryptionKey, logger)
|
issuerService := service.NewIssuerService(issuerRepo, auditService, issuerRegistry, testEncryptionKey, logger)
|
||||||
|
|
||||||
certificateHandler := handler.NewCertificateHandler(certificateService)
|
certificateHandler := handler.NewCertificateHandler(certificateService)
|
||||||
|
|||||||
@@ -90,8 +90,18 @@ type AgentRepository interface {
|
|||||||
List(ctx context.Context) ([]*domain.Agent, error)
|
List(ctx context.Context) ([]*domain.Agent, error)
|
||||||
// Get retrieves an agent by ID.
|
// Get retrieves an agent by ID.
|
||||||
Get(ctx context.Context, id string) (*domain.Agent, error)
|
Get(ctx context.Context, id string) (*domain.Agent, error)
|
||||||
// Create stores a new agent.
|
// Create stores a new agent. Callers that want duplicate-key errors surfaced
|
||||||
|
// (e.g. real-agent registration) must use this method; sentinel/bootstrap
|
||||||
|
// paths that expect the row to already exist on restart should call
|
||||||
|
// CreateIfNotExists instead (M-6, CWE-662).
|
||||||
Create(ctx context.Context, agent *domain.Agent) error
|
Create(ctx context.Context, agent *domain.Agent) error
|
||||||
|
// CreateIfNotExists creates an agent only if the ID doesn't already exist
|
||||||
|
// (INSERT ... ON CONFLICT (id) DO NOTHING). Returns true if the row was
|
||||||
|
// newly inserted, false if a row with the same ID already existed. Used
|
||||||
|
// by the sentinel-agent bootstrap path in cmd/server/main.go so restarts
|
||||||
|
// and upgrades are idempotent without swallowing unrelated database
|
||||||
|
// failures (M-6, CWE-662).
|
||||||
|
CreateIfNotExists(ctx context.Context, agent *domain.Agent) (bool, error)
|
||||||
// Update modifies an existing agent.
|
// Update modifies an existing agent.
|
||||||
Update(ctx context.Context, agent *domain.Agent) error
|
Update(ctx context.Context, agent *domain.Agent) error
|
||||||
// Delete removes an agent.
|
// Delete removes an agent.
|
||||||
|
|||||||
@@ -70,7 +70,9 @@ func (r *AgentRepository) Get(ctx context.Context, id string) (*domain.Agent, er
|
|||||||
return agent, nil
|
return agent, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create stores a new agent
|
// Create stores a new agent. Duplicate-key errors surface to the caller —
|
||||||
|
// real-agent registration paths rely on this to detect collisions. Use
|
||||||
|
// CreateIfNotExists for sentinel/bootstrap paths where re-inserts are expected.
|
||||||
func (r *AgentRepository) Create(ctx context.Context, agent *domain.Agent) error {
|
func (r *AgentRepository) Create(ctx context.Context, agent *domain.Agent) error {
|
||||||
if agent.ID == "" {
|
if agent.ID == "" {
|
||||||
agent.ID = uuid.New().String()
|
agent.ID = uuid.New().String()
|
||||||
@@ -92,6 +94,44 @@ func (r *AgentRepository) Create(ctx context.Context, agent *domain.Agent) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateIfNotExists creates an agent only if the ID doesn't already exist.
|
||||||
|
// Used for sentinel agents (server-scanner, cloud-aws-sm, cloud-azure-kv,
|
||||||
|
// cloud-gcp-sm) on first boot AND on every subsequent restart/upgrade — the
|
||||||
|
// pre-M-6 code used plain INSERT, swallowed the duplicate-key error, and so
|
||||||
|
// silently swallowed every other database failure too (CWE-662 /
|
||||||
|
// CWE-209-adjacent). ON CONFLICT (id) DO NOTHING + RETURNING id +
|
||||||
|
// sql.ErrNoRows distinguishes "row already existed" (created=false, err=nil)
|
||||||
|
// from genuine errors (connectivity, permission, constraint violations
|
||||||
|
// other than the id primary key) which still surface. Returns true if the
|
||||||
|
// row was newly inserted, false if a row with the same ID already existed.
|
||||||
|
func (r *AgentRepository) CreateIfNotExists(ctx context.Context, agent *domain.Agent) (bool, error) {
|
||||||
|
if agent.ID == "" {
|
||||||
|
agent.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
var id string
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO agents (id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash,
|
||||||
|
os, architecture, ip_address, version)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||||
|
ON CONFLICT (id) DO NOTHING
|
||||||
|
RETURNING id
|
||||||
|
`, agent.ID, agent.Name, agent.Hostname, agent.Status, agent.LastHeartbeatAt,
|
||||||
|
agent.RegisteredAt, agent.APIKeyHash,
|
||||||
|
agent.OS, agent.Architecture, agent.IPAddress, agent.Version).Scan(&id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
// ON CONFLICT DO NOTHING — a row with this ID already existed.
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, fmt.Errorf("failed to create agent: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
agent.ID = id
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Update modifies an existing agent
|
// Update modifies an existing agent
|
||||||
func (r *AgentRepository) Update(ctx context.Context, agent *domain.Agent) error {
|
func (r *AgentRepository) Update(ctx context.Context, agent *domain.Agent) error {
|
||||||
result, err := r.db.ExecContext(ctx, `
|
result, err := r.db.ExecContext(ctx, `
|
||||||
|
|||||||
@@ -190,18 +190,65 @@ func (r *CertificateRepository) List(ctx context.Context, filter *repository.Cer
|
|||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
var certs []*domain.ManagedCertificate
|
var certs []*domain.ManagedCertificate
|
||||||
|
var certIDs []string
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
cert, err := scanCertificate(rows)
|
var cert domain.ManagedCertificate
|
||||||
|
var tagsJSON []byte
|
||||||
|
var sans pq.StringArray
|
||||||
|
var profileID sql.NullString
|
||||||
|
var revocationReason sql.NullString
|
||||||
|
|
||||||
|
err := rows.Scan(
|
||||||
|
&cert.ID, &cert.Name, &cert.CommonName, &sans, &cert.Environment, &cert.OwnerID,
|
||||||
|
&cert.TeamID, &cert.IssuerID, &cert.RenewalPolicyID, &profileID,
|
||||||
|
&cert.Status, &cert.ExpiresAt, &tagsJSON,
|
||||||
|
&cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.RevokedAt, &revocationReason,
|
||||||
|
&cert.CreatedAt, &cert.UpdatedAt)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, fmt.Errorf("failed to scan certificate: %w", err)
|
||||||
}
|
}
|
||||||
certs = append(certs, cert)
|
|
||||||
|
cert.SANs = []string(sans)
|
||||||
|
if profileID.Valid {
|
||||||
|
cert.CertificateProfileID = profileID.String
|
||||||
|
}
|
||||||
|
if revocationReason.Valid {
|
||||||
|
cert.RevocationReason = revocationReason.String
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal tags
|
||||||
|
if len(tagsJSON) > 0 {
|
||||||
|
if err := json.Unmarshal(tagsJSON, &cert.Tags); err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to unmarshal tags: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cert.Tags = make(map[string]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
certs = append(certs, &cert)
|
||||||
|
certIDs = append(certIDs, cert.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
return nil, 0, fmt.Errorf("error iterating certificate rows: %w", err)
|
return nil, 0, fmt.Errorf("error iterating certificate rows: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fetch target IDs for all certificates in a single query (avoid N+1)
|
||||||
|
if len(certIDs) > 0 {
|
||||||
|
targetIDsMap, err := r.getTargetIDsForCertificates(ctx, certIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
for _, cert := range certs {
|
||||||
|
if targetIDs, ok := targetIDsMap[cert.ID]; ok {
|
||||||
|
cert.TargetIDs = targetIDs
|
||||||
|
} else {
|
||||||
|
cert.TargetIDs = []string{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return certs, total, nil
|
return certs, total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -214,7 +261,7 @@ func (r *CertificateRepository) Get(ctx context.Context, id string) (*domain.Man
|
|||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
`, id)
|
`, id)
|
||||||
|
|
||||||
cert, err := scanCertificate(row)
|
cert, err := r.scanCertificate(ctx, row)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, fmt.Errorf("certificate not found")
|
return nil, fmt.Errorf("certificate not found")
|
||||||
@@ -421,18 +468,65 @@ func (r *CertificateRepository) GetExpiringCertificates(ctx context.Context, bef
|
|||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
var certs []*domain.ManagedCertificate
|
var certs []*domain.ManagedCertificate
|
||||||
|
var certIDs []string
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
cert, err := scanCertificate(rows)
|
var cert domain.ManagedCertificate
|
||||||
|
var tagsJSON []byte
|
||||||
|
var sans pq.StringArray
|
||||||
|
var profileID sql.NullString
|
||||||
|
var revocationReason sql.NullString
|
||||||
|
|
||||||
|
err := rows.Scan(
|
||||||
|
&cert.ID, &cert.Name, &cert.CommonName, &sans, &cert.Environment, &cert.OwnerID,
|
||||||
|
&cert.TeamID, &cert.IssuerID, &cert.RenewalPolicyID, &profileID,
|
||||||
|
&cert.Status, &cert.ExpiresAt, &tagsJSON,
|
||||||
|
&cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.RevokedAt, &revocationReason,
|
||||||
|
&cert.CreatedAt, &cert.UpdatedAt)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to scan certificate: %w", err)
|
||||||
}
|
}
|
||||||
certs = append(certs, cert)
|
|
||||||
|
cert.SANs = []string(sans)
|
||||||
|
if profileID.Valid {
|
||||||
|
cert.CertificateProfileID = profileID.String
|
||||||
|
}
|
||||||
|
if revocationReason.Valid {
|
||||||
|
cert.RevocationReason = revocationReason.String
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal tags
|
||||||
|
if len(tagsJSON) > 0 {
|
||||||
|
if err := json.Unmarshal(tagsJSON, &cert.Tags); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal tags: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cert.Tags = make(map[string]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
certs = append(certs, &cert)
|
||||||
|
certIDs = append(certIDs, cert.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
return nil, fmt.Errorf("error iterating expiring certificate rows: %w", err)
|
return nil, fmt.Errorf("error iterating expiring certificate rows: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fetch target IDs for all certificates in a single query (avoid N+1)
|
||||||
|
if len(certIDs) > 0 {
|
||||||
|
targetIDsMap, err := r.getTargetIDsForCertificates(ctx, certIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, cert := range certs {
|
||||||
|
if targetIDs, ok := targetIDsMap[cert.ID]; ok {
|
||||||
|
cert.TargetIDs = targetIDs
|
||||||
|
} else {
|
||||||
|
cert.TargetIDs = []string{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return certs, nil
|
return certs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -462,8 +556,76 @@ func (r *CertificateRepository) GetLatestVersion(ctx context.Context, certID str
|
|||||||
return &v, nil
|
return &v, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// scanCertificate scans a certificate from a row or rows
|
// getTargetIDs retrieves all target IDs for a given certificate from the junction table.
|
||||||
func scanCertificate(scanner interface {
|
// Returns an empty slice (not nil) if no targets are found.
|
||||||
|
func (r *CertificateRepository) getTargetIDs(ctx context.Context, certID string) ([]string, error) {
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT target_id FROM certificate_target_mappings
|
||||||
|
WHERE certificate_id = $1
|
||||||
|
ORDER BY target_id ASC
|
||||||
|
`, certID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query target mappings: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var targetIDs []string
|
||||||
|
for rows.Next() {
|
||||||
|
var targetID string
|
||||||
|
if err := rows.Scan(&targetID); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan target ID: %w", err)
|
||||||
|
}
|
||||||
|
targetIDs = append(targetIDs, targetID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating target ID rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return empty slice instead of nil for consistency with JSON marshaling
|
||||||
|
if targetIDs == nil {
|
||||||
|
targetIDs = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return targetIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTargetIDsForCertificates retrieves target IDs for multiple certificates in a single query.
|
||||||
|
// Returns a map of certificate_id -> []target_id.
|
||||||
|
func (r *CertificateRepository) getTargetIDsForCertificates(ctx context.Context, certIDs []string) (map[string][]string, error) {
|
||||||
|
if len(certIDs) == 0 {
|
||||||
|
return make(map[string][]string), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT certificate_id, target_id FROM certificate_target_mappings
|
||||||
|
WHERE certificate_id = ANY($1)
|
||||||
|
ORDER BY certificate_id, target_id ASC
|
||||||
|
`, pq.Array(certIDs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query target mappings: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
targetIDsMap := make(map[string][]string)
|
||||||
|
for rows.Next() {
|
||||||
|
var certID, targetID string
|
||||||
|
if err := rows.Scan(&certID, &targetID); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan target mapping: %w", err)
|
||||||
|
}
|
||||||
|
targetIDsMap[certID] = append(targetIDsMap[certID], targetID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating target mapping rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return targetIDsMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanCertificate scans a certificate from a row or rows and populates its TargetIDs
|
||||||
|
// by querying the certificate_target_mappings junction table.
|
||||||
|
func (r *CertificateRepository) scanCertificate(ctx context.Context, scanner interface {
|
||||||
Scan(...interface{}) error
|
Scan(...interface{}) error
|
||||||
}) (*domain.ManagedCertificate, error) {
|
}) (*domain.ManagedCertificate, error) {
|
||||||
var cert domain.ManagedCertificate
|
var cert domain.ManagedCertificate
|
||||||
@@ -500,6 +662,13 @@ func scanCertificate(scanner interface {
|
|||||||
cert.Tags = make(map[string]string)
|
cert.Tags = make(map[string]string)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Populate TargetIDs from junction table
|
||||||
|
targetIDs, err := r.getTargetIDs(ctx, cert.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cert.TargetIDs = targetIDs
|
||||||
|
|
||||||
return &cert, nil
|
return &cert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,322 @@
|
|||||||
|
// Package postgres_test — integration tests for M-7: Certificate.TargetIDs
|
||||||
|
// must be populated from certificate_target_mappings on read.
|
||||||
|
//
|
||||||
|
// Before M-7 the repository scan helper never consulted the junction table, so
|
||||||
|
// Get / List / GetExpiringCertificates always returned empty TargetIDs even when
|
||||||
|
// rows existed in certificate_target_mappings. These tests exercise all three
|
||||||
|
// read paths end-to-end against a real PostgreSQL 16 container.
|
||||||
|
//
|
||||||
|
// Runs against the shared testcontainer from testutil_test.go. Skipped when
|
||||||
|
// `-short` is set (CI uses short mode; local runs pick it up by default).
|
||||||
|
package postgres_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
"github.com/shankar0123/certctl/internal/repository/postgres"
|
||||||
|
)
|
||||||
|
|
||||||
|
// insertAgentAndTargetsRaw creates one agent and N deployment_targets, returns
|
||||||
|
// the agent ID and the list of target IDs (in insertion order).
|
||||||
|
func insertAgentAndTargetsRaw(t *testing.T, db *sql.DB, ctx context.Context, suffix string, n int) (agentID string, targetIDs []string) {
|
||||||
|
t.Helper()
|
||||||
|
now := time.Now().Truncate(time.Microsecond)
|
||||||
|
agentID = "agent-" + suffix
|
||||||
|
|
||||||
|
_, err := db.ExecContext(ctx, `
|
||||||
|
INSERT INTO agents (id, name, hostname, status, registered_at, api_key_hash)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6)
|
||||||
|
`, agentID, "agent-"+suffix, "host-"+suffix, "online", now, "hash-"+suffix)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insertAgent failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
tid := "t-" + suffix + "-" + intToStr(i)
|
||||||
|
_, err := db.ExecContext(ctx, `
|
||||||
|
INSERT INTO deployment_targets (id, name, type, agent_id, config, enabled, created_at, updated_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
|
`, tid, tid, "NGINX", agentID, []byte(`{}`), true, now, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insertTarget %d failed: %v", i, err)
|
||||||
|
}
|
||||||
|
targetIDs = append(targetIDs, tid)
|
||||||
|
}
|
||||||
|
return agentID, targetIDs
|
||||||
|
}
|
||||||
|
|
||||||
|
// intToStr converts a non-negative int to its decimal string.
|
||||||
|
// Local helper to avoid importing strconv for a single use.
|
||||||
|
func intToStr(n int) string {
|
||||||
|
if n == 0 {
|
||||||
|
return "0"
|
||||||
|
}
|
||||||
|
var buf [20]byte
|
||||||
|
i := len(buf)
|
||||||
|
for n > 0 {
|
||||||
|
i--
|
||||||
|
buf[i] = byte('0' + n%10)
|
||||||
|
n /= 10
|
||||||
|
}
|
||||||
|
return string(buf[i:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertCertificateRow writes a minimal managed_certificates row via raw SQL.
|
||||||
|
// Bypasses the repository Create so we can isolate read-path tests from any
|
||||||
|
// write-path behavior. managed_certificates.sans is TEXT[], written here as an
|
||||||
|
// empty array literal.
|
||||||
|
func insertCertificateRow(t *testing.T, db *sql.DB, ctx context.Context, certID, ownerID, teamID, issuerID, policyID string, expiresAt time.Time) {
|
||||||
|
t.Helper()
|
||||||
|
now := time.Now().Truncate(time.Microsecond)
|
||||||
|
_, err := db.ExecContext(ctx, `
|
||||||
|
INSERT INTO managed_certificates (
|
||||||
|
id, name, common_name, sans, environment,
|
||||||
|
owner_id, team_id, issuer_id, renewal_policy_id,
|
||||||
|
status, expires_at, tags,
|
||||||
|
created_at, updated_at
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3, ARRAY[]::TEXT[], $4,
|
||||||
|
$5, $6, $7, $8,
|
||||||
|
$9, $10, $11,
|
||||||
|
$12, $13
|
||||||
|
)
|
||||||
|
`,
|
||||||
|
certID, certID, certID+".example.com", "production",
|
||||||
|
ownerID, teamID, issuerID, policyID,
|
||||||
|
string(domain.CertificateStatusActive), expiresAt, []byte(`{}`),
|
||||||
|
now, now,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insertCertificateRow failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertMapping writes a single row into certificate_target_mappings via raw SQL.
|
||||||
|
func insertMapping(t *testing.T, db *sql.DB, ctx context.Context, certID, targetID string) {
|
||||||
|
t.Helper()
|
||||||
|
_, err := db.ExecContext(ctx,
|
||||||
|
`INSERT INTO certificate_target_mappings (certificate_id, target_id) VALUES ($1, $2)`,
|
||||||
|
certID, targetID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insertMapping(%s, %s) failed: %v", certID, targetID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------
|
||||||
|
// Get() — single-cert read path
|
||||||
|
// --------------------------------------------------------------------
|
||||||
|
|
||||||
|
// TestGet_PopulatesTargetIDs_NoMappings: no mapping rows → TargetIDs must be
|
||||||
|
// an empty slice, not nil, so JSON serialisation emits "[]".
|
||||||
|
func TestGet_PopulatesTargetIDs_NoMappings(t *testing.T) {
|
||||||
|
tdb := getTestDB(t)
|
||||||
|
db := tdb.freshSchema(t)
|
||||||
|
repo := postgres.NewCertificateRepository(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "getnone")
|
||||||
|
certID := "mc-getnone"
|
||||||
|
insertCertificateRow(t, db, ctx, certID, ownerID, teamID, issuerID, policyID, time.Now().Add(30*24*time.Hour))
|
||||||
|
|
||||||
|
got, err := repo.Get(ctx, certID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get failed: %v", err)
|
||||||
|
}
|
||||||
|
if got.TargetIDs == nil {
|
||||||
|
t.Fatalf("TargetIDs = nil, want empty slice (JSON serialises nil as null and [] as [])")
|
||||||
|
}
|
||||||
|
if len(got.TargetIDs) != 0 {
|
||||||
|
t.Errorf("len(TargetIDs) = %d, want 0; got %v", len(got.TargetIDs), got.TargetIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGet_PopulatesTargetIDs_SingleTarget: one mapping → one entry.
|
||||||
|
func TestGet_PopulatesTargetIDs_SingleTarget(t *testing.T) {
|
||||||
|
tdb := getTestDB(t)
|
||||||
|
db := tdb.freshSchema(t)
|
||||||
|
repo := postgres.NewCertificateRepository(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "getone")
|
||||||
|
_, targets := insertAgentAndTargetsRaw(t, db, ctx, "getone", 1)
|
||||||
|
|
||||||
|
certID := "mc-getone"
|
||||||
|
insertCertificateRow(t, db, ctx, certID, ownerID, teamID, issuerID, policyID, time.Now().Add(30*24*time.Hour))
|
||||||
|
insertMapping(t, db, ctx, certID, targets[0])
|
||||||
|
|
||||||
|
got, err := repo.Get(ctx, certID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(got.TargetIDs) != 1 {
|
||||||
|
t.Fatalf("len(TargetIDs) = %d, want 1; got %v", len(got.TargetIDs), got.TargetIDs)
|
||||||
|
}
|
||||||
|
if got.TargetIDs[0] != targets[0] {
|
||||||
|
t.Errorf("TargetIDs[0] = %q, want %q", got.TargetIDs[0], targets[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGet_PopulatesTargetIDs_MultipleTargets: many mappings → sorted by target_id ASC.
|
||||||
|
func TestGet_PopulatesTargetIDs_MultipleTargets(t *testing.T) {
|
||||||
|
tdb := getTestDB(t)
|
||||||
|
db := tdb.freshSchema(t)
|
||||||
|
repo := postgres.NewCertificateRepository(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "getmany")
|
||||||
|
_, targets := insertAgentAndTargetsRaw(t, db, ctx, "getmany", 3)
|
||||||
|
|
||||||
|
certID := "mc-getmany"
|
||||||
|
insertCertificateRow(t, db, ctx, certID, ownerID, teamID, issuerID, policyID, time.Now().Add(30*24*time.Hour))
|
||||||
|
// Insert mappings in reverse order to confirm ORDER BY target_id ASC in the query.
|
||||||
|
insertMapping(t, db, ctx, certID, targets[2])
|
||||||
|
insertMapping(t, db, ctx, certID, targets[0])
|
||||||
|
insertMapping(t, db, ctx, certID, targets[1])
|
||||||
|
|
||||||
|
got, err := repo.Get(ctx, certID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(got.TargetIDs) != 3 {
|
||||||
|
t.Fatalf("len(TargetIDs) = %d, want 3; got %v", len(got.TargetIDs), got.TargetIDs)
|
||||||
|
}
|
||||||
|
// Ascending order: t-getmany-0, t-getmany-1, t-getmany-2
|
||||||
|
want := []string{targets[0], targets[1], targets[2]}
|
||||||
|
for i, tid := range want {
|
||||||
|
if got.TargetIDs[i] != tid {
|
||||||
|
t.Errorf("TargetIDs[%d] = %q, want %q (full: %v)", i, got.TargetIDs[i], tid, got.TargetIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------
|
||||||
|
// List() — batch read path, must avoid N+1
|
||||||
|
// --------------------------------------------------------------------
|
||||||
|
|
||||||
|
// TestList_PopulatesTargetIDs_BatchFetch: three certs with different mapping counts;
|
||||||
|
// all must have their TargetIDs populated correctly, and the cert with no mapping
|
||||||
|
// must get an empty (non-nil) slice.
|
||||||
|
func TestList_PopulatesTargetIDs_BatchFetch(t *testing.T) {
|
||||||
|
tdb := getTestDB(t)
|
||||||
|
db := tdb.freshSchema(t)
|
||||||
|
repo := postgres.NewCertificateRepository(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "listbatch")
|
||||||
|
_, targets := insertAgentAndTargetsRaw(t, db, ctx, "listbatch", 3)
|
||||||
|
|
||||||
|
certA := "mc-list-a"
|
||||||
|
certB := "mc-list-b"
|
||||||
|
certC := "mc-list-c"
|
||||||
|
insertCertificateRow(t, db, ctx, certA, ownerID, teamID, issuerID, policyID, time.Now().Add(30*24*time.Hour))
|
||||||
|
insertCertificateRow(t, db, ctx, certB, ownerID, teamID, issuerID, policyID, time.Now().Add(30*24*time.Hour))
|
||||||
|
insertCertificateRow(t, db, ctx, certC, ownerID, teamID, issuerID, policyID, time.Now().Add(30*24*time.Hour))
|
||||||
|
|
||||||
|
// certA → 2 targets (t-0, t-1)
|
||||||
|
insertMapping(t, db, ctx, certA, targets[0])
|
||||||
|
insertMapping(t, db, ctx, certA, targets[1])
|
||||||
|
// certB → 1 target (t-2)
|
||||||
|
insertMapping(t, db, ctx, certB, targets[2])
|
||||||
|
// certC → 0 targets
|
||||||
|
|
||||||
|
got, total, err := repo.List(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List failed: %v", err)
|
||||||
|
}
|
||||||
|
if total < 3 {
|
||||||
|
t.Fatalf("total = %d, want >= 3", total)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := map[string][]string{
|
||||||
|
certA: {targets[0], targets[1]},
|
||||||
|
certB: {targets[2]},
|
||||||
|
certC: {},
|
||||||
|
}
|
||||||
|
seen := map[string]bool{}
|
||||||
|
for _, c := range got {
|
||||||
|
exp, ok := want[c.ID]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[c.ID] = true
|
||||||
|
if c.TargetIDs == nil {
|
||||||
|
t.Errorf("cert %s: TargetIDs = nil, want %v", c.ID, exp)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(c.TargetIDs) != len(exp) {
|
||||||
|
t.Errorf("cert %s: len(TargetIDs) = %d, want %d (got %v, want %v)", c.ID, len(c.TargetIDs), len(exp), c.TargetIDs, exp)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for i, tid := range exp {
|
||||||
|
if c.TargetIDs[i] != tid {
|
||||||
|
t.Errorf("cert %s: TargetIDs[%d] = %q, want %q", c.ID, i, c.TargetIDs[i], tid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for id := range want {
|
||||||
|
if !seen[id] {
|
||||||
|
t.Errorf("cert %s missing from List() result", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------
|
||||||
|
// GetExpiringCertificates() — scheduler read path
|
||||||
|
// --------------------------------------------------------------------
|
||||||
|
|
||||||
|
// TestGetExpiringCertificates_PopulatesTargetIDs: expiring certs must also carry
|
||||||
|
// their mapping information so renewal-triggered deployments can route work.
|
||||||
|
func TestGetExpiringCertificates_PopulatesTargetIDs(t *testing.T) {
|
||||||
|
tdb := getTestDB(t)
|
||||||
|
db := tdb.freshSchema(t)
|
||||||
|
repo := postgres.NewCertificateRepository(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "expiring")
|
||||||
|
_, targets := insertAgentAndTargetsRaw(t, db, ctx, "expiring", 2)
|
||||||
|
|
||||||
|
// Two expiring certs (expires in 3 days). Threshold = 7 days → both selected.
|
||||||
|
certA := "mc-exp-a"
|
||||||
|
certB := "mc-exp-b"
|
||||||
|
expiresSoon := time.Now().Add(3 * 24 * time.Hour)
|
||||||
|
insertCertificateRow(t, db, ctx, certA, ownerID, teamID, issuerID, policyID, expiresSoon)
|
||||||
|
insertCertificateRow(t, db, ctx, certB, ownerID, teamID, issuerID, policyID, expiresSoon)
|
||||||
|
|
||||||
|
insertMapping(t, db, ctx, certA, targets[0])
|
||||||
|
insertMapping(t, db, ctx, certA, targets[1])
|
||||||
|
// certB has no mappings.
|
||||||
|
|
||||||
|
threshold := time.Now().Add(7 * 24 * time.Hour)
|
||||||
|
got, err := repo.GetExpiringCertificates(ctx, threshold)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetExpiringCertificates failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
found := map[string]*domain.ManagedCertificate{}
|
||||||
|
for _, c := range got {
|
||||||
|
found[c.ID] = c
|
||||||
|
}
|
||||||
|
|
||||||
|
a, ok := found[certA]
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("cert %s not in expiring list", certA)
|
||||||
|
}
|
||||||
|
if len(a.TargetIDs) != 2 || a.TargetIDs[0] != targets[0] || a.TargetIDs[1] != targets[1] {
|
||||||
|
t.Errorf("cert %s: TargetIDs = %v, want %v", certA, a.TargetIDs, []string{targets[0], targets[1]})
|
||||||
|
}
|
||||||
|
|
||||||
|
b, ok := found[certB]
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("cert %s not in expiring list", certB)
|
||||||
|
}
|
||||||
|
if b.TargetIDs == nil {
|
||||||
|
t.Errorf("cert %s: TargetIDs = nil, want empty slice", certB)
|
||||||
|
}
|
||||||
|
if len(b.TargetIDs) != 0 {
|
||||||
|
t.Errorf("cert %s: len(TargetIDs) = %d, want 0", certB, len(b.TargetIDs))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -457,6 +457,193 @@ func TestAgentRepository_Delete_NotFound(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestAgentRepository_CreateIfNotExists_FirstInsert verifies that a brand-new
|
||||||
|
// sentinel agent row is inserted and the helper reports created=true (M-6).
|
||||||
|
func TestAgentRepository_CreateIfNotExists_FirstInsert(t *testing.T) {
|
||||||
|
tdb := getTestDB(t)
|
||||||
|
db := tdb.freshSchema(t)
|
||||||
|
repo := postgres.NewAgentRepository(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
now := time.Now().Truncate(time.Microsecond)
|
||||||
|
agent := &domain.Agent{
|
||||||
|
ID: "server-scanner",
|
||||||
|
Name: "Network Scanner (Server-Side)",
|
||||||
|
Status: domain.AgentStatusOnline,
|
||||||
|
RegisteredAt: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
created, err := repo.CreateIfNotExists(ctx, agent)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateIfNotExists failed: %v", err)
|
||||||
|
}
|
||||||
|
if !created {
|
||||||
|
t.Error("created = false on first insert, want true")
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := repo.Get(ctx, "server-scanner")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get failed: %v", err)
|
||||||
|
}
|
||||||
|
if got.Name != "Network Scanner (Server-Side)" {
|
||||||
|
t.Errorf("Name = %q, want %q", got.Name, "Network Scanner (Server-Side)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAgentRepository_CreateIfNotExists_Idempotent verifies that a second
|
||||||
|
// call with the same ID returns created=false and err=nil without mutating
|
||||||
|
// the existing row — the core M-6 upgrade/restart scenario (CWE-662).
|
||||||
|
func TestAgentRepository_CreateIfNotExists_Idempotent(t *testing.T) {
|
||||||
|
tdb := getTestDB(t)
|
||||||
|
db := tdb.freshSchema(t)
|
||||||
|
repo := postgres.NewAgentRepository(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
now := time.Now().Truncate(time.Microsecond)
|
||||||
|
first := &domain.Agent{
|
||||||
|
ID: "cloud-aws-sm",
|
||||||
|
Name: "AWS Secrets Manager Discovery",
|
||||||
|
Status: domain.AgentStatusOnline,
|
||||||
|
RegisteredAt: now,
|
||||||
|
}
|
||||||
|
created, err := repo.CreateIfNotExists(ctx, first)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first CreateIfNotExists failed: %v", err)
|
||||||
|
}
|
||||||
|
if !created {
|
||||||
|
t.Fatal("first created = false, want true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second call with the same ID but a different name must be a no-op.
|
||||||
|
second := &domain.Agent{
|
||||||
|
ID: "cloud-aws-sm",
|
||||||
|
Name: "Overwritten Name Should Not Persist",
|
||||||
|
Status: domain.AgentStatusOffline,
|
||||||
|
RegisteredAt: now.Add(time.Hour),
|
||||||
|
}
|
||||||
|
created, err = repo.CreateIfNotExists(ctx, second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second CreateIfNotExists failed: %v", err)
|
||||||
|
}
|
||||||
|
if created {
|
||||||
|
t.Error("second created = true, want false (row already existed)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Row must still reflect the original insert.
|
||||||
|
got, err := repo.Get(ctx, "cloud-aws-sm")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get failed: %v", err)
|
||||||
|
}
|
||||||
|
if got.Name != "AWS Secrets Manager Discovery" {
|
||||||
|
t.Errorf("Name = %q, want %q (ON CONFLICT DO NOTHING must preserve original row)", got.Name, "AWS Secrets Manager Discovery")
|
||||||
|
}
|
||||||
|
if got.Status != domain.AgentStatusOnline {
|
||||||
|
t.Errorf("Status = %q, want %q", got.Status, domain.AgentStatusOnline)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAgentRepository_CreateIfNotExists_ConcurrentRace fires N concurrent
|
||||||
|
// inserts for the same sentinel ID. Exactly one goroutine must see
|
||||||
|
// created=true; every other must see created=false and err=nil. No panics,
|
||||||
|
// no duplicate rows, no swallowed errors. This is the scenario that the
|
||||||
|
// pre-M-6 plain-INSERT path masked with a blanket error log.
|
||||||
|
func TestAgentRepository_CreateIfNotExists_ConcurrentRace(t *testing.T) {
|
||||||
|
tdb := getTestDB(t)
|
||||||
|
db := tdb.freshSchema(t)
|
||||||
|
repo := postgres.NewAgentRepository(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
const N = 16
|
||||||
|
now := time.Now().Truncate(time.Microsecond)
|
||||||
|
|
||||||
|
var (
|
||||||
|
wg sync.WaitGroup
|
||||||
|
createdCount int64
|
||||||
|
errorCount int64
|
||||||
|
)
|
||||||
|
wg.Add(N)
|
||||||
|
for i := 0; i < N; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
agent := &domain.Agent{
|
||||||
|
ID: "cloud-gcp-sm",
|
||||||
|
Name: "GCP Secret Manager Discovery",
|
||||||
|
Status: domain.AgentStatusOnline,
|
||||||
|
RegisteredAt: now,
|
||||||
|
}
|
||||||
|
created, err := repo.CreateIfNotExists(ctx, agent)
|
||||||
|
if err != nil {
|
||||||
|
atomic.AddInt64(&errorCount, 1)
|
||||||
|
t.Errorf("CreateIfNotExists returned error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if created {
|
||||||
|
atomic.AddInt64(&createdCount, 1)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if errorCount != 0 {
|
||||||
|
t.Fatalf("errorCount = %d, want 0", errorCount)
|
||||||
|
}
|
||||||
|
if createdCount != 1 {
|
||||||
|
t.Errorf("createdCount = %d, want exactly 1 (only one goroutine may win the insert)", createdCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exactly one row must exist.
|
||||||
|
agents, err := repo.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List failed: %v", err)
|
||||||
|
}
|
||||||
|
count := 0
|
||||||
|
for _, a := range agents {
|
||||||
|
if a.ID == "cloud-gcp-sm" {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
t.Errorf("row count for cloud-gcp-sm = %d, want 1", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAgentRepository_CreateIfNotExists_GenericErrorSurfaces verifies that
|
||||||
|
// failures other than the primary-key duplicate (the only collision
|
||||||
|
// ON CONFLICT (id) absorbs) propagate to the caller instead of being
|
||||||
|
// swallowed. This is the security property that M-6 restores: the
|
||||||
|
// pre-fix plain-INSERT path logged every error at Debug level, so a
|
||||||
|
// connectivity or permission failure would vanish into the log without
|
||||||
|
// the server surfacing a problem on startup (CWE-662 / CWE-209-adjacent).
|
||||||
|
//
|
||||||
|
// Uses a pre-cancelled context to force QueryRowContext to fail with
|
||||||
|
// context.Canceled — a non-duplicate error class that must surface.
|
||||||
|
// Does NOT close the shared sql.DB (that would break sibling tests).
|
||||||
|
func TestAgentRepository_CreateIfNotExists_GenericErrorSurfaces(t *testing.T) {
|
||||||
|
tdb := getTestDB(t)
|
||||||
|
db := tdb.freshSchema(t)
|
||||||
|
repo := postgres.NewAgentRepository(db)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // pre-cancel so the driver round-trip fails immediately.
|
||||||
|
|
||||||
|
agent := &domain.Agent{
|
||||||
|
ID: "server-scanner",
|
||||||
|
Name: "Network Scanner (Server-Side)",
|
||||||
|
Status: domain.AgentStatusOnline,
|
||||||
|
RegisteredAt: time.Now(),
|
||||||
|
}
|
||||||
|
created, err := repo.CreateIfNotExists(ctx, agent)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error on cancelled context, got nil (error would have been swallowed pre-M-6)")
|
||||||
|
}
|
||||||
|
if created {
|
||||||
|
t.Error("created = true on failure, want false")
|
||||||
|
}
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
t.Error("got sql.ErrNoRows, want a real connection/context error (ErrNoRows is the duplicate-row sentinel)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ============================================================
|
// ============================================================
|
||||||
// Issuer Repository Tests
|
// Issuer Repository Tests
|
||||||
// ============================================================
|
// ============================================================
|
||||||
|
|||||||
@@ -91,8 +91,8 @@ func (s *AgentService) Register(ctx context.Context, name string, hostname strin
|
|||||||
return agent, apiKey, nil
|
return agent, apiKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// HeartbeatWithContext updates an agent's last seen time, status, and metadata.
|
// Heartbeat updates an agent's last seen time, status, and metadata.
|
||||||
func (s *AgentService) HeartbeatWithContext(ctx context.Context, agentID string, metadata *domain.AgentMetadata) error {
|
func (s *AgentService) Heartbeat(ctx context.Context, agentID string, metadata *domain.AgentMetadata) error {
|
||||||
agent, err := s.agentRepo.Get(ctx, agentID)
|
agent, err := s.agentRepo.Get(ctx, agentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to fetch agent: %w", err)
|
return fmt.Errorf("failed to fetch agent: %w", err)
|
||||||
@@ -114,12 +114,6 @@ func (s *AgentService) HeartbeatWithContext(ctx context.Context, agentID string,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Heartbeat updates agent heartbeat (handler interface method).
|
|
||||||
// Note: This method is called from handlers which have a context; callers should prefer HeartbeatWithContext.
|
|
||||||
func (s *AgentService) Heartbeat(ctx context.Context, agentID string, metadata *domain.AgentMetadata) error {
|
|
||||||
return s.HeartbeatWithContext(ctx, agentID, metadata)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SubmitCSR validates and processes a Certificate Signing Request from an agent.
|
// SubmitCSR validates and processes a Certificate Signing Request from an agent.
|
||||||
// In agent keygen mode, this completes an AwaitingCSR renewal job by signing the CSR
|
// In agent keygen mode, this completes an AwaitingCSR renewal job by signing the CSR
|
||||||
// and storing the cert version. The private key stays on the agent — only the CSR
|
// and storing the cert version. The private key stays on the agent — only the CSR
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ func TestHeartbeat(t *testing.T) {
|
|||||||
|
|
||||||
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry, nil)
|
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry, nil)
|
||||||
|
|
||||||
err := agentService.HeartbeatWithContext(ctx, "agent-001", nil)
|
err := agentService.Heartbeat(ctx, "agent-001", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Heartbeat failed: %v", err)
|
t.Fatalf("Heartbeat failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -125,7 +125,7 @@ func TestHeartbeat_NotFound(t *testing.T) {
|
|||||||
|
|
||||||
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry, nil)
|
agentService := NewAgentService(agentRepo, certRepo, jobRepo, targetRepo, auditService, issuerRegistry, nil)
|
||||||
|
|
||||||
err := agentService.HeartbeatWithContext(ctx, "nonexistent", nil)
|
err := agentService.Heartbeat(ctx, "nonexistent", nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for nonexistent agent")
|
t.Fatal("expected error for nonexistent agent")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ func (s *AuditService) ListByAction(ctx context.Context, action string, from, to
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListAuditEvents returns paginated audit events (handler interface method).
|
// ListAuditEvents returns paginated audit events (handler interface method).
|
||||||
func (s *AuditService) ListAuditEvents(page, perPage int) ([]domain.AuditEvent, int64, error) {
|
func (s *AuditService) ListAuditEvents(ctx context.Context, page, perPage int) ([]domain.AuditEvent, int64, error) {
|
||||||
if page < 1 {
|
if page < 1 {
|
||||||
page = 1
|
page = 1
|
||||||
}
|
}
|
||||||
@@ -123,7 +123,7 @@ func (s *AuditService) ListAuditEvents(page, perPage int) ([]domain.AuditEvent,
|
|||||||
PerPage: perPage,
|
PerPage: perPage,
|
||||||
}
|
}
|
||||||
|
|
||||||
events, err := s.auditRepo.List(context.Background(), filter)
|
events, err := s.auditRepo.List(ctx, filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to list audit events: %w", err)
|
return nil, 0, fmt.Errorf("failed to list audit events: %w", err)
|
||||||
}
|
}
|
||||||
@@ -143,13 +143,13 @@ func (s *AuditService) ListAuditEvents(page, perPage int) ([]domain.AuditEvent,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAuditEvent returns a single audit event (handler interface method).
|
// GetAuditEvent returns a single audit event (handler interface method).
|
||||||
func (s *AuditService) GetAuditEvent(id string) (*domain.AuditEvent, error) {
|
func (s *AuditService) GetAuditEvent(ctx context.Context, id string) (*domain.AuditEvent, error) {
|
||||||
filter := &repository.AuditFilter{
|
filter := &repository.AuditFilter{
|
||||||
ResourceID: id,
|
ResourceID: id,
|
||||||
PerPage: 1,
|
PerPage: 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
events, err := s.auditRepo.List(context.Background(), filter)
|
events, err := s.auditRepo.List(ctx, filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get audit event: %w", err)
|
return nil, fmt.Errorf("failed to get audit event: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ func (s *CAOperationsSvc) SetIssuerRegistry(registry *IssuerRegistry) {
|
|||||||
|
|
||||||
// GenerateDERCRL generates a DER-encoded X.509 CRL for the given issuer.
|
// GenerateDERCRL generates a DER-encoded X.509 CRL for the given issuer.
|
||||||
// Short-lived certificates (profile TTL < 1 hour) are excluded from the CRL.
|
// Short-lived certificates (profile TTL < 1 hour) are excluded from the CRL.
|
||||||
func (s *CAOperationsSvc) GenerateDERCRL(issuerID string) ([]byte, error) {
|
func (s *CAOperationsSvc) GenerateDERCRL(ctx context.Context, issuerID string) ([]byte, error) {
|
||||||
if s.revocationRepo == nil {
|
if s.revocationRepo == nil {
|
||||||
return nil, fmt.Errorf("revocation repository not configured")
|
return nil, fmt.Errorf("revocation repository not configured")
|
||||||
}
|
}
|
||||||
@@ -54,7 +54,7 @@ func (s *CAOperationsSvc) GenerateDERCRL(issuerID string) ([]byte, error) {
|
|||||||
return nil, fmt.Errorf("issuer not found: %s", issuerID)
|
return nil, fmt.Errorf("issuer not found: %s", issuerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
revocations, err := s.revocationRepo.ListAll(context.Background())
|
revocations, err := s.revocationRepo.ListAll(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to list revocations: %w", err)
|
return nil, fmt.Errorf("failed to list revocations: %w", err)
|
||||||
}
|
}
|
||||||
@@ -69,9 +69,9 @@ func (s *CAOperationsSvc) GenerateDERCRL(issuerID string) ([]byte, error) {
|
|||||||
|
|
||||||
// Check short-lived exemption: look up the cert's profile
|
// Check short-lived exemption: look up the cert's profile
|
||||||
if s.profileRepo != nil && s.certRepo != nil {
|
if s.profileRepo != nil && s.certRepo != nil {
|
||||||
cert, err := s.certRepo.Get(context.Background(), rev.CertificateID)
|
cert, err := s.certRepo.Get(ctx, rev.CertificateID)
|
||||||
if err == nil && cert.CertificateProfileID != "" {
|
if err == nil && cert.CertificateProfileID != "" {
|
||||||
profile, err := s.profileRepo.Get(context.Background(), cert.CertificateProfileID)
|
profile, err := s.profileRepo.Get(ctx, cert.CertificateProfileID)
|
||||||
if err == nil && profile.IsShortLived() {
|
if err == nil && profile.IsShortLived() {
|
||||||
slog.Debug("skipping short-lived cert from CRL",
|
slog.Debug("skipping short-lived cert from CRL",
|
||||||
"certificate_id", rev.CertificateID,
|
"certificate_id", rev.CertificateID,
|
||||||
@@ -92,11 +92,11 @@ func (s *CAOperationsSvc) GenerateDERCRL(issuerID string) ([]byte, error) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return issuerConn.GenerateCRL(context.Background(), entries)
|
return issuerConn.GenerateCRL(ctx, entries)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOCSPResponse generates a signed OCSP response for the given certificate serial.
|
// GetOCSPResponse generates a signed OCSP response for the given certificate serial.
|
||||||
func (s *CAOperationsSvc) GetOCSPResponse(issuerID string, serialHex string) ([]byte, error) {
|
func (s *CAOperationsSvc) GetOCSPResponse(ctx context.Context, issuerID string, serialHex string) ([]byte, error) {
|
||||||
if s.revocationRepo == nil {
|
if s.revocationRepo == nil {
|
||||||
return nil, fmt.Errorf("revocation repository not configured")
|
return nil, fmt.Errorf("revocation repository not configured")
|
||||||
}
|
}
|
||||||
@@ -120,13 +120,13 @@ func (s *CAOperationsSvc) GetOCSPResponse(issuerID string, serialHex string) ([]
|
|||||||
// Look up cert by (issuer_id, serial) — per RFC 5280 §5.2.3, serial numbers
|
// Look up cert by (issuer_id, serial) — per RFC 5280 §5.2.3, serial numbers
|
||||||
// are unique only within a single issuer. The OCSP URL path carries issuer_id,
|
// are unique only within a single issuer. The OCSP URL path carries issuer_id,
|
||||||
// so we scope the lookup to avoid cross-issuer collisions.
|
// so we scope the lookup to avoid cross-issuer collisions.
|
||||||
rev, _ := s.revocationRepo.GetByIssuerAndSerial(context.Background(), issuerID, serialHex)
|
rev, _ := s.revocationRepo.GetByIssuerAndSerial(ctx, issuerID, serialHex)
|
||||||
if rev != nil {
|
if rev != nil {
|
||||||
cert, err := s.certRepo.Get(context.Background(), rev.CertificateID)
|
cert, err := s.certRepo.Get(ctx, rev.CertificateID)
|
||||||
if err == nil && cert.CertificateProfileID != "" {
|
if err == nil && cert.CertificateProfileID != "" {
|
||||||
profile, err := s.profileRepo.Get(context.Background(), cert.CertificateProfileID)
|
profile, err := s.profileRepo.Get(ctx, cert.CertificateProfileID)
|
||||||
if err == nil && profile.IsShortLived() {
|
if err == nil && profile.IsShortLived() {
|
||||||
return issuerConn.SignOCSPResponse(context.Background(), OCSPSignRequest{
|
return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{
|
||||||
CertSerial: serial,
|
CertSerial: serial,
|
||||||
CertStatus: 0, // good — short-lived exemption
|
CertStatus: 0, // good — short-lived exemption
|
||||||
ThisUpdate: now,
|
ThisUpdate: now,
|
||||||
@@ -138,10 +138,10 @@ func (s *CAOperationsSvc) GetOCSPResponse(issuerID string, serialHex string) ([]
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if this (issuer_id, serial) is revoked — RFC 5280 §5.2.3 scoping.
|
// Check if this (issuer_id, serial) is revoked — RFC 5280 §5.2.3 scoping.
|
||||||
rev, err := s.revocationRepo.GetByIssuerAndSerial(context.Background(), issuerID, serialHex)
|
rev, err := s.revocationRepo.GetByIssuerAndSerial(ctx, issuerID, serialHex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Not revoked — return "good" status
|
// Not revoked — return "good" status
|
||||||
return issuerConn.SignOCSPResponse(context.Background(), OCSPSignRequest{
|
return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{
|
||||||
CertSerial: serial,
|
CertSerial: serial,
|
||||||
CertStatus: 0, // good
|
CertStatus: 0, // good
|
||||||
ThisUpdate: now,
|
ThisUpdate: now,
|
||||||
@@ -150,7 +150,7 @@ func (s *CAOperationsSvc) GetOCSPResponse(issuerID string, serialHex string) ([]
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Revoked
|
// Revoked
|
||||||
return issuerConn.SignOCSPResponse(context.Background(), OCSPSignRequest{
|
return issuerConn.SignOCSPResponse(ctx, OCSPSignRequest{
|
||||||
CertSerial: serial,
|
CertSerial: serial,
|
||||||
CertStatus: 1, // revoked
|
CertStatus: 1, // revoked
|
||||||
RevokedAt: rev.RevokedAt,
|
RevokedAt: rev.RevokedAt,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -48,7 +49,7 @@ func TestCAOperationsSvc_GenerateDERCRL_Success(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
crl, err := caSvc.GenerateDERCRL("iss-local")
|
crl, err := caSvc.GenerateDERCRL(context.Background(), "iss-local")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected no error, got: %v", err)
|
t.Fatalf("expected no error, got: %v", err)
|
||||||
@@ -71,7 +72,7 @@ func TestCAOperationsSvc_GenerateDERCRL_EmptyCRL(t *testing.T) {
|
|||||||
// No revoked certs for this issuer
|
// No revoked certs for this issuer
|
||||||
revocationRepo.Revocations = []*domain.CertificateRevocation{}
|
revocationRepo.Revocations = []*domain.CertificateRevocation{}
|
||||||
|
|
||||||
crl, err := caSvc.GenerateDERCRL("iss-local")
|
crl, err := caSvc.GenerateDERCRL(context.Background(), "iss-local")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected no error, got: %v", err)
|
t.Fatalf("expected no error, got: %v", err)
|
||||||
@@ -112,7 +113,7 @@ func TestCAOperationsSvc_GetOCSPResponse_Good(t *testing.T) {
|
|||||||
certRepo.Versions["cert-ocsp-good"] = []*domain.CertificateVersion{version}
|
certRepo.Versions["cert-ocsp-good"] = []*domain.CertificateVersion{version}
|
||||||
|
|
||||||
// Request OCSP response for good cert
|
// Request OCSP response for good cert
|
||||||
resp, err := caSvc.GetOCSPResponse("iss-local", "OCSP-GOOD-001")
|
resp, err := caSvc.GetOCSPResponse(context.Background(), "iss-local", "OCSP-GOOD-001")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected no error, got: %v", err)
|
t.Fatalf("expected no error, got: %v", err)
|
||||||
@@ -165,7 +166,7 @@ func TestCAOperationsSvc_GetOCSPResponse_Revoked(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Request OCSP response for revoked cert
|
// Request OCSP response for revoked cert
|
||||||
resp, err := caSvc.GetOCSPResponse("iss-local", "OCSP-REVOKED-001")
|
resp, err := caSvc.GetOCSPResponse(context.Background(), "iss-local", "OCSP-REVOKED-001")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected no error, got: %v", err)
|
t.Fatalf("expected no error, got: %v", err)
|
||||||
|
|||||||
@@ -71,8 +71,8 @@ func (s *CertificateService) List(ctx context.Context, filter *repository.Certif
|
|||||||
|
|
||||||
// ListCertificatesWithFilter returns a list of certificates with advanced filtering (M20).
|
// ListCertificatesWithFilter returns a list of certificates with advanced filtering (M20).
|
||||||
// This method supports the new M20 filters and returns domain.ManagedCertificate (not pointers).
|
// This method supports the new M20 filters and returns domain.ManagedCertificate (not pointers).
|
||||||
func (s *CertificateService) ListCertificatesWithFilter(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
func (s *CertificateService) ListCertificatesWithFilter(ctx context.Context, filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
certs, total, err := s.certRepo.List(context.Background(), filter)
|
certs, total, err := s.certRepo.List(ctx, filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to list certificates with filter: %w", err)
|
return nil, 0, fmt.Errorf("failed to list certificates with filter: %w", err)
|
||||||
}
|
}
|
||||||
@@ -206,10 +206,10 @@ func (s *CertificateService) GetVersions(ctx context.Context, certID string) ([]
|
|||||||
return versions, nil
|
return versions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TriggerRenewalWithActor initiates a renewal job if the certificate is eligible.
|
// TriggerRenewal initiates a renewal job if the certificate is eligible.
|
||||||
// Creates a Renewal job (or Issuance for new certs) so the scheduler's job processor
|
// Creates a Renewal job (or Issuance for new certs) so the scheduler's job processor
|
||||||
// can pick it up and route it through the issuer connector.
|
// can pick it up and route it through the issuer connector.
|
||||||
func (s *CertificateService) TriggerRenewalWithActor(ctx context.Context, certID string, actor string) error {
|
func (s *CertificateService) TriggerRenewal(ctx context.Context, certID string, actor string) error {
|
||||||
cert, err := s.certRepo.Get(ctx, certID)
|
cert, err := s.certRepo.Get(ctx, certID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to fetch certificate: %w", err)
|
return fmt.Errorf("failed to fetch certificate: %w", err)
|
||||||
@@ -283,8 +283,11 @@ func (s *CertificateService) TriggerRenewalWithActor(ctx context.Context, certID
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TriggerDeploymentWithActor creates deployment jobs for all targets of a certificate.
|
// TriggerDeployment creates deployment jobs for all targets of a certificate.
|
||||||
func (s *CertificateService) TriggerDeploymentWithActor(ctx context.Context, certID string, actor string) error {
|
// The targetID parameter is accepted from the handler interface but currently unused;
|
||||||
|
// deployment coordination happens per-certificate across all of its targets.
|
||||||
|
func (s *CertificateService) TriggerDeployment(ctx context.Context, certID string, targetID string, actor string) error {
|
||||||
|
_ = targetID
|
||||||
cert, err := s.certRepo.Get(ctx, certID)
|
cert, err := s.certRepo.Get(ctx, certID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to fetch certificate: %w", err)
|
return fmt.Errorf("failed to fetch certificate: %w", err)
|
||||||
@@ -306,7 +309,7 @@ func (s *CertificateService) TriggerDeploymentWithActor(ctx context.Context, cer
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListCertificates returns paginated certificates with optional filtering (handler interface method).
|
// ListCertificates returns paginated certificates with optional filtering (handler interface method).
|
||||||
func (s *CertificateService) ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
|
func (s *CertificateService) ListCertificates(ctx context.Context, status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
|
||||||
if page < 1 {
|
if page < 1 {
|
||||||
page = 1
|
page = 1
|
||||||
}
|
}
|
||||||
@@ -325,7 +328,7 @@ func (s *CertificateService) ListCertificates(status, environment, ownerID, team
|
|||||||
PerPage: perPage,
|
PerPage: perPage,
|
||||||
}
|
}
|
||||||
|
|
||||||
certs, total, err := s.certRepo.List(context.Background(), filter)
|
certs, total, err := s.certRepo.List(ctx, filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to list certificates: %w", err)
|
return nil, 0, fmt.Errorf("failed to list certificates: %w", err)
|
||||||
}
|
}
|
||||||
@@ -341,12 +344,12 @@ func (s *CertificateService) ListCertificates(status, environment, ownerID, team
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetCertificate returns a single certificate (handler interface method).
|
// GetCertificate returns a single certificate (handler interface method).
|
||||||
func (s *CertificateService) GetCertificate(id string) (*domain.ManagedCertificate, error) {
|
func (s *CertificateService) GetCertificate(ctx context.Context, id string) (*domain.ManagedCertificate, error) {
|
||||||
return s.certRepo.Get(context.Background(), id)
|
return s.certRepo.Get(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateCertificate creates a new certificate (handler interface method).
|
// CreateCertificate creates a new certificate (handler interface method).
|
||||||
func (s *CertificateService) CreateCertificate(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
func (s *CertificateService) CreateCertificate(ctx context.Context, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
||||||
if cert.ID == "" {
|
if cert.ID == "" {
|
||||||
cert.ID = generateID("cert")
|
cert.ID = generateID("cert")
|
||||||
}
|
}
|
||||||
@@ -365,16 +368,14 @@ func (s *CertificateService) CreateCertificate(cert domain.ManagedCertificate) (
|
|||||||
if cert.Tags == nil {
|
if cert.Tags == nil {
|
||||||
cert.Tags = make(map[string]string)
|
cert.Tags = make(map[string]string)
|
||||||
}
|
}
|
||||||
if err := s.certRepo.Create(context.Background(), &cert); err != nil {
|
if err := s.certRepo.Create(ctx, &cert); err != nil {
|
||||||
return nil, fmt.Errorf("failed to create certificate: %w", err)
|
return nil, fmt.Errorf("failed to create certificate: %w", err)
|
||||||
}
|
}
|
||||||
return &cert, nil
|
return &cert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateCertificate modifies a certificate (handler interface method).
|
// UpdateCertificate modifies a certificate (handler interface method).
|
||||||
func (s *CertificateService) UpdateCertificate(id string, patch domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
func (s *CertificateService) UpdateCertificate(ctx context.Context, id string, patch domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// Fetch existing certificate so partial updates don't zero out fields
|
// Fetch existing certificate so partial updates don't zero out fields
|
||||||
existing, err := s.certRepo.Get(ctx, id)
|
existing, err := s.certRepo.Get(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -425,12 +426,12 @@ func (s *CertificateService) UpdateCertificate(id string, patch domain.ManagedCe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ArchiveCertificate marks a certificate as archived (handler interface method).
|
// ArchiveCertificate marks a certificate as archived (handler interface method).
|
||||||
func (s *CertificateService) ArchiveCertificate(id string) error {
|
func (s *CertificateService) ArchiveCertificate(ctx context.Context, id string) error {
|
||||||
return s.certRepo.Archive(context.Background(), id)
|
return s.certRepo.Archive(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCertificateVersions returns certificate versions (handler interface method).
|
// GetCertificateVersions returns certificate versions (handler interface method).
|
||||||
func (s *CertificateService) GetCertificateVersions(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
|
func (s *CertificateService) GetCertificateVersions(ctx context.Context, certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) {
|
||||||
if page < 1 {
|
if page < 1 {
|
||||||
page = 1
|
page = 1
|
||||||
}
|
}
|
||||||
@@ -438,7 +439,7 @@ func (s *CertificateService) GetCertificateVersions(certID string, page, perPage
|
|||||||
perPage = 50
|
perPage = 50
|
||||||
}
|
}
|
||||||
|
|
||||||
versions, err := s.certRepo.ListVersions(context.Background(), certID)
|
versions, err := s.certRepo.ListVersions(ctx, certID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to list certificate versions: %w", err)
|
return nil, 0, fmt.Errorf("failed to list certificate versions: %w", err)
|
||||||
}
|
}
|
||||||
@@ -463,24 +464,8 @@ func (s *CertificateService) GetCertificateVersions(certID string, page, perPage
|
|||||||
return result, total, nil
|
return result, total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TriggerRenewal initiates renewal (handler interface method).
|
// RevokeCertificate performs revocation with actor tracking. Delegates to RevocationSvc.
|
||||||
func (s *CertificateService) TriggerRenewal(certID string) error {
|
func (s *CertificateService) RevokeCertificate(ctx context.Context, certID string, reason string, actor string) error {
|
||||||
return s.TriggerRenewalWithActor(context.Background(), certID, "api")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TriggerDeployment triggers deployment (handler interface method).
|
|
||||||
func (s *CertificateService) TriggerDeployment(certID string, targetID string) error {
|
|
||||||
return s.TriggerDeploymentWithActor(context.Background(), certID, "api")
|
|
||||||
}
|
|
||||||
|
|
||||||
// RevokeCertificate revokes a certificate with the given reason (handler interface method).
|
|
||||||
func (s *CertificateService) RevokeCertificate(certID string, reason string) error {
|
|
||||||
return s.RevokeCertificateWithActor(context.Background(), certID, reason, "api")
|
|
||||||
}
|
|
||||||
|
|
||||||
// RevokeCertificateWithActor performs revocation with actor tracking.
|
|
||||||
// Delegates to RevocationSvc.
|
|
||||||
func (s *CertificateService) RevokeCertificateWithActor(ctx context.Context, certID string, reason string, actor string) error {
|
|
||||||
if s.revSvc == nil {
|
if s.revSvc == nil {
|
||||||
return fmt.Errorf("revocation service not configured")
|
return fmt.Errorf("revocation service not configured")
|
||||||
}
|
}
|
||||||
@@ -489,35 +474,35 @@ func (s *CertificateService) RevokeCertificateWithActor(ctx context.Context, cer
|
|||||||
|
|
||||||
// GetRevokedCertificates returns all revoked certificate records (for CRL generation).
|
// GetRevokedCertificates returns all revoked certificate records (for CRL generation).
|
||||||
// Delegates to RevocationSvc.
|
// Delegates to RevocationSvc.
|
||||||
func (s *CertificateService) GetRevokedCertificates() ([]*domain.CertificateRevocation, error) {
|
func (s *CertificateService) GetRevokedCertificates(ctx context.Context) ([]*domain.CertificateRevocation, error) {
|
||||||
if s.revSvc == nil {
|
if s.revSvc == nil {
|
||||||
return nil, fmt.Errorf("revocation service not configured")
|
return nil, fmt.Errorf("revocation service not configured")
|
||||||
}
|
}
|
||||||
return s.revSvc.GetRevokedCertificates()
|
return s.revSvc.GetRevokedCertificates(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateDERCRL generates a DER-encoded X.509 CRL for the given issuer.
|
// GenerateDERCRL generates a DER-encoded X.509 CRL for the given issuer.
|
||||||
// Delegates to CAOperationsSvc.
|
// Delegates to CAOperationsSvc.
|
||||||
func (s *CertificateService) GenerateDERCRL(issuerID string) ([]byte, error) {
|
func (s *CertificateService) GenerateDERCRL(ctx context.Context, issuerID string) ([]byte, error) {
|
||||||
if s.caSvc == nil {
|
if s.caSvc == nil {
|
||||||
return nil, fmt.Errorf("CA operations service not configured")
|
return nil, fmt.Errorf("CA operations service not configured")
|
||||||
}
|
}
|
||||||
return s.caSvc.GenerateDERCRL(issuerID)
|
return s.caSvc.GenerateDERCRL(ctx, issuerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOCSPResponse generates a signed OCSP response for the given certificate serial.
|
// GetOCSPResponse generates a signed OCSP response for the given certificate serial.
|
||||||
// Delegates to CAOperationsSvc.
|
// Delegates to CAOperationsSvc.
|
||||||
func (s *CertificateService) GetOCSPResponse(issuerID string, serialHex string) ([]byte, error) {
|
func (s *CertificateService) GetOCSPResponse(ctx context.Context, issuerID string, serialHex string) ([]byte, error) {
|
||||||
if s.caSvc == nil {
|
if s.caSvc == nil {
|
||||||
return nil, fmt.Errorf("CA operations service not configured")
|
return nil, fmt.Errorf("CA operations service not configured")
|
||||||
}
|
}
|
||||||
return s.caSvc.GetOCSPResponse(issuerID, serialHex)
|
return s.caSvc.GetOCSPResponse(ctx, issuerID, serialHex)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCertificateDeployments returns all deployment targets for a certificate (M20).
|
// GetCertificateDeployments returns all deployment targets for a certificate (M20).
|
||||||
func (s *CertificateService) GetCertificateDeployments(certID string) ([]domain.DeploymentTarget, error) {
|
func (s *CertificateService) GetCertificateDeployments(ctx context.Context, certID string) ([]domain.DeploymentTarget, error) {
|
||||||
// Verify certificate exists
|
// Verify certificate exists
|
||||||
_, err := s.certRepo.Get(context.Background(), certID)
|
_, err := s.certRepo.Get(ctx, certID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("certificate not found: %w", err)
|
return nil, fmt.Errorf("certificate not found: %w", err)
|
||||||
}
|
}
|
||||||
@@ -527,7 +512,7 @@ func (s *CertificateService) GetCertificateDeployments(certID string) ([]domain.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get targets from repository
|
// Get targets from repository
|
||||||
targets, err := s.targetRepo.ListByCertificate(context.Background(), certID)
|
targets, err := s.targetRepo.ListByCertificate(ctx, certID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to list deployment targets: %w", err)
|
return nil, fmt.Errorf("failed to list deployment targets: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ func TestCertificateService_RevokeCertificate_RevocationSvcNil(t *testing.T) {
|
|||||||
certRepo.AddCert(cert)
|
certRepo.AddCert(cert)
|
||||||
|
|
||||||
// Call RevokeCertificateWithActor with nil RevocationSvc
|
// Call RevokeCertificateWithActor with nil RevocationSvc
|
||||||
err := certService.RevokeCertificateWithActor(context.Background(), "cert-1", "keyCompromise", "admin")
|
err := certService.RevokeCertificate(context.Background(), "cert-1", "keyCompromise", "admin")
|
||||||
|
|
||||||
// Assert: Should return error, NOT panic
|
// Assert: Should return error, NOT panic
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -64,7 +64,7 @@ func TestCertificateService_GenerateDERCRL_CAOpsSvcNil(t *testing.T) {
|
|||||||
// Note: NOT calling certService.SetCAOperationsSvc(...)
|
// Note: NOT calling certService.SetCAOperationsSvc(...)
|
||||||
|
|
||||||
// Call GenerateDERCRL with nil CAOperationsSvc
|
// Call GenerateDERCRL with nil CAOperationsSvc
|
||||||
_, err := certService.GenerateDERCRL("iss-local")
|
_, err := certService.GenerateDERCRL(context.Background(), "iss-local")
|
||||||
|
|
||||||
// Assert: Should return error, NOT panic
|
// Assert: Should return error, NOT panic
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -94,7 +94,7 @@ func TestCertificateService_GetOCSPResponse_CAOpsSvcNil(t *testing.T) {
|
|||||||
// Note: NOT calling certService.SetCAOperationsSvc(...)
|
// Note: NOT calling certService.SetCAOperationsSvc(...)
|
||||||
|
|
||||||
// Call GetOCSPResponse with nil CAOperationsSvc
|
// Call GetOCSPResponse with nil CAOperationsSvc
|
||||||
_, err := certService.GetOCSPResponse("iss-local", "serial123")
|
_, err := certService.GetOCSPResponse(context.Background(), "iss-local", "serial123")
|
||||||
|
|
||||||
// Assert: Should return error, NOT panic
|
// Assert: Should return error, NOT panic
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -124,7 +124,7 @@ func TestCertificateService_GetRevokedCertificates_RevocationSvcNil(t *testing.T
|
|||||||
// Note: NOT calling certService.SetRevocationSvc(...)
|
// Note: NOT calling certService.SetRevocationSvc(...)
|
||||||
|
|
||||||
// Call GetRevokedCertificates with nil RevocationSvc
|
// Call GetRevokedCertificates with nil RevocationSvc
|
||||||
_, err := certService.GetRevokedCertificates()
|
_, err := certService.GetRevokedCertificates(context.Background())
|
||||||
|
|
||||||
// Assert: Should return error, NOT panic
|
// Assert: Should return error, NOT panic
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -177,7 +177,7 @@ func TestCertificateService_GetCertificateDeployments_Success(t *testing.T) {
|
|||||||
targetRepo.AddTarget(target2)
|
targetRepo.AddTarget(target2)
|
||||||
|
|
||||||
// Call GetCertificateDeployments
|
// Call GetCertificateDeployments
|
||||||
deployments, err := certService.GetCertificateDeployments("cert-1")
|
deployments, err := certService.GetCertificateDeployments(context.Background(), "cert-1")
|
||||||
|
|
||||||
// Assert: Should return deployment list successfully
|
// Assert: Should return deployment list successfully
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -218,7 +218,7 @@ func TestCertificateService_GetCertificateDeployments_RepositoryError(t *testing
|
|||||||
certRepo.AddCert(cert)
|
certRepo.AddCert(cert)
|
||||||
|
|
||||||
// Call GetCertificateDeployments with repo error
|
// Call GetCertificateDeployments with repo error
|
||||||
_, err := certService.GetCertificateDeployments("cert-1")
|
_, err := certService.GetCertificateDeployments(context.Background(), "cert-1")
|
||||||
|
|
||||||
// Assert: Should return error, NOT panic
|
// Assert: Should return error, NOT panic
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -247,7 +247,7 @@ func TestCertificateService_GetCertificateDeployments_CertNotFound(t *testing.T)
|
|||||||
certService.SetTargetRepo(targetRepo)
|
certService.SetTargetRepo(targetRepo)
|
||||||
|
|
||||||
// Call GetCertificateDeployments with nonexistent certificate
|
// Call GetCertificateDeployments with nonexistent certificate
|
||||||
_, err := certService.GetCertificateDeployments("nonexistent-cert")
|
_, err := certService.GetCertificateDeployments(context.Background(), "nonexistent-cert")
|
||||||
|
|
||||||
// Assert: Should return error
|
// Assert: Should return error
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -283,7 +283,7 @@ func TestCertificateService_GetCertificateDeployments_NilTargetRepo(t *testing.T
|
|||||||
certRepo.AddCert(cert)
|
certRepo.AddCert(cert)
|
||||||
|
|
||||||
// Call GetCertificateDeployments with nil TargetRepo
|
// Call GetCertificateDeployments with nil TargetRepo
|
||||||
deployments, err := certService.GetCertificateDeployments("cert-1")
|
deployments, err := certService.GetCertificateDeployments(context.Background(), "cert-1")
|
||||||
|
|
||||||
// Assert: Should return empty list gracefully (not panic)
|
// Assert: Should return empty list gracefully (not panic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -337,19 +337,19 @@ func TestCertificateService_Multiple_NilSafetyChecks(t *testing.T) {
|
|||||||
revSvc.SetIssuerRegistry(registry)
|
revSvc.SetIssuerRegistry(registry)
|
||||||
|
|
||||||
// Test 1: RevokeCertificateWithActor should succeed (RevocationSvc is set)
|
// Test 1: RevokeCertificateWithActor should succeed (RevocationSvc is set)
|
||||||
errRevoke := certService.RevokeCertificateWithActor(context.Background(), "cert-1", "keyCompromise", "admin")
|
errRevoke := certService.RevokeCertificate(context.Background(), "cert-1", "keyCompromise", "admin")
|
||||||
if errRevoke != nil {
|
if errRevoke != nil {
|
||||||
t.Fatalf("RevokeCertificateWithActor failed unexpectedly: %v", errRevoke)
|
t.Fatalf("RevokeCertificateWithActor failed unexpectedly: %v", errRevoke)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test 2: GenerateDERCRL should fail gracefully (CAOperationsSvc is nil)
|
// Test 2: GenerateDERCRL should fail gracefully (CAOperationsSvc is nil)
|
||||||
_, errCRL := certService.GenerateDERCRL("iss-local")
|
_, errCRL := certService.GenerateDERCRL(context.Background(), "iss-local")
|
||||||
if errCRL == nil {
|
if errCRL == nil {
|
||||||
t.Fatal("GenerateDERCRL expected error, got nil")
|
t.Fatal("GenerateDERCRL expected error, got nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test 3: GetOCSPResponse should fail gracefully (CAOperationsSvc is nil)
|
// Test 3: GetOCSPResponse should fail gracefully (CAOperationsSvc is nil)
|
||||||
_, errOCSP := certService.GetOCSPResponse("iss-local", "ABC123")
|
_, errOCSP := certService.GetOCSPResponse(context.Background(), "iss-local", "ABC123")
|
||||||
if errOCSP == nil {
|
if errOCSP == nil {
|
||||||
t.Fatal("GetOCSPResponse expected error, got nil")
|
t.Fatal("GetOCSPResponse expected error, got nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -294,7 +294,7 @@ func TestTriggerRenewal(t *testing.T) {
|
|||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
certService := NewCertificateService(certRepo, policyService, auditService)
|
certService := NewCertificateService(certRepo, policyService, auditService)
|
||||||
|
|
||||||
err := certService.TriggerRenewalWithActor(ctx, "cert-001", "user-1")
|
err := certService.TriggerRenewal(ctx, "cert-001", "user-1")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("TriggerRenewal failed: %v", err)
|
t.Fatalf("TriggerRenewal failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -333,13 +333,14 @@ func TestTriggerRenewal_Archived(t *testing.T) {
|
|||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
certService := NewCertificateService(certRepo, policyService, auditService)
|
certService := NewCertificateService(certRepo, policyService, auditService)
|
||||||
|
|
||||||
err := certService.TriggerRenewalWithActor(ctx, "cert-001", "user-1")
|
err := certService.TriggerRenewal(ctx, "cert-001", "user-1")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for archived certificate")
|
t.Fatal("expected error for archived certificate")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestListCertificates(t *testing.T) {
|
func TestListCertificates(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
cert1 := &domain.ManagedCertificate{
|
cert1 := &domain.ManagedCertificate{
|
||||||
ID: "cert-001",
|
ID: "cert-001",
|
||||||
@@ -369,7 +370,7 @@ func TestListCertificates(t *testing.T) {
|
|||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
certService := NewCertificateService(certRepo, policyService, auditService)
|
certService := NewCertificateService(certRepo, policyService, auditService)
|
||||||
|
|
||||||
certs, total, err := certService.ListCertificates("", "", "", "", "", 1, 50)
|
certs, total, err := certService.ListCertificates(ctx, "", "", "", "", "", 1, 50)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ListCertificates failed: %v", err)
|
t.Fatalf("ListCertificates failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -159,7 +159,7 @@ func TestConcurrentAgentHeartbeats(t *testing.T) {
|
|||||||
Architecture: "x86_64",
|
Architecture: "x86_64",
|
||||||
}
|
}
|
||||||
|
|
||||||
err := agentSvc.HeartbeatWithContext(ctx, agentID, metadata)
|
err := agentSvc.Heartbeat(ctx, agentID, metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("goroutine %d: failed heartbeat for agent %s: %w", idx, agentID, err)
|
errChan <- fmt.Errorf("goroutine %d: failed heartbeat for agent %s: %w", idx, agentID, err)
|
||||||
return
|
return
|
||||||
@@ -194,7 +194,7 @@ func TestConcurrentTargetCRUD(t *testing.T) {
|
|||||||
Targets: make(map[string]*domain.DeploymentTarget),
|
Targets: make(map[string]*domain.DeploymentTarget),
|
||||||
}
|
}
|
||||||
|
|
||||||
targetSvc := NewTargetService(mockTargetRepo, nil, nil, nil, slog.New(slog.NewTextHandler(os.Stderr, nil)))
|
targetSvc := NewTargetService(mockTargetRepo, nil, nil, "", slog.New(slog.NewTextHandler(os.Stderr, nil)))
|
||||||
|
|
||||||
var mu sync.Mutex
|
var mu sync.Mutex
|
||||||
createdTargets := make([]string, 0)
|
createdTargets := make([]string, 0)
|
||||||
@@ -403,7 +403,7 @@ func TestConcurrentMixedOperations(t *testing.T) {
|
|||||||
// Setup services
|
// Setup services
|
||||||
auditSvc := &AuditService{auditRepo: mockAuditRepo}
|
auditSvc := &AuditService{auditRepo: mockAuditRepo}
|
||||||
certSvc := NewCertificateService(mockCertRepo, nil, auditSvc)
|
certSvc := NewCertificateService(mockCertRepo, nil, auditSvc)
|
||||||
targetSvc := NewTargetService(mockTargetRepo, auditSvc, nil, nil, slog.New(slog.NewTextHandler(os.Stderr, nil)))
|
targetSvc := NewTargetService(mockTargetRepo, auditSvc, nil, "", slog.New(slog.NewTextHandler(os.Stderr, nil)))
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
errChan := make(chan error, 30)
|
errChan := make(chan error, 30)
|
||||||
|
|||||||
@@ -142,7 +142,7 @@ func TestTargetService_ListWithCancelledContext(t *testing.T) {
|
|||||||
mockTargetRepo := &mockTargetRepo{
|
mockTargetRepo := &mockTargetRepo{
|
||||||
Targets: make(map[string]*domain.DeploymentTarget),
|
Targets: make(map[string]*domain.DeploymentTarget),
|
||||||
}
|
}
|
||||||
targetSvc := NewTargetService(mockTargetRepo, nil, nil, nil, slog.New(slog.NewTextHandler(os.Stderr, nil)))
|
targetSvc := NewTargetService(mockTargetRepo, nil, nil, "", slog.New(slog.NewTextHandler(os.Stderr, nil)))
|
||||||
|
|
||||||
_, _, err := targetSvc.List(ctx, 1, 50)
|
_, _, err := targetSvc.List(ctx, 1, 50)
|
||||||
|
|
||||||
@@ -176,13 +176,13 @@ func TestAgentService_HeartbeatWithCancelledContext(t *testing.T) {
|
|||||||
nil, // renewalService
|
nil, // renewalService
|
||||||
)
|
)
|
||||||
|
|
||||||
err := agentSvc.HeartbeatWithContext(ctx, "agent-1", &domain.AgentMetadata{})
|
err := agentSvc.Heartbeat(ctx, "agent-1", &domain.AgentMetadata{})
|
||||||
|
|
||||||
// Service should handle cancelled context
|
// Service should handle cancelled context
|
||||||
if err == nil || ctx.Err() == context.Canceled {
|
if err == nil || ctx.Err() == context.Canceled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
t.Logf("HeartbeatWithContext with cancelled context returned: %v", err)
|
t.Logf("Heartbeat with cancelled context returned: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test with timeout context (should trigger deadline exceeded)
|
// Test with timeout context (should trigger deadline exceeded)
|
||||||
@@ -229,11 +229,11 @@ func TestAgentService_HeartbeatWithDeadlineExceeded(t *testing.T) {
|
|||||||
|
|
||||||
time.Sleep(10 * time.Millisecond) // Ensure deadline is exceeded
|
time.Sleep(10 * time.Millisecond) // Ensure deadline is exceeded
|
||||||
|
|
||||||
err := agentSvc.HeartbeatWithContext(ctx, "agent-1", &domain.AgentMetadata{})
|
err := agentSvc.Heartbeat(ctx, "agent-1", &domain.AgentMetadata{})
|
||||||
|
|
||||||
// Service should handle deadline exceeded
|
// Service should handle deadline exceeded
|
||||||
if err == nil || ctx.Err() == context.DeadlineExceeded {
|
if err == nil || ctx.Err() == context.DeadlineExceeded {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
t.Logf("HeartbeatWithContext with deadline exceeded returned: %v", err)
|
t.Logf("Heartbeat with deadline exceeded returned: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
+25
-23
@@ -17,20 +17,27 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// IssuerService provides business logic for certificate issuer management.
|
// IssuerService provides business logic for certificate issuer management.
|
||||||
|
//
|
||||||
|
// The encryptionKey field holds the raw passphrase (not a pre-derived 32-byte
|
||||||
|
// key). Per-ciphertext salt derivation is performed inside
|
||||||
|
// [crypto.EncryptIfKeySet] / [crypto.DecryptIfKeySet] on each call. See M-8
|
||||||
|
// in certctl-audit-report.md.
|
||||||
type IssuerService struct {
|
type IssuerService struct {
|
||||||
issuerRepo repository.IssuerRepository
|
issuerRepo repository.IssuerRepository
|
||||||
auditService *AuditService
|
auditService *AuditService
|
||||||
registry *IssuerRegistry
|
registry *IssuerRegistry
|
||||||
encryptionKey []byte
|
encryptionKey string
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewIssuerService creates a new issuer service.
|
// NewIssuerService creates a new issuer service. The encryptionKey is the raw
|
||||||
|
// passphrase; it MUST NOT be pre-derived via crypto.DeriveKey (that was the
|
||||||
|
// v1 behavior, replaced in M-8 with per-ciphertext random salt).
|
||||||
func NewIssuerService(
|
func NewIssuerService(
|
||||||
issuerRepo repository.IssuerRepository,
|
issuerRepo repository.IssuerRepository,
|
||||||
auditService *AuditService,
|
auditService *AuditService,
|
||||||
registry *IssuerRegistry,
|
registry *IssuerRegistry,
|
||||||
encryptionKey []byte,
|
encryptionKey string,
|
||||||
logger *slog.Logger,
|
logger *slog.Logger,
|
||||||
) *IssuerService {
|
) *IssuerService {
|
||||||
return &IssuerService{
|
return &IssuerService{
|
||||||
@@ -253,9 +260,9 @@ func (s *IssuerService) Delete(ctx context.Context, id string, actor string) err
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestConnectionWithContext tests the connection to an issuer by instantiating a throwaway
|
// TestConnection tests the connection to an issuer by instantiating a throwaway
|
||||||
// connector and calling ValidateConfig. Records the result in the database.
|
// connector and calling ValidateConfig. Records the result in the database.
|
||||||
func (s *IssuerService) TestConnectionWithContext(ctx context.Context, id string) error {
|
func (s *IssuerService) TestConnection(ctx context.Context, id string) error {
|
||||||
iss, err := s.issuerRepo.Get(ctx, id)
|
iss, err := s.issuerRepo.Get(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("issuer not found: %w", err)
|
return fmt.Errorf("issuer not found: %w", err)
|
||||||
@@ -284,11 +291,6 @@ func (s *IssuerService) TestConnectionWithContext(ctx context.Context, id string
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestConnection verifies the issuer connection (handler interface method).
|
|
||||||
func (s *IssuerService) TestConnection(id string) error {
|
|
||||||
return s.TestConnectionWithContext(context.Background(), id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuildRegistry loads all enabled issuers from the database and rebuilds the dynamic registry.
|
// BuildRegistry loads all enabled issuers from the database and rebuilds the dynamic registry.
|
||||||
// Called at server startup. Partial failures (individual issuers failing to load) are logged
|
// Called at server startup. Partial failures (individual issuers failing to load) are logged
|
||||||
// as warnings but don't prevent the server from starting.
|
// as warnings but don't prevent the server from starting.
|
||||||
@@ -626,7 +628,7 @@ func (s *IssuerService) buildEnvVarSeeds(cfg *config.Config) []*domain.Issuer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListIssuers returns paginated issuers (handler interface method).
|
// ListIssuers returns paginated issuers (handler interface method).
|
||||||
func (s *IssuerService) ListIssuers(page, perPage int) ([]domain.Issuer, int64, error) {
|
func (s *IssuerService) ListIssuers(ctx context.Context, page, perPage int) ([]domain.Issuer, int64, error) {
|
||||||
if page < 1 {
|
if page < 1 {
|
||||||
page = 1
|
page = 1
|
||||||
}
|
}
|
||||||
@@ -634,7 +636,7 @@ func (s *IssuerService) ListIssuers(page, perPage int) ([]domain.Issuer, int64,
|
|||||||
perPage = 50
|
perPage = 50
|
||||||
}
|
}
|
||||||
|
|
||||||
issuers, err := s.issuerRepo.List(context.Background())
|
issuers, err := s.issuerRepo.List(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to list issuers: %w", err)
|
return nil, 0, fmt.Errorf("failed to list issuers: %w", err)
|
||||||
}
|
}
|
||||||
@@ -651,12 +653,12 @@ func (s *IssuerService) ListIssuers(page, perPage int) ([]domain.Issuer, int64,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetIssuer returns a single issuer (handler interface method).
|
// GetIssuer returns a single issuer (handler interface method).
|
||||||
func (s *IssuerService) GetIssuer(id string) (*domain.Issuer, error) {
|
func (s *IssuerService) GetIssuer(ctx context.Context, id string) (*domain.Issuer, error) {
|
||||||
return s.issuerRepo.Get(context.Background(), id)
|
return s.issuerRepo.Get(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateIssuer creates a new issuer (handler interface method).
|
// CreateIssuer creates a new issuer (handler interface method).
|
||||||
func (s *IssuerService) CreateIssuer(iss domain.Issuer) (*domain.Issuer, error) {
|
func (s *IssuerService) CreateIssuer(ctx context.Context, iss domain.Issuer) (*domain.Issuer, error) {
|
||||||
iss.Type = normalizeIssuerType(iss.Type)
|
iss.Type = normalizeIssuerType(iss.Type)
|
||||||
if !isValidIssuerType(iss.Type) {
|
if !isValidIssuerType(iss.Type) {
|
||||||
return nil, fmt.Errorf("unsupported issuer type: %s", iss.Type)
|
return nil, fmt.Errorf("unsupported issuer type: %s", iss.Type)
|
||||||
@@ -693,26 +695,26 @@ func (s *IssuerService) CreateIssuer(iss domain.Issuer) (*domain.Issuer, error)
|
|||||||
iss.Config = redactConfigJSON(iss.Config)
|
iss.Config = redactConfigJSON(iss.Config)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.issuerRepo.Create(context.Background(), &iss); err != nil {
|
if err := s.issuerRepo.Create(ctx, &iss); err != nil {
|
||||||
return nil, fmt.Errorf("failed to create issuer: %w", err)
|
return nil, fmt.Errorf("failed to create issuer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rebuild registry
|
// Rebuild registry
|
||||||
if iss.Enabled {
|
if iss.Enabled {
|
||||||
s.rebuildRegistryQuiet(context.Background())
|
s.rebuildRegistryQuiet(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &iss, nil
|
return &iss, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateIssuer modifies an issuer (handler interface method).
|
// UpdateIssuer modifies an issuer (handler interface method).
|
||||||
func (s *IssuerService) UpdateIssuer(id string, iss domain.Issuer) (*domain.Issuer, error) {
|
func (s *IssuerService) UpdateIssuer(ctx context.Context, id string, iss domain.Issuer) (*domain.Issuer, error) {
|
||||||
iss.ID = id
|
iss.ID = id
|
||||||
iss.UpdatedAt = time.Now()
|
iss.UpdatedAt = time.Now()
|
||||||
|
|
||||||
// Merge redacted fields with existing config
|
// Merge redacted fields with existing config
|
||||||
if len(iss.Config) > 0 {
|
if len(iss.Config) > 0 {
|
||||||
mergedConfig, err := s.mergeRedactedConfig(context.Background(), id, iss.Config)
|
mergedConfig, err := s.mergeRedactedConfig(ctx, id, iss.Config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to merge config: %w", err)
|
return nil, fmt.Errorf("failed to merge config: %w", err)
|
||||||
}
|
}
|
||||||
@@ -725,18 +727,18 @@ func (s *IssuerService) UpdateIssuer(id string, iss domain.Issuer) (*domain.Issu
|
|||||||
iss.Config = redactConfigJSON(json.RawMessage(mergedConfig))
|
iss.Config = redactConfigJSON(json.RawMessage(mergedConfig))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.issuerRepo.Update(context.Background(), &iss); err != nil {
|
if err := s.issuerRepo.Update(ctx, &iss); err != nil {
|
||||||
return nil, fmt.Errorf("failed to update issuer: %w", err)
|
return nil, fmt.Errorf("failed to update issuer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.rebuildRegistryQuiet(context.Background())
|
s.rebuildRegistryQuiet(ctx)
|
||||||
|
|
||||||
return &iss, nil
|
return &iss, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteIssuer removes an issuer (handler interface method).
|
// DeleteIssuer removes an issuer (handler interface method).
|
||||||
func (s *IssuerService) DeleteIssuer(id string) error {
|
func (s *IssuerService) DeleteIssuer(ctx context.Context, id string) error {
|
||||||
if err := s.issuerRepo.Delete(context.Background(), id); err != nil {
|
if err := s.issuerRepo.Delete(ctx, id); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if s.registry != nil {
|
if s.registry != nil {
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ func TestBuildEnvVarSeeds_ACMEConfig(t *testing.T) {
|
|||||||
auditRepo := newMockAuditRepository()
|
auditRepo := newMockAuditRepository()
|
||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default())
|
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
|
||||||
|
|
||||||
// Call buildEnvVarSeeds (unexported method, but testable from same package)
|
// Call buildEnvVarSeeds (unexported method, but testable from same package)
|
||||||
seeds := service.buildEnvVarSeeds(cfg)
|
seeds := service.buildEnvVarSeeds(cfg)
|
||||||
@@ -82,7 +82,7 @@ func TestBuildEnvVarSeeds_VaultConfig(t *testing.T) {
|
|||||||
auditRepo := newMockAuditRepository()
|
auditRepo := newMockAuditRepository()
|
||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default())
|
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
|
||||||
|
|
||||||
seeds := service.buildEnvVarSeeds(cfg)
|
seeds := service.buildEnvVarSeeds(cfg)
|
||||||
|
|
||||||
@@ -136,7 +136,7 @@ func TestBuildEnvVarSeeds_NoConfig(t *testing.T) {
|
|||||||
auditRepo := newMockAuditRepository()
|
auditRepo := newMockAuditRepository()
|
||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default())
|
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
|
||||||
|
|
||||||
seeds := service.buildEnvVarSeeds(cfg)
|
seeds := service.buildEnvVarSeeds(cfg)
|
||||||
|
|
||||||
@@ -186,7 +186,7 @@ func TestBuildEnvVarSeeds_MultipleConfigs(t *testing.T) {
|
|||||||
auditRepo := newMockAuditRepository()
|
auditRepo := newMockAuditRepository()
|
||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default())
|
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
|
||||||
|
|
||||||
seeds := service.buildEnvVarSeeds(cfg)
|
seeds := service.buildEnvVarSeeds(cfg)
|
||||||
|
|
||||||
@@ -232,7 +232,7 @@ func TestSeedFromEnvVars_Empty(t *testing.T) {
|
|||||||
auditRepo := newMockAuditRepository()
|
auditRepo := newMockAuditRepository()
|
||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default())
|
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
|
||||||
|
|
||||||
// Call SeedFromEnvVars on empty repo
|
// Call SeedFromEnvVars on empty repo
|
||||||
service.SeedFromEnvVars(ctx, cfg)
|
service.SeedFromEnvVars(ctx, cfg)
|
||||||
@@ -280,7 +280,7 @@ func TestSeedFromEnvVars_AlreadyExists(t *testing.T) {
|
|||||||
auditRepo := newMockAuditRepository()
|
auditRepo := newMockAuditRepository()
|
||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default())
|
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
|
||||||
|
|
||||||
// Get count before seeding
|
// Get count before seeding
|
||||||
beforeSeeding, _ := repo.List(ctx)
|
beforeSeeding, _ := repo.List(ctx)
|
||||||
@@ -328,7 +328,7 @@ func TestBuildRegistry_Success(t *testing.T) {
|
|||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
registry := NewIssuerRegistry(slog.Default())
|
registry := NewIssuerRegistry(slog.Default())
|
||||||
service := NewIssuerService(repo, auditService, registry, nil, slog.Default())
|
service := NewIssuerService(repo, auditService, registry, "", slog.Default())
|
||||||
|
|
||||||
// Call BuildRegistry
|
// Call BuildRegistry
|
||||||
err := service.BuildRegistry(ctx)
|
err := service.BuildRegistry(ctx)
|
||||||
@@ -351,7 +351,7 @@ func TestBuildRegistry_EmptyDatabase(t *testing.T) {
|
|||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
registry := NewIssuerRegistry(slog.Default())
|
registry := NewIssuerRegistry(slog.Default())
|
||||||
service := NewIssuerService(repo, auditService, registry, nil, slog.Default())
|
service := NewIssuerService(repo, auditService, registry, "", slog.Default())
|
||||||
|
|
||||||
// Call BuildRegistry on empty database
|
// Call BuildRegistry on empty database
|
||||||
err := service.BuildRegistry(ctx)
|
err := service.BuildRegistry(ctx)
|
||||||
|
|||||||
@@ -72,7 +72,12 @@ func (r *IssuerRegistry) Len() int {
|
|||||||
// For each enabled issuer, it decrypts the config (if encryption key is set),
|
// For each enabled issuer, it decrypts the config (if encryption key is set),
|
||||||
// instantiates a connector via the factory, wraps it in an adapter, and
|
// instantiates a connector via the factory, wraps it in an adapter, and
|
||||||
// atomically swaps the entire map.
|
// atomically swaps the entire map.
|
||||||
func (r *IssuerRegistry) Rebuild(configs []*domain.Issuer, encryptionKey []byte) error {
|
//
|
||||||
|
// The encryption passphrase is passed as a string; per-ciphertext salt derivation
|
||||||
|
// for v2 blobs is performed inside [crypto.DecryptIfKeySet]. Empty passphrase
|
||||||
|
// fails closed via [crypto.ErrEncryptionKeyRequired] when encrypted configs
|
||||||
|
// are encountered. See M-8 in certctl-audit-report.md.
|
||||||
|
func (r *IssuerRegistry) Rebuild(configs []*domain.Issuer, encryptionKey string) error {
|
||||||
newIssuers := make(map[string]IssuerConnector)
|
newIssuers := make(map[string]IssuerConnector)
|
||||||
var errors []string
|
var errors []string
|
||||||
|
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ func TestIssuerRegistry_Rebuild_Enabled(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := reg.Rebuild(configs, nil)
|
err := reg.Rebuild(configs, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Rebuild failed: %v", err)
|
t.Fatalf("Rebuild failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -124,11 +124,12 @@ func TestIssuerRegistry_Rebuild_Enabled(t *testing.T) {
|
|||||||
func TestIssuerRegistry_Rebuild_WithEncryption(t *testing.T) {
|
func TestIssuerRegistry_Rebuild_WithEncryption(t *testing.T) {
|
||||||
reg := NewIssuerRegistry(registryTestLogger())
|
reg := NewIssuerRegistry(registryTestLogger())
|
||||||
|
|
||||||
key := crypto.DeriveKey("test-key")
|
|
||||||
configJSON := []byte(`{"ca_common_name":"Encrypted CA"}`)
|
configJSON := []byte(`{"ca_common_name":"Encrypted CA"}`)
|
||||||
encrypted, err := crypto.Encrypt(configJSON, key)
|
// M-8: EncryptIfKeySet now emits v2 (magic 0x02 || per-ciphertext salt || sealed).
|
||||||
|
// IssuerRegistry.Rebuild accepts the raw passphrase and delegates PBKDF2 to crypto.DecryptIfKeySet.
|
||||||
|
encrypted, _, err := crypto.EncryptIfKeySet(configJSON, "test-key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("encrypt failed: %v", err)
|
t.Fatalf("EncryptIfKeySet failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
configs := []*domain.Issuer{
|
configs := []*domain.Issuer{
|
||||||
@@ -141,7 +142,7 @@ func TestIssuerRegistry_Rebuild_WithEncryption(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err = reg.Rebuild(configs, key)
|
err = reg.Rebuild(configs, "test-key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Rebuild with encryption failed: %v", err)
|
t.Fatalf("Rebuild with encryption failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -165,10 +166,11 @@ func TestIssuerRegistry_Rebuild_NilKeyFallback(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// nil key should work — falls back to config column
|
// Empty passphrase is safe when no EncryptedConfig is present — falls back to config column.
|
||||||
err := reg.Rebuild(configs, nil)
|
// The C-2 fail-closed sentinel only fires when EncryptedConfig is non-empty.
|
||||||
|
err := reg.Rebuild(configs, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Rebuild with nil key failed: %v", err)
|
t.Fatalf("Rebuild with empty key failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok := reg.Get("iss-plain")
|
_, ok := reg.Get("iss-plain")
|
||||||
@@ -198,7 +200,7 @@ func TestIssuerRegistry_Rebuild_InvalidConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Should return an error indicating partial failure, but still load valid issuers
|
// Should return an error indicating partial failure, but still load valid issuers
|
||||||
err := reg.Rebuild(configs, nil)
|
err := reg.Rebuild(configs, "")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("Rebuild should return error when some issuers fail to load")
|
t.Fatal("Rebuild should return error when some issuers fail to load")
|
||||||
}
|
}
|
||||||
@@ -230,7 +232,7 @@ func TestIssuerRegistry_Rebuild_ReplacesExisting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := reg.Rebuild(configs, nil)
|
err := reg.Rebuild(configs, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Rebuild failed: %v", err)
|
t.Fatalf("Rebuild failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -275,7 +277,7 @@ func TestIssuerRegistry_Rebuild_Empty(t *testing.T) {
|
|||||||
|
|
||||||
reg.Set("iss-existing", &mockIssuerConnector{})
|
reg.Set("iss-existing", &mockIssuerConnector{})
|
||||||
|
|
||||||
err := reg.Rebuild([]*domain.Issuer{}, nil)
|
err := reg.Rebuild([]*domain.Issuer{}, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Rebuild with empty configs failed: %v", err)
|
t.Fatalf("Rebuild with empty configs failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ func TestIssuerService_List(t *testing.T) {
|
|||||||
auditRepo := newMockAuditRepository()
|
auditRepo := newMockAuditRepository()
|
||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default())
|
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
|
||||||
|
|
||||||
issuers, total, err := service.List(ctx, 1, 2)
|
issuers, total, err := service.List(ctx, 1, 2)
|
||||||
|
|
||||||
@@ -87,7 +87,7 @@ func TestIssuerService_List_DefaultPagination(t *testing.T) {
|
|||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
registry := NewIssuerRegistry(slog.Default())
|
registry := NewIssuerRegistry(slog.Default())
|
||||||
service := NewIssuerService(repo, auditService, registry, nil, slog.Default())
|
service := NewIssuerService(repo, auditService, registry, "", slog.Default())
|
||||||
|
|
||||||
// Call with invalid page and perPage
|
// Call with invalid page and perPage
|
||||||
issuers, total, err := service.List(ctx, 0, 0)
|
issuers, total, err := service.List(ctx, 0, 0)
|
||||||
@@ -115,7 +115,7 @@ func TestIssuerService_List_RepositoryError(t *testing.T) {
|
|||||||
auditRepo := newMockAuditRepository()
|
auditRepo := newMockAuditRepository()
|
||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default())
|
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
|
||||||
|
|
||||||
_, _, err := service.List(ctx, 1, 50)
|
_, _, err := service.List(ctx, 1, 50)
|
||||||
|
|
||||||
@@ -137,7 +137,7 @@ func TestIssuerService_List_EmptyResult(t *testing.T) {
|
|||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
registry := NewIssuerRegistry(slog.Default())
|
registry := NewIssuerRegistry(slog.Default())
|
||||||
service := NewIssuerService(repo, auditService, registry, nil, slog.Default())
|
service := NewIssuerService(repo, auditService, registry, "", slog.Default())
|
||||||
|
|
||||||
issuers, total, err := service.List(ctx, 1, 50)
|
issuers, total, err := service.List(ctx, 1, 50)
|
||||||
|
|
||||||
@@ -173,7 +173,7 @@ func TestIssuerService_Get(t *testing.T) {
|
|||||||
auditRepo := newMockAuditRepository()
|
auditRepo := newMockAuditRepository()
|
||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default())
|
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
|
||||||
|
|
||||||
retrieved, err := service.Get(ctx, "iss-acme-prod")
|
retrieved, err := service.Get(ctx, "iss-acme-prod")
|
||||||
|
|
||||||
@@ -199,7 +199,7 @@ func TestIssuerService_Get_NotFound(t *testing.T) {
|
|||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
registry := NewIssuerRegistry(slog.Default())
|
registry := NewIssuerRegistry(slog.Default())
|
||||||
service := NewIssuerService(repo, auditService, registry, nil, slog.Default())
|
service := NewIssuerService(repo, auditService, registry, "", slog.Default())
|
||||||
|
|
||||||
_, err := service.Get(ctx, "nonexistent-issuer")
|
_, err := service.Get(ctx, "nonexistent-issuer")
|
||||||
|
|
||||||
@@ -280,7 +280,7 @@ func TestIssuerService_Create_EmptyName(t *testing.T) {
|
|||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
registry := NewIssuerRegistry(slog.Default())
|
registry := NewIssuerRegistry(slog.Default())
|
||||||
service := NewIssuerService(repo, auditService, registry, nil, slog.Default())
|
service := NewIssuerService(repo, auditService, registry, "", slog.Default())
|
||||||
|
|
||||||
issuer := &domain.Issuer{
|
issuer := &domain.Issuer{
|
||||||
Name: "",
|
Name: "",
|
||||||
@@ -314,7 +314,7 @@ func TestIssuerService_Create_RepositoryError(t *testing.T) {
|
|||||||
auditRepo := newMockAuditRepository()
|
auditRepo := newMockAuditRepository()
|
||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default())
|
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
|
||||||
|
|
||||||
issuer := &domain.Issuer{
|
issuer := &domain.Issuer{
|
||||||
Name: "Test Issuer",
|
Name: "Test Issuer",
|
||||||
@@ -387,7 +387,7 @@ func TestIssuerService_Update_EmptyName(t *testing.T) {
|
|||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
registry := NewIssuerRegistry(slog.Default())
|
registry := NewIssuerRegistry(slog.Default())
|
||||||
service := NewIssuerService(repo, auditService, registry, nil, slog.Default())
|
service := NewIssuerService(repo, auditService, registry, "", slog.Default())
|
||||||
|
|
||||||
issuer := &domain.Issuer{
|
issuer := &domain.Issuer{
|
||||||
Name: "",
|
Name: "",
|
||||||
@@ -415,7 +415,7 @@ func TestIssuerService_Delete(t *testing.T) {
|
|||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
registry := NewIssuerRegistry(slog.Default())
|
registry := NewIssuerRegistry(slog.Default())
|
||||||
service := NewIssuerService(repo, auditService, registry, nil, slog.Default())
|
service := NewIssuerService(repo, auditService, registry, "", slog.Default())
|
||||||
|
|
||||||
err := service.Delete(ctx, "iss-to-delete", "user-frank")
|
err := service.Delete(ctx, "iss-to-delete", "user-frank")
|
||||||
|
|
||||||
@@ -447,7 +447,7 @@ func TestIssuerService_Delete_RepositoryError(t *testing.T) {
|
|||||||
auditRepo := newMockAuditRepository()
|
auditRepo := newMockAuditRepository()
|
||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default())
|
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
|
||||||
|
|
||||||
err := service.Delete(ctx, "iss-bad-id", "user-grace")
|
err := service.Delete(ctx, "iss-bad-id", "user-grace")
|
||||||
|
|
||||||
@@ -482,12 +482,12 @@ func TestIssuerService_TestConnection_Success(t *testing.T) {
|
|||||||
auditRepo := newMockAuditRepository()
|
auditRepo := newMockAuditRepository()
|
||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
svc := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default())
|
svc := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
|
||||||
|
|
||||||
err := svc.TestConnectionWithContext(ctx, "iss-test-conn")
|
err := svc.TestConnection(ctx, "iss-test-conn")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("TestConnectionWithContext failed: %v", err)
|
t.Fatalf("TestConnection failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -500,9 +500,9 @@ func TestIssuerService_TestConnection_NotFound(t *testing.T) {
|
|||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
registry := NewIssuerRegistry(slog.Default())
|
registry := NewIssuerRegistry(slog.Default())
|
||||||
service := NewIssuerService(repo, auditService, registry, nil, slog.Default())
|
service := NewIssuerService(repo, auditService, registry, "", slog.Default())
|
||||||
|
|
||||||
err := service.TestConnectionWithContext(ctx, "nonexistent-issuer")
|
err := service.TestConnection(ctx, "nonexistent-issuer")
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for nonexistent issuer")
|
t.Fatal("expected error for nonexistent issuer")
|
||||||
@@ -540,9 +540,10 @@ func TestIssuerService_ListIssuers_HandlerInterface(t *testing.T) {
|
|||||||
auditRepo := newMockAuditRepository()
|
auditRepo := newMockAuditRepository()
|
||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), nil, slog.Default())
|
service := NewIssuerService(repo, auditService, NewIssuerRegistry(slog.Default()), "", slog.Default())
|
||||||
|
|
||||||
issuers, total, err := service.ListIssuers(1, 50)
|
ctx := context.Background()
|
||||||
|
issuers, total, err := service.ListIssuers(ctx, 1, 50)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ListIssuers failed: %v", err)
|
t.Fatalf("ListIssuers failed: %v", err)
|
||||||
@@ -580,7 +581,8 @@ func TestIssuerService_CreateIssuer_HandlerInterface(t *testing.T) {
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := service.CreateIssuer(issuer)
|
ctx := context.Background()
|
||||||
|
result, err := service.CreateIssuer(ctx, issuer)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("CreateIssuer failed: %v", err)
|
t.Fatalf("CreateIssuer failed: %v", err)
|
||||||
@@ -606,9 +608,10 @@ func TestIssuerService_DeleteIssuer_HandlerInterface(t *testing.T) {
|
|||||||
auditService := NewAuditService(auditRepo)
|
auditService := NewAuditService(auditRepo)
|
||||||
|
|
||||||
registry := NewIssuerRegistry(slog.Default())
|
registry := NewIssuerRegistry(slog.Default())
|
||||||
service := NewIssuerService(repo, auditService, registry, nil, slog.Default())
|
service := NewIssuerService(repo, auditService, registry, "", slog.Default())
|
||||||
|
|
||||||
err := service.DeleteIssuer("iss-handler-delete")
|
ctx := context.Background()
|
||||||
|
err := service.DeleteIssuer(ctx, "iss-handler-delete")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("DeleteIssuer failed: %v", err)
|
t.Fatalf("DeleteIssuer failed: %v", err)
|
||||||
@@ -722,7 +725,8 @@ func TestIssuerService_CreateIssuer_LowercaseType(t *testing.T) {
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := service.CreateIssuer(issuer)
|
ctx := context.Background()
|
||||||
|
result, err := service.CreateIssuer(ctx, issuer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("CreateIssuer with lowercase 'stepca' should succeed, got: %v", err)
|
t.Fatalf("CreateIssuer with lowercase 'stepca' should succeed, got: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
+8
-15
@@ -189,8 +189,8 @@ func (s *JobService) GetJobStatus(ctx context.Context, jobID string) (*domain.Jo
|
|||||||
return job, nil
|
return job, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CancelJobWithContext cancels a pending or running job.
|
// CancelJob cancels a pending or running job (handler interface method).
|
||||||
func (s *JobService) CancelJobWithContext(ctx context.Context, jobID string) error {
|
func (s *JobService) CancelJob(ctx context.Context, jobID string) error {
|
||||||
job, err := s.jobRepo.Get(ctx, jobID)
|
job, err := s.jobRepo.Get(ctx, jobID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to fetch job: %w", err)
|
return fmt.Errorf("failed to fetch job: %w", err)
|
||||||
@@ -208,13 +208,8 @@ func (s *JobService) CancelJobWithContext(ctx context.Context, jobID string) err
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CancelJob cancels a job (handler interface method).
|
|
||||||
func (s *JobService) CancelJob(id string) error {
|
|
||||||
return s.CancelJobWithContext(context.Background(), id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListJobs returns paginated jobs with optional filtering (handler interface method).
|
// ListJobs returns paginated jobs with optional filtering (handler interface method).
|
||||||
func (s *JobService) ListJobs(status, jobType string, page, perPage int) ([]domain.Job, int64, error) {
|
func (s *JobService) ListJobs(ctx context.Context, status, jobType string, page, perPage int) ([]domain.Job, int64, error) {
|
||||||
if page < 1 {
|
if page < 1 {
|
||||||
page = 1
|
page = 1
|
||||||
}
|
}
|
||||||
@@ -222,7 +217,7 @@ func (s *JobService) ListJobs(status, jobType string, page, perPage int) ([]doma
|
|||||||
perPage = 50
|
perPage = 50
|
||||||
}
|
}
|
||||||
|
|
||||||
allJobs, err := s.jobRepo.List(context.Background())
|
allJobs, err := s.jobRepo.List(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to list jobs: %w", err)
|
return nil, 0, fmt.Errorf("failed to list jobs: %w", err)
|
||||||
}
|
}
|
||||||
@@ -263,14 +258,13 @@ func (s *JobService) ListJobs(status, jobType string, page, perPage int) ([]doma
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetJob returns a single job (handler interface method).
|
// GetJob returns a single job (handler interface method).
|
||||||
func (s *JobService) GetJob(id string) (*domain.Job, error) {
|
func (s *JobService) GetJob(ctx context.Context, id string) (*domain.Job, error) {
|
||||||
return s.jobRepo.Get(context.Background(), id)
|
return s.jobRepo.Get(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ApproveJob approves a renewal job that is awaiting approval.
|
// ApproveJob approves a renewal job that is awaiting approval.
|
||||||
// Transitions the job from AwaitingApproval to Pending so the scheduler picks it up.
|
// Transitions the job from AwaitingApproval to Pending so the scheduler picks it up.
|
||||||
func (s *JobService) ApproveJob(id string) error {
|
func (s *JobService) ApproveJob(ctx context.Context, id string) error {
|
||||||
ctx := context.Background()
|
|
||||||
job, err := s.jobRepo.Get(ctx, id)
|
job, err := s.jobRepo.Get(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("job not found: %w", err)
|
return fmt.Errorf("job not found: %w", err)
|
||||||
@@ -290,8 +284,7 @@ func (s *JobService) ApproveJob(id string) error {
|
|||||||
|
|
||||||
// RejectJob rejects a renewal job that is awaiting approval.
|
// RejectJob rejects a renewal job that is awaiting approval.
|
||||||
// Transitions the job to Cancelled with a rejection reason.
|
// Transitions the job to Cancelled with a rejection reason.
|
||||||
func (s *JobService) RejectJob(id string, reason string) error {
|
func (s *JobService) RejectJob(ctx context.Context, id string, reason string) error {
|
||||||
ctx := context.Background()
|
|
||||||
job, err := s.jobRepo.Get(ctx, id)
|
job, err := s.jobRepo.Get(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("job not found: %w", err)
|
return fmt.Errorf("job not found: %w", err)
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ func TestCancelJob(t *testing.T) {
|
|||||||
|
|
||||||
jobService := newTestJobService(jobRepo)
|
jobService := newTestJobService(jobRepo)
|
||||||
|
|
||||||
err := jobService.CancelJobWithContext(ctx, "job-001")
|
err := jobService.CancelJob(ctx, "job-001")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("CancelJob failed: %v", err)
|
t.Fatalf("CancelJob failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -129,13 +129,15 @@ func TestCancelJob_AlreadyCompleted(t *testing.T) {
|
|||||||
|
|
||||||
jobService := newTestJobService(jobRepo)
|
jobService := newTestJobService(jobRepo)
|
||||||
|
|
||||||
err := jobService.CancelJobWithContext(ctx, "job-001")
|
err := jobService.CancelJob(ctx, "job-001")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for completed job")
|
t.Fatal("expected error for completed job")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetJob(t *testing.T) {
|
func TestGetJob(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
job := &domain.Job{
|
job := &domain.Job{
|
||||||
ID: "job-001",
|
ID: "job-001",
|
||||||
@@ -153,7 +155,7 @@ func TestGetJob(t *testing.T) {
|
|||||||
|
|
||||||
jobService := newTestJobService(jobRepo)
|
jobService := newTestJobService(jobRepo)
|
||||||
|
|
||||||
retrieved, err := jobService.GetJob("job-001")
|
retrieved, err := jobService.GetJob(ctx, "job-001")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GetJob failed: %v", err)
|
t.Fatalf("GetJob failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -167,6 +169,8 @@ func TestGetJob(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestListJobs(t *testing.T) {
|
func TestListJobs(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
job1 := &domain.Job{
|
job1 := &domain.Job{
|
||||||
ID: "job-001",
|
ID: "job-001",
|
||||||
@@ -192,7 +196,7 @@ func TestListJobs(t *testing.T) {
|
|||||||
|
|
||||||
jobService := newTestJobService(jobRepo)
|
jobService := newTestJobService(jobRepo)
|
||||||
|
|
||||||
jobs, total, err := jobService.ListJobs("", "", 1, 50)
|
jobs, total, err := jobService.ListJobs(ctx, "", "", 1, 50)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ListJobs failed: %v", err)
|
t.Fatalf("ListJobs failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -206,6 +210,8 @@ func TestListJobs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestListJobs_FilterByStatus(t *testing.T) {
|
func TestListJobs_FilterByStatus(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
job1 := &domain.Job{
|
job1 := &domain.Job{
|
||||||
ID: "job-001",
|
ID: "job-001",
|
||||||
@@ -231,7 +237,7 @@ func TestListJobs_FilterByStatus(t *testing.T) {
|
|||||||
|
|
||||||
jobService := newTestJobService(jobRepo)
|
jobService := newTestJobService(jobRepo)
|
||||||
|
|
||||||
jobs, total, err := jobService.ListJobs(string(domain.JobStatusPending), "", 1, 50)
|
jobs, total, err := jobService.ListJobs(ctx, string(domain.JobStatusPending), "", 1, 50)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ListJobs failed: %v", err)
|
t.Fatalf("ListJobs failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -235,21 +235,19 @@ func (s *NetworkScanService) scanTarget(ctx context.Context, target *domain.Netw
|
|||||||
timeout := time.Duration(target.TimeoutMs) * time.Millisecond
|
timeout := time.Duration(target.TimeoutMs) * time.Millisecond
|
||||||
results := s.scanEndpoints(ctx, endpoints, timeout)
|
results := s.scanEndpoints(ctx, endpoints, timeout)
|
||||||
|
|
||||||
// Collect discovered cert entries
|
// Collect discovered cert entries and per-endpoint errors.
|
||||||
var entries []domain.DiscoveredCertEntry
|
//
|
||||||
var scanErrors []string
|
// M-9 (operator-observability): before this fix, scanErrors was declared
|
||||||
for _, result := range results {
|
// but never appended to, so the "errors" count in the summary Info log
|
||||||
if result.Error != "" {
|
// and the Errors field on the DiscoveryReport were always zero/nil —
|
||||||
// Only log connection errors at debug level (many hosts won't have TLS)
|
// silently hiding per-endpoint failures from operators and from the
|
||||||
if s.logger != nil {
|
// downstream scan history record. Per-endpoint failures are still logged
|
||||||
s.logger.Debug("scan endpoint error",
|
// at Debug (sweep scans generate high connection-refused noise by design
|
||||||
"address", result.Address,
|
// — most hosts in a CIDR won't have TLS on the probed port), but the
|
||||||
"error", result.Error)
|
// aggregate count and the report's Errors field now reflect reality so
|
||||||
}
|
// operators can see, via the scan summary and the stored scan record,
|
||||||
continue
|
// how many endpoints failed without having to enable Debug logging.
|
||||||
}
|
entries, scanErrors := s.collectScanResults(results)
|
||||||
entries = append(entries, result.Certs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
scanDuration := time.Since(startTime)
|
scanDuration := time.Since(startTime)
|
||||||
if s.logger != nil {
|
if s.logger != nil {
|
||||||
@@ -385,6 +383,44 @@ func incrementIP(ip net.IP) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// collectScanResults partitions per-endpoint scan results into discovered
|
||||||
|
// certificate entries and a list of per-endpoint error strings.
|
||||||
|
//
|
||||||
|
// M-9 (operator-observability): the summary Info log and the DiscoveryReport
|
||||||
|
// both report the count of endpoints that failed to probe. Before this helper
|
||||||
|
// existed, the caller accumulated entries but never populated the errors
|
||||||
|
// slice, so the aggregate error count was always zero and the scan record's
|
||||||
|
// Errors field was always nil — silently hiding per-endpoint failures.
|
||||||
|
//
|
||||||
|
// Per-endpoint errors remain logged at Debug (sweep scans generate high
|
||||||
|
// connection-refused noise by design — most hosts in a CIDR won't have TLS
|
||||||
|
// on the probed port). Aggregation surfaces the count at Info, preserving
|
||||||
|
// Debug-level detail for operators who want it without creating log spam
|
||||||
|
// at default verbosity.
|
||||||
|
func (s *NetworkScanService) collectScanResults(results []domain.NetworkScanResult) ([]domain.DiscoveredCertEntry, []string) {
|
||||||
|
var entries []domain.DiscoveredCertEntry
|
||||||
|
var scanErrors []string
|
||||||
|
for _, result := range results {
|
||||||
|
if result.Error != "" {
|
||||||
|
// Debug-level is intentional: a sweep scan of a /24 typically
|
||||||
|
// produces 200+ connection-refused results, and logging each
|
||||||
|
// at Warn would create log spam at default verbosity. The
|
||||||
|
// aggregate count in the Info-level scan-completed log surfaces
|
||||||
|
// the failure volume to operators; Debug provides the detail
|
||||||
|
// when diagnosing a specific endpoint.
|
||||||
|
if s.logger != nil {
|
||||||
|
s.logger.Debug("scan endpoint error",
|
||||||
|
"address", result.Address,
|
||||||
|
"error", result.Error)
|
||||||
|
}
|
||||||
|
scanErrors = append(scanErrors, fmt.Sprintf("%s: %s", result.Address, result.Error))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
entries = append(entries, result.Certs...)
|
||||||
|
}
|
||||||
|
return entries, scanErrors
|
||||||
|
}
|
||||||
|
|
||||||
// scanEndpoints probes TLS endpoints concurrently and returns results.
|
// scanEndpoints probes TLS endpoints concurrently and returns results.
|
||||||
func (s *NetworkScanService) scanEndpoints(ctx context.Context, endpoints []string, timeout time.Duration) []domain.NetworkScanResult {
|
func (s *NetworkScanService) scanEndpoints(ctx context.Context, endpoints []string, timeout time.Duration) []domain.NetworkScanResult {
|
||||||
results := make([]domain.NetworkScanResult, len(endpoints))
|
results := make([]domain.NetworkScanResult, len(endpoints))
|
||||||
|
|||||||
@@ -491,3 +491,113 @@ func TestExpandCIDR_SingleLinkLocalIP(t *testing.T) {
|
|||||||
t.Errorf("expected empty for cloud metadata IP, got %v", ips)
|
t.Errorf("expected empty for cloud metadata IP, got %v", ips)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestCollectScanResults_AggregatesErrors is the M-9 regression guard:
|
||||||
|
// per-endpoint probe failures must accumulate into the errors slice so the
|
||||||
|
// summary Info log and the DiscoveryReport reflect the true failure count.
|
||||||
|
// Before the M-9 fix, scanErrors was declared but never appended to, so the
|
||||||
|
// aggregate count was always zero and the scan record's Errors field was
|
||||||
|
// always nil — silently hiding per-endpoint failures from operators.
|
||||||
|
func TestCollectScanResults_AggregatesErrors(t *testing.T) {
|
||||||
|
svc := &NetworkScanService{}
|
||||||
|
results := []domain.NetworkScanResult{
|
||||||
|
{Address: "203.0.113.1:443", Error: "connection refused"},
|
||||||
|
{Address: "203.0.113.2:443", Certs: []domain.DiscoveredCertEntry{
|
||||||
|
{CommonName: "example.com"},
|
||||||
|
}},
|
||||||
|
{Address: "203.0.113.3:443", Error: "tls handshake failure"},
|
||||||
|
{Address: "203.0.113.4:443", Certs: []domain.DiscoveredCertEntry{
|
||||||
|
{CommonName: "internal.example.com"},
|
||||||
|
}},
|
||||||
|
{Address: "203.0.113.5:443", Error: "i/o timeout"},
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, errs := svc.collectScanResults(results)
|
||||||
|
|
||||||
|
if len(entries) != 2 {
|
||||||
|
t.Errorf("expected 2 entries (one per successful probe), got %d", len(entries))
|
||||||
|
}
|
||||||
|
if len(errs) != 3 {
|
||||||
|
t.Fatalf("expected 3 error strings (one per failed probe), got %d: %v", len(errs), errs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each error string must be non-empty and include the endpoint address so
|
||||||
|
// the scan record lets operators correlate failures back to endpoints
|
||||||
|
// without needing Debug logging enabled.
|
||||||
|
for i, e := range errs {
|
||||||
|
if e == "" {
|
||||||
|
t.Errorf("error[%d]: expected non-empty error string", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Spot-check that address is threaded through the error strings.
|
||||||
|
if want := "203.0.113.1:443"; errs[0] == "" || errs[0][:len(want)] != want {
|
||||||
|
t.Errorf("errs[0] should start with %q, got %q", want, errs[0])
|
||||||
|
}
|
||||||
|
if want := "203.0.113.3:443"; errs[1] == "" || errs[1][:len(want)] != want {
|
||||||
|
t.Errorf("errs[1] should start with %q, got %q", want, errs[1])
|
||||||
|
}
|
||||||
|
if want := "203.0.113.5:443"; errs[2] == "" || errs[2][:len(want)] != want {
|
||||||
|
t.Errorf("errs[2] should start with %q, got %q", want, errs[2])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCollectScanResults_AllSuccess exercises the happy path: a scan where
|
||||||
|
// every endpoint returned certificates. The errors slice must be nil (not an
|
||||||
|
// empty non-nil slice) so the downstream DiscoveryReport.Errors field stays
|
||||||
|
// nil as well, preserving the JSON-omitempty behavior that callers rely on.
|
||||||
|
func TestCollectScanResults_AllSuccess(t *testing.T) {
|
||||||
|
svc := &NetworkScanService{}
|
||||||
|
results := []domain.NetworkScanResult{
|
||||||
|
{Address: "203.0.113.10:443", Certs: []domain.DiscoveredCertEntry{
|
||||||
|
{CommonName: "a.example.com"},
|
||||||
|
}},
|
||||||
|
{Address: "203.0.113.11:443", Certs: []domain.DiscoveredCertEntry{
|
||||||
|
{CommonName: "b.example.com"},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, errs := svc.collectScanResults(results)
|
||||||
|
|
||||||
|
if len(entries) != 2 {
|
||||||
|
t.Errorf("expected 2 entries, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if errs != nil {
|
||||||
|
t.Errorf("expected nil errors slice on all-success, got %v", errs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCollectScanResults_AllFailed exercises the worst-case sweep: every
|
||||||
|
// endpoint failed to probe. Entries must be nil, and every failure must be
|
||||||
|
// recorded in the errors slice so the scan record is complete.
|
||||||
|
func TestCollectScanResults_AllFailed(t *testing.T) {
|
||||||
|
svc := &NetworkScanService{}
|
||||||
|
results := []domain.NetworkScanResult{
|
||||||
|
{Address: "203.0.113.20:443", Error: "connection refused"},
|
||||||
|
{Address: "203.0.113.21:443", Error: "connection refused"},
|
||||||
|
{Address: "203.0.113.22:443", Error: "connection refused"},
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, errs := svc.collectScanResults(results)
|
||||||
|
|
||||||
|
if entries != nil {
|
||||||
|
t.Errorf("expected nil entries on all-failed, got %v", entries)
|
||||||
|
}
|
||||||
|
if len(errs) != 3 {
|
||||||
|
t.Errorf("expected 3 error strings, got %d: %v", len(errs), errs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCollectScanResults_Empty guards against a degenerate empty-input case
|
||||||
|
// (scanEndpoints returns no results, e.g. if ctx was cancelled before the
|
||||||
|
// first probe ran). Both return slices must be nil.
|
||||||
|
func TestCollectScanResults_Empty(t *testing.T) {
|
||||||
|
svc := &NetworkScanService{}
|
||||||
|
entries, errs := svc.collectScanResults(nil)
|
||||||
|
if entries != nil {
|
||||||
|
t.Errorf("expected nil entries for empty input, got %v", entries)
|
||||||
|
}
|
||||||
|
if errs != nil {
|
||||||
|
t.Errorf("expected nil errors for empty input, got %v", errs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -319,7 +319,7 @@ func (s *NotificationService) GetNotificationHistory(ctx context.Context, certID
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListNotifications returns paginated notifications (handler interface method).
|
// ListNotifications returns paginated notifications (handler interface method).
|
||||||
func (s *NotificationService) ListNotifications(page, perPage int) ([]domain.NotificationEvent, int64, error) {
|
func (s *NotificationService) ListNotifications(ctx context.Context, page, perPage int) ([]domain.NotificationEvent, int64, error) {
|
||||||
if page < 1 {
|
if page < 1 {
|
||||||
page = 1
|
page = 1
|
||||||
}
|
}
|
||||||
@@ -332,7 +332,7 @@ func (s *NotificationService) ListNotifications(page, perPage int) ([]domain.Not
|
|||||||
PerPage: perPage,
|
PerPage: perPage,
|
||||||
}
|
}
|
||||||
|
|
||||||
notifications, err := s.notifRepo.List(context.Background(), filter)
|
notifications, err := s.notifRepo.List(ctx, filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to list notifications: %w", err)
|
return nil, 0, fmt.Errorf("failed to list notifications: %w", err)
|
||||||
}
|
}
|
||||||
@@ -349,12 +349,12 @@ func (s *NotificationService) ListNotifications(page, perPage int) ([]domain.Not
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetNotification returns a single notification (handler interface method).
|
// GetNotification returns a single notification (handler interface method).
|
||||||
func (s *NotificationService) GetNotification(id string) (*domain.NotificationEvent, error) {
|
func (s *NotificationService) GetNotification(ctx context.Context, id string) (*domain.NotificationEvent, error) {
|
||||||
filter := &repository.NotificationFilter{
|
filter := &repository.NotificationFilter{
|
||||||
PerPage: 1,
|
PerPage: 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
notifications, err := s.notifRepo.List(context.Background(), filter)
|
notifications, err := s.notifRepo.List(ctx, filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get notification: %w", err)
|
return nil, fmt.Errorf("failed to get notification: %w", err)
|
||||||
}
|
}
|
||||||
@@ -370,6 +370,6 @@ func (s *NotificationService) GetNotification(id string) (*domain.NotificationEv
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MarkAsRead marks a notification as read (handler interface method).
|
// MarkAsRead marks a notification as read (handler interface method).
|
||||||
func (s *NotificationService) MarkAsRead(id string) error {
|
func (s *NotificationService) MarkAsRead(ctx context.Context, id string) error {
|
||||||
return s.notifRepo.UpdateStatus(context.Background(), id, "read", time.Now())
|
return s.notifRepo.UpdateStatus(ctx, id, "read", time.Now())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -370,7 +370,7 @@ func TestListNotifications(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// List with pagination
|
// List with pagination
|
||||||
notifs, total, err := svc.ListNotifications(1, 3)
|
notifs, total, err := svc.ListNotifications(context.Background(), 1, 3)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ListNotifications failed: %v", err)
|
t.Fatalf("ListNotifications failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -404,7 +404,7 @@ func TestMarkAsRead(t *testing.T) {
|
|||||||
notifRepo.AddNotification(notif)
|
notifRepo.AddNotification(notif)
|
||||||
|
|
||||||
// Mark as read
|
// Mark as read
|
||||||
err := svc.MarkAsRead(notif.ID)
|
err := svc.MarkAsRead(context.Background(), notif.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("MarkAsRead failed: %v", err)
|
t.Fatalf("MarkAsRead failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -434,7 +434,7 @@ func TestGetNotification(t *testing.T) {
|
|||||||
notifRepo.AddNotification(notif)
|
notifRepo.AddNotification(notif)
|
||||||
|
|
||||||
// Get the notification
|
// Get the notification
|
||||||
retrieved, err := svc.GetNotification(notif.ID)
|
retrieved, err := svc.GetNotification(context.Background(), notif.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GetNotification failed: %v", err)
|
t.Fatalf("GetNotification failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
+10
-10
@@ -126,7 +126,7 @@ func (s *OwnerService) Delete(ctx context.Context, id string, actor string) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListOwners returns paginated owners (handler interface method).
|
// ListOwners returns paginated owners (handler interface method).
|
||||||
func (s *OwnerService) ListOwners(page, perPage int) ([]domain.Owner, int64, error) {
|
func (s *OwnerService) ListOwners(ctx context.Context, page, perPage int) ([]domain.Owner, int64, error) {
|
||||||
if page < 1 {
|
if page < 1 {
|
||||||
page = 1
|
page = 1
|
||||||
}
|
}
|
||||||
@@ -134,7 +134,7 @@ func (s *OwnerService) ListOwners(page, perPage int) ([]domain.Owner, int64, err
|
|||||||
perPage = 50
|
perPage = 50
|
||||||
}
|
}
|
||||||
|
|
||||||
owners, err := s.ownerRepo.List(context.Background())
|
owners, err := s.ownerRepo.List(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to list owners: %w", err)
|
return nil, 0, fmt.Errorf("failed to list owners: %w", err)
|
||||||
}
|
}
|
||||||
@@ -151,12 +151,12 @@ func (s *OwnerService) ListOwners(page, perPage int) ([]domain.Owner, int64, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetOwner returns a single owner (handler interface method).
|
// GetOwner returns a single owner (handler interface method).
|
||||||
func (s *OwnerService) GetOwner(id string) (*domain.Owner, error) {
|
func (s *OwnerService) GetOwner(ctx context.Context, id string) (*domain.Owner, error) {
|
||||||
return s.ownerRepo.Get(context.Background(), id)
|
return s.ownerRepo.Get(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateOwner creates a new owner (handler interface method).
|
// CreateOwner creates a new owner (handler interface method).
|
||||||
func (s *OwnerService) CreateOwner(owner domain.Owner) (*domain.Owner, error) {
|
func (s *OwnerService) CreateOwner(ctx context.Context, owner domain.Owner) (*domain.Owner, error) {
|
||||||
if owner.ID == "" {
|
if owner.ID == "" {
|
||||||
owner.ID = generateID("owner")
|
owner.ID = generateID("owner")
|
||||||
}
|
}
|
||||||
@@ -167,22 +167,22 @@ func (s *OwnerService) CreateOwner(owner domain.Owner) (*domain.Owner, error) {
|
|||||||
if owner.UpdatedAt.IsZero() {
|
if owner.UpdatedAt.IsZero() {
|
||||||
owner.UpdatedAt = now
|
owner.UpdatedAt = now
|
||||||
}
|
}
|
||||||
if err := s.ownerRepo.Create(context.Background(), &owner); err != nil {
|
if err := s.ownerRepo.Create(ctx, &owner); err != nil {
|
||||||
return nil, fmt.Errorf("failed to create owner: %w", err)
|
return nil, fmt.Errorf("failed to create owner: %w", err)
|
||||||
}
|
}
|
||||||
return &owner, nil
|
return &owner, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateOwner modifies an owner (handler interface method).
|
// UpdateOwner modifies an owner (handler interface method).
|
||||||
func (s *OwnerService) UpdateOwner(id string, owner domain.Owner) (*domain.Owner, error) {
|
func (s *OwnerService) UpdateOwner(ctx context.Context, id string, owner domain.Owner) (*domain.Owner, error) {
|
||||||
owner.ID = id
|
owner.ID = id
|
||||||
if err := s.ownerRepo.Update(context.Background(), &owner); err != nil {
|
if err := s.ownerRepo.Update(ctx, &owner); err != nil {
|
||||||
return nil, fmt.Errorf("failed to update owner: %w", err)
|
return nil, fmt.Errorf("failed to update owner: %w", err)
|
||||||
}
|
}
|
||||||
return &owner, nil
|
return &owner, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteOwner removes an owner (handler interface method).
|
// DeleteOwner removes an owner (handler interface method).
|
||||||
func (s *OwnerService) DeleteOwner(id string) error {
|
func (s *OwnerService) DeleteOwner(ctx context.Context, id string) error {
|
||||||
return s.ownerRepo.Delete(context.Background(), id)
|
return s.ownerRepo.Delete(ctx, id)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -638,7 +638,7 @@ func TestOwnerService_ListOwners_HandlerInterface(t *testing.T) {
|
|||||||
|
|
||||||
ownerService := NewOwnerService(ownerRepo, auditService)
|
ownerService := NewOwnerService(ownerRepo, auditService)
|
||||||
|
|
||||||
owners, total, err := ownerService.ListOwners(1, 50)
|
owners, total, err := ownerService.ListOwners(context.Background(), 1, 50)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ListOwners failed: %v", err)
|
t.Fatalf("ListOwners failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -678,7 +678,7 @@ func TestOwnerService_GetOwner_HandlerInterface(t *testing.T) {
|
|||||||
|
|
||||||
ownerService := NewOwnerService(ownerRepo, auditService)
|
ownerService := NewOwnerService(ownerRepo, auditService)
|
||||||
|
|
||||||
retrieved, err := ownerService.GetOwner("owner-001")
|
retrieved, err := ownerService.GetOwner(context.Background(), "owner-001")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GetOwner failed: %v", err)
|
t.Fatalf("GetOwner failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -702,7 +702,7 @@ func TestOwnerService_CreateOwner_HandlerInterface(t *testing.T) {
|
|||||||
TeamID: "team-001",
|
TeamID: "team-001",
|
||||||
}
|
}
|
||||||
|
|
||||||
created, err := ownerService.CreateOwner(owner)
|
created, err := ownerService.CreateOwner(context.Background(), owner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("CreateOwner failed: %v", err)
|
t.Fatalf("CreateOwner failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -752,7 +752,7 @@ func TestOwnerService_UpdateOwner_HandlerInterface(t *testing.T) {
|
|||||||
TeamID: "team-002",
|
TeamID: "team-002",
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, err := ownerService.UpdateOwner("owner-001", updatedOwner)
|
updated, err := ownerService.UpdateOwner(context.Background(), "owner-001", updatedOwner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("UpdateOwner failed: %v", err)
|
t.Fatalf("UpdateOwner failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -798,7 +798,7 @@ func TestOwnerService_DeleteOwner_HandlerInterface(t *testing.T) {
|
|||||||
|
|
||||||
ownerService := NewOwnerService(ownerRepo, auditService)
|
ownerService := NewOwnerService(ownerRepo, auditService)
|
||||||
|
|
||||||
err := ownerService.DeleteOwner("owner-001")
|
err := ownerService.DeleteOwner(context.Background(), "owner-001")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("DeleteOwner failed: %v", err)
|
t.Fatalf("DeleteOwner failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
+12
-12
@@ -230,7 +230,7 @@ func (s *PolicyService) ListViolationsWithContext(ctx context.Context, filter *r
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListPolicies returns paginated policies (handler interface method).
|
// ListPolicies returns paginated policies (handler interface method).
|
||||||
func (s *PolicyService) ListPolicies(page, perPage int) ([]domain.PolicyRule, int64, error) {
|
func (s *PolicyService) ListPolicies(ctx context.Context, page, perPage int) ([]domain.PolicyRule, int64, error) {
|
||||||
if page < 1 {
|
if page < 1 {
|
||||||
page = 1
|
page = 1
|
||||||
}
|
}
|
||||||
@@ -238,7 +238,7 @@ func (s *PolicyService) ListPolicies(page, perPage int) ([]domain.PolicyRule, in
|
|||||||
perPage = 50
|
perPage = 50
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, err := s.policyRepo.ListRules(context.Background())
|
rules, err := s.policyRepo.ListRules(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to list policies: %w", err)
|
return nil, 0, fmt.Errorf("failed to list policies: %w", err)
|
||||||
}
|
}
|
||||||
@@ -264,12 +264,12 @@ func (s *PolicyService) ListPolicies(page, perPage int) ([]domain.PolicyRule, in
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetPolicy returns a single policy (handler interface method).
|
// GetPolicy returns a single policy (handler interface method).
|
||||||
func (s *PolicyService) GetPolicy(id string) (*domain.PolicyRule, error) {
|
func (s *PolicyService) GetPolicy(ctx context.Context, id string) (*domain.PolicyRule, error) {
|
||||||
return s.policyRepo.GetRule(context.Background(), id)
|
return s.policyRepo.GetRule(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreatePolicy creates a new policy (handler interface method).
|
// CreatePolicy creates a new policy (handler interface method).
|
||||||
func (s *PolicyService) CreatePolicy(policy domain.PolicyRule) (*domain.PolicyRule, error) {
|
func (s *PolicyService) CreatePolicy(ctx context.Context, policy domain.PolicyRule) (*domain.PolicyRule, error) {
|
||||||
if policy.ID == "" {
|
if policy.ID == "" {
|
||||||
policy.ID = generateID("rule")
|
policy.ID = generateID("rule")
|
||||||
}
|
}
|
||||||
@@ -277,30 +277,30 @@ func (s *PolicyService) CreatePolicy(policy domain.PolicyRule) (*domain.PolicyRu
|
|||||||
policy.CreatedAt = time.Now()
|
policy.CreatedAt = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.policyRepo.CreateRule(context.Background(), &policy); err != nil {
|
if err := s.policyRepo.CreateRule(ctx, &policy); err != nil {
|
||||||
return nil, fmt.Errorf("failed to create policy: %w", err)
|
return nil, fmt.Errorf("failed to create policy: %w", err)
|
||||||
}
|
}
|
||||||
return &policy, nil
|
return &policy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePolicy modifies a policy (handler interface method).
|
// UpdatePolicy modifies a policy (handler interface method).
|
||||||
func (s *PolicyService) UpdatePolicy(id string, policy domain.PolicyRule) (*domain.PolicyRule, error) {
|
func (s *PolicyService) UpdatePolicy(ctx context.Context, id string, policy domain.PolicyRule) (*domain.PolicyRule, error) {
|
||||||
policy.ID = id
|
policy.ID = id
|
||||||
policy.UpdatedAt = time.Now()
|
policy.UpdatedAt = time.Now()
|
||||||
|
|
||||||
if err := s.policyRepo.UpdateRule(context.Background(), &policy); err != nil {
|
if err := s.policyRepo.UpdateRule(ctx, &policy); err != nil {
|
||||||
return nil, fmt.Errorf("failed to update policy: %w", err)
|
return nil, fmt.Errorf("failed to update policy: %w", err)
|
||||||
}
|
}
|
||||||
return &policy, nil
|
return &policy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePolicy removes a policy (handler interface method).
|
// DeletePolicy removes a policy (handler interface method).
|
||||||
func (s *PolicyService) DeletePolicy(id string) error {
|
func (s *PolicyService) DeletePolicy(ctx context.Context, id string) error {
|
||||||
return s.policyRepo.DeleteRule(context.Background(), id)
|
return s.policyRepo.DeleteRule(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListViolations returns policy violations with pagination (handler interface method).
|
// ListViolations returns policy violations with pagination (handler interface method).
|
||||||
func (s *PolicyService) ListViolations(policyID string, page, perPage int) ([]domain.PolicyViolation, int64, error) {
|
func (s *PolicyService) ListViolations(ctx context.Context, policyID string, page, perPage int) ([]domain.PolicyViolation, int64, error) {
|
||||||
if page < 1 {
|
if page < 1 {
|
||||||
page = 1
|
page = 1
|
||||||
}
|
}
|
||||||
@@ -313,7 +313,7 @@ func (s *PolicyService) ListViolations(policyID string, page, perPage int) ([]do
|
|||||||
PerPage: 1000, // Get all violations for the policy
|
PerPage: 1000, // Get all violations for the policy
|
||||||
}
|
}
|
||||||
|
|
||||||
violations, err := s.policyRepo.ListViolations(context.Background(), filter)
|
violations, err := s.policyRepo.ListViolations(ctx, filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to list violations: %w", err)
|
return nil, 0, fmt.Errorf("failed to list violations: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -376,7 +376,7 @@ func TestListPolicies(t *testing.T) {
|
|||||||
|
|
||||||
policyService := NewPolicyService(policyRepo, auditService)
|
policyService := NewPolicyService(policyRepo, auditService)
|
||||||
|
|
||||||
policies, total, err := policyService.ListPolicies(1, 50)
|
policies, total, err := policyService.ListPolicies(context.Background(), 1, 50)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ListPolicies failed: %v", err)
|
t.Fatalf("ListPolicies failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -407,7 +407,7 @@ func TestCreatePolicy(t *testing.T) {
|
|||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
}
|
}
|
||||||
|
|
||||||
created, err := policyService.CreatePolicy(policy)
|
created, err := policyService.CreatePolicy(context.Background(), policy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("CreatePolicy failed: %v", err)
|
t.Fatalf("CreatePolicy failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
+13
-13
@@ -28,7 +28,7 @@ func NewProfileService(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListProfiles returns all profiles (handler interface method).
|
// ListProfiles returns all profiles (handler interface method).
|
||||||
func (s *ProfileService) ListProfiles(page, perPage int) ([]domain.CertificateProfile, int64, error) {
|
func (s *ProfileService) ListProfiles(ctx context.Context, page, perPage int) ([]domain.CertificateProfile, int64, error) {
|
||||||
if page < 1 {
|
if page < 1 {
|
||||||
page = 1
|
page = 1
|
||||||
}
|
}
|
||||||
@@ -36,7 +36,7 @@ func (s *ProfileService) ListProfiles(page, perPage int) ([]domain.CertificatePr
|
|||||||
perPage = 50
|
perPage = 50
|
||||||
}
|
}
|
||||||
|
|
||||||
profiles, err := s.profileRepo.List(context.Background())
|
profiles, err := s.profileRepo.List(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to list profiles: %w", err)
|
return nil, 0, fmt.Errorf("failed to list profiles: %w", err)
|
||||||
}
|
}
|
||||||
@@ -53,12 +53,12 @@ func (s *ProfileService) ListProfiles(page, perPage int) ([]domain.CertificatePr
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetProfile returns a single profile (handler interface method).
|
// GetProfile returns a single profile (handler interface method).
|
||||||
func (s *ProfileService) GetProfile(id string) (*domain.CertificateProfile, error) {
|
func (s *ProfileService) GetProfile(ctx context.Context, id string) (*domain.CertificateProfile, error) {
|
||||||
return s.profileRepo.Get(context.Background(), id)
|
return s.profileRepo.Get(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateProfile creates a new profile with validation (handler interface method).
|
// CreateProfile creates a new profile with validation (handler interface method).
|
||||||
func (s *ProfileService) CreateProfile(profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
|
func (s *ProfileService) CreateProfile(ctx context.Context, profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
|
||||||
if err := validateProfile(&profile); err != nil {
|
if err := validateProfile(&profile); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -82,12 +82,12 @@ func (s *ProfileService) CreateProfile(profile domain.CertificateProfile) (*doma
|
|||||||
profile.AllowedEKUs = domain.DefaultEKUs()
|
profile.AllowedEKUs = domain.DefaultEKUs()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.profileRepo.Create(context.Background(), &profile); err != nil {
|
if err := s.profileRepo.Create(ctx, &profile); err != nil {
|
||||||
return nil, fmt.Errorf("failed to create profile: %w", err)
|
return nil, fmt.Errorf("failed to create profile: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.auditService != nil {
|
if s.auditService != nil {
|
||||||
if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser,
|
if auditErr := s.auditService.RecordEvent(context.WithoutCancel(ctx), "api", domain.ActorTypeUser,
|
||||||
"create_profile", "certificate_profile", profile.ID, nil); auditErr != nil {
|
"create_profile", "certificate_profile", profile.ID, nil); auditErr != nil {
|
||||||
slog.Error("failed to record audit event", "error", auditErr)
|
slog.Error("failed to record audit event", "error", auditErr)
|
||||||
}
|
}
|
||||||
@@ -97,18 +97,18 @@ func (s *ProfileService) CreateProfile(profile domain.CertificateProfile) (*doma
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateProfile modifies an existing profile (handler interface method).
|
// UpdateProfile modifies an existing profile (handler interface method).
|
||||||
func (s *ProfileService) UpdateProfile(id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
|
func (s *ProfileService) UpdateProfile(ctx context.Context, id string, profile domain.CertificateProfile) (*domain.CertificateProfile, error) {
|
||||||
if err := validateProfile(&profile); err != nil {
|
if err := validateProfile(&profile); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
profile.ID = id
|
profile.ID = id
|
||||||
if err := s.profileRepo.Update(context.Background(), &profile); err != nil {
|
if err := s.profileRepo.Update(ctx, &profile); err != nil {
|
||||||
return nil, fmt.Errorf("failed to update profile: %w", err)
|
return nil, fmt.Errorf("failed to update profile: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.auditService != nil {
|
if s.auditService != nil {
|
||||||
if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser,
|
if auditErr := s.auditService.RecordEvent(context.WithoutCancel(ctx), "api", domain.ActorTypeUser,
|
||||||
"update_profile", "certificate_profile", id, nil); auditErr != nil {
|
"update_profile", "certificate_profile", id, nil); auditErr != nil {
|
||||||
slog.Error("failed to record audit event", "error", auditErr)
|
slog.Error("failed to record audit event", "error", auditErr)
|
||||||
}
|
}
|
||||||
@@ -118,13 +118,13 @@ func (s *ProfileService) UpdateProfile(id string, profile domain.CertificateProf
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteProfile removes a profile (handler interface method).
|
// DeleteProfile removes a profile (handler interface method).
|
||||||
func (s *ProfileService) DeleteProfile(id string) error {
|
func (s *ProfileService) DeleteProfile(ctx context.Context, id string) error {
|
||||||
if err := s.profileRepo.Delete(context.Background(), id); err != nil {
|
if err := s.profileRepo.Delete(ctx, id); err != nil {
|
||||||
return fmt.Errorf("failed to delete profile: %w", err)
|
return fmt.Errorf("failed to delete profile: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.auditService != nil {
|
if s.auditService != nil {
|
||||||
if auditErr := s.auditService.RecordEvent(context.Background(), "api", domain.ActorTypeUser,
|
if auditErr := s.auditService.RecordEvent(context.WithoutCancel(ctx), "api", domain.ActorTypeUser,
|
||||||
"delete_profile", "certificate_profile", id, nil); auditErr != nil {
|
"delete_profile", "certificate_profile", id, nil); auditErr != nil {
|
||||||
slog.Error("failed to record audit event", "error", auditErr)
|
slog.Error("failed to record audit event", "error", auditErr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ func TestProfileService_ListProfiles(t *testing.T) {
|
|||||||
repo.AddProfile(&domain.CertificateProfile{ID: "prof-2", Name: "Internal mTLS", Enabled: true})
|
repo.AddProfile(&domain.CertificateProfile{ID: "prof-2", Name: "Internal mTLS", Enabled: true})
|
||||||
|
|
||||||
svc := NewProfileService(repo, nil)
|
svc := NewProfileService(repo, nil)
|
||||||
profiles, total, err := svc.ListProfiles(1, 50)
|
profiles, total, err := svc.ListProfiles(context.Background(), 1, 50)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -98,7 +98,7 @@ func TestProfileService_ListProfiles_Empty(t *testing.T) {
|
|||||||
repo := newMockProfileRepository()
|
repo := newMockProfileRepository()
|
||||||
svc := NewProfileService(repo, nil)
|
svc := NewProfileService(repo, nil)
|
||||||
|
|
||||||
profiles, total, err := svc.ListProfiles(1, 50)
|
profiles, total, err := svc.ListProfiles(context.Background(), 1, 50)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -115,7 +115,7 @@ func TestProfileService_ListProfiles_RepoError(t *testing.T) {
|
|||||||
repo.ListErr = errors.New("db error")
|
repo.ListErr = errors.New("db error")
|
||||||
svc := NewProfileService(repo, nil)
|
svc := NewProfileService(repo, nil)
|
||||||
|
|
||||||
_, _, err := svc.ListProfiles(1, 50)
|
_, _, err := svc.ListProfiles(context.Background(), 1, 50)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error, got nil")
|
t.Fatal("expected error, got nil")
|
||||||
}
|
}
|
||||||
@@ -126,7 +126,7 @@ func TestProfileService_GetProfile(t *testing.T) {
|
|||||||
repo.AddProfile(&domain.CertificateProfile{ID: "prof-1", Name: "Standard TLS"})
|
repo.AddProfile(&domain.CertificateProfile{ID: "prof-1", Name: "Standard TLS"})
|
||||||
svc := NewProfileService(repo, nil)
|
svc := NewProfileService(repo, nil)
|
||||||
|
|
||||||
profile, err := svc.GetProfile("prof-1")
|
profile, err := svc.GetProfile(context.Background(), "prof-1")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -139,7 +139,7 @@ func TestProfileService_GetProfile_NotFound(t *testing.T) {
|
|||||||
repo := newMockProfileRepository()
|
repo := newMockProfileRepository()
|
||||||
svc := NewProfileService(repo, nil)
|
svc := NewProfileService(repo, nil)
|
||||||
|
|
||||||
_, err := svc.GetProfile("nonexistent")
|
_, err := svc.GetProfile(context.Background(), "nonexistent")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error, got nil")
|
t.Fatal("expected error, got nil")
|
||||||
}
|
}
|
||||||
@@ -156,7 +156,7 @@ func TestProfileService_CreateProfile_Defaults(t *testing.T) {
|
|||||||
MaxTTLSeconds: 86400,
|
MaxTTLSeconds: 86400,
|
||||||
}
|
}
|
||||||
|
|
||||||
created, err := svc.CreateProfile(profile)
|
created, err := svc.CreateProfile(context.Background(), profile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -258,7 +258,7 @@ func TestProfileService_CreateProfile_ValidationErrors(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
_, err := svc.CreateProfile(tt.profile)
|
_, err := svc.CreateProfile(context.Background(), tt.profile)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("expected error containing %q, got nil", tt.errMsg)
|
t.Fatalf("expected error containing %q, got nil", tt.errMsg)
|
||||||
}
|
}
|
||||||
@@ -274,7 +274,7 @@ func TestProfileService_CreateProfile_RepoError(t *testing.T) {
|
|||||||
repo.CreateErr = errors.New("db create failed")
|
repo.CreateErr = errors.New("db create failed")
|
||||||
svc := NewProfileService(repo, nil)
|
svc := NewProfileService(repo, nil)
|
||||||
|
|
||||||
_, err := svc.CreateProfile(domain.CertificateProfile{Name: "Valid"})
|
_, err := svc.CreateProfile(context.Background(), domain.CertificateProfile{Name: "Valid"})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error, got nil")
|
t.Fatal("expected error, got nil")
|
||||||
}
|
}
|
||||||
@@ -287,7 +287,7 @@ func TestProfileService_UpdateProfile(t *testing.T) {
|
|||||||
auditSvc := NewAuditService(auditRepo)
|
auditSvc := NewAuditService(auditRepo)
|
||||||
svc := NewProfileService(repo, auditSvc)
|
svc := NewProfileService(repo, auditSvc)
|
||||||
|
|
||||||
updated, err := svc.UpdateProfile("prof-1", domain.CertificateProfile{
|
updated, err := svc.UpdateProfile(context.Background(), "prof-1", domain.CertificateProfile{
|
||||||
Name: "Updated",
|
Name: "Updated",
|
||||||
MaxTTLSeconds: 43200,
|
MaxTTLSeconds: 43200,
|
||||||
})
|
})
|
||||||
@@ -306,7 +306,7 @@ func TestProfileService_UpdateProfile_ValidationError(t *testing.T) {
|
|||||||
repo := newMockProfileRepository()
|
repo := newMockProfileRepository()
|
||||||
svc := NewProfileService(repo, nil)
|
svc := NewProfileService(repo, nil)
|
||||||
|
|
||||||
_, err := svc.UpdateProfile("prof-1", domain.CertificateProfile{Name: ""})
|
_, err := svc.UpdateProfile(context.Background(), "prof-1", domain.CertificateProfile{Name: ""})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected validation error, got nil")
|
t.Fatal("expected validation error, got nil")
|
||||||
}
|
}
|
||||||
@@ -319,7 +319,7 @@ func TestProfileService_DeleteProfile(t *testing.T) {
|
|||||||
auditSvc := NewAuditService(auditRepo)
|
auditSvc := NewAuditService(auditRepo)
|
||||||
svc := NewProfileService(repo, auditSvc)
|
svc := NewProfileService(repo, auditSvc)
|
||||||
|
|
||||||
err := svc.DeleteProfile("prof-1")
|
err := svc.DeleteProfile(context.Background(), "prof-1")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -333,7 +333,7 @@ func TestProfileService_DeleteProfile_RepoError(t *testing.T) {
|
|||||||
repo.DeleteErr = errors.New("db delete failed")
|
repo.DeleteErr = errors.New("db delete failed")
|
||||||
svc := NewProfileService(repo, nil)
|
svc := NewProfileService(repo, nil)
|
||||||
|
|
||||||
err := svc.DeleteProfile("prof-1")
|
err := svc.DeleteProfile(context.Background(), "prof-1")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error, got nil")
|
t.Fatal("expected error, got nil")
|
||||||
}
|
}
|
||||||
@@ -344,7 +344,7 @@ func TestProfileService_CreateProfile_ValidShortLived(t *testing.T) {
|
|||||||
svc := NewProfileService(repo, nil)
|
svc := NewProfileService(repo, nil)
|
||||||
|
|
||||||
// Short-lived with TTL under 1 hour should succeed
|
// Short-lived with TTL under 1 hour should succeed
|
||||||
created, err := svc.CreateProfile(domain.CertificateProfile{
|
created, err := svc.CreateProfile(context.Background(), domain.CertificateProfile{
|
||||||
Name: "CI Ephemeral",
|
Name: "CI Ephemeral",
|
||||||
AllowShortLived: true,
|
AllowShortLived: true,
|
||||||
MaxTTLSeconds: 300, // 5 minutes
|
MaxTTLSeconds: 300, // 5 minutes
|
||||||
|
|||||||
@@ -151,9 +151,9 @@ func (s *RevocationSvc) RevokeCertificateWithActor(ctx context.Context, certID s
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetRevokedCertificates returns all revoked certificate records (for CRL generation).
|
// GetRevokedCertificates returns all revoked certificate records (for CRL generation).
|
||||||
func (s *RevocationSvc) GetRevokedCertificates() ([]*domain.CertificateRevocation, error) {
|
func (s *RevocationSvc) GetRevokedCertificates(ctx context.Context) ([]*domain.CertificateRevocation, error) {
|
||||||
if s.revocationRepo == nil {
|
if s.revocationRepo == nil {
|
||||||
return nil, fmt.Errorf("revocation repository not configured")
|
return nil, fmt.Errorf("revocation repository not configured")
|
||||||
}
|
}
|
||||||
return s.revocationRepo.ListAll(context.Background())
|
return s.revocationRepo.ListAll(ctx)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ func TestRevocationSvc_GetRevokedCertificates_Success(t *testing.T) {
|
|||||||
{ID: "rev-2", CertificateID: "cert-2", SerialNumber: "SER-2", Reason: "superseded", RevokedAt: time.Now()},
|
{ID: "rev-2", CertificateID: "cert-2", SerialNumber: "SER-2", Reason: "superseded", RevokedAt: time.Now()},
|
||||||
}
|
}
|
||||||
|
|
||||||
revocations, err := revSvc.GetRevokedCertificates()
|
revocations, err := revSvc.GetRevokedCertificates(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected no error, got: %v", err)
|
t.Fatalf("expected no error, got: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ func TestRevokeCertificate_Success(t *testing.T) {
|
|||||||
certRepo.Versions["cert-1"] = []*domain.CertificateVersion{version}
|
certRepo.Versions["cert-1"] = []*domain.CertificateVersion{version}
|
||||||
|
|
||||||
// Revoke
|
// Revoke
|
||||||
err := svc.RevokeCertificateWithActor(context.Background(), "cert-1", "keyCompromise", "admin")
|
err := svc.RevokeCertificate(context.Background(), "cert-1", "keyCompromise", "admin")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected no error, got: %v", err)
|
t.Fatalf("expected no error, got: %v", err)
|
||||||
}
|
}
|
||||||
@@ -125,7 +125,7 @@ func TestRevokeCertificate_DefaultReason(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Revoke with empty reason — should default to "unspecified"
|
// Revoke with empty reason — should default to "unspecified"
|
||||||
err := svc.RevokeCertificateWithActor(context.Background(), "cert-2", "", "api")
|
err := svc.RevokeCertificate(context.Background(), "cert-2", "", "api")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected no error, got: %v", err)
|
t.Fatalf("expected no error, got: %v", err)
|
||||||
}
|
}
|
||||||
@@ -158,7 +158,7 @@ func TestRevokeCertificate_AlreadyRevoked(t *testing.T) {
|
|||||||
}
|
}
|
||||||
certRepo.AddCert(cert)
|
certRepo.AddCert(cert)
|
||||||
|
|
||||||
err := svc.RevokeCertificateWithActor(context.Background(), "cert-3", "superseded", "admin")
|
err := svc.RevokeCertificate(context.Background(), "cert-3", "superseded", "admin")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for already revoked certificate")
|
t.Fatal("expected error for already revoked certificate")
|
||||||
}
|
}
|
||||||
@@ -179,7 +179,7 @@ func TestRevokeCertificate_ArchivedCert(t *testing.T) {
|
|||||||
}
|
}
|
||||||
certRepo.AddCert(cert)
|
certRepo.AddCert(cert)
|
||||||
|
|
||||||
err := svc.RevokeCertificateWithActor(context.Background(), "cert-4", "keyCompromise", "admin")
|
err := svc.RevokeCertificate(context.Background(), "cert-4", "keyCompromise", "admin")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for archived certificate")
|
t.Fatal("expected error for archived certificate")
|
||||||
}
|
}
|
||||||
@@ -200,7 +200,7 @@ func TestRevokeCertificate_InvalidReason(t *testing.T) {
|
|||||||
}
|
}
|
||||||
certRepo.AddCert(cert)
|
certRepo.AddCert(cert)
|
||||||
|
|
||||||
err := svc.RevokeCertificateWithActor(context.Background(), "cert-5", "notAValidReason", "admin")
|
err := svc.RevokeCertificate(context.Background(), "cert-5", "notAValidReason", "admin")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for invalid reason")
|
t.Fatal("expected error for invalid reason")
|
||||||
}
|
}
|
||||||
@@ -212,7 +212,7 @@ func TestRevokeCertificate_InvalidReason(t *testing.T) {
|
|||||||
func TestRevokeCertificate_NotFound(t *testing.T) {
|
func TestRevokeCertificate_NotFound(t *testing.T) {
|
||||||
svc, _, _, _ := newRevocationTestService()
|
svc, _, _, _ := newRevocationTestService()
|
||||||
|
|
||||||
err := svc.RevokeCertificateWithActor(context.Background(), "nonexistent-cert", "keyCompromise", "admin")
|
err := svc.RevokeCertificate(context.Background(), "nonexistent-cert", "keyCompromise", "admin")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for nonexistent certificate")
|
t.Fatal("expected error for nonexistent certificate")
|
||||||
}
|
}
|
||||||
@@ -231,7 +231,7 @@ func TestRevokeCertificate_NoVersion(t *testing.T) {
|
|||||||
certRepo.AddCert(cert)
|
certRepo.AddCert(cert)
|
||||||
// No versions added — should fail
|
// No versions added — should fail
|
||||||
|
|
||||||
err := svc.RevokeCertificateWithActor(context.Background(), "cert-6", "keyCompromise", "admin")
|
err := svc.RevokeCertificate(context.Background(), "cert-6", "keyCompromise", "admin")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error when no certificate version exists")
|
t.Fatal("expected error when no certificate version exists")
|
||||||
}
|
}
|
||||||
@@ -258,7 +258,7 @@ func TestRevokeCertificate_WithIssuerNotification(t *testing.T) {
|
|||||||
{ID: "ver-7", CertificateID: "cert-7", SerialNumber: "GHI789", CreatedAt: time.Now()},
|
{ID: "ver-7", CertificateID: "cert-7", SerialNumber: "GHI789", CreatedAt: time.Now()},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := svc.RevokeCertificateWithActor(context.Background(), "cert-7", "cessationOfOperation", "admin")
|
err := svc.RevokeCertificate(context.Background(), "cert-7", "cessationOfOperation", "admin")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected no error, got: %v", err)
|
t.Fatalf("expected no error, got: %v", err)
|
||||||
}
|
}
|
||||||
@@ -293,7 +293,7 @@ func TestRevokeCertificate_WithNotificationService(t *testing.T) {
|
|||||||
{ID: "ver-8", CertificateID: "cert-8", SerialNumber: "JKL012", CreatedAt: time.Now()},
|
{ID: "ver-8", CertificateID: "cert-8", SerialNumber: "JKL012", CreatedAt: time.Now()},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := svc.RevokeCertificateWithActor(context.Background(), "cert-8", "keyCompromise", "admin")
|
err := svc.RevokeCertificate(context.Background(), "cert-8", "keyCompromise", "admin")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected no error, got: %v", err)
|
t.Fatalf("expected no error, got: %v", err)
|
||||||
}
|
}
|
||||||
@@ -336,7 +336,7 @@ func TestRevokeCertificate_AllValidReasons(t *testing.T) {
|
|||||||
{ID: "ver-" + reason, CertificateID: "cert-" + reason, SerialNumber: "SER-" + reason, CreatedAt: time.Now()},
|
{ID: "ver-" + reason, CertificateID: "cert-" + reason, SerialNumber: "SER-" + reason, CreatedAt: time.Now()},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := svc.RevokeCertificateWithActor(context.Background(), "cert-"+reason, reason, "admin")
|
err := svc.RevokeCertificate(context.Background(), "cert-"+reason, reason, "admin")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected no error for reason %s, got: %v", reason, err)
|
t.Fatalf("expected no error for reason %s, got: %v", reason, err)
|
||||||
}
|
}
|
||||||
@@ -358,7 +358,7 @@ func TestGetRevokedCertificates_Success(t *testing.T) {
|
|||||||
{ID: "rev-2", CertificateID: "cert-2", SerialNumber: "SER-2", Reason: "superseded", RevokedAt: time.Now()},
|
{ID: "rev-2", CertificateID: "cert-2", SerialNumber: "SER-2", Reason: "superseded", RevokedAt: time.Now()},
|
||||||
}
|
}
|
||||||
|
|
||||||
revocations, err := svc.GetRevokedCertificates()
|
revocations, err := svc.GetRevokedCertificates(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected no error, got: %v", err)
|
t.Fatalf("expected no error, got: %v", err)
|
||||||
}
|
}
|
||||||
@@ -370,7 +370,7 @@ func TestGetRevokedCertificates_Success(t *testing.T) {
|
|||||||
func TestGetRevokedCertificates_Empty(t *testing.T) {
|
func TestGetRevokedCertificates_Empty(t *testing.T) {
|
||||||
svc, _, _, _ := newRevocationTestService()
|
svc, _, _, _ := newRevocationTestService()
|
||||||
|
|
||||||
revocations, err := svc.GetRevokedCertificates()
|
revocations, err := svc.GetRevokedCertificates(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected no error, got: %v", err)
|
t.Fatalf("expected no error, got: %v", err)
|
||||||
}
|
}
|
||||||
@@ -390,7 +390,7 @@ func TestGetRevokedCertificates_NoRepo(t *testing.T) {
|
|||||||
svc := NewCertificateService(certRepo, policyService, auditService)
|
svc := NewCertificateService(certRepo, policyService, auditService)
|
||||||
// Do NOT set revocation repo
|
// Do NOT set revocation repo
|
||||||
|
|
||||||
_, err := svc.GetRevokedCertificates()
|
_, err := svc.GetRevokedCertificates(context.Background())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error when revocation repo not configured")
|
t.Fatal("expected error when revocation repo not configured")
|
||||||
}
|
}
|
||||||
@@ -411,8 +411,8 @@ func TestRevokeCertificate_HandlerInterfaceMethod(t *testing.T) {
|
|||||||
{ID: "ver-handler", CertificateID: "cert-handler", SerialNumber: "SER-HANDLER", CreatedAt: time.Now()},
|
{ID: "ver-handler", CertificateID: "cert-handler", SerialNumber: "SER-HANDLER", CreatedAt: time.Now()},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test the handler interface method (no actor param)
|
// Test the handler interface method (actor collapsed to required positional arg per D-2)
|
||||||
err := svc.RevokeCertificate("cert-handler", "superseded")
|
err := svc.RevokeCertificate(context.Background(), "cert-handler", "superseded", "api")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected no error, got: %v", err)
|
t.Fatalf("expected no error, got: %v", err)
|
||||||
}
|
}
|
||||||
@@ -449,7 +449,7 @@ func TestGenerateDERCRL_Success(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
crl, err := svc.GenerateDERCRL("iss-local")
|
crl, err := svc.GenerateDERCRL(context.Background(), "iss-local")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected no error, got: %v", err)
|
t.Fatalf("expected no error, got: %v", err)
|
||||||
@@ -472,7 +472,7 @@ func TestGenerateDERCRL_EmptyCRL(t *testing.T) {
|
|||||||
// No revoked certs for this issuer
|
// No revoked certs for this issuer
|
||||||
revocationRepo.Revocations = []*domain.CertificateRevocation{}
|
revocationRepo.Revocations = []*domain.CertificateRevocation{}
|
||||||
|
|
||||||
crl, err := svc.GenerateDERCRL("iss-local")
|
crl, err := svc.GenerateDERCRL(context.Background(), "iss-local")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected no error, got: %v", err)
|
t.Fatalf("expected no error, got: %v", err)
|
||||||
@@ -493,7 +493,7 @@ func TestGenerateDERCRL_IssuerNotFound(t *testing.T) {
|
|||||||
svc, _, _, _ := newRevocationTestService()
|
svc, _, _, _ := newRevocationTestService()
|
||||||
|
|
||||||
// Try to generate CRL for unknown issuer
|
// Try to generate CRL for unknown issuer
|
||||||
crl, err := svc.GenerateDERCRL("iss-unknown")
|
crl, err := svc.GenerateDERCRL(context.Background(), "iss-unknown")
|
||||||
|
|
||||||
// Should return error or nil CRL depending on implementation
|
// Should return error or nil CRL depending on implementation
|
||||||
if crl != nil && err == nil {
|
if crl != nil && err == nil {
|
||||||
@@ -527,7 +527,7 @@ func TestGetOCSPResponse_Good(t *testing.T) {
|
|||||||
certRepo.Versions["cert-ocsp-good"] = []*domain.CertificateVersion{version}
|
certRepo.Versions["cert-ocsp-good"] = []*domain.CertificateVersion{version}
|
||||||
|
|
||||||
// Request OCSP response for good cert
|
// Request OCSP response for good cert
|
||||||
resp, err := svc.GetOCSPResponse("iss-local", "OCSP-GOOD-001")
|
resp, err := svc.GetOCSPResponse(context.Background(), "iss-local", "OCSP-GOOD-001")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected no error, got: %v", err)
|
t.Fatalf("expected no error, got: %v", err)
|
||||||
@@ -580,7 +580,7 @@ func TestGetOCSPResponse_Revoked(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Request OCSP response for revoked cert
|
// Request OCSP response for revoked cert
|
||||||
resp, err := svc.GetOCSPResponse("iss-local", "OCSP-REVOKED-001")
|
resp, err := svc.GetOCSPResponse(context.Background(), "iss-local", "OCSP-REVOKED-001")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected no error, got: %v", err)
|
t.Fatalf("expected no error, got: %v", err)
|
||||||
@@ -597,7 +597,7 @@ func TestGetOCSPResponse_Unknown(t *testing.T) {
|
|||||||
svc, _, _, _ := newRevocationTestService()
|
svc, _, _, _ := newRevocationTestService()
|
||||||
|
|
||||||
// Request OCSP response for unknown cert
|
// Request OCSP response for unknown cert
|
||||||
resp, err := svc.GetOCSPResponse("iss-local", "UNKNOWN-SERIAL")
|
resp, err := svc.GetOCSPResponse(context.Background(), "iss-local", "UNKNOWN-SERIAL")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected no error (should return unknown status), got: %v", err)
|
t.Fatalf("expected no error (should return unknown status), got: %v", err)
|
||||||
@@ -615,7 +615,7 @@ func TestGetOCSPResponse_IssuerNotFound(t *testing.T) {
|
|||||||
svc, _, _, _ := newRevocationTestService()
|
svc, _, _, _ := newRevocationTestService()
|
||||||
|
|
||||||
// Request OCSP response for unknown issuer
|
// Request OCSP response for unknown issuer
|
||||||
resp, err := svc.GetOCSPResponse("iss-unknown", "SOME-SERIAL")
|
resp, err := svc.GetOCSPResponse(context.Background(), "iss-unknown", "SOME-SERIAL")
|
||||||
|
|
||||||
// Should return error since issuer doesn't exist
|
// Should return error since issuer doesn't exist
|
||||||
if err == nil && resp != nil {
|
if err == nil && resp != nil {
|
||||||
@@ -629,7 +629,7 @@ func TestGetOCSPResponse_InvalidSerial(t *testing.T) {
|
|||||||
svc, _, _, _ := newRevocationTestService()
|
svc, _, _, _ := newRevocationTestService()
|
||||||
|
|
||||||
// Request OCSP response with invalid serial format
|
// Request OCSP response with invalid serial format
|
||||||
resp, err := svc.GetOCSPResponse("iss-local", "")
|
resp, err := svc.GetOCSPResponse(context.Background(), "iss-local", "")
|
||||||
|
|
||||||
if err == nil && resp != nil {
|
if err == nil && resp != nil {
|
||||||
// Empty serial might return unknown status; that's ok
|
// Empty serial might return unknown status; that's ok
|
||||||
|
|||||||
+21
-19
@@ -36,20 +36,27 @@ func isValidTargetType(t domain.TargetType) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TargetService provides business logic for deployment target management.
|
// TargetService provides business logic for deployment target management.
|
||||||
|
//
|
||||||
|
// The encryptionKey field holds the raw passphrase (not a pre-derived 32-byte
|
||||||
|
// key). Per-ciphertext salt derivation is performed inside
|
||||||
|
// [crypto.EncryptIfKeySet] / [crypto.DecryptIfKeySet] on each call. See M-8
|
||||||
|
// in certctl-audit-report.md.
|
||||||
type TargetService struct {
|
type TargetService struct {
|
||||||
targetRepo repository.TargetRepository
|
targetRepo repository.TargetRepository
|
||||||
agentRepo repository.AgentRepository
|
agentRepo repository.AgentRepository
|
||||||
auditService *AuditService
|
auditService *AuditService
|
||||||
encryptionKey []byte
|
encryptionKey string
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTargetService creates a new target service.
|
// NewTargetService creates a new target service. The encryptionKey is the raw
|
||||||
|
// passphrase; it MUST NOT be pre-derived via crypto.DeriveKey (that was the
|
||||||
|
// v1 behavior, replaced in M-8 with per-ciphertext random salt).
|
||||||
func NewTargetService(
|
func NewTargetService(
|
||||||
targetRepo repository.TargetRepository,
|
targetRepo repository.TargetRepository,
|
||||||
auditService *AuditService,
|
auditService *AuditService,
|
||||||
agentRepo repository.AgentRepository,
|
agentRepo repository.AgentRepository,
|
||||||
encryptionKey []byte,
|
encryptionKey string,
|
||||||
logger *slog.Logger,
|
logger *slog.Logger,
|
||||||
) *TargetService {
|
) *TargetService {
|
||||||
return &TargetService{
|
return &TargetService{
|
||||||
@@ -235,7 +242,7 @@ func (s *TargetService) TestConnection(ctx context.Context, id string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListTargets returns paginated targets (handler interface method).
|
// ListTargets returns paginated targets (handler interface method).
|
||||||
func (s *TargetService) ListTargets(page, perPage int) ([]domain.DeploymentTarget, int64, error) {
|
func (s *TargetService) ListTargets(ctx context.Context, page, perPage int) ([]domain.DeploymentTarget, int64, error) {
|
||||||
if page < 1 {
|
if page < 1 {
|
||||||
page = 1
|
page = 1
|
||||||
}
|
}
|
||||||
@@ -243,7 +250,7 @@ func (s *TargetService) ListTargets(page, perPage int) ([]domain.DeploymentTarge
|
|||||||
perPage = 50
|
perPage = 50
|
||||||
}
|
}
|
||||||
|
|
||||||
targets, err := s.targetRepo.List(context.Background())
|
targets, err := s.targetRepo.List(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to list targets: %w", err)
|
return nil, 0, fmt.Errorf("failed to list targets: %w", err)
|
||||||
}
|
}
|
||||||
@@ -260,12 +267,12 @@ func (s *TargetService) ListTargets(page, perPage int) ([]domain.DeploymentTarge
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetTarget returns a single target (handler interface method).
|
// GetTarget returns a single target (handler interface method).
|
||||||
func (s *TargetService) GetTarget(id string) (*domain.DeploymentTarget, error) {
|
func (s *TargetService) GetTarget(ctx context.Context, id string) (*domain.DeploymentTarget, error) {
|
||||||
return s.targetRepo.Get(context.Background(), id)
|
return s.targetRepo.Get(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateTarget creates a new target (handler interface method).
|
// CreateTarget creates a new target (handler interface method).
|
||||||
func (s *TargetService) CreateTarget(target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
func (s *TargetService) CreateTarget(ctx context.Context, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
||||||
if !isValidTargetType(target.Type) {
|
if !isValidTargetType(target.Type) {
|
||||||
return nil, fmt.Errorf("unsupported target type: %s", target.Type)
|
return nil, fmt.Errorf("unsupported target type: %s", target.Type)
|
||||||
}
|
}
|
||||||
@@ -301,20 +308,20 @@ func (s *TargetService) CreateTarget(target domain.DeploymentTarget) (*domain.De
|
|||||||
target.Config = redactConfigJSON(target.Config)
|
target.Config = redactConfigJSON(target.Config)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.targetRepo.Create(context.Background(), &target); err != nil {
|
if err := s.targetRepo.Create(ctx, &target); err != nil {
|
||||||
return nil, fmt.Errorf("failed to create target: %w", err)
|
return nil, fmt.Errorf("failed to create target: %w", err)
|
||||||
}
|
}
|
||||||
return &target, nil
|
return &target, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateTarget modifies a target (handler interface method).
|
// UpdateTarget modifies a target (handler interface method).
|
||||||
func (s *TargetService) UpdateTarget(id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
func (s *TargetService) UpdateTarget(ctx context.Context, id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
||||||
target.ID = id
|
target.ID = id
|
||||||
target.UpdatedAt = time.Now()
|
target.UpdatedAt = time.Now()
|
||||||
|
|
||||||
// Merge redacted fields with existing config
|
// Merge redacted fields with existing config
|
||||||
if len(target.Config) > 0 {
|
if len(target.Config) > 0 {
|
||||||
mergedConfig, err := s.mergeRedactedConfig(context.Background(), id, target.Config)
|
mergedConfig, err := s.mergeRedactedConfig(ctx, id, target.Config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to merge config: %w", err)
|
return nil, fmt.Errorf("failed to merge config: %w", err)
|
||||||
}
|
}
|
||||||
@@ -327,20 +334,15 @@ func (s *TargetService) UpdateTarget(id string, target domain.DeploymentTarget)
|
|||||||
target.Config = redactConfigJSON(json.RawMessage(mergedConfig))
|
target.Config = redactConfigJSON(json.RawMessage(mergedConfig))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.targetRepo.Update(context.Background(), &target); err != nil {
|
if err := s.targetRepo.Update(ctx, &target); err != nil {
|
||||||
return nil, fmt.Errorf("failed to update target: %w", err)
|
return nil, fmt.Errorf("failed to update target: %w", err)
|
||||||
}
|
}
|
||||||
return &target, nil
|
return &target, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteTarget removes a target (handler interface method).
|
// DeleteTarget removes a target (handler interface method).
|
||||||
func (s *TargetService) DeleteTarget(id string) error {
|
func (s *TargetService) DeleteTarget(ctx context.Context, id string) error {
|
||||||
return s.targetRepo.Delete(context.Background(), id)
|
return s.targetRepo.Delete(ctx, id)
|
||||||
}
|
|
||||||
|
|
||||||
// TestTargetConnection tests target connectivity (handler interface method).
|
|
||||||
func (s *TargetService) TestTargetConnection(id string) error {
|
|
||||||
return s.TestConnection(context.Background(), id)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Internal helpers ---
|
// --- Internal helpers ---
|
||||||
|
|||||||
@@ -344,7 +344,8 @@ func TestTargetService_ListTargets_Success(t *testing.T) {
|
|||||||
targetRepo.AddTarget(target2)
|
targetRepo.AddTarget(target2)
|
||||||
|
|
||||||
// Call handler-interface method
|
// Call handler-interface method
|
||||||
targets, total, err := svc.ListTargets(1, 50)
|
ctx := context.Background()
|
||||||
|
targets, total, err := svc.ListTargets(ctx, 1, 50)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -364,7 +365,8 @@ func TestTargetService_GetTarget_Success(t *testing.T) {
|
|||||||
target := &domain.DeploymentTarget{ID: "t-1", Name: "Target 1", Type: domain.TargetTypeNGINX}
|
target := &domain.DeploymentTarget{ID: "t-1", Name: "Target 1", Type: domain.TargetTypeNGINX}
|
||||||
targetRepo.AddTarget(target)
|
targetRepo.AddTarget(target)
|
||||||
|
|
||||||
result, err := svc.GetTarget("t-1")
|
ctx := context.Background()
|
||||||
|
result, err := svc.GetTarget(ctx, "t-1")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -382,7 +384,8 @@ func TestTargetService_CreateTarget_Success(t *testing.T) {
|
|||||||
Type: domain.TargetTypeNGINX,
|
Type: domain.TargetTypeNGINX,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.CreateTarget(target)
|
ctx := context.Background()
|
||||||
|
result, err := svc.CreateTarget(ctx, target)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -405,7 +408,8 @@ func TestTargetService_CreateTarget_InvalidType(t *testing.T) {
|
|||||||
Type: domain.TargetType("Unknown"),
|
Type: domain.TargetType("Unknown"),
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := svc.CreateTarget(target)
|
ctx := context.Background()
|
||||||
|
_, err := svc.CreateTarget(ctx, target)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("expected error for invalid type, got nil")
|
t.Fatalf("expected error for invalid type, got nil")
|
||||||
}
|
}
|
||||||
@@ -424,7 +428,8 @@ func TestTargetService_UpdateTarget_Success(t *testing.T) {
|
|||||||
Type: domain.TargetTypeApache,
|
Type: domain.TargetTypeApache,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.UpdateTarget("t-1", updated)
|
ctx := context.Background()
|
||||||
|
result, err := svc.UpdateTarget(ctx, "t-1", updated)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -442,7 +447,8 @@ func TestTargetService_DeleteTarget_Success(t *testing.T) {
|
|||||||
targetRepo.AddTarget(target)
|
targetRepo.AddTarget(target)
|
||||||
|
|
||||||
// Delete it
|
// Delete it
|
||||||
err := svc.DeleteTarget("t-1")
|
ctx := context.Background()
|
||||||
|
err := svc.DeleteTarget(ctx, "t-1")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
+10
-10
@@ -126,7 +126,7 @@ func (s *TeamService) Delete(ctx context.Context, id string, actor string) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListTeams returns paginated teams (handler interface method).
|
// ListTeams returns paginated teams (handler interface method).
|
||||||
func (s *TeamService) ListTeams(page, perPage int) ([]domain.Team, int64, error) {
|
func (s *TeamService) ListTeams(ctx context.Context, page, perPage int) ([]domain.Team, int64, error) {
|
||||||
if page < 1 {
|
if page < 1 {
|
||||||
page = 1
|
page = 1
|
||||||
}
|
}
|
||||||
@@ -134,7 +134,7 @@ func (s *TeamService) ListTeams(page, perPage int) ([]domain.Team, int64, error)
|
|||||||
perPage = 50
|
perPage = 50
|
||||||
}
|
}
|
||||||
|
|
||||||
teams, err := s.teamRepo.List(context.Background())
|
teams, err := s.teamRepo.List(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to list teams: %w", err)
|
return nil, 0, fmt.Errorf("failed to list teams: %w", err)
|
||||||
}
|
}
|
||||||
@@ -151,12 +151,12 @@ func (s *TeamService) ListTeams(page, perPage int) ([]domain.Team, int64, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetTeam returns a single team (handler interface method).
|
// GetTeam returns a single team (handler interface method).
|
||||||
func (s *TeamService) GetTeam(id string) (*domain.Team, error) {
|
func (s *TeamService) GetTeam(ctx context.Context, id string) (*domain.Team, error) {
|
||||||
return s.teamRepo.Get(context.Background(), id)
|
return s.teamRepo.Get(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateTeam creates a new team (handler interface method).
|
// CreateTeam creates a new team (handler interface method).
|
||||||
func (s *TeamService) CreateTeam(team domain.Team) (*domain.Team, error) {
|
func (s *TeamService) CreateTeam(ctx context.Context, team domain.Team) (*domain.Team, error) {
|
||||||
if team.ID == "" {
|
if team.ID == "" {
|
||||||
team.ID = generateID("team")
|
team.ID = generateID("team")
|
||||||
}
|
}
|
||||||
@@ -167,22 +167,22 @@ func (s *TeamService) CreateTeam(team domain.Team) (*domain.Team, error) {
|
|||||||
if team.UpdatedAt.IsZero() {
|
if team.UpdatedAt.IsZero() {
|
||||||
team.UpdatedAt = now
|
team.UpdatedAt = now
|
||||||
}
|
}
|
||||||
if err := s.teamRepo.Create(context.Background(), &team); err != nil {
|
if err := s.teamRepo.Create(ctx, &team); err != nil {
|
||||||
return nil, fmt.Errorf("failed to create team: %w", err)
|
return nil, fmt.Errorf("failed to create team: %w", err)
|
||||||
}
|
}
|
||||||
return &team, nil
|
return &team, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateTeam modifies a team (handler interface method).
|
// UpdateTeam modifies a team (handler interface method).
|
||||||
func (s *TeamService) UpdateTeam(id string, team domain.Team) (*domain.Team, error) {
|
func (s *TeamService) UpdateTeam(ctx context.Context, id string, team domain.Team) (*domain.Team, error) {
|
||||||
team.ID = id
|
team.ID = id
|
||||||
if err := s.teamRepo.Update(context.Background(), &team); err != nil {
|
if err := s.teamRepo.Update(ctx, &team); err != nil {
|
||||||
return nil, fmt.Errorf("failed to update team: %w", err)
|
return nil, fmt.Errorf("failed to update team: %w", err)
|
||||||
}
|
}
|
||||||
return &team, nil
|
return &team, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteTeam removes a team (handler interface method).
|
// DeleteTeam removes a team (handler interface method).
|
||||||
func (s *TeamService) DeleteTeam(id string) error {
|
func (s *TeamService) DeleteTeam(ctx context.Context, id string) error {
|
||||||
return s.teamRepo.Delete(context.Background(), id)
|
return s.teamRepo.Delete(ctx, id)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -544,7 +544,7 @@ func TestTeamService_ListTeams_HandlerInterface(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
teams, total, err := teamService.ListTeams(1, 2)
|
teams, total, err := teamService.ListTeams(context.Background(), 1, 2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -571,7 +571,7 @@ func TestTeamService_GetTeam_HandlerInterface(t *testing.T) {
|
|||||||
}
|
}
|
||||||
mockTeamRepo.AddTeam(testTeam)
|
mockTeamRepo.AddTeam(testTeam)
|
||||||
|
|
||||||
team, err := teamService.GetTeam("handler-team")
|
team, err := teamService.GetTeam(context.Background(), "handler-team")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -593,7 +593,7 @@ func TestTeamService_CreateTeam_HandlerInterface(t *testing.T) {
|
|||||||
Description: "Created via handler",
|
Description: "Created via handler",
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := teamService.CreateTeam(team)
|
result, err := teamService.CreateTeam(context.Background(), team)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -629,7 +629,7 @@ func TestTeamService_UpdateTeam_HandlerInterface(t *testing.T) {
|
|||||||
Description: "Handler update",
|
Description: "Handler update",
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := teamService.UpdateTeam("handler-update-team", updateTeam)
|
result, err := teamService.UpdateTeam(context.Background(), "handler-update-team", updateTeam)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -656,7 +656,7 @@ func TestTeamService_DeleteTeam_HandlerInterface(t *testing.T) {
|
|||||||
Name: "To Delete",
|
Name: "To Delete",
|
||||||
})
|
})
|
||||||
|
|
||||||
err := teamService.DeleteTeam("handler-delete-team")
|
err := teamService.DeleteTeam(context.Background(), "handler-delete-team")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,12 +12,15 @@ import (
|
|||||||
|
|
||||||
var errNotFound = errors.New("not found")
|
var errNotFound = errors.New("not found")
|
||||||
|
|
||||||
// testEncryptionKey is a deterministic 32-byte AES-256 key for unit tests that
|
// testEncryptionKey is a deterministic passphrase for unit tests that
|
||||||
// exercise IssuerService/TargetService write paths. After the C-2 remediation
|
// exercise IssuerService/TargetService write paths. After the C-2 remediation
|
||||||
// these services fail closed when no key is configured, so happy-path tests
|
// these services fail closed when no key is configured, so happy-path tests
|
||||||
// must supply a real key. Using a constant keeps wire-format assertions stable
|
// must supply a real passphrase. M-8 reshaped the type from []byte to string
|
||||||
// across runs and avoids flaky PBKDF2 timing.
|
// because services now hold the raw passphrase and delegate PBKDF2 to
|
||||||
var testEncryptionKey = []byte("0123456789abcdef0123456789abcdef") // 32 bytes
|
// crypto.EncryptIfKeySet / crypto.DecryptIfKeySet (which apply a fresh random
|
||||||
|
// salt per ciphertext). Using a constant keeps wire-format assertions stable
|
||||||
|
// across runs.
|
||||||
|
var testEncryptionKey = "0123456789abcdef0123456789abcdef"
|
||||||
|
|
||||||
// mockCertRepo is a test implementation of CertificateRepository
|
// mockCertRepo is a test implementation of CertificateRepository
|
||||||
type mockCertRepo struct {
|
type mockCertRepo struct {
|
||||||
@@ -599,6 +602,19 @@ func (m *mockAgentRepo) Create(ctx context.Context, agent *domain.Agent) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockAgentRepo) CreateIfNotExists(ctx context.Context, agent *domain.Agent) (bool, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if m.CreateErr != nil {
|
||||||
|
return false, m.CreateErr
|
||||||
|
}
|
||||||
|
if _, exists := m.Agents[agent.ID]; exists {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
m.Agents[agent.ID] = agent
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockAgentRepo) Update(ctx context.Context, agent *domain.Agent) error {
|
func (m *mockAgentRepo) Update(ctx context.Context, agent *domain.Agent) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import { useState } from 'react';
|
|||||||
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query';
|
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query';
|
||||||
import { useNavigate, Link } from 'react-router-dom';
|
import { useNavigate, Link } from 'react-router-dom';
|
||||||
import {
|
import {
|
||||||
getIssuers, getAgents, getProfiles,
|
getIssuers, getAgents, getProfiles, getOwners,
|
||||||
createIssuer, testIssuerConnection,
|
createIssuer, testIssuerConnection,
|
||||||
createCertificate, triggerRenewal,
|
createCertificate, triggerRenewal,
|
||||||
getApiKey,
|
getApiKey,
|
||||||
@@ -404,12 +404,14 @@ function CertificateStep({ onNext, onSkip, createdIssuerId }: {
|
|||||||
const [sans, setSans] = useState('');
|
const [sans, setSans] = useState('');
|
||||||
const [issuerId, setIssuerId] = useState(createdIssuerId || '');
|
const [issuerId, setIssuerId] = useState(createdIssuerId || '');
|
||||||
const [profileId, setProfileId] = useState('');
|
const [profileId, setProfileId] = useState('');
|
||||||
|
const [ownerId, setOwnerId] = useState('');
|
||||||
const [error, setError] = useState('');
|
const [error, setError] = useState('');
|
||||||
const [created, setCreated] = useState(false);
|
const [created, setCreated] = useState(false);
|
||||||
|
|
||||||
const { data: issuers } = useQuery({ queryKey: ['issuers'], queryFn: () => getIssuers() });
|
const { data: issuers } = useQuery({ queryKey: ['issuers'], queryFn: () => getIssuers() });
|
||||||
const { data: profiles } = useQuery({ queryKey: ['profiles'], queryFn: () => getProfiles() });
|
const { data: profiles } = useQuery({ queryKey: ['profiles'], queryFn: () => getProfiles() });
|
||||||
const { data: agents } = useQuery({ queryKey: ['agents'], queryFn: () => getAgents() });
|
const { data: agents } = useQuery({ queryKey: ['agents'], queryFn: () => getAgents() });
|
||||||
|
const { data: owners } = useQuery({ queryKey: ['owners'], queryFn: () => getOwners() });
|
||||||
|
|
||||||
const hasAgents = (agents?.data?.length ?? 0) > 0;
|
const hasAgents = (agents?.data?.length ?? 0) > 0;
|
||||||
|
|
||||||
@@ -421,6 +423,7 @@ function CertificateStep({ onNext, onSkip, createdIssuerId }: {
|
|||||||
sans: sanList,
|
sans: sanList,
|
||||||
issuer_id: issuerId,
|
issuer_id: issuerId,
|
||||||
certificate_profile_id: profileId || undefined,
|
certificate_profile_id: profileId || undefined,
|
||||||
|
owner_id: ownerId,
|
||||||
environment: 'production',
|
environment: 'production',
|
||||||
});
|
});
|
||||||
// Trigger issuance
|
// Trigger issuance
|
||||||
@@ -521,6 +524,29 @@ function CertificateStep({ onNext, onSkip, createdIssuerId }: {
|
|||||||
</select>
|
</select>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-ink mb-2">
|
||||||
|
Owner <span className="text-red-600">*</span>
|
||||||
|
</label>
|
||||||
|
<select
|
||||||
|
value={ownerId}
|
||||||
|
onChange={e => setOwnerId(e.target.value)}
|
||||||
|
className="w-full px-3 py-2 bg-surface border border-surface-border rounded text-ink focus:outline-none focus:border-brand-500 transition-colors"
|
||||||
|
>
|
||||||
|
<option value="">Select owner...</option>
|
||||||
|
{owners?.data?.map(o => (
|
||||||
|
<option key={o.id} value={o.id}>
|
||||||
|
{o.name}{o.email ? ` (${o.email})` : ''}
|
||||||
|
</option>
|
||||||
|
))}
|
||||||
|
</select>
|
||||||
|
{(owners?.data?.length ?? 0) === 0 && (
|
||||||
|
<p className="mt-1 text-xs text-ink-muted">
|
||||||
|
No owners yet — create one from the <Link to="/owners" className="underline hover:text-ink">Owners page</Link> first, then return here.
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Discovery hint */}
|
{/* Discovery hint */}
|
||||||
@@ -547,7 +573,7 @@ function CertificateStep({ onNext, onSkip, createdIssuerId }: {
|
|||||||
onSkip={onSkip}
|
onSkip={onSkip}
|
||||||
onNext={() => createMutation.mutate()}
|
onNext={() => createMutation.mutate()}
|
||||||
nextLabel={createMutation.isPending ? 'Creating...' : 'Issue Certificate'}
|
nextLabel={createMutation.isPending ? 'Creating...' : 'Issue Certificate'}
|
||||||
nextDisabled={!commonName || !issuerId || createMutation.isPending}
|
nextDisabled={!commonName || !issuerId || !ownerId || createMutation.isPending}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
Reference in New Issue
Block a user