Files
certctl/internal/repository/postgres/repo_test.go
T
shankar0123 a3d8b9c607 fix(deploy,db,handler): close fresh-clone postgres init failure + 4 ride-along audit findings (U-3 master)
GitHub #10 reopened: operator mikeakasully cloned v2.0.50 fresh and ran the
canonical quickstart (docker compose -f deploy/docker-compose.yml up -d --build);
postgres reported unhealthy indefinitely, dependent containers never started.

Root cause: deploy/docker-compose.yml mounted a hand-curated subset of
migrations/*.up.sql + seed.sql into postgres /docker-entrypoint-initdb.d/.
Postgres applied them at initdb time. Once seed.sql referenced columns added
by migrations *after* the mounted cutoff (e.g., policy_rules.severity from
migration 000013), initdb crashed mid-seed and the container loop wedged.
Two sources of truth (compose mount list vs in-tree migration ladder)
diverged the moment a seed-touching migration shipped, and the only thing
that fixed it was hand-editing the compose file every release.

Fix: remove the dual source. Postgres boots empty; the server applies
migrations + seed at startup via RunMigrations + RunSeed. Helm has used
this pattern since day one (postgres-init emptyDir); compose now matches.

Bundled with four ride-along audit findings whose fixes share the same
schema/db code surface, so operators take the schema-change pain only once:

  cat-u-seed_initdb_schema_drift           [P1, primary] — initdb-mount fix
  cat-o-retry_interval_unit_mismatch       [P1] — column rename minutes→seconds
  cat-o-notification_created_at_dead_field [P2] — add column + populate
  cat-o-health_check_column_orphans        [P1] — drop unwired columns
  cat-u-no_version_endpoint                [P2] — add /api/v1/version

Single migration (000017_db_coupling_cleanup) bundles the three schema
changes under a DO \$\$ guard so re-application is safe; reduces
operator-visible 'schema-change releases' from four to one.

Backend
- internal/repository/postgres/db.go: add RunSeed (baseline) + RunDemoSeed
  (gated by CERTCTL_DEMO_SEED). Both idempotent (ON CONFLICT DO NOTHING in
  every shipped INSERT) so repeated boots are safe; missing-file is no-op
  so custom packaging that strips seeds still boots cleanly.
- cmd/server/main.go: invoke RunSeed (always) + RunDemoSeed (when flag set)
  immediately after RunMigrations.
- internal/repository/postgres/notification.go: NotificationRepository.Create
  now sets created_at (with time.Now() fallback when caller leaves it zero);
  scanNotification reads it back; List + ListRetryEligible SELECT extended.
- internal/repository/postgres/renewal_policy.go: column references updated
  to retry_interval_seconds across SELECT/INSERT/UPDATE sites.
- internal/api/handler/version.go: new VersionHandler exposes
  {version, commit, modified, build_time, go_version} from
  runtime/debug.ReadBuildInfo() with ldflags-supplied Version override.
- internal/api/router/router.go: register GET /api/v1/version through the
  no-auth chain (CORS + ContentType) alongside /health, /ready,
  /api/v1/auth/info.
- cmd/server/main.go: add /api/v1/version to no-auth dispatch + audit
  ExcludePaths so rollout polling doesn't dominate the audit trail.
- internal/config/config.go: add DatabaseConfig.DemoSeed +
  CERTCTL_DEMO_SEED env var.

Migration
- migrations/000017_db_coupling_cleanup.up.sql + .down.sql:
    (1) renewal_policies.retry_interval_minutes → retry_interval_seconds
        (DO \$\$ guard, idempotent re-application)
    (2) notification_events ADD COLUMN created_at TIMESTAMPTZ
        NOT NULL DEFAULT NOW()
    (3) network_scan_targets DROP orphan health_check_enabled +
        health_check_interval_seconds
- migrations/seed.sql: column reference updated to retry_interval_seconds.
- migrations/seed_demo.sql: same column rename + applied at runtime now via
  RunDemoSeed (no longer initdb-mounted).

Compose
- deploy/docker-compose.yml: drop ALL initdb mounts (10 migration files +
  seed.sql); add start_period: 30s to postgres + certctl-server healthchecks
  to absorb the runtime migration + seed application window on first boot.
- deploy/docker-compose.test.yml: same drop (+ ghost seed_test.sql mount
  removed; that file never existed); same healthcheck start_period.
- deploy/docker-compose.demo.yml: replace seed_demo.sql initdb mount with
  CERTCTL_DEMO_SEED=true env var on certctl-server.

Tests
- internal/api/handler/version_handler_test.go: TestVersion_ReturnsBuildInfo,
  TestVersion_RejectsNonGet, TestVersion_LdflagsOverride.
- internal/repository/postgres/seed_test.go: TestRunSeed_AppliesIdempotently,
  TestRunSeed_MissingFileIsNoOp, TestRunDemoSeed_AppliesIdempotently,
  TestMigration000017_RetryIntervalRename,
  TestMigration000017_NotificationCreatedAt,
  TestMigration000017_HealthCheckOrphansDropped (testcontainers, -short skips).
- internal/repository/postgres/notification_test.go:
  TestNotificationRepository_CreatedAt_IsPersisted +
  TestNotificationRepository_CreatedAt_DefaultsToNow.

CI guardrail
- .github/workflows/ci.yml: new 'Forbidden migration mount in compose initdb
  (U-3)' step grep-fails the build if any migrations/*.sql or seed*.sql
  re-appears in /docker-entrypoint-initdb.d in any compose file. Catches
  future drift before a fresh-clone operator hits it.

Spec / Docs
- api/openapi.yaml: add /api/v1/version operation under Health tag.
- docs/architecture.md: replace the 'initdb may run the same SQL' paragraph
  with a post-U-3 single-source-of-truth explanation.
- CHANGELOG.md: full unreleased-section entry covering all 5 closures,
  breaking changes, and the new env var.

Audit doc
- coverage-gap-audit-2026-04-24-v5/unified-audit.md: add new P1 #14
  cat-u-seed_initdb_schema_drift; flip the 4 ride-along findings to
   RESOLVED with closure prose pointing at this commit.

Verification: build/vet/test -short -race all clean across all touched
packages locally; govulncheck reports 0 vulnerabilities affecting our
code; OpenAPI YAML parses; CI U-3 grep guardrail clears against the
post-fix tree.
2026-04-25 13:29:23 +00:00

2208 lines
68 KiB
Go

// Package postgres_test provides repository integration tests covering 17 of 17
// PostgreSQL repository files. Each test function exercises CRUD operations,
// edge cases, and deduplication logic against a real database. HealthCheck
// and RenewalPolicy integration tests live in sibling *_test.go files in this
// package (see health_check_test.go and renewal_policy_test.go).
package postgres_test
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/repository"
"github.com/shankar0123/certctl/internal/repository/postgres"
)
// Shared test database — started once, reused across tests in this package.
// Each test creates its own schema for isolation.
var sharedDB *testDB
func TestMain(m *testing.M) {
// Note: We can't use setupTestDB here because it needs a *testing.T.
// Instead, each top-level test function calls setupTestDB if sharedDB is nil.
// This is handled by the getTestDB helper.
m.Run()
}
// getTestDB lazily initializes the shared container.
// In practice, the first test to call this starts the container.
func getTestDB(t *testing.T) *testDB {
t.Helper()
if sharedDB == nil {
sharedDB = setupTestDB(t)
// Register cleanup at the end of the entire test run
t.Cleanup(func() {
sharedDB.teardown(t)
sharedDB = nil
})
}
return sharedDB
}
// insertCertPrereqsRaw creates prerequisite FK records using raw SQL on the *sql.DB.
func insertCertPrereqsRaw(t *testing.T, db *sql.DB, ctx context.Context, suffix string) (ownerID, teamID, issuerID, policyID string) {
t.Helper()
teamID = "team-" + suffix
ownerID = "o-" + suffix
issuerID = "iss-" + suffix
policyID = "pol-" + suffix
now := time.Now().Truncate(time.Microsecond)
// Create team
_, err := db.ExecContext(ctx, `INSERT INTO teams (id, name, created_at, updated_at) VALUES ($1, $2, $3, $4)`,
teamID, "Team "+suffix, now, now)
if err != nil {
t.Fatalf("insertCertPrereqs: create team failed: %v", err)
}
// Create owner (requires team)
_, err = db.ExecContext(ctx, `INSERT INTO owners (id, name, email, team_id, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6)`,
ownerID, "Owner "+suffix, suffix+"@example.com", teamID, now, now)
if err != nil {
t.Fatalf("insertCertPrereqs: create owner failed: %v", err)
}
// Create issuer
_, err = db.ExecContext(ctx, `INSERT INTO issuers (id, name, type, enabled, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6)`,
issuerID, "Issuer "+suffix, "generic-ca", true, now, now)
if err != nil {
t.Fatalf("insertCertPrereqs: create issuer failed: %v", err)
}
// Create renewal policy
_, err = db.ExecContext(ctx, `INSERT INTO renewal_policies (id, name, renewal_window_days, auto_renew, max_retries, retry_interval_seconds, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
policyID, "Policy "+suffix, 30, true, 3, 60, now, now)
if err != nil {
t.Fatalf("insertCertPrereqs: create renewal_policy failed: %v", err)
}
return
}
// ============================================================
// Certificate Repository Tests
// ============================================================
func TestCertificateRepository_CRUD(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewCertificateRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
expires := now.Add(90 * 24 * time.Hour)
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "crud")
cert := &domain.ManagedCertificate{
ID: "mc-test-crud",
Name: "test-cert",
CommonName: "test.example.com",
SANs: []string{"test.example.com", "www.test.example.com"},
Environment: "production",
OwnerID: ownerID,
TeamID: teamID,
IssuerID: issuerID,
RenewalPolicyID: policyID,
Status: domain.CertificateStatusActive,
ExpiresAt: expires,
Tags: map[string]string{"team": "platform"},
CreatedAt: now,
UpdatedAt: now,
}
// Create
err := repo.Create(ctx, cert)
if err != nil {
t.Fatalf("Create failed: %v", err)
}
// Get
got, err := repo.Get(ctx, "mc-test-crud")
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if got.CommonName != "test.example.com" {
t.Errorf("CommonName = %q, want %q", got.CommonName, "test.example.com")
}
if len(got.SANs) != 2 {
t.Errorf("SANs length = %d, want 2", len(got.SANs))
}
if got.Tags["team"] != "platform" {
t.Errorf("Tags[team] = %q, want %q", got.Tags["team"], "platform")
}
// Update
cert.Status = domain.CertificateStatusExpiring
cert.UpdatedAt = time.Now().Truncate(time.Microsecond)
err = repo.Update(ctx, cert)
if err != nil {
t.Fatalf("Update failed: %v", err)
}
got, _ = repo.Get(ctx, "mc-test-crud")
if got.Status != domain.CertificateStatusExpiring {
t.Errorf("Status = %q, want %q", got.Status, domain.CertificateStatusExpiring)
}
// Archive
err = repo.Archive(ctx, "mc-test-crud")
if err != nil {
t.Fatalf("Archive failed: %v", err)
}
got, _ = repo.Get(ctx, "mc-test-crud")
if got.Status != domain.CertificateStatusArchived {
t.Errorf("Status after archive = %q, want %q", got.Status, domain.CertificateStatusArchived)
}
}
func TestCertificateRepository_List_Filtering(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewCertificateRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "listfilt")
// Create test certs in different states
for _, tc := range []struct {
id string
status domain.CertificateStatus
env string
}{
{"mc-list-1", domain.CertificateStatusActive, "production"},
{"mc-list-2", domain.CertificateStatusActive, "staging"},
{"mc-list-3", domain.CertificateStatusExpired, "production"},
} {
cert := &domain.ManagedCertificate{
ID: tc.id,
Name: tc.id,
CommonName: tc.id + ".example.com",
SANs: []string{},
Environment: tc.env,
OwnerID: ownerID,
TeamID: teamID,
IssuerID: issuerID,
RenewalPolicyID: policyID,
Status: tc.status,
ExpiresAt: now.Add(30 * 24 * time.Hour),
Tags: map[string]string{},
CreatedAt: now,
UpdatedAt: now,
}
if err := repo.Create(ctx, cert); err != nil {
t.Fatalf("Create %s failed: %v", tc.id, err)
}
}
// Filter by status
certs, total, err := repo.List(ctx, &repository.CertificateFilter{Status: "Active"})
if err != nil {
t.Fatalf("List with status filter failed: %v", err)
}
if total != 2 {
t.Errorf("total Active = %d, want 2", total)
}
if len(certs) != 2 {
t.Errorf("len(certs) = %d, want 2", len(certs))
}
// Filter by environment
_, total, err = repo.List(ctx, &repository.CertificateFilter{Environment: "production"})
if err != nil {
t.Fatalf("List with env filter failed: %v", err)
}
if total != 2 {
t.Errorf("total production = %d, want 2", total)
}
// Nil filter returns all
_, total, err = repo.List(ctx, nil)
if err != nil {
t.Fatalf("List with nil filter failed: %v", err)
}
if total != 3 {
t.Errorf("total all = %d, want 3", total)
}
}
func TestCertificateRepository_Versions(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewCertificateRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "ver")
// Create parent cert
cert := &domain.ManagedCertificate{
ID: "mc-ver-test", Name: "ver-test", CommonName: "ver.example.com",
SANs: []string{}, OwnerID: ownerID, TeamID: teamID, IssuerID: issuerID,
RenewalPolicyID: policyID, Status: domain.CertificateStatusActive,
ExpiresAt: now.Add(30 * 24 * time.Hour), Tags: map[string]string{},
CreatedAt: now, UpdatedAt: now,
}
if err := repo.Create(ctx, cert); err != nil {
t.Fatalf("Create cert failed: %v", err)
}
// Create two versions
v1 := &domain.CertificateVersion{
ID: "v-1", CertificateID: "mc-ver-test", SerialNumber: "AABB01",
NotBefore: now, NotAfter: now.Add(90 * 24 * time.Hour),
FingerprintSHA256: "sha256-v1", PEMChain: "---BEGIN---", CSRPEM: "---CSR---",
CreatedAt: now,
}
v2 := &domain.CertificateVersion{
ID: "v-2", CertificateID: "mc-ver-test", SerialNumber: "AABB02",
NotBefore: now, NotAfter: now.Add(180 * 24 * time.Hour),
FingerprintSHA256: "sha256-v2", PEMChain: "---BEGIN2---", CSRPEM: "---CSR2---",
CreatedAt: now.Add(1 * time.Second),
}
if err := repo.CreateVersion(ctx, v1); err != nil {
t.Fatalf("CreateVersion v1 failed: %v", err)
}
if err := repo.CreateVersion(ctx, v2); err != nil {
t.Fatalf("CreateVersion v2 failed: %v", err)
}
// ListVersions
versions, err := repo.ListVersions(ctx, "mc-ver-test")
if err != nil {
t.Fatalf("ListVersions failed: %v", err)
}
if len(versions) != 2 {
t.Errorf("len(versions) = %d, want 2", len(versions))
}
// GetLatestVersion
latest, err := repo.GetLatestVersion(ctx, "mc-ver-test")
if err != nil {
t.Fatalf("GetLatestVersion failed: %v", err)
}
if latest.SerialNumber != "AABB02" {
t.Errorf("latest serial = %q, want %q", latest.SerialNumber, "AABB02")
}
}
func TestCertificateRepository_GetExpiringCertificates(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewCertificateRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "exp")
// One expiring soon, one far out
for _, tc := range []struct {
id string
expires time.Time
}{
{"mc-exp-soon", now.Add(5 * 24 * time.Hour)},
{"mc-exp-far", now.Add(365 * 24 * time.Hour)},
} {
cert := &domain.ManagedCertificate{
ID: tc.id, Name: tc.id, CommonName: tc.id + ".example.com",
SANs: []string{}, OwnerID: ownerID, TeamID: teamID,
IssuerID: issuerID, RenewalPolicyID: policyID,
Status: domain.CertificateStatusActive,
ExpiresAt: tc.expires, Tags: map[string]string{},
CreatedAt: now, UpdatedAt: now,
}
if err := repo.Create(ctx, cert); err != nil {
t.Fatalf("Create %s failed: %v", tc.id, err)
}
}
expiring, err := repo.GetExpiringCertificates(ctx, now.Add(30*24*time.Hour))
if err != nil {
t.Fatalf("GetExpiringCertificates failed: %v", err)
}
if len(expiring) != 1 {
t.Errorf("len(expiring) = %d, want 1", len(expiring))
}
if len(expiring) > 0 && expiring[0].ID != "mc-exp-soon" {
t.Errorf("expiring[0].ID = %q, want %q", expiring[0].ID, "mc-exp-soon")
}
}
func TestCertificateRepository_Get_NotFound(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewCertificateRepository(db)
_, err := repo.Get(context.Background(), "nonexistent")
if err == nil {
t.Error("expected error for nonexistent cert, got nil")
}
}
func TestCertificateRepository_Update_NotFound(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewCertificateRepository(db)
err := repo.Update(context.Background(), &domain.ManagedCertificate{
ID: "nonexistent", Tags: map[string]string{},
})
if err == nil {
t.Error("expected error for nonexistent update, got nil")
}
}
// ============================================================
// Agent Repository Tests
// ============================================================
func TestAgentRepository_CRUD(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewAgentRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
agent := &domain.Agent{
ID: "agent-test-1",
Name: "test-agent",
Hostname: "host1.local",
Status: domain.AgentStatusOnline,
RegisteredAt: now,
LastHeartbeatAt: &now,
APIKeyHash: "abc123hash",
OS: "linux",
Architecture: "amd64",
IPAddress: "10.0.0.1",
Version: "1.0.0",
}
// Create
if err := repo.Create(ctx, agent); err != nil {
t.Fatalf("Create failed: %v", err)
}
// Get
got, err := repo.Get(ctx, "agent-test-1")
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if got.Hostname != "host1.local" {
t.Errorf("Hostname = %q, want %q", got.Hostname, "host1.local")
}
if got.OS != "linux" {
t.Errorf("OS = %q, want %q", got.OS, "linux")
}
// List
agents, err := repo.List(ctx)
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(agents) != 1 {
t.Errorf("len(agents) = %d, want 1", len(agents))
}
// UpdateHeartbeat with metadata
metadata := &domain.AgentMetadata{
OS: "linux", Architecture: "arm64", Hostname: "host1-updated.local",
IPAddress: "10.0.0.2", Version: "1.1.0",
}
if err := repo.UpdateHeartbeat(ctx, "agent-test-1", metadata); err != nil {
t.Fatalf("UpdateHeartbeat failed: %v", err)
}
got, _ = repo.Get(ctx, "agent-test-1")
if got.Architecture != "arm64" {
t.Errorf("Architecture after heartbeat = %q, want %q", got.Architecture, "arm64")
}
// GetByAPIKey
got, err = repo.GetByAPIKey(ctx, "abc123hash")
if err != nil {
t.Fatalf("GetByAPIKey failed: %v", err)
}
if got.ID != "agent-test-1" {
t.Errorf("GetByAPIKey ID = %q, want %q", got.ID, "agent-test-1")
}
// Delete
if err := repo.Delete(ctx, "agent-test-1"); err != nil {
t.Fatalf("Delete failed: %v", err)
}
_, err = repo.Get(ctx, "agent-test-1")
if err == nil {
t.Error("expected error after delete, got nil")
}
}
func TestAgentRepository_Delete_NotFound(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewAgentRepository(db)
err := repo.Delete(context.Background(), "nonexistent")
if err == nil {
t.Error("expected error for nonexistent delete, got nil")
}
}
// TestAgentRepository_CreateIfNotExists_FirstInsert verifies that a brand-new
// sentinel agent row is inserted and the helper reports created=true (M-6).
func TestAgentRepository_CreateIfNotExists_FirstInsert(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewAgentRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
agent := &domain.Agent{
ID: "server-scanner",
Name: "Network Scanner (Server-Side)",
Status: domain.AgentStatusOnline,
RegisteredAt: now,
}
created, err := repo.CreateIfNotExists(ctx, agent)
if err != nil {
t.Fatalf("CreateIfNotExists failed: %v", err)
}
if !created {
t.Error("created = false on first insert, want true")
}
got, err := repo.Get(ctx, "server-scanner")
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if got.Name != "Network Scanner (Server-Side)" {
t.Errorf("Name = %q, want %q", got.Name, "Network Scanner (Server-Side)")
}
}
// TestAgentRepository_CreateIfNotExists_Idempotent verifies that a second
// call with the same ID returns created=false and err=nil without mutating
// the existing row — the core M-6 upgrade/restart scenario (CWE-662).
func TestAgentRepository_CreateIfNotExists_Idempotent(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewAgentRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
first := &domain.Agent{
ID: "cloud-aws-sm",
Name: "AWS Secrets Manager Discovery",
Status: domain.AgentStatusOnline,
RegisteredAt: now,
}
created, err := repo.CreateIfNotExists(ctx, first)
if err != nil {
t.Fatalf("first CreateIfNotExists failed: %v", err)
}
if !created {
t.Fatal("first created = false, want true")
}
// Second call with the same ID but a different name must be a no-op.
second := &domain.Agent{
ID: "cloud-aws-sm",
Name: "Overwritten Name Should Not Persist",
Status: domain.AgentStatusOffline,
RegisteredAt: now.Add(time.Hour),
}
created, err = repo.CreateIfNotExists(ctx, second)
if err != nil {
t.Fatalf("second CreateIfNotExists failed: %v", err)
}
if created {
t.Error("second created = true, want false (row already existed)")
}
// Row must still reflect the original insert.
got, err := repo.Get(ctx, "cloud-aws-sm")
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if got.Name != "AWS Secrets Manager Discovery" {
t.Errorf("Name = %q, want %q (ON CONFLICT DO NOTHING must preserve original row)", got.Name, "AWS Secrets Manager Discovery")
}
if got.Status != domain.AgentStatusOnline {
t.Errorf("Status = %q, want %q", got.Status, domain.AgentStatusOnline)
}
}
// TestAgentRepository_CreateIfNotExists_ConcurrentRace fires N concurrent
// inserts for the same sentinel ID. Exactly one goroutine must see
// created=true; every other must see created=false and err=nil. No panics,
// no duplicate rows, no swallowed errors. This is the scenario that the
// pre-M-6 plain-INSERT path masked with a blanket error log.
func TestAgentRepository_CreateIfNotExists_ConcurrentRace(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewAgentRepository(db)
ctx := context.Background()
const N = 16
now := time.Now().Truncate(time.Microsecond)
var (
wg sync.WaitGroup
createdCount int64
errorCount int64
)
wg.Add(N)
for i := 0; i < N; i++ {
go func() {
defer wg.Done()
agent := &domain.Agent{
ID: "cloud-gcp-sm",
Name: "GCP Secret Manager Discovery",
Status: domain.AgentStatusOnline,
RegisteredAt: now,
}
created, err := repo.CreateIfNotExists(ctx, agent)
if err != nil {
atomic.AddInt64(&errorCount, 1)
t.Errorf("CreateIfNotExists returned error: %v", err)
return
}
if created {
atomic.AddInt64(&createdCount, 1)
}
}()
}
wg.Wait()
if errorCount != 0 {
t.Fatalf("errorCount = %d, want 0", errorCount)
}
if createdCount != 1 {
t.Errorf("createdCount = %d, want exactly 1 (only one goroutine may win the insert)", createdCount)
}
// Exactly one row must exist.
agents, err := repo.List(ctx)
if err != nil {
t.Fatalf("List failed: %v", err)
}
count := 0
for _, a := range agents {
if a.ID == "cloud-gcp-sm" {
count++
}
}
if count != 1 {
t.Errorf("row count for cloud-gcp-sm = %d, want 1", count)
}
}
// TestAgentRepository_CreateIfNotExists_GenericErrorSurfaces verifies that
// failures other than the primary-key duplicate (the only collision
// ON CONFLICT (id) absorbs) propagate to the caller instead of being
// swallowed. This is the security property that M-6 restores: the
// pre-fix plain-INSERT path logged every error at Debug level, so a
// connectivity or permission failure would vanish into the log without
// the server surfacing a problem on startup (CWE-662 / CWE-209-adjacent).
//
// Uses a pre-cancelled context to force QueryRowContext to fail with
// context.Canceled — a non-duplicate error class that must surface.
// Does NOT close the shared sql.DB (that would break sibling tests).
func TestAgentRepository_CreateIfNotExists_GenericErrorSurfaces(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewAgentRepository(db)
ctx, cancel := context.WithCancel(context.Background())
cancel() // pre-cancel so the driver round-trip fails immediately.
agent := &domain.Agent{
ID: "server-scanner",
Name: "Network Scanner (Server-Side)",
Status: domain.AgentStatusOnline,
RegisteredAt: time.Now(),
}
created, err := repo.CreateIfNotExists(ctx, agent)
if err == nil {
t.Fatal("expected error on cancelled context, got nil (error would have been swallowed pre-M-6)")
}
if created {
t.Error("created = true on failure, want false")
}
if err == sql.ErrNoRows {
t.Error("got sql.ErrNoRows, want a real connection/context error (ErrNoRows is the duplicate-row sentinel)")
}
}
// ============================================================
// Issuer Repository Tests
// ============================================================
func TestIssuerRepository_CRUD(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewIssuerRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
config, _ := json.Marshal(map[string]string{"type": "local"})
issuer := &domain.Issuer{
ID: "iss-test", Name: "Test Issuer", Type: domain.IssuerTypeGenericCA,
Config: config, Enabled: true, CreatedAt: now, UpdatedAt: now,
}
// Create
if err := repo.Create(ctx, issuer); err != nil {
t.Fatalf("Create failed: %v", err)
}
// Get
got, err := repo.Get(ctx, "iss-test")
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if got.Name != "Test Issuer" {
t.Errorf("Name = %q, want %q", got.Name, "Test Issuer")
}
// Update
issuer.Enabled = false
issuer.UpdatedAt = time.Now().Truncate(time.Microsecond)
if err := repo.Update(ctx, issuer); err != nil {
t.Fatalf("Update failed: %v", err)
}
got, _ = repo.Get(ctx, "iss-test")
if got.Enabled {
t.Error("expected Enabled=false after update")
}
// List
issuers, err := repo.List(ctx)
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(issuers) != 1 {
t.Errorf("len(issuers) = %d, want 1", len(issuers))
}
// Delete
if err := repo.Delete(ctx, "iss-test"); err != nil {
t.Fatalf("Delete failed: %v", err)
}
_, err = repo.Get(ctx, "iss-test")
if err == nil {
t.Error("expected error after delete")
}
}
// ============================================================
// Target Repository Tests
// ============================================================
func TestTargetRepository_CRUD(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
targetRepo := postgres.NewTargetRepository(db)
agentRepo := postgres.NewAgentRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
// Create agent first (FK requirement)
agent := &domain.Agent{
ID: "agent-target-test", Name: "target-test-agent", Hostname: "host",
Status: domain.AgentStatusOnline, RegisteredAt: now, APIKeyHash: "hash1",
}
agentRepo.Create(ctx, agent)
config, _ := json.Marshal(map[string]string{"cert_path": "/etc/nginx/ssl/cert.pem"})
target := &domain.DeploymentTarget{
ID: "t-test", Name: "Test Target", Type: domain.TargetTypeNGINX,
AgentID: "agent-target-test", Config: config, Enabled: true,
CreatedAt: now, UpdatedAt: now,
}
if err := targetRepo.Create(ctx, target); err != nil {
t.Fatalf("Create failed: %v", err)
}
got, err := targetRepo.Get(ctx, "t-test")
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if got.Type != domain.TargetTypeNGINX {
t.Errorf("Type = %q, want %q", got.Type, domain.TargetTypeNGINX)
}
targets, err := targetRepo.List(ctx)
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(targets) != 1 {
t.Errorf("len(targets) = %d, want 1", len(targets))
}
if err := targetRepo.Delete(ctx, "t-test"); err != nil {
t.Fatalf("Delete failed: %v", err)
}
}
// ============================================================
// Job Repository Tests
// ============================================================
func TestJobRepository_CRUD(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
jobRepo := postgres.NewJobRepository(db)
certRepo := postgres.NewCertificateRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "job")
// Create prerequisite cert
cert := &domain.ManagedCertificate{
ID: "mc-job-test", Name: "job-test", CommonName: "job.example.com",
SANs: []string{}, OwnerID: ownerID, TeamID: teamID,
IssuerID: issuerID, RenewalPolicyID: policyID,
Status: domain.CertificateStatusActive,
ExpiresAt: now.Add(30 * 24 * time.Hour), Tags: map[string]string{},
CreatedAt: now, UpdatedAt: now,
}
if err := certRepo.Create(ctx, cert); err != nil {
t.Fatalf("Create cert failed: %v", err)
}
job := &domain.Job{
ID: "job-test-1", Type: domain.JobTypeRenewal, CertificateID: "mc-job-test",
Status: domain.JobStatusPending, Attempts: 0, MaxAttempts: 3,
ScheduledAt: now, CreatedAt: now,
}
// Create
if err := jobRepo.Create(ctx, job); err != nil {
t.Fatalf("Create failed: %v", err)
}
// Get
got, err := jobRepo.Get(ctx, "job-test-1")
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if got.Type != domain.JobTypeRenewal {
t.Errorf("Type = %q, want %q", got.Type, domain.JobTypeRenewal)
}
// ListByStatus
pending, err := jobRepo.ListByStatus(ctx, domain.JobStatusPending)
if err != nil {
t.Fatalf("ListByStatus failed: %v", err)
}
if len(pending) != 1 {
t.Errorf("len(pending) = %d, want 1", len(pending))
}
// UpdateStatus
errMsg := "test error"
if err := jobRepo.UpdateStatus(ctx, "job-test-1", domain.JobStatusFailed, errMsg); err != nil {
t.Fatalf("UpdateStatus failed: %v", err)
}
got, _ = jobRepo.Get(ctx, "job-test-1")
if got.Status != domain.JobStatusFailed {
t.Errorf("Status after update = %q, want %q", got.Status, domain.JobStatusFailed)
}
// GetPendingJobs (should be empty now)
pendingJobs, err := jobRepo.GetPendingJobs(ctx, domain.JobTypeRenewal)
if err != nil {
t.Fatalf("GetPendingJobs failed: %v", err)
}
if len(pendingJobs) != 0 {
t.Errorf("len(pendingJobs) = %d, want 0 (job is now Failed)", len(pendingJobs))
}
// ListByCertificate
certJobs, err := jobRepo.ListByCertificate(ctx, "mc-job-test")
if err != nil {
t.Fatalf("ListByCertificate failed: %v", err)
}
if len(certJobs) != 1 {
t.Errorf("len(certJobs) = %d, want 1", len(certJobs))
}
// Delete
if err := jobRepo.Delete(ctx, "job-test-1"); err != nil {
t.Fatalf("Delete failed: %v", err)
}
}
// ============================================================
// Revocation Repository Tests
// ============================================================
func TestRevocationRepository_CRUD(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewRevocationRepository(db)
certRepo := postgres.NewCertificateRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "rev")
// Create prerequisite cert
cert := &domain.ManagedCertificate{
ID: "mc-rev-test", Name: "rev-test", CommonName: "rev.example.com",
SANs: []string{}, OwnerID: ownerID, TeamID: teamID,
IssuerID: issuerID, RenewalPolicyID: policyID,
Status: domain.CertificateStatusRevoked,
ExpiresAt: now.Add(30 * 24 * time.Hour), Tags: map[string]string{},
CreatedAt: now, UpdatedAt: now,
}
if err := certRepo.Create(ctx, cert); err != nil {
t.Fatalf("Create cert failed: %v", err)
}
revocation := &domain.CertificateRevocation{
ID: "rev-test-1", CertificateID: "mc-rev-test", SerialNumber: "DEADBEEF01",
Reason: "keyCompromise", RevokedBy: "admin", RevokedAt: now,
IssuerID: issuerID, CreatedAt: now,
}
// Create
if err := repo.Create(ctx, revocation); err != nil {
t.Fatalf("Create failed: %v", err)
}
// Idempotent create (ON CONFLICT DO NOTHING)
if err := repo.Create(ctx, revocation); err != nil {
t.Fatalf("Idempotent create failed: %v", err)
}
// GetByIssuerAndSerial — lookups are scoped to (issuer_id, serial) per RFC 5280 §5.2.3.
got, err := repo.GetByIssuerAndSerial(ctx, issuerID, "DEADBEEF01")
if err != nil {
t.Fatalf("GetByIssuerAndSerial failed: %v", err)
}
if got.Reason != "keyCompromise" {
t.Errorf("Reason = %q, want %q", got.Reason, "keyCompromise")
}
// ListAll
all, err := repo.ListAll(ctx)
if err != nil {
t.Fatalf("ListAll failed: %v", err)
}
if len(all) != 1 {
t.Errorf("len(all) = %d, want 1", len(all))
}
// ListByCertificate
certRevs, err := repo.ListByCertificate(ctx, "mc-rev-test")
if err != nil {
t.Fatalf("ListByCertificate failed: %v", err)
}
if len(certRevs) != 1 {
t.Errorf("len(certRevs) = %d, want 1", len(certRevs))
}
// MarkIssuerNotified
if err := repo.MarkIssuerNotified(ctx, "rev-test-1"); err != nil {
t.Fatalf("MarkIssuerNotified failed: %v", err)
}
got, _ = repo.GetByIssuerAndSerial(ctx, issuerID, "DEADBEEF01")
if !got.IssuerNotified {
t.Error("expected IssuerNotified=true after marking")
}
}
// TestRevocationRepository_CrossIssuerSerialCollision verifies that the same
// serial number can coexist under two different issuers — RFC 5280 §5.2.3
// defines serial uniqueness only within a single CA, and certctl supports
// multi-issuer deployments where serial collisions across issuers are
// legitimate (e.g., Local CA serial 0x01 and Vault PKI serial 0x01).
//
// This test locks in the behavior change from migration 000012: the unique
// index is on (issuer_id, serial_number), not on serial_number alone.
func TestRevocationRepository_CrossIssuerSerialCollision(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewRevocationRepository(db)
certRepo := postgres.NewCertificateRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
// First issuer + cert + revocation with serial "CAFEBABE01".
ownerID1, teamID1, issuerID1, policyID1 := insertCertPrereqsRaw(t, db, ctx, "dup-a")
cert1 := &domain.ManagedCertificate{
ID: "mc-dup-a", Name: "dup-a", CommonName: "a.example.com",
SANs: []string{}, OwnerID: ownerID1, TeamID: teamID1,
IssuerID: issuerID1, RenewalPolicyID: policyID1,
Status: domain.CertificateStatusRevoked,
ExpiresAt: now.Add(30 * 24 * time.Hour), Tags: map[string]string{},
CreatedAt: now, UpdatedAt: now,
}
if err := certRepo.Create(ctx, cert1); err != nil {
t.Fatalf("Create cert1 failed: %v", err)
}
if err := repo.Create(ctx, &domain.CertificateRevocation{
ID: "rev-dup-a", CertificateID: "mc-dup-a", SerialNumber: "CAFEBABE01",
Reason: "keyCompromise", RevokedBy: "admin", RevokedAt: now,
IssuerID: issuerID1, CreatedAt: now,
}); err != nil {
t.Fatalf("Create revocation under issuer1 failed: %v", err)
}
// Second issuer + cert + revocation with the SAME serial "CAFEBABE01".
// Under the pre-000012 global-unique index this would silently drop via
// ON CONFLICT DO NOTHING. Under the new (issuer_id, serial_number) scope
// it must succeed.
ownerID2, teamID2, issuerID2, policyID2 := insertCertPrereqsRaw(t, db, ctx, "dup-b")
cert2 := &domain.ManagedCertificate{
ID: "mc-dup-b", Name: "dup-b", CommonName: "b.example.com",
SANs: []string{}, OwnerID: ownerID2, TeamID: teamID2,
IssuerID: issuerID2, RenewalPolicyID: policyID2,
Status: domain.CertificateStatusRevoked,
ExpiresAt: now.Add(30 * 24 * time.Hour), Tags: map[string]string{},
CreatedAt: now, UpdatedAt: now,
}
if err := certRepo.Create(ctx, cert2); err != nil {
t.Fatalf("Create cert2 failed: %v", err)
}
if err := repo.Create(ctx, &domain.CertificateRevocation{
ID: "rev-dup-b", CertificateID: "mc-dup-b", SerialNumber: "CAFEBABE01",
Reason: "superseded", RevokedBy: "admin", RevokedAt: now,
IssuerID: issuerID2, CreatedAt: now,
}); err != nil {
t.Fatalf("Create revocation under issuer2 failed (cross-issuer duplicate serial must be allowed): %v", err)
}
// Both revocations must be retrievable under their respective issuers.
revA, err := repo.GetByIssuerAndSerial(ctx, issuerID1, "CAFEBABE01")
if err != nil {
t.Fatalf("GetByIssuerAndSerial(issuer1) failed: %v", err)
}
if revA.ID != "rev-dup-a" || revA.Reason != "keyCompromise" {
t.Errorf("issuer1 lookup returned wrong row: id=%q reason=%q", revA.ID, revA.Reason)
}
revB, err := repo.GetByIssuerAndSerial(ctx, issuerID2, "CAFEBABE01")
if err != nil {
t.Fatalf("GetByIssuerAndSerial(issuer2) failed: %v", err)
}
if revB.ID != "rev-dup-b" || revB.Reason != "superseded" {
t.Errorf("issuer2 lookup returned wrong row: id=%q reason=%q", revB.ID, revB.Reason)
}
// ListAll should see both revocations.
all, err := repo.ListAll(ctx)
if err != nil {
t.Fatalf("ListAll failed: %v", err)
}
if len(all) != 2 {
t.Errorf("len(all) = %d, want 2 (cross-issuer duplicate serials)", len(all))
}
// Same-issuer idempotency guard still works (ON CONFLICT DO NOTHING on
// (issuer_id, serial_number) — re-inserting the same (issuer, serial)
// pair must not error and must not duplicate the row).
if err := repo.Create(ctx, &domain.CertificateRevocation{
ID: "rev-dup-a-repeat", CertificateID: "mc-dup-a", SerialNumber: "CAFEBABE01",
Reason: "superseded", RevokedBy: "admin", RevokedAt: now,
IssuerID: issuerID1, CreatedAt: now,
}); err != nil {
t.Fatalf("Idempotent create under same issuer failed: %v", err)
}
all, _ = repo.ListAll(ctx)
if len(all) != 2 {
t.Errorf("len(all) after idempotent re-insert = %d, want 2", len(all))
}
}
// ============================================================
// Team Repository Tests
// ============================================================
func TestTeamRepository_CRUD(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewTeamRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
team := &domain.Team{
ID: "team-test", Name: "Platform", Description: "Platform team",
CreatedAt: now, UpdatedAt: now,
}
if err := repo.Create(ctx, team); err != nil {
t.Fatalf("Create failed: %v", err)
}
got, err := repo.Get(ctx, "team-test")
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if got.Name != "Platform" {
t.Errorf("Name = %q, want %q", got.Name, "Platform")
}
teams, err := repo.List(ctx)
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(teams) != 1 {
t.Errorf("len(teams) = %d, want 1", len(teams))
}
team.Description = "Updated"
team.UpdatedAt = time.Now().Truncate(time.Microsecond)
if err := repo.Update(ctx, team); err != nil {
t.Fatalf("Update failed: %v", err)
}
if err := repo.Delete(ctx, "team-test"); err != nil {
t.Fatalf("Delete failed: %v", err)
}
}
// ============================================================
// Owner Repository Tests
// ============================================================
func TestOwnerRepository_CRUD(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
ownerRepo := postgres.NewOwnerRepository(db)
teamRepo := postgres.NewTeamRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
// Create team first (FK)
team := &domain.Team{
ID: "team-owner-test", Name: "Owner Test Team",
CreatedAt: now, UpdatedAt: now,
}
teamRepo.Create(ctx, team)
owner := &domain.Owner{
ID: "o-test", Name: "Alice", Email: "alice@example.com",
TeamID: "team-owner-test", CreatedAt: now, UpdatedAt: now,
}
if err := ownerRepo.Create(ctx, owner); err != nil {
t.Fatalf("Create failed: %v", err)
}
got, err := ownerRepo.Get(ctx, "o-test")
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if got.Email != "alice@example.com" {
t.Errorf("Email = %q, want %q", got.Email, "alice@example.com")
}
owners, err := ownerRepo.List(ctx)
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(owners) != 1 {
t.Errorf("len(owners) = %d, want 1", len(owners))
}
if err := ownerRepo.Delete(ctx, "o-test"); err != nil {
t.Fatalf("Delete failed: %v", err)
}
}
// ============================================================
// Policy Repository Tests
// ============================================================
func TestPolicyRepository_CRUD(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewPolicyRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
config, _ := json.Marshal(map[string]interface{}{"domains": []string{"*.example.com"}})
rule := &domain.PolicyRule{
ID: "pol-test", Name: "Test Policy", Type: domain.PolicyTypeAllowedDomains,
Config: config, Enabled: true, CreatedAt: now, UpdatedAt: now,
}
// CreateRule
if err := repo.CreateRule(ctx, rule); err != nil {
t.Fatalf("CreateRule failed: %v", err)
}
// GetRule
got, err := repo.GetRule(ctx, "pol-test")
if err != nil {
t.Fatalf("GetRule failed: %v", err)
}
if got.Type != domain.PolicyTypeAllowedDomains {
t.Errorf("Type = %q, want %q", got.Type, domain.PolicyTypeAllowedDomains)
}
// ListRules
rules, err := repo.ListRules(ctx)
if err != nil {
t.Fatalf("ListRules failed: %v", err)
}
if len(rules) != 1 {
t.Errorf("len(rules) = %d, want 1", len(rules))
}
// UpdateRule
rule.Enabled = false
rule.UpdatedAt = time.Now().Truncate(time.Microsecond)
if err := repo.UpdateRule(ctx, rule); err != nil {
t.Fatalf("UpdateRule failed: %v", err)
}
// DeleteRule
if err := repo.DeleteRule(ctx, "pol-test"); err != nil {
t.Fatalf("DeleteRule failed: %v", err)
}
}
// ============================================================
// Audit Repository Tests
// ============================================================
func TestAuditRepository_CreateAndList(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewAuditRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
event := &domain.AuditEvent{
ID: "audit-test-1", Actor: "admin", ActorType: "User",
Action: "certificate_created", ResourceType: "certificate",
ResourceID: "mc-test", Details: json.RawMessage(`{"cn":"test.example.com"}`),
Timestamp: now,
}
if err := repo.Create(ctx, event); err != nil {
t.Fatalf("Create failed: %v", err)
}
// List with filter
events, err := repo.List(ctx, &repository.AuditFilter{
Actor: "admin", Page: 1, PerPage: 10,
})
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(events) != 1 {
t.Errorf("len(events) = %d, want 1", len(events))
}
// List with empty filter
events, err = repo.List(ctx, &repository.AuditFilter{Page: 1, PerPage: 10})
if err != nil {
t.Fatalf("List all failed: %v", err)
}
if len(events) != 1 {
t.Errorf("len(events) = %d, want 1", len(events))
}
}
// ============================================================
// Profile Repository Tests
// ============================================================
func TestProfileRepository_CRUD(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewProfileRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
profile := &domain.CertificateProfile{
ID: "prof-test", Name: "Test Profile", Description: "Test",
AllowedKeyAlgorithms: []domain.KeyAlgorithmRule{
{Algorithm: "RSA", MinSize: 2048},
{Algorithm: "ECDSA", MinSize: 256},
},
MaxTTLSeconds: 86400,
AllowedEKUs: []string{"serverAuth"},
AllowShortLived: false,
Enabled: true,
CreatedAt: now,
UpdatedAt: now,
}
if err := repo.Create(ctx, profile); err != nil {
t.Fatalf("Create failed: %v", err)
}
got, err := repo.Get(ctx, "prof-test")
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if got.MaxTTLSeconds != 86400 {
t.Errorf("MaxTTLSeconds = %d, want 86400", got.MaxTTLSeconds)
}
if len(got.AllowedKeyAlgorithms) != 2 {
t.Errorf("len(AllowedKeyAlgorithms) = %d, want 2", len(got.AllowedKeyAlgorithms))
}
profiles, err := repo.List(ctx)
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(profiles) != 1 {
t.Errorf("len(profiles) = %d, want 1", len(profiles))
}
if err := repo.Delete(ctx, "prof-test"); err != nil {
t.Fatalf("Delete failed: %v", err)
}
}
// ============================================================
// Notification Repository Tests
// ============================================================
func TestNotificationRepository_CRUD(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewNotificationRepository(db)
certRepo := postgres.NewCertificateRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "notif")
// Create prerequisite cert (notification references it via FK)
cert := &domain.ManagedCertificate{
ID: "mc-notif-test", Name: "notif-test", CommonName: "notif.example.com",
SANs: []string{}, OwnerID: ownerID, TeamID: teamID,
IssuerID: issuerID, RenewalPolicyID: policyID,
Status: domain.CertificateStatusActive,
ExpiresAt: now.Add(30 * 24 * time.Hour), Tags: map[string]string{},
CreatedAt: now, UpdatedAt: now,
}
if err := certRepo.Create(ctx, cert); err != nil {
t.Fatalf("Create cert failed: %v", err)
}
certID := "mc-notif-test"
notif := &domain.NotificationEvent{
ID: "notif-test-1", Type: domain.NotificationTypeExpirationWarning,
CertificateID: &certID, Channel: domain.NotificationChannelEmail,
Recipient: "admin@example.com", Message: "Cert expiring in 7 days",
Status: "pending", CreatedAt: now,
}
if err := repo.Create(ctx, notif); err != nil {
t.Fatalf("Create failed: %v", err)
}
// List
notifications, err := repo.List(ctx, &repository.NotificationFilter{Page: 1, PerPage: 10})
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(notifications) != 1 {
t.Errorf("len(notifications) = %d, want 1", len(notifications))
}
// UpdateStatus
sentAt := time.Now().Truncate(time.Microsecond)
if err := repo.UpdateStatus(ctx, "notif-test-1", "sent", sentAt); err != nil {
t.Fatalf("UpdateStatus failed: %v", err)
}
}
// ============================================================
// Discovery Repository Tests
// ============================================================
func TestDiscoveryRepository_ScanCRUD(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewDiscoveryRepository(db)
agentRepo := postgres.NewAgentRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
// Create agent first (FK for discovered certs)
agent := &domain.Agent{
ID: "agent-disc-test", Name: "disc-agent", Hostname: "disc-host",
Status: domain.AgentStatusOnline, RegisteredAt: now, APIKeyHash: "dischash",
}
agentRepo.Create(ctx, agent)
completedAt := now.Add(5 * time.Second)
scan := &domain.DiscoveryScan{
ID: "scan-test-1", AgentID: "agent-disc-test",
Directories: []string{"/etc/ssl", "/opt/certs"},
CertificatesFound: 10, CertificatesNew: 3, ErrorsCount: 1,
ScanDurationMs: 1500, StartedAt: now, CompletedAt: &completedAt,
}
// CreateScan
if err := repo.CreateScan(ctx, scan); err != nil {
t.Fatalf("CreateScan failed: %v", err)
}
// GetScan
got, err := repo.GetScan(ctx, "scan-test-1")
if err != nil {
t.Fatalf("GetScan failed: %v", err)
}
if got.CertificatesFound != 10 {
t.Errorf("CertificatesFound = %d, want 10", got.CertificatesFound)
}
if len(got.Directories) != 2 {
t.Errorf("len(Directories) = %d, want 2", len(got.Directories))
}
// ListScans
scans, total, err := repo.ListScans(ctx, "agent-disc-test", 1, 10)
if err != nil {
t.Fatalf("ListScans failed: %v", err)
}
if total != 1 || len(scans) != 1 {
t.Errorf("ListScans total=%d len=%d, want 1/1", total, len(scans))
}
// ListScans with empty agent (all)
_, total, err = repo.ListScans(ctx, "", 1, 10)
if err != nil {
t.Fatalf("ListScans all failed: %v", err)
}
if total != 1 {
t.Errorf("ListScans all total=%d, want 1", total)
}
}
func TestDiscoveryRepository_DiscoveredCertCRUD(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewDiscoveryRepository(db)
agentRepo := postgres.NewAgentRepository(db)
certRepo := postgres.NewCertificateRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
notBefore := now.Add(-30 * 24 * time.Hour)
notAfter := now.Add(60 * 24 * time.Hour)
// Create agent first
agent := &domain.Agent{
ID: "agent-dcert-test", Name: "dcert-agent", Hostname: "dcert-host",
Status: domain.AgentStatusOnline, RegisteredAt: now, APIKeyHash: "dcerthash",
}
agentRepo.Create(ctx, agent)
// Create a managed cert for the "claim" test (FK on managed_certificate_id)
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "dcert")
linkedCert := &domain.ManagedCertificate{
ID: "mc-linked-cert", Name: "linked-cert", CommonName: "linked.example.com",
SANs: []string{}, OwnerID: ownerID, TeamID: teamID,
IssuerID: issuerID, RenewalPolicyID: policyID,
Status: domain.CertificateStatusActive,
ExpiresAt: now.Add(90 * 24 * time.Hour), Tags: map[string]string{},
CreatedAt: now, UpdatedAt: now,
}
if err := certRepo.Create(ctx, linkedCert); err != nil {
t.Fatalf("Create linked cert failed: %v", err)
}
cert := &domain.DiscoveredCertificate{
ID: "dc-test-1", FingerprintSHA256: "abcdef1234567890",
CommonName: "disc.example.com", SANs: []string{"disc.example.com", "www.disc.example.com"},
SerialNumber: "DISC01", IssuerDN: "CN=Test CA", SubjectDN: "CN=disc.example.com",
NotBefore: &notBefore, NotAfter: &notAfter, KeyAlgorithm: "RSA", KeySize: 2048,
IsCA: false, PEMData: "---PEM---", SourcePath: "/etc/ssl/certs/disc.pem",
SourceFormat: "PEM", AgentID: "agent-dcert-test",
Status: domain.DiscoveryStatusUnmanaged,
FirstSeenAt: now, LastSeenAt: now, CreatedAt: now, UpdatedAt: now,
}
// CreateDiscovered — new insert
isNew, err := repo.CreateDiscovered(ctx, cert)
if err != nil {
t.Fatalf("CreateDiscovered failed: %v", err)
}
if !isNew {
t.Error("expected isNew=true for first insert")
}
// CreateDiscovered again — upsert (same fingerprint+agent+path)
cert.ID = "dc-test-1-dup" // different ID, same fingerprint+agent+path
cert.LastSeenAt = now.Add(1 * time.Hour)
isNew, err = repo.CreateDiscovered(ctx, cert)
if err != nil {
t.Fatalf("CreateDiscovered upsert failed: %v", err)
}
if isNew {
t.Error("expected isNew=false for upsert")
}
// GetDiscovered
got, err := repo.GetDiscovered(ctx, "dc-test-1")
if err != nil {
t.Fatalf("GetDiscovered failed: %v", err)
}
if got.CommonName != "disc.example.com" {
t.Errorf("CommonName = %q, want %q", got.CommonName, "disc.example.com")
}
if len(got.SANs) != 2 {
t.Errorf("len(SANs) = %d, want 2", len(got.SANs))
}
// ListDiscovered
certs, total, err := repo.ListDiscovered(ctx, &repository.DiscoveryFilter{Page: 1, PerPage: 10})
if err != nil {
t.Fatalf("ListDiscovered failed: %v", err)
}
_ = certs // used in subsequent calls
if total != 1 {
t.Errorf("total = %d, want 1", total)
}
// ListDiscovered by agent
certs, total, err = repo.ListDiscovered(ctx, &repository.DiscoveryFilter{
AgentID: "agent-dcert-test", Page: 1, PerPage: 10,
})
if err != nil {
t.Fatalf("ListDiscovered by agent failed: %v", err)
}
if total != 1 || len(certs) != 1 {
t.Errorf("agent filter: total=%d len=%d, want 1/1", total, len(certs))
}
// ListDiscovered by status
certs, _, err = repo.ListDiscovered(ctx, &repository.DiscoveryFilter{
Status: "Unmanaged", Page: 1, PerPage: 10,
})
if err != nil {
t.Fatalf("ListDiscovered by status failed: %v", err)
}
if len(certs) != 1 {
t.Errorf("status filter len = %d, want 1", len(certs))
}
// GetByFingerprint
fpCerts, err := repo.GetByFingerprint(ctx, "abcdef1234567890")
if err != nil {
t.Fatalf("GetByFingerprint failed: %v", err)
}
if len(fpCerts) != 1 {
t.Errorf("len(fpCerts) = %d, want 1", len(fpCerts))
}
// CountByStatus
counts, err := repo.CountByStatus(ctx)
if err != nil {
t.Fatalf("CountByStatus failed: %v", err)
}
if counts["Unmanaged"] != 1 {
t.Errorf("Unmanaged count = %d, want 1", counts["Unmanaged"])
}
// UpdateDiscoveredStatus to Dismissed
if err := repo.UpdateDiscoveredStatus(ctx, "dc-test-1", domain.DiscoveryStatusDismissed, ""); err != nil {
t.Fatalf("UpdateDiscoveredStatus to Dismissed failed: %v", err)
}
got, _ = repo.GetDiscovered(ctx, "dc-test-1")
if got.Status != domain.DiscoveryStatusDismissed {
t.Errorf("Status = %q, want %q", got.Status, domain.DiscoveryStatusDismissed)
}
if got.DismissedAt == nil {
t.Error("expected DismissedAt to be set")
}
// UpdateDiscoveredStatus to Managed with link
if err := repo.UpdateDiscoveredStatus(ctx, "dc-test-1", domain.DiscoveryStatusManaged, "mc-linked-cert"); err != nil {
t.Fatalf("UpdateDiscoveredStatus to Managed failed: %v", err)
}
got, _ = repo.GetDiscovered(ctx, "dc-test-1")
if got.Status != domain.DiscoveryStatusManaged {
t.Errorf("Status = %q, want %q", got.Status, domain.DiscoveryStatusManaged)
}
if got.ManagedCertificateID != "mc-linked-cert" {
t.Errorf("ManagedCertificateID = %q, want %q", got.ManagedCertificateID, "mc-linked-cert")
}
// UpdateDiscoveredStatus NotFound
if err := repo.UpdateDiscoveredStatus(ctx, "nonexistent", domain.DiscoveryStatusDismissed, ""); err == nil {
t.Error("expected error for nonexistent status update")
}
}
// ============================================================
// Network Scan Repository Tests
// ============================================================
func TestNetworkScanRepository_CRUD(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
repo := postgres.NewNetworkScanRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
target := &domain.NetworkScanTarget{
ID: "ns-test-1", Name: "Internal Network",
CIDRs: []string{"10.0.0.0/24", "192.168.1.0/24"},
Ports: []int64{443, 8443},
Enabled: true, ScanIntervalHours: 6, TimeoutMs: 5000,
CreatedAt: now, UpdatedAt: now,
}
// Create
if err := repo.Create(ctx, target); err != nil {
t.Fatalf("Create failed: %v", err)
}
// Get
got, err := repo.Get(ctx, "ns-test-1")
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if got.Name != "Internal Network" {
t.Errorf("Name = %q, want %q", got.Name, "Internal Network")
}
if len(got.CIDRs) != 2 {
t.Errorf("len(CIDRs) = %d, want 2", len(got.CIDRs))
}
if len(got.Ports) != 2 {
t.Errorf("len(Ports) = %d, want 2", len(got.Ports))
}
if got.LastScanAt != nil {
t.Error("expected LastScanAt to be nil initially")
}
// List
targets, err := repo.List(ctx)
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(targets) != 1 {
t.Errorf("len(targets) = %d, want 1", len(targets))
}
// ListEnabled
enabled, err := repo.ListEnabled(ctx)
if err != nil {
t.Fatalf("ListEnabled failed: %v", err)
}
if len(enabled) != 1 {
t.Errorf("len(enabled) = %d, want 1", len(enabled))
}
// Update
target.Name = "Updated Network"
target.Enabled = false
if err := repo.Update(ctx, target); err != nil {
t.Fatalf("Update failed: %v", err)
}
got, _ = repo.Get(ctx, "ns-test-1")
if got.Name != "Updated Network" {
t.Errorf("Name after update = %q, want %q", got.Name, "Updated Network")
}
// ListEnabled after disabling
enabled, err = repo.ListEnabled(ctx)
if err != nil {
t.Fatalf("ListEnabled after disable failed: %v", err)
}
if len(enabled) != 0 {
t.Errorf("len(enabled) after disable = %d, want 0", len(enabled))
}
// UpdateScanResults
scanTime := now.Add(1 * time.Hour)
if err := repo.UpdateScanResults(ctx, "ns-test-1", scanTime, 1500, 5); err != nil {
t.Fatalf("UpdateScanResults failed: %v", err)
}
got, _ = repo.Get(ctx, "ns-test-1")
if got.LastScanAt == nil {
t.Fatal("expected LastScanAt to be set after scan results update")
}
if got.LastScanCertsFound == nil || *got.LastScanCertsFound != 5 {
t.Errorf("LastScanCertsFound = %v, want 5", got.LastScanCertsFound)
}
if got.LastScanDurationMs == nil || *got.LastScanDurationMs != 1500 {
t.Errorf("LastScanDurationMs = %v, want 1500", got.LastScanDurationMs)
}
// Delete
if err := repo.Delete(ctx, "ns-test-1"); err != nil {
t.Fatalf("Delete failed: %v", err)
}
_, err = repo.Get(ctx, "ns-test-1")
if err == nil {
t.Error("expected error after delete")
}
// Delete NotFound
if err := repo.Delete(ctx, "nonexistent"); err == nil {
t.Error("expected error for nonexistent delete")
}
// Update NotFound
target.ID = "nonexistent"
if err := repo.Update(ctx, target); err == nil {
t.Error("expected error for nonexistent update")
}
}
// ============================================================
// Agent Group Repository Tests
// ============================================================
func TestAgentGroupRepository_CRUD(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
groupRepo := postgres.NewAgentGroupRepository(db)
agentRepo := postgres.NewAgentRepository(db)
ctx := context.Background()
now := time.Now().Truncate(time.Microsecond)
group := &domain.AgentGroup{
ID: "grp-test", Name: "Linux Servers", Description: "All Linux agents",
MatchOS: "linux", MatchArchitecture: "amd64",
Enabled: true, CreatedAt: now, UpdatedAt: now,
}
// Create
if err := groupRepo.Create(ctx, group); err != nil {
t.Fatalf("Create failed: %v", err)
}
// Get
got, err := groupRepo.Get(ctx, "grp-test")
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if got.Name != "Linux Servers" {
t.Errorf("Name = %q, want %q", got.Name, "Linux Servers")
}
if got.MatchOS != "linux" {
t.Errorf("MatchOS = %q, want %q", got.MatchOS, "linux")
}
// List
groups, err := groupRepo.List(ctx)
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(groups) != 1 {
t.Errorf("len(groups) = %d, want 1", len(groups))
}
// Update
group.Description = "Updated"
if err := groupRepo.Update(ctx, group); err != nil {
t.Fatalf("Update failed: %v", err)
}
got, _ = groupRepo.Get(ctx, "grp-test")
if got.Description != "Updated" {
t.Errorf("Description after update = %q, want %q", got.Description, "Updated")
}
// Member management — create an agent first
agent := &domain.Agent{
ID: "agent-grp-test", Name: "grp-agent", Hostname: "grp-host",
Status: domain.AgentStatusOnline, RegisteredAt: now, APIKeyHash: "grphash",
}
agentRepo.Create(ctx, agent)
// AddMember
if err := groupRepo.AddMember(ctx, "grp-test", "agent-grp-test", "include"); err != nil {
t.Fatalf("AddMember failed: %v", err)
}
// AddMember again (ON CONFLICT upsert)
if err := groupRepo.AddMember(ctx, "grp-test", "agent-grp-test", "exclude"); err != nil {
t.Fatalf("AddMember upsert failed: %v", err)
}
// ListMembers (only includes — agent was changed to exclude, so should be empty)
members, err := groupRepo.ListMembers(ctx, "grp-test")
if err != nil {
t.Fatalf("ListMembers failed: %v", err)
}
if len(members) != 0 {
t.Errorf("len(members) = %d, want 0 (agent is excluded)", len(members))
}
// Change back to include
if err := groupRepo.AddMember(ctx, "grp-test", "agent-grp-test", "include"); err != nil {
t.Fatalf("AddMember back to include failed: %v", err)
}
members, err = groupRepo.ListMembers(ctx, "grp-test")
if err != nil {
t.Fatalf("ListMembers after re-include failed: %v", err)
}
if len(members) != 1 {
t.Errorf("len(members) = %d, want 1", len(members))
}
// RemoveMember
if err := groupRepo.RemoveMember(ctx, "grp-test", "agent-grp-test"); err != nil {
t.Fatalf("RemoveMember failed: %v", err)
}
members, _ = groupRepo.ListMembers(ctx, "grp-test")
if len(members) != 0 {
t.Errorf("len(members) after remove = %d, want 0", len(members))
}
// Delete
if err := groupRepo.Delete(ctx, "grp-test"); err != nil {
t.Fatalf("Delete failed: %v", err)
}
_, err = groupRepo.Get(ctx, "grp-test")
if err == nil {
t.Error("expected error after delete")
}
// Delete NotFound
if err := groupRepo.Delete(ctx, "nonexistent"); err == nil {
t.Error("expected error for nonexistent delete")
}
}
// ============================================================
// Empty Result Set Tests
// ============================================================
func TestEmptyResultSets(t *testing.T) {
tdb := getTestDB(t)
db := tdb.freshSchema(t)
ctx := context.Background()
// Certificates
certRepo := postgres.NewCertificateRepository(db)
certs, total, err := certRepo.List(ctx, nil)
if err != nil {
t.Fatalf("cert List failed: %v", err)
}
if total != 0 || len(certs) != 0 {
t.Errorf("expected empty cert list, got total=%d len=%d", total, len(certs))
}
// Agents
agentRepo := postgres.NewAgentRepository(db)
agents, err := agentRepo.List(ctx)
if err != nil {
t.Fatalf("agent List failed: %v", err)
}
if len(agents) != 0 {
t.Errorf("expected empty agent list, got %d", len(agents))
}
// Revocations
revRepo := postgres.NewRevocationRepository(db)
revs, err := revRepo.ListAll(ctx)
if err != nil {
t.Fatalf("revocation ListAll failed: %v", err)
}
if len(revs) != 0 {
t.Errorf("expected empty revocations, got %d", len(revs))
}
// Discovery
discRepo := postgres.NewDiscoveryRepository(db)
dcerts, dtotal, err := discRepo.ListDiscovered(ctx, &repository.DiscoveryFilter{Page: 1, PerPage: 10})
if err != nil {
t.Fatalf("discovery ListDiscovered failed: %v", err)
}
if dtotal != 0 || len(dcerts) != 0 {
t.Errorf("expected empty discovered certs, got total=%d len=%d", dtotal, len(dcerts))
}
counts, err := discRepo.CountByStatus(ctx)
if err != nil {
t.Fatalf("discovery CountByStatus failed: %v", err)
}
if len(counts) != 0 {
t.Errorf("expected empty status counts, got %d", len(counts))
}
// Network Scans
nsRepo := postgres.NewNetworkScanRepository(db)
nsTargets, err := nsRepo.List(ctx)
if err != nil {
t.Fatalf("network scan List failed: %v", err)
}
if len(nsTargets) != 0 {
t.Errorf("expected empty network scan targets, got %d", len(nsTargets))
}
// Agent Groups
grpRepo := postgres.NewAgentGroupRepository(db)
groups, err := grpRepo.List(ctx)
if err != nil {
t.Fatalf("agent group List failed: %v", err)
}
if len(groups) != 0 {
t.Errorf("expected empty agent groups, got %d", len(groups))
}
}
// ============================================================
// H-6 (CWE-362) Claim-Based Concurrency Tests
//
// These tests exercise the `SELECT ... FOR UPDATE SKIP LOCKED` worker-queue pattern
// introduced to remediate the H-6 race condition. They validate two invariants:
//
// 1. Disjoint claim: under concurrent callers, no Pending row is returned to more
// than one worker (i.e. each claim is exclusive).
// 2. State transition: claimed rows are atomically flipped to Running inside the
// same transaction that locked them, so a subsequent query must see the row in
// the Running state and no other worker can observe it as Pending again.
//
// Skipped automatically in `-short` mode (CI) since they require a real PostgreSQL
// instance and take ~1s under contention.
// ============================================================
// seedPendingJobs creates n Pending renewal jobs against a single prerequisite
// certificate and returns the generated job IDs.
func seedPendingJobs(t *testing.T, ctx context.Context, db *sql.DB, certID string, n int) []string {
t.Helper()
certRepo := postgres.NewCertificateRepository(db)
jobRepo := postgres.NewJobRepository(db)
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, certID)
now := time.Now().Truncate(time.Microsecond)
cert := &domain.ManagedCertificate{
ID: "mc-" + certID, Name: certID, CommonName: certID + ".example.com",
SANs: []string{}, OwnerID: ownerID, TeamID: teamID,
IssuerID: issuerID, RenewalPolicyID: policyID,
Status: domain.CertificateStatusActive,
ExpiresAt: now.Add(30 * 24 * time.Hour), Tags: map[string]string{},
CreatedAt: now, UpdatedAt: now,
}
if err := certRepo.Create(ctx, cert); err != nil {
t.Fatalf("seedPendingJobs: create cert failed: %v", err)
}
ids := make([]string, 0, n)
for i := 0; i < n; i++ {
job := &domain.Job{
ID: fmt.Sprintf("job-%s-%03d", certID, i),
Type: domain.JobTypeRenewal,
CertificateID: "mc-" + certID,
Status: domain.JobStatusPending,
Attempts: 0,
MaxAttempts: 3,
ScheduledAt: now,
CreatedAt: now,
}
if err := jobRepo.Create(ctx, job); err != nil {
t.Fatalf("seedPendingJobs: create job %d failed: %v", i, err)
}
ids = append(ids, job.ID)
}
return ids
}
// TestJobRepository_ClaimPendingJobs_FlipsToRunning validates the basic claim
// semantics: a single call transitions Pending rows to Running atomically, and
// the rows returned to the caller reflect the post-update state.
func TestJobRepository_ClaimPendingJobs_FlipsToRunning(t *testing.T) {
if testing.Short() {
t.Skip("integration test requires PostgreSQL")
}
tdb := getTestDB(t)
db := tdb.freshSchema(t)
jobRepo := postgres.NewJobRepository(db)
ctx := context.Background()
seeded := seedPendingJobs(t, ctx, db, "claimflip", 5)
claimed, err := jobRepo.ClaimPendingJobs(ctx, domain.JobTypeRenewal, 0)
if err != nil {
t.Fatalf("ClaimPendingJobs failed: %v", err)
}
if len(claimed) != len(seeded) {
t.Fatalf("len(claimed) = %d, want %d", len(claimed), len(seeded))
}
// In-memory return values must reflect the transitioned state.
for _, j := range claimed {
if j.Status != domain.JobStatusRunning {
t.Errorf("claimed job %s Status = %q, want %q", j.ID, j.Status, domain.JobStatusRunning)
}
}
// Persisted rows must also be Running — a fresh Get must not see Pending.
for _, id := range seeded {
got, err := jobRepo.Get(ctx, id)
if err != nil {
t.Fatalf("Get(%s) failed: %v", id, err)
}
if got.Status != domain.JobStatusRunning {
t.Errorf("persisted job %s Status = %q, want %q", id, got.Status, domain.JobStatusRunning)
}
}
// A subsequent claim must return zero rows — nothing is Pending anymore.
residual, err := jobRepo.ClaimPendingJobs(ctx, domain.JobTypeRenewal, 0)
if err != nil {
t.Fatalf("residual ClaimPendingJobs failed: %v", err)
}
if len(residual) != 0 {
t.Errorf("residual claims = %d, want 0 (all should be Running now)", len(residual))
}
}
// TestJobRepository_ClaimPendingJobs_ConcurrentDisjoint validates the core H-6
// invariant: under concurrent access, no row is handed to more than one worker.
//
// The test seeds M Pending jobs, fans out N goroutines each of which loops
// calling ClaimPendingJobs with limit=1, and finally asserts the union of all
// claimed IDs is exactly M with zero duplicates. Workers that transiently
// observe zero rows (because peers are holding the only remaining rows) re-check
// an atomic progress counter before exiting, so transient SKIP-LOCKED zeros do
// not cause premature termination.
func TestJobRepository_ClaimPendingJobs_ConcurrentDisjoint(t *testing.T) {
if testing.Short() {
t.Skip("integration test requires PostgreSQL")
}
tdb := getTestDB(t)
db := tdb.freshSchema(t)
jobRepo := postgres.NewJobRepository(db)
ctx := context.Background()
const M = 40 // seeded Pending jobs
const N = 8 // concurrent workers
seeded := seedPendingJobs(t, ctx, db, "concurrent", M)
seededSet := make(map[string]bool, M)
for _, id := range seeded {
seededSet[id] = true
}
var (
totalClaimed int64
allClaims []string
mu sync.Mutex
wg sync.WaitGroup
)
for w := 0; w < N; w++ {
wg.Add(1)
go func(worker int) {
defer wg.Done()
emptyStreak := 0
for iter := 0; iter < M*4; iter++ { // generous ceiling to prevent hangs
claimed, err := jobRepo.ClaimPendingJobs(ctx, domain.JobTypeRenewal, 1)
if err != nil {
t.Errorf("worker %d ClaimPendingJobs failed: %v", worker, err)
return
}
if len(claimed) == 0 {
// Transient zero (peer holds lock) vs. terminal zero (all claimed).
// Bail only once the shared counter proves work is done, but guard
// with a streak so we don't spin forever under starvation.
if atomic.LoadInt64(&totalClaimed) >= int64(M) {
return
}
emptyStreak++
if emptyStreak >= 20 {
return
}
time.Sleep(500 * time.Microsecond)
continue
}
emptyStreak = 0
mu.Lock()
for _, j := range claimed {
if j.Status != domain.JobStatusRunning {
t.Errorf("worker %d got job %s in Status=%q (want Running) — claim did not flip state", worker, j.ID, j.Status)
}
allClaims = append(allClaims, j.ID)
}
mu.Unlock()
atomic.AddInt64(&totalClaimed, int64(len(claimed)))
}
}(w)
}
wg.Wait()
// Invariant 1: no duplicate claims across the worker pool.
seen := make(map[string]int, len(allClaims))
for _, id := range allClaims {
seen[id]++
}
for id, count := range seen {
if count > 1 {
t.Errorf("job %s claimed %d times — SKIP LOCKED invariant violated", id, count)
}
}
// Invariant 2: every seeded job appears in the claim set exactly once.
if len(seen) != M {
t.Errorf("distinct claimed IDs = %d, want %d (all seeded jobs must be claimed)", len(seen), M)
}
for id := range seededSet {
if seen[id] == 0 {
t.Errorf("seeded job %s was never claimed by any worker", id)
}
}
// Invariant 3: persisted state reflects the transition — every seeded row
// is now Running; none is Pending.
for id := range seededSet {
got, err := jobRepo.Get(ctx, id)
if err != nil {
t.Fatalf("Get(%s) failed: %v", id, err)
}
if got.Status != domain.JobStatusRunning {
t.Errorf("job %s Status = %q, want %q", id, got.Status, domain.JobStatusRunning)
}
}
// Final progress counter must match the total number of seeded jobs.
if got := atomic.LoadInt64(&totalClaimed); got != int64(M) {
t.Errorf("totalClaimed = %d, want %d", got, M)
}
}
// TestJobRepository_ClaimPendingByAgentID_TransitionsDeployments validates the
// agent-scoped claim variant: Pending deployment rows for a given agent flip to
// Running; AwaitingCSR rows are returned but their state is preserved (the CSR
// submission path drives their next transition).
func TestJobRepository_ClaimPendingByAgentID_TransitionsDeployments(t *testing.T) {
if testing.Short() {
t.Skip("integration test requires PostgreSQL")
}
tdb := getTestDB(t)
db := tdb.freshSchema(t)
jobRepo := postgres.NewJobRepository(db)
agentRepo := postgres.NewAgentRepository(db)
ctx := context.Background()
ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "agentclaim")
now := time.Now().Truncate(time.Microsecond)
cert := &domain.ManagedCertificate{
ID: "mc-agentclaim", Name: "agentclaim", CommonName: "agentclaim.example.com",
SANs: []string{}, OwnerID: ownerID, TeamID: teamID,
IssuerID: issuerID, RenewalPolicyID: policyID,
Status: domain.CertificateStatusActive,
ExpiresAt: now.Add(30 * 24 * time.Hour), Tags: map[string]string{},
CreatedAt: now, UpdatedAt: now,
}
if err := postgres.NewCertificateRepository(db).Create(ctx, cert); err != nil {
t.Fatalf("create cert failed: %v", err)
}
agent := &domain.Agent{
ID: "a-claim",
Name: "claim-agent",
Hostname: "claim-agent-host",
Status: domain.AgentStatusOnline,
RegisteredAt: now,
APIKeyHash: "hash-claim",
}
if err := agentRepo.Create(ctx, agent); err != nil {
t.Fatalf("create agent failed: %v", err)
}
agentID := agent.ID
mkJob := func(id string, typ domain.JobType, status domain.JobStatus) *domain.Job {
return &domain.Job{
ID: id, Type: typ, CertificateID: cert.ID,
AgentID: &agentID,
Status: status,
Attempts: 0,
MaxAttempts: 3,
ScheduledAt: now,
CreatedAt: now,
}
}
jobs := []*domain.Job{
mkJob("job-agentclaim-dep-1", domain.JobTypeDeployment, domain.JobStatusPending),
mkJob("job-agentclaim-dep-2", domain.JobTypeDeployment, domain.JobStatusPending),
mkJob("job-agentclaim-csr-1", domain.JobTypeRenewal, domain.JobStatusAwaitingCSR),
// A Pending Renewal (not Deployment) must NOT be returned by the per-agent claim.
mkJob("job-agentclaim-ren-pending", domain.JobTypeRenewal, domain.JobStatusPending),
}
for _, j := range jobs {
if err := jobRepo.Create(ctx, j); err != nil {
t.Fatalf("create job %s failed: %v", j.ID, err)
}
}
claimed, err := jobRepo.ClaimPendingByAgentID(ctx, agentID)
if err != nil {
t.Fatalf("ClaimPendingByAgentID failed: %v", err)
}
// Expect exactly the 2 deployments + 1 AwaitingCSR.
if len(claimed) != 3 {
t.Fatalf("len(claimed) = %d, want 3 (2 deployments + 1 AwaitingCSR)", len(claimed))
}
statusByID := map[string]domain.JobStatus{}
for _, j := range claimed {
statusByID[j.ID] = j.Status
}
// Both deployments must be Running in the returned slice (in-memory reflection).
for _, id := range []string{"job-agentclaim-dep-1", "job-agentclaim-dep-2"} {
if statusByID[id] != domain.JobStatusRunning {
t.Errorf("returned deployment %s Status = %q, want Running", id, statusByID[id])
}
}
// AwaitingCSR must remain AwaitingCSR.
if statusByID["job-agentclaim-csr-1"] != domain.JobStatusAwaitingCSR {
t.Errorf("returned AwaitingCSR Status = %q, want AwaitingCSR", statusByID["job-agentclaim-csr-1"])
}
// The unrelated Pending Renewal must not be returned.
if _, ok := statusByID["job-agentclaim-ren-pending"]; ok {
t.Errorf("Pending Renewal job was returned by ClaimPendingByAgentID — scope violation")
}
// Persisted state: deployments Running, AwaitingCSR unchanged, Pending Renewal still Pending.
for id, want := range map[string]domain.JobStatus{
"job-agentclaim-dep-1": domain.JobStatusRunning,
"job-agentclaim-dep-2": domain.JobStatusRunning,
"job-agentclaim-csr-1": domain.JobStatusAwaitingCSR,
"job-agentclaim-ren-pending": domain.JobStatusPending,
} {
got, err := jobRepo.Get(ctx, id)
if err != nil {
t.Fatalf("Get(%s) failed: %v", id, err)
}
if got.Status != want {
t.Errorf("persisted %s Status = %q, want %q", id, got.Status, want)
}
}
}