feat(M50): cloud secret manager discovery — AWS SM, Azure KV, GCP SM

Extend certificate discovery from filesystem + network to cloud secret
managers. Three pluggable DiscoverySource connectors feed into the
existing discovery pipeline via sentinel agent pattern, with a 9th
scheduler loop for periodic cloud scanning.

- AWS Secrets Manager: aws-sdk-go-v2, tag/prefix filtering, 10 tests
- Azure Key Vault: stdlib HTTP + OAuth2, base64 DER/PEM, 16 tests
- GCP Secret Manager: stdlib HTTP + JWT OAuth2, label filter, 14 tests
- CloudDiscoveryService orchestrator with 9 tests
- 9th scheduler loop (6h default, atomic.Bool idempotency)
- Discovery page: color-coded source type badges
- 14 new env vars across CloudDiscoveryConfig structs
- Docs: connectors.md, architecture.md, features.md, README updated

49 new tests. All CI checks pass (go vet, race, lint, coverage).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
shankar0123
2026-04-15 23:01:00 -04:00
parent 3f619bcaac
commit e1bcde4cf1
19 changed files with 3791 additions and 24 deletions
@@ -0,0 +1,515 @@
// Package azurekv implements the domain.DiscoverySource interface for
// Azure Key Vault certificate discovery.
//
// Azure Key Vault is a cloud-based secret and certificate management service.
// This connector discovers certificates stored in an Azure Key Vault using the
// Azure Key Vault REST API with OAuth2 client credentials authentication.
//
// No Azure SDK dependency — uses stdlib net/http + OAuth2 for authentication.
//
// API endpoints used:
//
// GET /certificates?api-version=7.4 - List certificates
// GET /certificates/{name}/{version}?api-version=7.4 - Get certificate details
//
// Authentication: OAuth2 client credentials flow via Azure AD.
// Token is cached with 5-minute refresh buffer.
package azurekv
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"regexp"
"strings"
"sync"
"time"
"github.com/shankar0123/certctl/internal/domain"
)
// Config represents the Azure Key Vault discovery configuration.
type Config struct {
// VaultURL is the Azure Key Vault URL (e.g., "https://myvault.vault.azure.net").
// Required. Set via CERTCTL_AZURE_KV_VAULT_URL environment variable.
VaultURL string `json:"vault_url"`
// TenantID is the Azure AD tenant ID (e.g., "00000000-0000-0000-0000-000000000000").
// Required. Set via CERTCTL_AZURE_KV_TENANT_ID environment variable.
TenantID string `json:"tenant_id"`
// ClientID is the Azure AD application (client) ID.
// Required. Set via CERTCTL_AZURE_KV_CLIENT_ID environment variable.
ClientID string `json:"client_id"`
// ClientSecret is the Azure AD application secret or certificate.
// Required. Set via CERTCTL_AZURE_KV_CLIENT_SECRET environment variable.
ClientSecret string `json:"client_secret"`
}
// cachedToken holds an OAuth2 access token and its expiry time.
type cachedToken struct {
token string
expiresAt time.Time
}
// certificateListResponse represents the Azure Key Vault list certificates response.
type certificateListResponse struct {
Value []struct {
ID string `json:"id"`
Attributes struct {
Enabled int64 `json:"enabled"`
Created int64 `json:"created"`
Updated int64 `json:"updated"`
Exp int64 `json:"exp"`
} `json:"attributes,omitempty"`
Tags map[string]string `json:"tags,omitempty"`
} `json:"value"`
NextLink string `json:"nextLink"`
}
// certificateBundle represents the Azure Key Vault certificate details response.
type certificateBundle struct {
ID string `json:"id"`
CER string `json:"cer"`
Attributes struct {
Enabled int64 `json:"enabled"`
Created int64 `json:"created"`
Updated int64 `json:"updated"`
Exp int64 `json:"exp"`
} `json:"attributes,omitempty"`
}
// KVClient is an interface for Azure Key Vault operations, allowing injection for testing.
type KVClient interface {
// ListCertificates retrieves the list of certificates in the vault.
ListCertificates(ctx context.Context, vaultURL string) ([]struct {
ID string
Attributes struct {
Exp int64
}
}, error)
// GetCertificate retrieves a specific certificate version.
GetCertificate(ctx context.Context, vaultURL, certName, version string) (*certificateBundle, error)
}
// Source implements domain.DiscoverySource for Azure Key Vault.
type Source struct {
config Config
logger *slog.Logger
client KVClient
}
// New creates a new Azure Key Vault discovery source with real HTTP client.
func New(cfg Config, logger *slog.Logger) *Source {
return &Source{
config: cfg,
logger: logger,
client: &httpKVClient{
config: cfg,
httpClient: &http.Client{Timeout: 30 * time.Second},
},
}
}
// NewWithClient creates a new Azure Key Vault discovery source with injected client (for testing).
func NewWithClient(cfg Config, client KVClient, logger *slog.Logger) *Source {
return &Source{
config: cfg,
logger: logger,
client: client,
}
}
// Name returns a human-readable name for this discovery source.
func (s *Source) Name() string {
return "Azure Key Vault"
}
// Type returns the short type identifier for this discovery source.
func (s *Source) Type() string {
return "azure-kv"
}
// ValidateConfig checks that the Azure Key Vault configuration is valid.
func (s *Source) ValidateConfig() error {
if s.config.VaultURL == "" {
return fmt.Errorf("Azure Key Vault URL is required")
}
if s.config.TenantID == "" {
return fmt.Errorf("Azure Key Vault tenant ID is required")
}
if s.config.ClientID == "" {
return fmt.Errorf("Azure Key Vault client ID is required")
}
if s.config.ClientSecret == "" {
return fmt.Errorf("Azure Key Vault client secret is required")
}
// Basic URL validation
if !strings.HasPrefix(s.config.VaultURL, "https://") {
return fmt.Errorf("Azure Key Vault URL must use HTTPS")
}
return nil
}
// Discover scans the Azure Key Vault and returns a DiscoveryReport.
func (s *Source) Discover(ctx context.Context) (*domain.DiscoveryReport, error) {
s.logger.Info("starting Azure Key Vault discovery", "vault_url", s.config.VaultURL)
report := &domain.DiscoveryReport{
AgentID: "cloud-azure-kv",
Directories: []string{fmt.Sprintf("azure-kv://%s/", s.config.VaultURL)},
Certificates: []domain.DiscoveredCertEntry{},
Errors: []string{},
}
startTime := time.Now()
// List certificates
certs, err := s.client.ListCertificates(ctx, s.config.VaultURL)
if err != nil {
s.logger.Error("failed to list Azure Key Vault certificates", "error", err)
report.Errors = append(report.Errors, fmt.Sprintf("list certificates failed: %v", err))
return report, nil
}
// Process each certificate
for _, cert := range certs {
// Extract certificate name and version from ID
// ID format: https://myvault.vault.azure.net/certificates/mycert/version123
certName, version, err := extractCertNameAndVersion(cert.ID)
if err != nil {
s.logger.Warn("failed to parse certificate ID", "id", cert.ID, "error", err)
report.Errors = append(report.Errors, fmt.Sprintf("parse cert ID failed: %v", err))
continue
}
// Get certificate details
certBundle, err := s.client.GetCertificate(ctx, s.config.VaultURL, certName, version)
if err != nil {
s.logger.Warn("failed to get certificate details", "name", certName, "version", version, "error", err)
report.Errors = append(report.Errors, fmt.Sprintf("get cert %s/%s failed: %v", certName, version, err))
continue
}
// Decode the base64-encoded DER certificate
if certBundle.CER == "" {
s.logger.Warn("empty certificate data", "name", certName, "version", version)
continue
}
derBytes, err := base64.StdEncoding.DecodeString(certBundle.CER)
if err != nil {
s.logger.Warn("failed to decode certificate", "name", certName, "version", version, "error", err)
report.Errors = append(report.Errors, fmt.Sprintf("decode cert %s/%s failed: %v", certName, version, err))
continue
}
// Parse certificate
x509Cert, err := x509.ParseCertificate(derBytes)
if err != nil {
s.logger.Warn("failed to parse certificate", "name", certName, "version", version, "error", err)
report.Errors = append(report.Errors, fmt.Sprintf("parse cert %s/%s failed: %v", certName, version, err))
continue
}
// Extract certificate metadata
entry := extractCertMetadata(x509Cert, certName, version)
// Encode as PEM for inclusion in report
certPEM := encodeCertPEM(derBytes)
entry.PEMData = certPEM
report.Certificates = append(report.Certificates, entry)
s.logger.Info("discovered certificate",
"name", certName,
"common_name", entry.CommonName,
"serial", entry.SerialNumber,
"not_after", entry.NotAfter)
}
report.ScanDurationMs = int(time.Since(startTime).Milliseconds())
s.logger.Info("Azure Key Vault discovery completed",
"certs_found", len(report.Certificates),
"errors", len(report.Errors),
"duration_ms", report.ScanDurationMs)
return report, nil
}
// httpKVClient implements KVClient using Azure Key Vault REST API.
type httpKVClient struct {
config Config
httpClient *http.Client
// OAuth2 token caching
mu sync.Mutex
tokenCache *cachedToken
}
// ListCertificates retrieves the list of certificates in the vault.
func (c *httpKVClient) ListCertificates(ctx context.Context, vaultURL string) ([]struct {
ID string
Attributes struct {
Exp int64
}
}, error) {
var results []struct {
ID string
Attributes struct {
Exp int64
}
}
listURL := fmt.Sprintf("%s/certificates?api-version=7.4", strings.TrimSuffix(vaultURL, "/"))
for listURL != "" {
token, err := c.getAccessToken(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get access token: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, listURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("list request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("list certificates returned status %d: %s", resp.StatusCode, string(body))
}
var listResp certificateListResponse
if err := json.Unmarshal(body, &listResp); err != nil {
return nil, fmt.Errorf("failed to parse list response: %w", err)
}
for _, cert := range listResp.Value {
results = append(results, struct {
ID string
Attributes struct {
Exp int64
}
}{
ID: cert.ID,
Attributes: struct {
Exp int64
}{Exp: cert.Attributes.Exp},
})
}
// Handle pagination
if listResp.NextLink == "" {
break
}
listURL = listResp.NextLink
}
return results, nil
}
// GetCertificate retrieves a specific certificate version from the vault.
func (c *httpKVClient) GetCertificate(ctx context.Context, vaultURL, certName, version string) (*certificateBundle, error) {
token, err := c.getAccessToken(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get access token: %w", err)
}
// Ensure vaultURL has no trailing slash
vaultURL = strings.TrimSuffix(vaultURL, "/")
// Build the certificate URL
// Format: https://myvault.vault.azure.net/certificates/mycert/version123?api-version=7.4
certURL := fmt.Sprintf("%s/certificates/%s/%s?api-version=7.4",
vaultURL, url.PathEscape(certName), url.PathEscape(version))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, certURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("get certificate request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("get certificate returned status %d: %s", resp.StatusCode, string(body))
}
var certBundle certificateBundle
if err := json.Unmarshal(body, &certBundle); err != nil {
return nil, fmt.Errorf("failed to parse certificate response: %w", err)
}
return &certBundle, nil
}
// getAccessToken returns a valid OAuth2 access token, refreshing if needed.
func (c *httpKVClient) getAccessToken(ctx context.Context) (string, error) {
c.mu.Lock()
defer c.mu.Unlock()
// Return cached token if still valid (5 min buffer)
if c.tokenCache != nil && time.Now().Add(5*time.Minute).Before(c.tokenCache.expiresAt) {
return c.tokenCache.token, nil
}
// Exchange client credentials for access token
tokenURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token",
url.PathEscape(c.config.TenantID))
form := url.Values{
"grant_type": {"client_credentials"},
"client_id": {c.config.ClientID},
"client_secret": {c.config.ClientSecret},
"scope": {"https://vault.azure.net/.default"},
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL,
strings.NewReader(form.Encode()))
if err != nil {
return "", fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("token request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read token response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("token request returned status %d: %s", resp.StatusCode, string(body))
}
var tokenResp struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
}
if err := json.Unmarshal(body, &tokenResp); err != nil {
return "", fmt.Errorf("failed to parse token response: %w", err)
}
if tokenResp.AccessToken == "" {
return "", fmt.Errorf("empty access token in response")
}
// Cache token
c.tokenCache = &cachedToken{
token: tokenResp.AccessToken,
expiresAt: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second),
}
return tokenResp.AccessToken, nil
}
// extractCertNameAndVersion extracts the certificate name and version from the Azure ID.
// ID format: https://myvault.vault.azure.net/certificates/mycert/version123
func extractCertNameAndVersion(id string) (name, version string, err error) {
// Use regex to extract name and version from the ID URL
// Pattern: /certificates/{name}/{version}
re := regexp.MustCompile(`/certificates/([^/]+)/([^/]+)$`)
matches := re.FindStringSubmatch(id)
if len(matches) != 3 {
return "", "", fmt.Errorf("cannot parse certificate ID: %s", id)
}
return matches[1], matches[2], nil
}
// extractCertMetadata extracts metadata from a parsed X.509 certificate.
func extractCertMetadata(cert *x509.Certificate, certName, version string) domain.DiscoveredCertEntry {
// Extract Subject Alternative Names (DNS names and email addresses)
sans := []string{}
sans = append(sans, cert.DNSNames...)
// Extract key algorithm
keyAlgo := "unknown"
keySize := 0
switch pub := cert.PublicKey.(type) {
case *rsa.PublicKey:
keyAlgo = "RSA"
keySize = pub.N.BitLen()
case *ecdsa.PublicKey:
keyAlgo = "ECDSA"
keySize = pub.Curve.Params().BitSize
}
// Compute SHA-256 fingerprint
fp := sha256.Sum256(cert.Raw)
fingerprint := fmt.Sprintf("%X", fp)
// Format times as RFC3339
notBefore := cert.NotBefore.UTC().Format(time.RFC3339)
notAfter := cert.NotAfter.UTC().Format(time.RFC3339)
return domain.DiscoveredCertEntry{
FingerprintSHA256: fingerprint,
CommonName: cert.Subject.CommonName,
SANs: sans,
SerialNumber: fmt.Sprintf("%x", cert.SerialNumber),
IssuerDN: cert.Issuer.String(),
SubjectDN: cert.Subject.String(),
NotBefore: notBefore,
NotAfter: notAfter,
KeyAlgorithm: keyAlgo,
KeySize: keySize,
IsCA: cert.IsCA,
SourcePath: fmt.Sprintf("azure-kv://%s/%s", certName, version),
SourceFormat: "DER",
}
}
// encodeCertPEM encodes a DER certificate as PEM.
func encodeCertPEM(derBytes []byte) string {
var buf bytes.Buffer
pem.Encode(&buf, &pem.Block{
Type: "CERTIFICATE",
Bytes: derBytes,
})
return buf.String()
}
// Ensure Source implements domain.DiscoverySource.
var _ domain.DiscoverySource = (*Source)(nil)
@@ -0,0 +1,597 @@
package azurekv
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/pem"
"fmt"
"log/slog"
"math/big"
"testing"
"time"
"github.com/shankar0123/certctl/internal/domain"
)
// TestValidateConfig_Success validates a correct configuration.
func TestValidateConfig_Success(t *testing.T) {
cfg := Config{
VaultURL: "https://myvault.vault.azure.net",
TenantID: "00000000-0000-0000-0000-000000000000",
ClientID: "11111111-1111-1111-1111-111111111111",
ClientSecret: "mysecret123",
}
src := &Source{config: cfg, logger: slog.Default()}
if err := src.ValidateConfig(); err != nil {
t.Fatalf("ValidateConfig failed: %v", err)
}
}
// TestValidateConfig_MissingVaultURL validates error when VaultURL is empty.
func TestValidateConfig_MissingVaultURL(t *testing.T) {
cfg := Config{
VaultURL: "",
TenantID: "tenant-id",
ClientID: "client-id",
ClientSecret: "secret",
}
src := &Source{config: cfg, logger: slog.Default()}
if err := src.ValidateConfig(); err == nil {
t.Fatal("expected error for missing VaultURL")
}
}
// TestValidateConfig_MissingTenantID validates error when TenantID is empty.
func TestValidateConfig_MissingTenantID(t *testing.T) {
cfg := Config{
VaultURL: "https://myvault.vault.azure.net",
TenantID: "",
ClientID: "client-id",
ClientSecret: "secret",
}
src := &Source{config: cfg, logger: slog.Default()}
if err := src.ValidateConfig(); err == nil {
t.Fatal("expected error for missing TenantID")
}
}
// TestValidateConfig_MissingClientID validates error when ClientID is empty.
func TestValidateConfig_MissingClientID(t *testing.T) {
cfg := Config{
VaultURL: "https://myvault.vault.azure.net",
TenantID: "tenant-id",
ClientID: "",
ClientSecret: "secret",
}
src := &Source{config: cfg, logger: slog.Default()}
if err := src.ValidateConfig(); err == nil {
t.Fatal("expected error for missing ClientID")
}
}
// TestValidateConfig_MissingClientSecret validates error when ClientSecret is empty.
func TestValidateConfig_MissingClientSecret(t *testing.T) {
cfg := Config{
VaultURL: "https://myvault.vault.azure.net",
TenantID: "tenant-id",
ClientID: "client-id",
ClientSecret: "",
}
src := &Source{config: cfg, logger: slog.Default()}
if err := src.ValidateConfig(); err == nil {
t.Fatal("expected error for missing ClientSecret")
}
}
// TestValidateConfig_InvalidURL validates error when VaultURL is not HTTPS.
func TestValidateConfig_InvalidURL(t *testing.T) {
cfg := Config{
VaultURL: "http://myvault.vault.azure.net",
TenantID: "tenant-id",
ClientID: "client-id",
ClientSecret: "secret",
}
src := &Source{config: cfg, logger: slog.Default()}
if err := src.ValidateConfig(); err == nil {
t.Fatal("expected error for non-HTTPS URL")
}
}
// mockKVClient implements KVClient for testing.
type mockKVClient struct {
certs map[string]*certificateBundle
err error
}
func (m *mockKVClient) ListCertificates(ctx context.Context, vaultURL string) ([]struct {
ID string
Attributes struct {
Exp int64
}
}, error) {
if m.err != nil {
return nil, m.err
}
var results []struct {
ID string
Attributes struct {
Exp int64
}
}
for id := range m.certs {
results = append(results, struct {
ID string
Attributes struct {
Exp int64
}
}{ID: id})
}
return results, nil
}
func (m *mockKVClient) GetCertificate(ctx context.Context, vaultURL, certName, version string) (*certificateBundle, error) {
if m.err != nil {
return nil, m.err
}
id := fmt.Sprintf("https://myvault.vault.azure.net/certificates/%s/%s", certName, version)
cert, ok := m.certs[id]
if !ok {
return nil, fmt.Errorf("certificate not found")
}
return cert, nil
}
// generateTestCert generates a test X.509 certificate.
func generateTestCert(cn string, sans []string) ([]byte, error) {
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}
serialNumber, err := rand.Int(rand.Reader, big.NewInt(0).Exp(big.NewInt(2), big.NewInt(64), nil))
if err != nil {
return nil, err
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: cn,
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: false,
DNSNames: sans,
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privKey.PublicKey, privKey)
if err != nil {
return nil, err
}
return derBytes, nil
}
// TestDiscover_Success validates successful certificate discovery.
func TestDiscover_Success(t *testing.T) {
// Generate test certificates
cert1DER, err := generateTestCert("example.com", []string{"www.example.com", "api.example.com"})
if err != nil {
t.Fatalf("failed to generate test cert: %v", err)
}
cert2DER, err := generateTestCert("test.example.com", []string{})
if err != nil {
t.Fatalf("failed to generate test cert: %v", err)
}
// Create mock client
mockClient := &mockKVClient{
certs: map[string]*certificateBundle{
"https://myvault.vault.azure.net/certificates/example/v1": {
ID: "https://myvault.vault.azure.net/certificates/example/v1",
CER: base64.StdEncoding.EncodeToString(cert1DER),
},
"https://myvault.vault.azure.net/certificates/test/v2": {
ID: "https://myvault.vault.azure.net/certificates/test/v2",
CER: base64.StdEncoding.EncodeToString(cert2DER),
},
},
}
cfg := Config{
VaultURL: "https://myvault.vault.azure.net",
TenantID: "tenant-id",
ClientID: "client-id",
ClientSecret: "secret",
}
src := NewWithClient(cfg, mockClient, slog.Default())
ctx := context.Background()
report, err := src.Discover(ctx)
if err != nil {
t.Fatalf("Discover failed: %v", err)
}
if report == nil {
t.Fatal("expected non-nil report")
}
if len(report.Certificates) != 2 {
t.Fatalf("expected 2 certificates, got %d", len(report.Certificates))
}
// Verify first cert metadata
if report.Certificates[0].CommonName == "" {
t.Fatal("expected common name in first cert")
}
// Verify PEM encoding
if report.Certificates[0].PEMData == "" {
t.Fatal("expected PEM data in first cert")
}
// Verify PEM is valid
block, _ := pem.Decode([]byte(report.Certificates[0].PEMData))
if block == nil {
t.Fatal("failed to decode PEM data")
}
}
// TestDiscover_ListError validates error handling when listing fails.
func TestDiscover_ListError(t *testing.T) {
mockClient := &mockKVClient{
err: fmt.Errorf("connection error"),
}
cfg := Config{
VaultURL: "https://myvault.vault.azure.net",
TenantID: "tenant-id",
ClientID: "client-id",
ClientSecret: "secret",
}
src := NewWithClient(cfg, mockClient, slog.Default())
ctx := context.Background()
report, err := src.Discover(ctx)
// Should return partial report with error
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if len(report.Errors) == 0 {
t.Fatal("expected errors in report")
}
}
// TestDiscover_EmptyResults validates handling of empty certificate list.
func TestDiscover_EmptyResults(t *testing.T) {
mockClient := &mockKVClient{
certs: map[string]*certificateBundle{},
}
cfg := Config{
VaultURL: "https://myvault.vault.azure.net",
TenantID: "tenant-id",
ClientID: "client-id",
ClientSecret: "secret",
}
src := NewWithClient(cfg, mockClient, slog.Default())
ctx := context.Background()
report, err := src.Discover(ctx)
if err != nil {
t.Fatalf("Discover failed: %v", err)
}
if len(report.Certificates) != 0 {
t.Fatalf("expected 0 certificates, got %d", len(report.Certificates))
}
if len(report.Errors) != 0 {
t.Fatalf("expected 0 errors, got %d", len(report.Errors))
}
}
// TestDiscover_InvalidCertData validates handling of invalid certificate data.
func TestDiscover_InvalidCertData(t *testing.T) {
// Generate one valid cert and one invalid
validDER, err := generateTestCert("valid.example.com", []string{})
if err != nil {
t.Fatalf("failed to generate test cert: %v", err)
}
mockClient := &mockKVClient{
certs: map[string]*certificateBundle{
"https://myvault.vault.azure.net/certificates/valid/v1": {
ID: "https://myvault.vault.azure.net/certificates/valid/v1",
CER: base64.StdEncoding.EncodeToString(validDER),
},
"https://myvault.vault.azure.net/certificates/invalid/v1": {
ID: "https://myvault.vault.azure.net/certificates/invalid/v1",
CER: "not-valid-base64!@#$%",
},
},
}
cfg := Config{
VaultURL: "https://myvault.vault.azure.net",
TenantID: "tenant-id",
ClientID: "client-id",
ClientSecret: "secret",
}
src := NewWithClient(cfg, mockClient, slog.Default())
ctx := context.Background()
report, err := src.Discover(ctx)
if err != nil {
t.Fatalf("Discover failed: %v", err)
}
// Should have 1 valid cert
if len(report.Certificates) != 1 {
t.Fatalf("expected 1 valid certificate, got %d", len(report.Certificates))
}
// Should have 1 error
if len(report.Errors) != 1 {
t.Fatalf("expected 1 error, got %d", len(report.Errors))
}
}
// TestDiscover_AgentIDAndSourcePath validates correct agent ID and source paths.
func TestDiscover_AgentIDAndSourcePath(t *testing.T) {
certDER, err := generateTestCert("test.example.com", []string{})
if err != nil {
t.Fatalf("failed to generate test cert: %v", err)
}
mockClient := &mockKVClient{
certs: map[string]*certificateBundle{
"https://myvault.vault.azure.net/certificates/mycert/v1": {
ID: "https://myvault.vault.azure.net/certificates/mycert/v1",
CER: base64.StdEncoding.EncodeToString(certDER),
},
},
}
cfg := Config{
VaultURL: "https://myvault.vault.azure.net",
TenantID: "tenant-id",
ClientID: "client-id",
ClientSecret: "secret",
}
src := NewWithClient(cfg, mockClient, slog.Default())
ctx := context.Background()
report, err := src.Discover(ctx)
if err != nil {
t.Fatalf("Discover failed: %v", err)
}
if report.AgentID != "cloud-azure-kv" {
t.Fatalf("expected agent_id 'cloud-azure-kv', got %s", report.AgentID)
}
if len(report.Directories) == 0 {
t.Fatal("expected directories in report")
}
if len(report.Certificates) > 0 {
cert := report.Certificates[0]
if !domain.IsValidDiscoveryStatus(cert.SourcePath) == false {
// SourcePath should follow azure-kv://certname/version format
if !contains(cert.SourcePath, "azure-kv://") {
t.Fatalf("expected source path to start with 'azure-kv://', got %s", cert.SourcePath)
}
}
}
}
// TestName validates the Name method.
func TestName(t *testing.T) {
src := &Source{
config: Config{},
logger: slog.Default(),
}
expected := "Azure Key Vault"
if src.Name() != expected {
t.Fatalf("expected Name '%s', got '%s'", expected, src.Name())
}
}
// TestType validates the Type method.
func TestType(t *testing.T) {
src := &Source{
config: Config{},
logger: slog.Default(),
}
expected := "azure-kv"
if src.Type() != expected {
t.Fatalf("expected Type '%s', got '%s'", expected, src.Type())
}
}
// TestExtractCertNameAndVersion validates certificate ID parsing.
func TestExtractCertNameAndVersion(t *testing.T) {
tests := []struct {
id string
wantName string
wantVer string
wantErr bool
}{
{
id: "https://myvault.vault.azure.net/certificates/example/v1",
wantName: "example",
wantVer: "v1",
wantErr: false,
},
{
id: "https://myvault.vault.azure.net/certificates/my-cert/version123",
wantName: "my-cert",
wantVer: "version123",
wantErr: false,
},
{
id: "invalid-id",
wantErr: true,
},
}
for _, tt := range tests {
name, ver, err := extractCertNameAndVersion(tt.id)
if (err != nil) != tt.wantErr {
t.Fatalf("extractCertNameAndVersion(%s) error = %v, wantErr %v", tt.id, err, tt.wantErr)
}
if !tt.wantErr {
if name != tt.wantName || ver != tt.wantVer {
t.Fatalf("extractCertNameAndVersion(%s) = (%s, %s), want (%s, %s)",
tt.id, name, ver, tt.wantName, tt.wantVer)
}
}
}
}
// TestExtractCertMetadata validates certificate metadata extraction.
func TestExtractCertMetadata(t *testing.T) {
// Generate a test certificate
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("failed to generate key: %v", err)
}
serialNumber := big.NewInt(123456)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: "test.example.com",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: false,
DNSNames: []string{"test.example.com", "www.test.example.com"},
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privKey.PublicKey, privKey)
if err != nil {
t.Fatalf("failed to create cert: %v", err)
}
cert, err := x509.ParseCertificate(derBytes)
if err != nil {
t.Fatalf("failed to parse cert: %v", err)
}
entry := extractCertMetadata(cert, "testcert", "v1")
if entry.CommonName != "test.example.com" {
t.Fatalf("expected CN 'test.example.com', got %s", entry.CommonName)
}
if len(entry.SANs) != 2 {
t.Fatalf("expected 2 SANs, got %d", len(entry.SANs))
}
if entry.KeyAlgorithm != "ECDSA" {
t.Fatalf("expected key algorithm ECDSA, got %s", entry.KeyAlgorithm)
}
if entry.KeySize != 256 {
t.Fatalf("expected key size 256, got %d", entry.KeySize)
}
if entry.SerialNumber == "" {
t.Fatal("expected serial number, got empty")
}
if entry.SourceFormat != "DER" {
t.Fatalf("expected source format DER, got %s", entry.SourceFormat)
}
// Verify fingerprint is valid hex
if len(entry.FingerprintSHA256) != 64 {
t.Fatalf("expected 64-char fingerprint, got %d chars", len(entry.FingerprintSHA256))
}
// Verify manually calculated fingerprint
fp := sha256.Sum256(derBytes)
expectedFP := fmt.Sprintf("%X", fp)
if entry.FingerprintSHA256 != expectedFP {
t.Fatalf("fingerprint mismatch: got %s, want %s", entry.FingerprintSHA256, expectedFP)
}
}
// TestEncodeCertPEM validates PEM encoding.
func TestEncodeCertPEM(t *testing.T) {
derBytes, err := generateTestCert("test.example.com", []string{})
if err != nil {
t.Fatalf("failed to generate test cert: %v", err)
}
pemStr := encodeCertPEM(derBytes)
// Verify PEM format
if !contains(pemStr, "-----BEGIN CERTIFICATE-----") {
t.Fatal("expected PEM header")
}
if !contains(pemStr, "-----END CERTIFICATE-----") {
t.Fatal("expected PEM footer")
}
// Verify we can decode it back
block, _ := pem.Decode([]byte(pemStr))
if block == nil {
t.Fatal("failed to decode PEM")
}
if len(block.Bytes) != len(derBytes) {
t.Fatal("decoded PEM does not match original DER")
}
}
// Helper function to check if string contains substring
func contains(s, substr string) bool {
return len(s) > 0 && len(substr) > 0 && s != substr &&
(s == substr || len(s) > len(substr))
}