mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-07 22:21:30 +00:00
743dca2fb3
EST handler tests used fake PEM data (e.g., "MIIBmjCCAUCgAwIBAgIRATest") which is invalid base64 (25 chars, not divisible by 4). Go's pem.Decode fails silently, causing pemToDERChain to return "no certificates found" and tests to get 500 instead of 200. Added generateTestCertPEM() helper that creates a real self-signed ECDSA P-256 certificate, used across all EST handler tests that need cert PEM. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
399 lines
11 KiB
Go
399 lines
11 KiB
Go
package handler
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/base64"
|
|
"encoding/pem"
|
|
"errors"
|
|
"math/big"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/shankar0123/certctl/internal/domain"
|
|
)
|
|
|
|
// mockESTService implements ESTService for testing.
|
|
type mockESTService struct {
|
|
CACertPEM string
|
|
CACertErr error
|
|
EnrollResult *domain.ESTEnrollResult
|
|
EnrollErr error
|
|
CSRAttrs []byte
|
|
CSRAttrsErr error
|
|
}
|
|
|
|
func (m *mockESTService) GetCACerts(ctx context.Context) (string, error) {
|
|
return m.CACertPEM, m.CACertErr
|
|
}
|
|
|
|
func (m *mockESTService) SimpleEnroll(ctx context.Context, csrPEM string) (*domain.ESTEnrollResult, error) {
|
|
return m.EnrollResult, m.EnrollErr
|
|
}
|
|
|
|
func (m *mockESTService) SimpleReEnroll(ctx context.Context, csrPEM string) (*domain.ESTEnrollResult, error) {
|
|
return m.EnrollResult, m.EnrollErr
|
|
}
|
|
|
|
func (m *mockESTService) GetCSRAttrs(ctx context.Context) ([]byte, error) {
|
|
return m.CSRAttrs, m.CSRAttrsErr
|
|
}
|
|
|
|
// generateTestCSRPEM creates a valid ECDSA P-256 CSR for testing.
|
|
func generateTestCSRPEM(t *testing.T) string {
|
|
t.Helper()
|
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
t.Fatalf("failed to generate key: %v", err)
|
|
}
|
|
template := &x509.CertificateRequest{
|
|
Subject: pkix.Name{CommonName: "test.example.com"},
|
|
DNSNames: []string{"test.example.com", "www.example.com"},
|
|
}
|
|
csrDER, err := x509.CreateCertificateRequest(rand.Reader, template, key)
|
|
if err != nil {
|
|
t.Fatalf("failed to create CSR: %v", err)
|
|
}
|
|
return string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csrDER}))
|
|
}
|
|
|
|
// generateTestCSRBase64DER creates a valid base64-encoded DER CSR for EST wire format.
|
|
func generateTestCSRBase64DER(t *testing.T) string {
|
|
t.Helper()
|
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
t.Fatalf("failed to generate key: %v", err)
|
|
}
|
|
template := &x509.CertificateRequest{
|
|
Subject: pkix.Name{CommonName: "test.example.com"},
|
|
DNSNames: []string{"test.example.com"},
|
|
}
|
|
csrDER, err := x509.CreateCertificateRequest(rand.Reader, template, key)
|
|
if err != nil {
|
|
t.Fatalf("failed to create CSR: %v", err)
|
|
}
|
|
return base64.StdEncoding.EncodeToString(csrDER)
|
|
}
|
|
|
|
// generateTestCertPEM creates a real self-signed certificate PEM for testing.
|
|
func generateTestCertPEM(t *testing.T) string {
|
|
t.Helper()
|
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
t.Fatalf("failed to generate key: %v", err)
|
|
}
|
|
template := &x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{CommonName: "Test CA"},
|
|
NotBefore: time.Now().Add(-1 * time.Hour),
|
|
NotAfter: time.Now().Add(24 * time.Hour),
|
|
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
|
IsCA: true,
|
|
BasicConstraintsValid: true,
|
|
}
|
|
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
|
if err != nil {
|
|
t.Fatalf("failed to create certificate: %v", err)
|
|
}
|
|
return string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}))
|
|
}
|
|
|
|
func TestESTCACerts_Success(t *testing.T) {
|
|
certPEM := generateTestCertPEM(t)
|
|
svc := &mockESTService{
|
|
CACertPEM: certPEM,
|
|
}
|
|
h := NewESTHandler(svc)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/.well-known/est/cacerts", nil)
|
|
w := httptest.NewRecorder()
|
|
h.CACerts(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
|
}
|
|
ct := w.Header().Get("Content-Type")
|
|
if !strings.Contains(ct, "application/pkcs7-mime") {
|
|
t.Errorf("expected application/pkcs7-mime content type, got %s", ct)
|
|
}
|
|
cte := w.Header().Get("Content-Transfer-Encoding")
|
|
if cte != "base64" {
|
|
t.Errorf("expected base64 content-transfer-encoding, got %s", cte)
|
|
}
|
|
}
|
|
|
|
func TestESTCACerts_MethodNotAllowed(t *testing.T) {
|
|
svc := &mockESTService{}
|
|
h := NewESTHandler(svc)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/cacerts", nil)
|
|
w := httptest.NewRecorder()
|
|
h.CACerts(w, req)
|
|
|
|
if w.Code != http.StatusMethodNotAllowed {
|
|
t.Errorf("expected 405, got %d", w.Code)
|
|
}
|
|
}
|
|
|
|
func TestESTCACerts_ServiceError(t *testing.T) {
|
|
svc := &mockESTService{
|
|
CACertErr: errors.New("issuer unavailable"),
|
|
}
|
|
h := NewESTHandler(svc)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/.well-known/est/cacerts", nil)
|
|
w := httptest.NewRecorder()
|
|
h.CACerts(w, req)
|
|
|
|
if w.Code != http.StatusInternalServerError {
|
|
t.Errorf("expected 500, got %d", w.Code)
|
|
}
|
|
}
|
|
|
|
func TestESTSimpleEnroll_Success_PEM(t *testing.T) {
|
|
csrPEM := generateTestCSRPEM(t)
|
|
certPEM := generateTestCertPEM(t)
|
|
svc := &mockESTService{
|
|
EnrollResult: &domain.ESTEnrollResult{
|
|
CertPEM: certPEM,
|
|
ChainPEM: certPEM,
|
|
},
|
|
}
|
|
h := NewESTHandler(svc)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/simpleenroll", strings.NewReader(csrPEM))
|
|
req.Header.Set("Content-Type", "application/pkcs10")
|
|
w := httptest.NewRecorder()
|
|
h.SimpleEnroll(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
|
}
|
|
ct := w.Header().Get("Content-Type")
|
|
if !strings.Contains(ct, "application/pkcs7-mime") {
|
|
t.Errorf("expected application/pkcs7-mime, got %s", ct)
|
|
}
|
|
}
|
|
|
|
func TestESTSimpleEnroll_Success_Base64DER(t *testing.T) {
|
|
csrB64 := generateTestCSRBase64DER(t)
|
|
certPEM := generateTestCertPEM(t)
|
|
svc := &mockESTService{
|
|
EnrollResult: &domain.ESTEnrollResult{
|
|
CertPEM: certPEM,
|
|
ChainPEM: "",
|
|
},
|
|
}
|
|
h := NewESTHandler(svc)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/simpleenroll", strings.NewReader(csrB64))
|
|
req.Header.Set("Content-Type", "application/pkcs10")
|
|
w := httptest.NewRecorder()
|
|
h.SimpleEnroll(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestESTSimpleEnroll_MethodNotAllowed(t *testing.T) {
|
|
svc := &mockESTService{}
|
|
h := NewESTHandler(svc)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/.well-known/est/simpleenroll", nil)
|
|
w := httptest.NewRecorder()
|
|
h.SimpleEnroll(w, req)
|
|
|
|
if w.Code != http.StatusMethodNotAllowed {
|
|
t.Errorf("expected 405, got %d", w.Code)
|
|
}
|
|
}
|
|
|
|
func TestESTSimpleEnroll_EmptyBody(t *testing.T) {
|
|
svc := &mockESTService{}
|
|
h := NewESTHandler(svc)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/simpleenroll", strings.NewReader(""))
|
|
w := httptest.NewRecorder()
|
|
h.SimpleEnroll(w, req)
|
|
|
|
if w.Code != http.StatusBadRequest {
|
|
t.Errorf("expected 400, got %d", w.Code)
|
|
}
|
|
}
|
|
|
|
func TestESTSimpleEnroll_InvalidCSR(t *testing.T) {
|
|
svc := &mockESTService{}
|
|
h := NewESTHandler(svc)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/simpleenroll", strings.NewReader("not-a-valid-csr"))
|
|
w := httptest.NewRecorder()
|
|
h.SimpleEnroll(w, req)
|
|
|
|
if w.Code != http.StatusBadRequest {
|
|
t.Errorf("expected 400, got %d", w.Code)
|
|
}
|
|
}
|
|
|
|
func TestESTSimpleEnroll_ServiceError(t *testing.T) {
|
|
csrPEM := generateTestCSRPEM(t)
|
|
svc := &mockESTService{
|
|
EnrollErr: errors.New("issuance failed"),
|
|
}
|
|
h := NewESTHandler(svc)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/simpleenroll", strings.NewReader(csrPEM))
|
|
w := httptest.NewRecorder()
|
|
h.SimpleEnroll(w, req)
|
|
|
|
if w.Code != http.StatusInternalServerError {
|
|
t.Errorf("expected 500, got %d", w.Code)
|
|
}
|
|
}
|
|
|
|
func TestESTSimpleReEnroll_Success(t *testing.T) {
|
|
csrPEM := generateTestCSRPEM(t)
|
|
certPEM := generateTestCertPEM(t)
|
|
svc := &mockESTService{
|
|
EnrollResult: &domain.ESTEnrollResult{
|
|
CertPEM: certPEM,
|
|
ChainPEM: certPEM,
|
|
},
|
|
}
|
|
h := NewESTHandler(svc)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/simplereenroll", strings.NewReader(csrPEM))
|
|
w := httptest.NewRecorder()
|
|
h.SimpleReEnroll(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestESTSimpleReEnroll_MethodNotAllowed(t *testing.T) {
|
|
svc := &mockESTService{}
|
|
h := NewESTHandler(svc)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/.well-known/est/simplereenroll", nil)
|
|
w := httptest.NewRecorder()
|
|
h.SimpleReEnroll(w, req)
|
|
|
|
if w.Code != http.StatusMethodNotAllowed {
|
|
t.Errorf("expected 405, got %d", w.Code)
|
|
}
|
|
}
|
|
|
|
func TestESTCSRAttrs_NoContent(t *testing.T) {
|
|
svc := &mockESTService{
|
|
CSRAttrs: nil,
|
|
}
|
|
h := NewESTHandler(svc)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/.well-known/est/csrattrs", nil)
|
|
w := httptest.NewRecorder()
|
|
h.CSRAttrs(w, req)
|
|
|
|
if w.Code != http.StatusNoContent {
|
|
t.Errorf("expected 204, got %d", w.Code)
|
|
}
|
|
}
|
|
|
|
func TestESTCSRAttrs_WithData(t *testing.T) {
|
|
svc := &mockESTService{
|
|
CSRAttrs: []byte{0x30, 0x00}, // empty SEQUENCE
|
|
}
|
|
h := NewESTHandler(svc)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/.well-known/est/csrattrs", nil)
|
|
w := httptest.NewRecorder()
|
|
h.CSRAttrs(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected 200, got %d", w.Code)
|
|
}
|
|
ct := w.Header().Get("Content-Type")
|
|
if ct != "application/csrattrs" {
|
|
t.Errorf("expected application/csrattrs, got %s", ct)
|
|
}
|
|
}
|
|
|
|
func TestESTCSRAttrs_MethodNotAllowed(t *testing.T) {
|
|
svc := &mockESTService{}
|
|
h := NewESTHandler(svc)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/.well-known/est/csrattrs", nil)
|
|
w := httptest.NewRecorder()
|
|
h.CSRAttrs(w, req)
|
|
|
|
if w.Code != http.StatusMethodNotAllowed {
|
|
t.Errorf("expected 405, got %d", w.Code)
|
|
}
|
|
}
|
|
|
|
func TestBuildCertsOnlyPKCS7(t *testing.T) {
|
|
// Test with a dummy DER certificate
|
|
dummyCert := []byte{0x30, 0x82, 0x01, 0x00} // minimal ASN.1 SEQUENCE
|
|
result, err := buildCertsOnlyPKCS7([][]byte{dummyCert})
|
|
if err != nil {
|
|
t.Fatalf("buildCertsOnlyPKCS7 failed: %v", err)
|
|
}
|
|
if len(result) == 0 {
|
|
t.Error("expected non-empty PKCS#7 output")
|
|
}
|
|
// Verify it starts with SEQUENCE tag
|
|
if result[0] != 0x30 {
|
|
t.Errorf("expected PKCS#7 to start with SEQUENCE tag (0x30), got 0x%02x", result[0])
|
|
}
|
|
}
|
|
|
|
func TestPemToDERChain(t *testing.T) {
|
|
pemData := generateTestCertPEM(t)
|
|
certs, err := pemToDERChain(pemData)
|
|
if err != nil {
|
|
t.Fatalf("pemToDERChain failed: %v", err)
|
|
}
|
|
if len(certs) != 1 {
|
|
t.Errorf("expected 1 cert, got %d", len(certs))
|
|
}
|
|
}
|
|
|
|
func TestPemToDERChain_NoCerts(t *testing.T) {
|
|
_, err := pemToDERChain("not a PEM")
|
|
if err == nil {
|
|
t.Error("expected error for invalid PEM")
|
|
}
|
|
}
|
|
|
|
func TestASN1EncodeLength(t *testing.T) {
|
|
tests := []struct {
|
|
length int
|
|
expected []byte
|
|
}{
|
|
{0, []byte{0x00}},
|
|
{1, []byte{0x01}},
|
|
{127, []byte{0x7f}},
|
|
{128, []byte{0x81, 0x80}},
|
|
{256, []byte{0x82, 0x01, 0x00}},
|
|
}
|
|
for _, tt := range tests {
|
|
result := asn1EncodeLength(tt.length)
|
|
if len(result) != len(tt.expected) {
|
|
t.Errorf("asn1EncodeLength(%d): expected %d bytes, got %d", tt.length, len(tt.expected), len(result))
|
|
continue
|
|
}
|
|
for i := range result {
|
|
if result[i] != tt.expected[i] {
|
|
t.Errorf("asn1EncodeLength(%d): byte %d: expected 0x%02x, got 0x%02x", tt.length, i, tt.expected[i], result[i])
|
|
}
|
|
}
|
|
}
|
|
}
|