feat: add network certificate discovery (M21) and Prometheus metrics (M22)

M21 adds server-side active TLS scanning of CIDR ranges with concurrent
probing, sentinel agent pattern for pipeline reuse, and full CRUD API for
scan targets. M22 adds Prometheus exposition format endpoint alongside
existing JSON metrics. Comprehensive documentation audit updates all docs
to reflect 91 endpoints, 19 tables, 6 scheduler loops, and 900+ tests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Shankar
2026-03-24 23:37:47 -04:00
parent 3dc76e0b87
commit be85fbd77e
26 changed files with 2022 additions and 71 deletions
+92 -3
View File
@@ -3,6 +3,7 @@ package handler
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
@@ -14,9 +15,9 @@ type MetricsService interface {
GetDashboardSummary(ctx context.Context) (interface{}, error)
}
// MetricsHandler handles HTTP requests for Prometheus-style metrics.
// In V2, returns JSON metrics (not Prometheus format).
// Prometheus format can be added in V3 when observability becomes a paid feature.
// MetricsHandler handles HTTP requests for metrics.
// Supports both JSON format (GET /api/v1/metrics) and Prometheus exposition format
// (GET /api/v1/metrics/prometheus) for integration with Prometheus, Grafana, Datadog, etc.
type MetricsHandler struct {
svc MetricsService
serverStarted time.Time
@@ -117,6 +118,94 @@ func (h MetricsHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
JSON(w, http.StatusOK, metricsResp)
}
// GetPrometheusMetrics returns metrics in Prometheus exposition format (text/plain).
// GET /api/v1/metrics/prometheus
// Compatible with Prometheus, Grafana Agent, Datadog Agent, Victoria Metrics, and any
// OpenMetrics-compatible scraper. Metric names follow Prometheus naming conventions
// (lowercase, snake_case, prefixed with certctl_).
func (h MetricsHandler) GetPrometheusMetrics(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
return
}
requestID := middleware.GetRequestID(r.Context())
summary, err := h.svc.GetDashboardSummary(r.Context())
if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to collect metrics", requestID)
return
}
// Extract fields from summary via JSON round-trip (avoids cross-package type assertion)
jsonBytes, err := json.Marshal(summary)
if err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to marshal metrics data", requestID)
return
}
var dashboardSummary DashboardSummary
if err := json.Unmarshal(jsonBytes, &dashboardSummary); err != nil {
ErrorWithRequestID(w, http.StatusInternalServerError, "Invalid metrics data", requestID)
return
}
// Compute derived values
active := dashboardSummary.TotalCertificates - dashboardSummary.ExpiringCertificates - dashboardSummary.ExpiredCertificates - dashboardSummary.RevokedCertificates
uptimeSeconds := int64(time.Since(h.serverStarted).Seconds())
// Build Prometheus exposition format
// See: https://prometheus.io/docs/instrumenting/exposition_formats/
w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8")
w.WriteHeader(http.StatusOK)
// Gauges — point-in-time values
fmt.Fprintf(w, "# HELP certctl_certificate_total Total number of managed certificates.\n")
fmt.Fprintf(w, "# TYPE certctl_certificate_total gauge\n")
fmt.Fprintf(w, "certctl_certificate_total %d\n\n", dashboardSummary.TotalCertificates)
fmt.Fprintf(w, "# HELP certctl_certificate_active Number of active (non-expiring, non-expired, non-revoked) certificates.\n")
fmt.Fprintf(w, "# TYPE certctl_certificate_active gauge\n")
fmt.Fprintf(w, "certctl_certificate_active %d\n\n", active)
fmt.Fprintf(w, "# HELP certctl_certificate_expiring_soon Number of certificates expiring within 30 days.\n")
fmt.Fprintf(w, "# TYPE certctl_certificate_expiring_soon gauge\n")
fmt.Fprintf(w, "certctl_certificate_expiring_soon %d\n\n", dashboardSummary.ExpiringCertificates)
fmt.Fprintf(w, "# HELP certctl_certificate_expired Number of expired certificates.\n")
fmt.Fprintf(w, "# TYPE certctl_certificate_expired gauge\n")
fmt.Fprintf(w, "certctl_certificate_expired %d\n\n", dashboardSummary.ExpiredCertificates)
fmt.Fprintf(w, "# HELP certctl_certificate_revoked Number of revoked certificates.\n")
fmt.Fprintf(w, "# TYPE certctl_certificate_revoked gauge\n")
fmt.Fprintf(w, "certctl_certificate_revoked %d\n\n", dashboardSummary.RevokedCertificates)
fmt.Fprintf(w, "# HELP certctl_agent_total Total number of registered agents.\n")
fmt.Fprintf(w, "# TYPE certctl_agent_total gauge\n")
fmt.Fprintf(w, "certctl_agent_total %d\n\n", dashboardSummary.TotalAgents)
fmt.Fprintf(w, "# HELP certctl_agent_online Number of agents currently online.\n")
fmt.Fprintf(w, "# TYPE certctl_agent_online gauge\n")
fmt.Fprintf(w, "certctl_agent_online %d\n\n", dashboardSummary.ActiveAgents)
fmt.Fprintf(w, "# HELP certctl_job_pending Number of jobs currently pending.\n")
fmt.Fprintf(w, "# TYPE certctl_job_pending gauge\n")
fmt.Fprintf(w, "certctl_job_pending %d\n\n", dashboardSummary.PendingJobs)
// Counters — cumulative values
fmt.Fprintf(w, "# HELP certctl_job_completed_total Total number of completed jobs.\n")
fmt.Fprintf(w, "# TYPE certctl_job_completed_total counter\n")
fmt.Fprintf(w, "certctl_job_completed_total %d\n\n", dashboardSummary.CompleteJobs)
fmt.Fprintf(w, "# HELP certctl_job_failed_total Total number of failed jobs.\n")
fmt.Fprintf(w, "# TYPE certctl_job_failed_total counter\n")
fmt.Fprintf(w, "certctl_job_failed_total %d\n\n", dashboardSummary.FailedJobs)
// Info — server uptime
fmt.Fprintf(w, "# HELP certctl_uptime_seconds Server uptime in seconds.\n")
fmt.Fprintf(w, "# TYPE certctl_uptime_seconds gauge\n")
fmt.Fprintf(w, "certctl_uptime_seconds %d\n", uptimeSeconds)
}
// DashboardSummary mirrors the service.DashboardSummary for JSON unmarshaling.
// JSON tags must match the service-layer struct exactly.
type DashboardSummary struct {
+179
View File
@@ -0,0 +1,179 @@
package handler
import (
"context"
"encoding/json"
"fmt"
"net/http"
"github.com/shankar0123/certctl/internal/domain"
)
// NetworkScanService defines the interface used by the network scan handler.
type NetworkScanService interface {
ListTargets(ctx context.Context) ([]*domain.NetworkScanTarget, error)
GetTarget(ctx context.Context, id string) (*domain.NetworkScanTarget, error)
CreateTarget(ctx context.Context, target *domain.NetworkScanTarget) (*domain.NetworkScanTarget, error)
UpdateTarget(ctx context.Context, id string, target *domain.NetworkScanTarget) (*domain.NetworkScanTarget, error)
DeleteTarget(ctx context.Context, id string) error
TriggerScan(ctx context.Context, targetID string) (*domain.DiscoveryScan, error)
}
// NetworkScanHandler handles HTTP requests for network scan targets.
type NetworkScanHandler struct {
svc NetworkScanService
}
// NewNetworkScanHandler creates a new network scan handler.
func NewNetworkScanHandler(svc NetworkScanService) NetworkScanHandler {
return NetworkScanHandler{svc: svc}
}
// ListNetworkScanTargets handles GET /api/v1/network-scan-targets
func (h NetworkScanHandler) ListNetworkScanTargets(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
return
}
targets, err := h.svc.ListTargets(r.Context())
if err != nil {
Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to list network scan targets: %v", err))
return
}
if targets == nil {
targets = []*domain.NetworkScanTarget{}
}
JSON(w, http.StatusOK, PagedResponse{
Data: targets,
Total: int64(len(targets)),
Page: 1,
PerPage: len(targets),
})
}
// GetNetworkScanTarget handles GET /api/v1/network-scan-targets/{id}
func (h NetworkScanHandler) GetNetworkScanTarget(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
return
}
id := r.PathValue("id")
if id == "" {
Error(w, http.StatusBadRequest, "network scan target ID is required")
return
}
target, err := h.svc.GetTarget(r.Context(), id)
if err != nil {
Error(w, http.StatusNotFound, fmt.Sprintf("network scan target not found: %v", err))
return
}
JSON(w, http.StatusOK, target)
}
// CreateNetworkScanTarget handles POST /api/v1/network-scan-targets
func (h NetworkScanHandler) CreateNetworkScanTarget(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
return
}
var target domain.NetworkScanTarget
if err := json.NewDecoder(r.Body).Decode(&target); err != nil {
Error(w, http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err))
return
}
created, err := h.svc.CreateTarget(r.Context(), &target)
if err != nil {
Error(w, http.StatusBadRequest, fmt.Sprintf("failed to create network scan target: %v", err))
return
}
JSON(w, http.StatusCreated, created)
}
// UpdateNetworkScanTarget handles PUT /api/v1/network-scan-targets/{id}
func (h NetworkScanHandler) UpdateNetworkScanTarget(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
return
}
id := r.PathValue("id")
if id == "" {
Error(w, http.StatusBadRequest, "network scan target ID is required")
return
}
var target domain.NetworkScanTarget
if err := json.NewDecoder(r.Body).Decode(&target); err != nil {
Error(w, http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err))
return
}
updated, err := h.svc.UpdateTarget(r.Context(), id, &target)
if err != nil {
Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to update network scan target: %v", err))
return
}
JSON(w, http.StatusOK, updated)
}
// DeleteNetworkScanTarget handles DELETE /api/v1/network-scan-targets/{id}
func (h NetworkScanHandler) DeleteNetworkScanTarget(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
return
}
id := r.PathValue("id")
if id == "" {
Error(w, http.StatusBadRequest, "network scan target ID is required")
return
}
if err := h.svc.DeleteTarget(r.Context(), id); err != nil {
Error(w, http.StatusNotFound, fmt.Sprintf("failed to delete network scan target: %v", err))
return
}
JSON(w, http.StatusNoContent, nil)
}
// TriggerNetworkScan handles POST /api/v1/network-scan-targets/{id}/scan
func (h NetworkScanHandler) TriggerNetworkScan(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
return
}
id := r.PathValue("id")
if id == "" {
Error(w, http.StatusBadRequest, "network scan target ID is required")
return
}
scan, err := h.svc.TriggerScan(r.Context(), id)
if err != nil {
Error(w, http.StatusInternalServerError, fmt.Sprintf("failed to trigger scan: %v", err))
return
}
// scan may be nil if no certs found
if scan == nil {
JSON(w, http.StatusOK, map[string]string{
"status": "completed",
"message": "Scan completed, no certificates found",
})
return
}
JSON(w, http.StatusAccepted, scan)
}
@@ -0,0 +1,220 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/shankar0123/certctl/internal/domain"
)
// mockNetworkScanService implements NetworkScanService for testing.
type mockNetworkScanService struct {
targets []*domain.NetworkScanTarget
}
func (m *mockNetworkScanService) ListTargets(ctx context.Context) ([]*domain.NetworkScanTarget, error) {
return m.targets, nil
}
func (m *mockNetworkScanService) GetTarget(ctx context.Context, id string) (*domain.NetworkScanTarget, error) {
for _, t := range m.targets {
if t.ID == id {
return t, nil
}
}
return nil, fmt.Errorf("not found: %s", id)
}
func (m *mockNetworkScanService) CreateTarget(ctx context.Context, target *domain.NetworkScanTarget) (*domain.NetworkScanTarget, error) {
if target.Name == "" {
return nil, fmt.Errorf("name is required")
}
target.ID = "nst-test-123"
m.targets = append(m.targets, target)
return target, nil
}
func (m *mockNetworkScanService) UpdateTarget(ctx context.Context, id string, target *domain.NetworkScanTarget) (*domain.NetworkScanTarget, error) {
for _, t := range m.targets {
if t.ID == id {
if target.Name != "" {
t.Name = target.Name
}
return t, nil
}
}
return nil, fmt.Errorf("not found: %s", id)
}
func (m *mockNetworkScanService) DeleteTarget(ctx context.Context, id string) error {
for i, t := range m.targets {
if t.ID == id {
m.targets = append(m.targets[:i], m.targets[i+1:]...)
return nil
}
}
return fmt.Errorf("not found: %s", id)
}
func (m *mockNetworkScanService) TriggerScan(ctx context.Context, targetID string) (*domain.DiscoveryScan, error) {
for _, t := range m.targets {
if t.ID == targetID {
return &domain.DiscoveryScan{
ID: "dscan-test",
AgentID: "server-scanner",
CertificatesFound: 3,
}, nil
}
}
return nil, fmt.Errorf("not found: %s", targetID)
}
func TestListNetworkScanTargets(t *testing.T) {
svc := &mockNetworkScanService{
targets: []*domain.NetworkScanTarget{
{ID: "nst-1", Name: "target1", CIDRs: []string{"10.0.0.0/24"}, Ports: []int{443}},
{ID: "nst-2", Name: "target2", CIDRs: []string{"192.168.0.0/16"}, Ports: []int{443, 8443}},
},
}
h := NewNetworkScanHandler(svc)
req := httptest.NewRequest(http.MethodGet, "/api/v1/network-scan-targets", nil)
w := httptest.NewRecorder()
h.ListNetworkScanTargets(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d", w.Code)
}
var resp PagedResponse
json.NewDecoder(w.Body).Decode(&resp)
if resp.Total != 2 {
t.Errorf("expected total 2, got %d", resp.Total)
}
}
func TestListNetworkScanTargets_Empty(t *testing.T) {
svc := &mockNetworkScanService{}
h := NewNetworkScanHandler(svc)
req := httptest.NewRequest(http.MethodGet, "/api/v1/network-scan-targets", nil)
w := httptest.NewRecorder()
h.ListNetworkScanTargets(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d", w.Code)
}
}
func TestCreateNetworkScanTarget(t *testing.T) {
svc := &mockNetworkScanService{}
h := NewNetworkScanHandler(svc)
body, _ := json.Marshal(map[string]interface{}{
"name": "Production",
"cidrs": []string{"10.0.0.0/24"},
"ports": []int{443},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/network-scan-targets", bytes.NewReader(body))
w := httptest.NewRecorder()
h.CreateNetworkScanTarget(w, req)
if w.Code != http.StatusCreated {
t.Errorf("expected 201, got %d: %s", w.Code, w.Body.String())
}
}
func TestCreateNetworkScanTarget_InvalidJSON(t *testing.T) {
svc := &mockNetworkScanService{}
h := NewNetworkScanHandler(svc)
req := httptest.NewRequest(http.MethodPost, "/api/v1/network-scan-targets", bytes.NewReader([]byte("not json")))
w := httptest.NewRecorder()
h.CreateNetworkScanTarget(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
}
func TestCreateNetworkScanTarget_MissingName(t *testing.T) {
svc := &mockNetworkScanService{}
h := NewNetworkScanHandler(svc)
body, _ := json.Marshal(map[string]interface{}{
"cidrs": []string{"10.0.0.0/24"},
})
req := httptest.NewRequest(http.MethodPost, "/api/v1/network-scan-targets", bytes.NewReader(body))
w := httptest.NewRecorder()
h.CreateNetworkScanTarget(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
}
func TestDeleteNetworkScanTarget_NotFound(t *testing.T) {
svc := &mockNetworkScanService{}
h := NewNetworkScanHandler(svc)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/network-scan-targets/nst-nonexistent", nil)
req.SetPathValue("id", "nst-nonexistent")
w := httptest.NewRecorder()
h.DeleteNetworkScanTarget(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404, got %d", w.Code)
}
}
func TestTriggerNetworkScan(t *testing.T) {
svc := &mockNetworkScanService{
targets: []*domain.NetworkScanTarget{
{ID: "nst-1", Name: "target1"},
},
}
h := NewNetworkScanHandler(svc)
req := httptest.NewRequest(http.MethodPost, "/api/v1/network-scan-targets/nst-1/scan", nil)
req.SetPathValue("id", "nst-1")
w := httptest.NewRecorder()
h.TriggerNetworkScan(w, req)
if w.Code != http.StatusAccepted {
t.Errorf("expected 202, got %d: %s", w.Code, w.Body.String())
}
}
func TestTriggerNetworkScan_NotFound(t *testing.T) {
svc := &mockNetworkScanService{}
h := NewNetworkScanHandler(svc)
req := httptest.NewRequest(http.MethodPost, "/api/v1/network-scan-targets/nst-nonexistent/scan", nil)
req.SetPathValue("id", "nst-nonexistent")
w := httptest.NewRecorder()
h.TriggerNetworkScan(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected 500, got %d", w.Code)
}
}
func TestListNetworkScanTargets_MethodNotAllowed(t *testing.T) {
svc := &mockNetworkScanService{}
h := NewNetworkScanHandler(svc)
req := httptest.NewRequest(http.MethodPost, "/api/v1/network-scan-targets", nil)
w := httptest.NewRecorder()
h.ListNetworkScanTargets(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected 405, got %d", w.Code)
}
}
+114
View File
@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
@@ -202,3 +203,116 @@ func TestGetMetrics_ServiceError(t *testing.T) {
t.Errorf("expected 500, got %d", w.Code)
}
}
// --- Prometheus metrics endpoint tests ---
func TestGetPrometheusMetrics_Success(t *testing.T) {
mock := &MockStatsService{
GetDashboardSummaryFn: func(ctx context.Context) (interface{}, error) {
return &DashboardSummary{
TotalCertificates: 25,
ExpiringCertificates: 3,
ExpiredCertificates: 2,
RevokedCertificates: 1,
ActiveAgents: 4,
TotalAgents: 6,
PendingJobs: 2,
FailedJobs: 1,
CompleteJobs: 15,
}, nil
},
}
h := NewMetricsHandler(mock, time.Now().Add(-1*time.Hour))
req := httptest.NewRequest(http.MethodGet, "/api/v1/metrics/prometheus", nil)
w := httptest.NewRecorder()
h.GetPrometheusMetrics(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d", w.Code)
}
contentType := w.Header().Get("Content-Type")
if contentType != "text/plain; version=0.0.4; charset=utf-8" {
t.Errorf("expected Prometheus content type, got %q", contentType)
}
body := w.Body.String()
// Check metric lines are present
expected := []string{
"certctl_certificate_total 25",
"certctl_certificate_active 19",
"certctl_certificate_expiring_soon 3",
"certctl_certificate_expired 2",
"certctl_certificate_revoked 1",
"certctl_agent_total 6",
"certctl_agent_online 4",
"certctl_job_pending 2",
"certctl_job_completed_total 15",
"certctl_job_failed_total 1",
"# TYPE certctl_certificate_total gauge",
"# TYPE certctl_job_completed_total counter",
"# HELP certctl_uptime_seconds",
"# TYPE certctl_uptime_seconds gauge",
}
for _, exp := range expected {
if !containsLine(body, exp) {
t.Errorf("expected body to contain %q", exp)
}
}
}
func TestGetPrometheusMetrics_MethodNotAllowed(t *testing.T) {
mock := &MockStatsService{}
h := NewMetricsHandler(mock, time.Now())
req := httptest.NewRequest(http.MethodPost, "/api/v1/metrics/prometheus", nil)
w := httptest.NewRecorder()
h.GetPrometheusMetrics(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected 405, got %d", w.Code)
}
}
func TestGetPrometheusMetrics_ServiceError(t *testing.T) {
mock := &MockStatsService{
GetDashboardSummaryFn: func(ctx context.Context) (interface{}, error) {
return nil, fmt.Errorf("db error")
},
}
h := NewMetricsHandler(mock, time.Now())
req := httptest.NewRequest(http.MethodGet, "/api/v1/metrics/prometheus", nil)
w := httptest.NewRecorder()
h.GetPrometheusMetrics(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected 500, got %d", w.Code)
}
}
func TestGetPrometheusMetrics_ZeroValues(t *testing.T) {
mock := &MockStatsService{
GetDashboardSummaryFn: func(ctx context.Context) (interface{}, error) {
return &DashboardSummary{}, nil
},
}
h := NewMetricsHandler(mock, time.Now())
req := httptest.NewRequest(http.MethodGet, "/api/v1/metrics/prometheus", nil)
w := httptest.NewRecorder()
h.GetPrometheusMetrics(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d", w.Code)
}
body := w.Body.String()
if !containsLine(body, "certctl_certificate_total 0") {
t.Error("expected zero value for certificate_total")
}
if !containsLine(body, "certctl_job_pending 0") {
t.Error("expected zero value for job_pending")
}
}
// containsLine checks if the text contains the given substring.
func containsLine(text, substr string) bool {
return strings.Contains(text, substr)
}
+10
View File
@@ -61,6 +61,7 @@ func (r *Router) RegisterHandlers(
metrics handler.MetricsHandler,
health handler.HealthHandler,
discovery handler.DiscoveryHandler,
networkScan handler.NetworkScanHandler,
) {
// Health endpoints (no auth middleware — must always be accessible)
r.mux.Handle("GET /health", middleware.Chain(
@@ -188,6 +189,7 @@ func (r *Router) RegisterHandlers(
// Metrics routes: /api/v1/metrics
r.Register("GET /api/v1/metrics", http.HandlerFunc(metrics.GetMetrics))
r.Register("GET /api/v1/metrics/prometheus", http.HandlerFunc(metrics.GetPrometheusMetrics))
// Discovery routes: /api/v1/discovered-certificates, /api/v1/discovery-scans
r.Register("POST /api/v1/agents/{id}/discoveries", http.HandlerFunc(discovery.SubmitDiscoveryReport))
@@ -197,6 +199,14 @@ func (r *Router) RegisterHandlers(
r.Register("POST /api/v1/discovered-certificates/{id}/dismiss", http.HandlerFunc(discovery.DismissDiscovered))
r.Register("GET /api/v1/discovery-scans", http.HandlerFunc(discovery.ListScans))
r.Register("GET /api/v1/discovery-summary", http.HandlerFunc(discovery.GetDiscoverySummary))
// Network scan routes: /api/v1/network-scan-targets
r.Register("GET /api/v1/network-scan-targets", http.HandlerFunc(networkScan.ListNetworkScanTargets))
r.Register("POST /api/v1/network-scan-targets", http.HandlerFunc(networkScan.CreateNetworkScanTarget))
r.Register("GET /api/v1/network-scan-targets/{id}", http.HandlerFunc(networkScan.GetNetworkScanTarget))
r.Register("PUT /api/v1/network-scan-targets/{id}", http.HandlerFunc(networkScan.UpdateNetworkScanTarget))
r.Register("DELETE /api/v1/network-scan-targets/{id}", http.HandlerFunc(networkScan.DeleteNetworkScanTarget))
r.Register("POST /api/v1/network-scan-targets/{id}/scan", http.HandlerFunc(networkScan.TriggerNetworkScan))
}
// GetMux returns the underlying http.ServeMux for direct access if needed.
+21 -10
View File
@@ -11,16 +11,17 @@ import (
// Config represents the complete application configuration.
// All configuration values are read from environment variables with CERTCTL_ prefix.
type Config struct {
Server ServerConfig
Database DatabaseConfig
Scheduler SchedulerConfig
Log LogConfig
Auth AuthConfig
RateLimit RateLimitConfig
CORS CORSConfig
Keygen KeygenConfig
CA CAConfig
Notifiers NotifierConfig
Server ServerConfig
Database DatabaseConfig
Scheduler SchedulerConfig
Log LogConfig
Auth AuthConfig
RateLimit RateLimitConfig
CORS CORSConfig
Keygen KeygenConfig
CA CAConfig
Notifiers NotifierConfig
NetworkScan NetworkScanConfig
}
// NotifierConfig contains configuration for notification connectors.
@@ -80,6 +81,12 @@ type OpenSSLConfig struct {
TimeoutSeconds int
}
// NetworkScanConfig controls the server-side active TLS scanner.
type NetworkScanConfig struct {
Enabled bool // Enable network scanning (default false)
ScanInterval time.Duration // How often to run network scans (default 6h)
}
// ServerConfig contains HTTP server configuration.
type ServerConfig struct {
Host string
@@ -178,6 +185,10 @@ func Load() (*Config, error) {
OpsGenieAPIKey: getEnv("CERTCTL_OPSGENIE_API_KEY", ""),
OpsGeniePriority: getEnv("CERTCTL_OPSGENIE_PRIORITY", "P3"),
},
NetworkScan: NetworkScanConfig{
Enabled: getEnvBool("CERTCTL_NETWORK_SCAN_ENABLED", false),
ScanInterval: getEnvDuration("CERTCTL_NETWORK_SCAN_INTERVAL", 6*time.Hour),
},
}
if err := cfg.Validate(); err != nil {
+27
View File
@@ -0,0 +1,27 @@
package domain
import "time"
// NetworkScanTarget defines a network range to scan for TLS certificates.
type NetworkScanTarget struct {
ID string `json:"id"`
Name string `json:"name"`
CIDRs []string `json:"cidrs"`
Ports []int `json:"ports"`
Enabled bool `json:"enabled"`
ScanIntervalHours int `json:"scan_interval_hours"`
TimeoutMs int `json:"timeout_ms"`
LastScanAt *time.Time `json:"last_scan_at,omitempty"`
LastScanDurationMs *int `json:"last_scan_duration_ms,omitempty"`
LastScanCertsFound *int `json:"last_scan_certs_found,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// NetworkScanResult holds the outcome of scanning a single endpoint.
type NetworkScanResult struct {
Address string // "ip:port"
Certs []DiscoveredCertEntry
Error string
LatencyMs int
}
+67
View File
@@ -0,0 +1,67 @@
package domain
import (
"testing"
"time"
)
func TestNetworkScanTarget_Defaults(t *testing.T) {
target := NetworkScanTarget{
ID: "nst-test",
Name: "Test Target",
CIDRs: []string{"10.0.0.0/24"},
Ports: []int{443},
Enabled: true,
ScanIntervalHours: 6,
TimeoutMs: 5000,
}
if target.ID != "nst-test" {
t.Errorf("expected ID nst-test, got %s", target.ID)
}
if len(target.CIDRs) != 1 || target.CIDRs[0] != "10.0.0.0/24" {
t.Errorf("unexpected CIDRs: %v", target.CIDRs)
}
if target.LastScanAt != nil {
t.Error("expected nil LastScanAt for new target")
}
}
func TestNetworkScanTarget_WithScanResults(t *testing.T) {
now := time.Now()
duration := 1500
found := 12
target := NetworkScanTarget{
ID: "nst-prod",
Name: "Production Network",
CIDRs: []string{"192.168.1.0/24", "10.0.0.0/16"},
Ports: []int{443, 8443, 636},
Enabled: true,
ScanIntervalHours: 1,
TimeoutMs: 3000,
LastScanAt: &now,
LastScanDurationMs: &duration,
LastScanCertsFound: &found,
}
if len(target.Ports) != 3 {
t.Errorf("expected 3 ports, got %d", len(target.Ports))
}
if *target.LastScanCertsFound != 12 {
t.Errorf("expected 12 certs found, got %d", *target.LastScanCertsFound)
}
}
func TestNetworkScanResult_Fields(t *testing.T) {
result := NetworkScanResult{
Address: "192.168.1.1:443",
Error: "",
LatencyMs: 45,
}
if result.Address != "192.168.1.1:443" {
t.Errorf("expected address 192.168.1.1:443, got %s", result.Address)
}
if result.LatencyMs != 45 {
t.Errorf("expected latency 45ms, got %d", result.LatencyMs)
}
}
+29
View File
@@ -80,6 +80,7 @@ func TestCertificateLifecycle(t *testing.T) {
metricsHandler := handler.NewMetricsHandler(&mockStatsService{}, time.Now())
healthHandler := handler.NewHealthHandler("none")
discoveryHandler := handler.NewDiscoveryHandler(&mockDiscoveryService{})
networkScanHandler := handler.NewNetworkScanHandler(&mockNetworkScanService{})
// Create router and register handlers
r := router.New()
@@ -100,6 +101,7 @@ func TestCertificateLifecycle(t *testing.T) {
metricsHandler,
healthHandler,
discoveryHandler,
networkScanHandler,
)
// Create test server
@@ -1174,3 +1176,30 @@ func (m *mockDiscoveryService) GetScan(ctx context.Context, id string) (*domain.
func (m *mockDiscoveryService) GetDiscoverySummary(ctx context.Context) (map[string]int, error) {
return map[string]int{}, nil
}
// mockNetworkScanService implements handler.NetworkScanService for integration tests.
type mockNetworkScanService struct{}
func (m *mockNetworkScanService) ListTargets(ctx context.Context) ([]*domain.NetworkScanTarget, error) {
return nil, nil
}
func (m *mockNetworkScanService) GetTarget(ctx context.Context, id string) (*domain.NetworkScanTarget, error) {
return nil, fmt.Errorf("not found")
}
func (m *mockNetworkScanService) CreateTarget(ctx context.Context, target *domain.NetworkScanTarget) (*domain.NetworkScanTarget, error) {
return target, nil
}
func (m *mockNetworkScanService) UpdateTarget(ctx context.Context, id string, target *domain.NetworkScanTarget) (*domain.NetworkScanTarget, error) {
return target, nil
}
func (m *mockNetworkScanService) DeleteTarget(ctx context.Context, id string) error {
return nil
}
func (m *mockNetworkScanService) TriggerScan(ctx context.Context, targetID string) (*domain.DiscoveryScan, error) {
return nil, nil
}
+4
View File
@@ -73,6 +73,7 @@ func setupTestServer(t *testing.T) (*httptest.Server, *mockCertificateRepository
metricsHandler := handler.NewMetricsHandler(&mockStatsService{}, time.Now())
healthHandler := handler.NewHealthHandler("none")
discoveryHandler := handler.NewDiscoveryHandler(&mockDiscoveryService{})
networkScanHandler := handler.NewNetworkScanHandler(&mockNetworkScanService{})
r := router.New()
r.RegisterHandlers(
@@ -92,6 +93,7 @@ func setupTestServer(t *testing.T) (*httptest.Server, *mockCertificateRepository
metricsHandler,
healthHandler,
discoveryHandler,
networkScanHandler,
)
server := httptest.NewServer(r)
@@ -796,3 +798,5 @@ func TestRevocationEndpoints(t *testing.T) {
}
})
}
// mockNetworkScanService is defined in lifecycle_test.go (same package)
+18
View File
@@ -238,6 +238,24 @@ type DiscoveryFilter struct {
PerPage int
}
// NetworkScanRepository defines operations for managing network scan targets.
type NetworkScanRepository interface {
// List returns all network scan targets.
List(ctx context.Context) ([]*domain.NetworkScanTarget, error)
// ListEnabled returns only enabled scan targets.
ListEnabled(ctx context.Context) ([]*domain.NetworkScanTarget, error)
// Get retrieves a network scan target by ID.
Get(ctx context.Context, id string) (*domain.NetworkScanTarget, error)
// Create stores a new network scan target.
Create(ctx context.Context, target *domain.NetworkScanTarget) error
// Update modifies an existing network scan target.
Update(ctx context.Context, target *domain.NetworkScanTarget) error
// Delete removes a network scan target.
Delete(ctx context.Context, id string) error
// UpdateScanResults records the outcome of the last scan for a target.
UpdateScanResults(ctx context.Context, id string, scanAt time.Time, durationMs int, certsFound int) error
}
// OwnerRepository defines operations for managing certificate owners.
type OwnerRepository interface {
// List returns all owners.
@@ -0,0 +1,181 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/lib/pq"
"github.com/shankar0123/certctl/internal/domain"
)
// NetworkScanRepository implements repository.NetworkScanRepository using PostgreSQL.
type NetworkScanRepository struct {
db *sql.DB
}
// NewNetworkScanRepository creates a new PostgreSQL-backed network scan repository.
func NewNetworkScanRepository(db *sql.DB) *NetworkScanRepository {
return &NetworkScanRepository{db: db}
}
// List returns all network scan targets.
func (r *NetworkScanRepository) List(ctx context.Context) ([]*domain.NetworkScanTarget, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, name, cidrs, ports, enabled, scan_interval_hours, timeout_ms,
last_scan_at, last_scan_duration_ms, last_scan_certs_found,
created_at, updated_at
FROM network_scan_targets
ORDER BY created_at DESC`)
if err != nil {
return nil, fmt.Errorf("list network scan targets: %w", err)
}
defer rows.Close()
return r.scanRows(rows)
}
// ListEnabled returns only enabled scan targets.
func (r *NetworkScanRepository) ListEnabled(ctx context.Context) ([]*domain.NetworkScanTarget, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, name, cidrs, ports, enabled, scan_interval_hours, timeout_ms,
last_scan_at, last_scan_duration_ms, last_scan_certs_found,
created_at, updated_at
FROM network_scan_targets
WHERE enabled = TRUE
ORDER BY created_at DESC`)
if err != nil {
return nil, fmt.Errorf("list enabled network scan targets: %w", err)
}
defer rows.Close()
return r.scanRows(rows)
}
// Get retrieves a network scan target by ID.
func (r *NetworkScanRepository) Get(ctx context.Context, id string) (*domain.NetworkScanTarget, error) {
target := &domain.NetworkScanTarget{}
var lastScanAt sql.NullTime
var lastScanDurationMs, lastScanCertsFound sql.NullInt64
err := r.db.QueryRowContext(ctx, `
SELECT id, name, cidrs, ports, enabled, scan_interval_hours, timeout_ms,
last_scan_at, last_scan_duration_ms, last_scan_certs_found,
created_at, updated_at
FROM network_scan_targets
WHERE id = $1`, id).Scan(
&target.ID, &target.Name, pq.Array(&target.CIDRs), pq.Array(&target.Ports),
&target.Enabled, &target.ScanIntervalHours, &target.TimeoutMs,
&lastScanAt, &lastScanDurationMs, &lastScanCertsFound,
&target.CreatedAt, &target.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("network scan target not found: %s", id)
}
if err != nil {
return nil, fmt.Errorf("get network scan target: %w", err)
}
if lastScanAt.Valid {
target.LastScanAt = &lastScanAt.Time
}
if lastScanDurationMs.Valid {
v := int(lastScanDurationMs.Int64)
target.LastScanDurationMs = &v
}
if lastScanCertsFound.Valid {
v := int(lastScanCertsFound.Int64)
target.LastScanCertsFound = &v
}
return target, nil
}
// Create stores a new network scan target.
func (r *NetworkScanRepository) Create(ctx context.Context, target *domain.NetworkScanTarget) error {
_, err := r.db.ExecContext(ctx, `
INSERT INTO network_scan_targets (id, name, cidrs, ports, enabled, scan_interval_hours, timeout_ms, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
target.ID, target.Name, pq.Array(target.CIDRs), pq.Array(target.Ports),
target.Enabled, target.ScanIntervalHours, target.TimeoutMs,
target.CreatedAt, target.UpdatedAt,
)
if err != nil {
return fmt.Errorf("create network scan target: %w", err)
}
return nil
}
// Update modifies an existing network scan target.
func (r *NetworkScanRepository) Update(ctx context.Context, target *domain.NetworkScanTarget) error {
result, err := r.db.ExecContext(ctx, `
UPDATE network_scan_targets
SET name = $1, cidrs = $2, ports = $3, enabled = $4, scan_interval_hours = $5, timeout_ms = $6, updated_at = $7
WHERE id = $8`,
target.Name, pq.Array(target.CIDRs), pq.Array(target.Ports),
target.Enabled, target.ScanIntervalHours, target.TimeoutMs,
time.Now(), target.ID,
)
if err != nil {
return fmt.Errorf("update network scan target: %w", err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return fmt.Errorf("network scan target not found: %s", target.ID)
}
return nil
}
// Delete removes a network scan target.
func (r *NetworkScanRepository) Delete(ctx context.Context, id string) error {
result, err := r.db.ExecContext(ctx, `DELETE FROM network_scan_targets WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("delete network scan target: %w", err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return fmt.Errorf("network scan target not found: %s", id)
}
return nil
}
// UpdateScanResults records the outcome of the last scan for a target.
func (r *NetworkScanRepository) UpdateScanResults(ctx context.Context, id string, scanAt time.Time, durationMs int, certsFound int) error {
_, err := r.db.ExecContext(ctx, `
UPDATE network_scan_targets
SET last_scan_at = $1, last_scan_duration_ms = $2, last_scan_certs_found = $3, updated_at = $4
WHERE id = $5`,
scanAt, durationMs, certsFound, time.Now(), id,
)
if err != nil {
return fmt.Errorf("update scan results: %w", err)
}
return nil
}
// scanRows scans multiple rows from a query result.
func (r *NetworkScanRepository) scanRows(rows *sql.Rows) ([]*domain.NetworkScanTarget, error) {
var targets []*domain.NetworkScanTarget
for rows.Next() {
target := &domain.NetworkScanTarget{}
var lastScanAt sql.NullTime
var lastScanDurationMs, lastScanCertsFound sql.NullInt64
if err := rows.Scan(
&target.ID, &target.Name, pq.Array(&target.CIDRs), pq.Array(&target.Ports),
&target.Enabled, &target.ScanIntervalHours, &target.TimeoutMs,
&lastScanAt, &lastScanDurationMs, &lastScanCertsFound,
&target.CreatedAt, &target.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan network scan target row: %w", err)
}
if lastScanAt.Valid {
target.LastScanAt = &lastScanAt.Time
}
if lastScanDurationMs.Valid {
v := int(lastScanDurationMs.Int64)
target.LastScanDurationMs = &v
}
if lastScanCertsFound.Valid {
v := int(lastScanCertsFound.Int64)
target.LastScanCertsFound = &v
}
targets = append(targets, target)
}
return targets, rows.Err()
}
+45
View File
@@ -16,6 +16,7 @@ type Scheduler struct {
jobService *service.JobService
agentService *service.AgentService
notificationService *service.NotificationService
networkScanService *service.NetworkScanService
logger *slog.Logger
// Configurable tick intervals
@@ -24,6 +25,7 @@ type Scheduler struct {
agentHealthCheckInterval time.Duration
notificationProcessInterval time.Duration
shortLivedExpiryCheckInterval time.Duration
networkScanInterval time.Duration
}
// NewScheduler creates a new scheduler with configurable intervals.
@@ -32,6 +34,7 @@ func NewScheduler(
jobService *service.JobService,
agentService *service.AgentService,
notificationService *service.NotificationService,
networkScanService *service.NetworkScanService,
logger *slog.Logger,
) *Scheduler {
return &Scheduler{
@@ -39,6 +42,7 @@ func NewScheduler(
jobService: jobService,
agentService: agentService,
notificationService: notificationService,
networkScanService: networkScanService,
logger: logger,
// Default intervals
@@ -47,6 +51,7 @@ func NewScheduler(
agentHealthCheckInterval: 2 * time.Minute,
notificationProcessInterval: 1 * time.Minute,
shortLivedExpiryCheckInterval: 30 * time.Second,
networkScanInterval: 6 * time.Hour,
}
}
@@ -70,6 +75,11 @@ func (s *Scheduler) SetNotificationProcessInterval(d time.Duration) {
s.notificationProcessInterval = d
}
// SetNetworkScanInterval configures the interval for network scanning.
func (s *Scheduler) SetNetworkScanInterval(d time.Duration) {
s.networkScanInterval = d
}
// Start initiates all background scheduler loops. It returns a channel that signals
// when the scheduler has started all loops. The scheduler runs until the context is cancelled.
func (s *Scheduler) Start(ctx context.Context) <-chan struct{} {
@@ -90,6 +100,9 @@ func (s *Scheduler) Start(ctx context.Context) <-chan struct{} {
go s.agentHealthCheckLoop(ctx)
go s.notificationProcessLoop(ctx)
go s.shortLivedExpiryCheckLoop(ctx)
if s.networkScanService != nil {
go s.networkScanLoop(ctx)
}
// Wait for context cancellation
<-ctx.Done()
@@ -258,3 +271,35 @@ func (s *Scheduler) runShortLivedExpiryCheck(ctx context.Context) {
s.logger.Debug("short-lived expiry check completed")
}
}
// networkScanLoop runs every networkScanInterval and performs active TLS scanning
// of configured network targets.
func (s *Scheduler) networkScanLoop(ctx context.Context) {
ticker := time.NewTicker(s.networkScanInterval)
defer ticker.Stop()
// Run immediately on start
s.runNetworkScan(ctx)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.runNetworkScan(ctx)
}
}
}
// runNetworkScan executes a single network scan cycle with error recovery.
func (s *Scheduler) runNetworkScan(ctx context.Context) {
opCtx, cancel := context.WithTimeout(ctx, 30*time.Minute)
defer cancel()
if err := s.networkScanService.ScanAllTargets(opCtx); err != nil {
s.logger.Error("network scan failed",
"error", err,
"interval", s.networkScanInterval.String())
} else {
s.logger.Debug("network scan completed")
}
}
+436
View File
@@ -0,0 +1,436 @@
package service
import (
"context"
"crypto/ecdsa"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"log/slog"
"net"
"sync"
"time"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/repository"
)
// SentinelAgentID is the agent ID used for network-discovered certificates.
// This allows the existing discovery dedup constraint (fingerprint, agent_id, source_path)
// to work without schema changes.
const SentinelAgentID = "server-scanner"
// NetworkScanService manages active TLS scanning of network endpoints.
type NetworkScanService struct {
networkScanRepo repository.NetworkScanRepository
discoveryService *DiscoveryService
auditService *AuditService
logger *slog.Logger
concurrency int
}
// NewNetworkScanService creates a new network scan service.
func NewNetworkScanService(
networkScanRepo repository.NetworkScanRepository,
discoveryService *DiscoveryService,
auditService *AuditService,
logger *slog.Logger,
) *NetworkScanService {
return &NetworkScanService{
networkScanRepo: networkScanRepo,
discoveryService: discoveryService,
auditService: auditService,
logger: logger,
concurrency: 50,
}
}
// ListTargets returns all network scan targets.
func (s *NetworkScanService) ListTargets(ctx context.Context) ([]*domain.NetworkScanTarget, error) {
return s.networkScanRepo.List(ctx)
}
// GetTarget retrieves a network scan target by ID.
func (s *NetworkScanService) GetTarget(ctx context.Context, id string) (*domain.NetworkScanTarget, error) {
return s.networkScanRepo.Get(ctx, id)
}
// CreateTarget creates a new network scan target.
func (s *NetworkScanService) CreateTarget(ctx context.Context, target *domain.NetworkScanTarget) (*domain.NetworkScanTarget, error) {
if target.Name == "" {
return nil, fmt.Errorf("name is required")
}
if len(target.CIDRs) == 0 {
return nil, fmt.Errorf("at least one CIDR is required")
}
// Validate CIDRs
for _, cidr := range target.CIDRs {
if _, _, err := net.ParseCIDR(cidr); err != nil {
// Try parsing as plain IP
if ip := net.ParseIP(cidr); ip == nil {
return nil, fmt.Errorf("invalid CIDR or IP: %s", cidr)
}
}
}
if len(target.Ports) == 0 {
target.Ports = []int{443}
}
if target.ScanIntervalHours == 0 {
target.ScanIntervalHours = 6
}
if target.TimeoutMs == 0 {
target.TimeoutMs = 5000
}
target.ID = generateID("nst")
target.Enabled = true
target.CreatedAt = time.Now()
target.UpdatedAt = time.Now()
if err := s.networkScanRepo.Create(ctx, target); err != nil {
return nil, err
}
s.auditService.RecordEvent(ctx, "operator", domain.ActorTypeUser,
"network_scan_target_created", "network_scan_target", target.ID,
map[string]interface{}{
"name": target.Name,
"cidrs": target.CIDRs,
"ports": target.Ports,
})
return target, nil
}
// UpdateTarget updates an existing network scan target.
func (s *NetworkScanService) UpdateTarget(ctx context.Context, id string, target *domain.NetworkScanTarget) (*domain.NetworkScanTarget, error) {
existing, err := s.networkScanRepo.Get(ctx, id)
if err != nil {
return nil, err
}
if target.Name != "" {
existing.Name = target.Name
}
if len(target.CIDRs) > 0 {
// Validate new CIDRs
for _, cidr := range target.CIDRs {
if _, _, err := net.ParseCIDR(cidr); err != nil {
if ip := net.ParseIP(cidr); ip == nil {
return nil, fmt.Errorf("invalid CIDR or IP: %s", cidr)
}
}
}
existing.CIDRs = target.CIDRs
}
if len(target.Ports) > 0 {
existing.Ports = target.Ports
}
if target.ScanIntervalHours > 0 {
existing.ScanIntervalHours = target.ScanIntervalHours
}
if target.TimeoutMs > 0 {
existing.TimeoutMs = target.TimeoutMs
}
// Always update enabled field (it's a boolean, so 0-value is meaningful)
existing.Enabled = target.Enabled
if err := s.networkScanRepo.Update(ctx, existing); err != nil {
return nil, err
}
return existing, nil
}
// DeleteTarget removes a network scan target.
func (s *NetworkScanService) DeleteTarget(ctx context.Context, id string) error {
if err := s.networkScanRepo.Delete(ctx, id); err != nil {
return err
}
s.auditService.RecordEvent(ctx, "operator", domain.ActorTypeUser,
"network_scan_target_deleted", "network_scan_target", id, nil)
return nil
}
// ScanAllTargets runs the active TLS scan for all enabled targets.
// This is called by the scheduler on the configured interval.
func (s *NetworkScanService) ScanAllTargets(ctx context.Context) error {
targets, err := s.networkScanRepo.ListEnabled(ctx)
if err != nil {
return fmt.Errorf("list enabled targets: %w", err)
}
if len(targets) == 0 {
if s.logger != nil {
s.logger.Debug("no enabled network scan targets")
}
return nil
}
if s.logger != nil {
s.logger.Info("starting network scan", "targets", len(targets))
}
for _, target := range targets {
if ctx.Err() != nil {
return ctx.Err()
}
s.scanTarget(ctx, target)
}
return nil
}
// TriggerScan runs an immediate scan for a specific target.
func (s *NetworkScanService) TriggerScan(ctx context.Context, targetID string) (*domain.DiscoveryScan, error) {
target, err := s.networkScanRepo.Get(ctx, targetID)
if err != nil {
return nil, err
}
return s.scanTarget(ctx, target), nil
}
// scanTarget scans a single network target and feeds results into the discovery pipeline.
func (s *NetworkScanService) scanTarget(ctx context.Context, target *domain.NetworkScanTarget) *domain.DiscoveryScan {
startTime := time.Now()
if s.logger != nil {
s.logger.Info("scanning network target",
"target_id", target.ID,
"name", target.Name,
"cidrs", target.CIDRs,
"ports", target.Ports)
}
// Expand CIDRs to individual IPs
endpoints := s.expandEndpoints(target.CIDRs, target.Ports)
if s.logger != nil {
s.logger.Debug("expanded endpoints", "count", len(endpoints))
}
// Scan endpoints concurrently
timeout := time.Duration(target.TimeoutMs) * time.Millisecond
results := s.scanEndpoints(ctx, endpoints, timeout)
// Collect discovered cert entries
var entries []domain.DiscoveredCertEntry
var scanErrors []string
for _, result := range results {
if result.Error != "" {
// Only log connection errors at debug level (many hosts won't have TLS)
if s.logger != nil {
s.logger.Debug("scan endpoint error",
"address", result.Address,
"error", result.Error)
}
continue
}
entries = append(entries, result.Certs...)
}
scanDuration := time.Since(startTime)
if s.logger != nil {
s.logger.Info("network target scan completed",
"target_id", target.ID,
"endpoints_scanned", len(endpoints),
"certificates_found", len(entries),
"errors", len(scanErrors),
"duration_ms", scanDuration.Milliseconds())
}
// Update scan results on target
s.networkScanRepo.UpdateScanResults(ctx, target.ID, time.Now(),
int(scanDuration.Milliseconds()), len(entries))
// Feed into discovery pipeline if we found certs
if len(entries) == 0 {
return nil
}
// Build directories list from CIDRs for the scan record
dirs := make([]string, len(target.CIDRs))
copy(dirs, target.CIDRs)
report := &domain.DiscoveryReport{
AgentID: SentinelAgentID,
Directories: dirs,
Certificates: entries,
Errors: scanErrors,
ScanDurationMs: int(scanDuration.Milliseconds()),
}
scan, err := s.discoveryService.ProcessDiscoveryReport(ctx, report)
if err != nil {
if s.logger != nil {
s.logger.Error("failed to process network scan report",
"target_id", target.ID,
"error", err)
}
return nil
}
return scan
}
// expandEndpoints converts CIDR ranges and ports into a list of "ip:port" endpoints.
func (s *NetworkScanService) expandEndpoints(cidrs []string, ports []int) []string {
var endpoints []string
for _, cidr := range cidrs {
ips := expandCIDR(cidr)
for _, ip := range ips {
for _, port := range ports {
endpoints = append(endpoints, fmt.Sprintf("%s:%d", ip, port))
}
}
}
return endpoints
}
// expandCIDR expands a CIDR notation or single IP into a list of IPs.
// Limits expansion to /20 (4096 IPs) to prevent accidental huge scans.
func expandCIDR(cidr string) []string {
// Try as CIDR first
ip, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
// Try as single IP
if singleIP := net.ParseIP(cidr); singleIP != nil {
return []string{singleIP.String()}
}
return nil
}
// Count network size and cap at /20
ones, bits := ipNet.Mask.Size()
hostBits := bits - ones
if hostBits > 12 { // More than 4096 hosts
return nil // Skip overly large networks
}
var ips []string
for ip := ip.Mask(ipNet.Mask); ipNet.Contains(ip); incrementIP(ip) {
// Copy IP before appending (net.IP is a mutable slice)
ipCopy := make(net.IP, len(ip))
copy(ipCopy, ip)
ips = append(ips, ipCopy.String())
}
// Remove network and broadcast for IPv4 /31 and larger
if len(ips) > 2 {
ips = ips[1 : len(ips)-1]
}
return ips
}
// incrementIP increments an IP address by one.
func incrementIP(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
if ip[j] > 0 {
break
}
}
}
// scanEndpoints probes TLS endpoints concurrently and returns results.
func (s *NetworkScanService) scanEndpoints(ctx context.Context, endpoints []string, timeout time.Duration) []domain.NetworkScanResult {
results := make([]domain.NetworkScanResult, len(endpoints))
sem := make(chan struct{}, s.concurrency)
var wg sync.WaitGroup
for i, endpoint := range endpoints {
if ctx.Err() != nil {
break
}
wg.Add(1)
sem <- struct{}{}
go func(idx int, addr string) {
defer wg.Done()
defer func() { <-sem }()
results[idx] = s.probeTLS(ctx, addr, timeout)
}(i, endpoint)
}
wg.Wait()
return results
}
// probeTLS connects to an endpoint, performs a TLS handshake, and extracts certificates.
func (s *NetworkScanService) probeTLS(ctx context.Context, address string, timeout time.Duration) domain.NetworkScanResult {
startTime := time.Now()
result := domain.NetworkScanResult{Address: address}
dialer := &net.Dialer{Timeout: timeout}
conn, err := tls.DialWithDialer(dialer, "tcp", address, &tls.Config{
InsecureSkipVerify: true, // We want to discover ALL certs, including self-signed
})
if err != nil {
result.Error = err.Error()
result.LatencyMs = int(time.Since(startTime).Milliseconds())
return result
}
defer conn.Close()
result.LatencyMs = int(time.Since(startTime).Milliseconds())
// Extract certificates from TLS connection state
state := conn.ConnectionState()
for _, cert := range state.PeerCertificates {
entry := tlsCertToEntry(cert, address)
result.Certs = append(result.Certs, entry)
}
return result
}
// tlsCertToEntry converts an x509.Certificate from a TLS handshake into a DiscoveredCertEntry.
func tlsCertToEntry(cert *x509.Certificate, address string) domain.DiscoveredCertEntry {
// Compute SHA-256 fingerprint
fingerprintBytes := sha256.Sum256(cert.Raw)
fingerprint := fmt.Sprintf("%x", fingerprintBytes)
// Encode as PEM
pemBlock := &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
pemData := string(pem.EncodeToMemory(pemBlock))
// Key algorithm and size
keyAlg, keySize := tlsCertKeyInfo(cert)
return domain.DiscoveredCertEntry{
FingerprintSHA256: fingerprint,
CommonName: cert.Subject.CommonName,
SANs: cert.DNSNames,
SerialNumber: cert.SerialNumber.Text(16),
IssuerDN: cert.Issuer.String(),
SubjectDN: cert.Subject.String(),
NotBefore: cert.NotBefore.UTC().Format(time.RFC3339),
NotAfter: cert.NotAfter.UTC().Format(time.RFC3339),
KeyAlgorithm: keyAlg,
KeySize: keySize,
IsCA: cert.IsCA,
PEMData: pemData,
SourcePath: address,
SourceFormat: "network",
}
}
// tlsCertKeyInfo extracts key algorithm name and size from a certificate.
func tlsCertKeyInfo(cert *x509.Certificate) (string, int) {
switch pub := cert.PublicKey.(type) {
case *rsa.PublicKey:
return "RSA", pub.N.BitLen()
case *ecdsa.PublicKey:
return "ECDSA", pub.Curve.Params().BitSize
default:
switch cert.PublicKeyAlgorithm {
case x509.Ed25519:
return "Ed25519", 256
default:
return cert.PublicKeyAlgorithm.String(), 0
}
}
}
+244
View File
@@ -0,0 +1,244 @@
package service
import (
"context"
"fmt"
"testing"
"time"
"github.com/shankar0123/certctl/internal/domain"
)
// mockNetworkScanRepo for testing
type mockNetworkScanRepo struct {
targets []*domain.NetworkScanTarget
}
func (m *mockNetworkScanRepo) List(ctx context.Context) ([]*domain.NetworkScanTarget, error) {
return m.targets, nil
}
func (m *mockNetworkScanRepo) ListEnabled(ctx context.Context) ([]*domain.NetworkScanTarget, error) {
var enabled []*domain.NetworkScanTarget
for _, t := range m.targets {
if t.Enabled {
enabled = append(enabled, t)
}
}
return enabled, nil
}
func (m *mockNetworkScanRepo) Get(ctx context.Context, id string) (*domain.NetworkScanTarget, error) {
for _, t := range m.targets {
if t.ID == id {
return t, nil
}
}
return nil, fmt.Errorf("not found: %s", id)
}
func (m *mockNetworkScanRepo) Create(ctx context.Context, target *domain.NetworkScanTarget) error {
m.targets = append(m.targets, target)
return nil
}
func (m *mockNetworkScanRepo) Update(ctx context.Context, target *domain.NetworkScanTarget) error {
for i, t := range m.targets {
if t.ID == target.ID {
m.targets[i] = target
return nil
}
}
return fmt.Errorf("not found: %s", target.ID)
}
func (m *mockNetworkScanRepo) Delete(ctx context.Context, id string) error {
for i, t := range m.targets {
if t.ID == id {
m.targets = append(m.targets[:i], m.targets[i+1:]...)
return nil
}
}
return fmt.Errorf("not found: %s", id)
}
func (m *mockNetworkScanRepo) UpdateScanResults(ctx context.Context, id string, scanAt time.Time, durationMs int, certsFound int) error {
for _, t := range m.targets {
if t.ID == id {
t.LastScanAt = &scanAt
d := durationMs
t.LastScanDurationMs = &d
c := certsFound
t.LastScanCertsFound = &c
return nil
}
}
return fmt.Errorf("not found: %s", id)
}
func TestExpandCIDR_SingleIP(t *testing.T) {
ips := expandCIDR("192.168.1.1")
if len(ips) != 1 || ips[0] != "192.168.1.1" {
t.Errorf("expected [192.168.1.1], got %v", ips)
}
}
func TestExpandCIDR_Slash30(t *testing.T) {
// /30 = 4 total addresses, 2 usable (remove network + broadcast)
ips := expandCIDR("10.0.0.0/30")
if len(ips) != 2 {
t.Errorf("expected 2 usable IPs for /30, got %d: %v", len(ips), ips)
}
}
func TestExpandCIDR_Slash24(t *testing.T) {
ips := expandCIDR("10.0.0.0/24")
if len(ips) != 254 {
t.Errorf("expected 254 usable IPs for /24, got %d", len(ips))
}
}
func TestExpandCIDR_TooLarge(t *testing.T) {
// /16 = 65536 IPs, exceeds /20 cap
ips := expandCIDR("10.0.0.0/16")
if len(ips) != 0 {
t.Errorf("expected empty for /16 (too large), got %d", len(ips))
}
}
func TestExpandCIDR_InvalidInput(t *testing.T) {
ips := expandCIDR("not-a-cidr")
if len(ips) != 0 {
t.Errorf("expected empty for invalid input, got %v", ips)
}
}
func TestNetworkScanService_CreateTarget(t *testing.T) {
repo := &mockNetworkScanRepo{}
auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo)
svc := NewNetworkScanService(repo, nil, auditService, nil)
target, err := svc.CreateTarget(context.Background(), &domain.NetworkScanTarget{
Name: "Test Network",
CIDRs: []string{"10.0.0.0/24"},
Ports: []int{443, 8443},
})
if err != nil {
t.Fatalf("CreateTarget failed: %v", err)
}
if target.ID == "" {
t.Error("expected non-empty ID")
}
if !target.Enabled {
t.Error("expected target to be enabled by default")
}
if target.ScanIntervalHours != 6 {
t.Errorf("expected default interval 6h, got %d", target.ScanIntervalHours)
}
if target.TimeoutMs != 5000 {
t.Errorf("expected default timeout 5000ms, got %d", target.TimeoutMs)
}
}
func TestNetworkScanService_CreateTarget_ValidationErrors(t *testing.T) {
repo := &mockNetworkScanRepo{}
auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo)
svc := NewNetworkScanService(repo, nil, auditService, nil)
tests := []struct {
name string
target *domain.NetworkScanTarget
errMsg string
}{
{
name: "missing name",
target: &domain.NetworkScanTarget{CIDRs: []string{"10.0.0.0/24"}},
errMsg: "name is required",
},
{
name: "missing cidrs",
target: &domain.NetworkScanTarget{Name: "test"},
errMsg: "at least one CIDR is required",
},
{
name: "invalid cidr",
target: &domain.NetworkScanTarget{Name: "test", CIDRs: []string{"not-valid"}},
errMsg: "invalid CIDR or IP",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := svc.CreateTarget(context.Background(), tt.target)
if err == nil {
t.Fatal("expected error")
}
if !containsSubstring(err.Error(), tt.errMsg) {
t.Errorf("expected error containing %q, got %q", tt.errMsg, err.Error())
}
})
}
}
func TestNetworkScanService_DeleteTarget(t *testing.T) {
repo := &mockNetworkScanRepo{
targets: []*domain.NetworkScanTarget{
{ID: "nst-1", Name: "test"},
},
}
auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo)
svc := NewNetworkScanService(repo, nil, auditService, nil)
if err := svc.DeleteTarget(context.Background(), "nst-1"); err != nil {
t.Fatalf("DeleteTarget failed: %v", err)
}
if len(repo.targets) != 0 {
t.Error("expected target to be deleted")
}
}
func TestNetworkScanService_ListTargets(t *testing.T) {
repo := &mockNetworkScanRepo{
targets: []*domain.NetworkScanTarget{
{ID: "nst-1", Name: "target1"},
{ID: "nst-2", Name: "target2"},
},
}
svc := NewNetworkScanService(repo, nil, nil, nil)
targets, err := svc.ListTargets(context.Background())
if err != nil {
t.Fatalf("ListTargets failed: %v", err)
}
if len(targets) != 2 {
t.Errorf("expected 2 targets, got %d", len(targets))
}
}
func TestExpandEndpoints(t *testing.T) {
svc := &NetworkScanService{}
endpoints := svc.expandEndpoints([]string{"192.168.1.1"}, []int{443, 8443})
if len(endpoints) != 2 {
t.Errorf("expected 2 endpoints, got %d: %v", len(endpoints), endpoints)
}
if endpoints[0] != "192.168.1.1:443" {
t.Errorf("expected 192.168.1.1:443, got %s", endpoints[0])
}
if endpoints[1] != "192.168.1.1:8443" {
t.Errorf("expected 192.168.1.1:8443, got %s", endpoints[1])
}
}
// containsSubstring checks if a string contains a substring (helper)
func containsSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}