Files
certctl/internal/repository/postgres/certificate.go
T
shankar0123 80450c7180 fix(repository): populate TargetIDs in certificate scan helper (M-7)
scanCertificate never queried the certificate_target_mappings junction
table, so Certificate.TargetIDs was always nil on reads. This silently
broke deployment lookups, bulk revocation filters, cert detail pages,
and any code path that iterated TargetIDs to dispatch target work.

Fix:
- Convert scanCertificate to a receiver method (r *CertificateRepository)
  so it has access to the DB for the secondary junction query.
- Get(): scan the row, then call r.getTargetIDs(ctx, certID) to populate
  TargetIDs with a single targeted query.
- List() and GetExpiringCertificates(): inline the scan loop so we can
  collect all certIDs first, then call getTargetIDsForCertificates once
  with pq.Array(certIDs) to avoid N+1 round-trips. Build a map and
  attach TargetIDs to each certificate in the result set.
- Default TargetIDs to []string{} (not nil) when a cert has no mappings
  so JSON marshals as [] rather than null.

Tests:
- New integration test file certificate_targetids_test.go with 5
  subtests exercising Get / List / GetExpiringCertificates single
  and multi-target cases plus the empty-slice vs nil contract.
- Uses the shared testcontainers-go setupTestDB infrastructure and
  skips under 'go test -short' so CI (which excludes ./internal/repository/...
  from coverage paths anyway) stays green.

Addresses M-7 from certctl-audit-report.md.
2026-04-17 15:41:08 +00:00

698 lines
21 KiB
Go

