From e078a686bfde3190e6218e32d9b5c655768e7ee3 Mon Sep 17 00:00:00 2001 From: shankar0123 Date: Mon, 23 Mar 2026 18:56:02 -0400 Subject: [PATCH] =?UTF-8?q?feat:=20M20=20Enhanced=20Query=20API=20?= =?UTF-8?q?=E2=80=94=20sort,=20time-range=20filters,=20cursor=20pagination?= =?UTF-8?q?,=20sparse=20fields,=20deployments=20endpoint?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit V2 (free) query enhancements for certificates: - `sort` param with direction (`?sort=-notAfter` for descending) - Time-range filters: `expires_before`, `expires_after`, `created_after`, `updated_after` - Cursor-based pagination (`?cursor=token&page_size=100`) alongside page-based - Sparse field selection (`?fields=id,commonName,status`) - Additional filters: `agent_id`, `profile_id` - New endpoint: `GET /api/v1/certificates/{id}/deployments` 25 new tests (12 handler + 13 e2e) covering all M20 features. Co-Authored-By: Claude Opus 4.6 --- cmd/server/main.go | 1 + .../api/handler/certificate_handler_test.go | 444 +++++++++++++++++- internal/api/handler/certificates.go | 147 +++++- internal/api/handler/response.go | 81 ++++ internal/api/router/router.go | 1 + internal/integration/e2e_test.go | 213 +++++++++ internal/integration/lifecycle_test.go | 1 + internal/repository/filters.go | 23 +- internal/repository/postgres/certificate.go | 124 ++++- internal/service/certificate.go | 48 ++ 10 files changed, 1041 insertions(+), 42 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index 5415c4a..316cced 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -193,6 +193,7 @@ func main() { certificateService.SetNotificationService(notificationService) certificateService.SetIssuerRegistry(issuerRegistry) certificateService.SetProfileRepo(profileRepo) + certificateService.SetTargetRepo(targetRepo) renewalService := service.NewRenewalService(certificateRepo, jobRepo, renewalPolicyRepo, profileRepo, auditService, notificationService, issuerRegistry, cfg.Keygen.Mode) deploymentService := service.NewDeploymentService(jobRepo, targetRepo, agentRepo, certificateRepo, auditService, notificationService) jobService := service.NewJobService(jobRepo, renewalService, deploymentService, logger) diff --git a/internal/api/handler/certificate_handler_test.go b/internal/api/handler/certificate_handler_test.go index 63c8986..07047d5 100644 --- a/internal/api/handler/certificate_handler_test.go +++ b/internal/api/handler/certificate_handler_test.go @@ -12,22 +12,25 @@ import ( "github.com/shankar0123/certctl/internal/api/middleware" "github.com/shankar0123/certctl/internal/domain" + "github.com/shankar0123/certctl/internal/repository" ) // MockCertificateService is a mock implementation of CertificateService interface. type MockCertificateService struct { - ListCertificatesFn func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) - GetCertificateFn func(id string) (*domain.ManagedCertificate, error) - CreateCertificateFn func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) - UpdateCertificateFn func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) - ArchiveCertificateFn func(id string) error - GetCertificateVersionsFn func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) - TriggerRenewalFn func(certID string) error - TriggerDeploymentFn func(certID string, targetID string) error - RevokeCertificateFn func(certID string, reason string) error - GetRevokedCertificatesFn func() ([]*domain.CertificateRevocation, error) - GenerateDERCRLFn func(issuerID string) ([]byte, error) - GetOCSPResponseFn func(issuerID string, serialHex string) ([]byte, error) + ListCertificatesFn func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) + ListCertificatesWithFilterFn func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) + GetCertificateFn func(id string) (*domain.ManagedCertificate, error) + CreateCertificateFn func(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) + UpdateCertificateFn func(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) + ArchiveCertificateFn func(id string) error + GetCertificateVersionsFn func(certID string, page, perPage int) ([]domain.CertificateVersion, int64, error) + TriggerRenewalFn func(certID string) error + TriggerDeploymentFn func(certID string, targetID string) error + RevokeCertificateFn func(certID string, reason string) error + GetRevokedCertificatesFn func() ([]*domain.CertificateRevocation, error) + GenerateDERCRLFn func(issuerID string) ([]byte, error) + GetOCSPResponseFn func(issuerID string, serialHex string) ([]byte, error) + GetCertificateDeploymentsFn func(certID string) ([]domain.DeploymentTarget, error) } func (m *MockCertificateService) ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) { @@ -114,6 +117,20 @@ func (m *MockCertificateService) GetOCSPResponse(issuerID string, serialHex stri return nil, nil } +func (m *MockCertificateService) ListCertificatesWithFilter(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + if m.ListCertificatesWithFilterFn != nil { + return m.ListCertificatesWithFilterFn(filter) + } + return nil, 0, nil +} + +func (m *MockCertificateService) GetCertificateDeployments(certID string) ([]domain.DeploymentTarget, error) { + if m.GetCertificateDeploymentsFn != nil { + return m.GetCertificateDeploymentsFn(certID) + } + return nil, nil +} + // Helper function to create context with request ID. func contextWithRequestID() context.Context { return context.WithValue(context.Background(), middleware.RequestIDKey{}, "test-request-id-123") @@ -141,8 +158,8 @@ func TestListCertificates_Success(t *testing.T) { } mock := &MockCertificateService{ - ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) { - if page == 1 && perPage == 50 { + ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + if filter.Page == 1 && filter.PerPage == 50 { return []domain.ManagedCertificate{cert1, cert2}, 2, nil } return nil, 0, nil @@ -180,8 +197,8 @@ func TestListCertificates_Success(t *testing.T) { // Test ListCertificates - with filters func TestListCertificates_WithFilters(t *testing.T) { mock := &MockCertificateService{ - ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) { - if status == "Active" && environment == "prod" { + ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + if filter.Status == "Active" && filter.Environment == "prod" { return []domain.ManagedCertificate{}, 0, nil } return nil, 0, nil @@ -219,7 +236,7 @@ func TestListCertificates_MethodNotAllowed(t *testing.T) { // Test ListCertificates - service error func TestListCertificates_ServiceError(t *testing.T) { mock := &MockCertificateService{ - ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) { + ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { return nil, 0, ErrMockServiceFailed }, } @@ -697,9 +714,9 @@ func TestTriggerDeployment_NoTargetID(t *testing.T) { // Test ListCertificates - invalid page parameter func TestListCertificates_InvalidPageParam(t *testing.T) { mock := &MockCertificateService{ - ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) { + ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { // Should default to page 1 - if page == 1 { + if filter.Page == 1 { return []domain.ManagedCertificate{}, 0, nil } return nil, 0, nil @@ -721,9 +738,9 @@ func TestListCertificates_InvalidPageParam(t *testing.T) { // Test ListCertificates - per_page exceeds max func TestListCertificates_PerPageExceedsMax(t *testing.T) { mock := &MockCertificateService{ - ListCertificatesFn: func(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) { + ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { // Should cap perPage at 500 - if perPage == 50 { // defaults to 50 if > 500 + if filter.PerPage == 50 { // defaults to 50 if > 500 return []domain.ManagedCertificate{}, 0, nil } return nil, 0, nil @@ -1236,3 +1253,388 @@ func TestHandleOCSP_MethodNotAllowed(t *testing.T) { t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code) } } + +// === M20 Enhanced Query API Tests === + +// TestListCertificates_SortParam tests sort parameter parsing and passing to service. +func TestListCertificates_SortParam(t *testing.T) { + mock := &MockCertificateService{ + ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + // Handler strips the '-' prefix and sets SortDesc = true + if filter.Sort != "notAfter" || !filter.SortDesc { + t.Errorf("expected sort=notAfter desc=true, got sort=%s desc=%v", filter.Sort, filter.SortDesc) + } + return []domain.ManagedCertificate{}, 0, nil + }, + } + handler := NewCertificateHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?sort=-notAfter", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.ListCertificates(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } +} + +// TestListCertificates_SortParam_Ascending tests sort parameter without '-' prefix (ascending). +func TestListCertificates_SortParam_Ascending(t *testing.T) { + mock := &MockCertificateService{ + ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + if filter.Sort != "createdAt" || filter.SortDesc { + t.Errorf("expected sort=createdAt desc=false, got sort=%s desc=%v", filter.Sort, filter.SortDesc) + } + return []domain.ManagedCertificate{}, 0, nil + }, + } + handler := NewCertificateHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?sort=createdAt", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.ListCertificates(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } +} + +// TestListCertificates_TimeRangeFilters tests time-range filter parsing. +func TestListCertificates_TimeRangeFilters(t *testing.T) { + before := time.Now().AddDate(0, 0, 90) + after := time.Now().AddDate(0, 0, -90) + + mock := &MockCertificateService{ + ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + if filter.ExpiresBefore == nil { + t.Error("expected ExpiresBefore to be set") + } + if filter.ExpiresAfter == nil { + t.Error("expected ExpiresAfter to be set") + } + return []domain.ManagedCertificate{}, 0, nil + }, + } + + handler := NewCertificateHandler(mock) + url := fmt.Sprintf("/api/v1/certificates?expires_before=%s&expires_after=%s", + before.Format(time.RFC3339), after.Format(time.RFC3339)) + req := httptest.NewRequest(http.MethodGet, url, nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.ListCertificates(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } +} + +// TestListCertificates_CreatedAfterFilter tests created_after filter parsing. +func TestListCertificates_CreatedAfterFilter(t *testing.T) { + past := time.Now().AddDate(-1, 0, 0) + + mock := &MockCertificateService{ + ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + if filter.CreatedAfter == nil { + t.Error("expected CreatedAfter to be set") + } + return []domain.ManagedCertificate{}, 0, nil + }, + } + + handler := NewCertificateHandler(mock) + url := fmt.Sprintf("/api/v1/certificates?created_after=%s", past.Format(time.RFC3339)) + req := httptest.NewRequest(http.MethodGet, url, nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.ListCertificates(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } +} + +// TestListCertificates_CursorPagination tests cursor-based pagination response. +func TestListCertificates_CursorPagination(t *testing.T) { + cert := domain.ManagedCertificate{ + ID: "mc-cursor-test-1", + CommonName: "cursor.example.com", + CreatedAt: time.Now(), + } + + mock := &MockCertificateService{ + ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + return []domain.ManagedCertificate{cert}, 1, nil + }, + } + + handler := NewCertificateHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?cursor=abc123&page_size=10", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.ListCertificates(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } + + var resp CursorPagedResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if resp.NextCursor == "" { + t.Error("expected NextCursor to be populated with cursor pagination") + } + if resp.PageSize != 10 { + t.Errorf("expected PageSize=10, got %d", resp.PageSize) + } +} + +// TestListCertificates_SparseFields tests field filtering in response. +func TestListCertificates_SparseFields(t *testing.T) { + cert := domain.ManagedCertificate{ + ID: "mc-sparse-test-1", + Name: "Sparse Test Cert", + CommonName: "sparse.example.com", + Environment: "staging", + Status: domain.CertificateStatusActive, + } + + mock := &MockCertificateService{ + ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + if len(filter.Fields) != 2 { + t.Errorf("expected 2 fields, got %d", len(filter.Fields)) + } + return []domain.ManagedCertificate{cert}, 1, nil + }, + } + + handler := NewCertificateHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?fields=id,common_name", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.ListCertificates(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } + + var resp PagedResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + // Response data should have sparse fields applied + data, ok := resp.Data.([]interface{}) + if !ok || len(data) == 0 { + t.Fatal("expected data array in response") + } + + certMap, ok := data[0].(map[string]interface{}) + if !ok { + t.Fatal("expected cert object in response") + } + + // Check that requested fields are present + if _, ok := certMap["id"]; !ok { + t.Error("expected 'id' field in filtered response") + } + if _, ok := certMap["common_name"]; !ok { + t.Error("expected 'common_name' field in filtered response") + } +} + +// TestListCertificates_ProfileFilter tests profile_id filter. +func TestListCertificates_ProfileFilter(t *testing.T) { + mock := &MockCertificateService{ + ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + if filter.ProfileID != "prof-standard" { + t.Errorf("expected ProfileID=prof-standard, got %s", filter.ProfileID) + } + return []domain.ManagedCertificate{}, 0, nil + }, + } + + handler := NewCertificateHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?profile_id=prof-standard", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.ListCertificates(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } +} + +// TestListCertificates_AgentIDFilter tests agent_id filter. +func TestListCertificates_AgentIDFilter(t *testing.T) { + mock := &MockCertificateService{ + ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + if filter.AgentID != "agent-prod-001" { + t.Errorf("expected AgentID=agent-prod-001, got %s", filter.AgentID) + } + return []domain.ManagedCertificate{}, 0, nil + }, + } + + handler := NewCertificateHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?agent_id=agent-prod-001", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.ListCertificates(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } +} + +// TestListCertificates_CombinedFilters tests multiple filters together. +func TestListCertificates_CombinedFilters(t *testing.T) { + mock := &MockCertificateService{ + ListCertificatesWithFilterFn: func(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + if filter.Status != "Active" || filter.Environment != "production" || filter.ProfileID != "prof-standard" { + t.Error("expected all filters to be set") + } + return []domain.ManagedCertificate{}, 0, nil + }, + } + + handler := NewCertificateHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates?status=Active&environment=production&profile_id=prof-standard", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.ListCertificates(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } +} + +// TestGetCertificateDeployments_Success tests retrieving deployments for a certificate. +func TestGetCertificateDeployments_Success(t *testing.T) { + deployments := []domain.DeploymentTarget{ + { + ID: "t-nginx-prod-1", + Name: "NGINX Production", + Type: "NGINX", + Config: json.RawMessage(`{"cert_path": "/etc/nginx/ssl/cert.pem"}`), + }, + { + ID: "t-haproxy-prod-1", + Name: "HAProxy Production", + Type: "HAProxy", + Config: json.RawMessage(`{"pem_path": "/etc/haproxy/ssl/cert.pem"}`), + }, + } + + mock := &MockCertificateService{ + GetCertificateDeploymentsFn: func(certID string) ([]domain.DeploymentTarget, error) { + if certID != "mc-prod-001" { + return nil, ErrMockNotFound + } + return deployments, nil + }, + } + + handler := NewCertificateHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates/mc-prod-001/deployments", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.GetCertificateDeployments(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } + + var resp map[string]interface{} + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if data, ok := resp["data"].([]interface{}); !ok || len(data) != 2 { + t.Errorf("expected 2 deployments in response") + } + + if total, ok := resp["total"].(float64); !ok || int(total) != 2 { + t.Errorf("expected total=2, got %v", resp["total"]) + } +} + +// TestGetCertificateDeployments_NotFound tests 404 for nonexistent certificate. +func TestGetCertificateDeployments_NotFound(t *testing.T) { + mock := &MockCertificateService{ + GetCertificateDeploymentsFn: func(certID string) ([]domain.DeploymentTarget, error) { + return nil, fmt.Errorf("certificate not found") + }, + } + + handler := NewCertificateHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates/mc-nonexistent/deployments", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.GetCertificateDeployments(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d", w.Code) + } +} + +// TestGetCertificateDeployments_Empty tests successful response with no deployments. +func TestGetCertificateDeployments_Empty(t *testing.T) { + mock := &MockCertificateService{ + GetCertificateDeploymentsFn: func(certID string) ([]domain.DeploymentTarget, error) { + if certID == "mc-no-deployments" { + return []domain.DeploymentTarget{}, nil + } + return nil, ErrMockNotFound + }, + } + + handler := NewCertificateHandler(mock) + req := httptest.NewRequest(http.MethodGet, "/api/v1/certificates/mc-no-deployments/deployments", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.GetCertificateDeployments(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } + + var resp map[string]interface{} + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if total, ok := resp["total"].(float64); !ok || int(total) != 0 { + t.Errorf("expected total=0, got %v", resp["total"]) + } +} + +// TestGetCertificateDeployments_MethodNotAllowed tests 405 for non-GET requests. +func TestGetCertificateDeployments_MethodNotAllowed(t *testing.T) { + mock := &MockCertificateService{} + handler := NewCertificateHandler(mock) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/certificates/mc-prod-001/deployments", nil) + req = req.WithContext(contextWithRequestID()) + w := httptest.NewRecorder() + + handler.GetCertificateDeployments(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("expected 405, got %d", w.Code) + } +} diff --git a/internal/api/handler/certificates.go b/internal/api/handler/certificates.go index 05ba8a2..dbea198 100644 --- a/internal/api/handler/certificates.go +++ b/internal/api/handler/certificates.go @@ -9,11 +9,13 @@ import ( "github.com/shankar0123/certctl/internal/api/middleware" "github.com/shankar0123/certctl/internal/domain" + "github.com/shankar0123/certctl/internal/repository" ) // CertificateService defines the service interface for certificate operations. type CertificateService interface { ListCertificates(status, environment, ownerID, teamID, issuerID string, page, perPage int) ([]domain.ManagedCertificate, int64, error) + ListCertificatesWithFilter(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) GetCertificate(id string) (*domain.ManagedCertificate, error) CreateCertificate(cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) UpdateCertificate(id string, cert domain.ManagedCertificate) (*domain.ManagedCertificate, error) @@ -25,6 +27,7 @@ type CertificateService interface { GetRevokedCertificates() ([]*domain.CertificateRevocation, error) GenerateDERCRL(issuerID string) ([]byte, error) GetOCSPResponse(issuerID string, serialHex string) ([]byte, error) + GetCertificateDeployments(certID string) ([]domain.DeploymentTarget, error) } // CertificateHandler handles HTTP requests for certificate operations. @@ -38,7 +41,7 @@ func NewCertificateHandler(svc CertificateService) CertificateHandler { } // ListCertificates lists certificates with optional filtering. -// GET /api/v1/certificates?status=Active&environment=prod&owner_id=...&team_id=...&issuer_id=...&page=1&per_page=50 +// GET /api/v1/certificates?status=Active&environment=prod&owner_id=...&team_id=...&issuer_id=...&agent_id=...&profile_id=...&expires_before=...&expires_after=...&created_after=...&updated_after=...&sort=notAfter&sort_desc=false&cursor=...&page=1&per_page=50&fields=id,commonName,status func (h CertificateHandler) ListCertificates(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { Error(w, http.StatusMethodNotAllowed, "Method not allowed") @@ -49,12 +52,56 @@ func (h CertificateHandler) ListCertificates(w http.ResponseWriter, r *http.Requ // Parse query parameters query := r.URL.Query() - status := query.Get("status") - environment := query.Get("environment") - ownerID := query.Get("owner_id") - teamID := query.Get("team_id") - issuerID := query.Get("issuer_id") + // Basic filters + filter := &repository.CertificateFilter{ + Status: query.Get("status"), + Environment: query.Get("environment"), + OwnerID: query.Get("owner_id"), + TeamID: query.Get("team_id"), + IssuerID: query.Get("issuer_id"), + AgentID: query.Get("agent_id"), + ProfileID: query.Get("profile_id"), + } + + // Time-range filters + if eb := query.Get("expires_before"); eb != "" { + if t, err := time.Parse(time.RFC3339, eb); err == nil { + filter.ExpiresBefore = &t + } + } + if ea := query.Get("expires_after"); ea != "" { + if t, err := time.Parse(time.RFC3339, ea); err == nil { + filter.ExpiresAfter = &t + } + } + if ca := query.Get("created_after"); ca != "" { + if t, err := time.Parse(time.RFC3339, ca); err == nil { + filter.CreatedAfter = &t + } + } + if ua := query.Get("updated_after"); ua != "" { + if t, err := time.Parse(time.RFC3339, ua); err == nil { + filter.UpdatedAfter = &t + } + } + + // Sorting + if sort := query.Get("sort"); sort != "" { + // Handle sort direction prefix + if strings.HasPrefix(sort, "-") { + filter.Sort = sort[1:] + filter.SortDesc = true + } else { + filter.Sort = sort + filter.SortDesc = query.Get("sort_desc") == "true" + } + } + + // Cursor-based pagination + filter.Cursor = query.Get("cursor") + + // Page-based pagination page := 1 perPage := 50 if p := query.Get("page"); p != "" { @@ -67,21 +114,59 @@ func (h CertificateHandler) ListCertificates(w http.ResponseWriter, r *http.Requ perPage = parsed } } + if ps := query.Get("page_size"); ps != "" { + if parsed, err := strconv.Atoi(ps); err == nil && parsed > 0 && parsed <= 500 { + filter.PageSize = parsed + } + } + filter.Page = page + filter.PerPage = perPage - certs, total, err := h.svc.ListCertificates(status, environment, ownerID, teamID, issuerID, page, perPage) + // Sparse fields + if fieldsStr := query.Get("fields"); fieldsStr != "" { + filter.Fields = strings.Split(fieldsStr, ",") + } + + certs, total, err := h.svc.ListCertificatesWithFilter(filter) if err != nil { ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to list certificates", requestID) return } - response := PagedResponse{ - Data: certs, - Total: total, - Page: page, - PerPage: perPage, + // Apply sparse field filtering if requested + var responseData interface{} = certs + if len(filter.Fields) > 0 { + responseData = filterFields(certs, filter.Fields) } - JSON(w, http.StatusOK, response) + // Return cursor-based or page-based response depending on which pagination is used + if filter.Cursor != "" { + // Compute next cursor from last result + nextCursor := "" + if len(certs) > 0 { + lastCert := certs[len(certs)-1] + nextCursor = encodeCursor(lastCert.CreatedAt, lastCert.ID) + } + pageSize := filter.PageSize + if pageSize == 0 { + pageSize = filter.PerPage + } + response := CursorPagedResponse{ + Data: responseData, + Total: int64(total), + NextCursor: nextCursor, + PageSize: pageSize, + } + JSON(w, http.StatusOK, response) + } else { + response := PagedResponse{ + Data: responseData, + Total: int64(total), + Page: page, + PerPage: perPage, + } + JSON(w, http.StatusOK, response) + } } // GetCertificate retrieves a single certificate by ID. @@ -525,3 +610,39 @@ func (h CertificateHandler) HandleOCSP(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write(derBytes) } + +// GetCertificateDeployments retrieves all deployment targets for a certificate. +// GET /api/v1/certificates/{id}/deployments +func (h CertificateHandler) GetCertificateDeployments(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + Error(w, http.StatusMethodNotAllowed, "Method not allowed") + return + } + + requestID := middleware.GetRequestID(r.Context()) + + // Extract certificate ID from path /api/v1/certificates/{id}/deployments + path := strings.TrimPrefix(r.URL.Path, "/api/v1/certificates/") + parts := strings.Split(path, "/") + if len(parts) < 2 || parts[0] == "" { + ErrorWithRequestID(w, http.StatusBadRequest, "Certificate ID is required", requestID) + return + } + certID := parts[0] + + deployments, err := h.svc.GetCertificateDeployments(certID) + if err != nil { + errMsg := err.Error() + if strings.Contains(errMsg, "not found") { + ErrorWithRequestID(w, http.StatusNotFound, "Certificate not found", requestID) + return + } + ErrorWithRequestID(w, http.StatusInternalServerError, "Failed to get deployments", requestID) + return + } + + JSON(w, http.StatusOK, map[string]interface{}{ + "data": deployments, + "total": len(deployments), + }) +} diff --git a/internal/api/handler/response.go b/internal/api/handler/response.go index d47105a..0bca636 100644 --- a/internal/api/handler/response.go +++ b/internal/api/handler/response.go @@ -1,8 +1,12 @@ package handler import ( + "encoding/base64" "encoding/json" + "fmt" "net/http" + "strings" + "time" ) // PagedResponse represents a paginated API response. @@ -13,6 +17,14 @@ type PagedResponse struct { PerPage int `json:"per_page"` } +// CursorPagedResponse represents a cursor-paginated API response. +type CursorPagedResponse struct { + Data interface{} `json:"data"` + Total int64 `json:"total"` + NextCursor string `json:"next_cursor,omitempty"` + PageSize int `json:"page_size"` +} + // ErrorResponse represents a standard error response. type ErrorResponse struct { Error string `json:"error"` @@ -49,3 +61,72 @@ func ErrorWithRequestID(w http.ResponseWriter, status int, message, requestID st w.WriteHeader(status) return json.NewEncoder(w).Encode(errResp) } + +// encodeCursor creates an opaque cursor token from a timestamp and ID. +func encodeCursor(createdAt time.Time, id string) string { + raw := createdAt.Format(time.RFC3339Nano) + ":" + id + return base64.URLEncoding.EncodeToString([]byte(raw)) +} + +// decodeCursor extracts a timestamp and ID from a cursor token. +func decodeCursor(cursor string) (time.Time, string, error) { + raw, err := base64.URLEncoding.DecodeString(cursor) + if err != nil { + return time.Time{}, "", fmt.Errorf("invalid cursor: %w", err) + } + parts := strings.SplitN(string(raw), ":", 2) + if len(parts) != 2 { + return time.Time{}, "", fmt.Errorf("invalid cursor format") + } + t, err := time.Parse(time.RFC3339Nano, parts[0]) + if err != nil { + return time.Time{}, "", fmt.Errorf("invalid cursor timestamp: %w", err) + } + return t, parts[1], nil +} + +// filterFields removes fields not in the allowed list from the response data. +// Works with both single objects and slices. +func filterFields(data interface{}, fields []string) interface{} { + if len(fields) == 0 { + return data + } + + // Create field set for O(1) lookup + fieldSet := make(map[string]bool, len(fields)) + for _, f := range fields { + fieldSet[f] = true + } + + // Marshal to JSON, then unmarshal to generic structure + bytes, err := json.Marshal(data) + if err != nil { + return data + } + + // Try as array first + var arr []map[string]interface{} + if err := json.Unmarshal(bytes, &arr); err == nil { + for i := range arr { + for key := range arr[i] { + if !fieldSet[key] { + delete(arr[i], key) + } + } + } + return arr + } + + // Try as object + var obj map[string]interface{} + if err := json.Unmarshal(bytes, &obj); err == nil { + for key := range obj { + if !fieldSet[key] { + delete(obj, key) + } + } + return obj + } + + return data +} diff --git a/internal/api/router/router.go b/internal/api/router/router.go index 40b62e0..de9a8f5 100644 --- a/internal/api/router/router.go +++ b/internal/api/router/router.go @@ -88,6 +88,7 @@ func (r *Router) RegisterHandlers( r.Register("PUT /api/v1/certificates/{id}", http.HandlerFunc(certificates.UpdateCertificate)) r.Register("DELETE /api/v1/certificates/{id}", http.HandlerFunc(certificates.ArchiveCertificate)) r.Register("GET /api/v1/certificates/{id}/versions", http.HandlerFunc(certificates.GetCertificateVersions)) + r.Register("GET /api/v1/certificates/{id}/deployments", http.HandlerFunc(certificates.GetCertificateDeployments)) r.Register("POST /api/v1/certificates/{id}/renew", http.HandlerFunc(certificates.TriggerRenewal)) r.Register("POST /api/v1/certificates/{id}/deploy", http.HandlerFunc(certificates.TriggerDeployment)) r.Register("POST /api/v1/certificates/{id}/revoke", http.HandlerFunc(certificates.RevokeCertificate)) diff --git a/internal/integration/e2e_test.go b/internal/integration/e2e_test.go index febd7ea..99f281f 100644 --- a/internal/integration/e2e_test.go +++ b/internal/integration/e2e_test.go @@ -679,3 +679,216 @@ func TestIssuerAndTargetCRUD(t *testing.T) { } }) } + +// TestM20EnhancedQueryAPI exercises M20 query API enhancements: sorting, time-range filters, +// cursor pagination, sparse fields, profile/agent filters, and the deployments endpoint. +func TestM20EnhancedQueryAPI(t *testing.T) { + server, certRepo, _, _ := setupTestServer(t) + + // Setup: Create a certificate for testing + now := time.Now() + cert := &domain.ManagedCertificate{ + ID: "mc-m20-test-1", + Name: "M20 Test Cert", + CommonName: "m20.example.com", + Environment: "production", + Status: domain.CertificateStatusActive, + IssuerID: "iss-local", + OwnerID: "owner-ops", + TeamID: "team-platform", + CertificateProfileID: "prof-standard", + CreatedAt: now, + UpdatedAt: now, + } + certRepo.certs["mc-m20-test-1"] = cert + + t.Run("ListWithSortDescending", func(t *testing.T) { + resp, err := http.Get(server.URL + "/api/v1/certificates?sort=-notAfter&page=1&per_page=10") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + var respBody map[string]interface{} + json.NewDecoder(resp.Body).Decode(&respBody) + if _, ok := respBody["data"]; !ok { + t.Error("expected data field in response") + } + }) + + t.Run("ListWithSortAscending", func(t *testing.T) { + resp, err := http.Get(server.URL + "/api/v1/certificates?sort=createdAt&page=1&per_page=10") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + var respBody map[string]interface{} + json.NewDecoder(resp.Body).Decode(&respBody) + if _, ok := respBody["page"]; !ok { + t.Error("expected page-based pagination response") + } + }) + + t.Run("TimeRangeFilter_ExpiresBefore", func(t *testing.T) { + future := now.AddDate(0, 0, 365).Format(time.RFC3339) + resp, err := http.Get(server.URL + "/api/v1/certificates?expires_before=" + future) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Errorf("expected 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + }) + + t.Run("TimeRangeFilter_ExpiresAfter", func(t *testing.T) { + past := now.AddDate(0, 0, -90).Format(time.RFC3339) + resp, err := http.Get(server.URL + "/api/v1/certificates?expires_after=" + past) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + }) + + t.Run("TimeRangeFilter_CreatedAfter", func(t *testing.T) { + past := now.AddDate(-1, 0, 0).Format(time.RFC3339) + resp, err := http.Get(server.URL + "/api/v1/certificates?created_after=" + past) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + }) + + t.Run("SparseFields", func(t *testing.T) { + resp, err := http.Get(server.URL + "/api/v1/certificates?fields=id,common_name,status") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Errorf("expected 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + var respBody map[string]interface{} + json.NewDecoder(resp.Body).Decode(&respBody) + if data, ok := respBody["data"].([]interface{}); ok && len(data) > 0 { + firstCert, ok := data[0].(map[string]interface{}) + if !ok { + t.Fatal("expected cert object in data array") + } + // Should have requested fields + if _, ok := firstCert["id"]; !ok { + t.Error("expected 'id' field in sparse response") + } + // Should NOT have unrequested fields like 'environment' + if _, ok := firstCert["environment"]; ok { + t.Error("did not expect 'environment' field in sparse response") + } + } + }) + + t.Run("ProfileFilter", func(t *testing.T) { + resp, err := http.Get(server.URL + "/api/v1/certificates?profile_id=prof-standard") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + }) + + t.Run("AgentIDFilter", func(t *testing.T) { + resp, err := http.Get(server.URL + "/api/v1/certificates?agent_id=agent-prod-001") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + }) + + t.Run("CursorPagination", func(t *testing.T) { + resp, err := http.Get(server.URL + "/api/v1/certificates?cursor=abc123&page_size=10") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + var respBody map[string]interface{} + json.NewDecoder(resp.Body).Decode(&respBody) + if _, ok := respBody["next_cursor"]; !ok { + t.Error("expected next_cursor field with cursor pagination") + } + }) + + t.Run("CombinedFilters", func(t *testing.T) { + resp, err := http.Get(server.URL + "/api/v1/certificates?status=Active&environment=production&profile_id=prof-standard&sort=-createdAt&per_page=10") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + }) + + t.Run("GetCertificateDeployments_Success", func(t *testing.T) { + resp, err := http.Get(server.URL + "/api/v1/certificates/mc-m20-test-1/deployments") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Errorf("expected 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + var respBody map[string]interface{} + json.NewDecoder(resp.Body).Decode(&respBody) + if _, ok := respBody["data"]; !ok { + t.Error("expected data field in response") + } + if _, ok := respBody["total"]; !ok { + t.Error("expected total field in response") + } + }) + + t.Run("GetCertificateDeployments_NotFound", func(t *testing.T) { + resp, err := http.Get(server.URL + "/api/v1/certificates/mc-nonexistent-m20/deployments") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusNotFound { + t.Errorf("expected 404, got %d", resp.StatusCode) + } + }) + + t.Run("InvalidTimeRange", func(t *testing.T) { + // Invalid RFC3339 should be silently ignored (no filter applied) + resp, err := http.Get(server.URL + "/api/v1/certificates?expires_before=not-a-date&page=1&per_page=10") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200 (invalid time ignored), got %d", resp.StatusCode) + } + }) +} diff --git a/internal/integration/lifecycle_test.go b/internal/integration/lifecycle_test.go index bce2999..4dba5fc 100644 --- a/internal/integration/lifecycle_test.go +++ b/internal/integration/lifecycle_test.go @@ -56,6 +56,7 @@ func TestCertificateLifecycle(t *testing.T) { certificateService.SetRevocationRepo(revocationRepo) certificateService.SetNotificationService(notificationService) certificateService.SetIssuerRegistry(issuerRegistry) + certificateService.SetTargetRepo(targetRepo) renewalService := service.NewRenewalService(certRepo, jobRepo, renewalPolicyRepo, nil, auditService, notificationService, issuerRegistry, "server") deploymentService := service.NewDeploymentService(jobRepo, targetRepo, agentRepo, certRepo, auditService, notificationService) jobService := service.NewJobService(jobRepo, renewalService, deploymentService, logger) diff --git a/internal/repository/filters.go b/internal/repository/filters.go index 1eef7e0..111a2d3 100644 --- a/internal/repository/filters.go +++ b/internal/repository/filters.go @@ -9,8 +9,27 @@ type CertificateFilter struct { OwnerID string TeamID string IssuerID string - Page int // 1-indexed; default 1 - PerPage int // default 50, max 500 + AgentID string // filter by agent_id (via deployment targets) + ProfileID string // filter by certificate_profile_id + Page int // 1-indexed; default 1 + PerPage int // default 50, max 500 + + // Time-range filters + ExpiresBefore *time.Time // certs expiring before this date + ExpiresAfter *time.Time // certs expiring after this date + CreatedAfter *time.Time // certs created after this date + UpdatedAfter *time.Time // certs updated after this date + + // Sorting + Sort string // field name to sort by (e.g., "notAfter", "createdAt", "commonName") + SortDesc bool // true = DESC, false = ASC + + // Cursor pagination (alternative to page-based) + Cursor string // opaque cursor token (base64-encoded "created_at:id") + PageSize int // for cursor-based pagination (alias for PerPage) + + // Sparse fields + Fields []string // if non-empty, only return these JSON field names } // JobFilter defines filtering criteria for job queries. diff --git a/internal/repository/postgres/certificate.go b/internal/repository/postgres/certificate.go index 195c3ca..9dfaf4b 100644 --- a/internal/repository/postgres/certificate.go +++ b/internal/repository/postgres/certificate.go @@ -3,6 +3,7 @@ package postgres import ( "context" "database/sql" + "encoding/base64" "encoding/json" "fmt" "strings" @@ -68,12 +69,59 @@ func (r *CertificateRepository) List(ctx context.Context, filter *repository.Cer args = append(args, filter.IssuerID) argCount++ } + if filter.ProfileID != "" { + whereConditions = append(whereConditions, fmt.Sprintf("certificate_profile_id = $%d", argCount)) + args = append(args, filter.ProfileID) + argCount++ + } + if filter.ExpiresBefore != nil { + whereConditions = append(whereConditions, fmt.Sprintf("expires_at < $%d", argCount)) + args = append(args, filter.ExpiresBefore) + argCount++ + } + if filter.ExpiresAfter != nil { + whereConditions = append(whereConditions, fmt.Sprintf("expires_at > $%d", argCount)) + args = append(args, filter.ExpiresAfter) + argCount++ + } + if filter.CreatedAfter != nil { + whereConditions = append(whereConditions, fmt.Sprintf("created_at > $%d", argCount)) + args = append(args, filter.CreatedAfter) + argCount++ + } + if filter.UpdatedAfter != nil { + whereConditions = append(whereConditions, fmt.Sprintf("updated_at > $%d", argCount)) + args = append(args, filter.UpdatedAfter) + argCount++ + } + if filter.AgentID != "" { + // Filter by agent_id via deployment_targets and certificate_target_mappings + whereConditions = append(whereConditions, fmt.Sprintf(`id IN ( + SELECT DISTINCT certificate_id FROM certificate_target_mappings ctm + JOIN deployment_targets dt ON ctm.target_id = dt.id + WHERE dt.agent_id = $%d + )`, argCount)) + args = append(args, filter.AgentID) + argCount++ + } whereClause := "" if len(whereConditions) > 0 { whereClause = "WHERE " + strings.Join(whereConditions, " AND ") } + // Handle cursor-based pagination + if filter.Cursor != "" { + createdAt, id, err := decodeCursor(filter.Cursor) + if err == nil { + // Add cursor condition: (created_at, id) < (cursor_time, cursor_id) + whereConditions = append(whereConditions, fmt.Sprintf("(created_at, id) < ($%d, $%d)", argCount, argCount+1)) + args = append(args, createdAt, id) + argCount += 2 + whereClause = "WHERE " + strings.Join(whereConditions, " AND ") + } + } + // Get total count countQuery := fmt.Sprintf("SELECT COUNT(*) FROM managed_certificates %s", whereClause) var total int @@ -81,18 +129,59 @@ func (r *CertificateRepository) List(ctx context.Context, filter *repository.Cer return nil, 0, fmt.Errorf("failed to count certificates: %w", err) } + // Determine sort field and direction + sortField := "created_at" + sortDir := "DESC" + sortFieldMap := map[string]string{ + "notAfter": "expires_at", + "expiresAt": "expires_at", + "createdAt": "created_at", + "updatedAt": "updated_at", + "commonName": "common_name", + "name": "name", + "status": "status", + "environment": "environment", + } + if filter.Sort != "" { + if mappedField, ok := sortFieldMap[filter.Sort]; ok { + sortField = mappedField + } + } + if filter.SortDesc { + sortDir = "DESC" + } else { + sortDir = "ASC" + } + // Get paginated results - offset := (filter.Page - 1) * filter.PerPage + pageSize := filter.PerPage + if filter.PageSize > 0 && filter.PageSize <= 500 { + pageSize = filter.PageSize + } + + var limitClause string + var offset int + if filter.Cursor != "" { + // Cursor-based pagination + limitClause = fmt.Sprintf("LIMIT $%d", argCount) + args = append(args, pageSize) + argCount++ + } else { + // Page-based pagination + offset = (filter.Page - 1) * pageSize + limitClause = fmt.Sprintf("LIMIT $%d OFFSET $%d", argCount, argCount+1) + args = append(args, pageSize, offset) + argCount += 2 + } + query := fmt.Sprintf(` SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id, certificate_profile_id, status, expires_at, tags, last_renewal_at, last_deployment_at, revoked_at, revocation_reason, created_at, updated_at FROM managed_certificates %s - ORDER BY created_at DESC - LIMIT $%d OFFSET $%d - `, whereClause, argCount, argCount+1) - - args = append(args, filter.PerPage, offset) + ORDER BY %s %s + %s + `, whereClause, sortField, sortDir, limitClause) rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { @@ -401,3 +490,26 @@ func scanCertificate(scanner interface { return &cert, nil } + +// decodeCursor extracts a timestamp and ID from a cursor token. +func decodeCursor(cursor string) (time.Time, string, error) { + raw, err := base64.URLEncoding.DecodeString(cursor) + if err != nil { + return time.Time{}, "", fmt.Errorf("invalid cursor: %w", err) + } + parts := strings.SplitN(string(raw), ":", 2) + if len(parts) != 2 { + return time.Time{}, "", fmt.Errorf("invalid cursor format") + } + t, err := time.Parse(time.RFC3339Nano, parts[0]) + if err != nil { + return time.Time{}, "", fmt.Errorf("invalid cursor timestamp: %w", err) + } + return t, parts[1], nil +} + +// encodeCursor creates an opaque cursor token from a timestamp and ID. +func encodeCursor(createdAt time.Time, id string) string { + raw := createdAt.Format(time.RFC3339Nano) + ":" + id + return base64.URLEncoding.EncodeToString([]byte(raw)) +} diff --git a/internal/service/certificate.go b/internal/service/certificate.go index 511df0b..e726219 100644 --- a/internal/service/certificate.go +++ b/internal/service/certificate.go @@ -14,6 +14,7 @@ import ( // CertificateService provides business logic for certificate management. type CertificateService struct { certRepo repository.CertificateRepository + targetRepo repository.TargetRepository revocationRepo repository.RevocationRepository profileRepo repository.CertificateProfileRepository policyService *PolicyService @@ -55,6 +56,11 @@ func (s *CertificateService) SetProfileRepo(repo repository.CertificateProfileRe s.profileRepo = repo } +// SetTargetRepo sets the target repository for deployment queries. +func (s *CertificateService) SetTargetRepo(repo repository.TargetRepository) { + s.targetRepo = repo +} + // List returns a paginated list of certificates matching the filter. func (s *CertificateService) List(ctx context.Context, filter *repository.CertificateFilter) ([]*domain.ManagedCertificate, int, error) { certs, total, err := s.certRepo.List(ctx, filter) @@ -64,6 +70,22 @@ func (s *CertificateService) List(ctx context.Context, filter *repository.Certif return certs, total, nil } +// ListCertificatesWithFilter returns a list of certificates with advanced filtering (M20). +// This method supports the new M20 filters and returns domain.ManagedCertificate (not pointers). +func (s *CertificateService) ListCertificatesWithFilter(filter *repository.CertificateFilter) ([]domain.ManagedCertificate, int, error) { + certs, total, err := s.certRepo.List(context.Background(), filter) + if err != nil { + return nil, 0, fmt.Errorf("failed to list certificates with filter: %w", err) + } + + // Convert pointers to values for handler compatibility + result := make([]domain.ManagedCertificate, len(certs)) + for i, cert := range certs { + result[i] = *cert + } + return result, total, nil +} + // Get retrieves a certificate by ID. func (s *CertificateService) Get(ctx context.Context, id string) (*domain.ManagedCertificate, error) { cert, err := s.certRepo.Get(ctx, id) @@ -597,3 +619,29 @@ func (s *CertificateService) GetOCSPResponse(issuerID string, serialHex string) NextUpdate: now.Add(1 * time.Hour), }) } + +// GetCertificateDeployments returns all deployment targets for a certificate (M20). +func (s *CertificateService) GetCertificateDeployments(certID string) ([]domain.DeploymentTarget, error) { + // Verify certificate exists + _, err := s.certRepo.Get(context.Background(), certID) + if err != nil { + return nil, fmt.Errorf("certificate not found: %w", err) + } + + if s.targetRepo == nil { + return []domain.DeploymentTarget{}, nil + } + + // Get targets from repository + targets, err := s.targetRepo.ListByCertificate(context.Background(), certID) + if err != nil { + return nil, fmt.Errorf("failed to list deployment targets: %w", err) + } + + // Convert pointers to values + result := make([]domain.DeploymentTarget, len(targets)) + for i, target := range targets { + result[i] = *target + } + return result, nil +}