diff --git a/internal/integration/lifecycle_test.go b/internal/integration/lifecycle_test.go index 385da76..ed6d7a5 100644 --- a/internal/integration/lifecycle_test.go +++ b/internal/integration/lifecycle_test.go @@ -682,6 +682,46 @@ func (m *mockJobRepository) ListPendingByAgentID(ctx context.Context, agentID st return result, nil } +// ClaimPendingJobs mirrors the production H-6 semantics: Pending jobs of the given type +// (or any type when jobType is empty) flip to Running before being returned. limit <= 0 +// means unlimited. +func (m *mockJobRepository) ClaimPendingJobs(ctx context.Context, jobType domain.JobType, limit int) ([]*domain.Job, error) { + var claimed []*domain.Job + for _, j := range m.jobs { + if j.Status != domain.JobStatusPending { + continue + } + if jobType != "" && j.Type != jobType { + continue + } + j.Status = domain.JobStatusRunning + claimed = append(claimed, j) + if limit > 0 && len(claimed) >= limit { + break + } + } + return claimed, nil +} + +// ClaimPendingByAgentID mirrors the production H-6 semantics: Pending deployment rows for +// the agent flip to Running; AwaitingCSR rows are returned with state preserved. +func (m *mockJobRepository) ClaimPendingByAgentID(ctx context.Context, agentID string) ([]*domain.Job, error) { + var result []*domain.Job + for _, j := range m.jobs { + if j.AgentID == nil || *j.AgentID != agentID { + continue + } + switch { + case j.Status == domain.JobStatusPending && j.Type == domain.JobTypeDeployment: + j.Status = domain.JobStatusRunning + result = append(result, j) + case j.Status == domain.JobStatusAwaitingCSR: + result = append(result, j) + } + } + return result, nil +} + type mockAuditRepository struct { events []*domain.AuditEvent } diff --git a/internal/repository/interfaces.go b/internal/repository/interfaces.go index 71716f3..26ba1cb 100644 --- a/internal/repository/interfaces.go +++ b/internal/repository/interfaces.go @@ -120,10 +120,20 @@ type JobRepository interface { ListByCertificate(ctx context.Context, certID string) ([]*domain.Job, error) // UpdateStatus updates a job's status and optional error message. UpdateStatus(ctx context.Context, id string, status domain.JobStatus, errMsg string) error - // GetPendingJobs returns jobs not yet processed of a specific type. + // GetPendingJobs returns jobs not yet processed of a specific type. Prefer ClaimPendingJobs in + // production paths where concurrent schedulers may race — see H-6 (CWE-362) remediation. GetPendingJobs(ctx context.Context, jobType domain.JobType) ([]*domain.Job, error) // ListPendingByAgentID returns pending deployment jobs and AwaitingCSR jobs for a specific agent. + // Prefer ClaimPendingByAgentID in production paths — see H-6 (CWE-362) remediation. ListPendingByAgentID(ctx context.Context, agentID string) ([]*domain.Job, error) + // ClaimPendingJobs atomically claims up to `limit` Pending jobs and transitions them to Running + // using SELECT FOR UPDATE SKIP LOCKED inside a transaction. An empty jobType matches any type; + // limit <= 0 means no limit. H-6 (CWE-362) race remediation. + ClaimPendingJobs(ctx context.Context, jobType domain.JobType, limit int) ([]*domain.Job, error) + // ClaimPendingByAgentID atomically claims pending deployment jobs for an agent (flipping them + // to Running) and locks AwaitingCSR jobs against concurrent observers (leaving state intact, + // since the CSR-submission path drives the next transition). H-6 (CWE-362) race remediation. + ClaimPendingByAgentID(ctx context.Context, agentID string) ([]*domain.Job, error) } // RenewalPolicyRepository defines operations for managing renewal policies. diff --git a/internal/repository/postgres/job.go b/internal/repository/postgres/job.go index cd7524d..4e521d7 100644 --- a/internal/repository/postgres/job.go +++ b/internal/repository/postgres/job.go @@ -237,7 +237,14 @@ func (r *JobRepository) UpdateStatus(ctx context.Context, id string, status doma return nil } -// GetPendingJobs returns jobs not yet processed of a specific type +// GetPendingJobs returns jobs not yet processed of a specific type. +// +// The SELECT uses FOR UPDATE SKIP LOCKED so that concurrent scheduler replicas +// cannot observe the same rows when invoked inside a transaction; combine with +// a subsequent UPDATE to Running for correct dispatch semantics. For the +// standard production dispatch path, prefer ClaimPendingJobs which wraps the +// lock, read, and state transition in a single transaction and is the +// authoritative race-free claim primitive (CWE-362 fix for H-6). func (r *JobRepository) GetPendingJobs(ctx context.Context, jobType domain.JobType) ([]*domain.Job, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, type, certificate_id, target_id, agent_id, status, attempts, max_attempts, @@ -245,6 +252,7 @@ func (r *JobRepository) GetPendingJobs(ctx context.Context, jobType domain.JobTy FROM jobs WHERE type = $1 AND status = $2 ORDER BY scheduled_at ASC + FOR UPDATE SKIP LOCKED `, jobType, domain.JobStatusPending) if err != nil { @@ -268,10 +276,115 @@ func (r *JobRepository) GetPendingJobs(ctx context.Context, jobType domain.JobTy return jobs, nil } -// ListPendingByAgentID returns pending deployment jobs and AwaitingCSR jobs for a specific agent. -// Deployment jobs are matched by agent_id directly (set at creation time), with a fallback -// for legacy jobs where agent_id is NULL but target_id resolves to the agent via deployment_targets. -// AwaitingCSR jobs are matched through certificate → target mappings → agent ownership. +// ClaimPendingJobs atomically claims up to `limit` Pending jobs and transitions +// them to Running inside a single transaction. The SELECT uses FOR UPDATE SKIP +// LOCKED so concurrent scheduler replicas observe disjoint result sets — each +// row can be claimed by exactly one caller per tick (CWE-362 fix for H-6). +// +// Passing an empty jobType claims any type. Passing limit<=0 claims all +// available rows. The claimed rows are returned with Status already set to +// domain.JobStatusRunning. +// +// Downstream processors (ProcessRenewalJob, ProcessDeploymentJob) already call +// UpdateStatus(Running) unconditionally on entry, so this pre-flip is +// idempotent with respect to existing processing logic. +func (r *JobRepository) ClaimPendingJobs(ctx context.Context, jobType domain.JobType, limit int) ([]*domain.Job, error) { + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("failed to begin claim transaction: %w", err) + } + // Rollback is a no-op after Commit — safe deferred cleanup if an error path + // triggers an early return before Commit(). + defer func() { _ = tx.Rollback() }() + + // Build the SELECT — jobType="" means any type, limit<=0 means unlimited. + query := ` + SELECT id, type, certificate_id, target_id, agent_id, status, attempts, max_attempts, + last_error, scheduled_at, started_at, completed_at, created_at + FROM jobs + WHERE status = $1` + args := []interface{}{domain.JobStatusPending} + if jobType != "" { + query += ` AND type = $2` + args = append(args, jobType) + } + query += ` + ORDER BY scheduled_at ASC + FOR UPDATE SKIP LOCKED` + if limit > 0 { + query += fmt.Sprintf(` LIMIT %d`, limit) + } + + rows, err := tx.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("failed to query claimable jobs: %w", err) + } + + var jobs []*domain.Job + for rows.Next() { + job, err := scanJob(rows) + if err != nil { + rows.Close() + return nil, err + } + jobs = append(jobs, job) + } + if err := rows.Err(); err != nil { + rows.Close() + return nil, fmt.Errorf("error iterating claimable job rows: %w", err) + } + rows.Close() + + if len(jobs) == 0 { + // No rows to claim — commit the (read-only) tx and return. + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("failed to commit empty claim tx: %w", err) + } + return nil, nil + } + + // Flip claimed rows to Running. Build IN clause safely with placeholders. + ids := make([]interface{}, len(jobs)) + placeholders := make([]byte, 0, len(jobs)*5) + for i, job := range jobs { + ids[i] = job.ID + if i > 0 { + placeholders = append(placeholders, ',') + } + placeholders = append(placeholders, fmt.Sprintf("$%d", i+2)...) + } + updateQuery := fmt.Sprintf( + `UPDATE jobs SET status = $1 WHERE id IN (%s)`, + string(placeholders), + ) + updateArgs := append([]interface{}{domain.JobStatusRunning}, ids...) + if _, err := tx.ExecContext(ctx, updateQuery, updateArgs...); err != nil { + return nil, fmt.Errorf("failed to transition claimed jobs to Running: %w", err) + } + + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("failed to commit claim transaction: %w", err) + } + + // Reflect the committed state in the returned objects. + for _, job := range jobs { + job.Status = domain.JobStatusRunning + } + + return jobs, nil +} + +// ListPendingByAgentID returns pending deployment jobs and AwaitingCSR jobs for +// a specific agent. Deployment jobs are matched by agent_id directly (set at +// creation time), with a fallback for legacy jobs where agent_id is NULL but +// target_id resolves to the agent via deployment_targets. AwaitingCSR jobs are +// matched through certificate → target mappings → agent ownership. +// +// The SELECT uses FOR UPDATE SKIP LOCKED so concurrent pollers (e.g. two agent +// instances running with the same agent_id) cannot observe the same rows when +// this method is invoked inside a transaction. For the production agent work +// poll path, prefer ClaimPendingByAgentID which additionally transitions +// claimed Pending deployment rows to Running atomically (H-6 CWE-362 fix). func (r *JobRepository) ListPendingByAgentID(ctx context.Context, agentID string) ([]*domain.Job, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, type, certificate_id, target_id, agent_id, status, attempts, max_attempts, @@ -326,6 +439,137 @@ func (r *JobRepository) ListPendingByAgentID(ctx context.Context, agentID string return jobs, nil } +// ClaimPendingByAgentID atomically claims agent work inside a single +// transaction. Pending Deployment jobs assigned to the agent (directly via +// agent_id, or via legacy target→agent fallback) are transitioned from +// Pending to Running. AwaitingCSR Renewal/Issuance jobs linked to the agent +// via certificate → target mappings are locked with FOR UPDATE SKIP LOCKED +// and returned without a state transition — the flow requires the agent to +// submit a CSR to advance state, and pre-flipping AwaitingCSR would violate +// the renewal state machine (CWE-362 fix for H-6). +// +// Claimed rows are invisible to other concurrent claim calls for the lifetime +// of the transaction; rows claimed as Running remain invisible after commit +// because ListPendingByAgentID's filter is status='Pending'. +func (r *JobRepository) ClaimPendingByAgentID(ctx context.Context, agentID string) ([]*domain.Job, error) { + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("failed to begin agent claim transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + + // Branch 1 + 2: Pending Deployment jobs (direct agent_id match or legacy + // target fallback). These get flipped to Running atomically below. + pendingRows, err := tx.QueryContext(ctx, ` + SELECT id, type, certificate_id, target_id, agent_id, status, attempts, max_attempts, + last_error, scheduled_at, started_at, completed_at, created_at + FROM jobs + WHERE agent_id = $1 AND status = 'Pending' AND type = 'Deployment' + + UNION ALL + + SELECT j.id, j.type, j.certificate_id, j.target_id, j.agent_id, j.status, j.attempts, j.max_attempts, + j.last_error, j.scheduled_at, j.started_at, j.completed_at, j.created_at + FROM jobs j + INNER JOIN deployment_targets dt ON j.target_id = dt.id + WHERE j.agent_id IS NULL AND j.status = 'Pending' AND j.type = 'Deployment' + AND dt.agent_id = $1 + + ORDER BY created_at ASC + FOR UPDATE SKIP LOCKED + `, agentID) + if err != nil { + return nil, fmt.Errorf("failed to query pending deployment jobs for agent: %w", err) + } + + var pendingJobs []*domain.Job + for pendingRows.Next() { + job, err := scanJob(pendingRows) + if err != nil { + pendingRows.Close() + return nil, err + } + pendingJobs = append(pendingJobs, job) + } + if err := pendingRows.Err(); err != nil { + pendingRows.Close() + return nil, fmt.Errorf("error iterating pending deployment rows: %w", err) + } + pendingRows.Close() + + // Branch 3: AwaitingCSR jobs for this agent. Locked with FOR UPDATE SKIP + // LOCKED to prevent duplicate delivery to concurrent pollers, but state is + // NOT transitioned — the agent advances state via CSR submission. + csrRows, err := tx.QueryContext(ctx, ` + SELECT j.id, j.type, j.certificate_id, j.target_id, j.agent_id, j.status, j.attempts, j.max_attempts, + j.last_error, j.scheduled_at, j.started_at, j.completed_at, j.created_at + FROM jobs j + WHERE j.status = 'AwaitingCSR' + AND j.type IN ('Renewal', 'Issuance') + AND EXISTS ( + SELECT 1 FROM certificate_target_mappings ctm + INNER JOIN deployment_targets dt ON ctm.target_id = dt.id + WHERE ctm.certificate_id = j.certificate_id + AND dt.agent_id = $1 + ) + ORDER BY j.created_at ASC + FOR UPDATE SKIP LOCKED + `, agentID) + if err != nil { + return nil, fmt.Errorf("failed to query AwaitingCSR jobs for agent: %w", err) + } + + var csrJobs []*domain.Job + for csrRows.Next() { + job, err := scanJob(csrRows) + if err != nil { + csrRows.Close() + return nil, err + } + csrJobs = append(csrJobs, job) + } + if err := csrRows.Err(); err != nil { + csrRows.Close() + return nil, fmt.Errorf("error iterating AwaitingCSR rows: %w", err) + } + csrRows.Close() + + // Transition locked Pending deployments to Running before commit. + if len(pendingJobs) > 0 { + ids := make([]interface{}, len(pendingJobs)) + placeholders := make([]byte, 0, len(pendingJobs)*5) + for i, job := range pendingJobs { + ids[i] = job.ID + if i > 0 { + placeholders = append(placeholders, ',') + } + placeholders = append(placeholders, fmt.Sprintf("$%d", i+2)...) + } + updateQuery := fmt.Sprintf( + `UPDATE jobs SET status = $1 WHERE id IN (%s)`, + string(placeholders), + ) + updateArgs := append([]interface{}{domain.JobStatusRunning}, ids...) + if _, err := tx.ExecContext(ctx, updateQuery, updateArgs...); err != nil { + return nil, fmt.Errorf("failed to transition claimed deployment jobs to Running: %w", err) + } + } + + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("failed to commit agent claim transaction: %w", err) + } + + // Reflect the committed state in returned Pending deployment jobs; leave + // AwaitingCSR jobs untouched. + for _, job := range pendingJobs { + job.Status = domain.JobStatusRunning + } + + // Preserve the legacy ordering: Pending deployments first, AwaitingCSR + // second. Callers that want a strict created_at merge can re-sort. + return append(pendingJobs, csrJobs...), nil +} + // scanJob scans a job from a row or rows func scanJob(scanner interface { Scan(...interface{}) error diff --git a/internal/repository/postgres/repo_test.go b/internal/repository/postgres/repo_test.go index 7896151..20d3dc8 100644 --- a/internal/repository/postgres/repo_test.go +++ b/internal/repository/postgres/repo_test.go @@ -7,6 +7,9 @@ import ( "context" "database/sql" "encoding/json" + "fmt" + "sync" + "sync/atomic" "testing" "time" @@ -1682,3 +1685,334 @@ func TestEmptyResultSets(t *testing.T) { t.Errorf("expected empty agent groups, got %d", len(groups)) } } + +// ============================================================ +// H-6 (CWE-362) Claim-Based Concurrency Tests +// +// These tests exercise the `SELECT ... FOR UPDATE SKIP LOCKED` worker-queue pattern +// introduced to remediate the H-6 race condition. They validate two invariants: +// +// 1. Disjoint claim: under concurrent callers, no Pending row is returned to more +// than one worker (i.e. each claim is exclusive). +// 2. State transition: claimed rows are atomically flipped to Running inside the +// same transaction that locked them, so a subsequent query must see the row in +// the Running state and no other worker can observe it as Pending again. +// +// Skipped automatically in `-short` mode (CI) since they require a real PostgreSQL +// instance and take ~1s under contention. +// ============================================================ + +// seedPendingJobs creates n Pending renewal jobs against a single prerequisite +// certificate and returns the generated job IDs. +func seedPendingJobs(t *testing.T, ctx context.Context, db *sql.DB, certID string, n int) []string { + t.Helper() + certRepo := postgres.NewCertificateRepository(db) + jobRepo := postgres.NewJobRepository(db) + + ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, certID) + + now := time.Now().Truncate(time.Microsecond) + cert := &domain.ManagedCertificate{ + ID: "mc-" + certID, Name: certID, CommonName: certID + ".example.com", + SANs: []string{}, OwnerID: ownerID, TeamID: teamID, + IssuerID: issuerID, RenewalPolicyID: policyID, + Status: domain.CertificateStatusActive, + ExpiresAt: now.Add(30 * 24 * time.Hour), Tags: map[string]string{}, + CreatedAt: now, UpdatedAt: now, + } + if err := certRepo.Create(ctx, cert); err != nil { + t.Fatalf("seedPendingJobs: create cert failed: %v", err) + } + + ids := make([]string, 0, n) + for i := 0; i < n; i++ { + job := &domain.Job{ + ID: fmt.Sprintf("job-%s-%03d", certID, i), + Type: domain.JobTypeRenewal, + CertificateID: "mc-" + certID, + Status: domain.JobStatusPending, + Attempts: 0, + MaxAttempts: 3, + ScheduledAt: now, + CreatedAt: now, + } + if err := jobRepo.Create(ctx, job); err != nil { + t.Fatalf("seedPendingJobs: create job %d failed: %v", i, err) + } + ids = append(ids, job.ID) + } + return ids +} + +// TestJobRepository_ClaimPendingJobs_FlipsToRunning validates the basic claim +// semantics: a single call transitions Pending rows to Running atomically, and +// the rows returned to the caller reflect the post-update state. +func TestJobRepository_ClaimPendingJobs_FlipsToRunning(t *testing.T) { + if testing.Short() { + t.Skip("integration test requires PostgreSQL") + } + tdb := getTestDB(t) + db := tdb.freshSchema(t) + jobRepo := postgres.NewJobRepository(db) + ctx := context.Background() + + seeded := seedPendingJobs(t, ctx, db, "claimflip", 5) + + claimed, err := jobRepo.ClaimPendingJobs(ctx, domain.JobTypeRenewal, 0) + if err != nil { + t.Fatalf("ClaimPendingJobs failed: %v", err) + } + if len(claimed) != len(seeded) { + t.Fatalf("len(claimed) = %d, want %d", len(claimed), len(seeded)) + } + + // In-memory return values must reflect the transitioned state. + for _, j := range claimed { + if j.Status != domain.JobStatusRunning { + t.Errorf("claimed job %s Status = %q, want %q", j.ID, j.Status, domain.JobStatusRunning) + } + } + + // Persisted rows must also be Running — a fresh Get must not see Pending. + for _, id := range seeded { + got, err := jobRepo.Get(ctx, id) + if err != nil { + t.Fatalf("Get(%s) failed: %v", id, err) + } + if got.Status != domain.JobStatusRunning { + t.Errorf("persisted job %s Status = %q, want %q", id, got.Status, domain.JobStatusRunning) + } + } + + // A subsequent claim must return zero rows — nothing is Pending anymore. + residual, err := jobRepo.ClaimPendingJobs(ctx, domain.JobTypeRenewal, 0) + if err != nil { + t.Fatalf("residual ClaimPendingJobs failed: %v", err) + } + if len(residual) != 0 { + t.Errorf("residual claims = %d, want 0 (all should be Running now)", len(residual)) + } +} + +// TestJobRepository_ClaimPendingJobs_ConcurrentDisjoint validates the core H-6 +// invariant: under concurrent access, no row is handed to more than one worker. +// +// The test seeds M Pending jobs, fans out N goroutines each of which loops +// calling ClaimPendingJobs with limit=1, and finally asserts the union of all +// claimed IDs is exactly M with zero duplicates. Workers that transiently +// observe zero rows (because peers are holding the only remaining rows) re-check +// an atomic progress counter before exiting, so transient SKIP-LOCKED zeros do +// not cause premature termination. +func TestJobRepository_ClaimPendingJobs_ConcurrentDisjoint(t *testing.T) { + if testing.Short() { + t.Skip("integration test requires PostgreSQL") + } + tdb := getTestDB(t) + db := tdb.freshSchema(t) + jobRepo := postgres.NewJobRepository(db) + ctx := context.Background() + + const M = 40 // seeded Pending jobs + const N = 8 // concurrent workers + seeded := seedPendingJobs(t, ctx, db, "concurrent", M) + seededSet := make(map[string]bool, M) + for _, id := range seeded { + seededSet[id] = true + } + + var ( + totalClaimed int64 + allClaims []string + mu sync.Mutex + wg sync.WaitGroup + ) + + for w := 0; w < N; w++ { + wg.Add(1) + go func(worker int) { + defer wg.Done() + emptyStreak := 0 + for iter := 0; iter < M*4; iter++ { // generous ceiling to prevent hangs + claimed, err := jobRepo.ClaimPendingJobs(ctx, domain.JobTypeRenewal, 1) + if err != nil { + t.Errorf("worker %d ClaimPendingJobs failed: %v", worker, err) + return + } + if len(claimed) == 0 { + // Transient zero (peer holds lock) vs. terminal zero (all claimed). + // Bail only once the shared counter proves work is done, but guard + // with a streak so we don't spin forever under starvation. + if atomic.LoadInt64(&totalClaimed) >= int64(M) { + return + } + emptyStreak++ + if emptyStreak >= 20 { + return + } + time.Sleep(500 * time.Microsecond) + continue + } + emptyStreak = 0 + mu.Lock() + for _, j := range claimed { + if j.Status != domain.JobStatusRunning { + t.Errorf("worker %d got job %s in Status=%q (want Running) — claim did not flip state", worker, j.ID, j.Status) + } + allClaims = append(allClaims, j.ID) + } + mu.Unlock() + atomic.AddInt64(&totalClaimed, int64(len(claimed))) + } + }(w) + } + wg.Wait() + + // Invariant 1: no duplicate claims across the worker pool. + seen := make(map[string]int, len(allClaims)) + for _, id := range allClaims { + seen[id]++ + } + for id, count := range seen { + if count > 1 { + t.Errorf("job %s claimed %d times — SKIP LOCKED invariant violated", id, count) + } + } + + // Invariant 2: every seeded job appears in the claim set exactly once. + if len(seen) != M { + t.Errorf("distinct claimed IDs = %d, want %d (all seeded jobs must be claimed)", len(seen), M) + } + for id := range seededSet { + if seen[id] == 0 { + t.Errorf("seeded job %s was never claimed by any worker", id) + } + } + + // Invariant 3: persisted state reflects the transition — every seeded row + // is now Running; none is Pending. + for id := range seededSet { + got, err := jobRepo.Get(ctx, id) + if err != nil { + t.Fatalf("Get(%s) failed: %v", id, err) + } + if got.Status != domain.JobStatusRunning { + t.Errorf("job %s Status = %q, want %q", id, got.Status, domain.JobStatusRunning) + } + } + + // Final progress counter must match the total number of seeded jobs. + if got := atomic.LoadInt64(&totalClaimed); got != int64(M) { + t.Errorf("totalClaimed = %d, want %d", got, M) + } +} + +// TestJobRepository_ClaimPendingByAgentID_TransitionsDeployments validates the +// agent-scoped claim variant: Pending deployment rows for a given agent flip to +// Running; AwaitingCSR rows are returned but their state is preserved (the CSR +// submission path drives their next transition). +func TestJobRepository_ClaimPendingByAgentID_TransitionsDeployments(t *testing.T) { + if testing.Short() { + t.Skip("integration test requires PostgreSQL") + } + tdb := getTestDB(t) + db := tdb.freshSchema(t) + jobRepo := postgres.NewJobRepository(db) + agentRepo := postgres.NewAgentRepository(db) + ctx := context.Background() + + ownerID, teamID, issuerID, policyID := insertCertPrereqsRaw(t, db, ctx, "agentclaim") + + now := time.Now().Truncate(time.Microsecond) + cert := &domain.ManagedCertificate{ + ID: "mc-agentclaim", Name: "agentclaim", CommonName: "agentclaim.example.com", + SANs: []string{}, OwnerID: ownerID, TeamID: teamID, + IssuerID: issuerID, RenewalPolicyID: policyID, + Status: domain.CertificateStatusActive, + ExpiresAt: now.Add(30 * 24 * time.Hour), Tags: map[string]string{}, + CreatedAt: now, UpdatedAt: now, + } + if err := postgres.NewCertificateRepository(db).Create(ctx, cert); err != nil { + t.Fatalf("create cert failed: %v", err) + } + + agent := &domain.Agent{ + ID: "a-claim", + Name: "claim-agent", + Hostname: "claim-agent-host", + Status: domain.AgentStatusOnline, + RegisteredAt: now, + APIKeyHash: "hash-claim", + } + if err := agentRepo.Create(ctx, agent); err != nil { + t.Fatalf("create agent failed: %v", err) + } + + agentID := agent.ID + mkJob := func(id string, typ domain.JobType, status domain.JobStatus) *domain.Job { + return &domain.Job{ + ID: id, Type: typ, CertificateID: cert.ID, + AgentID: &agentID, + Status: status, + Attempts: 0, + MaxAttempts: 3, + ScheduledAt: now, + CreatedAt: now, + } + } + jobs := []*domain.Job{ + mkJob("job-agentclaim-dep-1", domain.JobTypeDeployment, domain.JobStatusPending), + mkJob("job-agentclaim-dep-2", domain.JobTypeDeployment, domain.JobStatusPending), + mkJob("job-agentclaim-csr-1", domain.JobTypeRenewal, domain.JobStatusAwaitingCSR), + // A Pending Renewal (not Deployment) must NOT be returned by the per-agent claim. + mkJob("job-agentclaim-ren-pending", domain.JobTypeRenewal, domain.JobStatusPending), + } + for _, j := range jobs { + if err := jobRepo.Create(ctx, j); err != nil { + t.Fatalf("create job %s failed: %v", j.ID, err) + } + } + + claimed, err := jobRepo.ClaimPendingByAgentID(ctx, agentID) + if err != nil { + t.Fatalf("ClaimPendingByAgentID failed: %v", err) + } + // Expect exactly the 2 deployments + 1 AwaitingCSR. + if len(claimed) != 3 { + t.Fatalf("len(claimed) = %d, want 3 (2 deployments + 1 AwaitingCSR)", len(claimed)) + } + + statusByID := map[string]domain.JobStatus{} + for _, j := range claimed { + statusByID[j.ID] = j.Status + } + // Both deployments must be Running in the returned slice (in-memory reflection). + for _, id := range []string{"job-agentclaim-dep-1", "job-agentclaim-dep-2"} { + if statusByID[id] != domain.JobStatusRunning { + t.Errorf("returned deployment %s Status = %q, want Running", id, statusByID[id]) + } + } + // AwaitingCSR must remain AwaitingCSR. + if statusByID["job-agentclaim-csr-1"] != domain.JobStatusAwaitingCSR { + t.Errorf("returned AwaitingCSR Status = %q, want AwaitingCSR", statusByID["job-agentclaim-csr-1"]) + } + // The unrelated Pending Renewal must not be returned. + if _, ok := statusByID["job-agentclaim-ren-pending"]; ok { + t.Errorf("Pending Renewal job was returned by ClaimPendingByAgentID — scope violation") + } + + // Persisted state: deployments Running, AwaitingCSR unchanged, Pending Renewal still Pending. + for id, want := range map[string]domain.JobStatus{ + "job-agentclaim-dep-1": domain.JobStatusRunning, + "job-agentclaim-dep-2": domain.JobStatusRunning, + "job-agentclaim-csr-1": domain.JobStatusAwaitingCSR, + "job-agentclaim-ren-pending": domain.JobStatusPending, + } { + got, err := jobRepo.Get(ctx, id) + if err != nil { + t.Fatalf("Get(%s) failed: %v", id, err) + } + if got.Status != want { + t.Errorf("persisted %s Status = %q, want %q", id, got.Status, want) + } + } +} diff --git a/internal/service/agent.go b/internal/service/agent.go index 43b9962..1a14718 100644 --- a/internal/service/agent.go +++ b/internal/service/agent.go @@ -284,8 +284,13 @@ func (s *AgentService) GetPendingWork(ctx context.Context, agentID string) ([]*d return nil, fmt.Errorf("failed to fetch agent: %w", err) } - // Return only jobs assigned to this agent (via agent_id or target→agent relationship) - return s.jobRepo.ListPendingByAgentID(ctx, agentID) + // Atomically claim jobs assigned to this agent. H-6 (CWE-362) remediation: + // ClaimPendingByAgentID uses SELECT ... FOR UPDATE SKIP LOCKED so concurrent poll + // requests (duplicate agents, retry storms, or a lagging long-poll) never observe + // the same Pending deployment row. Pending deployments are flipped to Running inside + // the claim transaction; AwaitingCSR jobs keep their state since CSR submission is + // the state-machine trigger for their next transition. + return s.jobRepo.ClaimPendingByAgentID(ctx, agentID) } // ReportJobStatus updates a job's status based on agent feedback. diff --git a/internal/service/job.go b/internal/service/job.go index b4b7650..89fc2d1 100644 --- a/internal/service/job.go +++ b/internal/service/job.go @@ -35,11 +35,18 @@ func NewJobService( // ProcessPendingJobs fetches and processes all pending jobs. // It routes jobs to the appropriate service based on job type and handles errors gracefully. +// +// Concurrency (H-6 CWE-362): jobs are claimed via ClaimPendingJobs which uses +// SELECT ... FOR UPDATE SKIP LOCKED and flips Pending → Running atomically. Concurrent +// scheduler replicas in HA deployments will therefore never observe the same Pending row, +// and the subsequent UpdateStatus(Running) calls inside the downstream service methods are +// idempotent against the pre-flipped state. func (s *JobService) ProcessPendingJobs(ctx context.Context) error { - // Fetch pending jobs - pendingJobs, err := s.jobRepo.ListByStatus(ctx, domain.JobStatusPending) + // Claim pending jobs atomically (H-6 remediation: was ListByStatus which had no row lock). + // Empty jobType matches all types; zero limit means unlimited (preserves prior semantics). + pendingJobs, err := s.jobRepo.ClaimPendingJobs(ctx, "", 0) if err != nil { - return fmt.Errorf("failed to list pending jobs: %w", err) + return fmt.Errorf("failed to claim pending jobs: %w", err) } if len(pendingJobs) == 0 { diff --git a/internal/service/testutil_test.go b/internal/service/testutil_test.go index 9f1ffb3..f88a7b3 100644 --- a/internal/service/testutil_test.go +++ b/internal/service/testutil_test.go @@ -278,6 +278,56 @@ func (m *mockJobRepo) ListPendingByAgentID(ctx context.Context, agentID string) return result, nil } +// ClaimPendingJobs simulates the H-6 atomic claim semantics: matching rows are transitioned +// Pending → Running before being returned. The in-memory mock has no concurrency primitives +// beyond the existing mutex, which is sufficient for single-goroutine service tests. +func (m *mockJobRepo) ClaimPendingJobs(ctx context.Context, jobType domain.JobType, limit int) ([]*domain.Job, error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.ListErr != nil { + return nil, m.ListErr + } + var claimed []*domain.Job + for _, j := range m.Jobs { + if j.Status != domain.JobStatusPending { + continue + } + if jobType != "" && j.Type != jobType { + continue + } + j.Status = domain.JobStatusRunning + claimed = append(claimed, j) + if limit > 0 && len(claimed) >= limit { + break + } + } + return claimed, nil +} + +// ClaimPendingByAgentID simulates the H-6 per-agent claim: Pending deployment rows scoped +// to the agent flip to Running; AwaitingCSR rows are returned but keep their state. +func (m *mockJobRepo) ClaimPendingByAgentID(ctx context.Context, agentID string) ([]*domain.Job, error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.ListErr != nil { + return nil, m.ListErr + } + var result []*domain.Job + for _, j := range m.Jobs { + if j.AgentID == nil || *j.AgentID != agentID { + continue + } + switch { + case j.Status == domain.JobStatusPending && j.Type == domain.JobTypeDeployment: + j.Status = domain.JobStatusRunning + result = append(result, j) + case j.Status == domain.JobStatusAwaitingCSR: + result = append(result, j) + } + } + return result, nil +} + func (m *mockJobRepo) AddJob(job *domain.Job) { m.mu.Lock() defer m.mu.Unlock() diff --git a/internal/service/verification_test.go b/internal/service/verification_test.go index e5c4ee8..b5785c4 100644 --- a/internal/service/verification_test.go +++ b/internal/service/verification_test.go @@ -69,6 +69,14 @@ func (m *mockVerificationJobRepo) ListPendingByAgentID(ctx context.Context, agen return nil, nil } +func (m *mockVerificationJobRepo) ClaimPendingJobs(ctx context.Context, jobType domain.JobType, limit int) ([]*domain.Job, error) { + return nil, nil +} + +func (m *mockVerificationJobRepo) ClaimPendingByAgentID(ctx context.Context, agentID string) ([]*domain.Job, error) { + return nil, nil +} + // newVerificationTestService creates a VerificationService wired with test doubles. func newVerificationTestService(jobs map[string]*domain.Job, jobRepoErr error) (*VerificationService, *mockVerificationJobRepo, *mockAuditRepo) { jobRepo := &mockVerificationJobRepo{jobs: jobs, err: jobRepoErr}