Files
certctl/internal/repository/postgres/discovery.go
T
shankar0123 21aeed4f4e legal: addlicense headers + normalize legacy variants (Phase 0 RED-4)
Phase 0 closure (Path B2, post-rewrite):

addlicense sweep — adds the canonical certctl LLC copyright + BUSL-1.1
SPDX header to every production Go file. Template:

  // Copyright 2026 certctl LLC. All rights reserved.
  // SPDX-License-Identifier: BUSL-1.1

Coverage: 338 / 338 production Go files (cmd/ + internal/, excluding
*_test.go and **/testdata/**). Pre-sweep coverage was 22 / 338 (6.5%);
post-sweep is 338 / 338 (100%).

Normalized 22 pre-existing legacy headers (`// Copyright (c) certctl`
+ `// SPDX-License-Identifier: BSL-1.1`) and 1 file using a
`Certctl Contributors` attribution. The legacy SPDX ID `BSL-1.1`
is non-standard; the official SPDX identifier for Business Source
License 1.1 is `BUSL-1.1` (capital U). All 338 files now share the
canonical form.

Generated via:
  addlicense -c "certctl LLC" -y 2026 \
    -f cowork/legal/copyright-header.tpl \
    -ignore '**/testdata/**' -ignore '**/*_test.go' \
    cmd/ internal/

Verification:
  find cmd internal -name '*.go' -not -name '*_test.go' \
    -not -path '*/testdata/*' \
    -exec grep -L '^// Copyright 2026 certctl LLC' {} \; | wc -l

  Returns: 0

gofmt clean. Header additions are comments only, no compile impact.

Closes: cowork/certctl-architecture-diligence-audit.html#fix-RED-4
2026-05-13 21:23:35 +00:00

400 lines
14 KiB
Go

// Copyright 2026 certctl LLC. All rights reserved.
// SPDX-License-Identifier: BUSL-1.1
package postgres
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/certctl-io/certctl/internal/domain"
"github.com/certctl-io/certctl/internal/repository"
"github.com/lib/pq"
)
// DiscoveryRepository implements the repository.DiscoveryRepository interface.
type DiscoveryRepository struct {
db *sql.DB
}
// NewDiscoveryRepository creates a new PostgreSQL-backed discovery repository.
func NewDiscoveryRepository(db *sql.DB) *DiscoveryRepository {
return &DiscoveryRepository{db: db}
}
// --- Discovery Scans ---
// CreateScan stores a new discovery scan record.
func (r *DiscoveryRepository) CreateScan(ctx context.Context, scan *domain.DiscoveryScan) error {
query := `
INSERT INTO discovery_scans (id, agent_id, directories, certificates_found, certificates_new, errors_count, scan_duration_ms, started_at, completed_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT (id) DO NOTHING`
_, err := r.db.ExecContext(ctx, query,
scan.ID,
scan.AgentID,
pq.Array(scan.Directories),
scan.CertificatesFound,
scan.CertificatesNew,
scan.ErrorsCount,
scan.ScanDurationMs,
scan.StartedAt,
scan.CompletedAt,
)
if err != nil {
return fmt.Errorf("failed to create discovery scan: %w", err)
}
return nil
}
// GetScan retrieves a discovery scan by ID.
func (r *DiscoveryRepository) GetScan(ctx context.Context, id string) (*domain.DiscoveryScan, error) {
query := `
SELECT id, agent_id, directories, certificates_found, certificates_new, errors_count, scan_duration_ms, started_at, completed_at
FROM discovery_scans WHERE id = $1`
scan := &domain.DiscoveryScan{}
var dirs []string
err := r.db.QueryRowContext(ctx, query, id).Scan(
&scan.ID, &scan.AgentID, pq.Array(&dirs),
&scan.CertificatesFound, &scan.CertificatesNew, &scan.ErrorsCount,
&scan.ScanDurationMs, &scan.StartedAt, &scan.CompletedAt,
)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("discovery scan not found: %w", repository.ErrNotFound)
}
if err != nil {
return nil, fmt.Errorf("failed to get discovery scan: %w", err)
}
scan.Directories = dirs
return scan, nil
}
// ListScans returns discovery scans, optionally filtered by agent ID.
func (r *DiscoveryRepository) ListScans(ctx context.Context, agentID string, page, perPage int) ([]*domain.DiscoveryScan, int, error) {
if page < 1 {
page = 1
}
if perPage <= 0 || perPage > 500 {
perPage = 50
}
var whereConditions []string
var args []interface{}
argCount := 1
if agentID != "" {
whereConditions = append(whereConditions, fmt.Sprintf("agent_id = $%d", argCount))
args = append(args, agentID)
argCount++
}
whereClause := ""
if len(whereConditions) > 0 {
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
}
// Count
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM discovery_scans %s", whereClause)
var total int
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("failed to count discovery scans: %w", err)
}
// List
offset := (page - 1) * perPage
listQuery := fmt.Sprintf(`
SELECT id, agent_id, directories, certificates_found, certificates_new, errors_count, scan_duration_ms, started_at, completed_at
FROM discovery_scans %s
ORDER BY started_at DESC
LIMIT $%d OFFSET $%d`, whereClause, argCount, argCount+1)
args = append(args, perPage, offset)
rows, err := r.db.QueryContext(ctx, listQuery, args...)
if err != nil {
return nil, 0, fmt.Errorf("failed to list discovery scans: %w", err)
}
defer rows.Close()
var scans []*domain.DiscoveryScan
for rows.Next() {
scan := &domain.DiscoveryScan{}
var dirs []string
if err := rows.Scan(
&scan.ID, &scan.AgentID, pq.Array(&dirs),
&scan.CertificatesFound, &scan.CertificatesNew, &scan.ErrorsCount,
&scan.ScanDurationMs, &scan.StartedAt, &scan.CompletedAt,
); err != nil {
return nil, 0, fmt.Errorf("failed to scan discovery scan row: %w", err)
}
scan.Directories = dirs
scans = append(scans, scan)
}
return scans, total, nil
}
// --- Discovered Certificates ---
// CreateDiscovered stores a new discovered certificate.
// Uses ON CONFLICT to update last_seen_at for existing fingerprint+agent+path combos.
func (r *DiscoveryRepository) CreateDiscovered(ctx context.Context, cert *domain.DiscoveredCertificate) (bool, error) {
query := `
INSERT INTO discovered_certificates (
id, fingerprint_sha256, common_name, sans, serial_number, issuer_dn, subject_dn,
not_before, not_after, key_algorithm, key_size, is_ca, pem_data,
source_path, source_format, agent_id, discovery_scan_id,
status, first_seen_at, last_seen_at, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22)
ON CONFLICT (fingerprint_sha256, agent_id, source_path) DO UPDATE SET
last_seen_at = EXCLUDED.last_seen_at,
discovery_scan_id = EXCLUDED.discovery_scan_id,
updated_at = NOW()
RETURNING (xmax = 0) AS is_new`
var isNew bool
err := r.db.QueryRowContext(ctx, query,
cert.ID, cert.FingerprintSHA256, cert.CommonName, pq.Array(cert.SANs),
cert.SerialNumber, cert.IssuerDN, cert.SubjectDN,
cert.NotBefore, cert.NotAfter, cert.KeyAlgorithm, cert.KeySize, cert.IsCA,
cert.PEMData, cert.SourcePath, cert.SourceFormat,
cert.AgentID, nullableString(cert.DiscoveryScanID),
string(cert.Status), cert.FirstSeenAt, cert.LastSeenAt,
cert.CreatedAt, cert.UpdatedAt,
).Scan(&isNew)
if err != nil {
return false, fmt.Errorf("failed to upsert discovered certificate: %w", err)
}
return isNew, nil
}
// GetDiscovered retrieves a discovered certificate by ID.
func (r *DiscoveryRepository) GetDiscovered(ctx context.Context, id string) (*domain.DiscoveredCertificate, error) {
query := `
SELECT id, fingerprint_sha256, common_name, sans, serial_number, issuer_dn, subject_dn,
not_before, not_after, key_algorithm, key_size, is_ca, pem_data,
source_path, source_format, agent_id, discovery_scan_id, managed_certificate_id,
status, first_seen_at, last_seen_at, dismissed_at, created_at, updated_at
FROM discovered_certificates WHERE id = $1`
cert := &domain.DiscoveredCertificate{}
var sans []string
var scanID, managedID sql.NullString
err := r.db.QueryRowContext(ctx, query, id).Scan(
&cert.ID, &cert.FingerprintSHA256, &cert.CommonName, pq.Array(&sans),
&cert.SerialNumber, &cert.IssuerDN, &cert.SubjectDN,
&cert.NotBefore, &cert.NotAfter, &cert.KeyAlgorithm, &cert.KeySize, &cert.IsCA,
&cert.PEMData, &cert.SourcePath, &cert.SourceFormat,
&cert.AgentID, &scanID, &managedID,
&cert.Status, &cert.FirstSeenAt, &cert.LastSeenAt, &cert.DismissedAt,
&cert.CreatedAt, &cert.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("discovered certificate not found: %w", repository.ErrNotFound)
}
if err != nil {
return nil, fmt.Errorf("failed to get discovered certificate: %w", err)
}
cert.SANs = sans
if scanID.Valid {
cert.DiscoveryScanID = scanID.String
}
if managedID.Valid {
cert.ManagedCertificateID = managedID.String
}
return cert, nil
}
// ListDiscovered returns discovered certificates matching the filter.
func (r *DiscoveryRepository) ListDiscovered(ctx context.Context, filter *repository.DiscoveryFilter) ([]*domain.DiscoveredCertificate, int, error) {
if filter.Page < 1 {
filter.Page = 1
}
if filter.PerPage <= 0 || filter.PerPage > 500 {
filter.PerPage = 50
}
var whereConditions []string
var args []interface{}
argCount := 1
if filter.AgentID != "" {
whereConditions = append(whereConditions, fmt.Sprintf("agent_id = $%d", argCount))
args = append(args, filter.AgentID)
argCount++
}
if filter.Status != "" {
whereConditions = append(whereConditions, fmt.Sprintf("status = $%d", argCount))
args = append(args, filter.Status)
argCount++
}
if filter.IsExpired {
whereConditions = append(whereConditions, "not_after < NOW()")
}
if filter.IsCA {
whereConditions = append(whereConditions, "is_ca = TRUE")
}
whereClause := ""
if len(whereConditions) > 0 {
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
}
// Count
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM discovered_certificates %s", whereClause)
var total int
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("failed to count discovered certificates: %w", err)
}
// List
offset := (filter.Page - 1) * filter.PerPage
listQuery := fmt.Sprintf(`
SELECT id, fingerprint_sha256, common_name, sans, serial_number, issuer_dn, subject_dn,
not_before, not_after, key_algorithm, key_size, is_ca, pem_data,
source_path, source_format, agent_id, discovery_scan_id, managed_certificate_id,
status, first_seen_at, last_seen_at, dismissed_at, created_at, updated_at
FROM discovered_certificates %s
ORDER BY last_seen_at DESC
LIMIT $%d OFFSET $%d`, whereClause, argCount, argCount+1)
args = append(args, filter.PerPage, offset)
rows, err := r.db.QueryContext(ctx, listQuery, args...)
if err != nil {
return nil, 0, fmt.Errorf("failed to list discovered certificates: %w", err)
}
defer rows.Close()
var certs []*domain.DiscoveredCertificate
for rows.Next() {
cert := &domain.DiscoveredCertificate{}
var sans []string
var scanID, managedID sql.NullString
if err := rows.Scan(
&cert.ID, &cert.FingerprintSHA256, &cert.CommonName, pq.Array(&sans),
&cert.SerialNumber, &cert.IssuerDN, &cert.SubjectDN,
&cert.NotBefore, &cert.NotAfter, &cert.KeyAlgorithm, &cert.KeySize, &cert.IsCA,
&cert.PEMData, &cert.SourcePath, &cert.SourceFormat,
&cert.AgentID, &scanID, &managedID,
&cert.Status, &cert.FirstSeenAt, &cert.LastSeenAt, &cert.DismissedAt,
&cert.CreatedAt, &cert.UpdatedAt,
); err != nil {
return nil, 0, fmt.Errorf("failed to scan discovered certificate row: %w", err)
}
cert.SANs = sans
if scanID.Valid {
cert.DiscoveryScanID = scanID.String
}
if managedID.Valid {
cert.ManagedCertificateID = managedID.String
}
certs = append(certs, cert)
}
return certs, total, nil
}
// UpdateDiscoveredStatus updates the status and optional managed certificate link.
func (r *DiscoveryRepository) UpdateDiscoveredStatus(ctx context.Context, id string, status domain.DiscoveryStatus, managedCertID string) error {
var query string
var args []interface{}
now := time.Now()
switch status {
case domain.DiscoveryStatusManaged:
query = `UPDATE discovered_certificates SET status = $1, managed_certificate_id = $2, updated_at = $3 WHERE id = $4`
args = []interface{}{string(status), managedCertID, now, id}
case domain.DiscoveryStatusDismissed:
query = `UPDATE discovered_certificates SET status = $1, dismissed_at = $2, updated_at = $3 WHERE id = $4`
args = []interface{}{string(status), now, now, id}
default:
query = `UPDATE discovered_certificates SET status = $1, managed_certificate_id = NULL, dismissed_at = NULL, updated_at = $2 WHERE id = $3`
args = []interface{}{string(status), now, id}
}
result, err := r.db.ExecContext(ctx, query, args...)
if err != nil {
return fmt.Errorf("failed to update discovered certificate status: %w", err)
}
rowsAffected, _ := result.RowsAffected()
if rowsAffected == 0 {
return fmt.Errorf("discovered certificate not found: %w", repository.ErrNotFound)
}
return nil
}
// GetByFingerprint retrieves discovered certificates by SHA-256 fingerprint.
func (r *DiscoveryRepository) GetByFingerprint(ctx context.Context, fingerprint string) ([]*domain.DiscoveredCertificate, error) {
query := `
SELECT id, fingerprint_sha256, common_name, sans, serial_number, issuer_dn, subject_dn,
not_before, not_after, key_algorithm, key_size, is_ca, '',
source_path, source_format, agent_id, discovery_scan_id, managed_certificate_id,
status, first_seen_at, last_seen_at, dismissed_at, created_at, updated_at
FROM discovered_certificates WHERE fingerprint_sha256 = $1
ORDER BY last_seen_at DESC`
rows, err := r.db.QueryContext(ctx, query, fingerprint)
if err != nil {
return nil, fmt.Errorf("failed to get by fingerprint: %w", err)
}
defer rows.Close()
var certs []*domain.DiscoveredCertificate
for rows.Next() {
cert := &domain.DiscoveredCertificate{}
var sans []string
var scanID, managedID sql.NullString
if err := rows.Scan(
&cert.ID, &cert.FingerprintSHA256, &cert.CommonName, pq.Array(&sans),
&cert.SerialNumber, &cert.IssuerDN, &cert.SubjectDN,
&cert.NotBefore, &cert.NotAfter, &cert.KeyAlgorithm, &cert.KeySize, &cert.IsCA,
&cert.PEMData, &cert.SourcePath, &cert.SourceFormat,
&cert.AgentID, &scanID, &managedID,
&cert.Status, &cert.FirstSeenAt, &cert.LastSeenAt, &cert.DismissedAt,
&cert.CreatedAt, &cert.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
cert.SANs = sans
if scanID.Valid {
cert.DiscoveryScanID = scanID.String
}
if managedID.Valid {
cert.ManagedCertificateID = managedID.String
}
certs = append(certs, cert)
}
return certs, nil
}
// CountByStatus returns counts of discovered certificates grouped by status.
func (r *DiscoveryRepository) CountByStatus(ctx context.Context) (map[string]int, error) {
query := `SELECT status, COUNT(*) FROM discovered_certificates GROUP BY status`
rows, err := r.db.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to count by status: %w", err)
}
defer rows.Close()
counts := make(map[string]int)
for rows.Next() {
var status string
var count int
if err := rows.Scan(&status, &count); err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
counts[status] = count
}
return counts, nil
}
// nullableString returns a sql.NullString, null if the string is empty.
func nullableString(s string) sql.NullString {
if s == "" {
return sql.NullString{}
}
return sql.NullString{String: s, Valid: true}
}