fix: resolve NULL csr_pem scan errors and QA smoke test failures

Root cause: certificate_versions.csr_pem is nullable in the schema but
Go code scanned it into a plain string. Used sql.NullString in
ListVersions and GetLatestVersion to handle NULL values correctly.

Also includes: partial update fetch-merge-update pattern to prevent FK
violations, nil directory guard in discovery service, diagnostic slog
logging in handlers, export handler 422 for unparseable PEM, OpenAPI
spec corrections, MCP tool description improvements, and test fixes.

Rewrites the Release Sign-Off section in testing-guide.md to individual
test-level granularity (320 rows) with smoke test results audited and
checked off (121 pass, 5 skip, 194 manual remaining).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Shankar
2026-03-30 00:51:18 -04:00
parent ed3f9cc2db
commit 380fcab42e
12 changed files with 683 additions and 74 deletions
+11
View File
@@ -3,6 +3,7 @@ package handler
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"strconv"
"strings"
@@ -134,6 +135,11 @@ func (h AgentHandler) RegisterAgent(w http.ResponseWriter, r *http.Request) {
created, err := h.svc.RegisterAgent(r.Context(), agent)
if err != nil {
errMsg := err.Error()
if strings.Contains(errMsg, "unique") || strings.Contains(errMsg, "duplicate") || strings.Contains(errMsg, "already exists") {
ErrorWithRequestID(w, http.StatusConflict, "Agent with this name already exists", requestID)
return
}
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to register agent", requestID)
return
}
@@ -184,6 +190,11 @@ func (h AgentHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
}
if err := h.svc.Heartbeat(r.Context(), agentID, metadata); err != nil {
if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Agent not found", requestID)
return
}
slog.Error("Heartbeat failed", "agent_id", agentID, "error", err.Error())
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to record heartbeat", requestID)
return
}
@@ -353,11 +353,12 @@ func TestCreateCertificate_Success(t *testing.T) {
handler := NewCertificateHandler(mock)
certBody := domain.ManagedCertificate{
Name: "Production Cert",
CommonName: "example.com",
OwnerID: "o-alice",
TeamID: "t-platform",
IssuerID: "iss-local",
Name: "Production Cert",
CommonName: "example.com",
OwnerID: "o-alice",
TeamID: "t-platform",
IssuerID: "iss-local",
RenewalPolicyID: "rp-standard",
}
body, _ := json.Marshal(certBody)
@@ -410,11 +411,12 @@ func TestCreateCertificate_ServiceError(t *testing.T) {
handler := NewCertificateHandler(mock)
certBody := domain.ManagedCertificate{
Name: "Production Cert",
CommonName: "example.com",
OwnerID: "o-alice",
TeamID: "t-platform",
IssuerID: "iss-local",
Name: "Production Cert",
CommonName: "example.com",
OwnerID: "o-alice",
TeamID: "t-platform",
IssuerID: "iss-local",
RenewalPolicyID: "rp-standard",
}
body, _ := json.Marshal(certBody)
@@ -534,8 +536,8 @@ func TestArchiveCertificate_NotFound(t *testing.T) {
handler.ArchiveCertificate(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
if w.Code != http.StatusNotFound {
t.Errorf("expected status %d, got %d", http.StatusNotFound, w.Code)
}
}
+38 -2
View File
@@ -2,6 +2,7 @@ package handler
import (
"encoding/json"
"log/slog"
"net/http"
"strconv"
"strings"
@@ -231,6 +232,14 @@ func (h CertificateHandler) CreateCertificate(w http.ResponseWriter, r *http.Req
ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID)
return
}
if err := ValidateRequired("name", cert.Name); err != nil {
ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID)
return
}
if err := ValidateRequired("renewal_policy_id", cert.RenewalPolicyID); err != nil {
ErrorWithRequestID(w, http.StatusBadRequest, err.Error(), requestID)
return
}
created, err := h.svc.CreateCertificate(cert)
if err != nil {
@@ -287,6 +296,11 @@ func (h CertificateHandler) UpdateCertificate(w http.ResponseWriter, r *http.Req
updated, err := h.svc.UpdateCertificate(id, cert)
if err != nil {
if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
return
}
slog.Error("UpdateCertificate failed", "cert_id", id, "error", err.Error())
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to update certificate", requestID)
return
}
@@ -311,6 +325,10 @@ func (h CertificateHandler) ArchiveCertificate(w http.ResponseWriter, r *http.Re
}
if err := h.svc.ArchiveCertificate(id); err != nil {
if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
return
}
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to archive certificate", requestID)
return
}
@@ -353,7 +371,12 @@ func (h CertificateHandler) GetCertificateVersions(w http.ResponseWriter, r *htt
versions, total, err := h.svc.GetCertificateVersions(certID, page, perPage)
if err != nil {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
if strings.Contains(err.Error(), "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
return
}
slog.Error("GetCertificateVersions failed", "cert_id", certID, "error", err.Error())
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to get certificate versions", requestID)
return
}
@@ -387,6 +410,19 @@ func (h CertificateHandler) TriggerRenewal(w http.ResponseWriter, r *http.Reques
certID := parts[0]
if err := h.svc.TriggerRenewal(certID); err != nil {
errMsg := err.Error()
if strings.Contains(errMsg, "not found") {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
return
}
if strings.Contains(errMsg, "cannot renew") {
ErrorWithRequestID(w, http.StatusBadRequest, errMsg, requestID)
return
}
if strings.Contains(errMsg, "already in progress") {
ErrorWithRequestID(w, http.StatusConflict, errMsg, requestID)
return
}
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to trigger renewal", requestID)
return
}
@@ -480,7 +516,7 @@ func (h CertificateHandler) RevokeCertificate(w http.ResponseWriter, r *http.Req
ErrorWithRequestID(w, http.StatusBadRequest, errMsg, requestID)
return
}
if strings.Contains(errMsg, "not found") || strings.Contains(errMsg, "failed to fetch") {
if strings.Contains(errMsg, "not found") || strings.Contains(errMsg, "failed to fetch") || strings.Contains(errMsg, "failed to get") {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
return
}
+7
View File
@@ -3,6 +3,7 @@ package handler
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"strings"
@@ -49,6 +50,7 @@ func (h ExportHandler) ExportPEM(w http.ResponseWriter, r *http.Request) {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
return
}
slog.Error("ExportPEM failed", "cert_id", id, "error", err.Error())
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to export certificate", requestID)
return
}
@@ -96,6 +98,11 @@ func (h ExportHandler) ExportPKCS12(w http.ResponseWriter, r *http.Request) {
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
return
}
if strings.Contains(err.Error(), "cannot be parsed") || strings.Contains(err.Error(), "no certificates found") {
ErrorWithRequestID(w, http.StatusUnprocessableEntity, "Certificate data cannot be parsed as X.509", requestID)
return
}
slog.Error("ExportPKCS12 failed", "cert_id", id, "error", err.Error())
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to export PKCS#12", requestID)
return
}
+4 -4
View File
@@ -99,7 +99,7 @@ func registerCertificateTools(s *gomcp.Server, c *Client) {
gomcp.AddTool(s, &gomcp.Tool{
Name: "certctl_create_certificate",
Description: "Create a new managed certificate. Requires common_name and issuer_id at minimum.",
Description: "Create a new managed certificate. Requires name, common_name, renewal_policy_id, issuer_id, owner_id, and team_id.",
}, func(ctx context.Context, req *gomcp.CallToolRequest, input CreateCertificateInput) (*gomcp.CallToolResult, any, error) {
data, err := c.Post("/api/v1/certificates", input)
if err != nil {
@@ -144,7 +144,7 @@ func registerCertificateTools(s *gomcp.Server, c *Client) {
gomcp.AddTool(s, &gomcp.Tool{
Name: "certctl_trigger_renewal",
Description: "Trigger immediate renewal of a certificate. Creates a renewal job (async, returns 202).",
Description: "Trigger immediate renewal of a certificate. Creates a renewal job (async, returns 202). Returns 404 if certificate not found, 400 if certificate is archived/expired, 409 if renewal already in progress.",
}, func(ctx context.Context, req *gomcp.CallToolRequest, input GetByIDInput) (*gomcp.CallToolResult, any, error) {
data, err := c.Post("/api/v1/certificates/"+input.ID+"/renew", nil)
if err != nil {
@@ -385,7 +385,7 @@ func registerAgentTools(s *gomcp.Server, c *Client) {
gomcp.AddTool(s, &gomcp.Tool{
Name: "certctl_register_agent",
Description: "Register a new agent. Requires name and hostname.",
Description: "Register a new agent. Requires name and hostname. Returns 409 if an agent with the same name already exists.",
}, func(ctx context.Context, req *gomcp.CallToolRequest, input RegisterAgentInput) (*gomcp.CallToolResult, any, error) {
data, err := c.Post("/api/v1/agents", input)
if err != nil {
@@ -396,7 +396,7 @@ func registerAgentTools(s *gomcp.Server, c *Client) {
gomcp.AddTool(s, &gomcp.Tool{
Name: "certctl_agent_heartbeat",
Description: "Send agent heartbeat with optional metadata (OS, architecture, IP, version).",
Description: "Send agent heartbeat with optional metadata (OS, architecture, IP, version). Returns 404 if agent not found.",
}, func(ctx context.Context, req *gomcp.CallToolRequest, input struct {
ID string `json:"id" jsonschema:"Agent ID"`
Version string `json:"version,omitempty" jsonschema:"Agent version"`
+11 -7
View File
@@ -349,7 +349,7 @@ func (r *CertificateRepository) Archive(ctx context.Context, id string) error {
func (r *CertificateRepository) ListVersions(ctx context.Context, certID string) ([]*domain.CertificateVersion, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, certificate_id, serial_number, not_before, not_after,
fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at
fingerprint_sha256, pem_chain, csr_pem, created_at
FROM certificate_versions
WHERE certificate_id = $1
ORDER BY created_at DESC
@@ -363,10 +363,12 @@ func (r *CertificateRepository) ListVersions(ctx context.Context, certID string)
var versions []*domain.CertificateVersion
for rows.Next() {
var v domain.CertificateVersion
var csrPEM sql.NullString
if err := rows.Scan(&v.ID, &v.CertificateID, &v.SerialNumber, &v.NotBefore, &v.NotAfter,
&v.FingerprintSHA256, &v.PEMChain, &v.CSRPEM, &v.KeyAlgorithm, &v.KeySize, &v.CreatedAt); err != nil {
&v.FingerprintSHA256, &v.PEMChain, &csrPEM, &v.CreatedAt); err != nil {
return nil, fmt.Errorf("failed to scan certificate version: %w", err)
}
v.CSRPEM = csrPEM.String
versions = append(versions, &v)
}
@@ -386,11 +388,11 @@ func (r *CertificateRepository) CreateVersion(ctx context.Context, version *doma
err := r.db.QueryRowContext(ctx, `
INSERT INTO certificate_versions (
id, certificate_id, serial_number, not_before, not_after,
fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
fingerprint_sha256, pem_chain, csr_pem, created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING id
`, version.ID, version.CertificateID, version.SerialNumber, version.NotBefore, version.NotAfter,
version.FingerprintSHA256, version.PEMChain, version.CSRPEM, version.KeyAlgorithm, version.KeySize, version.CreatedAt).Scan(&version.ID)
version.FingerprintSHA256, version.PEMChain, version.CSRPEM, version.CreatedAt).Scan(&version.ID)
if err != nil {
return fmt.Errorf("failed to create certificate version: %w", err)
@@ -433,15 +435,17 @@ func (r *CertificateRepository) GetExpiringCertificates(ctx context.Context, bef
// GetLatestVersion returns the most recent certificate version for a certificate.
func (r *CertificateRepository) GetLatestVersion(ctx context.Context, certID string) (*domain.CertificateVersion, error) {
var v domain.CertificateVersion
var csrPEM sql.NullString
err := r.db.QueryRowContext(ctx, `
SELECT id, certificate_id, serial_number, not_before, not_after,
fingerprint_sha256, pem_chain, csr_pem, key_algorithm, key_size, created_at
fingerprint_sha256, pem_chain, csr_pem, created_at
FROM certificate_versions
WHERE certificate_id = $1
ORDER BY created_at DESC
LIMIT 1
`, certID).Scan(&v.ID, &v.CertificateID, &v.SerialNumber, &v.NotBefore, &v.NotAfter,
&v.FingerprintSHA256, &v.PEMChain, &v.CSRPEM, &v.KeyAlgorithm, &v.KeySize, &v.CreatedAt)
&v.FingerprintSHA256, &v.PEMChain, &csrPEM, &v.CreatedAt)
v.CSRPEM = csrPEM.String
if err != nil {
return nil, fmt.Errorf("failed to get latest certificate version: %w", err)
+48 -4
View File
@@ -311,12 +311,56 @@ func (s *CertificateService) CreateCertificate(cert domain.ManagedCertificate) (
}
// UpdateCertificate modifies a certificate (handler interface method).
func (s *CertificateService) UpdateCertificate(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
cert.ID = id
if err := s.certRepo.Update(context.Background(), &cert); err != nil {
func (s *CertificateService) UpdateCertificate(id string, patch domain.ManagedCertificate) (*domain.ManagedCertificate, error) {
ctx := context.Background()
// Fetch existing certificate so partial updates don't zero out fields
existing, err := s.certRepo.Get(ctx, id)
if err != nil {
return nil, fmt.Errorf("certificate not found: %w", err)
}
// Merge non-zero fields from patch into existing
if patch.Name != "" {
existing.Name = patch.Name
}
if patch.CommonName != "" {
existing.CommonName = patch.CommonName
}
if len(patch.SANs) > 0 {
existing.SANs = patch.SANs
}
if patch.Environment != "" {
existing.Environment = patch.Environment
}
if patch.OwnerID != "" {
existing.OwnerID = patch.OwnerID
}
if patch.TeamID != "" {
existing.TeamID = patch.TeamID
}
if patch.IssuerID != "" {
existing.IssuerID = patch.IssuerID
}
if patch.RenewalPolicyID != "" {
existing.RenewalPolicyID = patch.RenewalPolicyID
}
if patch.CertificateProfileID != "" {
existing.CertificateProfileID = patch.CertificateProfileID
}
if patch.Status != "" {
existing.Status = patch.Status
}
if patch.Tags != nil {
existing.Tags = patch.Tags
}
existing.UpdatedAt = time.Now()
if err := s.certRepo.Update(ctx, existing); err != nil {
return nil, fmt.Errorf("failed to update certificate: %w", err)
}
return &cert, nil
return existing, nil
}
// ArchiveCertificate marks a certificate as archived (handler interface method).
+10 -5
View File
@@ -40,6 +40,11 @@ func (s *DiscoveryService) ProcessDiscoveryReport(ctx context.Context, report *d
return nil, fmt.Errorf("report must contain at least one certificate or error")
}
// Ensure directories is never nil (PostgreSQL TEXT[] NOT NULL)
if report.Directories == nil {
report.Directories = []string{}
}
now := time.Now()
scan := &domain.DiscoveryScan{
ID: generateID("dscan"),
@@ -52,6 +57,11 @@ func (s *DiscoveryService) ProcessDiscoveryReport(ctx context.Context, report *d
CompletedAt: &now,
}
// Store the scan record first (discovered certs reference scan via FK)
if err := s.discoveryRepo.CreateScan(ctx, scan); err != nil {
return nil, fmt.Errorf("failed to create scan record: %w", err)
}
// Upsert each discovered certificate
newCount := 0
for _, entry := range report.Certificates {
@@ -105,11 +115,6 @@ func (s *DiscoveryService) ProcessDiscoveryReport(ctx context.Context, report *d
scan.CertificatesNew = newCount
// Store the scan record
if err := s.discoveryRepo.CreateScan(ctx, scan); err != nil {
return nil, fmt.Errorf("failed to create scan record: %w", err)
}
// Audit trail
if err := s.auditService.RecordEvent(ctx, report.AgentID, domain.ActorTypeSystem,
"discovery_scan_completed", "discovery_scan", scan.ID,
+1 -1
View File
@@ -88,7 +88,7 @@ func (s *ExportService) ExportPKCS12(ctx context.Context, certID string, passwor
// Parse PEM chain into x509.Certificate objects
certs, err := parsePEMCertificates(version.PEMChain)
if err != nil {
return nil, fmt.Errorf("failed to parse certificate chain: %w", err)
return nil, fmt.Errorf("certificate data cannot be parsed as X.509: %w", err)
}
if len(certs) == 0 {
+2 -2
View File
@@ -321,8 +321,8 @@ func TestTeamService_Create_EmptyName(t *testing.T) {
t.Fatalf("expected validation error for empty name, got nil")
}
if !errors.Is(err, errors.New("team name is required")) {
t.Logf("error: %v", err)
if !strings.Contains(err.Error(), "team name is required") {
t.Errorf("expected error containing 'team name is required', got: %v", err)
}
}