Files
certctl/internal/service/discovery_test.go
T
shankar0123 fe7e766510 Close M-004 (OCSP issuer binding) and M-005 (discovery actor propagation) coverage-gap findings
M-004 — OCSP issuer binding (composite key):
  The OCSP lookup path now binds (issuer_id, serial) as a composite key
  rather than resolving by serial alone. CertificateRepository and
  RevocationRepository gain GetByIssuerAndSerial methods; ca_operations.go
  scopes both lookups by the issuer_id path param. When no managed cert
  binds to that (issuer, serial) tuple, GetOCSPResponse constructs an
  RFC 6960 §2.2 'unknown' response (CertStatus=2) instead of the prior
  default 'good'. Short-lived cert exemption (profile TTL < 1h) is
  preserved. Real repo errors (non-sql.ErrNoRows) fail closed with a log.

  Regression coverage: internal/service/ca_operations_test.go
    - TestCAOperationsSvc_GetOCSPResponse_Unknown_CrossIssuer
    - TestCAOperationsSvc_GetOCSPResponse_Unknown_UnknownSerial

M-005 — Discovery Claim/Dismiss actor propagation:
  DiscoveryService.ClaimDiscovered and DismissDiscovered now accept an
  explicit 'actor string' parameter (propagation pattern mirrors
  bulk_revocation.go / revocation_svc.go). The handler layer passes
  resolveActor(r.Context()) — the named-key identity established by the
  M-002 auth unification — and the service falls back to 'api' (the same
  safe sentinel resolveActor uses when no auth context is present) only
  when the caller passes an empty string. Never falls back to 'operator'.

  Regression coverage: internal/service/discovery_test.go
    - TestDiscoveryService_ClaimDiscovered_AuditActor
    - TestDiscoveryService_DismissDiscovered_AuditActor
    - TestDiscoveryService_ClaimDiscovered_EmptyActorFallsBackToAPI
    - TestDiscoveryService_DismissDiscovered_EmptyActorFallsBackToAPI

Each new test asserts event.Actor matches the caller-supplied string (or
'api' on empty input) and explicitly asserts event.Actor != 'operator'
to lock in the historical fix intent.

Files:
  internal/api/handler/discovery.go          — pass resolveActor(ctx)
  internal/api/handler/discovery_handler_test.go — updated call sites
  internal/integration/lifecycle_test.go     — updated mock wiring
  internal/repository/interfaces.go          — GetByIssuerAndSerial on
                                               CertificateRepository +
                                               RevocationRepository
  internal/repository/postgres/certificate.go — composite key lookup
  internal/service/ca_operations.go          — (issuer_id, serial) scoping
  internal/service/ca_operations_test.go     — 2 new M-004 tests
  internal/service/discovery.go              — actor parameter + 'api' fallback
  internal/service/discovery_test.go         — 4 new M-005 tests
  internal/service/shortlived_test.go        — mock signature update
  internal/service/testutil_test.go          — mock GetByIssuerAndSerial
2026-04-18 22:20:25 +00:00

675 lines
19 KiB
Go

