mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-12 12:49:00 +00:00
feat: M20 Enhanced Query API — sort, time-range filters, cursor pagination, sparse fields, deployments endpoint
V2 (free) query enhancements for certificates:
- `sort` param with direction (`?sort=-notAfter` for descending)
- Time-range filters: `expires_before`, `expires_after`, `created_after`, `updated_after`
- Cursor-based pagination (`?cursor=token&page_size=100`) alongside page-based
- Sparse field selection (`?fields=id,commonName,status`)
- Additional filters: `agent_id`, `profile_id`
- New endpoint: `GET /api/v1/certificates/{id}/deployments`
25 new tests (12 handler + 13 e2e) covering all M20 features.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -193,6 +193,7 @@ func main() {
|
|||||||
certificateService.SetNotificationService(notificationService)
|
certificateService.SetNotificationService(notificationService)
|
||||||
certificateService.SetIssuerRegistry(issuerRegistry)
|
certificateService.SetIssuerRegistry(issuerRegistry)
|
||||||
certificateService.SetProfileRepo(profileRepo)
|
certificateService.SetProfileRepo(profileRepo)
|
||||||
|
certificateService.SetTargetRepo(targetRepo)
|
||||||
renewalService := service.NewRenewalService(certificateRepo, jobRepo, renewalPolicyRepo, profileRepo, auditService, notificationService, issuerRegistry, cfg.Keygen.Mode)
|
renewalService := service.NewRenewalService(certificateRepo, jobRepo, renewalPolicyRepo, profileRepo, auditService, notificationService, issuerRegistry, cfg.Keygen.Mode)
|
||||||
deploymentService := service.NewDeploymentService(jobRepo, targetRepo, agentRepo, certificateRepo, auditService, notificationService)
|
deploymentService := service.NewDeploymentService(jobRepo, targetRepo, agentRepo, certificateRepo, auditService, notificationService)
|
||||||
jobService := service.NewJobService(jobRepo, renewalService, deploymentService, logger)
|
jobService := service.NewJobService(jobRepo, renewalService, deploymentService, logger)
|
||||||
|
|||||||
@@ -12,22 +12,25 @@ import (
|
|||||||
|
|
||||||
"github.com/shankar0123/certctl/internal/api/middleware"
|
"github.com/shankar0123/certctl/internal/api/middleware"
|
||||||
"github.com/shankar0123/certctl/internal/domain"
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
"github.com/shankar0123/certctl/internal/repository"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockCertificateService is a mock implementation of CertificateService interface.
|
// MockCertificateService is a mock implementation of CertificateService interface.
|
||||||
type MockCertificateService struct {
|
type MockCertificateService struct {
|
||||||
ListCertificatesFn func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error)
|
ListCertificatesFn func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error)
|
||||||
GetCertificateFn func(id string) (*domain.ManagedCertificate, error)
|
ListCertificatesWithFilterFn func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error)
|
||||||
CreateCertificateFn func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
|
GetCertificateFn func(id string) (*domain.ManagedCertificate, error)
|
||||||
UpdateCertificateFn func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
|
CreateCertificateFn func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
|
||||||
ArchiveCertificateFn func(id string) error
|
UpdateCertificateFn func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
|
||||||
GetCertificateVersionsFn func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error)
|
ArchiveCertificateFn func(id string) error
|
||||||
TriggerRenewalFn func(certID string) error
|
GetCertificateVersionsFn func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error)
|
||||||
TriggerDeploymentFn func(certID string, targetID string) error
|
TriggerRenewalFn func(certID string) error
|
||||||
RevokeCertificateFn func(certID string, reason string) error
|
TriggerDeploymentFn func(certID string, targetID string) error
|
||||||
GetRevokedCertificatesFn func() ([]*domain.CertificateRevocation, error)
|
RevokeCertificateFn func(certID string, reason string) error
|
||||||
GenerateDERCRLFn func(issuerID string) ([]byte, error)
|
GetRevokedCertificatesFn func() ([]*domain.CertificateRevocation, error)
|
||||||
GetOCSPResponseFn func(issuerID string, serialHex string) ([]byte, error)
|
GenerateDERCRLFn func(issuerID string) ([]byte, error)
|
||||||
|
GetOCSPResponseFn func(issuerID string, serialHex string) ([]byte, error)
|
||||||
|
GetCertificateDeploymentsFn func(certID string) ([]domain.DeploymentTarget, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockCertificateService) ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
|
func (m *MockCertificateService) ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
|
||||||
@@ -114,6 +117,20 @@ func (m *MockCertificateService) GetOCSPResponse(issuerID string, serialHex stri
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockCertificateService) ListCertificatesWithFilter(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
|
if m.ListCertificatesWithFilterFn != nil {
|
||||||
|
return m.ListCertificatesWithFilterFn(filter)
|
||||||
|
}
|
||||||
|
return nil, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockCertificateService) GetCertificateDeployments(certID string) ([]domain.DeploymentTarget, error) {
|
||||||
|
if m.GetCertificateDeploymentsFn != nil {
|
||||||
|
return m.GetCertificateDeploymentsFn(certID)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function to create context with request ID.
|
// Helper function to create context with request ID.
|
||||||
func contextWithRequestID() context.Context {
|
func contextWithRequestID() context.Context {
|
||||||
return context.WithValue(context.Background(), middleware.RequestIDKey{}, "test-request-id-123")
|
return context.WithValue(context.Background(), middleware.RequestIDKey{}, "test-request-id-123")
|
||||||
@@ -141,8 +158,8 @@ func TestListCertificates_Success(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
|
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
if page == 1 && perPage == 50 {
|
if filter.Page == 1 && filter.PerPage == 50 {
|
||||||
return []domain.ManagedCertificate{cert1, cert2}, 2, nil
|
return []domain.ManagedCertificate{cert1, cert2}, 2, nil
|
||||||
}
|
}
|
||||||
return nil, 0, nil
|
return nil, 0, nil
|
||||||
@@ -180,8 +197,8 @@ func TestListCertificates_Success(t *testing.T) {
|
|||||||
// Test ListCertificates - with filters
|
// Test ListCertificates - with filters
|
||||||
func TestListCertificates_WithFilters(t *testing.T) {
|
func TestListCertificates_WithFilters(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
|
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
if status == "Active" && environment == "prod" {
|
if filter.Status == "Active" && filter.Environment == "prod" {
|
||||||
return []domain.ManagedCertificate{}, 0, nil
|
return []domain.ManagedCertificate{}, 0, nil
|
||||||
}
|
}
|
||||||
return nil, 0, nil
|
return nil, 0, nil
|
||||||
@@ -219,7 +236,7 @@ func TestListCertificates_MethodNotAllowed(t *testing.T) {
|
|||||||
// Test ListCertificates - service error
|
// Test ListCertificates - service error
|
||||||
func TestListCertificates_ServiceError(t *testing.T) {
|
func TestListCertificates_ServiceError(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
|
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
return nil, 0, ErrMockServiceFailed
|
return nil, 0, ErrMockServiceFailed
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -697,9 +714,9 @@ func TestTriggerDeployment_NoTargetID(t *testing.T) {
|
|||||||
// Test ListCertificates - invalid page parameter
|
// Test ListCertificates - invalid page parameter
|
||||||
func TestListCertificates_InvalidPageParam(t *testing.T) {
|
func TestListCertificates_InvalidPageParam(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
|
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
// Should default to page 1
|
// Should default to page 1
|
||||||
if page == 1 {
|
if filter.Page == 1 {
|
||||||
return []domain.ManagedCertificate{}, 0, nil
|
return []domain.ManagedCertificate{}, 0, nil
|
||||||
}
|
}
|
||||||
return nil, 0, nil
|
return nil, 0, nil
|
||||||
@@ -721,9 +738,9 @@ func TestListCertificates_InvalidPageParam(t *testing.T) {
|
|||||||
// Test ListCertificates - per_page exceeds max
|
// Test ListCertificates - per_page exceeds max
|
||||||
func TestListCertificates_PerPageExceedsMax(t *testing.T) {
|
func TestListCertificates_PerPageExceedsMax(t *testing.T) {
|
||||||
mock := &MockCertificateService{
|
mock := &MockCertificateService{
|
||||||
ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) {
|
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
// Should cap perPage at 500
|
// Should cap perPage at 500
|
||||||
if perPage == 50 { // defaults to 50 if > 500
|
if filter.PerPage == 50 { // defaults to 50 if > 500
|
||||||
return []domain.ManagedCertificate{}, 0, nil
|
return []domain.ManagedCertificate{}, 0, nil
|
||||||
}
|
}
|
||||||
return nil, 0, nil
|
return nil, 0, nil
|
||||||
@@ -1236,3 +1253,388 @@ func TestHandleOCSP_MethodNotAllowed(t *testing.T) {
|
|||||||
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
|
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// === M20 Enhanced Query API Tests ===
|
||||||
|
|
||||||
|
// TestListCertificates_SortParam tests sort parameter parsing and passing to service.
|
||||||
|
func TestListCertificates_SortParam(t *testing.T) {
|
||||||
|
mock := &MockCertificateService{
|
||||||
|
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
|
// Handler strips the '-' prefix and sets SortDesc = true
|
||||||
|
if filter.Sort != "notAfter" || !filter.SortDesc {
|
||||||
|
t.Errorf("expected sort=notAfter desc=true, got sort=%s desc=%v", filter.Sort, filter.SortDesc)
|
||||||
|
}
|
||||||
|
return []domain.ManagedCertificate{}, 0, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
handler := NewCertificateHandler(mock)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?sort=-notAfter", nil)
|
||||||
|
req = req.WithContext(contextWithRequestID())
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ListCertificates(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestListCertificates_SortParam_Ascending tests sort parameter without '-' prefix (ascending).
|
||||||
|
func TestListCertificates_SortParam_Ascending(t *testing.T) {
|
||||||
|
mock := &MockCertificateService{
|
||||||
|
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
|
if filter.Sort != "createdAt" || filter.SortDesc {
|
||||||
|
t.Errorf("expected sort=createdAt desc=false, got sort=%s desc=%v", filter.Sort, filter.SortDesc)
|
||||||
|
}
|
||||||
|
return []domain.ManagedCertificate{}, 0, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
handler := NewCertificateHandler(mock)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?sort=createdAt", nil)
|
||||||
|
req = req.WithContext(contextWithRequestID())
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ListCertificates(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestListCertificates_TimeRangeFilters tests time-range filter parsing.
|
||||||
|
func TestListCertificates_TimeRangeFilters(t *testing.T) {
|
||||||
|
before := time.Now().AddDate(0, 0, 90)
|
||||||
|
after := time.Now().AddDate(0, 0, -90)
|
||||||
|
|
||||||
|
mock := &MockCertificateService{
|
||||||
|
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
|
if filter.ExpiresBefore == nil {
|
||||||
|
t.Error("expected ExpiresBefore to be set")
|
||||||
|
}
|
||||||
|
if filter.ExpiresAfter == nil {
|
||||||
|
t.Error("expected ExpiresAfter to be set")
|
||||||
|
}
|
||||||
|
return []domain.ManagedCertificate{}, 0, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := NewCertificateHandler(mock)
|
||||||
|
url := fmt.Sprintf("/api/v1/certificates?expires_before=%s&expires_after=%s",
|
||||||
|
before.Format(time.RFC3339), after.Format(time.RFC3339))
|
||||||
|
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||||
|
req = req.WithContext(contextWithRequestID())
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ListCertificates(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestListCertificates_CreatedAfterFilter tests created_after filter parsing.
|
||||||
|
func TestListCertificates_CreatedAfterFilter(t *testing.T) {
|
||||||
|
past := time.Now().AddDate(-1, 0, 0)
|
||||||
|
|
||||||
|
mock := &MockCertificateService{
|
||||||
|
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
|
if filter.CreatedAfter == nil {
|
||||||
|
t.Error("expected CreatedAfter to be set")
|
||||||
|
}
|
||||||
|
return []domain.ManagedCertificate{}, 0, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := NewCertificateHandler(mock)
|
||||||
|
url := fmt.Sprintf("/api/v1/certificates?created_after=%s", past.Format(time.RFC3339))
|
||||||
|
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||||
|
req = req.WithContext(contextWithRequestID())
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ListCertificates(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestListCertificates_CursorPagination tests cursor-based pagination response.
|
||||||
|
func TestListCertificates_CursorPagination(t *testing.T) {
|
||||||
|
cert := domain.ManagedCertificate{
|
||||||
|
ID: "mc-cursor-test-1",
|
||||||
|
CommonName: "cursor.example.com",
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
mock := &MockCertificateService{
|
||||||
|
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
|
return []domain.ManagedCertificate{cert}, 1, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := NewCertificateHandler(mock)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?cursor=abc123&page_size=10", nil)
|
||||||
|
req = req.WithContext(contextWithRequestID())
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ListCertificates(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp CursorPagedResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.NextCursor == "" {
|
||||||
|
t.Error("expected NextCursor to be populated with cursor pagination")
|
||||||
|
}
|
||||||
|
if resp.PageSize != 10 {
|
||||||
|
t.Errorf("expected PageSize=10, got %d", resp.PageSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestListCertificates_SparseFields tests field filtering in response.
|
||||||
|
func TestListCertificates_SparseFields(t *testing.T) {
|
||||||
|
cert := domain.ManagedCertificate{
|
||||||
|
ID: "mc-sparse-test-1",
|
||||||
|
Name: "Sparse Test Cert",
|
||||||
|
CommonName: "sparse.example.com",
|
||||||
|
Environment: "staging",
|
||||||
|
Status: domain.CertificateStatusActive,
|
||||||
|
}
|
||||||
|
|
||||||
|
mock := &MockCertificateService{
|
||||||
|
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
|
if len(filter.Fields) != 2 {
|
||||||
|
t.Errorf("expected 2 fields, got %d", len(filter.Fields))
|
||||||
|
}
|
||||||
|
return []domain.ManagedCertificate{cert}, 1, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := NewCertificateHandler(mock)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?fields=id,common_name", nil)
|
||||||
|
req = req.WithContext(contextWithRequestID())
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ListCertificates(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp PagedResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response data should have sparse fields applied
|
||||||
|
data, ok := resp.Data.([]interface{})
|
||||||
|
if !ok || len(data) == 0 {
|
||||||
|
t.Fatal("expected data array in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
certMap, ok := data[0].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected cert object in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that requested fields are present
|
||||||
|
if _, ok := certMap["id"]; !ok {
|
||||||
|
t.Error("expected 'id' field in filtered response")
|
||||||
|
}
|
||||||
|
if _, ok := certMap["common_name"]; !ok {
|
||||||
|
t.Error("expected 'common_name' field in filtered response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestListCertificates_ProfileFilter tests profile_id filter.
|
||||||
|
func TestListCertificates_ProfileFilter(t *testing.T) {
|
||||||
|
mock := &MockCertificateService{
|
||||||
|
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
|
if filter.ProfileID != "prof-standard" {
|
||||||
|
t.Errorf("expected ProfileID=prof-standard, got %s", filter.ProfileID)
|
||||||
|
}
|
||||||
|
return []domain.ManagedCertificate{}, 0, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := NewCertificateHandler(mock)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?profile_id=prof-standard", nil)
|
||||||
|
req = req.WithContext(contextWithRequestID())
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ListCertificates(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestListCertificates_AgentIDFilter tests agent_id filter.
|
||||||
|
func TestListCertificates_AgentIDFilter(t *testing.T) {
|
||||||
|
mock := &MockCertificateService{
|
||||||
|
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
|
if filter.AgentID != "agent-prod-001" {
|
||||||
|
t.Errorf("expected AgentID=agent-prod-001, got %s", filter.AgentID)
|
||||||
|
}
|
||||||
|
return []domain.ManagedCertificate{}, 0, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := NewCertificateHandler(mock)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?agent_id=agent-prod-001", nil)
|
||||||
|
req = req.WithContext(contextWithRequestID())
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ListCertificates(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestListCertificates_CombinedFilters tests multiple filters together.
|
||||||
|
func TestListCertificates_CombinedFilters(t *testing.T) {
|
||||||
|
mock := &MockCertificateService{
|
||||||
|
ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
|
if filter.Status != "Active" || filter.Environment != "production" || filter.ProfileID != "prof-standard" {
|
||||||
|
t.Error("expected all filters to be set")
|
||||||
|
}
|
||||||
|
return []domain.ManagedCertificate{}, 0, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := NewCertificateHandler(mock)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?status=Active&environment=production&profile_id=prof-standard", nil)
|
||||||
|
req = req.WithContext(contextWithRequestID())
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ListCertificates(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetCertificateDeployments_Success tests retrieving deployments for a certificate.
|
||||||
|
func TestGetCertificateDeployments_Success(t *testing.T) {
|
||||||
|
deployments := []domain.DeploymentTarget{
|
||||||
|
{
|
||||||
|
ID: "t-nginx-prod-1",
|
||||||
|
Name: "NGINX Production",
|
||||||
|
Type: "NGINX",
|
||||||
|
Config: json.RawMessage(`{"cert_path": "/etc/nginx/ssl/cert.pem"}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "t-haproxy-prod-1",
|
||||||
|
Name: "HAProxy Production",
|
||||||
|
Type: "HAProxy",
|
||||||
|
Config: json.RawMessage(`{"pem_path": "/etc/haproxy/ssl/cert.pem"}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock := &MockCertificateService{
|
||||||
|
GetCertificateDeploymentsFn: func(certID string) ([]domain.DeploymentTarget, error) {
|
||||||
|
if certID != "mc-prod-001" {
|
||||||
|
return nil, ErrMockNotFound
|
||||||
|
}
|
||||||
|
return deployments, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := NewCertificateHandler(mock)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates/mc-prod-001/deployments", nil)
|
||||||
|
req = req.WithContext(contextWithRequestID())
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.GetCertificateDeployments(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if data, ok := resp["data"].([]interface{}); !ok || len(data) != 2 {
|
||||||
|
t.Errorf("expected 2 deployments in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if total, ok := resp["total"].(float64); !ok || int(total) != 2 {
|
||||||
|
t.Errorf("expected total=2, got %v", resp["total"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetCertificateDeployments_NotFound tests 404 for nonexistent certificate.
|
||||||
|
func TestGetCertificateDeployments_NotFound(t *testing.T) {
|
||||||
|
mock := &MockCertificateService{
|
||||||
|
GetCertificateDeploymentsFn: func(certID string) ([]domain.DeploymentTarget, error) {
|
||||||
|
return nil, fmt.Errorf("certificate not found")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := NewCertificateHandler(mock)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates/mc-nonexistent/deployments", nil)
|
||||||
|
req = req.WithContext(contextWithRequestID())
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.GetCertificateDeployments(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("expected 404, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetCertificateDeployments_Empty tests successful response with no deployments.
|
||||||
|
func TestGetCertificateDeployments_Empty(t *testing.T) {
|
||||||
|
mock := &MockCertificateService{
|
||||||
|
GetCertificateDeploymentsFn: func(certID string) ([]domain.DeploymentTarget, error) {
|
||||||
|
if certID == "mc-no-deployments" {
|
||||||
|
return []domain.DeploymentTarget{}, nil
|
||||||
|
}
|
||||||
|
return nil, ErrMockNotFound
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := NewCertificateHandler(mock)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates/mc-no-deployments/deployments", nil)
|
||||||
|
req = req.WithContext(contextWithRequestID())
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.GetCertificateDeployments(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if total, ok := resp["total"].(float64); !ok || int(total) != 0 {
|
||||||
|
t.Errorf("expected total=0, got %v", resp["total"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetCertificateDeployments_MethodNotAllowed tests 405 for non-GET requests.
|
||||||
|
func TestGetCertificateDeployments_MethodNotAllowed(t *testing.T) {
|
||||||
|
mock := &MockCertificateService{}
|
||||||
|
handler := NewCertificateHandler(mock)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates/mc-prod-001/deployments", nil)
|
||||||
|
req = req.WithContext(contextWithRequestID())
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.GetCertificateDeployments(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusMethodNotAllowed {
|
||||||
|
t.Errorf("expected 405, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,11 +9,13 @@ import (
|
|||||||
|
|
||||||
"github.com/shankar0123/certctl/internal/api/middleware"
|
"github.com/shankar0123/certctl/internal/api/middleware"
|
||||||
"github.com/shankar0123/certctl/internal/domain"
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
"github.com/shankar0123/certctl/internal/repository"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CertificateService defines the service interface for certificate operations.
|
// CertificateService defines the service interface for certificate operations.
|
||||||
type CertificateService interface {
|
type CertificateService interface {
|
||||||
ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error)
|
ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error)
|
||||||
|
ListCertificatesWithFilter(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error)
|
||||||
GetCertificate(id string) (*domain.ManagedCertificate, error)
|
GetCertificate(id string) (*domain.ManagedCertificate, error)
|
||||||
CreateCertificate(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
|
CreateCertificate(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
|
||||||
UpdateCertificate(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
|
UpdateCertificate(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error)
|
||||||
@@ -25,6 +27,7 @@ type CertificateService interface {
|
|||||||
GetRevokedCertificates() ([]*domain.CertificateRevocation, error)
|
GetRevokedCertificates() ([]*domain.CertificateRevocation, error)
|
||||||
GenerateDERCRL(issuerID string) ([]byte, error)
|
GenerateDERCRL(issuerID string) ([]byte, error)
|
||||||
GetOCSPResponse(issuerID string, serialHex string) ([]byte, error)
|
GetOCSPResponse(issuerID string, serialHex string) ([]byte, error)
|
||||||
|
GetCertificateDeployments(certID string) ([]domain.DeploymentTarget, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CertificateHandler handles HTTP requests for certificate operations.
|
// CertificateHandler handles HTTP requests for certificate operations.
|
||||||
@@ -38,7 +41,7 @@ func NewCertificateHandler(svc CertificateService) CertificateHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListCertificates lists certificates with optional filtering.
|
// ListCertificates lists certificates with optional filtering.
|
||||||
// GET /api/v1/certificates?status=Active&environment=prod&owner_id=...&team_id=...&issuer_id=...&page=1&per_page=50
|
// GET /api/v1/certificates?status=Active&environment=prod&owner_id=...&team_id=...&issuer_id=...&agent_id=...&profile_id=...&expires_before=...&expires_after=...&created_after=...&updated_after=...&sort=notAfter&sort_desc=false&cursor=...&page=1&per_page=50&fields=id,commonName,status
|
||||||
func (h CertificateHandler) ListCertificates(w http.ResponseWriter, r *http.Request) {
|
func (h CertificateHandler) ListCertificates(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodGet {
|
if r.Method != http.MethodGet {
|
||||||
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
|
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
|
||||||
@@ -49,12 +52,56 @@ func (h CertificateHandler) ListCertificates(w http.ResponseWriter, r *http.Requ
|
|||||||
|
|
||||||
// Parse query parameters
|
// Parse query parameters
|
||||||
query := r.URL.Query()
|
query := r.URL.Query()
|
||||||
status := query.Get("status")
|
|
||||||
environment := query.Get("environment")
|
|
||||||
ownerID := query.Get("owner_id")
|
|
||||||
teamID := query.Get("team_id")
|
|
||||||
issuerID := query.Get("issuer_id")
|
|
||||||
|
|
||||||
|
// Basic filters
|
||||||
|
filter := &repository.CertificateFilter{
|
||||||
|
Status: query.Get("status"),
|
||||||
|
Environment: query.Get("environment"),
|
||||||
|
OwnerID: query.Get("owner_id"),
|
||||||
|
TeamID: query.Get("team_id"),
|
||||||
|
IssuerID: query.Get("issuer_id"),
|
||||||
|
AgentID: query.Get("agent_id"),
|
||||||
|
ProfileID: query.Get("profile_id"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Time-range filters
|
||||||
|
if eb := query.Get("expires_before"); eb != "" {
|
||||||
|
if t, err := time.Parse(time.RFC3339, eb); err == nil {
|
||||||
|
filter.ExpiresBefore = &t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ea := query.Get("expires_after"); ea != "" {
|
||||||
|
if t, err := time.Parse(time.RFC3339, ea); err == nil {
|
||||||
|
filter.ExpiresAfter = &t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ca := query.Get("created_after"); ca != "" {
|
||||||
|
if t, err := time.Parse(time.RFC3339, ca); err == nil {
|
||||||
|
filter.CreatedAfter = &t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ua := query.Get("updated_after"); ua != "" {
|
||||||
|
if t, err := time.Parse(time.RFC3339, ua); err == nil {
|
||||||
|
filter.UpdatedAfter = &t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sorting
|
||||||
|
if sort := query.Get("sort"); sort != "" {
|
||||||
|
// Handle sort direction prefix
|
||||||
|
if strings.HasPrefix(sort, "-") {
|
||||||
|
filter.Sort = sort[1:]
|
||||||
|
filter.SortDesc = true
|
||||||
|
} else {
|
||||||
|
filter.Sort = sort
|
||||||
|
filter.SortDesc = query.Get("sort_desc") == "true"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cursor-based pagination
|
||||||
|
filter.Cursor = query.Get("cursor")
|
||||||
|
|
||||||
|
// Page-based pagination
|
||||||
page := 1
|
page := 1
|
||||||
perPage := 50
|
perPage := 50
|
||||||
if p := query.Get("page"); p != "" {
|
if p := query.Get("page"); p != "" {
|
||||||
@@ -67,21 +114,59 @@ func (h CertificateHandler) ListCertificates(w http.ResponseWriter, r *http.Requ
|
|||||||
perPage = parsed
|
perPage = parsed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if ps := query.Get("page_size"); ps != "" {
|
||||||
|
if parsed, err := strconv.Atoi(ps); err == nil && parsed > 0 && parsed <= 500 {
|
||||||
|
filter.PageSize = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filter.Page = page
|
||||||
|
filter.PerPage = perPage
|
||||||
|
|
||||||
certs, total, err := h.svc.ListCertificates(status, environment, ownerID, teamID, issuerID, page, perPage)
|
// Sparse fields
|
||||||
|
if fieldsStr := query.Get("fields"); fieldsStr != "" {
|
||||||
|
filter.Fields = strings.Split(fieldsStr, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
certs, total, err := h.svc.ListCertificatesWithFilter(filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list certificates", requestID)
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list certificates", requestID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response := PagedResponse{
|
// Apply sparse field filtering if requested
|
||||||
Data: certs,
|
var responseData interface{} = certs
|
||||||
Total: total,
|
if len(filter.Fields) > 0 {
|
||||||
Page: page,
|
responseData = filterFields(certs, filter.Fields)
|
||||||
PerPage: perPage,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
JSON(w, http.StatusOK, response)
|
// Return cursor-based or page-based response depending on which pagination is used
|
||||||
|
if filter.Cursor != "" {
|
||||||
|
// Compute next cursor from last result
|
||||||
|
nextCursor := ""
|
||||||
|
if len(certs) > 0 {
|
||||||
|
lastCert := certs[len(certs)-1]
|
||||||
|
nextCursor = encodeCursor(lastCert.CreatedAt, lastCert.ID)
|
||||||
|
}
|
||||||
|
pageSize := filter.PageSize
|
||||||
|
if pageSize == 0 {
|
||||||
|
pageSize = filter.PerPage
|
||||||
|
}
|
||||||
|
response := CursorPagedResponse{
|
||||||
|
Data: responseData,
|
||||||
|
Total: int64(total),
|
||||||
|
NextCursor: nextCursor,
|
||||||
|
PageSize: pageSize,
|
||||||
|
}
|
||||||
|
JSON(w, http.StatusOK, response)
|
||||||
|
} else {
|
||||||
|
response := PagedResponse{
|
||||||
|
Data: responseData,
|
||||||
|
Total: int64(total),
|
||||||
|
Page: page,
|
||||||
|
PerPage: perPage,
|
||||||
|
}
|
||||||
|
JSON(w, http.StatusOK, response)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCertificate retrieves a single certificate by ID.
|
// GetCertificate retrieves a single certificate by ID.
|
||||||
@@ -525,3 +610,39 @@ func (h CertificateHandler) HandleOCSP(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write(derBytes)
|
w.Write(derBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCertificateDeployments retrieves all deployment targets for a certificate.
|
||||||
|
// GET /api/v1/certificates/{id}/deployments
|
||||||
|
func (h CertificateHandler) GetCertificateDeployments(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
Error(w, http.StatusMethodNotAllowed, "Method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
requestID := middleware.GetRequestID(r.Context())
|
||||||
|
|
||||||
|
// Extract certificate ID from path /api/v1/certificates/{id}/deployments
|
||||||
|
path := strings.TrimPrefix(r.URL.Path, "/api/v1/certificates/")
|
||||||
|
parts := strings.Split(path, "/")
|
||||||
|
if len(parts) < 2 || parts[0] == "" {
|
||||||
|
ErrorWithRequestID(w, http.StatusBadRequest, "Certificate ID is required", requestID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
certID := parts[0]
|
||||||
|
|
||||||
|
deployments, err := h.svc.GetCertificateDeployments(certID)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := err.Error()
|
||||||
|
if strings.Contains(errMsg, "not found") {
|
||||||
|
ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to get deployments", requestID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
JSON(w, http.StatusOK, map[string]interface{}{
|
||||||
|
"data": deployments,
|
||||||
|
"total": len(deployments),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PagedResponse represents a paginated API response.
|
// PagedResponse represents a paginated API response.
|
||||||
@@ -13,6 +17,14 @@ type PagedResponse struct {
|
|||||||
PerPage int `json:"per_page"`
|
PerPage int `json:"per_page"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CursorPagedResponse represents a cursor-paginated API response.
|
||||||
|
type CursorPagedResponse struct {
|
||||||
|
Data interface{} `json:"data"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
NextCursor string `json:"next_cursor,omitempty"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
|
||||||
// ErrorResponse represents a standard error response.
|
// ErrorResponse represents a standard error response.
|
||||||
type ErrorResponse struct {
|
type ErrorResponse struct {
|
||||||
Error string `json:"error"`
|
Error string `json:"error"`
|
||||||
@@ -49,3 +61,72 @@ func ErrorWithRequestID(w http.ResponseWriter, status int, message, requestID st
|
|||||||
w.WriteHeader(status)
|
w.WriteHeader(status)
|
||||||
return json.NewEncoder(w).Encode(errResp)
|
return json.NewEncoder(w).Encode(errResp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// encodeCursor creates an opaque cursor token from a timestamp and ID.
|
||||||
|
func encodeCursor(createdAt time.Time, id string) string {
|
||||||
|
raw := createdAt.Format(time.RFC3339Nano) + ":" + id
|
||||||
|
return base64.URLEncoding.EncodeToString([]byte(raw))
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeCursor extracts a timestamp and ID from a cursor token.
|
||||||
|
func decodeCursor(cursor string) (time.Time, string, error) {
|
||||||
|
raw, err := base64.URLEncoding.DecodeString(cursor)
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, "", fmt.Errorf("invalid cursor: %w", err)
|
||||||
|
}
|
||||||
|
parts := strings.SplitN(string(raw), ":", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return time.Time{}, "", fmt.Errorf("invalid cursor format")
|
||||||
|
}
|
||||||
|
t, err := time.Parse(time.RFC3339Nano, parts[0])
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, "", fmt.Errorf("invalid cursor timestamp: %w", err)
|
||||||
|
}
|
||||||
|
return t, parts[1], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterFields removes fields not in the allowed list from the response data.
|
||||||
|
// Works with both single objects and slices.
|
||||||
|
func filterFields(data interface{}, fields []string) interface{} {
|
||||||
|
if len(fields) == 0 {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create field set for O(1) lookup
|
||||||
|
fieldSet := make(map[string]bool, len(fields))
|
||||||
|
for _, f := range fields {
|
||||||
|
fieldSet[f] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal to JSON, then unmarshal to generic structure
|
||||||
|
bytes, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try as array first
|
||||||
|
var arr []map[string]interface{}
|
||||||
|
if err := json.Unmarshal(bytes, &arr); err == nil {
|
||||||
|
for i := range arr {
|
||||||
|
for key := range arr[i] {
|
||||||
|
if !fieldSet[key] {
|
||||||
|
delete(arr[i], key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return arr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try as object
|
||||||
|
var obj map[string]interface{}
|
||||||
|
if err := json.Unmarshal(bytes, &obj); err == nil {
|
||||||
|
for key := range obj {
|
||||||
|
if !fieldSet[key] {
|
||||||
|
delete(obj, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return obj
|
||||||
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|||||||
@@ -88,6 +88,7 @@ func (r *Router) RegisterHandlers(
|
|||||||
r.Register("PUT /api/v1/certificates/{id}", http.HandlerFunc(certificates.UpdateCertificate))
|
r.Register("PUT /api/v1/certificates/{id}", http.HandlerFunc(certificates.UpdateCertificate))
|
||||||
r.Register("DELETE /api/v1/certificates/{id}", http.HandlerFunc(certificates.ArchiveCertificate))
|
r.Register("DELETE /api/v1/certificates/{id}", http.HandlerFunc(certificates.ArchiveCertificate))
|
||||||
r.Register("GET /api/v1/certificates/{id}/versions", http.HandlerFunc(certificates.GetCertificateVersions))
|
r.Register("GET /api/v1/certificates/{id}/versions", http.HandlerFunc(certificates.GetCertificateVersions))
|
||||||
|
r.Register("GET /api/v1/certificates/{id}/deployments", http.HandlerFunc(certificates.GetCertificateDeployments))
|
||||||
r.Register("POST /api/v1/certificates/{id}/renew", http.HandlerFunc(certificates.TriggerRenewal))
|
r.Register("POST /api/v1/certificates/{id}/renew", http.HandlerFunc(certificates.TriggerRenewal))
|
||||||
r.Register("POST /api/v1/certificates/{id}/deploy", http.HandlerFunc(certificates.TriggerDeployment))
|
r.Register("POST /api/v1/certificates/{id}/deploy", http.HandlerFunc(certificates.TriggerDeployment))
|
||||||
r.Register("POST /api/v1/certificates/{id}/revoke", http.HandlerFunc(certificates.RevokeCertificate))
|
r.Register("POST /api/v1/certificates/{id}/revoke", http.HandlerFunc(certificates.RevokeCertificate))
|
||||||
|
|||||||
@@ -679,3 +679,216 @@ func TestIssuerAndTargetCRUD(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestM20EnhancedQueryAPI exercises M20 query API enhancements: sorting, time-range filters,
|
||||||
|
// cursor pagination, sparse fields, profile/agent filters, and the deployments endpoint.
|
||||||
|
func TestM20EnhancedQueryAPI(t *testing.T) {
|
||||||
|
server, certRepo, _, _ := setupTestServer(t)
|
||||||
|
|
||||||
|
// Setup: Create a certificate for testing
|
||||||
|
now := time.Now()
|
||||||
|
cert := &domain.ManagedCertificate{
|
||||||
|
ID: "mc-m20-test-1",
|
||||||
|
Name: "M20 Test Cert",
|
||||||
|
CommonName: "m20.example.com",
|
||||||
|
Environment: "production",
|
||||||
|
Status: domain.CertificateStatusActive,
|
||||||
|
IssuerID: "iss-local",
|
||||||
|
OwnerID: "owner-ops",
|
||||||
|
TeamID: "team-platform",
|
||||||
|
CertificateProfileID: "prof-standard",
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
certRepo.certs["mc-m20-test-1"] = cert
|
||||||
|
|
||||||
|
t.Run("ListWithSortDescending", func(t *testing.T) {
|
||||||
|
resp, err := http.Get(server.URL + "/api/v1/certificates?sort=-notAfter&page=1&per_page=10")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
var respBody map[string]interface{}
|
||||||
|
json.NewDecoder(resp.Body).Decode(&respBody)
|
||||||
|
if _, ok := respBody["data"]; !ok {
|
||||||
|
t.Error("expected data field in response")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ListWithSortAscending", func(t *testing.T) {
|
||||||
|
resp, err := http.Get(server.URL + "/api/v1/certificates?sort=createdAt&page=1&per_page=10")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
var respBody map[string]interface{}
|
||||||
|
json.NewDecoder(resp.Body).Decode(&respBody)
|
||||||
|
if _, ok := respBody["page"]; !ok {
|
||||||
|
t.Error("expected page-based pagination response")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TimeRangeFilter_ExpiresBefore", func(t *testing.T) {
|
||||||
|
future := now.AddDate(0, 0, 365).Format(time.RFC3339)
|
||||||
|
resp, err := http.Get(server.URL + "/api/v1/certificates?expires_before=" + future)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Errorf("expected 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TimeRangeFilter_ExpiresAfter", func(t *testing.T) {
|
||||||
|
past := now.AddDate(0, 0, -90).Format(time.RFC3339)
|
||||||
|
resp, err := http.Get(server.URL + "/api/v1/certificates?expires_after=" + past)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TimeRangeFilter_CreatedAfter", func(t *testing.T) {
|
||||||
|
past := now.AddDate(-1, 0, 0).Format(time.RFC3339)
|
||||||
|
resp, err := http.Get(server.URL + "/api/v1/certificates?created_after=" + past)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("SparseFields", func(t *testing.T) {
|
||||||
|
resp, err := http.Get(server.URL + "/api/v1/certificates?fields=id,common_name,status")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Errorf("expected 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
var respBody map[string]interface{}
|
||||||
|
json.NewDecoder(resp.Body).Decode(&respBody)
|
||||||
|
if data, ok := respBody["data"].([]interface{}); ok && len(data) > 0 {
|
||||||
|
firstCert, ok := data[0].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected cert object in data array")
|
||||||
|
}
|
||||||
|
// Should have requested fields
|
||||||
|
if _, ok := firstCert["id"]; !ok {
|
||||||
|
t.Error("expected 'id' field in sparse response")
|
||||||
|
}
|
||||||
|
// Should NOT have unrequested fields like 'environment'
|
||||||
|
if _, ok := firstCert["environment"]; ok {
|
||||||
|
t.Error("did not expect 'environment' field in sparse response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ProfileFilter", func(t *testing.T) {
|
||||||
|
resp, err := http.Get(server.URL + "/api/v1/certificates?profile_id=prof-standard")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AgentIDFilter", func(t *testing.T) {
|
||||||
|
resp, err := http.Get(server.URL + "/api/v1/certificates?agent_id=agent-prod-001")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CursorPagination", func(t *testing.T) {
|
||||||
|
resp, err := http.Get(server.URL + "/api/v1/certificates?cursor=abc123&page_size=10")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
var respBody map[string]interface{}
|
||||||
|
json.NewDecoder(resp.Body).Decode(&respBody)
|
||||||
|
if _, ok := respBody["next_cursor"]; !ok {
|
||||||
|
t.Error("expected next_cursor field with cursor pagination")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CombinedFilters", func(t *testing.T) {
|
||||||
|
resp, err := http.Get(server.URL + "/api/v1/certificates?status=Active&environment=production&profile_id=prof-standard&sort=-createdAt&per_page=10")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetCertificateDeployments_Success", func(t *testing.T) {
|
||||||
|
resp, err := http.Get(server.URL + "/api/v1/certificates/mc-m20-test-1/deployments")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Errorf("expected 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
var respBody map[string]interface{}
|
||||||
|
json.NewDecoder(resp.Body).Decode(&respBody)
|
||||||
|
if _, ok := respBody["data"]; !ok {
|
||||||
|
t.Error("expected data field in response")
|
||||||
|
}
|
||||||
|
if _, ok := respBody["total"]; !ok {
|
||||||
|
t.Error("expected total field in response")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetCertificateDeployments_NotFound", func(t *testing.T) {
|
||||||
|
resp, err := http.Get(server.URL + "/api/v1/certificates/mc-nonexistent-m20/deployments")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusNotFound {
|
||||||
|
t.Errorf("expected 404, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("InvalidTimeRange", func(t *testing.T) {
|
||||||
|
// Invalid RFC3339 should be silently ignored (no filter applied)
|
||||||
|
resp, err := http.Get(server.URL + "/api/v1/certificates?expires_before=not-a-date&page=1&per_page=10")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected 200 (invalid time ignored), got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ func TestCertificateLifecycle(t *testing.T) {
|
|||||||
certificateService.SetRevocationRepo(revocationRepo)
|
certificateService.SetRevocationRepo(revocationRepo)
|
||||||
certificateService.SetNotificationService(notificationService)
|
certificateService.SetNotificationService(notificationService)
|
||||||
certificateService.SetIssuerRegistry(issuerRegistry)
|
certificateService.SetIssuerRegistry(issuerRegistry)
|
||||||
|
certificateService.SetTargetRepo(targetRepo)
|
||||||
renewalService := service.NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, nil, auditService, notificationService, issuerRegistry, "server")
|
renewalService := service.NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, nil, auditService, notificationService, issuerRegistry, "server")
|
||||||
deploymentService := service.NewDeploymentService(jobRepo, targetRepo, agentRepo, certRepo, auditService, notificationService)
|
deploymentService := service.NewDeploymentService(jobRepo, targetRepo, agentRepo, certRepo, auditService, notificationService)
|
||||||
jobService := service.NewJobService(jobRepo, renewalService, deploymentService, logger)
|
jobService := service.NewJobService(jobRepo, renewalService, deploymentService, logger)
|
||||||
|
|||||||
@@ -9,8 +9,27 @@ type CertificateFilter struct {
|
|||||||
OwnerID string
|
OwnerID string
|
||||||
TeamID string
|
TeamID string
|
||||||
IssuerID string
|
IssuerID string
|
||||||
Page int // 1-indexed; default 1
|
AgentID string // filter by agent_id (via deployment targets)
|
||||||
PerPage int // default 50, max 500
|
ProfileID string // filter by certificate_profile_id
|
||||||
|
Page int // 1-indexed; default 1
|
||||||
|
PerPage int // default 50, max 500
|
||||||
|
|
||||||
|
// Time-range filters
|
||||||
|
ExpiresBefore *time.Time // certs expiring before this date
|
||||||
|
ExpiresAfter *time.Time // certs expiring after this date
|
||||||
|
CreatedAfter *time.Time // certs created after this date
|
||||||
|
UpdatedAfter *time.Time // certs updated after this date
|
||||||
|
|
||||||
|
// Sorting
|
||||||
|
Sort string // field name to sort by (e.g., "notAfter", "createdAt", "commonName")
|
||||||
|
SortDesc bool // true = DESC, false = ASC
|
||||||
|
|
||||||
|
// Cursor pagination (alternative to page-based)
|
||||||
|
Cursor string // opaque cursor token (base64-encoded "created_at:id")
|
||||||
|
PageSize int // for cursor-based pagination (alias for PerPage)
|
||||||
|
|
||||||
|
// Sparse fields
|
||||||
|
Fields []string // if non-empty, only return these JSON field names
|
||||||
}
|
}
|
||||||
|
|
||||||
// JobFilter defines filtering criteria for job queries.
|
// JobFilter defines filtering criteria for job queries.
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package postgres
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -68,12 +69,59 @@ func (r *CertificateRepository) List(ctx context.Context, filter *repository.Cer
|
|||||||
args = append(args, filter.IssuerID)
|
args = append(args, filter.IssuerID)
|
||||||
argCount++
|
argCount++
|
||||||
}
|
}
|
||||||
|
if filter.ProfileID != "" {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("certificate_profile_id = $%d", argCount))
|
||||||
|
args = append(args, filter.ProfileID)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
if filter.ExpiresBefore != nil {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("expires_at < $%d", argCount))
|
||||||
|
args = append(args, filter.ExpiresBefore)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
if filter.ExpiresAfter != nil {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("expires_at > $%d", argCount))
|
||||||
|
args = append(args, filter.ExpiresAfter)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
if filter.CreatedAfter != nil {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("created_at > $%d", argCount))
|
||||||
|
args = append(args, filter.CreatedAfter)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
if filter.UpdatedAfter != nil {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("updated_at > $%d", argCount))
|
||||||
|
args = append(args, filter.UpdatedAfter)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
if filter.AgentID != "" {
|
||||||
|
// Filter by agent_id via deployment_targets and certificate_target_mappings
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf(`id IN (
|
||||||
|
SELECT DISTINCT certificate_id FROM certificate_target_mappings ctm
|
||||||
|
JOIN deployment_targets dt ON ctm.target_id = dt.id
|
||||||
|
WHERE dt.agent_id = $%d
|
||||||
|
)`, argCount))
|
||||||
|
args = append(args, filter.AgentID)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
|
||||||
whereClause := ""
|
whereClause := ""
|
||||||
if len(whereConditions) > 0 {
|
if len(whereConditions) > 0 {
|
||||||
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
|
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle cursor-based pagination
|
||||||
|
if filter.Cursor != "" {
|
||||||
|
createdAt, id, err := decodeCursor(filter.Cursor)
|
||||||
|
if err == nil {
|
||||||
|
// Add cursor condition: (created_at, id) < (cursor_time, cursor_id)
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("(created_at, id) < ($%d, $%d)", argCount, argCount+1))
|
||||||
|
args = append(args, createdAt, id)
|
||||||
|
argCount += 2
|
||||||
|
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Get total count
|
// Get total count
|
||||||
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM managed_certificates %s", whereClause)
|
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM managed_certificates %s", whereClause)
|
||||||
var total int
|
var total int
|
||||||
@@ -81,18 +129,59 @@ func (r *CertificateRepository) List(ctx context.Context, filter *repository.Cer
|
|||||||
return nil, 0, fmt.Errorf("failed to count certificates: %w", err)
|
return nil, 0, fmt.Errorf("failed to count certificates: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Determine sort field and direction
|
||||||
|
sortField := "created_at"
|
||||||
|
sortDir := "DESC"
|
||||||
|
sortFieldMap := map[string]string{
|
||||||
|
"notAfter": "expires_at",
|
||||||
|
"expiresAt": "expires_at",
|
||||||
|
"createdAt": "created_at",
|
||||||
|
"updatedAt": "updated_at",
|
||||||
|
"commonName": "common_name",
|
||||||
|
"name": "name",
|
||||||
|
"status": "status",
|
||||||
|
"environment": "environment",
|
||||||
|
}
|
||||||
|
if filter.Sort != "" {
|
||||||
|
if mappedField, ok := sortFieldMap[filter.Sort]; ok {
|
||||||
|
sortField = mappedField
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if filter.SortDesc {
|
||||||
|
sortDir = "DESC"
|
||||||
|
} else {
|
||||||
|
sortDir = "ASC"
|
||||||
|
}
|
||||||
|
|
||||||
// Get paginated results
|
// Get paginated results
|
||||||
offset := (filter.Page - 1) * filter.PerPage
|
pageSize := filter.PerPage
|
||||||
|
if filter.PageSize > 0 && filter.PageSize <= 500 {
|
||||||
|
pageSize = filter.PageSize
|
||||||
|
}
|
||||||
|
|
||||||
|
var limitClause string
|
||||||
|
var offset int
|
||||||
|
if filter.Cursor != "" {
|
||||||
|
// Cursor-based pagination
|
||||||
|
limitClause = fmt.Sprintf("LIMIT $%d", argCount)
|
||||||
|
args = append(args, pageSize)
|
||||||
|
argCount++
|
||||||
|
} else {
|
||||||
|
// Page-based pagination
|
||||||
|
offset = (filter.Page - 1) * pageSize
|
||||||
|
limitClause = fmt.Sprintf("LIMIT $%d OFFSET $%d", argCount, argCount+1)
|
||||||
|
args = append(args, pageSize, offset)
|
||||||
|
argCount += 2
|
||||||
|
}
|
||||||
|
|
||||||
query := fmt.Sprintf(`
|
query := fmt.Sprintf(`
|
||||||
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
|
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id,
|
||||||
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, revoked_at, revocation_reason, created_at, updated_at
|
certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, revoked_at, revocation_reason, created_at, updated_at
|
||||||
FROM managed_certificates
|
FROM managed_certificates
|
||||||
%s
|
%s
|
||||||
ORDER BY created_at DESC
|
ORDER BY %s %s
|
||||||
LIMIT $%d OFFSET $%d
|
%s
|
||||||
`, whereClause, argCount, argCount+1)
|
`, whereClause, sortField, sortDir, limitClause)
|
||||||
|
|
||||||
args = append(args, filter.PerPage, offset)
|
|
||||||
|
|
||||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -401,3 +490,26 @@ func scanCertificate(scanner interface {
|
|||||||
|
|
||||||
return &cert, nil
|
return &cert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// decodeCursor extracts a timestamp and ID from a cursor token.
|
||||||
|
func decodeCursor(cursor string) (time.Time, string, error) {
|
||||||
|
raw, err := base64.URLEncoding.DecodeString(cursor)
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, "", fmt.Errorf("invalid cursor: %w", err)
|
||||||
|
}
|
||||||
|
parts := strings.SplitN(string(raw), ":", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return time.Time{}, "", fmt.Errorf("invalid cursor format")
|
||||||
|
}
|
||||||
|
t, err := time.Parse(time.RFC3339Nano, parts[0])
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, "", fmt.Errorf("invalid cursor timestamp: %w", err)
|
||||||
|
}
|
||||||
|
return t, parts[1], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeCursor creates an opaque cursor token from a timestamp and ID.
|
||||||
|
func encodeCursor(createdAt time.Time, id string) string {
|
||||||
|
raw := createdAt.Format(time.RFC3339Nano) + ":" + id
|
||||||
|
return base64.URLEncoding.EncodeToString([]byte(raw))
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
// CertificateService provides business logic for certificate management.
|
// CertificateService provides business logic for certificate management.
|
||||||
type CertificateService struct {
|
type CertificateService struct {
|
||||||
certRepo repository.CertificateRepository
|
certRepo repository.CertificateRepository
|
||||||
|
targetRepo repository.TargetRepository
|
||||||
revocationRepo repository.RevocationRepository
|
revocationRepo repository.RevocationRepository
|
||||||
profileRepo repository.CertificateProfileRepository
|
profileRepo repository.CertificateProfileRepository
|
||||||
policyService *PolicyService
|
policyService *PolicyService
|
||||||
@@ -55,6 +56,11 @@ func (s *CertificateService) SetProfileRepo(repo repository.CertificateProfileRe
|
|||||||
s.profileRepo = repo
|
s.profileRepo = repo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTargetRepo sets the target repository for deployment queries.
|
||||||
|
func (s *CertificateService) SetTargetRepo(repo repository.TargetRepository) {
|
||||||
|
s.targetRepo = repo
|
||||||
|
}
|
||||||
|
|
||||||
// List returns a paginated list of certificates matching the filter.
|
// List returns a paginated list of certificates matching the filter.
|
||||||
func (s *CertificateService) List(ctx context.Context, filter *repository.CertificateFilter) ([]*domain.ManagedCertificate, int, error) {
|
func (s *CertificateService) List(ctx context.Context, filter *repository.CertificateFilter) ([]*domain.ManagedCertificate, int, error) {
|
||||||
certs, total, err := s.certRepo.List(ctx, filter)
|
certs, total, err := s.certRepo.List(ctx, filter)
|
||||||
@@ -64,6 +70,22 @@ func (s *CertificateService) List(ctx context.Context, filter *repository.Certif
|
|||||||
return certs, total, nil
|
return certs, total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListCertificatesWithFilter returns a list of certificates with advanced filtering (M20).
|
||||||
|
// This method supports the new M20 filters and returns domain.ManagedCertificate (not pointers).
|
||||||
|
func (s *CertificateService) ListCertificatesWithFilter(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) {
|
||||||
|
certs, total, err := s.certRepo.List(context.Background(), filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to list certificates with filter: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert pointers to values for handler compatibility
|
||||||
|
result := make([]domain.ManagedCertificate, len(certs))
|
||||||
|
for i, cert := range certs {
|
||||||
|
result[i] = *cert
|
||||||
|
}
|
||||||
|
return result, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Get retrieves a certificate by ID.
|
// Get retrieves a certificate by ID.
|
||||||
func (s *CertificateService) Get(ctx context.Context, id string) (*domain.ManagedCertificate, error) {
|
func (s *CertificateService) Get(ctx context.Context, id string) (*domain.ManagedCertificate, error) {
|
||||||
cert, err := s.certRepo.Get(ctx, id)
|
cert, err := s.certRepo.Get(ctx, id)
|
||||||
@@ -597,3 +619,29 @@ func (s *CertificateService) GetOCSPResponse(issuerID string, serialHex string)
|
|||||||
NextUpdate: now.Add(1 * time.Hour),
|
NextUpdate: now.Add(1 * time.Hour),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCertificateDeployments returns all deployment targets for a certificate (M20).
|
||||||
|
func (s *CertificateService) GetCertificateDeployments(certID string) ([]domain.DeploymentTarget, error) {
|
||||||
|
// Verify certificate exists
|
||||||
|
_, err := s.certRepo.Get(context.Background(), certID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("certificate not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.targetRepo == nil {
|
||||||
|
return []domain.DeploymentTarget{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get targets from repository
|
||||||
|
targets, err := s.targetRepo.ListByCertificate(context.Background(), certID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list deployment targets: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert pointers to values
|
||||||
|
result := make([]domain.DeploymentTarget, len(targets))
|
||||||
|
for i, target := range targets {
|
||||||
|
result[i] = *target
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user