package postgres
import (
"context"
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/repository"
)
// CertificateRepository implements repository.CertificateRepository
type CertificateRepository struct {
db *sql.DB
}
// NewCertificateRepository creates a new CertificateRepository
func NewCertificateRepository(db *sql.DB) *CertificateRepository {
return &CertificateRepository{db: db}
}
// List returns a paginated list of certificates matching the filter criteria
func (r *CertificateRepository) List(ctx context.Context, filter *repository.CertificateFilter) ([]*domain.ManagedCertificate, int, error) {
if filter == nil {
filter = &repository.CertificateFilter{}
}
// Set defaults
if filter.Page < 1 {
filter.Page = 1
}
if filter.PerPage == 0 || filter.PerPage > 500 {
filter.PerPage = 50
}
// Build WHERE clause
var whereConditions []string
var args []interface{}
argCount := 1
if filter.Status != "" {
whereConditions = append(whereConditions, fmt.Sprintf("status = $%d", argCount))
args = append(args, filter.Status)
argCount++
}
if filter.Environment != "" {
whereConditions = append(whereConditions, fmt.Sprintf("environment = $%d", argCount))
args = append(args, filter.Environment)
argCount++
}
if filter.OwnerID != "" {
whereConditions = append(whereConditions, fmt.Sprintf("owner_id = $%d", argCount))
args = append(args, filter.OwnerID)
argCount++
}
if filter.TeamID != "" {
whereConditions = append(whereConditions, fmt.Sprintf("team_id = $%d", argCount))
args = append(args, filter.TeamID)
argCount++
}
if filter.IssuerID != "" {
whereConditions = append(whereConditions, fmt.Sprintf("issuer_id = $%d", argCount))
args = append(args, filter.IssuerID)
argCount++
}
if filter.ProfileID != "" {
whereConditions = append(whereConditions, fmt.Sprintf("certificate_profile_id = $%d", argCount))
args = append(args, filter.ProfileID)
argCount++
}
if filter.ExpiresBefore != nil {
whereConditions = append(whereConditions, fmt.Sprintf("expires_at < $%d", argCount))
args = append(args, filter.ExpiresBefore)
argCount++
}
if filter.ExpiresAfter != nil {
whereConditions = append(whereConditions, fmt.Sprintf("expires_at > $%d", argCount))
args = append(args, filter.ExpiresAfter)
argCount++
}
if filter.CreatedAfter != nil {
whereConditions = append(whereConditions, fmt.Sprintf("created_at > $%d", argCount))
args = append(args, filter.CreatedAfter)
argCount++
}
if filter.UpdatedAfter != nil {
whereConditions = append(whereConditions, fmt.Sprintf("updated_at > $%d", argCount))
args = append(args, filter.UpdatedAfter)
argCount++
}
if filter.AgentID != "" {
// Filter by agent_id via deployment_targets and certificate_target_mappings
whereConditions = append(whereConditions, fmt.Sprintf(`id IN (
SELECT DISTINCT certificate_id FROM certificate_target_mappings ctm
JOIN deployment_targets dt ON ctm.target_id = dt.id
WHERE dt.agent_id = $%d
)`, argCount))
args = append(args, filter.AgentID)
argCount++
}
whereClause := ""
if len(whereConditions) > 0 {
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
}
// Handle cursor-based pagination
if filter.Cursor != "" {
createdAt, id, err := decodeCursor(filter.Cursor)
if err == nil {
// Add cursor condition: (created_at, id) < (cursor_time, cursor_id)
whereConditions = append(whereConditions, fmt.Sprintf("(created_at, id) < ($%d, $%d)", argCount, argCount+1))
args = append(args, createdAt, id)
argCount += 2
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
}
}
// Get total count
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM managed_certificates %s", whereClause)
var total int
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("failed to count certificates: %w", err)
}
// Determine sort field and direction
sortField := "created_at"
sortDir := "DESC"
sortFieldMap := map[string]string{
"notAfter": "expires_at",
"expiresAt": "expires_at",
"createdAt": "created_at",
"updatedAt": "updated_at",
"commonName": "common_name",
"name": "name",
"status": "status",
"environment": "environment",
}
if filter.Sort != "" {
if mappedField, ok := sortFieldMap[filter.Sort]; ok {
sortField = mappedField
}
}
if filter.SortDesc {
sortDir = "DESC"
} else {
sortDir = "ASC"
}
// Get paginated results
pageSize := filter.PerPage
if filter.PageSize > 0 && filter.PageSize <= 500 {
pageSize = filter.PageSize
}
var limitClause string
var offset int
if filter.Cursor != "" {
// Cursor-based pagination
limitClause = fmt.Sprintf("LIMIT $%d", argCount)
args = append(args, pageSize)
argCount++
} else {
// Page-based pagination
offset = (filter.Page - 1) * pageSize
limitClause = fmt.Sprintf("LIMIT $%d OFFSET $%d", argCount, argCount+1)
args = append(args, pageSize, offset)
argCount += 2
}
query := fmt.Sprintf(`
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, revoked_at, revocation_reason, created_at, updated_at
FROM managed_certificates
%s
ORDER BY %s %s
%s
`, whereClause, sortField, sortDir, limitClause)
rows, err := r.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, 0, fmt.Errorf("failed to query certificates: %w", err)
}
defer rows.Close()
var certs []*domain.ManagedCertificate
var certIDs []string
for rows.Next() {
var cert domain.ManagedCertificate
var tagsJSON []byte
var sans pq.StringArray
var profileID sql.NullString
var revocationReason sql.NullString
err := rows.Scan(
&cert.ID, &cert.Name, &cert.CommonName, &sans, &cert.Environment, &cert.OwnerID,
&cert.TeamID, &cert.IssuerID, &cert.RenewalPolicyID, &profileID,
&cert.Status, &cert.ExpiresAt, &tagsJSON,
&cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.RevokedAt, &revocationReason,
&cert.CreatedAt, &cert.UpdatedAt)
if err != nil {
return nil, 0, fmt.Errorf("failed to scan certificate: %w", err)
}
cert.SANs = []string(sans)
if profileID.Valid {
cert.CertificateProfileID = profileID.String
}
if revocationReason.Valid {
cert.RevocationReason = revocationReason.String
}
// Unmarshal tags
if len(tagsJSON) > 0 {
if err := json.Unmarshal(tagsJSON, &cert.Tags); err != nil {
return nil, 0, fmt.Errorf("failed to unmarshal tags: %w", err)
}
} else {
cert.Tags = make(map[string]string)
}
certs = append(certs, &cert)
certIDs = append(certIDs, cert.ID)
}
if err := rows.Err(); err != nil {
return nil, 0, fmt.Errorf("error iterating certificate rows: %w", err)
}
// Fetch target IDs for all certificates in a single query (avoid N+1)
if len(certIDs) > 0 {
targetIDsMap, err := r.getTargetIDsForCertificates(ctx, certIDs)
if err != nil {
return nil, 0, err
}
for _, cert := range certs {
if targetIDs, ok := targetIDsMap[cert.ID]; ok {
cert.TargetIDs = targetIDs
} else {
cert.TargetIDs = []string{}
}
}
}
return certs, total, nil
}
// Get retrieves a certificate by ID
func (r *CertificateRepository) Get(ctx context.Context, id string) (*domain.ManagedCertificate, error) {
row := r.db.QueryRowContext(ctx, `
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, revoked_at, revocation_reason, created_at, updated_at
FROM managed_certificates
WHERE id = $1
`, id)
cert, err := r.scanCertificate(ctx, row)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("certificate not found")
}
return nil, fmt.Errorf("failed to query certificate: %w", err)
}
return cert, nil
}
// Create stores a new certificate
func (r *CertificateRepository) Create(ctx context.Context, cert *domain.ManagedCertificate) error {
if cert.ID == "" {
cert.ID = uuid.New().String()
}
tagsJSON, err := json.Marshal(cert.Tags)
if err != nil {
return fmt.Errorf("failed to marshal tags: %w", err)
}
var profileID *string
if cert.CertificateProfileID != "" {
profileID = &cert.CertificateProfileID
}
var revocationReason *string
if cert.RevocationReason != "" {
revocationReason = &cert.RevocationReason
}
err = r.db.QueryRowContext(ctx, `
INSERT INTO managed_certificates (
id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, revoked_at, revocation_reason, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
RETURNING id
`, cert.ID, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment,
cert.OwnerID, cert.TeamID, cert.IssuerID, cert.RenewalPolicyID, profileID,
cert.Status, cert.ExpiresAt,
tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt,
cert.RevokedAt, revocationReason,
cert.CreatedAt, cert.UpdatedAt).Scan(&cert.ID)
if err != nil {
return fmt.Errorf("failed to create certificate: %w", err)
}
return nil
}
// Update modifies an existing certificate
func (r *CertificateRepository) Update(ctx context.Context, cert *domain.ManagedCertificate) error {
tagsJSON, err := json.Marshal(cert.Tags)
if err != nil {
return fmt.Errorf("failed to marshal tags: %w", err)
}
var profileID *string
if cert.CertificateProfileID != "" {
profileID = &cert.CertificateProfileID
}
var revocationReason *string
if cert.RevocationReason != "" {
revocationReason = &cert.RevocationReason
}
result, err := r.db.ExecContext(ctx, `
UPDATE managed_certificates SET
name = $1,
common_name = $2,
sans = $3,
environment = $4,
owner_id = $5,
team_id = $6,
issuer_id = $7,
certificate_profile_id = $8,
status = $9,
expires_at = $10,
tags = $11,
last_renewal_at = $12,
last_deployment_at = $13,
revoked_at = $14,
revocation_reason = $15,
updated_at = $16
WHERE id = $17
`, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment,
cert.OwnerID, cert.TeamID, cert.IssuerID, profileID, cert.Status, cert.ExpiresAt,
tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt,
cert.RevokedAt, revocationReason, cert.UpdatedAt, cert.ID)
if err != nil {
return fmt.Errorf("failed to update certificate: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rows == 0 {
return fmt.Errorf("certificate not found")
}
return nil
}
// Archive marks a certificate as archived
func (r *CertificateRepository) Archive(ctx context.Context, id string) error {
result, err := r.db.ExecContext(ctx, `
UPDATE managed_certificates SET status = $1, updated_at = $2 WHERE id = $3
`, domain.CertificateStatusArchived, time.Now(), id)
if err != nil {
return fmt.Errorf("failed to archive certificate: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rows == 0 {
return fmt.Errorf("certificate not found")
}
return nil
}
// ListVersions returns all versions of a certificate
func (r *CertificateRepository) ListVersions(ctx context.Context, certID string) ([]*domain.CertificateVersion, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, certificate_id, serial_number, not_before, not_after,
fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at
FROM certificate_versions
WHERE certificate_id = $1
ORDER BY created_at DESC
`, certID)
if err != nil {
return nil, fmt.Errorf("failed to query certificate versions: %w", err)
}
defer rows.Close()
var versions []*domain.CertificateVersion
for rows.Next() {
var v domain.CertificateVersion
var csrPEM sql.NullString
var keyAlgo sql.NullString
var keySize sql.NullInt64
if err := rows.Scan(&v.ID, &v.CertificateID, &v.SerialNumber, &v.NotBefore, &v.NotAfter,
&v.FingerprintSHA256, &v.PEMChain, &csrPEM, &keyAlgo, &keySize, &v.CreatedAt); err != nil {
return nil, fmt.Errorf("failed to scan certificate version: %w", err)
}
v.CSRPEM = csrPEM.String
v.KeyAlgorithm = keyAlgo.String
v.KeySize = int(keySize.Int64)
versions = append(versions, &v)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating version rows: %w", err)
}
return versions, nil
}
// CreateVersion stores a new certificate version
func (r *CertificateRepository) CreateVersion(ctx context.Context, version *domain.CertificateVersion) error {
if version.ID == "" {
version.ID = uuid.New().String()
}
err := r.db.QueryRowContext(ctx, `
INSERT INTO certificate_versions (
id, certificate_id, serial_number, not_before, not_after,
fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING id
`, version.ID, version.CertificateID, version.SerialNumber, version.NotBefore, version.NotAfter,
version.FingerprintSHA256, version.PEMChain, version.CSRPEM, version.KeyAlgorithm, version.KeySize, version.CreatedAt).Scan(&version.ID)
if err != nil {
return fmt.Errorf("failed to create certificate version: %w", err)
}
return nil
}
// GetExpiringCertificates returns certificates expiring before the given time
func (r *CertificateRepository) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, revoked_at, revocation_reason, created_at, updated_at
FROM managed_certificates
WHERE expires_at < $1 AND status != $2
ORDER BY expires_at ASC
`, before, domain.CertificateStatusArchived)
if err != nil {
return nil, fmt.Errorf("failed to query expiring certificates: %w", err)
}
defer rows.Close()
var certs []*domain.ManagedCertificate
var certIDs []string
for rows.Next() {
var cert domain.ManagedCertificate
var tagsJSON []byte
var sans pq.StringArray
var profileID sql.NullString
var revocationReason sql.NullString
err := rows.Scan(
&cert.ID, &cert.Name, &cert.CommonName, &sans, &cert.Environment, &cert.OwnerID,
&cert.TeamID, &cert.IssuerID, &cert.RenewalPolicyID, &profileID,
&cert.Status, &cert.ExpiresAt, &tagsJSON,
&cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.RevokedAt, &revocationReason,
&cert.CreatedAt, &cert.UpdatedAt)
if err != nil {
return nil, fmt.Errorf("failed to scan certificate: %w", err)
}
cert.SANs = []string(sans)
if profileID.Valid {
cert.CertificateProfileID = profileID.String
}
if revocationReason.Valid {
cert.RevocationReason = revocationReason.String
}
// Unmarshal tags
if len(tagsJSON) > 0 {
if err := json.Unmarshal(tagsJSON, &cert.Tags); err != nil {
return nil, fmt.Errorf("failed to unmarshal tags: %w", err)
}
} else {
cert.Tags = make(map[string]string)
}
certs = append(certs, &cert)
certIDs = append(certIDs, cert.ID)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating expiring certificate rows: %w", err)
}
// Fetch target IDs for all certificates in a single query (avoid N+1)
if len(certIDs) > 0 {
targetIDsMap, err := r.getTargetIDsForCertificates(ctx, certIDs)
if err != nil {
return nil, err
}
for _, cert := range certs {
if targetIDs, ok := targetIDsMap[cert.ID]; ok {
cert.TargetIDs = targetIDs
} else {
cert.TargetIDs = []string{}
}
}
}
return certs, nil
}
// GetLatestVersion returns the most recent certificate version for a certificate.
func (r *CertificateRepository) GetLatestVersion(ctx context.Context, certID string) (*domain.CertificateVersion, error) {
var v domain.CertificateVersion
var csrPEM sql.NullString
var keyAlgo sql.NullString
var keySize sql.NullInt64
err := r.db.QueryRowContext(ctx, `
SELECT id, certificate_id, serial_number, not_before, not_after,
fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at
FROM certificate_versions
WHERE certificate_id = $1
ORDER BY created_at DESC
LIMIT 1
`, certID).Scan(&v.ID, &v.CertificateID, &v.SerialNumber, &v.NotBefore, &v.NotAfter,
&v.FingerprintSHA256, &v.PEMChain, &csrPEM, &keyAlgo, &keySize, &v.CreatedAt)
v.CSRPEM = csrPEM.String
v.KeyAlgorithm = keyAlgo.String
v.KeySize = int(keySize.Int64)
if err != nil {
return nil, fmt.Errorf("failed to get latest certificate version: %w", err)
}
return &v, nil
}
// getTargetIDs retrieves all target IDs for a given certificate from the junction table.
// Returns an empty slice (not nil) if no targets are found.
func (r *CertificateRepository) getTargetIDs(ctx context.Context, certID string) ([]string, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT target_id FROM certificate_target_mappings
WHERE certificate_id = $1
ORDER BY target_id ASC
`, certID)
if err != nil {
return nil, fmt.Errorf("failed to query target mappings: %w", err)
}
defer rows.Close()
var targetIDs []string
for rows.Next() {
var targetID string
if err := rows.Scan(&targetID); err != nil {
return nil, fmt.Errorf("failed to scan target ID: %w", err)
}
targetIDs = append(targetIDs, targetID)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating target ID rows: %w", err)
}
// Return empty slice instead of nil for consistency with JSON marshaling
if targetIDs == nil {
targetIDs = []string{}
}
return targetIDs, nil
}
// getTargetIDsForCertificates retrieves target IDs for multiple certificates in a single query.
// Returns a map of certificate_id -> []target_id.
func (r *CertificateRepository) getTargetIDsForCertificates(ctx context.Context, certIDs []string) (map[string][]string, error) {
if len(certIDs) == 0 {
return make(map[string][]string), nil
}
rows, err := r.db.QueryContext(ctx, `
SELECT certificate_id, target_id FROM certificate_target_mappings
WHERE certificate_id = ANY($1)
ORDER BY certificate_id, target_id ASC
`, pq.Array(certIDs))
if err != nil {
return nil, fmt.Errorf("failed to query target mappings: %w", err)
}
defer rows.Close()
targetIDsMap := make(map[string][]string)
for rows.Next() {
var certID, targetID string
if err := rows.Scan(&certID, &targetID); err != nil {
return nil, fmt.Errorf("failed to scan target mapping: %w", err)
}
targetIDsMap[certID] = append(targetIDsMap[certID], targetID)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating target mapping rows: %w", err)
}
return targetIDsMap, nil
}
// scanCertificate scans a certificate from a row or rows and populates its TargetIDs
// by querying the certificate_target_mappings junction table.
func (r *CertificateRepository) scanCertificate(ctx context.Context, scanner interface {
Scan(...interface{}) error
}) (*domain.ManagedCertificate, error) {
var cert domain.ManagedCertificate
var tagsJSON []byte
var sans pq.StringArray
var profileID sql.NullString
var revocationReason sql.NullString
err := scanner.Scan(
&cert.ID, &cert.Name, &cert.CommonName, &sans, &cert.Environment, &cert.OwnerID,
&cert.TeamID, &cert.IssuerID, &cert.RenewalPolicyID, &profileID,
&cert.Status, &cert.ExpiresAt, &tagsJSON,
&cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.RevokedAt, &revocationReason,
&cert.CreatedAt, &cert.UpdatedAt)
if err != nil {
return nil, fmt.Errorf("failed to scan certificate: %w", err)
}
cert.SANs = []string(sans)
if profileID.Valid {
cert.CertificateProfileID = profileID.String
}
if revocationReason.Valid {
cert.RevocationReason = revocationReason.String
}
// Unmarshal tags
if len(tagsJSON) > 0 {
if err := json.Unmarshal(tagsJSON, &cert.Tags); err != nil {
return nil, fmt.Errorf("failed to unmarshal tags: %w", err)
}
} else {
cert.Tags = make(map[string]string)
}
// Populate TargetIDs from junction table
targetIDs, err := r.getTargetIDs(ctx, cert.ID)
if err != nil {
return nil, err
}
cert.TargetIDs = targetIDs
return &cert, nil
}
// decodeCursor extracts a timestamp and ID from a cursor token.
func decodeCursor(cursor string) (time.Time, string, error) {
raw, err := base64.URLEncoding.DecodeString(cursor)
if err != nil {
return time.Time{}, "", fmt.Errorf("invalid cursor: %w", err)
}
parts := strings.SplitN(string(raw), ":", 2)
if len(parts) != 2 {
return time.Time{}, "", fmt.Errorf("invalid cursor format")
}
t, err := time.Parse(time.RFC3339Nano, parts[0])
if err != nil {
return time.Time{}, "", fmt.Errorf("invalid cursor timestamp: %w", err)
}
return t, parts[1], nil
}
// encodeCursor creates an opaque cursor token from a timestamp and ID.
// Reserved for future use in repository-level cursor pagination.
var _ = func(createdAt time.Time, id string) string {
raw := createdAt.Format(time.RFC3339Nano) + ":" + id
return base64.URLEncoding.EncodeToString([]byte(raw))
}