package service
import (
"context"
"errors"
"testing"
"time"
"github.com/shankar0123/certctl/internal/domain"
"github.com/shankar0123/certctl/internal/repository"
)
// mockDiscoveryRepo is a test implementation of DiscoveryRepository
type mockDiscoveryRepo struct {
Scans map[string]*domain.DiscoveryScan
Discovered map[string]*domain.DiscoveredCertificate
CreateScanErr error
GetScanErr error
ListScansErr error
CreateDiscoveredErr error
GetDiscoveredErr error
ListDiscoveredErr error
UpdateStatusErr error
GetByFingerprintErr error
CountByStatusErr error
}
func newMockDiscoveryRepository() *mockDiscoveryRepo {
return &mockDiscoveryRepo{
Scans: make(map[string]*domain.DiscoveryScan),
Discovered: make(map[string]*domain.DiscoveredCertificate),
}
}
func (m *mockDiscoveryRepo) CreateScan(ctx context.Context, scan *domain.DiscoveryScan) error {
if m.CreateScanErr != nil {
return m.CreateScanErr
}
m.Scans[scan.ID] = scan
return nil
}
func (m *mockDiscoveryRepo) GetScan(ctx context.Context, id string) (*domain.DiscoveryScan, error) {
if m.GetScanErr != nil {
return nil, m.GetScanErr
}
scan, ok := m.Scans[id]
if !ok {
return nil, errNotFound
}
return scan, nil
}
func (m *mockDiscoveryRepo) ListScans(ctx context.Context, agentID string, page, perPage int) ([]*domain.DiscoveryScan, int, error) {
if m.ListScansErr != nil {
return nil, 0, m.ListScansErr
}
var scans []*domain.DiscoveryScan
for _, s := range m.Scans {
if agentID == "" || s.AgentID == agentID {
scans = append(scans, s)
}
}
return scans, len(scans), nil
}
func (m *mockDiscoveryRepo) CreateDiscovered(ctx context.Context, cert *domain.DiscoveredCertificate) (bool, error) {
if m.CreateDiscoveredErr != nil {
return false, m.CreateDiscoveredErr
}
_, exists := m.Discovered[cert.ID]
m.Discovered[cert.ID] = cert
return !exists, nil // true if new (not existed before)
}
func (m *mockDiscoveryRepo) GetDiscovered(ctx context.Context, id string) (*domain.DiscoveredCertificate, error) {
if m.GetDiscoveredErr != nil {
return nil, m.GetDiscoveredErr
}
cert, ok := m.Discovered[id]
if !ok {
return nil, errNotFound
}
return cert, nil
}
func (m *mockDiscoveryRepo) ListDiscovered(ctx context.Context, filter *repository.DiscoveryFilter) ([]*domain.DiscoveredCertificate, int, error) {
if m.ListDiscoveredErr != nil {
return nil, 0, m.ListDiscoveredErr
}
var certs []*domain.DiscoveredCertificate
for _, c := range m.Discovered {
if filter.AgentID != "" && c.AgentID != filter.AgentID {
continue
}
if filter.Status != "" && string(c.Status) != filter.Status {
continue
}
certs = append(certs, c)
}
return certs, len(certs), nil
}
func (m *mockDiscoveryRepo) UpdateDiscoveredStatus(ctx context.Context, id string, status domain.DiscoveryStatus, managedCertID string) error {
if m.UpdateStatusErr != nil {
return m.UpdateStatusErr
}
cert, ok := m.Discovered[id]
if !ok {
return errNotFound
}
cert.Status = status
cert.ManagedCertificateID = managedCertID
now := time.Now()
if status == domain.DiscoveryStatusDismissed {
cert.DismissedAt = &now
}
return nil
}
func (m *mockDiscoveryRepo) GetByFingerprint(ctx context.Context, fingerprint string) ([]*domain.DiscoveredCertificate, error) {
if m.GetByFingerprintErr != nil {
return nil, m.GetByFingerprintErr
}
var certs []*domain.DiscoveredCertificate
for _, c := range m.Discovered {
if c.FingerprintSHA256 == fingerprint {
certs = append(certs, c)
}
}
return certs, nil
}
func (m *mockDiscoveryRepo) CountByStatus(ctx context.Context) (map[string]int, error) {
if m.CountByStatusErr != nil {
return nil, m.CountByStatusErr
}
counts := make(map[string]int)
for _, c := range m.Discovered {
counts[string(c.Status)]++
}
return counts, nil
}
// helper to create a test DiscoveryService wired for discovery tests
func newDiscoveryTestService() (*DiscoveryService, *mockDiscoveryRepo, *mockCertRepo, *mockAuditRepo) {
discoveryRepo := newMockDiscoveryRepository()
certRepo := newMockCertificateRepository()
auditRepo := newMockAuditRepository()
auditService := NewAuditService(auditRepo)
discoveryService := NewDiscoveryService(discoveryRepo, certRepo, auditService)
return discoveryService, discoveryRepo, certRepo, auditRepo
}
func TestProcessDiscoveryReport_Success(t *testing.T) {
svc, discoveryRepo, _, auditRepo := newDiscoveryTestService()
report := &domain.DiscoveryReport{
AgentID: "agent-1",
Directories: []string{"/etc/certs", "/opt/certs"},
ScanDurationMs: 150,
Certificates: []domain.DiscoveredCertEntry{
{
FingerprintSHA256: "abc123",
CommonName: "example.com",
SANs: []string{"www.example.com"},
SerialNumber: "001",
IssuerDN: "CN=Let's Encrypt",
SubjectDN: "CN=example.com",
NotBefore: time.Now().AddDate(-1, 0, 0).Format(time.RFC3339),
NotAfter: time.Now().AddDate(1, 0, 0).Format(time.RFC3339),
KeyAlgorithm: "RSA",
KeySize: 2048,
IsCA: false,
PEMData: "-----BEGIN CERTIFICATE-----...",
SourcePath: "/etc/certs/example.com.crt",
SourceFormat: "PEM",
},
},
Errors: []string{},
}
scan, err := svc.ProcessDiscoveryReport(context.Background(), report)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if scan == nil {
t.Fatal("expected scan to be returned")
}
if scan.AgentID != "agent-1" {
t.Errorf("expected agent ID agent-1, got %s", scan.AgentID)
}
if scan.CertificatesFound != 1 {
t.Errorf("expected 1 certificate found, got %d", scan.CertificatesFound)
}
if scan.CertificatesNew != 1 {
t.Errorf("expected 1 new certificate, got %d", scan.CertificatesNew)
}
// Verify scan was persisted
if len(discoveryRepo.Scans) != 1 {
t.Fatalf("expected 1 scan in repo, got %d", len(discoveryRepo.Scans))
}
// Verify discovered cert was persisted
if len(discoveryRepo.Discovered) != 1 {
t.Fatalf("expected 1 discovered cert in repo, got %d", len(discoveryRepo.Discovered))
}
// Verify audit event was recorded
if len(auditRepo.Events) == 0 {
t.Error("expected audit event to be recorded")
}
foundDiscoveryAudit := false
for _, e := range auditRepo.Events {
if e.Action == "discovery_scan_completed" {
foundDiscoveryAudit = true
}
}
if !foundDiscoveryAudit {
t.Error("expected discovery_scan_completed audit event")
}
}
func TestProcessDiscoveryReport_EmptyAgentID(t *testing.T) {
svc, _, _, _ := newDiscoveryTestService()
report := &domain.DiscoveryReport{
AgentID: "", // empty agent ID
Certificates: []domain.DiscoveredCertEntry{
{
FingerprintSHA256: "abc123",
CommonName: "example.com",
},
},
}
_, err := svc.ProcessDiscoveryReport(context.Background(), report)
if err == nil {
t.Fatal("expected error for empty agent_id")
}
if !errors.Is(err, err) { // just verify error occurred
t.Errorf("expected validation error")
}
}
func TestProcessDiscoveryReport_EmptyReport(t *testing.T) {
svc, _, _, _ := newDiscoveryTestService()
report := &domain.DiscoveryReport{
AgentID: "agent-1",
Certificates: []domain.DiscoveredCertEntry{},
Errors: []string{},
ScanDurationMs: 100,
}
_, err := svc.ProcessDiscoveryReport(context.Background(), report)
if err == nil {
t.Fatal("expected error for empty report")
}
}
func TestListDiscovered_Success(t *testing.T) {
svc, discoveryRepo, _, _ := newDiscoveryTestService()
now := time.Now()
cert1 := &domain.DiscoveredCertificate{
ID: "dcert-1",
AgentID: "agent-1",
CommonName: "example.com",
Status: domain.DiscoveryStatusUnmanaged,
CreatedAt: now,
UpdatedAt: now,
}
cert2 := &domain.DiscoveredCertificate{
ID: "dcert-2",
AgentID: "agent-1",
CommonName: "api.example.com",
Status: domain.DiscoveryStatusManaged,
CreatedAt: now,
UpdatedAt: now,
}
discoveryRepo.Discovered[cert1.ID] = cert1
discoveryRepo.Discovered[cert2.ID] = cert2
certs, total, err := svc.ListDiscovered(context.Background(), "agent-1", "", 1, 50)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if len(certs) != 2 {
t.Errorf("expected 2 certs, got %d", len(certs))
}
if total != 2 {
t.Errorf("expected total 2, got %d", total)
}
}
func TestListDiscovered_WithStatusFilter(t *testing.T) {
svc, discoveryRepo, _, _ := newDiscoveryTestService()
now := time.Now()
cert1 := &domain.DiscoveredCertificate{
ID: "dcert-1",
AgentID: "agent-1",
CommonName: "example.com",
Status: domain.DiscoveryStatusUnmanaged,
CreatedAt: now,
UpdatedAt: now,
}
cert2 := &domain.DiscoveredCertificate{
ID: "dcert-2",
AgentID: "agent-1",
CommonName: "api.example.com",
Status: domain.DiscoveryStatusManaged,
CreatedAt: now,
UpdatedAt: now,
}
discoveryRepo.Discovered[cert1.ID] = cert1
discoveryRepo.Discovered[cert2.ID] = cert2
certs, total, err := svc.ListDiscovered(context.Background(), "agent-1", "Unmanaged", 1, 50)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if len(certs) != 1 {
t.Errorf("expected 1 cert, got %d", len(certs))
}
if total != 1 {
t.Errorf("expected total 1, got %d", total)
}
}
func TestGetDiscovered_Success(t *testing.T) {
svc, discoveryRepo, _, _ := newDiscoveryTestService()
now := time.Now()
cert := &domain.DiscoveredCertificate{
ID: "dcert-1",
CommonName: "example.com",
Status: domain.DiscoveryStatusUnmanaged,
CreatedAt: now,
UpdatedAt: now,
}
discoveryRepo.Discovered[cert.ID] = cert
retrieved, err := svc.GetDiscovered(context.Background(), "dcert-1")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if retrieved.ID != "dcert-1" {
t.Errorf("expected ID dcert-1, got %s", retrieved.ID)
}
}
func TestClaimDiscovered_Success(t *testing.T) {
svc, discoveryRepo, certRepo, auditRepo := newDiscoveryTestService()
now := time.Now()
discoveredCert := &domain.DiscoveredCertificate{
ID: "dcert-1",
CommonName: "example.com",
FingerprintSHA256: "abc123",
Status: domain.DiscoveryStatusUnmanaged,
CreatedAt: now,
UpdatedAt: now,
}
discoveryRepo.Discovered[discoveredCert.ID] = discoveredCert
managedCert := &domain.ManagedCertificate{
ID: "mc-prod-1",
CommonName: "example.com",
Status: domain.CertificateStatusActive,
CreatedAt: now,
UpdatedAt: now,
}
certRepo.AddCert(managedCert)
err := svc.ClaimDiscovered(context.Background(), "dcert-1", "mc-prod-1", "alice@corp")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
// Verify status was updated
updated := discoveryRepo.Discovered["dcert-1"]
if updated.Status != domain.DiscoveryStatusManaged {
t.Errorf("expected status Managed, got %s", updated.Status)
}
if updated.ManagedCertificateID != "mc-prod-1" {
t.Errorf("expected managed cert ID mc-prod-1, got %s", updated.ManagedCertificateID)
}
// Verify audit event was recorded
if len(auditRepo.Events) == 0 {
t.Error("expected audit event to be recorded")
}
foundClaimAudit := false
for _, e := range auditRepo.Events {
if e.Action == "discovery_cert_claimed" {
foundClaimAudit = true
}
}
if !foundClaimAudit {
t.Error("expected discovery_cert_claimed audit event")
}
}
func TestClaimDiscovered_MissingManagedCertID(t *testing.T) {
svc, discoveryRepo, _, _ := newDiscoveryTestService()
now := time.Now()
cert := &domain.DiscoveredCertificate{
ID: "dcert-1",
CommonName: "example.com",
Status: domain.DiscoveryStatusUnmanaged,
CreatedAt: now,
UpdatedAt: now,
}
discoveryRepo.Discovered[cert.ID] = cert
err := svc.ClaimDiscovered(context.Background(), "dcert-1", "", "test-actor")
if err == nil {
t.Fatal("expected error for empty managed_certificate_id")
}
}
func TestClaimDiscovered_ManagedCertNotFound(t *testing.T) {
svc, discoveryRepo, _, _ := newDiscoveryTestService()
now := time.Now()
cert := &domain.DiscoveredCertificate{
ID: "dcert-1",
CommonName: "example.com",
Status: domain.DiscoveryStatusUnmanaged,
CreatedAt: now,
UpdatedAt: now,
}
discoveryRepo.Discovered[cert.ID] = cert
err := svc.ClaimDiscovered(context.Background(), "dcert-1", "nonexistent-cert", "test-actor")
if err == nil {
t.Fatal("expected error for nonexistent managed certificate")
}
if !errors.Is(err, err) { // just verify error occurred
t.Errorf("expected 'not found' error")
}
}
func TestDismissDiscovered_Success(t *testing.T) {
svc, discoveryRepo, _, auditRepo := newDiscoveryTestService()
now := time.Now()
cert := &domain.DiscoveredCertificate{
ID: "dcert-1",
CommonName: "example.com",
Status: domain.DiscoveryStatusUnmanaged,
CreatedAt: now,
UpdatedAt: now,
}
discoveryRepo.Discovered[cert.ID] = cert
err := svc.DismissDiscovered(context.Background(), "dcert-1", "bob@corp")
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
// Verify status was updated
updated := discoveryRepo.Discovered["dcert-1"]
if updated.Status != domain.DiscoveryStatusDismissed {
t.Errorf("expected status Dismissed, got %s", updated.Status)
}
if updated.DismissedAt == nil {
t.Error("expected DismissedAt to be set")
}
// Verify audit event was recorded
if len(auditRepo.Events) == 0 {
t.Error("expected audit event to be recorded")
}
foundDismissAudit := false
for _, e := range auditRepo.Events {
if e.Action == "discovery_cert_dismissed" {
foundDismissAudit = true
}
}
if !foundDismissAudit {
t.Error("expected discovery_cert_dismissed audit event")
}
}
func TestDismissDiscovered_NotFound(t *testing.T) {
svc, discoveryRepo, _, _ := newDiscoveryTestService()
discoveryRepo.UpdateStatusErr = errNotFound
err := svc.DismissDiscovered(context.Background(), "nonexistent", "test-actor")
if err == nil {
t.Fatal("expected error for nonexistent cert")
}
}
// M-005 regression: caller-supplied actor must propagate onto the
// discovery_cert_claimed audit event so the trail identifies who performed
// triage (pre-M-005 the service hardcoded "operator").
func TestDiscoveryService_ClaimDiscovered_AuditActor(t *testing.T) {
svc, discoveryRepo, certRepo, auditRepo := newDiscoveryTestService()
now := time.Now()
discoveredCert := &domain.DiscoveredCertificate{
ID: "dcert-1",
CommonName: "example.com",
FingerprintSHA256: "abc123",
Status: domain.DiscoveryStatusUnmanaged,
CreatedAt: now,
UpdatedAt: now,
}
discoveryRepo.Discovered[discoveredCert.ID] = discoveredCert
managedCert := &domain.ManagedCertificate{
ID: "mc-prod-1",
CommonName: "example.com",
Status: domain.CertificateStatusActive,
CreatedAt: now,
UpdatedAt: now,
}
certRepo.AddCert(managedCert)
if err := svc.ClaimDiscovered(context.Background(), "dcert-1", "mc-prod-1", "alice@corp"); err != nil {
t.Fatalf("expected no error, got: %v", err)
}
// Locate the discovery_cert_claimed audit event and assert actor propagation.
var claimEvent *domain.AuditEvent
for _, e := range auditRepo.Events {
if e.Action == "discovery_cert_claimed" {
claimEvent = e
break
}
}
if claimEvent == nil {
t.Fatal("expected discovery_cert_claimed audit event to be recorded")
}
if claimEvent.Actor != "alice@corp" {
t.Errorf("expected audit actor to be caller-supplied 'alice@corp', got %q", claimEvent.Actor)
}
if claimEvent.Actor == "operator" {
t.Error("audit actor must not be hardcoded 'operator' (M-005 regression)")
}
}
// M-005 regression symmetric pair for DismissDiscovered.
func TestDiscoveryService_DismissDiscovered_AuditActor(t *testing.T) {
svc, discoveryRepo, _, auditRepo := newDiscoveryTestService()
now := time.Now()
cert := &domain.DiscoveredCertificate{
ID: "dcert-1",
CommonName: "example.com",
Status: domain.DiscoveryStatusUnmanaged,
CreatedAt: now,
UpdatedAt: now,
}
discoveryRepo.Discovered[cert.ID] = cert
if err := svc.DismissDiscovered(context.Background(), "dcert-1", "bob@corp"); err != nil {
t.Fatalf("expected no error, got: %v", err)
}
var dismissEvent *domain.AuditEvent
for _, e := range auditRepo.Events {
if e.Action == "discovery_cert_dismissed" {
dismissEvent = e
break
}
}
if dismissEvent == nil {
t.Fatal("expected discovery_cert_dismissed audit event to be recorded")
}
if dismissEvent.Actor != "bob@corp" {
t.Errorf("expected audit actor to be caller-supplied 'bob@corp', got %q", dismissEvent.Actor)
}
if dismissEvent.Actor == "operator" {
t.Error("audit actor must not be hardcoded 'operator' (M-005 regression)")
}
}
// M-005 regression: when the caller passes an empty actor (e.g., the handler's
// resolveActor helper returns "" because no auth context is present), the
// service must fall back to the safe sentinel "api" — never to the pre-M-005
// hardcoded "operator".
func TestDiscoveryService_ClaimDiscovered_EmptyActorFallsBackToAPI(t *testing.T) {
svc, discoveryRepo, certRepo, auditRepo := newDiscoveryTestService()
now := time.Now()
discoveredCert := &domain.DiscoveredCertificate{
ID: "dcert-1",
CommonName: "example.com",
FingerprintSHA256: "abc123",
Status: domain.DiscoveryStatusUnmanaged,
CreatedAt: now,
UpdatedAt: now,
}
discoveryRepo.Discovered[discoveredCert.ID] = discoveredCert
managedCert := &domain.ManagedCertificate{
ID: "mc-prod-1",
CommonName: "example.com",
Status: domain.CertificateStatusActive,
CreatedAt: now,
UpdatedAt: now,
}
certRepo.AddCert(managedCert)
if err := svc.ClaimDiscovered(context.Background(), "dcert-1", "mc-prod-1", ""); err != nil {
t.Fatalf("expected no error, got: %v", err)
}
var claimEvent *domain.AuditEvent
for _, e := range auditRepo.Events {
if e.Action == "discovery_cert_claimed" {
claimEvent = e
break
}
}
if claimEvent == nil {
t.Fatal("expected discovery_cert_claimed audit event to be recorded")
}
if claimEvent.Actor != "api" {
t.Errorf("expected empty actor to fall back to 'api', got %q", claimEvent.Actor)
}
if claimEvent.Actor == "operator" {
t.Error("audit actor must not be hardcoded 'operator' (M-005 regression)")
}
}
// M-005 regression symmetric pair for DismissDiscovered empty-actor fallback.
func TestDiscoveryService_DismissDiscovered_EmptyActorFallsBackToAPI(t *testing.T) {
svc, discoveryRepo, _, auditRepo := newDiscoveryTestService()
now := time.Now()
cert := &domain.DiscoveredCertificate{
ID: "dcert-1",
CommonName: "example.com",
Status: domain.DiscoveryStatusUnmanaged,
CreatedAt: now,
UpdatedAt: now,
}
discoveryRepo.Discovered[cert.ID] = cert
if err := svc.DismissDiscovered(context.Background(), "dcert-1", ""); err != nil {
t.Fatalf("expected no error, got: %v", err)
}
var dismissEvent *domain.AuditEvent
for _, e := range auditRepo.Events {
if e.Action == "discovery_cert_dismissed" {
dismissEvent = e
break
}
}
if dismissEvent == nil {
t.Fatal("expected discovery_cert_dismissed audit event to be recorded")
}
if dismissEvent.Actor != "api" {
t.Errorf("expected empty actor to fall back to 'api', got %q", dismissEvent.Actor)
}
if dismissEvent.Actor == "operator" {
t.Error("audit actor must not be hardcoded 'operator' (M-005 regression)")
}
}