mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-12 19:40:27 +00:00
fix(repository): populate TargetIDs in certificate scan helper (M-7)
scanCertificate never queried the certificate_target_mappings junction
table, so Certificate.TargetIDs was always nil on reads. This silently
broke deployment lookups, bulk revocation filters, cert detail pages,
and any code path that iterated TargetIDs to dispatch target work.
Fix:
- Convert scanCertificate to a receiver method (r *CertificateRepository)
so it has access to the DB for the secondary junction query.
- Get(): scan the row, then call r.getTargetIDs(ctx, certID) to populate
TargetIDs with a single targeted query.
- List() and GetExpiringCertificates(): inline the scan loop so we can
collect all certIDs first, then call getTargetIDsForCertificates once
with pq.Array(certIDs) to avoid N+1 round-trips. Build a map and
attach TargetIDs to each certificate in the result set.
- Default TargetIDs to []string{} (not nil) when a cert has no mappings
so JSON marshals as [] rather than null.
Tests:
- New integration test file certificate_targetids_test.go with 5
subtests exercising Get / List / GetExpiringCertificates single
and multi-target cases plus the empty-slice vs nil contract.
- Uses the shared testcontainers-go setupTestDB infrastructure and
skips under 'go test -short' so CI (which excludes ./internal/repository/...
from coverage paths anyway) stays green.
Addresses M-7 from certctl-audit-report.md.
This commit is contained in:
@@ -190,18 +190,65 @@ func (r *CertificateRepository) List(ctx context.Context, filter *repository.Cer
|
||||
defer rows.Close()
|
||||
|
||||
var certs []*domain.ManagedCertificate
|
||||
var certIDs []string
|
||||
for rows.Next() {
|
||||
cert, err := scanCertificate(rows)
|
||||
var cert domain.ManagedCertificate
|
||||
var tagsJSON []byte
|
||||
var sans pq.StringArray
|
||||
var profileID sql.NullString
|
||||
var revocationReason sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&cert.ID, &cert.Name, &cert.CommonName, &sans, &cert.Environment, &cert.OwnerID,
|
||||
&cert.TeamID, &cert.IssuerID, &cert.RenewalPolicyID, &profileID,
|
||||
&cert.Status, &cert.ExpiresAt, &tagsJSON,
|
||||
&cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.RevokedAt, &revocationReason,
|
||||
&cert.CreatedAt, &cert.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return nil, 0, fmt.Errorf("failed to scan certificate: %w", err)
|
||||
}
|
||||
certs = append(certs, cert)
|
||||
|
||||
cert.SANs = []string(sans)
|
||||
if profileID.Valid {
|
||||
cert.CertificateProfileID = profileID.String
|
||||
}
|
||||
if revocationReason.Valid {
|
||||
cert.RevocationReason = revocationReason.String
|
||||
}
|
||||
|
||||
// Unmarshal tags
|
||||
if len(tagsJSON) > 0 {
|
||||
if err := json.Unmarshal(tagsJSON, &cert.Tags); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to unmarshal tags: %w", err)
|
||||
}
|
||||
} else {
|
||||
cert.Tags = make(map[string]string)
|
||||
}
|
||||
|
||||
certs = append(certs, &cert)
|
||||
certIDs = append(certIDs, cert.ID)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, fmt.Errorf("error iterating certificate rows: %w", err)
|
||||
}
|
||||
|
||||
// Fetch target IDs for all certificates in a single query (avoid N+1)
|
||||
if len(certIDs) > 0 {
|
||||
targetIDsMap, err := r.getTargetIDsForCertificates(ctx, certIDs)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
for _, cert := range certs {
|
||||
if targetIDs, ok := targetIDsMap[cert.ID]; ok {
|
||||
cert.TargetIDs = targetIDs
|
||||
} else {
|
||||
cert.TargetIDs = []string{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return certs, total, nil
|
||||
}
|
||||
|
||||
@@ -214,7 +261,7 @@ func (r *CertificateRepository) Get(ctx context.Context, id string) (*domain.Man
|
||||
WHERE id = $1
|
||||
`, id)
|
||||
|
||||
cert, err := scanCertificate(row)
|
||||
cert, err := r.scanCertificate(ctx, row)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("certificate not found")
|
||||
@@ -421,18 +468,65 @@ func (r *CertificateRepository) GetExpiringCertificates(ctx context.Context, bef
|
||||
defer rows.Close()
|
||||
|
||||
var certs []*domain.ManagedCertificate
|
||||
var certIDs []string
|
||||
for rows.Next() {
|
||||
cert, err := scanCertificate(rows)
|
||||
var cert domain.ManagedCertificate
|
||||
var tagsJSON []byte
|
||||
var sans pq.StringArray
|
||||
var profileID sql.NullString
|
||||
var revocationReason sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&cert.ID, &cert.Name, &cert.CommonName, &sans, &cert.Environment, &cert.OwnerID,
|
||||
&cert.TeamID, &cert.IssuerID, &cert.RenewalPolicyID, &profileID,
|
||||
&cert.Status, &cert.ExpiresAt, &tagsJSON,
|
||||
&cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.RevokedAt, &revocationReason,
|
||||
&cert.CreatedAt, &cert.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to scan certificate: %w", err)
|
||||
}
|
||||
certs = append(certs, cert)
|
||||
|
||||
cert.SANs = []string(sans)
|
||||
if profileID.Valid {
|
||||
cert.CertificateProfileID = profileID.String
|
||||
}
|
||||
if revocationReason.Valid {
|
||||
cert.RevocationReason = revocationReason.String
|
||||
}
|
||||
|
||||
// Unmarshal tags
|
||||
if len(tagsJSON) > 0 {
|
||||
if err := json.Unmarshal(tagsJSON, &cert.Tags); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal tags: %w", err)
|
||||
}
|
||||
} else {
|
||||
cert.Tags = make(map[string]string)
|
||||
}
|
||||
|
||||
certs = append(certs, &cert)
|
||||
certIDs = append(certIDs, cert.ID)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating expiring certificate rows: %w", err)
|
||||
}
|
||||
|
||||
// Fetch target IDs for all certificates in a single query (avoid N+1)
|
||||
if len(certIDs) > 0 {
|
||||
targetIDsMap, err := r.getTargetIDsForCertificates(ctx, certIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, cert := range certs {
|
||||
if targetIDs, ok := targetIDsMap[cert.ID]; ok {
|
||||
cert.TargetIDs = targetIDs
|
||||
} else {
|
||||
cert.TargetIDs = []string{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return certs, nil
|
||||
}
|
||||
|
||||
@@ -462,8 +556,76 @@ func (r *CertificateRepository) GetLatestVersion(ctx context.Context, certID str
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
// scanCertificate scans a certificate from a row or rows
|
||||
func scanCertificate(scanner interface {
|
||||
// getTargetIDs retrieves all target IDs for a given certificate from the junction table.
|
||||
// Returns an empty slice (not nil) if no targets are found.
|
||||
func (r *CertificateRepository) getTargetIDs(ctx context.Context, certID string) ([]string, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT target_id FROM certificate_target_mappings
|
||||
WHERE certificate_id = $1
|
||||
ORDER BY target_id ASC
|
||||
`, certID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query target mappings: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var targetIDs []string
|
||||
for rows.Next() {
|
||||
var targetID string
|
||||
if err := rows.Scan(&targetID); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan target ID: %w", err)
|
||||
}
|
||||
targetIDs = append(targetIDs, targetID)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating target ID rows: %w", err)
|
||||
}
|
||||
|
||||
// Return empty slice instead of nil for consistency with JSON marshaling
|
||||
if targetIDs == nil {
|
||||
targetIDs = []string{}
|
||||
}
|
||||
|
||||
return targetIDs, nil
|
||||
}
|
||||
|
||||
// getTargetIDsForCertificates retrieves target IDs for multiple certificates in a single query.
|
||||
// Returns a map of certificate_id -> []target_id.
|
||||
func (r *CertificateRepository) getTargetIDsForCertificates(ctx context.Context, certIDs []string) (map[string][]string, error) {
|
||||
if len(certIDs) == 0 {
|
||||
return make(map[string][]string), nil
|
||||
}
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT certificate_id, target_id FROM certificate_target_mappings
|
||||
WHERE certificate_id = ANY($1)
|
||||
ORDER BY certificate_id, target_id ASC
|
||||
`, pq.Array(certIDs))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query target mappings: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
targetIDsMap := make(map[string][]string)
|
||||
for rows.Next() {
|
||||
var certID, targetID string
|
||||
if err := rows.Scan(&certID, &targetID); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan target mapping: %w", err)
|
||||
}
|
||||
targetIDsMap[certID] = append(targetIDsMap[certID], targetID)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating target mapping rows: %w", err)
|
||||
}
|
||||
|
||||
return targetIDsMap, nil
|
||||
}
|
||||
|
||||
// scanCertificate scans a certificate from a row or rows and populates its TargetIDs
|
||||
// by querying the certificate_target_mappings junction table.
|
||||
func (r *CertificateRepository) scanCertificate(ctx context.Context, scanner interface {
|
||||
Scan(...interface{}) error
|
||||
}) (*domain.ManagedCertificate, error) {
|
||||
var cert domain.ManagedCertificate
|
||||
@@ -500,6 +662,13 @@ func scanCertificate(scanner interface {
|
||||
cert.Tags = make(map[string]string)
|
||||
}
|
||||
|
||||
// Populate TargetIDs from junction table
|
||||
targetIDs, err := r.getTargetIDs(ctx, cert.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cert.TargetIDs = targetIDs
|
||||
|
||||
return &cert, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user