feat: add network certificate discovery (M21) and Prometheus metrics (M22)

M21 adds server-side active TLS scanning of CIDR ranges with concurrent
probing, sentinel agent pattern for pipeline reuse, and full CRUD API for
scan targets. M22 adds Prometheus exposition format endpoint alongside
existing JSON metrics. Comprehensive documentation audit updates all docs
to reflect 91 endpoints, 19 tables, 6 scheduler loops, and 900+ tests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Shankar
2026-03-24 23:37:47 -04:00
parent 3dc76e0b87
commit be85fbd77e
26 changed files with 2022 additions and 71 deletions
+436
View File
@@ -0,0 +1,436 @@
package service
import (
"context"
"crypto/ecdsa"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"log/slog"
"net"
"sync"
"time"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/repository"
)
// SentinelAgentID is the agent ID used for network-discovered certificates.
// This allows the existing discovery dedup constraint (fingerprint, agent_id, source_path)
// to work without schema changes.
const SentinelAgentID = "server-scanner"
// NetworkScanService manages active TLS scanning of network endpoints.
type NetworkScanService struct {
networkScanRepo repository.NetworkScanRepository
discoveryService *DiscoveryService
auditService *AuditService
logger *slog.Logger
concurrency int
}
// NewNetworkScanService creates a new network scan service.
func NewNetworkScanService(
networkScanRepo repository.NetworkScanRepository,
discoveryService *DiscoveryService,
auditService *AuditService,
logger *slog.Logger,
) *NetworkScanService {
return &NetworkScanService{
networkScanRepo: networkScanRepo,
discoveryService: discoveryService,
auditService: auditService,
logger: logger,
concurrency: 50,
}
}
// ListTargets returns all network scan targets.
func (s *NetworkScanService) ListTargets(ctx context.Context) ([]*domain.NetworkScanTarget, error) {
return s.networkScanRepo.List(ctx)
}
// GetTarget retrieves a network scan target by ID.
func (s *NetworkScanService) GetTarget(ctx context.Context, id string) (*domain.NetworkScanTarget, error) {
return s.networkScanRepo.Get(ctx, id)
}
// CreateTarget creates a new network scan target.
func (s *NetworkScanService) CreateTarget(ctx context.Context, target *domain.NetworkScanTarget) (*domain.NetworkScanTarget, error) {
if target.Name == "" {
return nil, fmt.Errorf("name is required")
}
if len(target.CIDRs) == 0 {
return nil, fmt.Errorf("at least one CIDR is required")
}
// Validate CIDRs
for _, cidr := range target.CIDRs {
if _, _, err := net.ParseCIDR(cidr); err != nil {
// Try parsing as plain IP
if ip := net.ParseIP(cidr); ip == nil {
return nil, fmt.Errorf("invalid CIDR or IP: %s", cidr)
}
}
}
if len(target.Ports) == 0 {
target.Ports = []int{443}
}
if target.ScanIntervalHours == 0 {
target.ScanIntervalHours = 6
}
if target.TimeoutMs == 0 {
target.TimeoutMs = 5000
}
target.ID = generateID("nst")
target.Enabled = true
target.CreatedAt = time.Now()
target.UpdatedAt = time.Now()
if err := s.networkScanRepo.Create(ctx, target); err != nil {
return nil, err
}
s.auditService.RecordEvent(ctx, "operator", domain.ActorTypeUser,
"network_scan_target_created", "network_scan_target", target.ID,
map[string]interface{}{
"name": target.Name,
"cidrs": target.CIDRs,
"ports": target.Ports,
})
return target, nil
}
// UpdateTarget updates an existing network scan target.
func (s *NetworkScanService) UpdateTarget(ctx context.Context, id string, target *domain.NetworkScanTarget) (*domain.NetworkScanTarget, error) {
existing, err := s.networkScanRepo.Get(ctx, id)
if err != nil {
return nil, err
}
if target.Name != "" {
existing.Name = target.Name
}
if len(target.CIDRs) > 0 {
// Validate new CIDRs
for _, cidr := range target.CIDRs {
if _, _, err := net.ParseCIDR(cidr); err != nil {
if ip := net.ParseIP(cidr); ip == nil {
return nil, fmt.Errorf("invalid CIDR or IP: %s", cidr)
}
}
}
existing.CIDRs = target.CIDRs
}
if len(target.Ports) > 0 {
existing.Ports = target.Ports
}
if target.ScanIntervalHours > 0 {
existing.ScanIntervalHours = target.ScanIntervalHours
}
if target.TimeoutMs > 0 {
existing.TimeoutMs = target.TimeoutMs
}
// Always update enabled field (it's a boolean, so 0-value is meaningful)
existing.Enabled = target.Enabled
if err := s.networkScanRepo.Update(ctx, existing); err != nil {
return nil, err
}
return existing, nil
}
// DeleteTarget removes a network scan target.
func (s *NetworkScanService) DeleteTarget(ctx context.Context, id string) error {
if err := s.networkScanRepo.Delete(ctx, id); err != nil {
return err
}
s.auditService.RecordEvent(ctx, "operator", domain.ActorTypeUser,
"network_scan_target_deleted", "network_scan_target", id, nil)
return nil
}
// ScanAllTargets runs the active TLS scan for all enabled targets.
// This is called by the scheduler on the configured interval.
func (s *NetworkScanService) ScanAllTargets(ctx context.Context) error {
targets, err := s.networkScanRepo.ListEnabled(ctx)
if err != nil {
return fmt.Errorf("list enabled targets: %w", err)
}
if len(targets) == 0 {
if s.logger != nil {
s.logger.Debug("no enabled network scan targets")
}
return nil
}
if s.logger != nil {
s.logger.Info("starting network scan", "targets", len(targets))
}
for _, target := range targets {
if ctx.Err() != nil {
return ctx.Err()
}
s.scanTarget(ctx, target)
}
return nil
}
// TriggerScan runs an immediate scan for a specific target.
func (s *NetworkScanService) TriggerScan(ctx context.Context, targetID string) (*domain.DiscoveryScan, error) {
target, err := s.networkScanRepo.Get(ctx, targetID)
if err != nil {
return nil, err
}
return s.scanTarget(ctx, target), nil
}
// scanTarget scans a single network target and feeds results into the discovery pipeline.
func (s *NetworkScanService) scanTarget(ctx context.Context, target *domain.NetworkScanTarget) *domain.DiscoveryScan {
startTime := time.Now()
if s.logger != nil {
s.logger.Info("scanning network target",
"target_id", target.ID,
"name", target.Name,
"cidrs", target.CIDRs,
"ports", target.Ports)
}
// Expand CIDRs to individual IPs
endpoints := s.expandEndpoints(target.CIDRs, target.Ports)
if s.logger != nil {
s.logger.Debug("expanded endpoints", "count", len(endpoints))
}
// Scan endpoints concurrently
timeout := time.Duration(target.TimeoutMs) * time.Millisecond
results := s.scanEndpoints(ctx, endpoints, timeout)
// Collect discovered cert entries
var entries []domain.DiscoveredCertEntry
var scanErrors []string
for _, result := range results {
if result.Error != "" {
// Only log connection errors at debug level (many hosts won't have TLS)
if s.logger != nil {
s.logger.Debug("scan endpoint error",
"address", result.Address,
"error", result.Error)
}
continue
}
entries = append(entries, result.Certs...)
}
scanDuration := time.Since(startTime)
if s.logger != nil {
s.logger.Info("network target scan completed",
"target_id", target.ID,
"endpoints_scanned", len(endpoints),
"certificates_found", len(entries),
"errors", len(scanErrors),
"duration_ms", scanDuration.Milliseconds())
}
// Update scan results on target
s.networkScanRepo.UpdateScanResults(ctx, target.ID, time.Now(),
int(scanDuration.Milliseconds()), len(entries))
// Feed into discovery pipeline if we found certs
if len(entries) == 0 {
return nil
}
// Build directories list from CIDRs for the scan record
dirs := make([]string, len(target.CIDRs))
copy(dirs, target.CIDRs)
report := &domain.DiscoveryReport{
AgentID: SentinelAgentID,
Directories: dirs,
Certificates: entries,
Errors: scanErrors,
ScanDurationMs: int(scanDuration.Milliseconds()),
}
scan, err := s.discoveryService.ProcessDiscoveryReport(ctx, report)
if err != nil {
if s.logger != nil {
s.logger.Error("failed to process network scan report",
"target_id", target.ID,
"error", err)
}
return nil
}
return scan
}
// expandEndpoints converts CIDR ranges and ports into a list of "ip:port" endpoints.
func (s *NetworkScanService) expandEndpoints(cidrs []string, ports []int) []string {
var endpoints []string
for _, cidr := range cidrs {
ips := expandCIDR(cidr)
for _, ip := range ips {
for _, port := range ports {
endpoints = append(endpoints, fmt.Sprintf("%s:%d", ip, port))
}
}
}
return endpoints
}
// expandCIDR expands a CIDR notation or single IP into a list of IPs.
// Limits expansion to /20 (4096 IPs) to prevent accidental huge scans.
func expandCIDR(cidr string) []string {
// Try as CIDR first
ip, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
// Try as single IP
if singleIP := net.ParseIP(cidr); singleIP != nil {
return []string{singleIP.String()}
}
return nil
}
// Count network size and cap at /20
ones, bits := ipNet.Mask.Size()
hostBits := bits - ones
if hostBits > 12 { // More than 4096 hosts
return nil // Skip overly large networks
}
var ips []string
for ip := ip.Mask(ipNet.Mask); ipNet.Contains(ip); incrementIP(ip) {
// Copy IP before appending (net.IP is a mutable slice)
ipCopy := make(net.IP, len(ip))
copy(ipCopy, ip)
ips = append(ips, ipCopy.String())
}
// Remove network and broadcast for IPv4 /31 and larger
if len(ips) > 2 {
ips = ips[1 : len(ips)-1]
}
return ips
}
// incrementIP increments an IP address by one.
func incrementIP(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
if ip[j] > 0 {
break
}
}
}
// scanEndpoints probes TLS endpoints concurrently and returns results.
func (s *NetworkScanService) scanEndpoints(ctx context.Context, endpoints []string, timeout time.Duration) []domain.NetworkScanResult {
results := make([]domain.NetworkScanResult, len(endpoints))
sem := make(chan struct{}, s.concurrency)
var wg sync.WaitGroup
for i, endpoint := range endpoints {
if ctx.Err() != nil {
break
}
wg.Add(1)
sem <- struct{}{}
go func(idx int, addr string) {
defer wg.Done()
defer func() { <-sem }()
results[idx] = s.probeTLS(ctx, addr, timeout)
}(i, endpoint)
}
wg.Wait()
return results
}
// probeTLS connects to an endpoint, performs a TLS handshake, and extracts certificates.
func (s *NetworkScanService) probeTLS(ctx context.Context, address string, timeout time.Duration) domain.NetworkScanResult {
startTime := time.Now()
result := domain.NetworkScanResult{Address: address}
dialer := &net.Dialer{Timeout: timeout}
conn, err := tls.DialWithDialer(dialer, "tcp", address, &tls.Config{
InsecureSkipVerify: true, // We want to discover ALL certs, including self-signed
})
if err != nil {
result.Error = err.Error()
result.LatencyMs = int(time.Since(startTime).Milliseconds())
return result
}
defer conn.Close()
result.LatencyMs = int(time.Since(startTime).Milliseconds())
// Extract certificates from TLS connection state
state := conn.ConnectionState()
for _, cert := range state.PeerCertificates {
entry := tlsCertToEntry(cert, address)
result.Certs = append(result.Certs, entry)
}
return result
}
// tlsCertToEntry converts an x509.Certificate from a TLS handshake into a DiscoveredCertEntry.
func tlsCertToEntry(cert *x509.Certificate, address string) domain.DiscoveredCertEntry {
// Compute SHA-256 fingerprint
fingerprintBytes := sha256.Sum256(cert.Raw)
fingerprint := fmt.Sprintf("%x", fingerprintBytes)
// Encode as PEM
pemBlock := &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
pemData := string(pem.EncodeToMemory(pemBlock))
// Key algorithm and size
keyAlg, keySize := tlsCertKeyInfo(cert)
return domain.DiscoveredCertEntry{
FingerprintSHA256: fingerprint,
CommonName: cert.Subject.CommonName,
SANs: cert.DNSNames,
SerialNumber: cert.SerialNumber.Text(16),
IssuerDN: cert.Issuer.String(),
SubjectDN: cert.Subject.String(),
NotBefore: cert.NotBefore.UTC().Format(time.RFC3339),
NotAfter: cert.NotAfter.UTC().Format(time.RFC3339),
KeyAlgorithm: keyAlg,
KeySize: keySize,
IsCA: cert.IsCA,
PEMData: pemData,
SourcePath: address,
SourceFormat: "network",
}
}
// tlsCertKeyInfo extracts key algorithm name and size from a certificate.
func tlsCertKeyInfo(cert *x509.Certificate) (string, int) {
switch pub := cert.PublicKey.(type) {
case *rsa.PublicKey:
return "RSA", pub.N.BitLen()
case *ecdsa.PublicKey:
return "ECDSA", pub.Curve.Params().BitSize
default:
switch cert.PublicKeyAlgorithm {
case x509.Ed25519:
return "Ed25519", 256
default:
return cert.PublicKeyAlgorithm.String(), 0
}
}
}
+244
View File
@@ -0,0 +1,244 @@
package service
import (
"context"
"fmt"
"testing"
"time"
"github.com/shankar0123/certctl/internal/domain"
)
// mockNetworkScanRepo for testing
type mockNetworkScanRepo struct {
targets []*domain.NetworkScanTarget
}
func (m *mockNetworkScanRepo) List(ctx context.Context) ([]*domain.NetworkScanTarget, error) {
return m.targets, nil
}
func (m *mockNetworkScanRepo) ListEnabled(ctx context.Context) ([]*domain.NetworkScanTarget, error) {
var enabled []*domain.NetworkScanTarget
for _, t := range m.targets {
if t.Enabled {
enabled = append(enabled, t)
}
}
return enabled, nil
}
func (m *mockNetworkScanRepo) Get(ctx context.Context, id string) (*domain.NetworkScanTarget, error) {
for _, t := range m.targets {
if t.ID == id {
return t, nil
}
}
return nil, fmt.Errorf("not found: %s", id)
}
func (m *mockNetworkScanRepo) Create(ctx context.Context, target *domain.NetworkScanTarget) error {
m.targets = append(m.targets, target)
return nil
}
func (m *mockNetworkScanRepo) Update(ctx context.Context, target *domain.NetworkScanTarget) error {
for i, t := range m.targets {
if t.ID == target.ID {
m.targets[i] = target
return nil
}
}
return fmt.Errorf("not found: %s", target.ID)
}
func (m *mockNetworkScanRepo) Delete(ctx context.Context, id string) error {
for i, t := range m.targets {
if t.ID == id {
m.targets = append(m.targets[:i], m.targets[i+1:]...)
return nil
}
}
return fmt.Errorf("not found: %s", id)
}
func (m *mockNetworkScanRepo) UpdateScanResults(ctx context.Context, id string, scanAt time.Time, durationMs int, certsFound int) error {
for _, t := range m.targets {
if t.ID == id {
t.LastScanAt = &scanAt
d := durationMs
t.LastScanDurationMs = &d
c := certsFound
t.LastScanCertsFound = &c
return nil
}
}
return fmt.Errorf("not found: %s", id)
}
func TestExpandCIDR_SingleIP(t *testing.T) {
ips := expandCIDR("192.168.1.1")
if len(ips) != 1 || ips[0] != "192.168.1.1" {
t.Errorf("expected [192.168.1.1], got %v", ips)
}
}
func TestExpandCIDR_Slash30(t *testing.T) {
// /30 = 4 total addresses, 2 usable (remove network + broadcast)
ips := expandCIDR("10.0.0.0/30")
if len(ips) != 2 {
t.Errorf("expected 2 usable IPs for /30, got %d: %v", len(ips), ips)
}
}
func TestExpandCIDR_Slash24(t *testing.T) {
ips := expandCIDR("10.0.0.0/24")
if len(ips) != 254 {
t.Errorf("expected 254 usable IPs for /24, got %d", len(ips))
}
}
func TestExpandCIDR_TooLarge(t *testing.T) {
// /16 = 65536 IPs, exceeds /20 cap
ips := expandCIDR("10.0.0.0/16")
if len(ips) != 0 {
t.Errorf("expected empty for /16 (too large), got %d", len(ips))
}
}
func TestExpandCIDR_InvalidInput(t *testing.T) {
ips := expandCIDR("not-a-cidr")
if len(ips) != 0 {
t.Errorf("expected empty for invalid input, got %v", ips)
}
}
func TestNetworkScanService_CreateTarget(t *testing.T) {
repo := &mockNetworkScanRepo{}
auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo)
svc := NewNetworkScanService(repo, nil, auditService, nil)
target, err := svc.CreateTarget(context.Background(), &domain.NetworkScanTarget{
Name: "Test Network",
CIDRs: []string{"10.0.0.0/24"},
Ports: []int{443, 8443},
})
if err != nil {
t.Fatalf("CreateTarget failed: %v", err)
}
if target.ID == "" {
t.Error("expected non-empty ID")
}
if !target.Enabled {
t.Error("expected target to be enabled by default")
}
if target.ScanIntervalHours != 6 {
t.Errorf("expected default interval 6h, got %d", target.ScanIntervalHours)
}
if target.TimeoutMs != 5000 {
t.Errorf("expected default timeout 5000ms, got %d", target.TimeoutMs)
}
}
func TestNetworkScanService_CreateTarget_ValidationErrors(t *testing.T) {
repo := &mockNetworkScanRepo{}
auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo)
svc := NewNetworkScanService(repo, nil, auditService, nil)
tests := []struct {
name string
target *domain.NetworkScanTarget
errMsg string
}{
{
name: "missing name",
target: &domain.NetworkScanTarget{CIDRs: []string{"10.0.0.0/24"}},
errMsg: "name is required",
},
{
name: "missing cidrs",
target: &domain.NetworkScanTarget{Name: "test"},
errMsg: "at least one CIDR is required",
},
{
name: "invalid cidr",
target: &domain.NetworkScanTarget{Name: "test", CIDRs: []string{"not-valid"}},
errMsg: "invalid CIDR or IP",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := svc.CreateTarget(context.Background(), tt.target)
if err == nil {
t.Fatal("expected error")
}
if !containsSubstring(err.Error(), tt.errMsg) {
t.Errorf("expected error containing %q, got %q", tt.errMsg, err.Error())
}
})
}
}
func TestNetworkScanService_DeleteTarget(t *testing.T) {
repo := &mockNetworkScanRepo{
targets: []*domain.NetworkScanTarget{
{ID: "nst-1", Name: "test"},
},
}
auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo)
svc := NewNetworkScanService(repo, nil, auditService, nil)
if err := svc.DeleteTarget(context.Background(), "nst-1"); err != nil {
t.Fatalf("DeleteTarget failed: %v", err)
}
if len(repo.targets) != 0 {
t.Error("expected target to be deleted")
}
}
func TestNetworkScanService_ListTargets(t *testing.T) {
repo := &mockNetworkScanRepo{
targets: []*domain.NetworkScanTarget{
{ID: "nst-1", Name: "target1"},
{ID: "nst-2", Name: "target2"},
},
}
svc := NewNetworkScanService(repo, nil, nil, nil)
targets, err := svc.ListTargets(context.Background())
if err != nil {
t.Fatalf("ListTargets failed: %v", err)
}
if len(targets) != 2 {
t.Errorf("expected 2 targets, got %d", len(targets))
}
}
func TestExpandEndpoints(t *testing.T) {
svc := &NetworkScanService{}
endpoints := svc.expandEndpoints([]string{"192.168.1.1"}, []int{443, 8443})
if len(endpoints) != 2 {
t.Errorf("expected 2 endpoints, got %d: %v", len(endpoints), endpoints)
}
if endpoints[0] != "192.168.1.1:443" {
t.Errorf("expected 192.168.1.1:443, got %s", endpoints[0])
}
if endpoints[1] != "192.168.1.1:8443" {
t.Errorf("expected 192.168.1.1:8443, got %s", endpoints[1])
}
}
// containsSubstring checks if a string contains a substring (helper)
func containsSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}