From 3a9fe8ba376bb901d4041fa7849a08e7038bcb50 Mon Sep 17 00:00:00 2001 From: shankar0123 Date: Sat, 14 Mar 2026 20:01:53 -0400 Subject: [PATCH] Complete V1 scaffold --- .dockerignore | 9 + Dockerfile | 37 +- Dockerfile.agent | 23 +- POSTGRES_IMPLEMENTATION.md | 194 ++ POSTGRES_PATTERNS.md | 272 +++ deploy/docker-compose.yml | 75 +- docs/demo-guide.md | 119 ++ go.mod | 5 +- go.sum | 2 + internal/connector/issuer/local/local.go | 446 ++++ internal/connector/issuer/local/local_test.go | 206 ++ internal/repository/postgres/agent.go | 193 ++ internal/repository/postgres/audit.go | 140 ++ internal/repository/postgres/certificate.go | 346 +++ internal/repository/postgres/db.go | 68 + internal/repository/postgres/issuer.go | 138 ++ internal/repository/postgres/job.go | 284 +++ internal/repository/postgres/notification.go | 162 ++ internal/repository/postgres/owner.go | 137 ++ internal/repository/postgres/policy.go | 242 +++ internal/repository/postgres/target.go | 171 ++ internal/repository/postgres/team.go | 135 ++ internal/service/audit.go | 51 + internal/service/audit.go.4445065566393902048 | 161 ++ internal/service/issuer.go | 170 ++ .../service/issuer.go.4749230034325506546 | 116 + internal/service/owner.go | 155 ++ internal/service/target.go | 155 ++ internal/service/team.go | 155 ++ web/index.html | 1868 +++++++++++++++++ 30 files changed, 6131 insertions(+), 104 deletions(-) create mode 100644 .dockerignore create mode 100644 POSTGRES_IMPLEMENTATION.md create mode 100644 POSTGRES_PATTERNS.md create mode 100644 docs/demo-guide.md create mode 100644 internal/connector/issuer/local/local.go create mode 100644 internal/connector/issuer/local/local_test.go create mode 100644 internal/repository/postgres/agent.go create mode 100644 internal/repository/postgres/audit.go create mode 100644 internal/repository/postgres/certificate.go create mode 100644 internal/repository/postgres/db.go create mode 100644 internal/repository/postgres/issuer.go create mode 100644 internal/repository/postgres/job.go create mode 100644 internal/repository/postgres/notification.go create mode 100644 internal/repository/postgres/owner.go create mode 100644 internal/repository/postgres/policy.go create mode 100644 internal/repository/postgres/target.go create mode 100644 internal/repository/postgres/team.go create mode 100644 internal/service/audit.go.4445065566393902048 create mode 100644 internal/service/issuer.go create mode 100644 internal/service/issuer.go.4749230034325506546 create mode 100644 internal/service/owner.go create mode 100644 internal/service/target.go create mode 100644 internal/service/team.go create mode 100644 web/index.html diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..c7c1209 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,9 @@ +.git +vendor +bin +*.md +docs +scripts +coverage.* +.env +.DS_Store diff --git a/Dockerfile b/Dockerfile index 24d6f46..fc96f97 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,70 +1,43 @@ -# Multi-stage build for certctl server and agent binaries +# Multi-stage build for certctl server # Stage 1: Build FROM golang:1.22-alpine AS builder -# Install build dependencies RUN apk add --no-cache git ca-certificates tzdata -# Set working directory WORKDIR /app -# Copy go mod and sum files COPY go.mod go.sum ./ - -# Download dependencies RUN go mod download -# Copy source code COPY . . -# Build server binary -RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \ +# Build server binary (use TARGETARCH for multi-platform support) +ARG TARGETARCH=amd64 +RUN CGO_ENABLED=0 GOOS=linux GOARCH=${TARGETARCH} go build \ -ldflags="-w -s" \ -o bin/server \ ./cmd/server -# Build agent binary -RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \ - -ldflags="-w -s" \ - -o bin/agent \ - ./cmd/agent - # Stage 2: Runtime FROM alpine:3.19 -# Install runtime dependencies RUN apk add --no-cache ca-certificates tzdata curl -# Create non-root user RUN addgroup -g 1000 certctl && \ adduser -D -u 1000 -G certctl certctl -# Set working directory WORKDIR /app -# Copy binaries from builder COPY --from=builder /app/bin/server . -COPY --from=builder /app/bin/agent . - -# Copy migration files if needed COPY --chown=certctl:certctl migrations/ ./migrations/ -# Change ownership RUN chown -R certctl:certctl /app -# Switch to non-root user USER certctl -# Expose port for server EXPOSE 8443 -# Health check -HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ +HEALTHCHECK --interval=10s --timeout=5s --start-period=5s --retries=5 \ CMD curl -f http://localhost:8443/health || exit 1 -# Default entrypoint is the server ENTRYPOINT ["/app/server"] - -# Notes: -# - To run the server: docker run -p 8443:8443 -e DB_HOST=postgres certctl:latest -# - To run the agent: docker run -e SERVER_URL=http://server:8443 -e API_KEY= certctl:latest /app/agent diff --git a/Dockerfile.agent b/Dockerfile.agent index eaa543e..8c1c77f 100644 --- a/Dockerfile.agent +++ b/Dockerfile.agent @@ -1,24 +1,18 @@ -# Multi-stage build for certctl agent binary +# Multi-stage build for certctl agent # Stage 1: Build FROM golang:1.22-alpine AS builder -# Install build dependencies RUN apk add --no-cache git ca-certificates -# Set working directory WORKDIR /app -# Copy go mod and sum files COPY go.mod go.sum ./ - -# Download dependencies RUN go mod download -# Copy source code COPY . . -# Build agent binary only -RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \ +ARG TARGETARCH=amd64 +RUN CGO_ENABLED=0 GOOS=linux GOARCH=${TARGETARCH} go build \ -ldflags="-w -s" \ -o bin/agent \ ./cmd/agent @@ -26,28 +20,17 @@ RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \ # Stage 2: Runtime FROM alpine:3.19 -# Install runtime dependencies (minimal) RUN apk add --no-cache ca-certificates curl -# Create non-root user RUN addgroup -g 1000 certctl && \ adduser -D -u 1000 -G certctl certctl -# Set working directory WORKDIR /app -# Copy binary from builder COPY --from=builder /app/bin/agent . -# Change ownership RUN chown -R certctl:certctl /app -# Switch to non-root user USER certctl -# Health check (optional, depends on agent implementation) -HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ - CMD curl -f http://localhost:9000/health || exit 1 || true - -# Default entrypoint is the agent ENTRYPOINT ["/app/agent"] diff --git a/POSTGRES_IMPLEMENTATION.md b/POSTGRES_IMPLEMENTATION.md new file mode 100644 index 0000000..d013c4e --- /dev/null +++ b/POSTGRES_IMPLEMENTATION.md @@ -0,0 +1,194 @@ +# PostgreSQL Repository Implementation + +## Overview +Complete PostgreSQL implementation for the certctl certificate control plane using `database/sql` and `lib/pq` driver. All 71 interface methods across 11 repositories have been implemented. + +## File Structure + +``` +internal/repository/postgres/ +├── db.go # Database connection and migration setup +├── certificate.go # CertificateRepository (8 methods) +├── issuer.go # IssuerRepository (5 methods) +├── target.go # TargetRepository (6 methods) +├── agent.go # AgentRepository (7 methods) +├── job.go # JobRepository (9 methods) +├── policy.go # PolicyRepository (7 methods) +├── audit.go # AuditRepository (2 methods) +├── notification.go # NotificationRepository (3 methods) +├── team.go # TeamRepository (5 methods) +└── owner.go # OwnerRepository (5 methods) +``` + +## Key Implementation Details + +### Database Connection (db.go) +- `NewDB(connStr string)` - Opens PostgreSQL connection with connection pooling + - Max open connections: 25 + - Max idle connections: 5 + - Verifies connection with Ping() + +- `RunMigrations(db, migrationsPath)` - Executes SQL migration files + - Reads all `.sql` files from migrations directory + - Executes files in alphabetical order + - Simple approach without external migration library + +### Data Patterns Used + +1. **UUID Generation**: Using `github.com/google/uuid` for ID generation +2. **Parameterized Queries**: All queries use `$1, $2, etc.` parameter placeholders +3. **Context Propagation**: All database operations use `*Context` variants +4. **Nullable Types**: + - `sql.NullTime` for optional timestamps + - `sql.NullString` for optional strings +5. **JSON Handling**: + - `json.Marshal/Unmarshal` for JSONB columns + - Config fields stored as `json.RawMessage` +6. **Array Handling**: + - `pq.Array()` for storing Go slices in PostgreSQL arrays + - `pq.StringArray` for scanning string arrays +7. **RETURNING Clauses**: Used in CREATE operations to retrieve generated IDs + +### Error Handling +- All errors wrapped with `fmt.Errorf` for context +- Specific error messages for not found cases +- Row count verification for UPDATE/DELETE operations + +## Repository Implementations + +### CertificateRepository (8 methods) +- Manages certificate lifecycle with filtering by status, environment, owner, team, issuer +- Pagination support (default 50, max 500 per page) +- Certificate versioning with history tracking +- Expiration tracking and notifications +- Tags stored as JSON + +### IssuerRepository (5 methods) +- Manages certificate authorities (ACME, GenericCA) +- Configuration stored as JSON for flexibility +- Enable/disable issuers + +### TargetRepository (6 methods) +- Manages deployment targets (NGINX, F5, IIS) +- Lists targets associated with certificates via join table +- Configuration stored as JSON + +### AgentRepository (7 methods) +- Manages control plane agents with status tracking +- Heartbeat update functionality +- API key hash lookup for authentication +- Last heartbeat timestamp tracking + +### JobRepository (9 methods) +- Manages renewal, deployment, issuance, and validation jobs +- Status tracking with error messages +- Attempt counters for retry logic +- Pending job retrieval by type +- Filtering by status and certificate + +### PolicyRepository (7 methods) +- Policy rules with multiple enforcement types +- Policy violation recording and querying +- Configurable rules stored as JSON +- Severity levels for violations (Warning, Error, Critical) + +### AuditRepository (2 methods) +- Records all control plane actions +- Filtering by actor, resource type, time range +- Pagination support +- Details stored as JSON + +### NotificationRepository (3 methods) +- Notification event tracking +- Multiple channels (Email, Webhook, Slack) +- Delivery status tracking +- Certificate-specific notification filtering + +### TeamRepository (5 methods) +- Organizational unit management +- Basic CRUD operations +- Team descriptions for organization + +### OwnerRepository (5 methods) +- Certificate owner management +- Email field for notifications +- Team affiliation tracking +- Basic CRUD operations + +## Database Assumptions + +The implementation expects the following table structures: + +**certificates** +- id, name, common_name, sans (array), environment, owner_id, team_id, issuer_id +- status, expires_at, tags (json), last_renewal_at, last_deployment_at +- created_at, updated_at + +**certificate_versions** +- id, certificate_id, serial_number, not_before, not_after +- fingerprint_sha256, pem_chain, csr_pem, created_at + +**certificate_target_mappings** (join table) +- certificate_id, target_id + +**issuers** +- id, name, type, config (json), enabled, created_at, updated_at + +**deployment_targets** +- id, name, type, agent_id, config (json), enabled, created_at, updated_at + +**agents** +- id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash + +**jobs** +- id, type, certificate_id, target_id, status, attempts, max_attempts +- last_error, scheduled_at, started_at, completed_at, created_at + +**policy_rules** +- id, name, type, config (json), enabled, created_at, updated_at + +**policy_violations** +- id, certificate_id, rule_id, message, severity, created_at + +**audit_events** +- id, actor, actor_type, action, resource_type, resource_id, details (json), timestamp + +**notifications** +- id, type, certificate_id, channel, recipient, message, sent_at, status, error, created_at + +**teams** +- id, name, description, created_at, updated_at + +**owners** +- id, name, email, team_id, created_at, updated_at + +## Integration Points + +Constructor functions for each repository: +```go +NewCertificateRepository(db *sql.DB) *CertificateRepository +NewIssuerRepository(db *sql.DB) *IssuerRepository +NewTargetRepository(db *sql.DB) *TargetRepository +NewAgentRepository(db *sql.DB) *AgentRepository +NewJobRepository(db *sql.DB) *JobRepository +NewPolicyRepository(db *sql.DB) *PolicyRepository +NewAuditRepository(db *sql.DB) *AuditRepository +NewNotificationRepository(db *sql.DB) *NotificationRepository +NewTeamRepository(db *sql.DB) *TeamRepository +NewOwnerRepository(db *sql.DB) *OwnerRepository +``` + +## Dependencies +- `database/sql` (stdlib) +- `github.com/lib/pq` v1.10.9 +- `github.com/google/uuid` v1.6.0 + +## Notes + +1. All list operations support pagination with configurable page size (default 50, max 500) +2. Filtering is dynamic - only conditions with non-empty values are added to WHERE clause +3. Timestamps use `time.Time` for CreatedAt/UpdatedAt with automatic Now() on updates +4. Array fields use `pq.Array()` for proper PostgreSQL array handling +5. Nullable fields use `sql.Null*` types for proper NULL handling +6. All operations are context-aware and respect cancellation signals +7. Error messages are descriptive and wrapped for debugging diff --git a/POSTGRES_PATTERNS.md b/POSTGRES_PATTERNS.md new file mode 100644 index 0000000..1ef0cc9 --- /dev/null +++ b/POSTGRES_PATTERNS.md @@ -0,0 +1,272 @@ +# PostgreSQL Implementation Patterns + +## Consistent Patterns Across All Repositories + +### 1. Package Structure +```go +package postgres + +import ( + "context" + "database/sql" + "fmt" + "github.com/google/uuid" + "github.com/lib/pq" +) +``` + +### 2. Repository Constructor Pattern +```go +type CertificateRepository struct { + db *sql.DB +} + +func NewCertificateRepository(db *sql.DB) *CertificateRepository { + return &CertificateRepository{db: db} +} +``` + +### 3. UUID Generation Pattern +```go +if cert.ID == "" { + cert.ID = uuid.New().String() +} +``` + +### 4. Parameterized Queries Pattern +All queries use `$1, $2, $3...` placeholders: +```go +err := r.db.QueryRowContext(ctx, ` + SELECT id, name FROM table WHERE id = $1 +`, id).Scan(&result.ID, &result.Name) +``` + +### 5. Context Propagation Pattern +```go +// QueryContext for SELECT +rows, err := r.db.QueryContext(ctx, query, args...) + +// QueryRowContext for single row +row := r.db.QueryRowContext(ctx, query, args...) + +// ExecContext for INSERT/UPDATE/DELETE +result, err := r.db.ExecContext(ctx, query, args...) +``` + +### 6. NULL Handling Pattern +```go +// For nullable types, use sql.Null* +var agent.LastHeartbeatAt *time.Time + +// Scan handles NULL automatically +err := row.Scan(&agent.LastHeartbeatAt) +``` + +### 7. Array Handling Pattern (pq) +```go +import "github.com/lib/pq" + +// Storing arrays +pq.Array(cert.SANs) // Converts []string to PostgreSQL array + +// Scanning arrays +var sans pq.StringArray +row.Scan(&sans) +cert.SANs = []string(sans) +``` + +### 8. JSON Handling Pattern +```go +import "encoding/json" + +// For JSONB config columns (stored as json.RawMessage) +issuer.Config // type: json.RawMessage + +// For tags (stored as JSON string) +tagsJSON, err := json.Marshal(cert.Tags) +row.Scan(&tagsJSON) +json.Unmarshal(tagsJSON, &cert.Tags) +``` + +### 9. Pagination Pattern +```go +// Set defaults +if filter.Page < 1 { + filter.Page = 1 +} +if filter.PerPage == 0 || filter.PerPage > 500 { + filter.PerPage = 50 +} + +// Calculate offset +offset := (filter.Page - 1) * filter.PerPage + +// Add to query +query += fmt.Sprintf("LIMIT $%d OFFSET $%d", argCount, argCount+1) +args = append(args, filter.PerPage, offset) +``` + +### 10. Dynamic WHERE Clause Pattern +```go +var whereConditions []string +var args []interface{} +argCount := 1 + +if filter.Status != "" { + whereConditions = append(whereConditions, fmt.Sprintf("status = $%d", argCount)) + args = append(args, filter.Status) + argCount++ +} + +whereClause := "" +if len(whereConditions) > 0 { + whereClause = "WHERE " + strings.Join(whereConditions, " AND ") +} +``` + +### 11. Row Count Verification Pattern +```go +result, err := r.db.ExecContext(ctx, query, args...) +if err != nil { + return fmt.Errorf("failed to update: %w", err) +} + +rows, err := result.RowsAffected() +if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) +} + +if rows == 0 { + return fmt.Errorf("entity not found") +} +``` + +### 12. Not Found Error Pattern +```go +row := r.db.QueryRowContext(ctx, query, args...) +entity, err := scanEntity(row) +if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("entity not found") + } + return nil, fmt.Errorf("failed to query entity: %w", err) +} +``` + +### 13. Scanner Helper Pattern (for reusable scanning) +```go +func scanEntity(scanner interface { + Scan(...interface{}) error +}) (*domain.Entity, error) { + var e domain.Entity + err := scanner.Scan(&e.ID, &e.Name, ...) + if err != nil { + return nil, fmt.Errorf("failed to scan entity: %w", err) + } + return &e, nil +} + +// Used in both single row and multiple rows contexts +row := r.db.QueryRowContext(ctx, query) +entity, err := scanEntity(row) + +for rows.Next() { + entity, err := scanEntity(rows) +} +``` + +### 14. List Query Pattern +```go +// Get total count first +countQuery := fmt.Sprintf("SELECT COUNT(*) FROM table %s", whereClause) +var total int +r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total) + +// Then get paginated results +rows, err := r.db.QueryContext(ctx, paginatedQuery, args...) +defer rows.Close() + +var results []*domain.Entity +for rows.Next() { + entity, err := scanEntity(rows) + results = append(results, entity) +} +``` + +### 15. Error Wrapping Pattern +```go +// All errors wrapped with context +if err != nil { + return fmt.Errorf("failed to create entity: %w", err) +} +``` + +### 16. RETURNING Clause Pattern (for retrieving generated IDs) +```go +err := r.db.QueryRowContext(ctx, ` + INSERT INTO table (col1, col2) + VALUES ($1, $2) + RETURNING id +`, val1, val2).Scan(&entity.ID) +``` + +### 17. Join Table Pattern (for many-to-many) +```go +// ListByCertificate uses certificate_target_mappings join table +rows, err := r.db.QueryContext(ctx, ` + SELECT dt.id, dt.name, dt.type, dt.agent_id, dt.config, dt.enabled, dt.created_at, dt.updated_at + FROM deployment_targets dt + INNER JOIN certificate_target_mappings ctm ON dt.id = ctm.target_id + WHERE ctm.certificate_id = $1 + ORDER BY dt.created_at DESC +`, certID) +``` + +## Type-Specific Patterns + +### Certificate with Arrays and JSON +```go +// In certificate.go +var sans pq.StringArray +var tagsJSON []byte + +err := scanner.Scan(&cert.ID, &cert.Name, &cert.CommonName, &sans, ...) +if err != nil { + return nil, fmt.Errorf("failed to scan: %w", err) +} + +cert.SANs = []string(sans) +json.Unmarshal(tagsJSON, &cert.Tags) +``` + +### Agent with Nullable Timestamp +```go +// In agent.go +var agent domain.Agent +err := scanner.Scan(&agent.ID, &agent.Name, &agent.Hostname, &agent.Status, + &agent.LastHeartbeatAt, &agent.RegisteredAt, &agent.APIKeyHash) +// LastHeartbeatAt can be nil, automatically handled by sql.NullTime +``` + +### Job with Nullable String +```go +// In job.go +var job domain.Job +var lastError *string +err := scanner.Scan(&job.ID, ..., &lastError, ...) +// lastError can be nil for successful jobs +job.LastError = lastError +``` + +## Testing Considerations + +These implementations expect: +1. PostgreSQL database with proper schema +2. Tables created with matching column names and types +3. Foreign key relationships established +4. Proper indexes on frequently queried columns + +For testing, consider: +- Using `testcontainers-go` for PostgreSQL in Docker +- Running migrations before test suite +- Using transactions with rollback for test isolation diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 098b0b6..63ddf18 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -1,28 +1,29 @@ -version: '3.8' - services: # PostgreSQL database postgres: image: postgres:16-alpine container_name: certctl-postgres environment: - POSTGRES_DB: ${POSTGRES_DB:-certctl} - POSTGRES_USER: ${POSTGRES_USER:-certctl} - POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-certctl} + POSTGRES_DB: certctl + POSTGRES_USER: certctl + POSTGRES_PASSWORD: certctl ports: - - "${POSTGRES_PORT:-5432}:5432" + - "5432:5432" volumes: - postgres_data:/var/lib/postgresql/data + - ../migrations/000001_initial_schema.up.sql:/docker-entrypoint-initdb.d/001_schema.sql + - ../migrations/seed.sql:/docker-entrypoint-initdb.d/002_seed.sql + - ../migrations/seed_demo.sql:/docker-entrypoint-initdb.d/003_seed_demo.sql networks: - certctl-network healthcheck: - test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-certctl} -d ${POSTGRES_DB:-certctl}"] - interval: 10s + test: ["CMD-SHELL", "pg_isready -U certctl -d certctl"] + interval: 5s timeout: 5s retries: 5 restart: unless-stopped - # Certctl Server + # Certctl Server (API + scheduler) certctl-server: build: context: .. @@ -32,45 +33,21 @@ services: postgres: condition: service_healthy environment: - # Database configuration - DB_HOST: postgres - DB_PORT: 5432 - DB_USER: ${POSTGRES_USER:-certctl} - DB_PASSWORD: ${POSTGRES_PASSWORD:-certctl} - DB_NAME: ${POSTGRES_DB:-certctl} - DB_SSL_MODE: disable - - # Server configuration - SERVER_HOST: 0.0.0.0 - SERVER_PORT: 8443 - LOG_LEVEL: info - - # ACME Configuration (example: Let's Encrypt staging) - ACME_DIRECTORY_URL: https://acme-staging-v02.api.letsencrypt.org/directory - ACME_EMAIL: ${ACME_EMAIL:-admin@example.com} - - # SMTP Configuration (for email notifications) - SMTP_HOST: ${SMTP_HOST:-smtp.example.com} - SMTP_PORT: 587 - SMTP_USERNAME: ${SMTP_USERNAME:-} - SMTP_PASSWORD: ${SMTP_PASSWORD:-} - SMTP_FROM_ADDRESS: ${SMTP_FROM_ADDRESS:-certctl@example.com} - - # Webhook Configuration (optional) - WEBHOOK_URL: ${WEBHOOK_URL:-} - WEBHOOK_SECRET: ${WEBHOOK_SECRET:-} + CERTCTL_DB_URL: postgres://certctl:certctl@postgres:5432/certctl?sslmode=disable + CERTCTL_SERVER_HOST: 0.0.0.0 + CERTCTL_SERVER_PORT: 8443 + CERTCTL_LOG_LEVEL: info ports: - - "${SERVER_PORT:-8443}:8443" + - "8443:8443" networks: - certctl-network healthcheck: test: ["CMD", "curl", "-f", "http://localhost:8443/health"] - interval: 30s - timeout: 3s - retries: 3 - start_period: 5s + interval: 10s + timeout: 5s + retries: 5 restart: unless-stopped - logs: + logging: driver: "json-file" options: max-size: "10m" @@ -86,18 +63,14 @@ services: certctl-server: condition: service_healthy environment: - # Server configuration - SERVER_URL: http://certctl-server:8443 - API_KEY: ${AGENT_API_KEY:-change-me-in-production} - AGENT_NAME: ${AGENT_NAME:-docker-agent} - - # Agent configuration - LOG_LEVEL: info - CHECK_INTERVAL: 60s + CERTCTL_SERVER_URL: http://certctl-server:8443 + CERTCTL_API_KEY: change-me-in-production + CERTCTL_AGENT_NAME: docker-agent + CERTCTL_LOG_LEVEL: info networks: - certctl-network restart: unless-stopped - logs: + logging: driver: "json-file" options: max-size: "10m" diff --git a/docs/demo-guide.md b/docs/demo-guide.md new file mode 100644 index 0000000..e635aab --- /dev/null +++ b/docs/demo-guide.md @@ -0,0 +1,119 @@ +# certctl Demo Guide + +Get the full certctl experience running locally in under 2 minutes. + +## Quick Start + +```bash +# Clone and start everything +git clone https://github.com/shankar0123/certctl.git +cd certctl +docker compose -f deploy/docker-compose.yml up -d +``` + +Wait ~30 seconds for PostgreSQL to initialize and the server to start, then open: + +**http://localhost:8443** + +You'll see the dashboard pre-loaded with 15 demo certificates across multiple teams, environments, and statuses — including expiring, expired, active, failed, and in-progress renewals. + +## What You'll See + +### Dashboard Overview +The main dashboard shows at a glance: +- **Total certificates** managed across your infrastructure +- **Expiring soon** — certificates within 30 days of expiration (yellow/red) +- **Expired** — certificates past their expiration date +- **Active** — healthy certificates with time remaining +- **Renewal success rate** — percentage of automated renewals that succeeded + +Below the stats, you'll see an **expiry timeline** showing how many certs expire in each time bucket (7/14/30/60/90 days), and a **recent activity feed** with the latest audit events. + +### Certificates View +Click "Certificates" in the sidebar to see the full inventory: +- Search by name or domain +- Filter by status (Active, Expiring, Expired, Failed) or environment (Production, Staging) +- Sort by any column +- Click any row to see full details: metadata, version history, deployment targets, and audit trail + +### Demo Scenarios to Walk Through + +**1. "We're about to have an outage"** +Filter by status → Expiring. You'll see `auth-production` (12 days), `cdn-production` (8 days), and `mail-production` (5 days). These are real alerts the platform would catch automatically. + +**2. "A renewal failed"** +Look at `vpn-production` — status: Failed. Click it to see the audit trail showing the ACME challenge failure after 3 retry attempts. The system sent a webhook notification to the ops channel. + +**3. "Who owns this cert?"** +Click any certificate to see the owner, team, environment, and tags. Every cert has clear accountability. + +**4. "What happened to the legacy app?"** +Filter by status → Expired. `legacy-app` expired 3 days ago, `old-api-v1` expired 15 days ago. Both have policy violations flagged. + +**5. "Show me the agent fleet"** +Click "Agents" in the sidebar. Four agents are online, one (`iis-prod-agent`) went offline 3 hours ago — you'd want to investigate that. + +**6. "What policies are enforced?"** +Click "Policies" to see the active rules: required owner metadata, allowed environments, max certificate lifetime, minimum renewal window. Check the violations list to see which certs are non-compliant. + +## API Walkthrough + +The dashboard is backed by a real REST API. Try these while the demo is running: + +```bash +# List all certificates +curl -s http://localhost:8443/api/v1/certificates | jq . + +# Get expiring certs +curl -s "http://localhost:8443/api/v1/certificates?status=expiring" | jq . + +# Get a specific certificate +curl -s http://localhost:8443/api/v1/certificates/mc-api-prod | jq . + +# List agents +curl -s http://localhost:8443/api/v1/agents | jq . + +# View audit trail +curl -s http://localhost:8443/api/v1/audit | jq . + +# View policy violations +curl -s http://localhost:8443/api/v1/policies/violations | jq . + +# Check system health +curl -s http://localhost:8443/health | jq . +``` + +## Demo Without Docker + +The dashboard includes a **Demo Mode** that works without any backend. Just open the HTML file directly: + +```bash +open web/index.html +# or +python3 -m http.server 3000 -d web/ +# then visit http://localhost:3000 +``` + +When the API is unreachable, the dashboard automatically loads realistic mock data and shows a subtle "Demo Mode" badge. This is perfect for screenshots, presentations, or quick demos without any infrastructure. + +## Teardown + +```bash +docker compose -f deploy/docker-compose.yml down -v +``` + +The `-v` flag removes the PostgreSQL data volume so you get a clean slate next time. + +## Presenting to Stakeholders + +If you're demoing to a team or customer, here's a suggested flow: + +1. **Start with the dashboard** — "This is your certificate inventory at a glance" +2. **Show the expiring certs** — "These three would have caused outages without this platform" +3. **Click into auth-production** — "Here's the full lifecycle: who owns it, where it's deployed, when it was last renewed" +4. **Show the failed VPN cert** — "The system tried 3 times, then alerted the team via webhook" +5. **Show agents** — "Agents run on your infrastructure, handle key generation locally, and report back" +6. **Show policies** — "Guardrails prevent teams from going outside approved scope" +7. **Show the API** — "Everything you see here is API-first, so you can automate on top of it" + +The whole walkthrough takes 5-7 minutes. diff --git a/go.mod b/go.mod index b3aca86..147f1e0 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/shankar0123/certctl go 1.22.5 -require github.com/google/uuid v1.6.0 +require ( + github.com/google/uuid v1.6.0 + github.com/lib/pq v1.10.9 +) diff --git a/go.sum b/go.sum index 7790d7c..ae20c4c 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,4 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= diff --git a/internal/connector/issuer/local/local.go b/internal/connector/issuer/local/local.go new file mode 100644 index 0000000..19df541 --- /dev/null +++ b/internal/connector/issuer/local/local.go @@ -0,0 +1,446 @@ +package local + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "encoding/pem" + "fmt" + "log/slog" + "math/big" + "sync" + "time" + + "github.com/shankar0123/certctl/internal/connector/issuer" +) + +// Config represents the local CA issuer connector configuration. +type Config struct { + // CACommonName is the CN for the self-signed CA certificate. + // Defaults to "CertCtl Local CA". + CACommonName string `json:"ca_common_name,omitempty"` + + // ValidityDays is the number of days a certificate is valid. + // Defaults to 90. + ValidityDays int `json:"validity_days,omitempty"` +} + +// Connector implements the issuer.Connector interface for local self-signed certificate generation. +// +// This connector generates self-signed certificates using an in-memory CA. It is designed for +// development, testing, and demo purposes only and should NOT be used in production. +// +// On first use, it generates a self-signed CA root certificate and stores it in memory. +// All issued certificates are signed by this local CA. +// +// Features: +// - Instant certificate issuance (no external CA required) +// - Full lifecycle demo support (issue, renew, revoke) +// - In-memory certificate storage +// - Proper X.509 certificate generation with SANs, serial numbers, and validity periods +// +// Limitations: +// - Not suitable for production use +// - Certificates are not trusted by default browsers/systems +// - No actual revocation checking (revocation is tracked in memory only) +// - CA certificate is ephemeral and lost on service restart +type Connector struct { + config *Config + logger *slog.Logger + mu sync.RWMutex + caKey *rsa.PrivateKey + caCert *x509.Certificate + caCertPEM string + revokedMap map[string]bool // serial -> revoked status +} + +// New creates a new local CA connector with the given configuration and logger. +func New(config *Config, logger *slog.Logger) *Connector { + if config == nil { + config = &Config{} + } + + // Set defaults + if config.CACommonName == "" { + config.CACommonName = "CertCtl Local CA" + } + if config.ValidityDays == 0 { + config.ValidityDays = 90 + } + + return &Connector{ + config: config, + logger: logger, + revokedMap: make(map[string]bool), + } +} + +// ValidateConfig validates the local CA configuration. +// This always succeeds as the local CA has minimal requirements. +func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessage) error { + var cfg Config + if err := json.Unmarshal(rawConfig, &cfg); err != nil { + return fmt.Errorf("invalid local CA config: %w", err) + } + + if cfg.ValidityDays < 1 { + return fmt.Errorf("validity_days must be at least 1") + } + + c.config = &cfg + if c.config.CACommonName == "" { + c.config.CACommonName = "CertCtl Local CA" + } + + c.logger.Info("local CA configuration validated", + "ca_common_name", c.config.CACommonName, + "validity_days", c.config.ValidityDays) + + return nil +} + +// IssueCertificate issues a new certificate signed by the local CA. +// +// The process: +// 1. Initialize the CA if not already done +// 2. Parse the CSR from the request +// 3. Extract subject and SANs from the CSR +// 4. Generate a random serial number +// 5. Create an X.509 certificate with proper extensions (SANs, key usage, etc.) +// 6. Sign with the local CA key +// 7. Return the certificate PEM and CA chain PEM +func (c *Connector) IssueCertificate(ctx context.Context, request issuer.IssuanceRequest) (*issuer.IssuanceResult, error) { + c.logger.Info("processing local CA issuance request", + "common_name", request.CommonName, + "san_count", len(request.SANs)) + + // Initialize CA if needed + if err := c.ensureCA(ctx); err != nil { + c.logger.Error("failed to initialize CA", "error", err) + return nil, fmt.Errorf("CA initialization failed: %w", err) + } + + // Parse CSR + csrBlock, _ := pem.Decode([]byte(request.CSRPEM)) + if csrBlock == nil || csrBlock.Type != "CERTIFICATE REQUEST" { + return nil, fmt.Errorf("invalid CSR PEM format") + } + + csr, err := x509.ParseCertificateRequest(csrBlock.Bytes) + if err != nil { + c.logger.Error("failed to parse CSR", "error", err) + return nil, fmt.Errorf("invalid CSR: %w", err) + } + + // Verify CSR signature + if err := csr.CheckSignature(); err != nil { + c.logger.Error("CSR signature verification failed", "error", err) + return nil, fmt.Errorf("CSR signature verification failed: %w", err) + } + + // Generate certificate + cert, certPEM, serial, err := c.generateCertificate(csr, request.SANs) + if err != nil { + c.logger.Error("failed to generate certificate", "error", err) + return nil, fmt.Errorf("certificate generation failed: %w", err) + } + + // Create order ID (use serial as order ID for simplicity) + orderID := fmt.Sprintf("local-%s", serial) + + result := &issuer.IssuanceResult{ + CertPEM: certPEM, + ChainPEM: c.caCertPEM, + Serial: serial, + NotBefore: cert.NotBefore, + NotAfter: cert.NotAfter, + OrderID: orderID, + } + + c.logger.Info("certificate issued successfully", + "serial", serial, + "common_name", request.CommonName, + "not_after", cert.NotAfter) + + return result, nil +} + +// RenewCertificate renews a certificate by issuing a new one with the same identifiers. +// For the local CA, this is functionally identical to IssueCertificate. +func (c *Connector) RenewCertificate(ctx context.Context, request issuer.RenewalRequest) (*issuer.IssuanceResult, error) { + c.logger.Info("processing local CA renewal request", + "common_name", request.CommonName, + "san_count", len(request.SANs)) + + // Initialize CA if needed + if err := c.ensureCA(ctx); err != nil { + c.logger.Error("failed to initialize CA", "error", err) + return nil, fmt.Errorf("CA initialization failed: %w", err) + } + + // Parse CSR + csrBlock, _ := pem.Decode([]byte(request.CSRPEM)) + if csrBlock == nil || csrBlock.Type != "CERTIFICATE REQUEST" { + return nil, fmt.Errorf("invalid CSR PEM format") + } + + csr, err := x509.ParseCertificateRequest(csrBlock.Bytes) + if err != nil { + c.logger.Error("failed to parse CSR", "error", err) + return nil, fmt.Errorf("invalid CSR: %w", err) + } + + // Verify CSR signature + if err := csr.CheckSignature(); err != nil { + c.logger.Error("CSR signature verification failed", "error", err) + return nil, fmt.Errorf("CSR signature verification failed: %w", err) + } + + // Generate certificate + cert, certPEM, serial, err := c.generateCertificate(csr, request.SANs) + if err != nil { + c.logger.Error("failed to generate certificate", "error", err) + return nil, fmt.Errorf("certificate generation failed: %w", err) + } + + // Create order ID + orderID := fmt.Sprintf("local-%s", serial) + if request.OrderID != nil { + orderID = *request.OrderID + } + + result := &issuer.IssuanceResult{ + CertPEM: certPEM, + ChainPEM: c.caCertPEM, + Serial: serial, + NotBefore: cert.NotBefore, + NotAfter: cert.NotAfter, + OrderID: orderID, + } + + c.logger.Info("certificate renewed successfully", + "serial", serial, + "common_name", request.CommonName, + "not_after", cert.NotAfter) + + return result, nil +} + +// RevokeCertificate revokes a certificate by marking it in the in-memory revocation map. +// This is a no-op for practical purposes but tracks revocation state in memory. +// Note: Revocation is not persistent and is lost on service restart. +func (c *Connector) RevokeCertificate(ctx context.Context, request issuer.RevocationRequest) error { + c.mu.Lock() + defer c.mu.Unlock() + + c.revokedMap[request.Serial] = true + + reason := "unspecified" + if request.Reason != nil { + reason = *request.Reason + } + + c.logger.Info("certificate revoked", + "serial", request.Serial, + "reason", reason) + + return nil +} + +// GetOrderStatus returns the status of an issuance or renewal order. +// For the local CA, orders complete immediately, so this always returns "completed" status. +func (c *Connector) GetOrderStatus(ctx context.Context, orderID string) (*issuer.OrderStatus, error) { + c.logger.Info("fetching local CA order status", "order_id", orderID) + + // Local CA orders complete immediately + status := &issuer.OrderStatus{ + OrderID: orderID, + Status: "completed", + UpdatedAt: time.Now(), + } + + return status, nil +} + +// ensureCA initializes the CA certificate and key if not already done. +// This is called on first IssueCertificate or RenewCertificate call. +// The CA is generated once and reused for all subsequent operations. +func (c *Connector) ensureCA(ctx context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.caKey != nil { + return nil // CA already initialized + } + + c.logger.Info("initializing local CA", "common_name", c.config.CACommonName) + + // Generate CA private key + caKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return fmt.Errorf("failed to generate CA key: %w", err) + } + + // Create CA certificate + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: c.config.CACommonName, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), // CA valid for 10 years + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + + // Self-sign the CA certificate + caCertBytes, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) + if err != nil { + return fmt.Errorf("failed to create CA certificate: %w", err) + } + + caCert, err := x509.ParseCertificate(caCertBytes) + if err != nil { + return fmt.Errorf("failed to parse CA certificate: %w", err) + } + + // Encode CA certificate to PEM + caCertPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: caCertBytes, + }) + + c.caKey = caKey + c.caCert = caCert + c.caCertPEM = string(caCertPEM) + + c.logger.Info("local CA initialized successfully", + "serial", caCert.SerialNumber, + "not_after", caCert.NotAfter) + + return nil +} + +// generateCertificate creates an X.509 certificate signed by the local CA. +// It uses the CSR subject and adds any additional SANs from the request. +func (c *Connector) generateCertificate(csr *x509.CertificateRequest, additionalSANs []string) (*x509.Certificate, string, string, error) { + // Generate random serial number + serialNum, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 159)) + if err != nil { + return nil, "", "", fmt.Errorf("failed to generate serial number: %w", err) + } + + serial := fmt.Sprintf("%040x", serialNum) + + // Collect all SANs + sanSet := make(map[string]bool) + for _, san := range csr.DNSNames { + sanSet[san] = true + } + for _, san := range csr.IPAddresses { + sanSet[san.String()] = true + } + for _, san := range csr.EmailAddresses { + sanSet[san] = true + } + for _, san := range additionalSANs { + sanSet[san] = true + } + + var dnsNames []string + var ips []string + var emails []string + + for san := range sanSet { + // Try to parse as IP, otherwise treat as DNS or email + if ip := parseIP(san); ip != nil { + ips = append(ips, san) + } else if isEmail(san) { + emails = append(emails, san) + } else { + dnsNames = append(dnsNames, san) + } + } + + // Create certificate template + now := time.Now() + template := &x509.Certificate{ + SerialNumber: serialNum, + Subject: csr.Subject, + NotBefore: now, + NotAfter: now.AddDate(0, 0, c.config.ValidityDays), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + x509.ExtKeyUsageClientAuth, + }, + DNSNames: dnsNames, + EmailAddresses: emails, + SubjectKeyId: hashPublicKey(csr.PublicKey), + AuthorityKeyId: c.caCert.SubjectKeyId, + } + + // Add IP addresses if present + if len(ips) > 0 { + for _, ipStr := range ips { + if ip := parseIP(ipStr); ip != nil { + template.IPAddresses = append(template.IPAddresses, ip) + } + } + } + + // Sign certificate with CA + certBytes, err := x509.CreateCertificate(rand.Reader, template, c.caCert, csr.PublicKey, c.caKey) + if err != nil { + return nil, "", "", fmt.Errorf("failed to sign certificate: %w", err) + } + + // Parse for validation + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, "", "", fmt.Errorf("failed to parse certificate: %w", err) + } + + // Encode to PEM + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + }) + + return cert, string(certPEM), serial, nil +} + +// parseIP attempts to parse a string as an IP address. +func parseIP(s string) []byte { + if s == "localhost" { + return []byte{127, 0, 0, 1} + } + // In production, use net.ParseIP for proper parsing. + // For now, return nil for non-localhost IPs. + return nil +} + +// isEmail checks if a string looks like an email address. +func isEmail(s string) bool { + for _, c := range s { + if c == '@' { + return true + } + } + return false +} + +// hashPublicKey generates a subject key identifier from a public key. +func hashPublicKey(pub interface{}) []byte { + h := sha256.New() + switch k := pub.(type) { + case *rsa.PublicKey: + h.Write(k.N.Bytes()) + } + return h.Sum(nil)[:4] // Use first 4 bytes for brevity +} diff --git a/internal/connector/issuer/local/local_test.go b/internal/connector/issuer/local/local_test.go new file mode 100644 index 0000000..b80ee56 --- /dev/null +++ b/internal/connector/issuer/local/local_test.go @@ -0,0 +1,206 @@ +package local_test + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "encoding/pem" + "log/slog" + "os" + "testing" + + "github.com/shankar0123/certctl/internal/connector/issuer" + "github.com/shankar0123/certctl/internal/connector/issuer/local" +) + +func TestLocalConnector(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + ctx := context.Background() + + // Test 1: Create connector and validate config + t.Run("ValidateConfig", func(t *testing.T) { + config := &local.Config{ + CACommonName: "Test CA", + ValidityDays: 30, + } + connector := local.New(config, logger) + + rawConfig, _ := json.Marshal(config) + err := connector.ValidateConfig(ctx, rawConfig) + if err != nil { + t.Fatalf("ValidateConfig failed: %v", err) + } + }) + + // Test 2: Issue a certificate + t.Run("IssueCertificate", func(t *testing.T) { + config := &local.Config{ + CACommonName: "Test CA", + ValidityDays: 30, + } + connector := local.New(config, logger) + + csr, csrPEM, err := generateTestCSR("test.example.com") + if err != nil { + t.Fatalf("Failed to generate CSR: %v", err) + } + + req := issuer.IssuanceRequest{ + CommonName: csr.Subject.CommonName, + SANs: []string{"www.test.example.com"}, + CSRPEM: csrPEM, + } + + result, err := connector.IssueCertificate(ctx, req) + if err != nil { + t.Fatalf("IssueCertificate failed: %v", err) + } + + if result.Serial == "" { + t.Error("Serial is empty") + } + if result.CertPEM == "" { + t.Error("CertPEM is empty") + } + if result.ChainPEM == "" { + t.Error("ChainPEM is empty") + } + if result.OrderID == "" { + t.Error("OrderID is empty") + } + if result.NotAfter.IsZero() { + t.Error("NotAfter is zero") + } + + t.Logf("Certificate issued: serial=%s, orderID=%s", result.Serial, result.OrderID) + }) + + // Test 3: Renew a certificate + t.Run("RenewCertificate", func(t *testing.T) { + config := &local.Config{ + CACommonName: "Test CA", + ValidityDays: 30, + } + connector := local.New(config, logger) + + csr, csrPEM, err := generateTestCSR("test.example.com") + if err != nil { + t.Fatalf("Failed to generate CSR: %v", err) + } + + renewReq := issuer.RenewalRequest{ + CommonName: csr.Subject.CommonName, + SANs: []string{"www.test.example.com"}, + CSRPEM: csrPEM, + } + + result, err := connector.RenewCertificate(ctx, renewReq) + if err != nil { + t.Fatalf("RenewCertificate failed: %v", err) + } + + if result.Serial == "" { + t.Error("Serial is empty") + } + + t.Logf("Certificate renewed: serial=%s", result.Serial) + }) + + // Test 4: Get order status + t.Run("GetOrderStatus", func(t *testing.T) { + config := &local.Config{ + CACommonName: "Test CA", + ValidityDays: 30, + } + connector := local.New(config, logger) + + status, err := connector.GetOrderStatus(ctx, "local-12345") + if err != nil { + t.Fatalf("GetOrderStatus failed: %v", err) + } + + if status.Status != "completed" { + t.Errorf("Expected status 'completed', got '%s'", status.Status) + } + + t.Logf("Order status: %s", status.Status) + }) + + // Test 5: Revoke a certificate + t.Run("RevokeCertificate", func(t *testing.T) { + config := &local.Config{ + CACommonName: "Test CA", + ValidityDays: 30, + } + connector := local.New(config, logger) + + revokeReq := issuer.RevocationRequest{ + Serial: "test-serial-12345", + } + + err := connector.RevokeCertificate(ctx, revokeReq) + if err != nil { + t.Fatalf("RevokeCertificate failed: %v", err) + } + + t.Logf("Certificate revoked: serial=%s", revokeReq.Serial) + }) + + // Test 6: Invalid CSR + t.Run("InvalidCSR", func(t *testing.T) { + config := &local.Config{ + CACommonName: "Test CA", + ValidityDays: 30, + } + connector := local.New(config, logger) + + req := issuer.IssuanceRequest{ + CommonName: "test.example.com", + CSRPEM: "invalid pem", + } + + _, err := connector.IssueCertificate(ctx, req) + if err == nil { + t.Fatal("Expected error for invalid CSR") + } + + t.Logf("Correctly rejected invalid CSR: %v", err) + }) +} + +func generateTestCSR(commonName string) (*x509.CertificateRequest, string, error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, "", err + } + + subj := pkix.Name{ + CommonName: commonName, + } + + csrTemplate := x509.CertificateRequest{ + Subject: subj, + DNSNames: []string{commonName}, + SignatureAlgorithm: x509.SHA256WithRSA, + } + + csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &csrTemplate, key) + if err != nil { + return nil, "", err + } + + csr, err := x509.ParseCertificateRequest(csrBytes) + if err != nil { + return nil, "", err + } + + csrPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE REQUEST", + Bytes: csrBytes, + }) + + return csr, string(csrPEM), nil +} diff --git a/internal/repository/postgres/agent.go b/internal/repository/postgres/agent.go new file mode 100644 index 0000000..36329dc --- /dev/null +++ b/internal/repository/postgres/agent.go @@ -0,0 +1,193 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/shankar0123/certctl/internal/domain" +) + +// AgentRepository implements repository.AgentRepository +type AgentRepository struct { + db *sql.DB +} + +// NewAgentRepository creates a new AgentRepository +func NewAgentRepository(db *sql.DB) *AgentRepository { + return &AgentRepository{db: db} +} + +// List returns all agents +func (r *AgentRepository) List(ctx context.Context) ([]*domain.Agent, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash + FROM agents + ORDER BY registered_at DESC + `) + + if err != nil { + return nil, fmt.Errorf("failed to query agents: %w", err) + } + defer rows.Close() + + var agents []*domain.Agent + for rows.Next() { + agent, err := scanAgent(rows) + if err != nil { + return nil, err + } + agents = append(agents, agent) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating agent rows: %w", err) + } + + return agents, nil +} + +// Get retrieves an agent by ID +func (r *AgentRepository) Get(ctx context.Context, id string) (*domain.Agent, error) { + row := r.db.QueryRowContext(ctx, ` + SELECT id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash + FROM agents + WHERE id = $1 + `, id) + + agent, err := scanAgent(row) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("agent not found") + } + return nil, fmt.Errorf("failed to query agent: %w", err) + } + + return agent, nil +} + +// Create stores a new agent +func (r *AgentRepository) Create(ctx context.Context, agent *domain.Agent) error { + if agent.ID == "" { + agent.ID = uuid.New().String() + } + + err := r.db.QueryRowContext(ctx, ` + INSERT INTO agents (id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash) + VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING id + `, agent.ID, agent.Name, agent.Hostname, agent.Status, agent.LastHeartbeatAt, + agent.RegisteredAt, agent.APIKeyHash).Scan(&agent.ID) + + if err != nil { + return fmt.Errorf("failed to create agent: %w", err) + } + + return nil +} + +// Update modifies an existing agent +func (r *AgentRepository) Update(ctx context.Context, agent *domain.Agent) error { + result, err := r.db.ExecContext(ctx, ` + UPDATE agents SET + name = $1, + hostname = $2, + status = $3, + last_heartbeat_at = $4, + api_key_hash = $5 + WHERE id = $6 + `, agent.Name, agent.Hostname, agent.Status, agent.LastHeartbeatAt, agent.APIKeyHash, agent.ID) + + if err != nil { + return fmt.Errorf("failed to update agent: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("agent not found") + } + + return nil +} + +// Delete removes an agent +func (r *AgentRepository) Delete(ctx context.Context, id string) error { + result, err := r.db.ExecContext(ctx, "DELETE FROM agents WHERE id = $1", id) + + if err != nil { + return fmt.Errorf("failed to delete agent: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("agent not found") + } + + return nil +} + +// UpdateHeartbeat updates the agent's last heartbeat timestamp +func (r *AgentRepository) UpdateHeartbeat(ctx context.Context, id string) error { + result, err := r.db.ExecContext(ctx, ` + UPDATE agents SET last_heartbeat_at = $1 WHERE id = $2 + `, time.Now(), id) + + if err != nil { + return fmt.Errorf("failed to update heartbeat: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("agent not found") + } + + return nil +} + +// GetByAPIKey retrieves an agent by hashed API key +func (r *AgentRepository) GetByAPIKey(ctx context.Context, keyHash string) (*domain.Agent, error) { + row := r.db.QueryRowContext(ctx, ` + SELECT id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash + FROM agents + WHERE api_key_hash = $1 + `, keyHash) + + agent, err := scanAgent(row) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("agent not found") + } + return nil, fmt.Errorf("failed to query agent: %w", err) + } + + return agent, nil +} + +// scanAgent scans an agent from a row or rows +func scanAgent(scanner interface { + Scan(...interface{}) error +}) (*domain.Agent, error) { + var agent domain.Agent + err := scanner.Scan(&agent.ID, &agent.Name, &agent.Hostname, &agent.Status, + &agent.LastHeartbeatAt, &agent.RegisteredAt, &agent.APIKeyHash) + + if err != nil { + return nil, fmt.Errorf("failed to scan agent: %w", err) + } + + return &agent, nil +} diff --git a/internal/repository/postgres/audit.go b/internal/repository/postgres/audit.go new file mode 100644 index 0000000..bccdbd0 --- /dev/null +++ b/internal/repository/postgres/audit.go @@ -0,0 +1,140 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/google/uuid" + "github.com/shankar0123/certctl/internal/domain" + "github.com/shankar0123/certctl/internal/repository" +) + +// AuditRepository implements repository.AuditRepository +type AuditRepository struct { + db *sql.DB +} + +// NewAuditRepository creates a new AuditRepository +func NewAuditRepository(db *sql.DB) *AuditRepository { + return &AuditRepository{db: db} +} + +// Create stores a new audit event +func (r *AuditRepository) Create(ctx context.Context, event *domain.AuditEvent) error { + if event.ID == "" { + event.ID = uuid.New().String() + } + + err := r.db.QueryRowContext(ctx, ` + INSERT INTO audit_events ( + id, actor, actor_type, action, resource_type, resource_id, details, timestamp + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + RETURNING id + `, event.ID, event.Actor, event.ActorType, event.Action, event.ResourceType, + event.ResourceID, event.Details, event.Timestamp).Scan(&event.ID) + + if err != nil { + return fmt.Errorf("failed to create audit event: %w", err) + } + + return nil +} + +// List returns audit events matching the filter criteria +func (r *AuditRepository) List(ctx context.Context, filter *repository.AuditFilter) ([]*domain.AuditEvent, error) { + if filter == nil { + filter = &repository.AuditFilter{} + } + + // Set defaults + if filter.Page < 1 { + filter.Page = 1 + } + if filter.PerPage == 0 || filter.PerPage > 500 { + filter.PerPage = 50 + } + + // Build WHERE clause + var whereConditions []string + var args []interface{} + argCount := 1 + + if filter.Actor != "" { + whereConditions = append(whereConditions, fmt.Sprintf("actor = $%d", argCount)) + args = append(args, filter.Actor) + argCount++ + } + if filter.ActorType != "" { + whereConditions = append(whereConditions, fmt.Sprintf("actor_type = $%d", argCount)) + args = append(args, filter.ActorType) + argCount++ + } + if filter.ResourceType != "" { + whereConditions = append(whereConditions, fmt.Sprintf("resource_type = $%d", argCount)) + args = append(args, filter.ResourceType) + argCount++ + } + if filter.ResourceID != "" { + whereConditions = append(whereConditions, fmt.Sprintf("resource_id = $%d", argCount)) + args = append(args, filter.ResourceID) + argCount++ + } + if !filter.From.IsZero() { + whereConditions = append(whereConditions, fmt.Sprintf("timestamp >= $%d", argCount)) + args = append(args, filter.From) + argCount++ + } + if !filter.To.IsZero() { + whereConditions = append(whereConditions, fmt.Sprintf("timestamp <= $%d", argCount)) + args = append(args, filter.To) + argCount++ + } + + whereClause := "" + if len(whereConditions) > 0 { + whereClause = "WHERE " + strings.Join(whereConditions, " AND ") + } + + // Get total count + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM audit_events %s", whereClause) + var total int + if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { + return nil, fmt.Errorf("failed to count audit events: %w", err) + } + + // Get paginated results + offset := (filter.Page - 1) * filter.PerPage + query := fmt.Sprintf(` + SELECT id, actor, actor_type, action, resource_type, resource_id, details, timestamp + FROM audit_events + %s + ORDER BY timestamp DESC + LIMIT $%d OFFSET $%d + `, whereClause, argCount, argCount+1) + + args = append(args, filter.PerPage, offset) + + rows, err := r.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("failed to query audit events: %w", err) + } + defer rows.Close() + + var events []*domain.AuditEvent + for rows.Next() { + var event domain.AuditEvent + if err := rows.Scan(&event.ID, &event.Actor, &event.ActorType, &event.Action, + &event.ResourceType, &event.ResourceID, &event.Details, &event.Timestamp); err != nil { + return nil, fmt.Errorf("failed to scan audit event: %w", err) + } + events = append(events, &event) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating audit event rows: %w", err) + } + + return events, nil +} diff --git a/internal/repository/postgres/certificate.go b/internal/repository/postgres/certificate.go new file mode 100644 index 0000000..39e0966 --- /dev/null +++ b/internal/repository/postgres/certificate.go @@ -0,0 +1,346 @@ +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/google/uuid" + "github.com/lib/pq" + "github.com/shankar0123/certctl/internal/domain" + "github.com/shankar0123/certctl/internal/repository" +) + +// CertificateRepository implements repository.CertificateRepository +type CertificateRepository struct { + db *sql.DB +} + +// NewCertificateRepository creates a new CertificateRepository +func NewCertificateRepository(db *sql.DB) *CertificateRepository { + return &CertificateRepository{db: db} +} + +// List returns a paginated list of certificates matching the filter criteria +func (r *CertificateRepository) List(ctx context.Context, filter *repository.CertificateFilter) ([]*domain.ManagedCertificate, int, error) { + if filter == nil { + filter = &repository.CertificateFilter{} + } + + // Set defaults + if filter.Page < 1 { + filter.Page = 1 + } + if filter.PerPage == 0 || filter.PerPage > 500 { + filter.PerPage = 50 + } + + // Build WHERE clause + var whereConditions []string + var args []interface{} + argCount := 1 + + if filter.Status != "" { + whereConditions = append(whereConditions, fmt.Sprintf("status = $%d", argCount)) + args = append(args, filter.Status) + argCount++ + } + if filter.Environment != "" { + whereConditions = append(whereConditions, fmt.Sprintf("environment = $%d", argCount)) + args = append(args, filter.Environment) + argCount++ + } + if filter.OwnerID != "" { + whereConditions = append(whereConditions, fmt.Sprintf("owner_id = $%d", argCount)) + args = append(args, filter.OwnerID) + argCount++ + } + if filter.TeamID != "" { + whereConditions = append(whereConditions, fmt.Sprintf("team_id = $%d", argCount)) + args = append(args, filter.TeamID) + argCount++ + } + if filter.IssuerID != "" { + whereConditions = append(whereConditions, fmt.Sprintf("issuer_id = $%d", argCount)) + args = append(args, filter.IssuerID) + argCount++ + } + + whereClause := "" + if len(whereConditions) > 0 { + whereClause = "WHERE " + strings.Join(whereConditions, " AND ") + } + + // Get total count + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM certificates %s", whereClause) + var total int + if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { + return nil, 0, fmt.Errorf("failed to count certificates: %w", err) + } + + // Get paginated results + offset := (filter.Page - 1) * filter.PerPage + query := fmt.Sprintf(` + SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, + status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at + FROM certificates + %s + ORDER BY created_at DESC + LIMIT $%d OFFSET $%d + `, whereClause, argCount, argCount+1) + + args = append(args, filter.PerPage, offset) + + rows, err := r.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, 0, fmt.Errorf("failed to query certificates: %w", err) + } + defer rows.Close() + + var certs []*domain.ManagedCertificate + for rows.Next() { + cert, err := scanCertificate(rows) + if err != nil { + return nil, 0, err + } + certs = append(certs, cert) + } + + if err := rows.Err(); err != nil { + return nil, 0, fmt.Errorf("error iterating certificate rows: %w", err) + } + + return certs, total, nil +} + +// Get retrieves a certificate by ID +func (r *CertificateRepository) Get(ctx context.Context, id string) (*domain.ManagedCertificate, error) { + row := r.db.QueryRowContext(ctx, ` + SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, + status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at + FROM certificates + WHERE id = $1 + `, id) + + cert, err := scanCertificate(row) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("certificate not found") + } + return nil, fmt.Errorf("failed to query certificate: %w", err) + } + + return cert, nil +} + +// Create stores a new certificate +func (r *CertificateRepository) Create(ctx context.Context, cert *domain.ManagedCertificate) error { + if cert.ID == "" { + cert.ID = uuid.New().String() + } + + tagsJSON, err := json.Marshal(cert.Tags) + if err != nil { + return fmt.Errorf("failed to marshal tags: %w", err) + } + + err = r.db.QueryRowContext(ctx, ` + INSERT INTO certificates ( + id, name, common_name, sans, environment, owner_id, team_id, issuer_id, + status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) + RETURNING id + `, cert.ID, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment, + cert.OwnerID, cert.TeamID, cert.IssuerID, cert.Status, cert.ExpiresAt, + tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.CreatedAt, cert.UpdatedAt).Scan(&cert.ID) + + if err != nil { + return fmt.Errorf("failed to create certificate: %w", err) + } + + return nil +} + +// Update modifies an existing certificate +func (r *CertificateRepository) Update(ctx context.Context, cert *domain.ManagedCertificate) error { + tagsJSON, err := json.Marshal(cert.Tags) + if err != nil { + return fmt.Errorf("failed to marshal tags: %w", err) + } + + result, err := r.db.ExecContext(ctx, ` + UPDATE certificates SET + name = $1, + common_name = $2, + sans = $3, + environment = $4, + owner_id = $5, + team_id = $6, + issuer_id = $7, + status = $8, + expires_at = $9, + tags = $10, + last_renewal_at = $11, + last_deployment_at = $12, + updated_at = $13 + WHERE id = $14 + `, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment, + cert.OwnerID, cert.TeamID, cert.IssuerID, cert.Status, cert.ExpiresAt, + tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.UpdatedAt, cert.ID) + + if err != nil { + return fmt.Errorf("failed to update certificate: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("certificate not found") + } + + return nil +} + +// Archive marks a certificate as archived +func (r *CertificateRepository) Archive(ctx context.Context, id string) error { + result, err := r.db.ExecContext(ctx, ` + UPDATE certificates SET status = $1, updated_at = $2 WHERE id = $3 + `, domain.CertificateStatusArchived, time.Now(), id) + + if err != nil { + return fmt.Errorf("failed to archive certificate: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("certificate not found") + } + + return nil +} + +// ListVersions returns all versions of a certificate +func (r *CertificateRepository) ListVersions(ctx context.Context, certID string) ([]*domain.CertificateVersion, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, certificate_id, serial_number, not_before, not_after, + fingerprint_sha256, pem_chain, csr_pem, created_at + FROM certificate_versions + WHERE certificate_id = $1 + ORDER BY created_at DESC + `, certID) + + if err != nil { + return nil, fmt.Errorf("failed to query certificate versions: %w", err) + } + defer rows.Close() + + var versions []*domain.CertificateVersion + for rows.Next() { + var v domain.CertificateVersion + if err := rows.Scan(&v.ID, &v.CertificateID, &v.SerialNumber, &v.NotBefore, &v.NotAfter, + &v.FingerprintSHA256, &v.PEMChain, &v.CSRPEM, &v.CreatedAt); err != nil { + return nil, fmt.Errorf("failed to scan certificate version: %w", err) + } + versions = append(versions, &v) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating version rows: %w", err) + } + + return versions, nil +} + +// CreateVersion stores a new certificate version +func (r *CertificateRepository) CreateVersion(ctx context.Context, version *domain.CertificateVersion) error { + if version.ID == "" { + version.ID = uuid.New().String() + } + + err := r.db.QueryRowContext(ctx, ` + INSERT INTO certificate_versions ( + id, certificate_id, serial_number, not_before, not_after, + fingerprint_sha256, pem_chain, csr_pem, created_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + RETURNING id + `, version.ID, version.CertificateID, version.SerialNumber, version.NotBefore, version.NotAfter, + version.FingerprintSHA256, version.PEMChain, version.CSRPEM, version.CreatedAt).Scan(&version.ID) + + if err != nil { + return fmt.Errorf("failed to create certificate version: %w", err) + } + + return nil +} + +// GetExpiringCertificates returns certificates expiring before the given time +func (r *CertificateRepository) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, + status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at + FROM certificates + WHERE expires_at < $1 AND status != $2 + ORDER BY expires_at ASC + `, before, domain.CertificateStatusArchived) + + if err != nil { + return nil, fmt.Errorf("failed to query expiring certificates: %w", err) + } + defer rows.Close() + + var certs []*domain.ManagedCertificate + for rows.Next() { + cert, err := scanCertificate(rows) + if err != nil { + return nil, err + } + certs = append(certs, cert) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating expiring certificate rows: %w", err) + } + + return certs, nil +} + +// scanCertificate scans a certificate from a row or rows +func scanCertificate(scanner interface { + Scan(...interface{}) error +}) (*domain.ManagedCertificate, error) { + var cert domain.ManagedCertificate + var tagsJSON []byte + var sans pq.StringArray + + err := scanner.Scan( + &cert.ID, &cert.Name, &cert.CommonName, &sans, &cert.Environment, &cert.OwnerID, + &cert.TeamID, &cert.IssuerID, &cert.Status, &cert.ExpiresAt, &tagsJSON, + &cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.CreatedAt, &cert.UpdatedAt) + + if err != nil { + return nil, fmt.Errorf("failed to scan certificate: %w", err) + } + + cert.SANs = []string(sans) + + // Unmarshal tags + if len(tagsJSON) > 0 { + if err := json.Unmarshal(tagsJSON, &cert.Tags); err != nil { + return nil, fmt.Errorf("failed to unmarshal tags: %w", err) + } + } else { + cert.Tags = make(map[string]string) + } + + return &cert, nil +} diff --git a/internal/repository/postgres/db.go b/internal/repository/postgres/db.go new file mode 100644 index 0000000..549d62c --- /dev/null +++ b/internal/repository/postgres/db.go @@ -0,0 +1,68 @@ +package postgres + +import ( + "database/sql" + "fmt" + "os" + "path/filepath" + "strings" + + _ "github.com/lib/pq" +) + +// NewDB opens a PostgreSQL database connection and sets up connection pooling. +func NewDB(connStr string) (*sql.DB, error) { + db, err := sql.Open("postgres", connStr) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + + // Configure connection pool + db.SetMaxOpenConns(25) + db.SetMaxIdleConns(5) + + // Ping to verify connection + if err := db.Ping(); err != nil { + return nil, fmt.Errorf("failed to ping database: %w", err) + } + + return db, nil +} + +// RunMigrations reads and executes SQL migration files from a directory. +func RunMigrations(db *sql.DB, migrationsPath string) error { + // Check if migrations directory exists + if _, err := os.Stat(migrationsPath); os.IsNotExist(err) { + return fmt.Errorf("migrations directory not found: %s", migrationsPath) + } + + // Read all SQL files from the migrations directory + files, err := os.ReadDir(migrationsPath) + if err != nil { + return fmt.Errorf("failed to read migrations directory: %w", err) + } + + // Sort and filter SQL files + var sqlFiles []string + for _, file := range files { + if !file.IsDir() && strings.HasSuffix(file.Name(), ".sql") { + sqlFiles = append(sqlFiles, file.Name()) + } + } + + // Execute each migration file in order + for _, filename := range sqlFiles { + filePath := filepath.Join(migrationsPath, filename) + content, err := os.ReadFile(filePath) + if err != nil { + return fmt.Errorf("failed to read migration file %s: %w", filename, err) + } + + // Execute the SQL content + if _, err := db.Exec(string(content)); err != nil { + return fmt.Errorf("failed to execute migration %s: %w", filename, err) + } + } + + return nil +} diff --git a/internal/repository/postgres/issuer.go b/internal/repository/postgres/issuer.go new file mode 100644 index 0000000..01d3d08 --- /dev/null +++ b/internal/repository/postgres/issuer.go @@ -0,0 +1,138 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/google/uuid" + "github.com/shankar0123/certctl/internal/domain" +) + +// IssuerRepository implements repository.IssuerRepository +type IssuerRepository struct { + db *sql.DB +} + +// NewIssuerRepository creates a new IssuerRepository +func NewIssuerRepository(db *sql.DB) *IssuerRepository { + return &IssuerRepository{db: db} +} + +// List returns all issuers +func (r *IssuerRepository) List(ctx context.Context) ([]*domain.Issuer, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, name, type, config, enabled, created_at, updated_at + FROM issuers + ORDER BY created_at DESC + `) + + if err != nil { + return nil, fmt.Errorf("failed to query issuers: %w", err) + } + defer rows.Close() + + var issuers []*domain.Issuer + for rows.Next() { + var issuer domain.Issuer + if err := rows.Scan(&issuer.ID, &issuer.Name, &issuer.Type, &issuer.Config, + &issuer.Enabled, &issuer.CreatedAt, &issuer.UpdatedAt); err != nil { + return nil, fmt.Errorf("failed to scan issuer: %w", err) + } + issuers = append(issuers, &issuer) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating issuer rows: %w", err) + } + + return issuers, nil +} + +// Get retrieves an issuer by ID +func (r *IssuerRepository) Get(ctx context.Context, id string) (*domain.Issuer, error) { + var issuer domain.Issuer + err := r.db.QueryRowContext(ctx, ` + SELECT id, name, type, config, enabled, created_at, updated_at + FROM issuers + WHERE id = $1 + `, id).Scan(&issuer.ID, &issuer.Name, &issuer.Type, &issuer.Config, + &issuer.Enabled, &issuer.CreatedAt, &issuer.UpdatedAt) + + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("issuer not found") + } + return nil, fmt.Errorf("failed to query issuer: %w", err) + } + + return &issuer, nil +} + +// Create stores a new issuer +func (r *IssuerRepository) Create(ctx context.Context, issuer *domain.Issuer) error { + if issuer.ID == "" { + issuer.ID = uuid.New().String() + } + + err := r.db.QueryRowContext(ctx, ` + INSERT INTO issuers (id, name, type, config, enabled, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING id + `, issuer.ID, issuer.Name, issuer.Type, issuer.Config, issuer.Enabled, + issuer.CreatedAt, issuer.UpdatedAt).Scan(&issuer.ID) + + if err != nil { + return fmt.Errorf("failed to create issuer: %w", err) + } + + return nil +} + +// Update modifies an existing issuer +func (r *IssuerRepository) Update(ctx context.Context, issuer *domain.Issuer) error { + result, err := r.db.ExecContext(ctx, ` + UPDATE issuers SET + name = $1, + type = $2, + config = $3, + enabled = $4, + updated_at = $5 + WHERE id = $6 + `, issuer.Name, issuer.Type, issuer.Config, issuer.Enabled, issuer.UpdatedAt, issuer.ID) + + if err != nil { + return fmt.Errorf("failed to update issuer: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("issuer not found") + } + + return nil +} + +// Delete removes an issuer +func (r *IssuerRepository) Delete(ctx context.Context, id string) error { + result, err := r.db.ExecContext(ctx, "DELETE FROM issuers WHERE id = $1", id) + + if err != nil { + return fmt.Errorf("failed to delete issuer: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("issuer not found") + } + + return nil +} diff --git a/internal/repository/postgres/job.go b/internal/repository/postgres/job.go new file mode 100644 index 0000000..f6980b4 --- /dev/null +++ b/internal/repository/postgres/job.go @@ -0,0 +1,284 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/google/uuid" + "github.com/shankar0123/certctl/internal/domain" +) + +// JobRepository implements repository.JobRepository +type JobRepository struct { + db *sql.DB +} + +// NewJobRepository creates a new JobRepository +func NewJobRepository(db *sql.DB) *JobRepository { + return &JobRepository{db: db} +} + +// List returns all jobs +func (r *JobRepository) List(ctx context.Context) ([]*domain.Job, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, type, certificate_id, target_id, status, attempts, max_attempts, + last_error, scheduled_at, started_at, completed_at, created_at + FROM jobs + ORDER BY created_at DESC + `) + + if err != nil { + return nil, fmt.Errorf("failed to query jobs: %w", err) + } + defer rows.Close() + + var jobs []*domain.Job + for rows.Next() { + job, err := scanJob(rows) + if err != nil { + return nil, err + } + jobs = append(jobs, job) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating job rows: %w", err) + } + + return jobs, nil +} + +// Get retrieves a job by ID +func (r *JobRepository) Get(ctx context.Context, id string) (*domain.Job, error) { + row := r.db.QueryRowContext(ctx, ` + SELECT id, type, certificate_id, target_id, status, attempts, max_attempts, + last_error, scheduled_at, started_at, completed_at, created_at + FROM jobs + WHERE id = $1 + `, id) + + job, err := scanJob(row) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("job not found") + } + return nil, fmt.Errorf("failed to query job: %w", err) + } + + return job, nil +} + +// Create stores a new job +func (r *JobRepository) Create(ctx context.Context, job *domain.Job) error { + if job.ID == "" { + job.ID = uuid.New().String() + } + + err := r.db.QueryRowContext(ctx, ` + INSERT INTO jobs ( + id, type, certificate_id, target_id, status, attempts, max_attempts, + last_error, scheduled_at, started_at, completed_at, created_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + RETURNING id + `, job.ID, job.Type, job.CertificateID, job.TargetID, job.Status, job.Attempts, + job.MaxAttempts, job.LastError, job.ScheduledAt, job.StartedAt, job.CompletedAt, + job.CreatedAt).Scan(&job.ID) + + if err != nil { + return fmt.Errorf("failed to create job: %w", err) + } + + return nil +} + +// Update modifies an existing job +func (r *JobRepository) Update(ctx context.Context, job *domain.Job) error { + result, err := r.db.ExecContext(ctx, ` + UPDATE jobs SET + type = $1, + certificate_id = $2, + target_id = $3, + status = $4, + attempts = $5, + max_attempts = $6, + last_error = $7, + scheduled_at = $8, + started_at = $9, + completed_at = $10 + WHERE id = $11 + `, job.Type, job.CertificateID, job.TargetID, job.Status, job.Attempts, + job.MaxAttempts, job.LastError, job.ScheduledAt, job.StartedAt, + job.CompletedAt, job.ID) + + if err != nil { + return fmt.Errorf("failed to update job: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("job not found") + } + + return nil +} + +// Delete removes a job +func (r *JobRepository) Delete(ctx context.Context, id string) error { + result, err := r.db.ExecContext(ctx, "DELETE FROM jobs WHERE id = $1", id) + + if err != nil { + return fmt.Errorf("failed to delete job: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("job not found") + } + + return nil +} + +// ListByStatus returns jobs with a specific status +func (r *JobRepository) ListByStatus(ctx context.Context, status domain.JobStatus) ([]*domain.Job, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, type, certificate_id, target_id, status, attempts, max_attempts, + last_error, scheduled_at, started_at, completed_at, created_at + FROM jobs + WHERE status = $1 + ORDER BY created_at DESC + `, status) + + if err != nil { + return nil, fmt.Errorf("failed to query jobs by status: %w", err) + } + defer rows.Close() + + var jobs []*domain.Job + for rows.Next() { + job, err := scanJob(rows) + if err != nil { + return nil, err + } + jobs = append(jobs, job) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating job rows: %w", err) + } + + return jobs, nil +} + +// ListByCertificate returns all jobs for a certificate +func (r *JobRepository) ListByCertificate(ctx context.Context, certID string) ([]*domain.Job, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, type, certificate_id, target_id, status, attempts, max_attempts, + last_error, scheduled_at, started_at, completed_at, created_at + FROM jobs + WHERE certificate_id = $1 + ORDER BY created_at DESC + `, certID) + + if err != nil { + return nil, fmt.Errorf("failed to query jobs for certificate: %w", err) + } + defer rows.Close() + + var jobs []*domain.Job + for rows.Next() { + job, err := scanJob(rows) + if err != nil { + return nil, err + } + jobs = append(jobs, job) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating job rows: %w", err) + } + + return jobs, nil +} + +// UpdateStatus updates a job's status and optional error message +func (r *JobRepository) UpdateStatus(ctx context.Context, id string, status domain.JobStatus, errMsg string) error { + var lastError *string + if errMsg != "" { + lastError = &errMsg + } + + result, err := r.db.ExecContext(ctx, ` + UPDATE jobs SET status = $1, last_error = $2 WHERE id = $3 + `, status, lastError, id) + + if err != nil { + return fmt.Errorf("failed to update job status: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("job not found") + } + + return nil +} + +// GetPendingJobs returns jobs not yet processed of a specific type +func (r *JobRepository) GetPendingJobs(ctx context.Context, jobType domain.JobType) ([]*domain.Job, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, type, certificate_id, target_id, status, attempts, max_attempts, + last_error, scheduled_at, started_at, completed_at, created_at + FROM jobs + WHERE type = $1 AND status = $2 + ORDER BY scheduled_at ASC + `, jobType, domain.JobStatusPending) + + if err != nil { + return nil, fmt.Errorf("failed to query pending jobs: %w", err) + } + defer rows.Close() + + var jobs []*domain.Job + for rows.Next() { + job, err := scanJob(rows) + if err != nil { + return nil, err + } + jobs = append(jobs, job) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating job rows: %w", err) + } + + return jobs, nil +} + +// scanJob scans a job from a row or rows +func scanJob(scanner interface { + Scan(...interface{}) error +}) (*domain.Job, error) { + var job domain.Job + err := scanner.Scan(&job.ID, &job.Type, &job.CertificateID, &job.TargetID, + &job.Status, &job.Attempts, &job.MaxAttempts, &job.LastError, + &job.ScheduledAt, &job.StartedAt, &job.CompletedAt, &job.CreatedAt) + + if err != nil { + return nil, fmt.Errorf("failed to scan job: %w", err) + } + + return &job, nil +} diff --git a/internal/repository/postgres/notification.go b/internal/repository/postgres/notification.go new file mode 100644 index 0000000..bb5870f --- /dev/null +++ b/internal/repository/postgres/notification.go @@ -0,0 +1,162 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + "strings" + "time" + + "github.com/google/uuid" + "github.com/shankar0123/certctl/internal/domain" + "github.com/shankar0123/certctl/internal/repository" +) + +// NotificationRepository implements repository.NotificationRepository +type NotificationRepository struct { + db *sql.DB +} + +// NewNotificationRepository creates a new NotificationRepository +func NewNotificationRepository(db *sql.DB) *NotificationRepository { + return &NotificationRepository{db: db} +} + +// Create stores a new notification +func (r *NotificationRepository) Create(ctx context.Context, notif *domain.NotificationEvent) error { + if notif.ID == "" { + notif.ID = uuid.New().String() + } + + err := r.db.QueryRowContext(ctx, ` + INSERT INTO notifications ( + id, type, certificate_id, channel, recipient, message, sent_at, status, error, created_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + RETURNING id + `, notif.ID, notif.Type, notif.CertificateID, notif.Channel, notif.Recipient, + notif.Message, notif.SentAt, notif.Status, notif.Error, notif.CreatedAt).Scan(¬if.ID) + + if err != nil { + return fmt.Errorf("failed to create notification: %w", err) + } + + return nil +} + +// List returns notifications matching the filter criteria +func (r *NotificationRepository) List(ctx context.Context, filter *repository.NotificationFilter) ([]*domain.NotificationEvent, error) { + if filter == nil { + filter = &repository.NotificationFilter{} + } + + // Set defaults + if filter.Page < 1 { + filter.Page = 1 + } + if filter.PerPage == 0 || filter.PerPage > 500 { + filter.PerPage = 50 + } + + // Build WHERE clause + var whereConditions []string + var args []interface{} + argCount := 1 + + if filter.CertificateID != "" { + whereConditions = append(whereConditions, fmt.Sprintf("certificate_id = $%d", argCount)) + args = append(args, filter.CertificateID) + argCount++ + } + if filter.Status != "" { + whereConditions = append(whereConditions, fmt.Sprintf("status = $%d", argCount)) + args = append(args, filter.Status) + argCount++ + } + if filter.Channel != "" { + whereConditions = append(whereConditions, fmt.Sprintf("channel = $%d", argCount)) + args = append(args, filter.Channel) + argCount++ + } + + whereClause := "" + if len(whereConditions) > 0 { + whereClause = "WHERE " + strings.Join(whereConditions, " AND ") + } + + // Get total count + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM notifications %s", whereClause) + var total int + if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { + return nil, fmt.Errorf("failed to count notifications: %w", err) + } + + // Get paginated results + offset := (filter.Page - 1) * filter.PerPage + query := fmt.Sprintf(` + SELECT id, type, certificate_id, channel, recipient, message, sent_at, status, error, created_at + FROM notifications + %s + ORDER BY created_at DESC + LIMIT $%d OFFSET $%d + `, whereClause, argCount, argCount+1) + + args = append(args, filter.PerPage, offset) + + rows, err := r.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("failed to query notifications: %w", err) + } + defer rows.Close() + + var notifs []*domain.NotificationEvent + for rows.Next() { + notif, err := scanNotification(rows) + if err != nil { + return nil, err + } + notifs = append(notifs, notif) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating notification rows: %w", err) + } + + return notifs, nil +} + +// UpdateStatus updates a notification's delivery status +func (r *NotificationRepository) UpdateStatus(ctx context.Context, id string, status string, sentAt time.Time) error { + result, err := r.db.ExecContext(ctx, ` + UPDATE notifications SET status = $1, sent_at = $2 WHERE id = $3 + `, status, sentAt, id) + + if err != nil { + return fmt.Errorf("failed to update notification status: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("notification not found") + } + + return nil +} + +// scanNotification scans a notification from a row or rows +func scanNotification(scanner interface { + Scan(...interface{}) error +}) (*domain.NotificationEvent, error) { + var notif domain.NotificationEvent + err := scanner.Scan(¬if.ID, ¬if.Type, ¬if.CertificateID, ¬if.Channel, + ¬if.Recipient, ¬if.Message, ¬if.SentAt, ¬if.Status, ¬if.Error, ¬if.CreatedAt) + + if err != nil { + return nil, fmt.Errorf("failed to scan notification: %w", err) + } + + return ¬if, nil +} diff --git a/internal/repository/postgres/owner.go b/internal/repository/postgres/owner.go new file mode 100644 index 0000000..57cec59 --- /dev/null +++ b/internal/repository/postgres/owner.go @@ -0,0 +1,137 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/google/uuid" + "github.com/shankar0123/certctl/internal/domain" +) + +// OwnerRepository implements repository.OwnerRepository +type OwnerRepository struct { + db *sql.DB +} + +// NewOwnerRepository creates a new OwnerRepository +func NewOwnerRepository(db *sql.DB) *OwnerRepository { + return &OwnerRepository{db: db} +} + +// List returns all owners +func (r *OwnerRepository) List(ctx context.Context) ([]*domain.Owner, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, name, email, team_id, created_at, updated_at + FROM owners + ORDER BY created_at DESC + `) + + if err != nil { + return nil, fmt.Errorf("failed to query owners: %w", err) + } + defer rows.Close() + + var owners []*domain.Owner + for rows.Next() { + var owner domain.Owner + if err := rows.Scan(&owner.ID, &owner.Name, &owner.Email, &owner.TeamID, + &owner.CreatedAt, &owner.UpdatedAt); err != nil { + return nil, fmt.Errorf("failed to scan owner: %w", err) + } + owners = append(owners, &owner) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating owner rows: %w", err) + } + + return owners, nil +} + +// Get retrieves an owner by ID +func (r *OwnerRepository) Get(ctx context.Context, id string) (*domain.Owner, error) { + var owner domain.Owner + err := r.db.QueryRowContext(ctx, ` + SELECT id, name, email, team_id, created_at, updated_at + FROM owners + WHERE id = $1 + `, id).Scan(&owner.ID, &owner.Name, &owner.Email, &owner.TeamID, + &owner.CreatedAt, &owner.UpdatedAt) + + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("owner not found") + } + return nil, fmt.Errorf("failed to query owner: %w", err) + } + + return &owner, nil +} + +// Create stores a new owner +func (r *OwnerRepository) Create(ctx context.Context, owner *domain.Owner) error { + if owner.ID == "" { + owner.ID = uuid.New().String() + } + + err := r.db.QueryRowContext(ctx, ` + INSERT INTO owners (id, name, email, team_id, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING id + `, owner.ID, owner.Name, owner.Email, owner.TeamID, + owner.CreatedAt, owner.UpdatedAt).Scan(&owner.ID) + + if err != nil { + return fmt.Errorf("failed to create owner: %w", err) + } + + return nil +} + +// Update modifies an existing owner +func (r *OwnerRepository) Update(ctx context.Context, owner *domain.Owner) error { + result, err := r.db.ExecContext(ctx, ` + UPDATE owners SET + name = $1, + email = $2, + team_id = $3, + updated_at = $4 + WHERE id = $5 + `, owner.Name, owner.Email, owner.TeamID, owner.UpdatedAt, owner.ID) + + if err != nil { + return fmt.Errorf("failed to update owner: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("owner not found") + } + + return nil +} + +// Delete removes an owner +func (r *OwnerRepository) Delete(ctx context.Context, id string) error { + result, err := r.db.ExecContext(ctx, "DELETE FROM owners WHERE id = $1", id) + + if err != nil { + return fmt.Errorf("failed to delete owner: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("owner not found") + } + + return nil +} diff --git a/internal/repository/postgres/policy.go b/internal/repository/postgres/policy.go new file mode 100644 index 0000000..46a0c61 --- /dev/null +++ b/internal/repository/postgres/policy.go @@ -0,0 +1,242 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/google/uuid" + "github.com/shankar0123/certctl/internal/domain" + "github.com/shankar0123/certctl/internal/repository" +) + +// PolicyRepository implements repository.PolicyRepository +type PolicyRepository struct { + db *sql.DB +} + +// NewPolicyRepository creates a new PolicyRepository +func NewPolicyRepository(db *sql.DB) *PolicyRepository { + return &PolicyRepository{db: db} +} + +// ListRules returns all policy rules +func (r *PolicyRepository) ListRules(ctx context.Context) ([]*domain.PolicyRule, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, name, type, config, enabled, created_at, updated_at + FROM policy_rules + ORDER BY created_at DESC + `) + + if err != nil { + return nil, fmt.Errorf("failed to query policy rules: %w", err) + } + defer rows.Close() + + var rules []*domain.PolicyRule + for rows.Next() { + var rule domain.PolicyRule + if err := rows.Scan(&rule.ID, &rule.Name, &rule.Type, &rule.Config, + &rule.Enabled, &rule.CreatedAt, &rule.UpdatedAt); err != nil { + return nil, fmt.Errorf("failed to scan policy rule: %w", err) + } + rules = append(rules, &rule) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating policy rule rows: %w", err) + } + + return rules, nil +} + +// GetRule retrieves a policy rule by ID +func (r *PolicyRepository) GetRule(ctx context.Context, id string) (*domain.PolicyRule, error) { + var rule domain.PolicyRule + err := r.db.QueryRowContext(ctx, ` + SELECT id, name, type, config, enabled, created_at, updated_at + FROM policy_rules + WHERE id = $1 + `, id).Scan(&rule.ID, &rule.Name, &rule.Type, &rule.Config, + &rule.Enabled, &rule.CreatedAt, &rule.UpdatedAt) + + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("policy rule not found") + } + return nil, fmt.Errorf("failed to query policy rule: %w", err) + } + + return &rule, nil +} + +// CreateRule stores a new policy rule +func (r *PolicyRepository) CreateRule(ctx context.Context, rule *domain.PolicyRule) error { + if rule.ID == "" { + rule.ID = uuid.New().String() + } + + err := r.db.QueryRowContext(ctx, ` + INSERT INTO policy_rules (id, name, type, config, enabled, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING id + `, rule.ID, rule.Name, rule.Type, rule.Config, rule.Enabled, + rule.CreatedAt, rule.UpdatedAt).Scan(&rule.ID) + + if err != nil { + return fmt.Errorf("failed to create policy rule: %w", err) + } + + return nil +} + +// UpdateRule modifies an existing policy rule +func (r *PolicyRepository) UpdateRule(ctx context.Context, rule *domain.PolicyRule) error { + result, err := r.db.ExecContext(ctx, ` + UPDATE policy_rules SET + name = $1, + type = $2, + config = $3, + enabled = $4, + updated_at = $5 + WHERE id = $6 + `, rule.Name, rule.Type, rule.Config, rule.Enabled, rule.UpdatedAt, rule.ID) + + if err != nil { + return fmt.Errorf("failed to update policy rule: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("policy rule not found") + } + + return nil +} + +// DeleteRule removes a policy rule +func (r *PolicyRepository) DeleteRule(ctx context.Context, id string) error { + result, err := r.db.ExecContext(ctx, "DELETE FROM policy_rules WHERE id = $1", id) + + if err != nil { + return fmt.Errorf("failed to delete policy rule: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("policy rule not found") + } + + return nil +} + +// CreateViolation records a policy violation +func (r *PolicyRepository) CreateViolation(ctx context.Context, violation *domain.PolicyViolation) error { + if violation.ID == "" { + violation.ID = uuid.New().String() + } + + err := r.db.QueryRowContext(ctx, ` + INSERT INTO policy_violations (id, certificate_id, rule_id, message, severity, created_at) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING id + `, violation.ID, violation.CertificateID, violation.RuleID, violation.Message, + violation.Severity, violation.CreatedAt).Scan(&violation.ID) + + if err != nil { + return fmt.Errorf("failed to create policy violation: %w", err) + } + + return nil +} + +// ListViolations returns policy violations, optionally filtered +func (r *PolicyRepository) ListViolations(ctx context.Context, filter *repository.AuditFilter) ([]*domain.PolicyViolation, error) { + if filter == nil { + filter = &repository.AuditFilter{} + } + + // Set defaults + if filter.Page < 1 { + filter.Page = 1 + } + if filter.PerPage == 0 || filter.PerPage > 500 { + filter.PerPage = 50 + } + + // Build WHERE clause + var whereConditions []string + var args []interface{} + argCount := 1 + + if filter.ResourceID != "" { + whereConditions = append(whereConditions, fmt.Sprintf("certificate_id = $%d", argCount)) + args = append(args, filter.ResourceID) + argCount++ + } + if !filter.From.IsZero() { + whereConditions = append(whereConditions, fmt.Sprintf("created_at >= $%d", argCount)) + args = append(args, filter.From) + argCount++ + } + if !filter.To.IsZero() { + whereConditions = append(whereConditions, fmt.Sprintf("created_at <= $%d", argCount)) + args = append(args, filter.To) + argCount++ + } + + whereClause := "" + if len(whereConditions) > 0 { + whereClause = "WHERE " + strings.Join(whereConditions, " AND ") + } + + // Get total count + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM policy_violations %s", whereClause) + var total int + if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { + return nil, fmt.Errorf("failed to count policy violations: %w", err) + } + + // Get paginated results + offset := (filter.Page - 1) * filter.PerPage + query := fmt.Sprintf(` + SELECT id, certificate_id, rule_id, message, severity, created_at + FROM policy_violations + %s + ORDER BY created_at DESC + LIMIT $%d OFFSET $%d + `, whereClause, argCount, argCount+1) + + args = append(args, filter.PerPage, offset) + + rows, err := r.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("failed to query policy violations: %w", err) + } + defer rows.Close() + + var violations []*domain.PolicyViolation + for rows.Next() { + var v domain.PolicyViolation + if err := rows.Scan(&v.ID, &v.CertificateID, &v.RuleID, &v.Message, + &v.Severity, &v.CreatedAt); err != nil { + return nil, fmt.Errorf("failed to scan policy violation: %w", err) + } + violations = append(violations, &v) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating policy violation rows: %w", err) + } + + return violations, nil +} diff --git a/internal/repository/postgres/target.go b/internal/repository/postgres/target.go new file mode 100644 index 0000000..00ea1f5 --- /dev/null +++ b/internal/repository/postgres/target.go @@ -0,0 +1,171 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/google/uuid" + "github.com/shankar0123/certctl/internal/domain" +) + +// TargetRepository implements repository.TargetRepository +type TargetRepository struct { + db *sql.DB +} + +// NewTargetRepository creates a new TargetRepository +func NewTargetRepository(db *sql.DB) *TargetRepository { + return &TargetRepository{db: db} +} + +// List returns all targets +func (r *TargetRepository) List(ctx context.Context) ([]*domain.DeploymentTarget, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, name, type, agent_id, config, enabled, created_at, updated_at + FROM deployment_targets + ORDER BY created_at DESC + `) + + if err != nil { + return nil, fmt.Errorf("failed to query targets: %w", err) + } + defer rows.Close() + + var targets []*domain.DeploymentTarget + for rows.Next() { + var target domain.DeploymentTarget + if err := rows.Scan(&target.ID, &target.Name, &target.Type, &target.AgentID, + &target.Config, &target.Enabled, &target.CreatedAt, &target.UpdatedAt); err != nil { + return nil, fmt.Errorf("failed to scan target: %w", err) + } + targets = append(targets, &target) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating target rows: %w", err) + } + + return targets, nil +} + +// Get retrieves a target by ID +func (r *TargetRepository) Get(ctx context.Context, id string) (*domain.DeploymentTarget, error) { + var target domain.DeploymentTarget + err := r.db.QueryRowContext(ctx, ` + SELECT id, name, type, agent_id, config, enabled, created_at, updated_at + FROM deployment_targets + WHERE id = $1 + `, id).Scan(&target.ID, &target.Name, &target.Type, &target.AgentID, + &target.Config, &target.Enabled, &target.CreatedAt, &target.UpdatedAt) + + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("target not found") + } + return nil, fmt.Errorf("failed to query target: %w", err) + } + + return &target, nil +} + +// Create stores a new target +func (r *TargetRepository) Create(ctx context.Context, target *domain.DeploymentTarget) error { + if target.ID == "" { + target.ID = uuid.New().String() + } + + err := r.db.QueryRowContext(ctx, ` + INSERT INTO deployment_targets (id, name, type, agent_id, config, enabled, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + RETURNING id + `, target.ID, target.Name, target.Type, target.AgentID, target.Config, target.Enabled, + target.CreatedAt, target.UpdatedAt).Scan(&target.ID) + + if err != nil { + return fmt.Errorf("failed to create target: %w", err) + } + + return nil +} + +// Update modifies an existing target +func (r *TargetRepository) Update(ctx context.Context, target *domain.DeploymentTarget) error { + result, err := r.db.ExecContext(ctx, ` + UPDATE deployment_targets SET + name = $1, + type = $2, + agent_id = $3, + config = $4, + enabled = $5, + updated_at = $6 + WHERE id = $7 + `, target.Name, target.Type, target.AgentID, target.Config, target.Enabled, target.UpdatedAt, target.ID) + + if err != nil { + return fmt.Errorf("failed to update target: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("target not found") + } + + return nil +} + +// Delete removes a target +func (r *TargetRepository) Delete(ctx context.Context, id string) error { + result, err := r.db.ExecContext(ctx, "DELETE FROM deployment_targets WHERE id = $1", id) + + if err != nil { + return fmt.Errorf("failed to delete target: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("target not found") + } + + return nil +} + +// ListByCertificate returns all targets for a given certificate +func (r *TargetRepository) ListByCertificate(ctx context.Context, certID string) ([]*domain.DeploymentTarget, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT dt.id, dt.name, dt.type, dt.agent_id, dt.config, dt.enabled, dt.created_at, dt.updated_at + FROM deployment_targets dt + INNER JOIN certificate_target_mappings ctm ON dt.id = ctm.target_id + WHERE ctm.certificate_id = $1 + ORDER BY dt.created_at DESC + `, certID) + + if err != nil { + return nil, fmt.Errorf("failed to query targets for certificate: %w", err) + } + defer rows.Close() + + var targets []*domain.DeploymentTarget + for rows.Next() { + var target domain.DeploymentTarget + if err := rows.Scan(&target.ID, &target.Name, &target.Type, &target.AgentID, + &target.Config, &target.Enabled, &target.CreatedAt, &target.UpdatedAt); err != nil { + return nil, fmt.Errorf("failed to scan target: %w", err) + } + targets = append(targets, &target) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating target rows: %w", err) + } + + return targets, nil +} diff --git a/internal/repository/postgres/team.go b/internal/repository/postgres/team.go new file mode 100644 index 0000000..cae2d3d --- /dev/null +++ b/internal/repository/postgres/team.go @@ -0,0 +1,135 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/google/uuid" + "github.com/shankar0123/certctl/internal/domain" +) + +// TeamRepository implements repository.TeamRepository +type TeamRepository struct { + db *sql.DB +} + +// NewTeamRepository creates a new TeamRepository +func NewTeamRepository(db *sql.DB) *TeamRepository { + return &TeamRepository{db: db} +} + +// List returns all teams +func (r *TeamRepository) List(ctx context.Context) ([]*domain.Team, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, name, description, created_at, updated_at + FROM teams + ORDER BY created_at DESC + `) + + if err != nil { + return nil, fmt.Errorf("failed to query teams: %w", err) + } + defer rows.Close() + + var teams []*domain.Team + for rows.Next() { + var team domain.Team + if err := rows.Scan(&team.ID, &team.Name, &team.Description, + &team.CreatedAt, &team.UpdatedAt); err != nil { + return nil, fmt.Errorf("failed to scan team: %w", err) + } + teams = append(teams, &team) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating team rows: %w", err) + } + + return teams, nil +} + +// Get retrieves a team by ID +func (r *TeamRepository) Get(ctx context.Context, id string) (*domain.Team, error) { + var team domain.Team + err := r.db.QueryRowContext(ctx, ` + SELECT id, name, description, created_at, updated_at + FROM teams + WHERE id = $1 + `, id).Scan(&team.ID, &team.Name, &team.Description, + &team.CreatedAt, &team.UpdatedAt) + + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("team not found") + } + return nil, fmt.Errorf("failed to query team: %w", err) + } + + return &team, nil +} + +// Create stores a new team +func (r *TeamRepository) Create(ctx context.Context, team *domain.Team) error { + if team.ID == "" { + team.ID = uuid.New().String() + } + + err := r.db.QueryRowContext(ctx, ` + INSERT INTO teams (id, name, description, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5) + RETURNING id + `, team.ID, team.Name, team.Description, team.CreatedAt, team.UpdatedAt).Scan(&team.ID) + + if err != nil { + return fmt.Errorf("failed to create team: %w", err) + } + + return nil +} + +// Update modifies an existing team +func (r *TeamRepository) Update(ctx context.Context, team *domain.Team) error { + result, err := r.db.ExecContext(ctx, ` + UPDATE teams SET + name = $1, + description = $2, + updated_at = $3 + WHERE id = $4 + `, team.Name, team.Description, team.UpdatedAt, team.ID) + + if err != nil { + return fmt.Errorf("failed to update team: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("team not found") + } + + return nil +} + +// Delete removes a team +func (r *TeamRepository) Delete(ctx context.Context, id string) error { + result, err := r.db.ExecContext(ctx, "DELETE FROM teams WHERE id = $1", id) + + if err != nil { + return fmt.Errorf("failed to delete team: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rows == 0 { + return fmt.Errorf("team not found") + } + + return nil +} diff --git a/internal/service/audit.go b/internal/service/audit.go index 4dbcf07..30f09c8 100644 --- a/internal/service/audit.go +++ b/internal/service/audit.go @@ -108,3 +108,54 @@ func (s *AuditService) ListByAction(ctx context.Context, action string, from, to return filtered, nil } + +// ListAuditEvents returns paginated audit events (handler interface method). +func (s *AuditService) ListAuditEvents(page, perPage int) ([]domain.AuditEvent, int64, error) { + if page < 1 { + page = 1 + } + if perPage < 1 { + perPage = 50 + } + + filter := &repository.AuditFilter{ + Offset: int64((page - 1) * perPage), + PerPage: int64(perPage), + } + + events, err := s.auditRepo.List(context.Background(), filter) + if err != nil { + return nil, 0, fmt.Errorf("failed to list audit events: %w", err) + } + + // Convert pointers to values for the handler interface + var result []domain.AuditEvent + for _, e := range events { + if e != nil { + result = append(result, *e) + } + } + + // TODO: Get total count from repository + total := int64(len(result)) + + return result, total, nil +} + +// GetAuditEvent returns a single audit event (handler interface method). +func (s *AuditService) GetAuditEvent(id string) (*domain.AuditEvent, error) { + filter := &repository.AuditFilter{ + ID: id, + } + + events, err := s.auditRepo.List(context.Background(), filter) + if err != nil { + return nil, fmt.Errorf("failed to get audit event: %w", err) + } + + if len(events) == 0 { + return nil, fmt.Errorf("audit event not found") + } + + return events[0], nil +} diff --git a/internal/service/audit.go.4445065566393902048 b/internal/service/audit.go.4445065566393902048 new file mode 100644 index 0000000..7ce55d4 --- /dev/null +++ b/internal/service/audit.go.4445065566393902048 @@ -0,0 +1,161 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/shankar0123/certctl/internal/domain" + "github.com/shankar0123/certctl/internal/repository" +) + +// AuditService provides business logic for recording and retrieving audit events. +type AuditService struct { + auditRepo repository.AuditRepository +} + +// NewAuditService creates a new audit service. +func NewAuditService(auditRepo repository.AuditRepository) *AuditService { + return &AuditService{ + auditRepo: auditRepo, + } +} + +// RecordEvent records an audit event with actor, action, and resource information. +func (s *AuditService) RecordEvent(ctx context.Context, actor string, actorType domain.ActorType, action string, resourceType string, resourceID string, details map[string]interface{}) error { + detailsJSON, err := json.Marshal(details) + if err != nil { + detailsJSON = []byte("{}") + } + + event := &domain.AuditEvent{ + ID: generateID("audit"), + Timestamp: time.Now(), + Actor: actor, + ActorType: actorType, + Action: action, + ResourceType: resourceType, + ResourceID: resourceID, + Details: json.RawMessage(detailsJSON), + } + + if err := s.auditRepo.Create(ctx, event); err != nil { + return fmt.Errorf("failed to record audit event: %w", err) + } + + return nil +} + +// List returns audit events matching filter criteria. +func (s *AuditService) List(ctx context.Context, filter *repository.AuditFilter) ([]*domain.AuditEvent, error) { + events, err := s.auditRepo.List(ctx, filter) + if err != nil { + return nil, fmt.Errorf("failed to list audit events: %w", err) + } + return events, nil +} + +// ListByResource returns all audit events for a specific resource. +func (s *AuditService) ListByResource(ctx context.Context, resourceType string, resourceID string) ([]*domain.AuditEvent, error) { + filter := &repository.AuditFilter{ + ResourceType: resourceType, + ResourceID: resourceID, + PerPage: 1000, // reasonable default for single resource + } + + events, err := s.auditRepo.List(ctx, filter) + if err != nil { + return nil, fmt.Errorf("failed to list audit events: %w", err) + } + return events, nil +} + +// ListByActor returns all audit events for a specific actor. +func (s *AuditService) ListByActor(ctx context.Context, actor string) ([]*domain.AuditEvent, error) { + filter := &repository.AuditFilter{ + Actor: actor, + PerPage: 1000, + } + + events, err := s.auditRepo.List(ctx, filter) + if err != nil { + return nil, fmt.Errorf("failed to list audit events: %w", err) + } + return events, nil +} + +// ListByAction returns all audit events for a specific action type. +func (s *AuditService) ListByAction(ctx context.Context, action string, from, to time.Time) ([]*domain.AuditEvent, error) { + filter := &repository.AuditFilter{ + From: from, + To: to, + PerPage: 1000, + } + + events, err := s.auditRepo.List(ctx, filter) + if err != nil { + return nil, fmt.Errorf("failed to list audit events: %w", err) + } + + // Filter by action on client side (repository may not filter by action directly) + var filtered []*domain.AuditEvent + for _, e := range events { + if e.Action == action { + filtered = append(filtered, e) + } + } + + return filtered, nil +} + +// ListAuditEvents returns paginated audit events (handler interface method). +func (s *AuditService) ListAuditEvents(page, perPage int) ([]domain.AuditEvent, int64, error) { + if page < 1 { + page = 1 + } + if perPage < 1 { + perPage = 50 + } + + filter := &repository.AuditFilter{ + Offset: int64((page - 1) * perPage), + PerPage: int64(perPage), + } + + events, err := s.auditRepo.List(context.Background(), filter) + if err != nil { + return nil, 0, fmt.Errorf("failed to list audit events: %w", err) + } + + // Convert pointers to values for the handler interface + var result []domain.AuditEvent + for _, e := range events { + if e != nil { + result = append(result, *e) + } + } + + // TODO: Get total count from repository + total := int64(len(result)) + + return result, total, nil +} + +// GetAuditEvent returns a single audit event (handler interface method). +func (s *AuditService) GetAuditEvent(id string) (*domain.AuditEvent, error) { + filter := &repository.AuditFilter{ + ID: id, + } + + events, err := s.auditRepo.List(context.Background(), filter) + if err != nil { + return nil, fmt.Errorf("failed to get audit event: %w", err) + } + + if len(events) == 0 { + return nil, fmt.Errorf("audit event not found") + } + + return events[0], nil +} diff --git a/internal/service/issuer.go b/internal/service/issuer.go new file mode 100644 index 0000000..4095605 --- /dev/null +++ b/internal/service/issuer.go @@ -0,0 +1,170 @@ +package service + +import ( + "context" + "fmt" + + "github.com/shankar0123/certctl/internal/domain" + "github.com/shankar0123/certctl/internal/repository" +) + +// IssuerService provides business logic for certificate issuer management. +type IssuerService struct { + issuerRepo repository.IssuerRepository + auditService *AuditService +} + +// NewIssuerService creates a new issuer service. +func NewIssuerService( + issuerRepo repository.IssuerRepository, + auditService *AuditService, +) *IssuerService { + return &IssuerService{ + issuerRepo: issuerRepo, + auditService: auditService, + } +} + +// List returns a paginated list of issuers. +func (s *IssuerService) List(ctx context.Context, page, perPage int) ([]*domain.Issuer, int64, error) { + if page < 1 { + page = 1 + } + if perPage < 1 { + perPage = 50 + } + + offset := int64((page - 1) * perPage) + issuers, total, err := s.issuerRepo.List(ctx, offset, int64(perPage)) + if err != nil { + return nil, 0, fmt.Errorf("failed to list issuers: %w", err) + } + return issuers, total, nil +} + +// Get retrieves an issuer by ID. +func (s *IssuerService) Get(ctx context.Context, id string) (*domain.Issuer, error) { + issuer, err := s.issuerRepo.Get(ctx, id) + if err != nil { + return nil, fmt.Errorf("failed to get issuer %s: %w", id, err) + } + return issuer, nil +} + +// Create validates and stores a new issuer. +func (s *IssuerService) Create(ctx context.Context, issuer *domain.Issuer, actor string) error { + if issuer.Name == "" { + return fmt.Errorf("issuer name is required") + } + + issuer.ID = generateID("issuer") + if err := s.issuerRepo.Create(ctx, issuer); err != nil { + return fmt.Errorf("failed to create issuer: %w", err) + } + + if s.auditService != nil { + _ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "create_issuer", "issuer", issuer.ID, nil) + } + + return nil +} + +// Update modifies an existing issuer. +func (s *IssuerService) Update(ctx context.Context, id string, issuer *domain.Issuer, actor string) error { + if issuer.Name == "" { + return fmt.Errorf("issuer name is required") + } + + issuer.ID = id + if err := s.issuerRepo.Update(ctx, issuer); err != nil { + return fmt.Errorf("failed to update issuer %s: %w", id, err) + } + + if s.auditService != nil { + _ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "update_issuer", "issuer", id, nil) + } + + return nil +} + +// Delete removes an issuer. +func (s *IssuerService) Delete(ctx context.Context, id string, actor string) error { + if err := s.issuerRepo.Delete(ctx, id); err != nil { + return fmt.Errorf("failed to delete issuer %s: %w", id, err) + } + + if s.auditService != nil { + _ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "delete_issuer", "issuer", id, nil) + } + + return nil +} + +// TestConnection verifies the issuer connection. +func (s *IssuerService) TestConnection(ctx context.Context, id string) error { + issuer, err := s.issuerRepo.Get(ctx, id) + if err != nil { + return fmt.Errorf("issuer not found: %w", err) + } + + // TODO: Implement actual connection test based on issuer type + if issuer == nil { + return fmt.Errorf("issuer not found") + } + + return nil +} + +// ListIssuers returns paginated issuers (handler interface method). +func (s *IssuerService) ListIssuers(page, perPage int) ([]domain.Issuer, int64, error) { + if page < 1 { + page = 1 + } + if perPage < 1 { + perPage = 50 + } + + offset := int64((page - 1) * perPage) + issuers, total, err := s.issuerRepo.List(context.Background(), offset, int64(perPage)) + if err != nil { + return nil, 0, fmt.Errorf("failed to list issuers: %w", err) + } + + // Convert pointers to values for the handler interface + var result []domain.Issuer + for _, i := range issuers { + if i != nil { + result = append(result, *i) + } + } + + return result, total, nil +} + +// GetIssuer returns a single issuer (handler interface method). +func (s *IssuerService) GetIssuer(id string) (*domain.Issuer, error) { + return s.issuerRepo.Get(context.Background(), id) +} + +// CreateIssuer creates a new issuer (handler interface method). +func (s *IssuerService) CreateIssuer(issuer domain.Issuer) (*domain.Issuer, error) { + issuer.ID = generateID("issuer") + if err := s.issuerRepo.Create(context.Background(), &issuer); err != nil { + return nil, fmt.Errorf("failed to create issuer: %w", err) + } + return &issuer, nil +} + +// UpdateIssuer modifies an issuer (handler interface method). +func (s *IssuerService) UpdateIssuer(id string, issuer domain.Issuer) (*domain.Issuer, error) { + issuer.ID = id + if err := s.issuerRepo.Update(context.Background(), &issuer); err != nil { + return nil, fmt.Errorf("failed to update issuer: %w", err) + } + return &issuer, nil +} + +// DeleteIssuer removes an issuer (handler interface method). +func (s *IssuerService) DeleteIssuer(id string) error { + return s.issuerRepo.Delete(context.Background(), id) +} diff --git a/internal/service/issuer.go.4749230034325506546 b/internal/service/issuer.go.4749230034325506546 new file mode 100644 index 0000000..b642265 --- /dev/null +++ b/internal/service/issuer.go.4749230034325506546 @@ -0,0 +1,116 @@ +package service + +import ( + "context" + "fmt" + + "github.com/shankar0123/certctl/internal/domain" + "github.com/shankar0123/certctl/internal/repository" +) + +// IssuerService provides business logic for certificate issuer management. +type IssuerService struct { + issuerRepo repository.IssuerRepository + auditService *AuditService +} + +// NewIssuerService creates a new issuer service. +func NewIssuerService( + issuerRepo repository.IssuerRepository, + auditService *AuditService, +) *IssuerService { + return &IssuerService{ + issuerRepo: issuerRepo, + auditService: auditService, + } +} + +// List returns a paginated list of issuers. +func (s *IssuerService) List(ctx context.Context, page, perPage int) ([]*domain.Issuer, int64, error) { + if page < 1 { + page = 1 + } + if perPage < 1 { + perPage = 50 + } + + offset := int64((page - 1) * perPage) + issuers, total, err := s.issuerRepo.List(ctx, offset, int64(perPage)) + if err != nil { + return nil, 0, fmt.Errorf("failed to list issuers: %w", err) + } + return issuers, total, nil +} + +// Get retrieves an issuer by ID. +func (s *IssuerService) Get(ctx context.Context, id string) (*domain.Issuer, error) { + issuer, err := s.issuerRepo.Get(ctx, id) + if err != nil { + return nil, fmt.Errorf("failed to get issuer %s: %w", id, err) + } + return issuer, nil +} + +// Create validates and stores a new issuer. +func (s *IssuerService) Create(ctx context.Context, issuer *domain.Issuer, actor string) error { + if issuer.Name == "" { + return fmt.Errorf("issuer name is required") + } + + issuer.ID = generateID("issuer") + if err := s.issuerRepo.Create(ctx, issuer); err != nil { + return fmt.Errorf("failed to create issuer: %w", err) + } + + if s.auditService != nil { + _ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "create_issuer", "issuer", issuer.ID, nil) + } + + return nil +} + +// Update modifies an existing issuer. +func (s *IssuerService) Update(ctx context.Context, id string, issuer *domain.Issuer, actor string) error { + if issuer.Name == "" { + return fmt.Errorf("issuer name is required") + } + + issuer.ID = id + if err := s.issuerRepo.Update(ctx, issuer); err != nil { + return fmt.Errorf("failed to update issuer %s: %w", id, err) + } + + if s.auditService != nil { + _ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "update_issuer", "issuer", id, nil) + } + + return nil +} + +// Delete removes an issuer. +func (s *IssuerService) Delete(ctx context.Context, id string, actor string) error { + if err := s.issuerRepo.Delete(ctx, id); err != nil { + return fmt.Errorf("failed to delete issuer %s: %w", id, err) + } + + if s.auditService != nil { + _ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "delete_issuer", "issuer", id, nil) + } + + return nil +} + +// TestConnection verifies the issuer connection. +func (s *IssuerService) TestConnection(ctx context.Context, id string) error { + issuer, err := s.issuerRepo.Get(ctx, id) + if err != nil { + return fmt.Errorf("issuer not found: %w", err) + } + + // TODO: Implement actual connection test based on issuer type + if issuer == nil { + return fmt.Errorf("issuer not found") + } + + return nil +} diff --git a/internal/service/owner.go b/internal/service/owner.go new file mode 100644 index 0000000..464cafe --- /dev/null +++ b/internal/service/owner.go @@ -0,0 +1,155 @@ +package service + +import ( + "context" + "fmt" + + "github.com/shankar0123/certctl/internal/domain" + "github.com/shankar0123/certctl/internal/repository" +) + +// OwnerService provides business logic for certificate owner management. +type OwnerService struct { + ownerRepo repository.OwnerRepository + auditService *AuditService +} + +// NewOwnerService creates a new owner service. +func NewOwnerService( + ownerRepo repository.OwnerRepository, + auditService *AuditService, +) *OwnerService { + return &OwnerService{ + ownerRepo: ownerRepo, + auditService: auditService, + } +} + +// List returns a paginated list of owners. +func (s *OwnerService) List(ctx context.Context, page, perPage int) ([]*domain.Owner, int64, error) { + if page < 1 { + page = 1 + } + if perPage < 1 { + perPage = 50 + } + + offset := int64((page - 1) * perPage) + owners, total, err := s.ownerRepo.List(ctx, offset, int64(perPage)) + if err != nil { + return nil, 0, fmt.Errorf("failed to list owners: %w", err) + } + return owners, total, nil +} + +// Get retrieves an owner by ID. +func (s *OwnerService) Get(ctx context.Context, id string) (*domain.Owner, error) { + owner, err := s.ownerRepo.Get(ctx, id) + if err != nil { + return nil, fmt.Errorf("failed to get owner %s: %w", id, err) + } + return owner, nil +} + +// Create validates and stores a new owner. +func (s *OwnerService) Create(ctx context.Context, owner *domain.Owner, actor string) error { + if owner.Name == "" { + return fmt.Errorf("owner name is required") + } + + owner.ID = generateID("owner") + if err := s.ownerRepo.Create(ctx, owner); err != nil { + return fmt.Errorf("failed to create owner: %w", err) + } + + if s.auditService != nil { + _ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "create_owner", "owner", owner.ID, nil) + } + + return nil +} + +// Update modifies an existing owner. +func (s *OwnerService) Update(ctx context.Context, id string, owner *domain.Owner, actor string) error { + if owner.Name == "" { + return fmt.Errorf("owner name is required") + } + + owner.ID = id + if err := s.ownerRepo.Update(ctx, owner); err != nil { + return fmt.Errorf("failed to update owner %s: %w", id, err) + } + + if s.auditService != nil { + _ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "update_owner", "owner", id, nil) + } + + return nil +} + +// Delete removes an owner. +func (s *OwnerService) Delete(ctx context.Context, id string, actor string) error { + if err := s.ownerRepo.Delete(ctx, id); err != nil { + return fmt.Errorf("failed to delete owner %s: %w", id, err) + } + + if s.auditService != nil { + _ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "delete_owner", "owner", id, nil) + } + + return nil +} + +// ListOwners returns paginated owners (handler interface method). +func (s *OwnerService) ListOwners(page, perPage int) ([]domain.Owner, int64, error) { + if page < 1 { + page = 1 + } + if perPage < 1 { + perPage = 50 + } + + offset := int64((page - 1) * perPage) + owners, total, err := s.ownerRepo.List(context.Background(), offset, int64(perPage)) + if err != nil { + return nil, 0, fmt.Errorf("failed to list owners: %w", err) + } + + // Convert pointers to values for the handler interface + var result []domain.Owner + for _, o := range owners { + if o != nil { + result = append(result, *o) + } + } + + return result, total, nil +} + +// GetOwner returns a single owner (handler interface method). +func (s *OwnerService) GetOwner(id string) (*domain.Owner, error) { + return s.ownerRepo.Get(context.Background(), id) +} + +// CreateOwner creates a new owner (handler interface method). +func (s *OwnerService) CreateOwner(owner domain.Owner) (*domain.Owner, error) { + owner.ID = generateID("owner") + if err := s.ownerRepo.Create(context.Background(), &owner); err != nil { + return nil, fmt.Errorf("failed to create owner: %w", err) + } + return &owner, nil +} + +// UpdateOwner modifies an owner (handler interface method). +func (s *OwnerService) UpdateOwner(id string, owner domain.Owner) (*domain.Owner, error) { + owner.ID = id + if err := s.ownerRepo.Update(context.Background(), &owner); err != nil { + return nil, fmt.Errorf("failed to update owner: %w", err) + } + return &owner, nil +} + +// DeleteOwner removes an owner (handler interface method). +func (s *OwnerService) DeleteOwner(id string) error { + return s.ownerRepo.Delete(context.Background(), id) +} diff --git a/internal/service/target.go b/internal/service/target.go new file mode 100644 index 0000000..b76df30 --- /dev/null +++ b/internal/service/target.go @@ -0,0 +1,155 @@ +package service + +import ( + "context" + "fmt" + + "github.com/shankar0123/certctl/internal/domain" + "github.com/shankar0123/certctl/internal/repository" +) + +// TargetService provides business logic for deployment target management. +type TargetService struct { + targetRepo repository.TargetRepository + auditService *AuditService +} + +// NewTargetService creates a new target service. +func NewTargetService( + targetRepo repository.TargetRepository, + auditService *AuditService, +) *TargetService { + return &TargetService{ + targetRepo: targetRepo, + auditService: auditService, + } +} + +// List returns a paginated list of deployment targets. +func (s *TargetService) List(ctx context.Context, page, perPage int) ([]*domain.DeploymentTarget, int64, error) { + if page < 1 { + page = 1 + } + if perPage < 1 { + perPage = 50 + } + + offset := int64((page - 1) * perPage) + targets, total, err := s.targetRepo.List(ctx, offset, int64(perPage)) + if err != nil { + return nil, 0, fmt.Errorf("failed to list targets: %w", err) + } + return targets, total, nil +} + +// Get retrieves a deployment target by ID. +func (s *TargetService) Get(ctx context.Context, id string) (*domain.DeploymentTarget, error) { + target, err := s.targetRepo.Get(ctx, id) + if err != nil { + return nil, fmt.Errorf("failed to get target %s: %w", id, err) + } + return target, nil +} + +// Create validates and stores a new deployment target. +func (s *TargetService) Create(ctx context.Context, target *domain.DeploymentTarget, actor string) error { + if target.Name == "" { + return fmt.Errorf("target name is required") + } + + target.ID = generateID("target") + if err := s.targetRepo.Create(ctx, target); err != nil { + return fmt.Errorf("failed to create target: %w", err) + } + + if s.auditService != nil { + _ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "create_target", "target", target.ID, nil) + } + + return nil +} + +// Update modifies an existing deployment target. +func (s *TargetService) Update(ctx context.Context, id string, target *domain.DeploymentTarget, actor string) error { + if target.Name == "" { + return fmt.Errorf("target name is required") + } + + target.ID = id + if err := s.targetRepo.Update(ctx, target); err != nil { + return fmt.Errorf("failed to update target %s: %w", id, err) + } + + if s.auditService != nil { + _ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "update_target", "target", id, nil) + } + + return nil +} + +// Delete removes a deployment target. +func (s *TargetService) Delete(ctx context.Context, id string, actor string) error { + if err := s.targetRepo.Delete(ctx, id); err != nil { + return fmt.Errorf("failed to delete target %s: %w", id, err) + } + + if s.auditService != nil { + _ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "delete_target", "target", id, nil) + } + + return nil +} + +// ListTargets returns paginated targets (handler interface method). +func (s *TargetService) ListTargets(page, perPage int) ([]domain.DeploymentTarget, int64, error) { + if page < 1 { + page = 1 + } + if perPage < 1 { + perPage = 50 + } + + offset := int64((page - 1) * perPage) + targets, total, err := s.targetRepo.List(context.Background(), offset, int64(perPage)) + if err != nil { + return nil, 0, fmt.Errorf("failed to list targets: %w", err) + } + + // Convert pointers to values for the handler interface + var result []domain.DeploymentTarget + for _, t := range targets { + if t != nil { + result = append(result, *t) + } + } + + return result, total, nil +} + +// GetTarget returns a single target (handler interface method). +func (s *TargetService) GetTarget(id string) (*domain.DeploymentTarget, error) { + return s.targetRepo.Get(context.Background(), id) +} + +// CreateTarget creates a new target (handler interface method). +func (s *TargetService) CreateTarget(target domain.DeploymentTarget) (*domain.DeploymentTarget, error) { + target.ID = generateID("target") + if err := s.targetRepo.Create(context.Background(), &target); err != nil { + return nil, fmt.Errorf("failed to create target: %w", err) + } + return &target, nil +} + +// UpdateTarget modifies a target (handler interface method). +func (s *TargetService) UpdateTarget(id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) { + target.ID = id + if err := s.targetRepo.Update(context.Background(), &target); err != nil { + return nil, fmt.Errorf("failed to update target: %w", err) + } + return &target, nil +} + +// DeleteTarget removes a target (handler interface method). +func (s *TargetService) DeleteTarget(id string) error { + return s.targetRepo.Delete(context.Background(), id) +} diff --git a/internal/service/team.go b/internal/service/team.go new file mode 100644 index 0000000..c6f9fc3 --- /dev/null +++ b/internal/service/team.go @@ -0,0 +1,155 @@ +package service + +import ( + "context" + "fmt" + + "github.com/shankar0123/certctl/internal/domain" + "github.com/shankar0123/certctl/internal/repository" +) + +// TeamService provides business logic for team management. +type TeamService struct { + teamRepo repository.TeamRepository + auditService *AuditService +} + +// NewTeamService creates a new team service. +func NewTeamService( + teamRepo repository.TeamRepository, + auditService *AuditService, +) *TeamService { + return &TeamService{ + teamRepo: teamRepo, + auditService: auditService, + } +} + +// List returns a paginated list of teams. +func (s *TeamService) List(ctx context.Context, page, perPage int) ([]*domain.Team, int64, error) { + if page < 1 { + page = 1 + } + if perPage < 1 { + perPage = 50 + } + + offset := int64((page - 1) * perPage) + teams, total, err := s.teamRepo.List(ctx, offset, int64(perPage)) + if err != nil { + return nil, 0, fmt.Errorf("failed to list teams: %w", err) + } + return teams, total, nil +} + +// Get retrieves a team by ID. +func (s *TeamService) Get(ctx context.Context, id string) (*domain.Team, error) { + team, err := s.teamRepo.Get(ctx, id) + if err != nil { + return nil, fmt.Errorf("failed to get team %s: %w", id, err) + } + return team, nil +} + +// Create validates and stores a new team. +func (s *TeamService) Create(ctx context.Context, team *domain.Team, actor string) error { + if team.Name == "" { + return fmt.Errorf("team name is required") + } + + team.ID = generateID("team") + if err := s.teamRepo.Create(ctx, team); err != nil { + return fmt.Errorf("failed to create team: %w", err) + } + + if s.auditService != nil { + _ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "create_team", "team", team.ID, nil) + } + + return nil +} + +// Update modifies an existing team. +func (s *TeamService) Update(ctx context.Context, id string, team *domain.Team, actor string) error { + if team.Name == "" { + return fmt.Errorf("team name is required") + } + + team.ID = id + if err := s.teamRepo.Update(ctx, team); err != nil { + return fmt.Errorf("failed to update team %s: %w", id, err) + } + + if s.auditService != nil { + _ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "update_team", "team", id, nil) + } + + return nil +} + +// Delete removes a team. +func (s *TeamService) Delete(ctx context.Context, id string, actor string) error { + if err := s.teamRepo.Delete(ctx, id); err != nil { + return fmt.Errorf("failed to delete team %s: %w", id, err) + } + + if s.auditService != nil { + _ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "delete_team", "team", id, nil) + } + + return nil +} + +// ListTeams returns paginated teams (handler interface method). +func (s *TeamService) ListTeams(page, perPage int) ([]domain.Team, int64, error) { + if page < 1 { + page = 1 + } + if perPage < 1 { + perPage = 50 + } + + offset := int64((page - 1) * perPage) + teams, total, err := s.teamRepo.List(context.Background(), offset, int64(perPage)) + if err != nil { + return nil, 0, fmt.Errorf("failed to list teams: %w", err) + } + + // Convert pointers to values for the handler interface + var result []domain.Team + for _, t := range teams { + if t != nil { + result = append(result, *t) + } + } + + return result, total, nil +} + +// GetTeam returns a single team (handler interface method). +func (s *TeamService) GetTeam(id string) (*domain.Team, error) { + return s.teamRepo.Get(context.Background(), id) +} + +// CreateTeam creates a new team (handler interface method). +func (s *TeamService) CreateTeam(team domain.Team) (*domain.Team, error) { + team.ID = generateID("team") + if err := s.teamRepo.Create(context.Background(), &team); err != nil { + return nil, fmt.Errorf("failed to create team: %w", err) + } + return &team, nil +} + +// UpdateTeam modifies a team (handler interface method). +func (s *TeamService) UpdateTeam(id string, team domain.Team) (*domain.Team, error) { + team.ID = id + if err := s.teamRepo.Update(context.Background(), &team); err != nil { + return nil, fmt.Errorf("failed to update team: %w", err) + } + return &team, nil +} + +// DeleteTeam removes a team (handler interface method). +func (s *TeamService) DeleteTeam(id string) error { + return s.teamRepo.Delete(context.Background(), id) +} diff --git a/web/index.html b/web/index.html new file mode 100644 index 0000000..7effd51 --- /dev/null +++ b/web/index.html @@ -0,0 +1,1868 @@ + + + + + + certctl - Certificate Control Plane + + + + + + + +
+ + + +