package service import ( "context" "errors" "log/slog" "os" "testing" "github.com/certctl-io/certctl/internal/domain" ) // mockDiscoverySource implements domain.DiscoverySource for testing. type mockDiscoverySource struct { name string sourceType string report *domain.DiscoveryReport discoverErr error validateErr error discoverCalls int } func (m *mockDiscoverySource) Name() string { return m.name } func (m *mockDiscoverySource) Type() string { return m.sourceType } func (m *mockDiscoverySource) ValidateConfig() error { return m.validateErr } func (m *mockDiscoverySource) Discover(_ context.Context) (*domain.DiscoveryReport, error) { m.discoverCalls++ return m.report, m.discoverErr } func TestCloudDiscoveryService_DiscoverAll_NoSources(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) svc := NewCloudDiscoveryService(nil, logger) total, errs := svc.DiscoverAll(context.Background()) if total != 0 { t.Errorf("expected 0 certs, got %d", total) } if len(errs) != 0 { t.Errorf("expected no errors, got %v", errs) } } func TestCloudDiscoveryService_DiscoverAll_Success(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) // We need a mock discovery service that doesn't actually hit a database. // Since CloudDiscoveryService calls discoveryService.ProcessDiscoveryReport, // we'll test with nil discoveryService and sources that return empty cert lists. svc := NewCloudDiscoveryService(nil, logger) src := &mockDiscoverySource{ name: "Test Source", sourceType: "test", report: &domain.DiscoveryReport{ AgentID: "cloud-test", Directories: []string{"test://source/"}, Certificates: []domain.DiscoveredCertEntry{}, ScanDurationMs: 100, }, } svc.RegisterSource(src) total, errs := svc.DiscoverAll(context.Background()) if total != 0 { t.Errorf("expected 0 certs, got %d", total) } if len(errs) != 0 { t.Errorf("expected no errors, got %v", errs) } if src.discoverCalls != 1 { t.Errorf("expected 1 discover call, got %d", src.discoverCalls) } } func TestCloudDiscoveryService_DiscoverAll_SourceError(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) svc := NewCloudDiscoveryService(nil, logger) src := &mockDiscoverySource{ name: "Failing Source", sourceType: "fail", discoverErr: errors.New("connection refused"), } svc.RegisterSource(src) total, errs := svc.DiscoverAll(context.Background()) if total != 0 { t.Errorf("expected 0 certs, got %d", total) } if len(errs) != 1 { t.Fatalf("expected 1 error, got %d", len(errs)) } if errs[0].Error() != "source Failing Source failed: connection refused" { t.Errorf("unexpected error: %v", errs[0]) } } func TestCloudDiscoveryService_DiscoverAll_MultipleSources(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) svc := NewCloudDiscoveryService(nil, logger) // Source 1: returns certs (but empty list — no ProcessDiscoveryReport call needed) src1 := &mockDiscoverySource{ name: "AWS SM", sourceType: "aws-sm", report: &domain.DiscoveryReport{ AgentID: "cloud-aws-sm", Directories: []string{"aws-sm://us-east-1/"}, Certificates: []domain.DiscoveredCertEntry{}, }, } // Source 2: fails src2 := &mockDiscoverySource{ name: "Azure KV", sourceType: "azure-kv", discoverErr: errors.New("auth failed"), } // Source 3: returns certs (empty) src3 := &mockDiscoverySource{ name: "GCP SM", sourceType: "gcp-sm", report: &domain.DiscoveryReport{ AgentID: "cloud-gcp-sm", Directories: []string{"gcp-sm://project/"}, Certificates: []domain.DiscoveredCertEntry{}, }, } svc.RegisterSource(src1) svc.RegisterSource(src2) svc.RegisterSource(src3) total, errs := svc.DiscoverAll(context.Background()) if total != 0 { t.Errorf("expected 0 total certs, got %d", total) } if len(errs) != 1 { t.Fatalf("expected 1 error (Azure KV), got %d", len(errs)) } // Verify all sources were called if src1.discoverCalls != 1 { t.Errorf("src1 expected 1 call, got %d", src1.discoverCalls) } if src2.discoverCalls != 1 { t.Errorf("src2 expected 1 call, got %d", src2.discoverCalls) } if src3.discoverCalls != 1 { t.Errorf("src3 expected 1 call, got %d", src3.discoverCalls) } } func TestCloudDiscoveryService_DiscoverAll_NilReport(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) svc := NewCloudDiscoveryService(nil, logger) src := &mockDiscoverySource{ name: "Nil Reporter", sourceType: "nil", report: nil, } svc.RegisterSource(src) total, errs := svc.DiscoverAll(context.Background()) if total != 0 { t.Errorf("expected 0 certs, got %d", total) } if len(errs) != 0 { t.Errorf("expected no errors, got %v", errs) } } func TestCloudDiscoveryService_DiscoverAll_CancelledContext(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) svc := NewCloudDiscoveryService(nil, logger) src := &mockDiscoverySource{ name: "Should Not Run", sourceType: "cancel", report: &domain.DiscoveryReport{}, } svc.RegisterSource(src) ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately total, errs := svc.DiscoverAll(ctx) if total != 0 { t.Errorf("expected 0 certs, got %d", total) } if len(errs) != 1 { t.Fatalf("expected 1 error, got %d", len(errs)) } } func TestCloudDiscoveryService_RegisterSource(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) svc := NewCloudDiscoveryService(nil, logger) if svc.SourceCount() != 0 { t.Errorf("expected 0 sources, got %d", svc.SourceCount()) } svc.RegisterSource(&mockDiscoverySource{name: "src1", sourceType: "t1"}) svc.RegisterSource(&mockDiscoverySource{name: "src2", sourceType: "t2"}) if svc.SourceCount() != 2 { t.Errorf("expected 2 sources, got %d", svc.SourceCount()) } } func TestCloudDiscoveryService_DiscoverAll_WithCertsFound(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) // Use nil discoveryService — will cause ProcessDiscoveryReport to panic // unless we handle it. Since the service checks certCount > 0, we test the count tracking. // We'll use a source that returns certs but discoveryService is nil, expecting an error // from the nil pointer dereference recovery. svc := NewCloudDiscoveryService(nil, logger) src := &mockDiscoverySource{ name: "Has Certs", sourceType: "test", report: &domain.DiscoveryReport{ AgentID: "cloud-test", Directories: []string{"test://"}, Certificates: []domain.DiscoveredCertEntry{ { FingerprintSHA256: "AABBCCDD", CommonName: "test.example.com", SourcePath: "test://secret1", SourceFormat: "PEM", }, { FingerprintSHA256: "EEFF0011", CommonName: "api.example.com", SourcePath: "test://secret2", SourceFormat: "PEM", }, }, ScanDurationMs: 200, }, } svc.RegisterSource(src) // This will try to call ProcessDiscoveryReport on nil discoveryService, // which will cause a panic recovered as an error. The cert count is still tracked. // We use recover to verify the behavior. func() { defer func() { if r := recover(); r != nil { // Expected — nil discoveryService with certs to process t.Logf("expected panic from nil discoveryService: %v", r) } }() total, _ := svc.DiscoverAll(context.Background()) // If we get here without panic, total should reflect found certs if total != 2 { t.Errorf("expected 2 certs, got %d", total) } }() } func TestCloudDiscoveryService_SentinelAgentIDs(t *testing.T) { // Verify sentinel agent ID constants are correct if SentinelAWSSecretsMgr != "cloud-aws-sm" { t.Errorf("expected cloud-aws-sm, got %s", SentinelAWSSecretsMgr) } if SentinelAzureKeyVault != "cloud-azure-kv" { t.Errorf("expected cloud-azure-kv, got %s", SentinelAzureKeyVault) } if SentinelGCPSecretMgr != "cloud-gcp-sm" { t.Errorf("expected cloud-gcp-sm, got %s", SentinelGCPSecretMgr) } }