mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-14 07:48:57 +00:00
feat(M50): cloud secret manager discovery — AWS SM, Azure KV, GCP SM
Extend certificate discovery from filesystem + network to cloud secret managers. Three pluggable DiscoverySource connectors feed into the existing discovery pipeline via sentinel agent pattern, with a 9th scheduler loop for periodic cloud scanning. - AWS Secrets Manager: aws-sdk-go-v2, tag/prefix filtering, 10 tests - Azure Key Vault: stdlib HTTP + OAuth2, base64 DER/PEM, 16 tests - GCP Secret Manager: stdlib HTTP + JWT OAuth2, label filter, 14 tests - CloudDiscoveryService orchestrator with 9 tests - 9th scheduler loop (6h default, atomic.Bool idempotency) - Discovery page: color-coded source type badges - 14 new env vars across CloudDiscoveryConfig structs - Docs: connectors.md, architecture.md, features.md, README updated 49 new tests. All CI checks pass (go vet, race, lint, coverage). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,515 @@
|
||||
// Package azurekv implements the domain.DiscoverySource interface for
|
||||
// Azure Key Vault certificate discovery.
|
||||
//
|
||||
// Azure Key Vault is a cloud-based secret and certificate management service.
|
||||
// This connector discovers certificates stored in an Azure Key Vault using the
|
||||
// Azure Key Vault REST API with OAuth2 client credentials authentication.
|
||||
//
|
||||
// No Azure SDK dependency — uses stdlib net/http + OAuth2 for authentication.
|
||||
//
|
||||
// API endpoints used:
|
||||
//
|
||||
// GET /certificates?api-version=7.4 - List certificates
|
||||
// GET /certificates/{name}/{version}?api-version=7.4 - Get certificate details
|
||||
//
|
||||
// Authentication: OAuth2 client credentials flow via Azure AD.
|
||||
// Token is cached with 5-minute refresh buffer.
|
||||
package azurekv
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
// Config represents the Azure Key Vault discovery configuration.
|
||||
type Config struct {
|
||||
// VaultURL is the Azure Key Vault URL (e.g., "https://myvault.vault.azure.net").
|
||||
// Required. Set via CERTCTL_AZURE_KV_VAULT_URL environment variable.
|
||||
VaultURL string `json:"vault_url"`
|
||||
|
||||
// TenantID is the Azure AD tenant ID (e.g., "00000000-0000-0000-0000-000000000000").
|
||||
// Required. Set via CERTCTL_AZURE_KV_TENANT_ID environment variable.
|
||||
TenantID string `json:"tenant_id"`
|
||||
|
||||
// ClientID is the Azure AD application (client) ID.
|
||||
// Required. Set via CERTCTL_AZURE_KV_CLIENT_ID environment variable.
|
||||
ClientID string `json:"client_id"`
|
||||
|
||||
// ClientSecret is the Azure AD application secret or certificate.
|
||||
// Required. Set via CERTCTL_AZURE_KV_CLIENT_SECRET environment variable.
|
||||
ClientSecret string `json:"client_secret"`
|
||||
}
|
||||
|
||||
// cachedToken holds an OAuth2 access token and its expiry time.
|
||||
type cachedToken struct {
|
||||
token string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// certificateListResponse represents the Azure Key Vault list certificates response.
|
||||
type certificateListResponse struct {
|
||||
Value []struct {
|
||||
ID string `json:"id"`
|
||||
Attributes struct {
|
||||
Enabled int64 `json:"enabled"`
|
||||
Created int64 `json:"created"`
|
||||
Updated int64 `json:"updated"`
|
||||
Exp int64 `json:"exp"`
|
||||
} `json:"attributes,omitempty"`
|
||||
Tags map[string]string `json:"tags,omitempty"`
|
||||
} `json:"value"`
|
||||
NextLink string `json:"nextLink"`
|
||||
}
|
||||
|
||||
// certificateBundle represents the Azure Key Vault certificate details response.
|
||||
type certificateBundle struct {
|
||||
ID string `json:"id"`
|
||||
CER string `json:"cer"`
|
||||
Attributes struct {
|
||||
Enabled int64 `json:"enabled"`
|
||||
Created int64 `json:"created"`
|
||||
Updated int64 `json:"updated"`
|
||||
Exp int64 `json:"exp"`
|
||||
} `json:"attributes,omitempty"`
|
||||
}
|
||||
|
||||
// KVClient is an interface for Azure Key Vault operations, allowing injection for testing.
|
||||
type KVClient interface {
|
||||
// ListCertificates retrieves the list of certificates in the vault.
|
||||
ListCertificates(ctx context.Context, vaultURL string) ([]struct {
|
||||
ID string
|
||||
Attributes struct {
|
||||
Exp int64
|
||||
}
|
||||
}, error)
|
||||
// GetCertificate retrieves a specific certificate version.
|
||||
GetCertificate(ctx context.Context, vaultURL, certName, version string) (*certificateBundle, error)
|
||||
}
|
||||
|
||||
// Source implements domain.DiscoverySource for Azure Key Vault.
|
||||
type Source struct {
|
||||
config Config
|
||||
logger *slog.Logger
|
||||
client KVClient
|
||||
}
|
||||
|
||||
// New creates a new Azure Key Vault discovery source with real HTTP client.
|
||||
func New(cfg Config, logger *slog.Logger) *Source {
|
||||
return &Source{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
client: &httpKVClient{
|
||||
config: cfg,
|
||||
httpClient: &http.Client{Timeout: 30 * time.Second},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewWithClient creates a new Azure Key Vault discovery source with injected client (for testing).
|
||||
func NewWithClient(cfg Config, client KVClient, logger *slog.Logger) *Source {
|
||||
return &Source{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns a human-readable name for this discovery source.
|
||||
func (s *Source) Name() string {
|
||||
return "Azure Key Vault"
|
||||
}
|
||||
|
||||
// Type returns the short type identifier for this discovery source.
|
||||
func (s *Source) Type() string {
|
||||
return "azure-kv"
|
||||
}
|
||||
|
||||
// ValidateConfig checks that the Azure Key Vault configuration is valid.
|
||||
func (s *Source) ValidateConfig() error {
|
||||
if s.config.VaultURL == "" {
|
||||
return fmt.Errorf("Azure Key Vault URL is required")
|
||||
}
|
||||
if s.config.TenantID == "" {
|
||||
return fmt.Errorf("Azure Key Vault tenant ID is required")
|
||||
}
|
||||
if s.config.ClientID == "" {
|
||||
return fmt.Errorf("Azure Key Vault client ID is required")
|
||||
}
|
||||
if s.config.ClientSecret == "" {
|
||||
return fmt.Errorf("Azure Key Vault client secret is required")
|
||||
}
|
||||
|
||||
// Basic URL validation
|
||||
if !strings.HasPrefix(s.config.VaultURL, "https://") {
|
||||
return fmt.Errorf("Azure Key Vault URL must use HTTPS")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Discover scans the Azure Key Vault and returns a DiscoveryReport.
|
||||
func (s *Source) Discover(ctx context.Context) (*domain.DiscoveryReport, error) {
|
||||
s.logger.Info("starting Azure Key Vault discovery", "vault_url", s.config.VaultURL)
|
||||
|
||||
report := &domain.DiscoveryReport{
|
||||
AgentID: "cloud-azure-kv",
|
||||
Directories: []string{fmt.Sprintf("azure-kv://%s/", s.config.VaultURL)},
|
||||
Certificates: []domain.DiscoveredCertEntry{},
|
||||
Errors: []string{},
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// List certificates
|
||||
certs, err := s.client.ListCertificates(ctx, s.config.VaultURL)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to list Azure Key Vault certificates", "error", err)
|
||||
report.Errors = append(report.Errors, fmt.Sprintf("list certificates failed: %v", err))
|
||||
return report, nil
|
||||
}
|
||||
|
||||
// Process each certificate
|
||||
for _, cert := range certs {
|
||||
// Extract certificate name and version from ID
|
||||
// ID format: https://myvault.vault.azure.net/certificates/mycert/version123
|
||||
certName, version, err := extractCertNameAndVersion(cert.ID)
|
||||
if err != nil {
|
||||
s.logger.Warn("failed to parse certificate ID", "id", cert.ID, "error", err)
|
||||
report.Errors = append(report.Errors, fmt.Sprintf("parse cert ID failed: %v", err))
|
||||
continue
|
||||
}
|
||||
|
||||
// Get certificate details
|
||||
certBundle, err := s.client.GetCertificate(ctx, s.config.VaultURL, certName, version)
|
||||
if err != nil {
|
||||
s.logger.Warn("failed to get certificate details", "name", certName, "version", version, "error", err)
|
||||
report.Errors = append(report.Errors, fmt.Sprintf("get cert %s/%s failed: %v", certName, version, err))
|
||||
continue
|
||||
}
|
||||
|
||||
// Decode the base64-encoded DER certificate
|
||||
if certBundle.CER == "" {
|
||||
s.logger.Warn("empty certificate data", "name", certName, "version", version)
|
||||
continue
|
||||
}
|
||||
|
||||
derBytes, err := base64.StdEncoding.DecodeString(certBundle.CER)
|
||||
if err != nil {
|
||||
s.logger.Warn("failed to decode certificate", "name", certName, "version", version, "error", err)
|
||||
report.Errors = append(report.Errors, fmt.Sprintf("decode cert %s/%s failed: %v", certName, version, err))
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse certificate
|
||||
x509Cert, err := x509.ParseCertificate(derBytes)
|
||||
if err != nil {
|
||||
s.logger.Warn("failed to parse certificate", "name", certName, "version", version, "error", err)
|
||||
report.Errors = append(report.Errors, fmt.Sprintf("parse cert %s/%s failed: %v", certName, version, err))
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract certificate metadata
|
||||
entry := extractCertMetadata(x509Cert, certName, version)
|
||||
|
||||
// Encode as PEM for inclusion in report
|
||||
certPEM := encodeCertPEM(derBytes)
|
||||
entry.PEMData = certPEM
|
||||
|
||||
report.Certificates = append(report.Certificates, entry)
|
||||
s.logger.Info("discovered certificate",
|
||||
"name", certName,
|
||||
"common_name", entry.CommonName,
|
||||
"serial", entry.SerialNumber,
|
||||
"not_after", entry.NotAfter)
|
||||
}
|
||||
|
||||
report.ScanDurationMs = int(time.Since(startTime).Milliseconds())
|
||||
|
||||
s.logger.Info("Azure Key Vault discovery completed",
|
||||
"certs_found", len(report.Certificates),
|
||||
"errors", len(report.Errors),
|
||||
"duration_ms", report.ScanDurationMs)
|
||||
|
||||
return report, nil
|
||||
}
|
||||
|
||||
// httpKVClient implements KVClient using Azure Key Vault REST API.
|
||||
type httpKVClient struct {
|
||||
config Config
|
||||
httpClient *http.Client
|
||||
|
||||
// OAuth2 token caching
|
||||
mu sync.Mutex
|
||||
tokenCache *cachedToken
|
||||
}
|
||||
|
||||
// ListCertificates retrieves the list of certificates in the vault.
|
||||
func (c *httpKVClient) ListCertificates(ctx context.Context, vaultURL string) ([]struct {
|
||||
ID string
|
||||
Attributes struct {
|
||||
Exp int64
|
||||
}
|
||||
}, error) {
|
||||
var results []struct {
|
||||
ID string
|
||||
Attributes struct {
|
||||
Exp int64
|
||||
}
|
||||
}
|
||||
|
||||
listURL := fmt.Sprintf("%s/certificates?api-version=7.4", strings.TrimSuffix(vaultURL, "/"))
|
||||
|
||||
for listURL != "" {
|
||||
token, err := c.getAccessToken(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get access token: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, listURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("list certificates returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var listResp certificateListResponse
|
||||
if err := json.Unmarshal(body, &listResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse list response: %w", err)
|
||||
}
|
||||
|
||||
for _, cert := range listResp.Value {
|
||||
results = append(results, struct {
|
||||
ID string
|
||||
Attributes struct {
|
||||
Exp int64
|
||||
}
|
||||
}{
|
||||
ID: cert.ID,
|
||||
Attributes: struct {
|
||||
Exp int64
|
||||
}{Exp: cert.Attributes.Exp},
|
||||
})
|
||||
}
|
||||
|
||||
// Handle pagination
|
||||
if listResp.NextLink == "" {
|
||||
break
|
||||
}
|
||||
listURL = listResp.NextLink
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetCertificate retrieves a specific certificate version from the vault.
|
||||
func (c *httpKVClient) GetCertificate(ctx context.Context, vaultURL, certName, version string) (*certificateBundle, error) {
|
||||
token, err := c.getAccessToken(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get access token: %w", err)
|
||||
}
|
||||
|
||||
// Ensure vaultURL has no trailing slash
|
||||
vaultURL = strings.TrimSuffix(vaultURL, "/")
|
||||
|
||||
// Build the certificate URL
|
||||
// Format: https://myvault.vault.azure.net/certificates/mycert/version123?api-version=7.4
|
||||
certURL := fmt.Sprintf("%s/certificates/%s/%s?api-version=7.4",
|
||||
vaultURL, url.PathEscape(certName), url.PathEscape(version))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, certURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get certificate request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("get certificate returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var certBundle certificateBundle
|
||||
if err := json.Unmarshal(body, &certBundle); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse certificate response: %w", err)
|
||||
}
|
||||
|
||||
return &certBundle, nil
|
||||
}
|
||||
|
||||
// getAccessToken returns a valid OAuth2 access token, refreshing if needed.
|
||||
func (c *httpKVClient) getAccessToken(ctx context.Context) (string, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Return cached token if still valid (5 min buffer)
|
||||
if c.tokenCache != nil && time.Now().Add(5*time.Minute).Before(c.tokenCache.expiresAt) {
|
||||
return c.tokenCache.token, nil
|
||||
}
|
||||
|
||||
// Exchange client credentials for access token
|
||||
tokenURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token",
|
||||
url.PathEscape(c.config.TenantID))
|
||||
|
||||
form := url.Values{
|
||||
"grant_type": {"client_credentials"},
|
||||
"client_id": {c.config.ClientID},
|
||||
"client_secret": {c.config.ClientSecret},
|
||||
"scope": {"https://vault.azure.net/.default"},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL,
|
||||
strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("token request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read token response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("token request returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" {
|
||||
return "", fmt.Errorf("empty access token in response")
|
||||
}
|
||||
|
||||
// Cache token
|
||||
c.tokenCache = &cachedToken{
|
||||
token: tokenResp.AccessToken,
|
||||
expiresAt: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second),
|
||||
}
|
||||
|
||||
return tokenResp.AccessToken, nil
|
||||
}
|
||||
|
||||
// extractCertNameAndVersion extracts the certificate name and version from the Azure ID.
|
||||
// ID format: https://myvault.vault.azure.net/certificates/mycert/version123
|
||||
func extractCertNameAndVersion(id string) (name, version string, err error) {
|
||||
// Use regex to extract name and version from the ID URL
|
||||
// Pattern: /certificates/{name}/{version}
|
||||
re := regexp.MustCompile(`/certificates/([^/]+)/([^/]+)$`)
|
||||
matches := re.FindStringSubmatch(id)
|
||||
|
||||
if len(matches) != 3 {
|
||||
return "", "", fmt.Errorf("cannot parse certificate ID: %s", id)
|
||||
}
|
||||
|
||||
return matches[1], matches[2], nil
|
||||
}
|
||||
|
||||
// extractCertMetadata extracts metadata from a parsed X.509 certificate.
|
||||
func extractCertMetadata(cert *x509.Certificate, certName, version string) domain.DiscoveredCertEntry {
|
||||
// Extract Subject Alternative Names (DNS names and email addresses)
|
||||
sans := []string{}
|
||||
sans = append(sans, cert.DNSNames...)
|
||||
|
||||
// Extract key algorithm
|
||||
keyAlgo := "unknown"
|
||||
keySize := 0
|
||||
|
||||
switch pub := cert.PublicKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
keyAlgo = "RSA"
|
||||
keySize = pub.N.BitLen()
|
||||
case *ecdsa.PublicKey:
|
||||
keyAlgo = "ECDSA"
|
||||
keySize = pub.Curve.Params().BitSize
|
||||
}
|
||||
|
||||
// Compute SHA-256 fingerprint
|
||||
fp := sha256.Sum256(cert.Raw)
|
||||
fingerprint := fmt.Sprintf("%X", fp)
|
||||
|
||||
// Format times as RFC3339
|
||||
notBefore := cert.NotBefore.UTC().Format(time.RFC3339)
|
||||
notAfter := cert.NotAfter.UTC().Format(time.RFC3339)
|
||||
|
||||
return domain.DiscoveredCertEntry{
|
||||
FingerprintSHA256: fingerprint,
|
||||
CommonName: cert.Subject.CommonName,
|
||||
SANs: sans,
|
||||
SerialNumber: fmt.Sprintf("%x", cert.SerialNumber),
|
||||
IssuerDN: cert.Issuer.String(),
|
||||
SubjectDN: cert.Subject.String(),
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
KeyAlgorithm: keyAlgo,
|
||||
KeySize: keySize,
|
||||
IsCA: cert.IsCA,
|
||||
SourcePath: fmt.Sprintf("azure-kv://%s/%s", certName, version),
|
||||
SourceFormat: "DER",
|
||||
}
|
||||
}
|
||||
|
||||
// encodeCertPEM encodes a DER certificate as PEM.
|
||||
func encodeCertPEM(derBytes []byte) string {
|
||||
var buf bytes.Buffer
|
||||
pem.Encode(&buf, &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: derBytes,
|
||||
})
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// Ensure Source implements domain.DiscoverySource.
|
||||
var _ domain.DiscoverySource = (*Source)(nil)
|
||||
@@ -0,0 +1,597 @@
|
||||
package azurekv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/big"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
// TestValidateConfig_Success validates a correct configuration.
|
||||
func TestValidateConfig_Success(t *testing.T) {
|
||||
cfg := Config{
|
||||
VaultURL: "https://myvault.vault.azure.net",
|
||||
TenantID: "00000000-0000-0000-0000-000000000000",
|
||||
ClientID: "11111111-1111-1111-1111-111111111111",
|
||||
ClientSecret: "mysecret123",
|
||||
}
|
||||
|
||||
src := &Source{config: cfg, logger: slog.Default()}
|
||||
|
||||
if err := src.ValidateConfig(); err != nil {
|
||||
t.Fatalf("ValidateConfig failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateConfig_MissingVaultURL validates error when VaultURL is empty.
|
||||
func TestValidateConfig_MissingVaultURL(t *testing.T) {
|
||||
cfg := Config{
|
||||
VaultURL: "",
|
||||
TenantID: "tenant-id",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "secret",
|
||||
}
|
||||
|
||||
src := &Source{config: cfg, logger: slog.Default()}
|
||||
|
||||
if err := src.ValidateConfig(); err == nil {
|
||||
t.Fatal("expected error for missing VaultURL")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateConfig_MissingTenantID validates error when TenantID is empty.
|
||||
func TestValidateConfig_MissingTenantID(t *testing.T) {
|
||||
cfg := Config{
|
||||
VaultURL: "https://myvault.vault.azure.net",
|
||||
TenantID: "",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "secret",
|
||||
}
|
||||
|
||||
src := &Source{config: cfg, logger: slog.Default()}
|
||||
|
||||
if err := src.ValidateConfig(); err == nil {
|
||||
t.Fatal("expected error for missing TenantID")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateConfig_MissingClientID validates error when ClientID is empty.
|
||||
func TestValidateConfig_MissingClientID(t *testing.T) {
|
||||
cfg := Config{
|
||||
VaultURL: "https://myvault.vault.azure.net",
|
||||
TenantID: "tenant-id",
|
||||
ClientID: "",
|
||||
ClientSecret: "secret",
|
||||
}
|
||||
|
||||
src := &Source{config: cfg, logger: slog.Default()}
|
||||
|
||||
if err := src.ValidateConfig(); err == nil {
|
||||
t.Fatal("expected error for missing ClientID")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateConfig_MissingClientSecret validates error when ClientSecret is empty.
|
||||
func TestValidateConfig_MissingClientSecret(t *testing.T) {
|
||||
cfg := Config{
|
||||
VaultURL: "https://myvault.vault.azure.net",
|
||||
TenantID: "tenant-id",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "",
|
||||
}
|
||||
|
||||
src := &Source{config: cfg, logger: slog.Default()}
|
||||
|
||||
if err := src.ValidateConfig(); err == nil {
|
||||
t.Fatal("expected error for missing ClientSecret")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateConfig_InvalidURL validates error when VaultURL is not HTTPS.
|
||||
func TestValidateConfig_InvalidURL(t *testing.T) {
|
||||
cfg := Config{
|
||||
VaultURL: "http://myvault.vault.azure.net",
|
||||
TenantID: "tenant-id",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "secret",
|
||||
}
|
||||
|
||||
src := &Source{config: cfg, logger: slog.Default()}
|
||||
|
||||
if err := src.ValidateConfig(); err == nil {
|
||||
t.Fatal("expected error for non-HTTPS URL")
|
||||
}
|
||||
}
|
||||
|
||||
// mockKVClient implements KVClient for testing.
|
||||
type mockKVClient struct {
|
||||
certs map[string]*certificateBundle
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockKVClient) ListCertificates(ctx context.Context, vaultURL string) ([]struct {
|
||||
ID string
|
||||
Attributes struct {
|
||||
Exp int64
|
||||
}
|
||||
}, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
var results []struct {
|
||||
ID string
|
||||
Attributes struct {
|
||||
Exp int64
|
||||
}
|
||||
}
|
||||
|
||||
for id := range m.certs {
|
||||
results = append(results, struct {
|
||||
ID string
|
||||
Attributes struct {
|
||||
Exp int64
|
||||
}
|
||||
}{ID: id})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (m *mockKVClient) GetCertificate(ctx context.Context, vaultURL, certName, version string) (*certificateBundle, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
id := fmt.Sprintf("https://myvault.vault.azure.net/certificates/%s/%s", certName, version)
|
||||
cert, ok := m.certs[id]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("certificate not found")
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// generateTestCert generates a test X.509 certificate.
|
||||
func generateTestCert(cn string, sans []string) ([]byte, error) {
|
||||
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serialNumber, err := rand.Int(rand.Reader, big.NewInt(0).Exp(big.NewInt(2), big.NewInt(64), nil))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: cn,
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: false,
|
||||
DNSNames: sans,
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privKey.PublicKey, privKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return derBytes, nil
|
||||
}
|
||||
|
||||
// TestDiscover_Success validates successful certificate discovery.
|
||||
func TestDiscover_Success(t *testing.T) {
|
||||
// Generate test certificates
|
||||
cert1DER, err := generateTestCert("example.com", []string{"www.example.com", "api.example.com"})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate test cert: %v", err)
|
||||
}
|
||||
|
||||
cert2DER, err := generateTestCert("test.example.com", []string{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate test cert: %v", err)
|
||||
}
|
||||
|
||||
// Create mock client
|
||||
mockClient := &mockKVClient{
|
||||
certs: map[string]*certificateBundle{
|
||||
"https://myvault.vault.azure.net/certificates/example/v1": {
|
||||
ID: "https://myvault.vault.azure.net/certificates/example/v1",
|
||||
CER: base64.StdEncoding.EncodeToString(cert1DER),
|
||||
},
|
||||
"https://myvault.vault.azure.net/certificates/test/v2": {
|
||||
ID: "https://myvault.vault.azure.net/certificates/test/v2",
|
||||
CER: base64.StdEncoding.EncodeToString(cert2DER),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg := Config{
|
||||
VaultURL: "https://myvault.vault.azure.net",
|
||||
TenantID: "tenant-id",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "secret",
|
||||
}
|
||||
|
||||
src := NewWithClient(cfg, mockClient, slog.Default())
|
||||
|
||||
ctx := context.Background()
|
||||
report, err := src.Discover(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Discover failed: %v", err)
|
||||
}
|
||||
|
||||
if report == nil {
|
||||
t.Fatal("expected non-nil report")
|
||||
}
|
||||
|
||||
if len(report.Certificates) != 2 {
|
||||
t.Fatalf("expected 2 certificates, got %d", len(report.Certificates))
|
||||
}
|
||||
|
||||
// Verify first cert metadata
|
||||
if report.Certificates[0].CommonName == "" {
|
||||
t.Fatal("expected common name in first cert")
|
||||
}
|
||||
|
||||
// Verify PEM encoding
|
||||
if report.Certificates[0].PEMData == "" {
|
||||
t.Fatal("expected PEM data in first cert")
|
||||
}
|
||||
|
||||
// Verify PEM is valid
|
||||
block, _ := pem.Decode([]byte(report.Certificates[0].PEMData))
|
||||
if block == nil {
|
||||
t.Fatal("failed to decode PEM data")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDiscover_ListError validates error handling when listing fails.
|
||||
func TestDiscover_ListError(t *testing.T) {
|
||||
mockClient := &mockKVClient{
|
||||
err: fmt.Errorf("connection error"),
|
||||
}
|
||||
|
||||
cfg := Config{
|
||||
VaultURL: "https://myvault.vault.azure.net",
|
||||
TenantID: "tenant-id",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "secret",
|
||||
}
|
||||
|
||||
src := NewWithClient(cfg, mockClient, slog.Default())
|
||||
|
||||
ctx := context.Background()
|
||||
report, err := src.Discover(ctx)
|
||||
|
||||
// Should return partial report with error
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if len(report.Errors) == 0 {
|
||||
t.Fatal("expected errors in report")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDiscover_EmptyResults validates handling of empty certificate list.
|
||||
func TestDiscover_EmptyResults(t *testing.T) {
|
||||
mockClient := &mockKVClient{
|
||||
certs: map[string]*certificateBundle{},
|
||||
}
|
||||
|
||||
cfg := Config{
|
||||
VaultURL: "https://myvault.vault.azure.net",
|
||||
TenantID: "tenant-id",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "secret",
|
||||
}
|
||||
|
||||
src := NewWithClient(cfg, mockClient, slog.Default())
|
||||
|
||||
ctx := context.Background()
|
||||
report, err := src.Discover(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Discover failed: %v", err)
|
||||
}
|
||||
|
||||
if len(report.Certificates) != 0 {
|
||||
t.Fatalf("expected 0 certificates, got %d", len(report.Certificates))
|
||||
}
|
||||
|
||||
if len(report.Errors) != 0 {
|
||||
t.Fatalf("expected 0 errors, got %d", len(report.Errors))
|
||||
}
|
||||
}
|
||||
|
||||
// TestDiscover_InvalidCertData validates handling of invalid certificate data.
|
||||
func TestDiscover_InvalidCertData(t *testing.T) {
|
||||
// Generate one valid cert and one invalid
|
||||
validDER, err := generateTestCert("valid.example.com", []string{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate test cert: %v", err)
|
||||
}
|
||||
|
||||
mockClient := &mockKVClient{
|
||||
certs: map[string]*certificateBundle{
|
||||
"https://myvault.vault.azure.net/certificates/valid/v1": {
|
||||
ID: "https://myvault.vault.azure.net/certificates/valid/v1",
|
||||
CER: base64.StdEncoding.EncodeToString(validDER),
|
||||
},
|
||||
"https://myvault.vault.azure.net/certificates/invalid/v1": {
|
||||
ID: "https://myvault.vault.azure.net/certificates/invalid/v1",
|
||||
CER: "not-valid-base64!@#$%",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg := Config{
|
||||
VaultURL: "https://myvault.vault.azure.net",
|
||||
TenantID: "tenant-id",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "secret",
|
||||
}
|
||||
|
||||
src := NewWithClient(cfg, mockClient, slog.Default())
|
||||
|
||||
ctx := context.Background()
|
||||
report, err := src.Discover(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Discover failed: %v", err)
|
||||
}
|
||||
|
||||
// Should have 1 valid cert
|
||||
if len(report.Certificates) != 1 {
|
||||
t.Fatalf("expected 1 valid certificate, got %d", len(report.Certificates))
|
||||
}
|
||||
|
||||
// Should have 1 error
|
||||
if len(report.Errors) != 1 {
|
||||
t.Fatalf("expected 1 error, got %d", len(report.Errors))
|
||||
}
|
||||
}
|
||||
|
||||
// TestDiscover_AgentIDAndSourcePath validates correct agent ID and source paths.
|
||||
func TestDiscover_AgentIDAndSourcePath(t *testing.T) {
|
||||
certDER, err := generateTestCert("test.example.com", []string{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate test cert: %v", err)
|
||||
}
|
||||
|
||||
mockClient := &mockKVClient{
|
||||
certs: map[string]*certificateBundle{
|
||||
"https://myvault.vault.azure.net/certificates/mycert/v1": {
|
||||
ID: "https://myvault.vault.azure.net/certificates/mycert/v1",
|
||||
CER: base64.StdEncoding.EncodeToString(certDER),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg := Config{
|
||||
VaultURL: "https://myvault.vault.azure.net",
|
||||
TenantID: "tenant-id",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "secret",
|
||||
}
|
||||
|
||||
src := NewWithClient(cfg, mockClient, slog.Default())
|
||||
|
||||
ctx := context.Background()
|
||||
report, err := src.Discover(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Discover failed: %v", err)
|
||||
}
|
||||
|
||||
if report.AgentID != "cloud-azure-kv" {
|
||||
t.Fatalf("expected agent_id 'cloud-azure-kv', got %s", report.AgentID)
|
||||
}
|
||||
|
||||
if len(report.Directories) == 0 {
|
||||
t.Fatal("expected directories in report")
|
||||
}
|
||||
|
||||
if len(report.Certificates) > 0 {
|
||||
cert := report.Certificates[0]
|
||||
if !domain.IsValidDiscoveryStatus(cert.SourcePath) == false {
|
||||
// SourcePath should follow azure-kv://certname/version format
|
||||
if !contains(cert.SourcePath, "azure-kv://") {
|
||||
t.Fatalf("expected source path to start with 'azure-kv://', got %s", cert.SourcePath)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestName validates the Name method.
|
||||
func TestName(t *testing.T) {
|
||||
src := &Source{
|
||||
config: Config{},
|
||||
logger: slog.Default(),
|
||||
}
|
||||
|
||||
expected := "Azure Key Vault"
|
||||
if src.Name() != expected {
|
||||
t.Fatalf("expected Name '%s', got '%s'", expected, src.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// TestType validates the Type method.
|
||||
func TestType(t *testing.T) {
|
||||
src := &Source{
|
||||
config: Config{},
|
||||
logger: slog.Default(),
|
||||
}
|
||||
|
||||
expected := "azure-kv"
|
||||
if src.Type() != expected {
|
||||
t.Fatalf("expected Type '%s', got '%s'", expected, src.Type())
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractCertNameAndVersion validates certificate ID parsing.
|
||||
func TestExtractCertNameAndVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
id string
|
||||
wantName string
|
||||
wantVer string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
id: "https://myvault.vault.azure.net/certificates/example/v1",
|
||||
wantName: "example",
|
||||
wantVer: "v1",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
id: "https://myvault.vault.azure.net/certificates/my-cert/version123",
|
||||
wantName: "my-cert",
|
||||
wantVer: "version123",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
id: "invalid-id",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
name, ver, err := extractCertNameAndVersion(tt.id)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("extractCertNameAndVersion(%s) error = %v, wantErr %v", tt.id, err, tt.wantErr)
|
||||
}
|
||||
if !tt.wantErr {
|
||||
if name != tt.wantName || ver != tt.wantVer {
|
||||
t.Fatalf("extractCertNameAndVersion(%s) = (%s, %s), want (%s, %s)",
|
||||
tt.id, name, ver, tt.wantName, tt.wantVer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractCertMetadata validates certificate metadata extraction.
|
||||
func TestExtractCertMetadata(t *testing.T) {
|
||||
// Generate a test certificate
|
||||
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate key: %v", err)
|
||||
}
|
||||
|
||||
serialNumber := big.NewInt(123456)
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: "test.example.com",
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: false,
|
||||
DNSNames: []string{"test.example.com", "www.test.example.com"},
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privKey.PublicKey, privKey)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create cert: %v", err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(derBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse cert: %v", err)
|
||||
}
|
||||
|
||||
entry := extractCertMetadata(cert, "testcert", "v1")
|
||||
|
||||
if entry.CommonName != "test.example.com" {
|
||||
t.Fatalf("expected CN 'test.example.com', got %s", entry.CommonName)
|
||||
}
|
||||
|
||||
if len(entry.SANs) != 2 {
|
||||
t.Fatalf("expected 2 SANs, got %d", len(entry.SANs))
|
||||
}
|
||||
|
||||
if entry.KeyAlgorithm != "ECDSA" {
|
||||
t.Fatalf("expected key algorithm ECDSA, got %s", entry.KeyAlgorithm)
|
||||
}
|
||||
|
||||
if entry.KeySize != 256 {
|
||||
t.Fatalf("expected key size 256, got %d", entry.KeySize)
|
||||
}
|
||||
|
||||
if entry.SerialNumber == "" {
|
||||
t.Fatal("expected serial number, got empty")
|
||||
}
|
||||
|
||||
if entry.SourceFormat != "DER" {
|
||||
t.Fatalf("expected source format DER, got %s", entry.SourceFormat)
|
||||
}
|
||||
|
||||
// Verify fingerprint is valid hex
|
||||
if len(entry.FingerprintSHA256) != 64 {
|
||||
t.Fatalf("expected 64-char fingerprint, got %d chars", len(entry.FingerprintSHA256))
|
||||
}
|
||||
|
||||
// Verify manually calculated fingerprint
|
||||
fp := sha256.Sum256(derBytes)
|
||||
expectedFP := fmt.Sprintf("%X", fp)
|
||||
if entry.FingerprintSHA256 != expectedFP {
|
||||
t.Fatalf("fingerprint mismatch: got %s, want %s", entry.FingerprintSHA256, expectedFP)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeCertPEM validates PEM encoding.
|
||||
func TestEncodeCertPEM(t *testing.T) {
|
||||
derBytes, err := generateTestCert("test.example.com", []string{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate test cert: %v", err)
|
||||
}
|
||||
|
||||
pemStr := encodeCertPEM(derBytes)
|
||||
|
||||
// Verify PEM format
|
||||
if !contains(pemStr, "-----BEGIN CERTIFICATE-----") {
|
||||
t.Fatal("expected PEM header")
|
||||
}
|
||||
|
||||
if !contains(pemStr, "-----END CERTIFICATE-----") {
|
||||
t.Fatal("expected PEM footer")
|
||||
}
|
||||
|
||||
// Verify we can decode it back
|
||||
block, _ := pem.Decode([]byte(pemStr))
|
||||
if block == nil {
|
||||
t.Fatal("failed to decode PEM")
|
||||
}
|
||||
|
||||
if len(block.Bytes) != len(derBytes) {
|
||||
t.Fatal("decoded PEM does not match original DER")
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if string contains substring
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) > 0 && len(substr) > 0 && s != substr &&
|
||||
(s == substr || len(s) > len(substr))
|
||||
}
|
||||
Reference in New Issue
Block a user