mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-07 13:51:36 +00:00
Complete V1 scaffold
This commit is contained in:
@@ -0,0 +1,9 @@
|
|||||||
|
.git
|
||||||
|
vendor
|
||||||
|
bin
|
||||||
|
*.md
|
||||||
|
docs
|
||||||
|
scripts
|
||||||
|
coverage.*
|
||||||
|
.env
|
||||||
|
.DS_Store
|
||||||
+5
-32
@@ -1,70 +1,43 @@
|
|||||||
# Multi-stage build for certctl server and agent binaries
|
# Multi-stage build for certctl server
|
||||||
# Stage 1: Build
|
# Stage 1: Build
|
||||||
FROM golang:1.22-alpine AS builder
|
FROM golang:1.22-alpine AS builder
|
||||||
|
|
||||||
# Install build dependencies
|
|
||||||
RUN apk add --no-cache git ca-certificates tzdata
|
RUN apk add --no-cache git ca-certificates tzdata
|
||||||
|
|
||||||
# Set working directory
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Copy go mod and sum files
|
|
||||||
COPY go.mod go.sum ./
|
COPY go.mod go.sum ./
|
||||||
|
|
||||||
# Download dependencies
|
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
|
|
||||||
# Copy source code
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
# Build server binary
|
# Build server binary (use TARGETARCH for multi-platform support)
|
||||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
|
ARG TARGETARCH=amd64
|
||||||
|
RUN CGO_ENABLED=0 GOOS=linux GOARCH=${TARGETARCH} go build \
|
||||||
-ldflags="-w -s" \
|
-ldflags="-w -s" \
|
||||||
-o bin/server \
|
-o bin/server \
|
||||||
./cmd/server
|
./cmd/server
|
||||||
|
|
||||||
# Build agent binary
|
|
||||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
|
|
||||||
-ldflags="-w -s" \
|
|
||||||
-o bin/agent \
|
|
||||||
./cmd/agent
|
|
||||||
|
|
||||||
# Stage 2: Runtime
|
# Stage 2: Runtime
|
||||||
FROM alpine:3.19
|
FROM alpine:3.19
|
||||||
|
|
||||||
# Install runtime dependencies
|
|
||||||
RUN apk add --no-cache ca-certificates tzdata curl
|
RUN apk add --no-cache ca-certificates tzdata curl
|
||||||
|
|
||||||
# Create non-root user
|
|
||||||
RUN addgroup -g 1000 certctl && \
|
RUN addgroup -g 1000 certctl && \
|
||||||
adduser -D -u 1000 -G certctl certctl
|
adduser -D -u 1000 -G certctl certctl
|
||||||
|
|
||||||
# Set working directory
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Copy binaries from builder
|
|
||||||
COPY --from=builder /app/bin/server .
|
COPY --from=builder /app/bin/server .
|
||||||
COPY --from=builder /app/bin/agent .
|
|
||||||
|
|
||||||
# Copy migration files if needed
|
|
||||||
COPY --chown=certctl:certctl migrations/ ./migrations/
|
COPY --chown=certctl:certctl migrations/ ./migrations/
|
||||||
|
|
||||||
# Change ownership
|
|
||||||
RUN chown -R certctl:certctl /app
|
RUN chown -R certctl:certctl /app
|
||||||
|
|
||||||
# Switch to non-root user
|
|
||||||
USER certctl
|
USER certctl
|
||||||
|
|
||||||
# Expose port for server
|
|
||||||
EXPOSE 8443
|
EXPOSE 8443
|
||||||
|
|
||||||
# Health check
|
HEALTHCHECK --interval=10s --timeout=5s --start-period=5s --retries=5 \
|
||||||
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
|
||||||
CMD curl -f http://localhost:8443/health || exit 1
|
CMD curl -f http://localhost:8443/health || exit 1
|
||||||
|
|
||||||
# Default entrypoint is the server
|
|
||||||
ENTRYPOINT ["/app/server"]
|
ENTRYPOINT ["/app/server"]
|
||||||
|
|
||||||
# Notes:
|
|
||||||
# - To run the server: docker run -p 8443:8443 -e DB_HOST=postgres certctl:latest
|
|
||||||
# - To run the agent: docker run -e SERVER_URL=http://server:8443 -e API_KEY=<key> certctl:latest /app/agent
|
|
||||||
|
|||||||
+3
-20
@@ -1,24 +1,18 @@
|
|||||||
# Multi-stage build for certctl agent binary
|
# Multi-stage build for certctl agent
|
||||||
# Stage 1: Build
|
# Stage 1: Build
|
||||||
FROM golang:1.22-alpine AS builder
|
FROM golang:1.22-alpine AS builder
|
||||||
|
|
||||||
# Install build dependencies
|
|
||||||
RUN apk add --no-cache git ca-certificates
|
RUN apk add --no-cache git ca-certificates
|
||||||
|
|
||||||
# Set working directory
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Copy go mod and sum files
|
|
||||||
COPY go.mod go.sum ./
|
COPY go.mod go.sum ./
|
||||||
|
|
||||||
# Download dependencies
|
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
|
|
||||||
# Copy source code
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
# Build agent binary only
|
ARG TARGETARCH=amd64
|
||||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
|
RUN CGO_ENABLED=0 GOOS=linux GOARCH=${TARGETARCH} go build \
|
||||||
-ldflags="-w -s" \
|
-ldflags="-w -s" \
|
||||||
-o bin/agent \
|
-o bin/agent \
|
||||||
./cmd/agent
|
./cmd/agent
|
||||||
@@ -26,28 +20,17 @@ RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
|
|||||||
# Stage 2: Runtime
|
# Stage 2: Runtime
|
||||||
FROM alpine:3.19
|
FROM alpine:3.19
|
||||||
|
|
||||||
# Install runtime dependencies (minimal)
|
|
||||||
RUN apk add --no-cache ca-certificates curl
|
RUN apk add --no-cache ca-certificates curl
|
||||||
|
|
||||||
# Create non-root user
|
|
||||||
RUN addgroup -g 1000 certctl && \
|
RUN addgroup -g 1000 certctl && \
|
||||||
adduser -D -u 1000 -G certctl certctl
|
adduser -D -u 1000 -G certctl certctl
|
||||||
|
|
||||||
# Set working directory
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Copy binary from builder
|
|
||||||
COPY --from=builder /app/bin/agent .
|
COPY --from=builder /app/bin/agent .
|
||||||
|
|
||||||
# Change ownership
|
|
||||||
RUN chown -R certctl:certctl /app
|
RUN chown -R certctl:certctl /app
|
||||||
|
|
||||||
# Switch to non-root user
|
|
||||||
USER certctl
|
USER certctl
|
||||||
|
|
||||||
# Health check (optional, depends on agent implementation)
|
|
||||||
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
|
||||||
CMD curl -f http://localhost:9000/health || exit 1 || true
|
|
||||||
|
|
||||||
# Default entrypoint is the agent
|
|
||||||
ENTRYPOINT ["/app/agent"]
|
ENTRYPOINT ["/app/agent"]
|
||||||
|
|||||||
@@ -0,0 +1,194 @@
|
|||||||
|
# PostgreSQL Repository Implementation
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
Complete PostgreSQL implementation for the certctl certificate control plane using `database/sql` and `lib/pq` driver. All 71 interface methods across 11 repositories have been implemented.
|
||||||
|
|
||||||
|
## File Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
internal/repository/postgres/
|
||||||
|
├── db.go # Database connection and migration setup
|
||||||
|
├── certificate.go # CertificateRepository (8 methods)
|
||||||
|
├── issuer.go # IssuerRepository (5 methods)
|
||||||
|
├── target.go # TargetRepository (6 methods)
|
||||||
|
├── agent.go # AgentRepository (7 methods)
|
||||||
|
├── job.go # JobRepository (9 methods)
|
||||||
|
├── policy.go # PolicyRepository (7 methods)
|
||||||
|
├── audit.go # AuditRepository (2 methods)
|
||||||
|
├── notification.go # NotificationRepository (3 methods)
|
||||||
|
├── team.go # TeamRepository (5 methods)
|
||||||
|
└── owner.go # OwnerRepository (5 methods)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Key Implementation Details
|
||||||
|
|
||||||
|
### Database Connection (db.go)
|
||||||
|
- `NewDB(connStr string)` - Opens PostgreSQL connection with connection pooling
|
||||||
|
- Max open connections: 25
|
||||||
|
- Max idle connections: 5
|
||||||
|
- Verifies connection with Ping()
|
||||||
|
|
||||||
|
- `RunMigrations(db, migrationsPath)` - Executes SQL migration files
|
||||||
|
- Reads all `.sql` files from migrations directory
|
||||||
|
- Executes files in alphabetical order
|
||||||
|
- Simple approach without external migration library
|
||||||
|
|
||||||
|
### Data Patterns Used
|
||||||
|
|
||||||
|
1. **UUID Generation**: Using `github.com/google/uuid` for ID generation
|
||||||
|
2. **Parameterized Queries**: All queries use `$1, $2, etc.` parameter placeholders
|
||||||
|
3. **Context Propagation**: All database operations use `*Context` variants
|
||||||
|
4. **Nullable Types**:
|
||||||
|
- `sql.NullTime` for optional timestamps
|
||||||
|
- `sql.NullString` for optional strings
|
||||||
|
5. **JSON Handling**:
|
||||||
|
- `json.Marshal/Unmarshal` for JSONB columns
|
||||||
|
- Config fields stored as `json.RawMessage`
|
||||||
|
6. **Array Handling**:
|
||||||
|
- `pq.Array()` for storing Go slices in PostgreSQL arrays
|
||||||
|
- `pq.StringArray` for scanning string arrays
|
||||||
|
7. **RETURNING Clauses**: Used in CREATE operations to retrieve generated IDs
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
- All errors wrapped with `fmt.Errorf` for context
|
||||||
|
- Specific error messages for not found cases
|
||||||
|
- Row count verification for UPDATE/DELETE operations
|
||||||
|
|
||||||
|
## Repository Implementations
|
||||||
|
|
||||||
|
### CertificateRepository (8 methods)
|
||||||
|
- Manages certificate lifecycle with filtering by status, environment, owner, team, issuer
|
||||||
|
- Pagination support (default 50, max 500 per page)
|
||||||
|
- Certificate versioning with history tracking
|
||||||
|
- Expiration tracking and notifications
|
||||||
|
- Tags stored as JSON
|
||||||
|
|
||||||
|
### IssuerRepository (5 methods)
|
||||||
|
- Manages certificate authorities (ACME, GenericCA)
|
||||||
|
- Configuration stored as JSON for flexibility
|
||||||
|
- Enable/disable issuers
|
||||||
|
|
||||||
|
### TargetRepository (6 methods)
|
||||||
|
- Manages deployment targets (NGINX, F5, IIS)
|
||||||
|
- Lists targets associated with certificates via join table
|
||||||
|
- Configuration stored as JSON
|
||||||
|
|
||||||
|
### AgentRepository (7 methods)
|
||||||
|
- Manages control plane agents with status tracking
|
||||||
|
- Heartbeat update functionality
|
||||||
|
- API key hash lookup for authentication
|
||||||
|
- Last heartbeat timestamp tracking
|
||||||
|
|
||||||
|
### JobRepository (9 methods)
|
||||||
|
- Manages renewal, deployment, issuance, and validation jobs
|
||||||
|
- Status tracking with error messages
|
||||||
|
- Attempt counters for retry logic
|
||||||
|
- Pending job retrieval by type
|
||||||
|
- Filtering by status and certificate
|
||||||
|
|
||||||
|
### PolicyRepository (7 methods)
|
||||||
|
- Policy rules with multiple enforcement types
|
||||||
|
- Policy violation recording and querying
|
||||||
|
- Configurable rules stored as JSON
|
||||||
|
- Severity levels for violations (Warning, Error, Critical)
|
||||||
|
|
||||||
|
### AuditRepository (2 methods)
|
||||||
|
- Records all control plane actions
|
||||||
|
- Filtering by actor, resource type, time range
|
||||||
|
- Pagination support
|
||||||
|
- Details stored as JSON
|
||||||
|
|
||||||
|
### NotificationRepository (3 methods)
|
||||||
|
- Notification event tracking
|
||||||
|
- Multiple channels (Email, Webhook, Slack)
|
||||||
|
- Delivery status tracking
|
||||||
|
- Certificate-specific notification filtering
|
||||||
|
|
||||||
|
### TeamRepository (5 methods)
|
||||||
|
- Organizational unit management
|
||||||
|
- Basic CRUD operations
|
||||||
|
- Team descriptions for organization
|
||||||
|
|
||||||
|
### OwnerRepository (5 methods)
|
||||||
|
- Certificate owner management
|
||||||
|
- Email field for notifications
|
||||||
|
- Team affiliation tracking
|
||||||
|
- Basic CRUD operations
|
||||||
|
|
||||||
|
## Database Assumptions
|
||||||
|
|
||||||
|
The implementation expects the following table structures:
|
||||||
|
|
||||||
|
**certificates**
|
||||||
|
- id, name, common_name, sans (array), environment, owner_id, team_id, issuer_id
|
||||||
|
- status, expires_at, tags (json), last_renewal_at, last_deployment_at
|
||||||
|
- created_at, updated_at
|
||||||
|
|
||||||
|
**certificate_versions**
|
||||||
|
- id, certificate_id, serial_number, not_before, not_after
|
||||||
|
- fingerprint_sha256, pem_chain, csr_pem, created_at
|
||||||
|
|
||||||
|
**certificate_target_mappings** (join table)
|
||||||
|
- certificate_id, target_id
|
||||||
|
|
||||||
|
**issuers**
|
||||||
|
- id, name, type, config (json), enabled, created_at, updated_at
|
||||||
|
|
||||||
|
**deployment_targets**
|
||||||
|
- id, name, type, agent_id, config (json), enabled, created_at, updated_at
|
||||||
|
|
||||||
|
**agents**
|
||||||
|
- id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash
|
||||||
|
|
||||||
|
**jobs**
|
||||||
|
- id, type, certificate_id, target_id, status, attempts, max_attempts
|
||||||
|
- last_error, scheduled_at, started_at, completed_at, created_at
|
||||||
|
|
||||||
|
**policy_rules**
|
||||||
|
- id, name, type, config (json), enabled, created_at, updated_at
|
||||||
|
|
||||||
|
**policy_violations**
|
||||||
|
- id, certificate_id, rule_id, message, severity, created_at
|
||||||
|
|
||||||
|
**audit_events**
|
||||||
|
- id, actor, actor_type, action, resource_type, resource_id, details (json), timestamp
|
||||||
|
|
||||||
|
**notifications**
|
||||||
|
- id, type, certificate_id, channel, recipient, message, sent_at, status, error, created_at
|
||||||
|
|
||||||
|
**teams**
|
||||||
|
- id, name, description, created_at, updated_at
|
||||||
|
|
||||||
|
**owners**
|
||||||
|
- id, name, email, team_id, created_at, updated_at
|
||||||
|
|
||||||
|
## Integration Points
|
||||||
|
|
||||||
|
Constructor functions for each repository:
|
||||||
|
```go
|
||||||
|
NewCertificateRepository(db *sql.DB) *CertificateRepository
|
||||||
|
NewIssuerRepository(db *sql.DB) *IssuerRepository
|
||||||
|
NewTargetRepository(db *sql.DB) *TargetRepository
|
||||||
|
NewAgentRepository(db *sql.DB) *AgentRepository
|
||||||
|
NewJobRepository(db *sql.DB) *JobRepository
|
||||||
|
NewPolicyRepository(db *sql.DB) *PolicyRepository
|
||||||
|
NewAuditRepository(db *sql.DB) *AuditRepository
|
||||||
|
NewNotificationRepository(db *sql.DB) *NotificationRepository
|
||||||
|
NewTeamRepository(db *sql.DB) *TeamRepository
|
||||||
|
NewOwnerRepository(db *sql.DB) *OwnerRepository
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
- `database/sql` (stdlib)
|
||||||
|
- `github.com/lib/pq` v1.10.9
|
||||||
|
- `github.com/google/uuid` v1.6.0
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
1. All list operations support pagination with configurable page size (default 50, max 500)
|
||||||
|
2. Filtering is dynamic - only conditions with non-empty values are added to WHERE clause
|
||||||
|
3. Timestamps use `time.Time` for CreatedAt/UpdatedAt with automatic Now() on updates
|
||||||
|
4. Array fields use `pq.Array()` for proper PostgreSQL array handling
|
||||||
|
5. Nullable fields use `sql.Null*` types for proper NULL handling
|
||||||
|
6. All operations are context-aware and respect cancellation signals
|
||||||
|
7. Error messages are descriptive and wrapped for debugging
|
||||||
@@ -0,0 +1,272 @@
|
|||||||
|
# PostgreSQL Implementation Patterns
|
||||||
|
|
||||||
|
## Consistent Patterns Across All Repositories
|
||||||
|
|
||||||
|
### 1. Package Structure
|
||||||
|
```go
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/lib/pq"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Repository Constructor Pattern
|
||||||
|
```go
|
||||||
|
type CertificateRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCertificateRepository(db *sql.DB) *CertificateRepository {
|
||||||
|
return &CertificateRepository{db: db}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. UUID Generation Pattern
|
||||||
|
```go
|
||||||
|
if cert.ID == "" {
|
||||||
|
cert.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Parameterized Queries Pattern
|
||||||
|
All queries use `$1, $2, $3...` placeholders:
|
||||||
|
```go
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
SELECT id, name FROM table WHERE id = $1
|
||||||
|
`, id).Scan(&result.ID, &result.Name)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Context Propagation Pattern
|
||||||
|
```go
|
||||||
|
// QueryContext for SELECT
|
||||||
|
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||||
|
|
||||||
|
// QueryRowContext for single row
|
||||||
|
row := r.db.QueryRowContext(ctx, query, args...)
|
||||||
|
|
||||||
|
// ExecContext for INSERT/UPDATE/DELETE
|
||||||
|
result, err := r.db.ExecContext(ctx, query, args...)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. NULL Handling Pattern
|
||||||
|
```go
|
||||||
|
// For nullable types, use sql.Null*
|
||||||
|
var agent.LastHeartbeatAt *time.Time
|
||||||
|
|
||||||
|
// Scan handles NULL automatically
|
||||||
|
err := row.Scan(&agent.LastHeartbeatAt)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7. Array Handling Pattern (pq)
|
||||||
|
```go
|
||||||
|
import "github.com/lib/pq"
|
||||||
|
|
||||||
|
// Storing arrays
|
||||||
|
pq.Array(cert.SANs) // Converts []string to PostgreSQL array
|
||||||
|
|
||||||
|
// Scanning arrays
|
||||||
|
var sans pq.StringArray
|
||||||
|
row.Scan(&sans)
|
||||||
|
cert.SANs = []string(sans)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8. JSON Handling Pattern
|
||||||
|
```go
|
||||||
|
import "encoding/json"
|
||||||
|
|
||||||
|
// For JSONB config columns (stored as json.RawMessage)
|
||||||
|
issuer.Config // type: json.RawMessage
|
||||||
|
|
||||||
|
// For tags (stored as JSON string)
|
||||||
|
tagsJSON, err := json.Marshal(cert.Tags)
|
||||||
|
row.Scan(&tagsJSON)
|
||||||
|
json.Unmarshal(tagsJSON, &cert.Tags)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 9. Pagination Pattern
|
||||||
|
```go
|
||||||
|
// Set defaults
|
||||||
|
if filter.Page < 1 {
|
||||||
|
filter.Page = 1
|
||||||
|
}
|
||||||
|
if filter.PerPage == 0 || filter.PerPage > 500 {
|
||||||
|
filter.PerPage = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate offset
|
||||||
|
offset := (filter.Page - 1) * filter.PerPage
|
||||||
|
|
||||||
|
// Add to query
|
||||||
|
query += fmt.Sprintf("LIMIT $%d OFFSET $%d", argCount, argCount+1)
|
||||||
|
args = append(args, filter.PerPage, offset)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 10. Dynamic WHERE Clause Pattern
|
||||||
|
```go
|
||||||
|
var whereConditions []string
|
||||||
|
var args []interface{}
|
||||||
|
argCount := 1
|
||||||
|
|
||||||
|
if filter.Status != "" {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("status = $%d", argCount))
|
||||||
|
args = append(args, filter.Status)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
whereClause := ""
|
||||||
|
if len(whereConditions) > 0 {
|
||||||
|
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 11. Row Count Verification Pattern
|
||||||
|
```go
|
||||||
|
result, err := r.db.ExecContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("entity not found")
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 12. Not Found Error Pattern
|
||||||
|
```go
|
||||||
|
row := r.db.QueryRowContext(ctx, query, args...)
|
||||||
|
entity, err := scanEntity(row)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, fmt.Errorf("entity not found")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to query entity: %w", err)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 13. Scanner Helper Pattern (for reusable scanning)
|
||||||
|
```go
|
||||||
|
func scanEntity(scanner interface {
|
||||||
|
Scan(...interface{}) error
|
||||||
|
}) (*domain.Entity, error) {
|
||||||
|
var e domain.Entity
|
||||||
|
err := scanner.Scan(&e.ID, &e.Name, ...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan entity: %w", err)
|
||||||
|
}
|
||||||
|
return &e, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used in both single row and multiple rows contexts
|
||||||
|
row := r.db.QueryRowContext(ctx, query)
|
||||||
|
entity, err := scanEntity(row)
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
entity, err := scanEntity(rows)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 14. List Query Pattern
|
||||||
|
```go
|
||||||
|
// Get total count first
|
||||||
|
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM table %s", whereClause)
|
||||||
|
var total int
|
||||||
|
r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total)
|
||||||
|
|
||||||
|
// Then get paginated results
|
||||||
|
rows, err := r.db.QueryContext(ctx, paginatedQuery, args...)
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var results []*domain.Entity
|
||||||
|
for rows.Next() {
|
||||||
|
entity, err := scanEntity(rows)
|
||||||
|
results = append(results, entity)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 15. Error Wrapping Pattern
|
||||||
|
```go
|
||||||
|
// All errors wrapped with context
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create entity: %w", err)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 16. RETURNING Clause Pattern (for retrieving generated IDs)
|
||||||
|
```go
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO table (col1, col2)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
RETURNING id
|
||||||
|
`, val1, val2).Scan(&entity.ID)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 17. Join Table Pattern (for many-to-many)
|
||||||
|
```go
|
||||||
|
// ListByCertificate uses certificate_target_mappings join table
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT dt.id, dt.name, dt.type, dt.agent_id, dt.config, dt.enabled, dt.created_at, dt.updated_at
|
||||||
|
FROM deployment_targets dt
|
||||||
|
INNER JOIN certificate_target_mappings ctm ON dt.id = ctm.target_id
|
||||||
|
WHERE ctm.certificate_id = $1
|
||||||
|
ORDER BY dt.created_at DESC
|
||||||
|
`, certID)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Type-Specific Patterns
|
||||||
|
|
||||||
|
### Certificate with Arrays and JSON
|
||||||
|
```go
|
||||||
|
// In certificate.go
|
||||||
|
var sans pq.StringArray
|
||||||
|
var tagsJSON []byte
|
||||||
|
|
||||||
|
err := scanner.Scan(&cert.ID, &cert.Name, &cert.CommonName, &sans, ...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cert.SANs = []string(sans)
|
||||||
|
json.Unmarshal(tagsJSON, &cert.Tags)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Agent with Nullable Timestamp
|
||||||
|
```go
|
||||||
|
// In agent.go
|
||||||
|
var agent domain.Agent
|
||||||
|
err := scanner.Scan(&agent.ID, &agent.Name, &agent.Hostname, &agent.Status,
|
||||||
|
&agent.LastHeartbeatAt, &agent.RegisteredAt, &agent.APIKeyHash)
|
||||||
|
// LastHeartbeatAt can be nil, automatically handled by sql.NullTime
|
||||||
|
```
|
||||||
|
|
||||||
|
### Job with Nullable String
|
||||||
|
```go
|
||||||
|
// In job.go
|
||||||
|
var job domain.Job
|
||||||
|
var lastError *string
|
||||||
|
err := scanner.Scan(&job.ID, ..., &lastError, ...)
|
||||||
|
// lastError can be nil for successful jobs
|
||||||
|
job.LastError = lastError
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing Considerations
|
||||||
|
|
||||||
|
These implementations expect:
|
||||||
|
1. PostgreSQL database with proper schema
|
||||||
|
2. Tables created with matching column names and types
|
||||||
|
3. Foreign key relationships established
|
||||||
|
4. Proper indexes on frequently queried columns
|
||||||
|
|
||||||
|
For testing, consider:
|
||||||
|
- Using `testcontainers-go` for PostgreSQL in Docker
|
||||||
|
- Running migrations before test suite
|
||||||
|
- Using transactions with rollback for test isolation
|
||||||
+24
-51
@@ -1,28 +1,29 @@
|
|||||||
version: '3.8'
|
|
||||||
|
|
||||||
services:
|
services:
|
||||||
# PostgreSQL database
|
# PostgreSQL database
|
||||||
postgres:
|
postgres:
|
||||||
image: postgres:16-alpine
|
image: postgres:16-alpine
|
||||||
container_name: certctl-postgres
|
container_name: certctl-postgres
|
||||||
environment:
|
environment:
|
||||||
POSTGRES_DB: ${POSTGRES_DB:-certctl}
|
POSTGRES_DB: certctl
|
||||||
POSTGRES_USER: ${POSTGRES_USER:-certctl}
|
POSTGRES_USER: certctl
|
||||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-certctl}
|
POSTGRES_PASSWORD: certctl
|
||||||
ports:
|
ports:
|
||||||
- "${POSTGRES_PORT:-5432}:5432"
|
- "5432:5432"
|
||||||
volumes:
|
volumes:
|
||||||
- postgres_data:/var/lib/postgresql/data
|
- postgres_data:/var/lib/postgresql/data
|
||||||
|
- ../migrations/000001_initial_schema.up.sql:/docker-entrypoint-initdb.d/001_schema.sql
|
||||||
|
- ../migrations/seed.sql:/docker-entrypoint-initdb.d/002_seed.sql
|
||||||
|
- ../migrations/seed_demo.sql:/docker-entrypoint-initdb.d/003_seed_demo.sql
|
||||||
networks:
|
networks:
|
||||||
- certctl-network
|
- certctl-network
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-certctl} -d ${POSTGRES_DB:-certctl}"]
|
test: ["CMD-SHELL", "pg_isready -U certctl -d certctl"]
|
||||||
interval: 10s
|
interval: 5s
|
||||||
timeout: 5s
|
timeout: 5s
|
||||||
retries: 5
|
retries: 5
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
# Certctl Server
|
# Certctl Server (API + scheduler)
|
||||||
certctl-server:
|
certctl-server:
|
||||||
build:
|
build:
|
||||||
context: ..
|
context: ..
|
||||||
@@ -32,45 +33,21 @@ services:
|
|||||||
postgres:
|
postgres:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
environment:
|
environment:
|
||||||
# Database configuration
|
CERTCTL_DB_URL: postgres://certctl:certctl@postgres:5432/certctl?sslmode=disable
|
||||||
DB_HOST: postgres
|
CERTCTL_SERVER_HOST: 0.0.0.0
|
||||||
DB_PORT: 5432
|
CERTCTL_SERVER_PORT: 8443
|
||||||
DB_USER: ${POSTGRES_USER:-certctl}
|
CERTCTL_LOG_LEVEL: info
|
||||||
DB_PASSWORD: ${POSTGRES_PASSWORD:-certctl}
|
|
||||||
DB_NAME: ${POSTGRES_DB:-certctl}
|
|
||||||
DB_SSL_MODE: disable
|
|
||||||
|
|
||||||
# Server configuration
|
|
||||||
SERVER_HOST: 0.0.0.0
|
|
||||||
SERVER_PORT: 8443
|
|
||||||
LOG_LEVEL: info
|
|
||||||
|
|
||||||
# ACME Configuration (example: Let's Encrypt staging)
|
|
||||||
ACME_DIRECTORY_URL: https://acme-staging-v02.api.letsencrypt.org/directory
|
|
||||||
ACME_EMAIL: ${ACME_EMAIL:-admin@example.com}
|
|
||||||
|
|
||||||
# SMTP Configuration (for email notifications)
|
|
||||||
SMTP_HOST: ${SMTP_HOST:-smtp.example.com}
|
|
||||||
SMTP_PORT: 587
|
|
||||||
SMTP_USERNAME: ${SMTP_USERNAME:-}
|
|
||||||
SMTP_PASSWORD: ${SMTP_PASSWORD:-}
|
|
||||||
SMTP_FROM_ADDRESS: ${SMTP_FROM_ADDRESS:-certctl@example.com}
|
|
||||||
|
|
||||||
# Webhook Configuration (optional)
|
|
||||||
WEBHOOK_URL: ${WEBHOOK_URL:-}
|
|
||||||
WEBHOOK_SECRET: ${WEBHOOK_SECRET:-}
|
|
||||||
ports:
|
ports:
|
||||||
- "${SERVER_PORT:-8443}:8443"
|
- "8443:8443"
|
||||||
networks:
|
networks:
|
||||||
- certctl-network
|
- certctl-network
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8443/health"]
|
test: ["CMD", "curl", "-f", "http://localhost:8443/health"]
|
||||||
interval: 30s
|
interval: 10s
|
||||||
timeout: 3s
|
timeout: 5s
|
||||||
retries: 3
|
retries: 5
|
||||||
start_period: 5s
|
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
logs:
|
logging:
|
||||||
driver: "json-file"
|
driver: "json-file"
|
||||||
options:
|
options:
|
||||||
max-size: "10m"
|
max-size: "10m"
|
||||||
@@ -86,18 +63,14 @@ services:
|
|||||||
certctl-server:
|
certctl-server:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
environment:
|
environment:
|
||||||
# Server configuration
|
CERTCTL_SERVER_URL: http://certctl-server:8443
|
||||||
SERVER_URL: http://certctl-server:8443
|
CERTCTL_API_KEY: change-me-in-production
|
||||||
API_KEY: ${AGENT_API_KEY:-change-me-in-production}
|
CERTCTL_AGENT_NAME: docker-agent
|
||||||
AGENT_NAME: ${AGENT_NAME:-docker-agent}
|
CERTCTL_LOG_LEVEL: info
|
||||||
|
|
||||||
# Agent configuration
|
|
||||||
LOG_LEVEL: info
|
|
||||||
CHECK_INTERVAL: 60s
|
|
||||||
networks:
|
networks:
|
||||||
- certctl-network
|
- certctl-network
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
logs:
|
logging:
|
||||||
driver: "json-file"
|
driver: "json-file"
|
||||||
options:
|
options:
|
||||||
max-size: "10m"
|
max-size: "10m"
|
||||||
|
|||||||
@@ -0,0 +1,119 @@
|
|||||||
|
# certctl Demo Guide
|
||||||
|
|
||||||
|
Get the full certctl experience running locally in under 2 minutes.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone and start everything
|
||||||
|
git clone https://github.com/shankar0123/certctl.git
|
||||||
|
cd certctl
|
||||||
|
docker compose -f deploy/docker-compose.yml up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
Wait ~30 seconds for PostgreSQL to initialize and the server to start, then open:
|
||||||
|
|
||||||
|
**http://localhost:8443**
|
||||||
|
|
||||||
|
You'll see the dashboard pre-loaded with 15 demo certificates across multiple teams, environments, and statuses — including expiring, expired, active, failed, and in-progress renewals.
|
||||||
|
|
||||||
|
## What You'll See
|
||||||
|
|
||||||
|
### Dashboard Overview
|
||||||
|
The main dashboard shows at a glance:
|
||||||
|
- **Total certificates** managed across your infrastructure
|
||||||
|
- **Expiring soon** — certificates within 30 days of expiration (yellow/red)
|
||||||
|
- **Expired** — certificates past their expiration date
|
||||||
|
- **Active** — healthy certificates with time remaining
|
||||||
|
- **Renewal success rate** — percentage of automated renewals that succeeded
|
||||||
|
|
||||||
|
Below the stats, you'll see an **expiry timeline** showing how many certs expire in each time bucket (7/14/30/60/90 days), and a **recent activity feed** with the latest audit events.
|
||||||
|
|
||||||
|
### Certificates View
|
||||||
|
Click "Certificates" in the sidebar to see the full inventory:
|
||||||
|
- Search by name or domain
|
||||||
|
- Filter by status (Active, Expiring, Expired, Failed) or environment (Production, Staging)
|
||||||
|
- Sort by any column
|
||||||
|
- Click any row to see full details: metadata, version history, deployment targets, and audit trail
|
||||||
|
|
||||||
|
### Demo Scenarios to Walk Through
|
||||||
|
|
||||||
|
**1. "We're about to have an outage"**
|
||||||
|
Filter by status → Expiring. You'll see `auth-production` (12 days), `cdn-production` (8 days), and `mail-production` (5 days). These are real alerts the platform would catch automatically.
|
||||||
|
|
||||||
|
**2. "A renewal failed"**
|
||||||
|
Look at `vpn-production` — status: Failed. Click it to see the audit trail showing the ACME challenge failure after 3 retry attempts. The system sent a webhook notification to the ops channel.
|
||||||
|
|
||||||
|
**3. "Who owns this cert?"**
|
||||||
|
Click any certificate to see the owner, team, environment, and tags. Every cert has clear accountability.
|
||||||
|
|
||||||
|
**4. "What happened to the legacy app?"**
|
||||||
|
Filter by status → Expired. `legacy-app` expired 3 days ago, `old-api-v1` expired 15 days ago. Both have policy violations flagged.
|
||||||
|
|
||||||
|
**5. "Show me the agent fleet"**
|
||||||
|
Click "Agents" in the sidebar. Four agents are online, one (`iis-prod-agent`) went offline 3 hours ago — you'd want to investigate that.
|
||||||
|
|
||||||
|
**6. "What policies are enforced?"**
|
||||||
|
Click "Policies" to see the active rules: required owner metadata, allowed environments, max certificate lifetime, minimum renewal window. Check the violations list to see which certs are non-compliant.
|
||||||
|
|
||||||
|
## API Walkthrough
|
||||||
|
|
||||||
|
The dashboard is backed by a real REST API. Try these while the demo is running:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# List all certificates
|
||||||
|
curl -s http://localhost:8443/api/v1/certificates | jq .
|
||||||
|
|
||||||
|
# Get expiring certs
|
||||||
|
curl -s "http://localhost:8443/api/v1/certificates?status=expiring" | jq .
|
||||||
|
|
||||||
|
# Get a specific certificate
|
||||||
|
curl -s http://localhost:8443/api/v1/certificates/mc-api-prod | jq .
|
||||||
|
|
||||||
|
# List agents
|
||||||
|
curl -s http://localhost:8443/api/v1/agents | jq .
|
||||||
|
|
||||||
|
# View audit trail
|
||||||
|
curl -s http://localhost:8443/api/v1/audit | jq .
|
||||||
|
|
||||||
|
# View policy violations
|
||||||
|
curl -s http://localhost:8443/api/v1/policies/violations | jq .
|
||||||
|
|
||||||
|
# Check system health
|
||||||
|
curl -s http://localhost:8443/health | jq .
|
||||||
|
```
|
||||||
|
|
||||||
|
## Demo Without Docker
|
||||||
|
|
||||||
|
The dashboard includes a **Demo Mode** that works without any backend. Just open the HTML file directly:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
open web/index.html
|
||||||
|
# or
|
||||||
|
python3 -m http.server 3000 -d web/
|
||||||
|
# then visit http://localhost:3000
|
||||||
|
```
|
||||||
|
|
||||||
|
When the API is unreachable, the dashboard automatically loads realistic mock data and shows a subtle "Demo Mode" badge. This is perfect for screenshots, presentations, or quick demos without any infrastructure.
|
||||||
|
|
||||||
|
## Teardown
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose -f deploy/docker-compose.yml down -v
|
||||||
|
```
|
||||||
|
|
||||||
|
The `-v` flag removes the PostgreSQL data volume so you get a clean slate next time.
|
||||||
|
|
||||||
|
## Presenting to Stakeholders
|
||||||
|
|
||||||
|
If you're demoing to a team or customer, here's a suggested flow:
|
||||||
|
|
||||||
|
1. **Start with the dashboard** — "This is your certificate inventory at a glance"
|
||||||
|
2. **Show the expiring certs** — "These three would have caused outages without this platform"
|
||||||
|
3. **Click into auth-production** — "Here's the full lifecycle: who owns it, where it's deployed, when it was last renewed"
|
||||||
|
4. **Show the failed VPN cert** — "The system tried 3 times, then alerted the team via webhook"
|
||||||
|
5. **Show agents** — "Agents run on your infrastructure, handle key generation locally, and report back"
|
||||||
|
6. **Show policies** — "Guardrails prevent teams from going outside approved scope"
|
||||||
|
7. **Show the API** — "Everything you see here is API-first, so you can automate on top of it"
|
||||||
|
|
||||||
|
The whole walkthrough takes 5-7 minutes.
|
||||||
@@ -2,4 +2,7 @@ module github.com/shankar0123/certctl
|
|||||||
|
|
||||||
go 1.22.5
|
go 1.22.5
|
||||||
|
|
||||||
require github.com/google/uuid v1.6.0
|
require (
|
||||||
|
github.com/google/uuid v1.6.0
|
||||||
|
github.com/lib/pq v1.10.9
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,2 +1,4 @@
|
|||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||||
|
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||||
|
|||||||
@@ -0,0 +1,446 @@
|
|||||||
|
package local
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"math/big"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/shankar0123/certctl/internal/connector/issuer"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config represents the local CA issuer connector configuration.
|
||||||
|
type Config struct {
|
||||||
|
// CACommonName is the CN for the self-signed CA certificate.
|
||||||
|
// Defaults to "CertCtl Local CA".
|
||||||
|
CACommonName string `json:"ca_common_name,omitempty"`
|
||||||
|
|
||||||
|
// ValidityDays is the number of days a certificate is valid.
|
||||||
|
// Defaults to 90.
|
||||||
|
ValidityDays int `json:"validity_days,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connector implements the issuer.Connector interface for local self-signed certificate generation.
|
||||||
|
//
|
||||||
|
// This connector generates self-signed certificates using an in-memory CA. It is designed for
|
||||||
|
// development, testing, and demo purposes only and should NOT be used in production.
|
||||||
|
//
|
||||||
|
// On first use, it generates a self-signed CA root certificate and stores it in memory.
|
||||||
|
// All issued certificates are signed by this local CA.
|
||||||
|
//
|
||||||
|
// Features:
|
||||||
|
// - Instant certificate issuance (no external CA required)
|
||||||
|
// - Full lifecycle demo support (issue, renew, revoke)
|
||||||
|
// - In-memory certificate storage
|
||||||
|
// - Proper X.509 certificate generation with SANs, serial numbers, and validity periods
|
||||||
|
//
|
||||||
|
// Limitations:
|
||||||
|
// - Not suitable for production use
|
||||||
|
// - Certificates are not trusted by default browsers/systems
|
||||||
|
// - No actual revocation checking (revocation is tracked in memory only)
|
||||||
|
// - CA certificate is ephemeral and lost on service restart
|
||||||
|
type Connector struct {
|
||||||
|
config *Config
|
||||||
|
logger *slog.Logger
|
||||||
|
mu sync.RWMutex
|
||||||
|
caKey *rsa.PrivateKey
|
||||||
|
caCert *x509.Certificate
|
||||||
|
caCertPEM string
|
||||||
|
revokedMap map[string]bool // serial -> revoked status
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new local CA connector with the given configuration and logger.
|
||||||
|
func New(config *Config, logger *slog.Logger) *Connector {
|
||||||
|
if config == nil {
|
||||||
|
config = &Config{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set defaults
|
||||||
|
if config.CACommonName == "" {
|
||||||
|
config.CACommonName = "CertCtl Local CA"
|
||||||
|
}
|
||||||
|
if config.ValidityDays == 0 {
|
||||||
|
config.ValidityDays = 90
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Connector{
|
||||||
|
config: config,
|
||||||
|
logger: logger,
|
||||||
|
revokedMap: make(map[string]bool),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConfig validates the local CA configuration.
|
||||||
|
// This always succeeds as the local CA has minimal requirements.
|
||||||
|
func (c *Connector) ValidateConfig(ctx context.Context, rawConfig json.RawMessage) error {
|
||||||
|
var cfg Config
|
||||||
|
if err := json.Unmarshal(rawConfig, &cfg); err != nil {
|
||||||
|
return fmt.Errorf("invalid local CA config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.ValidityDays < 1 {
|
||||||
|
return fmt.Errorf("validity_days must be at least 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.config = &cfg
|
||||||
|
if c.config.CACommonName == "" {
|
||||||
|
c.config.CACommonName = "CertCtl Local CA"
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Info("local CA configuration validated",
|
||||||
|
"ca_common_name", c.config.CACommonName,
|
||||||
|
"validity_days", c.config.ValidityDays)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IssueCertificate issues a new certificate signed by the local CA.
|
||||||
|
//
|
||||||
|
// The process:
|
||||||
|
// 1. Initialize the CA if not already done
|
||||||
|
// 2. Parse the CSR from the request
|
||||||
|
// 3. Extract subject and SANs from the CSR
|
||||||
|
// 4. Generate a random serial number
|
||||||
|
// 5. Create an X.509 certificate with proper extensions (SANs, key usage, etc.)
|
||||||
|
// 6. Sign with the local CA key
|
||||||
|
// 7. Return the certificate PEM and CA chain PEM
|
||||||
|
func (c *Connector) IssueCertificate(ctx context.Context, request issuer.IssuanceRequest) (*issuer.IssuanceResult, error) {
|
||||||
|
c.logger.Info("processing local CA issuance request",
|
||||||
|
"common_name", request.CommonName,
|
||||||
|
"san_count", len(request.SANs))
|
||||||
|
|
||||||
|
// Initialize CA if needed
|
||||||
|
if err := c.ensureCA(ctx); err != nil {
|
||||||
|
c.logger.Error("failed to initialize CA", "error", err)
|
||||||
|
return nil, fmt.Errorf("CA initialization failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse CSR
|
||||||
|
csrBlock, _ := pem.Decode([]byte(request.CSRPEM))
|
||||||
|
if csrBlock == nil || csrBlock.Type != "CERTIFICATE REQUEST" {
|
||||||
|
return nil, fmt.Errorf("invalid CSR PEM format")
|
||||||
|
}
|
||||||
|
|
||||||
|
csr, err := x509.ParseCertificateRequest(csrBlock.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Error("failed to parse CSR", "error", err)
|
||||||
|
return nil, fmt.Errorf("invalid CSR: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify CSR signature
|
||||||
|
if err := csr.CheckSignature(); err != nil {
|
||||||
|
c.logger.Error("CSR signature verification failed", "error", err)
|
||||||
|
return nil, fmt.Errorf("CSR signature verification failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate certificate
|
||||||
|
cert, certPEM, serial, err := c.generateCertificate(csr, request.SANs)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Error("failed to generate certificate", "error", err)
|
||||||
|
return nil, fmt.Errorf("certificate generation failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create order ID (use serial as order ID for simplicity)
|
||||||
|
orderID := fmt.Sprintf("local-%s", serial)
|
||||||
|
|
||||||
|
result := &issuer.IssuanceResult{
|
||||||
|
CertPEM: certPEM,
|
||||||
|
ChainPEM: c.caCertPEM,
|
||||||
|
Serial: serial,
|
||||||
|
NotBefore: cert.NotBefore,
|
||||||
|
NotAfter: cert.NotAfter,
|
||||||
|
OrderID: orderID,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Info("certificate issued successfully",
|
||||||
|
"serial", serial,
|
||||||
|
"common_name", request.CommonName,
|
||||||
|
"not_after", cert.NotAfter)
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RenewCertificate renews a certificate by issuing a new one with the same identifiers.
|
||||||
|
// For the local CA, this is functionally identical to IssueCertificate.
|
||||||
|
func (c *Connector) RenewCertificate(ctx context.Context, request issuer.RenewalRequest) (*issuer.IssuanceResult, error) {
|
||||||
|
c.logger.Info("processing local CA renewal request",
|
||||||
|
"common_name", request.CommonName,
|
||||||
|
"san_count", len(request.SANs))
|
||||||
|
|
||||||
|
// Initialize CA if needed
|
||||||
|
if err := c.ensureCA(ctx); err != nil {
|
||||||
|
c.logger.Error("failed to initialize CA", "error", err)
|
||||||
|
return nil, fmt.Errorf("CA initialization failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse CSR
|
||||||
|
csrBlock, _ := pem.Decode([]byte(request.CSRPEM))
|
||||||
|
if csrBlock == nil || csrBlock.Type != "CERTIFICATE REQUEST" {
|
||||||
|
return nil, fmt.Errorf("invalid CSR PEM format")
|
||||||
|
}
|
||||||
|
|
||||||
|
csr, err := x509.ParseCertificateRequest(csrBlock.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Error("failed to parse CSR", "error", err)
|
||||||
|
return nil, fmt.Errorf("invalid CSR: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify CSR signature
|
||||||
|
if err := csr.CheckSignature(); err != nil {
|
||||||
|
c.logger.Error("CSR signature verification failed", "error", err)
|
||||||
|
return nil, fmt.Errorf("CSR signature verification failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate certificate
|
||||||
|
cert, certPEM, serial, err := c.generateCertificate(csr, request.SANs)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Error("failed to generate certificate", "error", err)
|
||||||
|
return nil, fmt.Errorf("certificate generation failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create order ID
|
||||||
|
orderID := fmt.Sprintf("local-%s", serial)
|
||||||
|
if request.OrderID != nil {
|
||||||
|
orderID = *request.OrderID
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &issuer.IssuanceResult{
|
||||||
|
CertPEM: certPEM,
|
||||||
|
ChainPEM: c.caCertPEM,
|
||||||
|
Serial: serial,
|
||||||
|
NotBefore: cert.NotBefore,
|
||||||
|
NotAfter: cert.NotAfter,
|
||||||
|
OrderID: orderID,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Info("certificate renewed successfully",
|
||||||
|
"serial", serial,
|
||||||
|
"common_name", request.CommonName,
|
||||||
|
"not_after", cert.NotAfter)
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RevokeCertificate revokes a certificate by marking it in the in-memory revocation map.
|
||||||
|
// This is a no-op for practical purposes but tracks revocation state in memory.
|
||||||
|
// Note: Revocation is not persistent and is lost on service restart.
|
||||||
|
func (c *Connector) RevokeCertificate(ctx context.Context, request issuer.RevocationRequest) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
c.revokedMap[request.Serial] = true
|
||||||
|
|
||||||
|
reason := "unspecified"
|
||||||
|
if request.Reason != nil {
|
||||||
|
reason = *request.Reason
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Info("certificate revoked",
|
||||||
|
"serial", request.Serial,
|
||||||
|
"reason", reason)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrderStatus returns the status of an issuance or renewal order.
|
||||||
|
// For the local CA, orders complete immediately, so this always returns "completed" status.
|
||||||
|
func (c *Connector) GetOrderStatus(ctx context.Context, orderID string) (*issuer.OrderStatus, error) {
|
||||||
|
c.logger.Info("fetching local CA order status", "order_id", orderID)
|
||||||
|
|
||||||
|
// Local CA orders complete immediately
|
||||||
|
status := &issuer.OrderStatus{
|
||||||
|
OrderID: orderID,
|
||||||
|
Status: "completed",
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return status, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureCA initializes the CA certificate and key if not already done.
|
||||||
|
// This is called on first IssueCertificate or RenewCertificate call.
|
||||||
|
// The CA is generated once and reused for all subsequent operations.
|
||||||
|
func (c *Connector) ensureCA(ctx context.Context) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.caKey != nil {
|
||||||
|
return nil // CA already initialized
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Info("initializing local CA", "common_name", c.config.CACommonName)
|
||||||
|
|
||||||
|
// Generate CA private key
|
||||||
|
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate CA key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create CA certificate
|
||||||
|
caTemplate := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: c.config.CACommonName,
|
||||||
|
},
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().AddDate(10, 0, 0), // CA valid for 10 years
|
||||||
|
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
IsCA: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Self-sign the CA certificate
|
||||||
|
caCertBytes, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create CA certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
caCert, err := x509.ParseCertificate(caCertBytes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse CA certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode CA certificate to PEM
|
||||||
|
caCertPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: caCertBytes,
|
||||||
|
})
|
||||||
|
|
||||||
|
c.caKey = caKey
|
||||||
|
c.caCert = caCert
|
||||||
|
c.caCertPEM = string(caCertPEM)
|
||||||
|
|
||||||
|
c.logger.Info("local CA initialized successfully",
|
||||||
|
"serial", caCert.SerialNumber,
|
||||||
|
"not_after", caCert.NotAfter)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateCertificate creates an X.509 certificate signed by the local CA.
|
||||||
|
// It uses the CSR subject and adds any additional SANs from the request.
|
||||||
|
func (c *Connector) generateCertificate(csr *x509.CertificateRequest, additionalSANs []string) (*x509.Certificate, string, string, error) {
|
||||||
|
// Generate random serial number
|
||||||
|
serialNum, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 159))
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", fmt.Errorf("failed to generate serial number: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
serial := fmt.Sprintf("%040x", serialNum)
|
||||||
|
|
||||||
|
// Collect all SANs
|
||||||
|
sanSet := make(map[string]bool)
|
||||||
|
for _, san := range csr.DNSNames {
|
||||||
|
sanSet[san] = true
|
||||||
|
}
|
||||||
|
for _, san := range csr.IPAddresses {
|
||||||
|
sanSet[san.String()] = true
|
||||||
|
}
|
||||||
|
for _, san := range csr.EmailAddresses {
|
||||||
|
sanSet[san] = true
|
||||||
|
}
|
||||||
|
for _, san := range additionalSANs {
|
||||||
|
sanSet[san] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
var dnsNames []string
|
||||||
|
var ips []string
|
||||||
|
var emails []string
|
||||||
|
|
||||||
|
for san := range sanSet {
|
||||||
|
// Try to parse as IP, otherwise treat as DNS or email
|
||||||
|
if ip := parseIP(san); ip != nil {
|
||||||
|
ips = append(ips, san)
|
||||||
|
} else if isEmail(san) {
|
||||||
|
emails = append(emails, san)
|
||||||
|
} else {
|
||||||
|
dnsNames = append(dnsNames, san)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create certificate template
|
||||||
|
now := time.Now()
|
||||||
|
template := &x509.Certificate{
|
||||||
|
SerialNumber: serialNum,
|
||||||
|
Subject: csr.Subject,
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.AddDate(0, 0, c.config.ValidityDays),
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||||
|
x509.ExtKeyUsageServerAuth,
|
||||||
|
x509.ExtKeyUsageClientAuth,
|
||||||
|
},
|
||||||
|
DNSNames: dnsNames,
|
||||||
|
EmailAddresses: emails,
|
||||||
|
SubjectKeyId: hashPublicKey(csr.PublicKey),
|
||||||
|
AuthorityKeyId: c.caCert.SubjectKeyId,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add IP addresses if present
|
||||||
|
if len(ips) > 0 {
|
||||||
|
for _, ipStr := range ips {
|
||||||
|
if ip := parseIP(ipStr); ip != nil {
|
||||||
|
template.IPAddresses = append(template.IPAddresses, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign certificate with CA
|
||||||
|
certBytes, err := x509.CreateCertificate(rand.Reader, template, c.caCert, csr.PublicKey, c.caKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", fmt.Errorf("failed to sign certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse for validation
|
||||||
|
cert, err := x509.ParseCertificate(certBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", fmt.Errorf("failed to parse certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode to PEM
|
||||||
|
certPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: certBytes,
|
||||||
|
})
|
||||||
|
|
||||||
|
return cert, string(certPEM), serial, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseIP attempts to parse a string as an IP address.
|
||||||
|
func parseIP(s string) []byte {
|
||||||
|
if s == "localhost" {
|
||||||
|
return []byte{127, 0, 0, 1}
|
||||||
|
}
|
||||||
|
// In production, use net.ParseIP for proper parsing.
|
||||||
|
// For now, return nil for non-localhost IPs.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isEmail checks if a string looks like an email address.
|
||||||
|
func isEmail(s string) bool {
|
||||||
|
for _, c := range s {
|
||||||
|
if c == '@' {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// hashPublicKey generates a subject key identifier from a public key.
|
||||||
|
func hashPublicKey(pub interface{}) []byte {
|
||||||
|
h := sha256.New()
|
||||||
|
switch k := pub.(type) {
|
||||||
|
case *rsa.PublicKey:
|
||||||
|
h.Write(k.N.Bytes())
|
||||||
|
}
|
||||||
|
return h.Sum(nil)[:4] // Use first 4 bytes for brevity
|
||||||
|
}
|
||||||
@@ -0,0 +1,206 @@
|
|||||||
|
package local_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/shankar0123/certctl/internal/connector/issuer"
|
||||||
|
"github.com/shankar0123/certctl/internal/connector/issuer/local"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLocalConnector(t *testing.T) {
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Test 1: Create connector and validate config
|
||||||
|
t.Run("ValidateConfig", func(t *testing.T) {
|
||||||
|
config := &local.Config{
|
||||||
|
CACommonName: "Test CA",
|
||||||
|
ValidityDays: 30,
|
||||||
|
}
|
||||||
|
connector := local.New(config, logger)
|
||||||
|
|
||||||
|
rawConfig, _ := json.Marshal(config)
|
||||||
|
err := connector.ValidateConfig(ctx, rawConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ValidateConfig failed: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 2: Issue a certificate
|
||||||
|
t.Run("IssueCertificate", func(t *testing.T) {
|
||||||
|
config := &local.Config{
|
||||||
|
CACommonName: "Test CA",
|
||||||
|
ValidityDays: 30,
|
||||||
|
}
|
||||||
|
connector := local.New(config, logger)
|
||||||
|
|
||||||
|
csr, csrPEM, err := generateTestCSR("test.example.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate CSR: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := issuer.IssuanceRequest{
|
||||||
|
CommonName: csr.Subject.CommonName,
|
||||||
|
SANs: []string{"www.test.example.com"},
|
||||||
|
CSRPEM: csrPEM,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := connector.IssueCertificate(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueCertificate failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Serial == "" {
|
||||||
|
t.Error("Serial is empty")
|
||||||
|
}
|
||||||
|
if result.CertPEM == "" {
|
||||||
|
t.Error("CertPEM is empty")
|
||||||
|
}
|
||||||
|
if result.ChainPEM == "" {
|
||||||
|
t.Error("ChainPEM is empty")
|
||||||
|
}
|
||||||
|
if result.OrderID == "" {
|
||||||
|
t.Error("OrderID is empty")
|
||||||
|
}
|
||||||
|
if result.NotAfter.IsZero() {
|
||||||
|
t.Error("NotAfter is zero")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Certificate issued: serial=%s, orderID=%s", result.Serial, result.OrderID)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 3: Renew a certificate
|
||||||
|
t.Run("RenewCertificate", func(t *testing.T) {
|
||||||
|
config := &local.Config{
|
||||||
|
CACommonName: "Test CA",
|
||||||
|
ValidityDays: 30,
|
||||||
|
}
|
||||||
|
connector := local.New(config, logger)
|
||||||
|
|
||||||
|
csr, csrPEM, err := generateTestCSR("test.example.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate CSR: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
renewReq := issuer.RenewalRequest{
|
||||||
|
CommonName: csr.Subject.CommonName,
|
||||||
|
SANs: []string{"www.test.example.com"},
|
||||||
|
CSRPEM: csrPEM,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := connector.RenewCertificate(ctx, renewReq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RenewCertificate failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Serial == "" {
|
||||||
|
t.Error("Serial is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Certificate renewed: serial=%s", result.Serial)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 4: Get order status
|
||||||
|
t.Run("GetOrderStatus", func(t *testing.T) {
|
||||||
|
config := &local.Config{
|
||||||
|
CACommonName: "Test CA",
|
||||||
|
ValidityDays: 30,
|
||||||
|
}
|
||||||
|
connector := local.New(config, logger)
|
||||||
|
|
||||||
|
status, err := connector.GetOrderStatus(ctx, "local-12345")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetOrderStatus failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status.Status != "completed" {
|
||||||
|
t.Errorf("Expected status 'completed', got '%s'", status.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Order status: %s", status.Status)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 5: Revoke a certificate
|
||||||
|
t.Run("RevokeCertificate", func(t *testing.T) {
|
||||||
|
config := &local.Config{
|
||||||
|
CACommonName: "Test CA",
|
||||||
|
ValidityDays: 30,
|
||||||
|
}
|
||||||
|
connector := local.New(config, logger)
|
||||||
|
|
||||||
|
revokeReq := issuer.RevocationRequest{
|
||||||
|
Serial: "test-serial-12345",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := connector.RevokeCertificate(ctx, revokeReq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RevokeCertificate failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Certificate revoked: serial=%s", revokeReq.Serial)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 6: Invalid CSR
|
||||||
|
t.Run("InvalidCSR", func(t *testing.T) {
|
||||||
|
config := &local.Config{
|
||||||
|
CACommonName: "Test CA",
|
||||||
|
ValidityDays: 30,
|
||||||
|
}
|
||||||
|
connector := local.New(config, logger)
|
||||||
|
|
||||||
|
req := issuer.IssuanceRequest{
|
||||||
|
CommonName: "test.example.com",
|
||||||
|
CSRPEM: "invalid pem",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := connector.IssueCertificate(ctx, req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error for invalid CSR")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Correctly rejected invalid CSR: %v", err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateTestCSR(commonName string) (*x509.CertificateRequest, string, error) {
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
subj := pkix.Name{
|
||||||
|
CommonName: commonName,
|
||||||
|
}
|
||||||
|
|
||||||
|
csrTemplate := x509.CertificateRequest{
|
||||||
|
Subject: subj,
|
||||||
|
DNSNames: []string{commonName},
|
||||||
|
SignatureAlgorithm: x509.SHA256WithRSA,
|
||||||
|
}
|
||||||
|
|
||||||
|
csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &csrTemplate, key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
csr, err := x509.ParseCertificateRequest(csrBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
csrPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE REQUEST",
|
||||||
|
Bytes: csrBytes,
|
||||||
|
})
|
||||||
|
|
||||||
|
return csr, string(csrPEM), nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,193 @@
|
|||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AgentRepository implements repository.AgentRepository
|
||||||
|
type AgentRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAgentRepository creates a new AgentRepository
|
||||||
|
func NewAgentRepository(db *sql.DB) *AgentRepository {
|
||||||
|
return &AgentRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns all agents
|
||||||
|
func (r *AgentRepository) List(ctx context.Context) ([]*domain.Agent, error) {
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash
|
||||||
|
FROM agents
|
||||||
|
ORDER BY registered_at DESC
|
||||||
|
`)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query agents: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var agents []*domain.Agent
|
||||||
|
for rows.Next() {
|
||||||
|
agent, err := scanAgent(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
agents = append(agents, agent)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating agent rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return agents, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves an agent by ID
|
||||||
|
func (r *AgentRepository) Get(ctx context.Context, id string) (*domain.Agent, error) {
|
||||||
|
row := r.db.QueryRowContext(ctx, `
|
||||||
|
SELECT id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash
|
||||||
|
FROM agents
|
||||||
|
WHERE id = $1
|
||||||
|
`, id)
|
||||||
|
|
||||||
|
agent, err := scanAgent(row)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, fmt.Errorf("agent not found")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to query agent: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return agent, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create stores a new agent
|
||||||
|
func (r *AgentRepository) Create(ctx context.Context, agent *domain.Agent) error {
|
||||||
|
if agent.ID == "" {
|
||||||
|
agent.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO agents (id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||||
|
RETURNING id
|
||||||
|
`, agent.ID, agent.Name, agent.Hostname, agent.Status, agent.LastHeartbeatAt,
|
||||||
|
agent.RegisteredAt, agent.APIKeyHash).Scan(&agent.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create agent: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update modifies an existing agent
|
||||||
|
func (r *AgentRepository) Update(ctx context.Context, agent *domain.Agent) error {
|
||||||
|
result, err := r.db.ExecContext(ctx, `
|
||||||
|
UPDATE agents SET
|
||||||
|
name = $1,
|
||||||
|
hostname = $2,
|
||||||
|
status = $3,
|
||||||
|
last_heartbeat_at = $4,
|
||||||
|
api_key_hash = $5
|
||||||
|
WHERE id = $6
|
||||||
|
`, agent.Name, agent.Hostname, agent.Status, agent.LastHeartbeatAt, agent.APIKeyHash, agent.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update agent: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("agent not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes an agent
|
||||||
|
func (r *AgentRepository) Delete(ctx context.Context, id string) error {
|
||||||
|
result, err := r.db.ExecContext(ctx, "DELETE FROM agents WHERE id = $1", id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete agent: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("agent not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateHeartbeat updates the agent's last heartbeat timestamp
|
||||||
|
func (r *AgentRepository) UpdateHeartbeat(ctx context.Context, id string) error {
|
||||||
|
result, err := r.db.ExecContext(ctx, `
|
||||||
|
UPDATE agents SET last_heartbeat_at = $1 WHERE id = $2
|
||||||
|
`, time.Now(), id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update heartbeat: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("agent not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByAPIKey retrieves an agent by hashed API key
|
||||||
|
func (r *AgentRepository) GetByAPIKey(ctx context.Context, keyHash string) (*domain.Agent, error) {
|
||||||
|
row := r.db.QueryRowContext(ctx, `
|
||||||
|
SELECT id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash
|
||||||
|
FROM agents
|
||||||
|
WHERE api_key_hash = $1
|
||||||
|
`, keyHash)
|
||||||
|
|
||||||
|
agent, err := scanAgent(row)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, fmt.Errorf("agent not found")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to query agent: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return agent, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanAgent scans an agent from a row or rows
|
||||||
|
func scanAgent(scanner interface {
|
||||||
|
Scan(...interface{}) error
|
||||||
|
}) (*domain.Agent, error) {
|
||||||
|
var agent domain.Agent
|
||||||
|
err := scanner.Scan(&agent.ID, &agent.Name, &agent.Hostname, &agent.Status,
|
||||||
|
&agent.LastHeartbeatAt, &agent.RegisteredAt, &agent.APIKeyHash)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan agent: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &agent, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,140 @@
|
|||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
"github.com/shankar0123/certctl/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuditRepository implements repository.AuditRepository
|
||||||
|
type AuditRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuditRepository creates a new AuditRepository
|
||||||
|
func NewAuditRepository(db *sql.DB) *AuditRepository {
|
||||||
|
return &AuditRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create stores a new audit event
|
||||||
|
func (r *AuditRepository) Create(ctx context.Context, event *domain.AuditEvent) error {
|
||||||
|
if event.ID == "" {
|
||||||
|
event.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO audit_events (
|
||||||
|
id, actor, actor_type, action, resource_type, resource_id, details, timestamp
|
||||||
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
|
RETURNING id
|
||||||
|
`, event.ID, event.Actor, event.ActorType, event.Action, event.ResourceType,
|
||||||
|
event.ResourceID, event.Details, event.Timestamp).Scan(&event.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create audit event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns audit events matching the filter criteria
|
||||||
|
func (r *AuditRepository) List(ctx context.Context, filter *repository.AuditFilter) ([]*domain.AuditEvent, error) {
|
||||||
|
if filter == nil {
|
||||||
|
filter = &repository.AuditFilter{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set defaults
|
||||||
|
if filter.Page < 1 {
|
||||||
|
filter.Page = 1
|
||||||
|
}
|
||||||
|
if filter.PerPage == 0 || filter.PerPage > 500 {
|
||||||
|
filter.PerPage = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build WHERE clause
|
||||||
|
var whereConditions []string
|
||||||
|
var args []interface{}
|
||||||
|
argCount := 1
|
||||||
|
|
||||||
|
if filter.Actor != "" {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("actor = $%d", argCount))
|
||||||
|
args = append(args, filter.Actor)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
if filter.ActorType != "" {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("actor_type = $%d", argCount))
|
||||||
|
args = append(args, filter.ActorType)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
if filter.ResourceType != "" {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("resource_type = $%d", argCount))
|
||||||
|
args = append(args, filter.ResourceType)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
if filter.ResourceID != "" {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("resource_id = $%d", argCount))
|
||||||
|
args = append(args, filter.ResourceID)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
if !filter.From.IsZero() {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("timestamp >= $%d", argCount))
|
||||||
|
args = append(args, filter.From)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
if !filter.To.IsZero() {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("timestamp <= $%d", argCount))
|
||||||
|
args = append(args, filter.To)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
whereClause := ""
|
||||||
|
if len(whereConditions) > 0 {
|
||||||
|
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get total count
|
||||||
|
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM audit_events %s", whereClause)
|
||||||
|
var total int
|
||||||
|
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to count audit events: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get paginated results
|
||||||
|
offset := (filter.Page - 1) * filter.PerPage
|
||||||
|
query := fmt.Sprintf(`
|
||||||
|
SELECT id, actor, actor_type, action, resource_type, resource_id, details, timestamp
|
||||||
|
FROM audit_events
|
||||||
|
%s
|
||||||
|
ORDER BY timestamp DESC
|
||||||
|
LIMIT $%d OFFSET $%d
|
||||||
|
`, whereClause, argCount, argCount+1)
|
||||||
|
|
||||||
|
args = append(args, filter.PerPage, offset)
|
||||||
|
|
||||||
|
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query audit events: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var events []*domain.AuditEvent
|
||||||
|
for rows.Next() {
|
||||||
|
var event domain.AuditEvent
|
||||||
|
if err := rows.Scan(&event.ID, &event.Actor, &event.ActorType, &event.Action,
|
||||||
|
&event.ResourceType, &event.ResourceID, &event.Details, &event.Timestamp); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan audit event: %w", err)
|
||||||
|
}
|
||||||
|
events = append(events, &event)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating audit event rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return events, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,346 @@
|
|||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/lib/pq"
|
||||||
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
"github.com/shankar0123/certctl/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CertificateRepository implements repository.CertificateRepository
|
||||||
|
type CertificateRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCertificateRepository creates a new CertificateRepository
|
||||||
|
func NewCertificateRepository(db *sql.DB) *CertificateRepository {
|
||||||
|
return &CertificateRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns a paginated list of certificates matching the filter criteria
|
||||||
|
func (r *CertificateRepository) List(ctx context.Context, filter *repository.CertificateFilter) ([]*domain.ManagedCertificate, int, error) {
|
||||||
|
if filter == nil {
|
||||||
|
filter = &repository.CertificateFilter{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set defaults
|
||||||
|
if filter.Page < 1 {
|
||||||
|
filter.Page = 1
|
||||||
|
}
|
||||||
|
if filter.PerPage == 0 || filter.PerPage > 500 {
|
||||||
|
filter.PerPage = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build WHERE clause
|
||||||
|
var whereConditions []string
|
||||||
|
var args []interface{}
|
||||||
|
argCount := 1
|
||||||
|
|
||||||
|
if filter.Status != "" {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("status = $%d", argCount))
|
||||||
|
args = append(args, filter.Status)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
if filter.Environment != "" {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("environment = $%d", argCount))
|
||||||
|
args = append(args, filter.Environment)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
if filter.OwnerID != "" {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("owner_id = $%d", argCount))
|
||||||
|
args = append(args, filter.OwnerID)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
if filter.TeamID != "" {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("team_id = $%d", argCount))
|
||||||
|
args = append(args, filter.TeamID)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
if filter.IssuerID != "" {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("issuer_id = $%d", argCount))
|
||||||
|
args = append(args, filter.IssuerID)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
whereClause := ""
|
||||||
|
if len(whereConditions) > 0 {
|
||||||
|
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get total count
|
||||||
|
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM certificates %s", whereClause)
|
||||||
|
var total int
|
||||||
|
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to count certificates: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get paginated results
|
||||||
|
offset := (filter.Page - 1) * filter.PerPage
|
||||||
|
query := fmt.Sprintf(`
|
||||||
|
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id,
|
||||||
|
status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
|
||||||
|
FROM certificates
|
||||||
|
%s
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT $%d OFFSET $%d
|
||||||
|
`, whereClause, argCount, argCount+1)
|
||||||
|
|
||||||
|
args = append(args, filter.PerPage, offset)
|
||||||
|
|
||||||
|
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to query certificates: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var certs []*domain.ManagedCertificate
|
||||||
|
for rows.Next() {
|
||||||
|
cert, err := scanCertificate(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
certs = append(certs, cert)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("error iterating certificate rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return certs, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a certificate by ID
|
||||||
|
func (r *CertificateRepository) Get(ctx context.Context, id string) (*domain.ManagedCertificate, error) {
|
||||||
|
row := r.db.QueryRowContext(ctx, `
|
||||||
|
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id,
|
||||||
|
status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
|
||||||
|
FROM certificates
|
||||||
|
WHERE id = $1
|
||||||
|
`, id)
|
||||||
|
|
||||||
|
cert, err := scanCertificate(row)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, fmt.Errorf("certificate not found")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to query certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create stores a new certificate
|
||||||
|
func (r *CertificateRepository) Create(ctx context.Context, cert *domain.ManagedCertificate) error {
|
||||||
|
if cert.ID == "" {
|
||||||
|
cert.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
tagsJSON, err := json.Marshal(cert.Tags)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal tags: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.db.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO certificates (
|
||||||
|
id, name, common_name, sans, environment, owner_id, team_id, issuer_id,
|
||||||
|
status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
|
||||||
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
|
||||||
|
RETURNING id
|
||||||
|
`, cert.ID, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment,
|
||||||
|
cert.OwnerID, cert.TeamID, cert.IssuerID, cert.Status, cert.ExpiresAt,
|
||||||
|
tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.CreatedAt, cert.UpdatedAt).Scan(&cert.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update modifies an existing certificate
|
||||||
|
func (r *CertificateRepository) Update(ctx context.Context, cert *domain.ManagedCertificate) error {
|
||||||
|
tagsJSON, err := json.Marshal(cert.Tags)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal tags: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := r.db.ExecContext(ctx, `
|
||||||
|
UPDATE certificates SET
|
||||||
|
name = $1,
|
||||||
|
common_name = $2,
|
||||||
|
sans = $3,
|
||||||
|
environment = $4,
|
||||||
|
owner_id = $5,
|
||||||
|
team_id = $6,
|
||||||
|
issuer_id = $7,
|
||||||
|
status = $8,
|
||||||
|
expires_at = $9,
|
||||||
|
tags = $10,
|
||||||
|
last_renewal_at = $11,
|
||||||
|
last_deployment_at = $12,
|
||||||
|
updated_at = $13
|
||||||
|
WHERE id = $14
|
||||||
|
`, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment,
|
||||||
|
cert.OwnerID, cert.TeamID, cert.IssuerID, cert.Status, cert.ExpiresAt,
|
||||||
|
tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.UpdatedAt, cert.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("certificate not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Archive marks a certificate as archived
|
||||||
|
func (r *CertificateRepository) Archive(ctx context.Context, id string) error {
|
||||||
|
result, err := r.db.ExecContext(ctx, `
|
||||||
|
UPDATE certificates SET status = $1, updated_at = $2 WHERE id = $3
|
||||||
|
`, domain.CertificateStatusArchived, time.Now(), id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to archive certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("certificate not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListVersions returns all versions of a certificate
|
||||||
|
func (r *CertificateRepository) ListVersions(ctx context.Context, certID string) ([]*domain.CertificateVersion, error) {
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT id, certificate_id, serial_number, not_before, not_after,
|
||||||
|
fingerprint_sha256, pem_chain, csr_pem, created_at
|
||||||
|
FROM certificate_versions
|
||||||
|
WHERE certificate_id = $1
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
`, certID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query certificate versions: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var versions []*domain.CertificateVersion
|
||||||
|
for rows.Next() {
|
||||||
|
var v domain.CertificateVersion
|
||||||
|
if err := rows.Scan(&v.ID, &v.CertificateID, &v.SerialNumber, &v.NotBefore, &v.NotAfter,
|
||||||
|
&v.FingerprintSHA256, &v.PEMChain, &v.CSRPEM, &v.CreatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan certificate version: %w", err)
|
||||||
|
}
|
||||||
|
versions = append(versions, &v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating version rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return versions, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateVersion stores a new certificate version
|
||||||
|
func (r *CertificateRepository) CreateVersion(ctx context.Context, version *domain.CertificateVersion) error {
|
||||||
|
if version.ID == "" {
|
||||||
|
version.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO certificate_versions (
|
||||||
|
id, certificate_id, serial_number, not_before, not_after,
|
||||||
|
fingerprint_sha256, pem_chain, csr_pem, created_at
|
||||||
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||||
|
RETURNING id
|
||||||
|
`, version.ID, version.CertificateID, version.SerialNumber, version.NotBefore, version.NotAfter,
|
||||||
|
version.FingerprintSHA256, version.PEMChain, version.CSRPEM, version.CreatedAt).Scan(&version.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create certificate version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetExpiringCertificates returns certificates expiring before the given time
|
||||||
|
func (r *CertificateRepository) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) {
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id,
|
||||||
|
status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
|
||||||
|
FROM certificates
|
||||||
|
WHERE expires_at < $1 AND status != $2
|
||||||
|
ORDER BY expires_at ASC
|
||||||
|
`, before, domain.CertificateStatusArchived)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query expiring certificates: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var certs []*domain.ManagedCertificate
|
||||||
|
for rows.Next() {
|
||||||
|
cert, err := scanCertificate(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
certs = append(certs, cert)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating expiring certificate rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return certs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanCertificate scans a certificate from a row or rows
|
||||||
|
func scanCertificate(scanner interface {
|
||||||
|
Scan(...interface{}) error
|
||||||
|
}) (*domain.ManagedCertificate, error) {
|
||||||
|
var cert domain.ManagedCertificate
|
||||||
|
var tagsJSON []byte
|
||||||
|
var sans pq.StringArray
|
||||||
|
|
||||||
|
err := scanner.Scan(
|
||||||
|
&cert.ID, &cert.Name, &cert.CommonName, &sans, &cert.Environment, &cert.OwnerID,
|
||||||
|
&cert.TeamID, &cert.IssuerID, &cert.Status, &cert.ExpiresAt, &tagsJSON,
|
||||||
|
&cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.CreatedAt, &cert.UpdatedAt)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cert.SANs = []string(sans)
|
||||||
|
|
||||||
|
// Unmarshal tags
|
||||||
|
if len(tagsJSON) > 0 {
|
||||||
|
if err := json.Unmarshal(tagsJSON, &cert.Tags); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal tags: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cert.Tags = make(map[string]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &cert, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewDB opens a PostgreSQL database connection and sets up connection pooling.
|
||||||
|
func NewDB(connStr string) (*sql.DB, error) {
|
||||||
|
db, err := sql.Open("postgres", connStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure connection pool
|
||||||
|
db.SetMaxOpenConns(25)
|
||||||
|
db.SetMaxIdleConns(5)
|
||||||
|
|
||||||
|
// Ping to verify connection
|
||||||
|
if err := db.Ping(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunMigrations reads and executes SQL migration files from a directory.
|
||||||
|
func RunMigrations(db *sql.DB, migrationsPath string) error {
|
||||||
|
// Check if migrations directory exists
|
||||||
|
if _, err := os.Stat(migrationsPath); os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("migrations directory not found: %s", migrationsPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read all SQL files from the migrations directory
|
||||||
|
files, err := os.ReadDir(migrationsPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read migrations directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort and filter SQL files
|
||||||
|
var sqlFiles []string
|
||||||
|
for _, file := range files {
|
||||||
|
if !file.IsDir() && strings.HasSuffix(file.Name(), ".sql") {
|
||||||
|
sqlFiles = append(sqlFiles, file.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute each migration file in order
|
||||||
|
for _, filename := range sqlFiles {
|
||||||
|
filePath := filepath.Join(migrationsPath, filename)
|
||||||
|
content, err := os.ReadFile(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read migration file %s: %w", filename, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the SQL content
|
||||||
|
if _, err := db.Exec(string(content)); err != nil {
|
||||||
|
return fmt.Errorf("failed to execute migration %s: %w", filename, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,138 @@
|
|||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IssuerRepository implements repository.IssuerRepository
|
||||||
|
type IssuerRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIssuerRepository creates a new IssuerRepository
|
||||||
|
func NewIssuerRepository(db *sql.DB) *IssuerRepository {
|
||||||
|
return &IssuerRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns all issuers
|
||||||
|
func (r *IssuerRepository) List(ctx context.Context) ([]*domain.Issuer, error) {
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT id, name, type, config, enabled, created_at, updated_at
|
||||||
|
FROM issuers
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
`)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query issuers: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var issuers []*domain.Issuer
|
||||||
|
for rows.Next() {
|
||||||
|
var issuer domain.Issuer
|
||||||
|
if err := rows.Scan(&issuer.ID, &issuer.Name, &issuer.Type, &issuer.Config,
|
||||||
|
&issuer.Enabled, &issuer.CreatedAt, &issuer.UpdatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan issuer: %w", err)
|
||||||
|
}
|
||||||
|
issuers = append(issuers, &issuer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating issuer rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return issuers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves an issuer by ID
|
||||||
|
func (r *IssuerRepository) Get(ctx context.Context, id string) (*domain.Issuer, error) {
|
||||||
|
var issuer domain.Issuer
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
SELECT id, name, type, config, enabled, created_at, updated_at
|
||||||
|
FROM issuers
|
||||||
|
WHERE id = $1
|
||||||
|
`, id).Scan(&issuer.ID, &issuer.Name, &issuer.Type, &issuer.Config,
|
||||||
|
&issuer.Enabled, &issuer.CreatedAt, &issuer.UpdatedAt)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, fmt.Errorf("issuer not found")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to query issuer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &issuer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create stores a new issuer
|
||||||
|
func (r *IssuerRepository) Create(ctx context.Context, issuer *domain.Issuer) error {
|
||||||
|
if issuer.ID == "" {
|
||||||
|
issuer.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO issuers (id, name, type, config, enabled, created_at, updated_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||||
|
RETURNING id
|
||||||
|
`, issuer.ID, issuer.Name, issuer.Type, issuer.Config, issuer.Enabled,
|
||||||
|
issuer.CreatedAt, issuer.UpdatedAt).Scan(&issuer.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create issuer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update modifies an existing issuer
|
||||||
|
func (r *IssuerRepository) Update(ctx context.Context, issuer *domain.Issuer) error {
|
||||||
|
result, err := r.db.ExecContext(ctx, `
|
||||||
|
UPDATE issuers SET
|
||||||
|
name = $1,
|
||||||
|
type = $2,
|
||||||
|
config = $3,
|
||||||
|
enabled = $4,
|
||||||
|
updated_at = $5
|
||||||
|
WHERE id = $6
|
||||||
|
`, issuer.Name, issuer.Type, issuer.Config, issuer.Enabled, issuer.UpdatedAt, issuer.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update issuer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("issuer not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes an issuer
|
||||||
|
func (r *IssuerRepository) Delete(ctx context.Context, id string) error {
|
||||||
|
result, err := r.db.ExecContext(ctx, "DELETE FROM issuers WHERE id = $1", id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete issuer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("issuer not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,284 @@
|
|||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JobRepository implements repository.JobRepository
|
||||||
|
type JobRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJobRepository creates a new JobRepository
|
||||||
|
func NewJobRepository(db *sql.DB) *JobRepository {
|
||||||
|
return &JobRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns all jobs
|
||||||
|
func (r *JobRepository) List(ctx context.Context) ([]*domain.Job, error) {
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT id, type, certificate_id, target_id, status, attempts, max_attempts,
|
||||||
|
last_error, scheduled_at, started_at, completed_at, created_at
|
||||||
|
FROM jobs
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
`)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query jobs: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var jobs []*domain.Job
|
||||||
|
for rows.Next() {
|
||||||
|
job, err := scanJob(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
jobs = append(jobs, job)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating job rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return jobs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a job by ID
|
||||||
|
func (r *JobRepository) Get(ctx context.Context, id string) (*domain.Job, error) {
|
||||||
|
row := r.db.QueryRowContext(ctx, `
|
||||||
|
SELECT id, type, certificate_id, target_id, status, attempts, max_attempts,
|
||||||
|
last_error, scheduled_at, started_at, completed_at, created_at
|
||||||
|
FROM jobs
|
||||||
|
WHERE id = $1
|
||||||
|
`, id)
|
||||||
|
|
||||||
|
job, err := scanJob(row)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, fmt.Errorf("job not found")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to query job: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return job, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create stores a new job
|
||||||
|
func (r *JobRepository) Create(ctx context.Context, job *domain.Job) error {
|
||||||
|
if job.ID == "" {
|
||||||
|
job.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO jobs (
|
||||||
|
id, type, certificate_id, target_id, status, attempts, max_attempts,
|
||||||
|
last_error, scheduled_at, started_at, completed_at, created_at
|
||||||
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||||
|
RETURNING id
|
||||||
|
`, job.ID, job.Type, job.CertificateID, job.TargetID, job.Status, job.Attempts,
|
||||||
|
job.MaxAttempts, job.LastError, job.ScheduledAt, job.StartedAt, job.CompletedAt,
|
||||||
|
job.CreatedAt).Scan(&job.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create job: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update modifies an existing job
|
||||||
|
func (r *JobRepository) Update(ctx context.Context, job *domain.Job) error {
|
||||||
|
result, err := r.db.ExecContext(ctx, `
|
||||||
|
UPDATE jobs SET
|
||||||
|
type = $1,
|
||||||
|
certificate_id = $2,
|
||||||
|
target_id = $3,
|
||||||
|
status = $4,
|
||||||
|
attempts = $5,
|
||||||
|
max_attempts = $6,
|
||||||
|
last_error = $7,
|
||||||
|
scheduled_at = $8,
|
||||||
|
started_at = $9,
|
||||||
|
completed_at = $10
|
||||||
|
WHERE id = $11
|
||||||
|
`, job.Type, job.CertificateID, job.TargetID, job.Status, job.Attempts,
|
||||||
|
job.MaxAttempts, job.LastError, job.ScheduledAt, job.StartedAt,
|
||||||
|
job.CompletedAt, job.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update job: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("job not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a job
|
||||||
|
func (r *JobRepository) Delete(ctx context.Context, id string) error {
|
||||||
|
result, err := r.db.ExecContext(ctx, "DELETE FROM jobs WHERE id = $1", id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete job: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("job not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListByStatus returns jobs with a specific status
|
||||||
|
func (r *JobRepository) ListByStatus(ctx context.Context, status domain.JobStatus) ([]*domain.Job, error) {
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT id, type, certificate_id, target_id, status, attempts, max_attempts,
|
||||||
|
last_error, scheduled_at, started_at, completed_at, created_at
|
||||||
|
FROM jobs
|
||||||
|
WHERE status = $1
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
`, status)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query jobs by status: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var jobs []*domain.Job
|
||||||
|
for rows.Next() {
|
||||||
|
job, err := scanJob(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
jobs = append(jobs, job)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating job rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return jobs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListByCertificate returns all jobs for a certificate
|
||||||
|
func (r *JobRepository) ListByCertificate(ctx context.Context, certID string) ([]*domain.Job, error) {
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT id, type, certificate_id, target_id, status, attempts, max_attempts,
|
||||||
|
last_error, scheduled_at, started_at, completed_at, created_at
|
||||||
|
FROM jobs
|
||||||
|
WHERE certificate_id = $1
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
`, certID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query jobs for certificate: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var jobs []*domain.Job
|
||||||
|
for rows.Next() {
|
||||||
|
job, err := scanJob(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
jobs = append(jobs, job)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating job rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return jobs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateStatus updates a job's status and optional error message
|
||||||
|
func (r *JobRepository) UpdateStatus(ctx context.Context, id string, status domain.JobStatus, errMsg string) error {
|
||||||
|
var lastError *string
|
||||||
|
if errMsg != "" {
|
||||||
|
lastError = &errMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := r.db.ExecContext(ctx, `
|
||||||
|
UPDATE jobs SET status = $1, last_error = $2 WHERE id = $3
|
||||||
|
`, status, lastError, id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update job status: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("job not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPendingJobs returns jobs not yet processed of a specific type
|
||||||
|
func (r *JobRepository) GetPendingJobs(ctx context.Context, jobType domain.JobType) ([]*domain.Job, error) {
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT id, type, certificate_id, target_id, status, attempts, max_attempts,
|
||||||
|
last_error, scheduled_at, started_at, completed_at, created_at
|
||||||
|
FROM jobs
|
||||||
|
WHERE type = $1 AND status = $2
|
||||||
|
ORDER BY scheduled_at ASC
|
||||||
|
`, jobType, domain.JobStatusPending)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query pending jobs: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var jobs []*domain.Job
|
||||||
|
for rows.Next() {
|
||||||
|
job, err := scanJob(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
jobs = append(jobs, job)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating job rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return jobs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanJob scans a job from a row or rows
|
||||||
|
func scanJob(scanner interface {
|
||||||
|
Scan(...interface{}) error
|
||||||
|
}) (*domain.Job, error) {
|
||||||
|
var job domain.Job
|
||||||
|
err := scanner.Scan(&job.ID, &job.Type, &job.CertificateID, &job.TargetID,
|
||||||
|
&job.Status, &job.Attempts, &job.MaxAttempts, &job.LastError,
|
||||||
|
&job.ScheduledAt, &job.StartedAt, &job.CompletedAt, &job.CreatedAt)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan job: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &job, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,162 @@
|
|||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
"github.com/shankar0123/certctl/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NotificationRepository implements repository.NotificationRepository
|
||||||
|
type NotificationRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNotificationRepository creates a new NotificationRepository
|
||||||
|
func NewNotificationRepository(db *sql.DB) *NotificationRepository {
|
||||||
|
return &NotificationRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create stores a new notification
|
||||||
|
func (r *NotificationRepository) Create(ctx context.Context, notif *domain.NotificationEvent) error {
|
||||||
|
if notif.ID == "" {
|
||||||
|
notif.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO notifications (
|
||||||
|
id, type, certificate_id, channel, recipient, message, sent_at, status, error, created_at
|
||||||
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||||
|
RETURNING id
|
||||||
|
`, notif.ID, notif.Type, notif.CertificateID, notif.Channel, notif.Recipient,
|
||||||
|
notif.Message, notif.SentAt, notif.Status, notif.Error, notif.CreatedAt).Scan(¬if.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create notification: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns notifications matching the filter criteria
|
||||||
|
func (r *NotificationRepository) List(ctx context.Context, filter *repository.NotificationFilter) ([]*domain.NotificationEvent, error) {
|
||||||
|
if filter == nil {
|
||||||
|
filter = &repository.NotificationFilter{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set defaults
|
||||||
|
if filter.Page < 1 {
|
||||||
|
filter.Page = 1
|
||||||
|
}
|
||||||
|
if filter.PerPage == 0 || filter.PerPage > 500 {
|
||||||
|
filter.PerPage = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build WHERE clause
|
||||||
|
var whereConditions []string
|
||||||
|
var args []interface{}
|
||||||
|
argCount := 1
|
||||||
|
|
||||||
|
if filter.CertificateID != "" {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("certificate_id = $%d", argCount))
|
||||||
|
args = append(args, filter.CertificateID)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
if filter.Status != "" {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("status = $%d", argCount))
|
||||||
|
args = append(args, filter.Status)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
if filter.Channel != "" {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("channel = $%d", argCount))
|
||||||
|
args = append(args, filter.Channel)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
whereClause := ""
|
||||||
|
if len(whereConditions) > 0 {
|
||||||
|
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get total count
|
||||||
|
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM notifications %s", whereClause)
|
||||||
|
var total int
|
||||||
|
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to count notifications: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get paginated results
|
||||||
|
offset := (filter.Page - 1) * filter.PerPage
|
||||||
|
query := fmt.Sprintf(`
|
||||||
|
SELECT id, type, certificate_id, channel, recipient, message, sent_at, status, error, created_at
|
||||||
|
FROM notifications
|
||||||
|
%s
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT $%d OFFSET $%d
|
||||||
|
`, whereClause, argCount, argCount+1)
|
||||||
|
|
||||||
|
args = append(args, filter.PerPage, offset)
|
||||||
|
|
||||||
|
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query notifications: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var notifs []*domain.NotificationEvent
|
||||||
|
for rows.Next() {
|
||||||
|
notif, err := scanNotification(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
notifs = append(notifs, notif)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating notification rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return notifs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateStatus updates a notification's delivery status
|
||||||
|
func (r *NotificationRepository) UpdateStatus(ctx context.Context, id string, status string, sentAt time.Time) error {
|
||||||
|
result, err := r.db.ExecContext(ctx, `
|
||||||
|
UPDATE notifications SET status = $1, sent_at = $2 WHERE id = $3
|
||||||
|
`, status, sentAt, id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update notification status: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("notification not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanNotification scans a notification from a row or rows
|
||||||
|
func scanNotification(scanner interface {
|
||||||
|
Scan(...interface{}) error
|
||||||
|
}) (*domain.NotificationEvent, error) {
|
||||||
|
var notif domain.NotificationEvent
|
||||||
|
err := scanner.Scan(¬if.ID, ¬if.Type, ¬if.CertificateID, ¬if.Channel,
|
||||||
|
¬if.Recipient, ¬if.Message, ¬if.SentAt, ¬if.Status, ¬if.Error, ¬if.CreatedAt)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan notification: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ¬if, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,137 @@
|
|||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OwnerRepository implements repository.OwnerRepository
|
||||||
|
type OwnerRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOwnerRepository creates a new OwnerRepository
|
||||||
|
func NewOwnerRepository(db *sql.DB) *OwnerRepository {
|
||||||
|
return &OwnerRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns all owners
|
||||||
|
func (r *OwnerRepository) List(ctx context.Context) ([]*domain.Owner, error) {
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT id, name, email, team_id, created_at, updated_at
|
||||||
|
FROM owners
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
`)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query owners: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var owners []*domain.Owner
|
||||||
|
for rows.Next() {
|
||||||
|
var owner domain.Owner
|
||||||
|
if err := rows.Scan(&owner.ID, &owner.Name, &owner.Email, &owner.TeamID,
|
||||||
|
&owner.CreatedAt, &owner.UpdatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan owner: %w", err)
|
||||||
|
}
|
||||||
|
owners = append(owners, &owner)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating owner rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return owners, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves an owner by ID
|
||||||
|
func (r *OwnerRepository) Get(ctx context.Context, id string) (*domain.Owner, error) {
|
||||||
|
var owner domain.Owner
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
SELECT id, name, email, team_id, created_at, updated_at
|
||||||
|
FROM owners
|
||||||
|
WHERE id = $1
|
||||||
|
`, id).Scan(&owner.ID, &owner.Name, &owner.Email, &owner.TeamID,
|
||||||
|
&owner.CreatedAt, &owner.UpdatedAt)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, fmt.Errorf("owner not found")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to query owner: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &owner, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create stores a new owner
|
||||||
|
func (r *OwnerRepository) Create(ctx context.Context, owner *domain.Owner) error {
|
||||||
|
if owner.ID == "" {
|
||||||
|
owner.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO owners (id, name, email, team_id, created_at, updated_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6)
|
||||||
|
RETURNING id
|
||||||
|
`, owner.ID, owner.Name, owner.Email, owner.TeamID,
|
||||||
|
owner.CreatedAt, owner.UpdatedAt).Scan(&owner.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create owner: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update modifies an existing owner
|
||||||
|
func (r *OwnerRepository) Update(ctx context.Context, owner *domain.Owner) error {
|
||||||
|
result, err := r.db.ExecContext(ctx, `
|
||||||
|
UPDATE owners SET
|
||||||
|
name = $1,
|
||||||
|
email = $2,
|
||||||
|
team_id = $3,
|
||||||
|
updated_at = $4
|
||||||
|
WHERE id = $5
|
||||||
|
`, owner.Name, owner.Email, owner.TeamID, owner.UpdatedAt, owner.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update owner: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("owner not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes an owner
|
||||||
|
func (r *OwnerRepository) Delete(ctx context.Context, id string) error {
|
||||||
|
result, err := r.db.ExecContext(ctx, "DELETE FROM owners WHERE id = $1", id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete owner: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("owner not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,242 @@
|
|||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
"github.com/shankar0123/certctl/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PolicyRepository implements repository.PolicyRepository
|
||||||
|
type PolicyRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPolicyRepository creates a new PolicyRepository
|
||||||
|
func NewPolicyRepository(db *sql.DB) *PolicyRepository {
|
||||||
|
return &PolicyRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListRules returns all policy rules
|
||||||
|
func (r *PolicyRepository) ListRules(ctx context.Context) ([]*domain.PolicyRule, error) {
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT id, name, type, config, enabled, created_at, updated_at
|
||||||
|
FROM policy_rules
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
`)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query policy rules: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var rules []*domain.PolicyRule
|
||||||
|
for rows.Next() {
|
||||||
|
var rule domain.PolicyRule
|
||||||
|
if err := rows.Scan(&rule.ID, &rule.Name, &rule.Type, &rule.Config,
|
||||||
|
&rule.Enabled, &rule.CreatedAt, &rule.UpdatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan policy rule: %w", err)
|
||||||
|
}
|
||||||
|
rules = append(rules, &rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating policy rule rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rules, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRule retrieves a policy rule by ID
|
||||||
|
func (r *PolicyRepository) GetRule(ctx context.Context, id string) (*domain.PolicyRule, error) {
|
||||||
|
var rule domain.PolicyRule
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
SELECT id, name, type, config, enabled, created_at, updated_at
|
||||||
|
FROM policy_rules
|
||||||
|
WHERE id = $1
|
||||||
|
`, id).Scan(&rule.ID, &rule.Name, &rule.Type, &rule.Config,
|
||||||
|
&rule.Enabled, &rule.CreatedAt, &rule.UpdatedAt)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, fmt.Errorf("policy rule not found")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to query policy rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateRule stores a new policy rule
|
||||||
|
func (r *PolicyRepository) CreateRule(ctx context.Context, rule *domain.PolicyRule) error {
|
||||||
|
if rule.ID == "" {
|
||||||
|
rule.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO policy_rules (id, name, type, config, enabled, created_at, updated_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||||
|
RETURNING id
|
||||||
|
`, rule.ID, rule.Name, rule.Type, rule.Config, rule.Enabled,
|
||||||
|
rule.CreatedAt, rule.UpdatedAt).Scan(&rule.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create policy rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRule modifies an existing policy rule
|
||||||
|
func (r *PolicyRepository) UpdateRule(ctx context.Context, rule *domain.PolicyRule) error {
|
||||||
|
result, err := r.db.ExecContext(ctx, `
|
||||||
|
UPDATE policy_rules SET
|
||||||
|
name = $1,
|
||||||
|
type = $2,
|
||||||
|
config = $3,
|
||||||
|
enabled = $4,
|
||||||
|
updated_at = $5
|
||||||
|
WHERE id = $6
|
||||||
|
`, rule.Name, rule.Type, rule.Config, rule.Enabled, rule.UpdatedAt, rule.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update policy rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("policy rule not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRule removes a policy rule
|
||||||
|
func (r *PolicyRepository) DeleteRule(ctx context.Context, id string) error {
|
||||||
|
result, err := r.db.ExecContext(ctx, "DELETE FROM policy_rules WHERE id = $1", id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete policy rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("policy rule not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateViolation records a policy violation
|
||||||
|
func (r *PolicyRepository) CreateViolation(ctx context.Context, violation *domain.PolicyViolation) error {
|
||||||
|
if violation.ID == "" {
|
||||||
|
violation.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO policy_violations (id, certificate_id, rule_id, message, severity, created_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6)
|
||||||
|
RETURNING id
|
||||||
|
`, violation.ID, violation.CertificateID, violation.RuleID, violation.Message,
|
||||||
|
violation.Severity, violation.CreatedAt).Scan(&violation.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create policy violation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListViolations returns policy violations, optionally filtered
|
||||||
|
func (r *PolicyRepository) ListViolations(ctx context.Context, filter *repository.AuditFilter) ([]*domain.PolicyViolation, error) {
|
||||||
|
if filter == nil {
|
||||||
|
filter = &repository.AuditFilter{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set defaults
|
||||||
|
if filter.Page < 1 {
|
||||||
|
filter.Page = 1
|
||||||
|
}
|
||||||
|
if filter.PerPage == 0 || filter.PerPage > 500 {
|
||||||
|
filter.PerPage = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build WHERE clause
|
||||||
|
var whereConditions []string
|
||||||
|
var args []interface{}
|
||||||
|
argCount := 1
|
||||||
|
|
||||||
|
if filter.ResourceID != "" {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("certificate_id = $%d", argCount))
|
||||||
|
args = append(args, filter.ResourceID)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
if !filter.From.IsZero() {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("created_at >= $%d", argCount))
|
||||||
|
args = append(args, filter.From)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
if !filter.To.IsZero() {
|
||||||
|
whereConditions = append(whereConditions, fmt.Sprintf("created_at <= $%d", argCount))
|
||||||
|
args = append(args, filter.To)
|
||||||
|
argCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
whereClause := ""
|
||||||
|
if len(whereConditions) > 0 {
|
||||||
|
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get total count
|
||||||
|
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM policy_violations %s", whereClause)
|
||||||
|
var total int
|
||||||
|
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to count policy violations: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get paginated results
|
||||||
|
offset := (filter.Page - 1) * filter.PerPage
|
||||||
|
query := fmt.Sprintf(`
|
||||||
|
SELECT id, certificate_id, rule_id, message, severity, created_at
|
||||||
|
FROM policy_violations
|
||||||
|
%s
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT $%d OFFSET $%d
|
||||||
|
`, whereClause, argCount, argCount+1)
|
||||||
|
|
||||||
|
args = append(args, filter.PerPage, offset)
|
||||||
|
|
||||||
|
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query policy violations: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var violations []*domain.PolicyViolation
|
||||||
|
for rows.Next() {
|
||||||
|
var v domain.PolicyViolation
|
||||||
|
if err := rows.Scan(&v.ID, &v.CertificateID, &v.RuleID, &v.Message,
|
||||||
|
&v.Severity, &v.CreatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan policy violation: %w", err)
|
||||||
|
}
|
||||||
|
violations = append(violations, &v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating policy violation rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return violations, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,171 @@
|
|||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TargetRepository implements repository.TargetRepository
|
||||||
|
type TargetRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTargetRepository creates a new TargetRepository
|
||||||
|
func NewTargetRepository(db *sql.DB) *TargetRepository {
|
||||||
|
return &TargetRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns all targets
|
||||||
|
func (r *TargetRepository) List(ctx context.Context) ([]*domain.DeploymentTarget, error) {
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT id, name, type, agent_id, config, enabled, created_at, updated_at
|
||||||
|
FROM deployment_targets
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
`)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query targets: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var targets []*domain.DeploymentTarget
|
||||||
|
for rows.Next() {
|
||||||
|
var target domain.DeploymentTarget
|
||||||
|
if err := rows.Scan(&target.ID, &target.Name, &target.Type, &target.AgentID,
|
||||||
|
&target.Config, &target.Enabled, &target.CreatedAt, &target.UpdatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan target: %w", err)
|
||||||
|
}
|
||||||
|
targets = append(targets, &target)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating target rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return targets, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a target by ID
|
||||||
|
func (r *TargetRepository) Get(ctx context.Context, id string) (*domain.DeploymentTarget, error) {
|
||||||
|
var target domain.DeploymentTarget
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
SELECT id, name, type, agent_id, config, enabled, created_at, updated_at
|
||||||
|
FROM deployment_targets
|
||||||
|
WHERE id = $1
|
||||||
|
`, id).Scan(&target.ID, &target.Name, &target.Type, &target.AgentID,
|
||||||
|
&target.Config, &target.Enabled, &target.CreatedAt, &target.UpdatedAt)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, fmt.Errorf("target not found")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to query target: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &target, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create stores a new target
|
||||||
|
func (r *TargetRepository) Create(ctx context.Context, target *domain.DeploymentTarget) error {
|
||||||
|
if target.ID == "" {
|
||||||
|
target.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO deployment_targets (id, name, type, agent_id, config, enabled, created_at, updated_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
|
RETURNING id
|
||||||
|
`, target.ID, target.Name, target.Type, target.AgentID, target.Config, target.Enabled,
|
||||||
|
target.CreatedAt, target.UpdatedAt).Scan(&target.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create target: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update modifies an existing target
|
||||||
|
func (r *TargetRepository) Update(ctx context.Context, target *domain.DeploymentTarget) error {
|
||||||
|
result, err := r.db.ExecContext(ctx, `
|
||||||
|
UPDATE deployment_targets SET
|
||||||
|
name = $1,
|
||||||
|
type = $2,
|
||||||
|
agent_id = $3,
|
||||||
|
config = $4,
|
||||||
|
enabled = $5,
|
||||||
|
updated_at = $6
|
||||||
|
WHERE id = $7
|
||||||
|
`, target.Name, target.Type, target.AgentID, target.Config, target.Enabled, target.UpdatedAt, target.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update target: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("target not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a target
|
||||||
|
func (r *TargetRepository) Delete(ctx context.Context, id string) error {
|
||||||
|
result, err := r.db.ExecContext(ctx, "DELETE FROM deployment_targets WHERE id = $1", id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete target: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("target not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListByCertificate returns all targets for a given certificate
|
||||||
|
func (r *TargetRepository) ListByCertificate(ctx context.Context, certID string) ([]*domain.DeploymentTarget, error) {
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT dt.id, dt.name, dt.type, dt.agent_id, dt.config, dt.enabled, dt.created_at, dt.updated_at
|
||||||
|
FROM deployment_targets dt
|
||||||
|
INNER JOIN certificate_target_mappings ctm ON dt.id = ctm.target_id
|
||||||
|
WHERE ctm.certificate_id = $1
|
||||||
|
ORDER BY dt.created_at DESC
|
||||||
|
`, certID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query targets for certificate: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var targets []*domain.DeploymentTarget
|
||||||
|
for rows.Next() {
|
||||||
|
var target domain.DeploymentTarget
|
||||||
|
if err := rows.Scan(&target.ID, &target.Name, &target.Type, &target.AgentID,
|
||||||
|
&target.Config, &target.Enabled, &target.CreatedAt, &target.UpdatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan target: %w", err)
|
||||||
|
}
|
||||||
|
targets = append(targets, &target)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating target rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return targets, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,135 @@
|
|||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TeamRepository implements repository.TeamRepository
|
||||||
|
type TeamRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTeamRepository creates a new TeamRepository
|
||||||
|
func NewTeamRepository(db *sql.DB) *TeamRepository {
|
||||||
|
return &TeamRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns all teams
|
||||||
|
func (r *TeamRepository) List(ctx context.Context) ([]*domain.Team, error) {
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT id, name, description, created_at, updated_at
|
||||||
|
FROM teams
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
`)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query teams: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var teams []*domain.Team
|
||||||
|
for rows.Next() {
|
||||||
|
var team domain.Team
|
||||||
|
if err := rows.Scan(&team.ID, &team.Name, &team.Description,
|
||||||
|
&team.CreatedAt, &team.UpdatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan team: %w", err)
|
||||||
|
}
|
||||||
|
teams = append(teams, &team)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error iterating team rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return teams, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a team by ID
|
||||||
|
func (r *TeamRepository) Get(ctx context.Context, id string) (*domain.Team, error) {
|
||||||
|
var team domain.Team
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
SELECT id, name, description, created_at, updated_at
|
||||||
|
FROM teams
|
||||||
|
WHERE id = $1
|
||||||
|
`, id).Scan(&team.ID, &team.Name, &team.Description,
|
||||||
|
&team.CreatedAt, &team.UpdatedAt)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, fmt.Errorf("team not found")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to query team: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &team, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create stores a new team
|
||||||
|
func (r *TeamRepository) Create(ctx context.Context, team *domain.Team) error {
|
||||||
|
if team.ID == "" {
|
||||||
|
team.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
err := r.db.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO teams (id, name, description, created_at, updated_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5)
|
||||||
|
RETURNING id
|
||||||
|
`, team.ID, team.Name, team.Description, team.CreatedAt, team.UpdatedAt).Scan(&team.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create team: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update modifies an existing team
|
||||||
|
func (r *TeamRepository) Update(ctx context.Context, team *domain.Team) error {
|
||||||
|
result, err := r.db.ExecContext(ctx, `
|
||||||
|
UPDATE teams SET
|
||||||
|
name = $1,
|
||||||
|
description = $2,
|
||||||
|
updated_at = $3
|
||||||
|
WHERE id = $4
|
||||||
|
`, team.Name, team.Description, team.UpdatedAt, team.ID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update team: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("team not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a team
|
||||||
|
func (r *TeamRepository) Delete(ctx context.Context, id string) error {
|
||||||
|
result, err := r.db.ExecContext(ctx, "DELETE FROM teams WHERE id = $1", id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete team: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("team not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -108,3 +108,54 @@ func (s *AuditService) ListByAction(ctx context.Context, action string, from, to
|
|||||||
|
|
||||||
return filtered, nil
|
return filtered, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListAuditEvents returns paginated audit events (handler interface method).
|
||||||
|
func (s *AuditService) ListAuditEvents(page, perPage int) ([]domain.AuditEvent, int64, error) {
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if perPage < 1 {
|
||||||
|
perPage = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := &repository.AuditFilter{
|
||||||
|
Offset: int64((page - 1) * perPage),
|
||||||
|
PerPage: int64(perPage),
|
||||||
|
}
|
||||||
|
|
||||||
|
events, err := s.auditRepo.List(context.Background(), filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to list audit events: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert pointers to values for the handler interface
|
||||||
|
var result []domain.AuditEvent
|
||||||
|
for _, e := range events {
|
||||||
|
if e != nil {
|
||||||
|
result = append(result, *e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Get total count from repository
|
||||||
|
total := int64(len(result))
|
||||||
|
|
||||||
|
return result, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuditEvent returns a single audit event (handler interface method).
|
||||||
|
func (s *AuditService) GetAuditEvent(id string) (*domain.AuditEvent, error) {
|
||||||
|
filter := &repository.AuditFilter{
|
||||||
|
ID: id,
|
||||||
|
}
|
||||||
|
|
||||||
|
events, err := s.auditRepo.List(context.Background(), filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get audit event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(events) == 0 {
|
||||||
|
return nil, fmt.Errorf("audit event not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return events[0], nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,161 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
"github.com/shankar0123/certctl/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuditService provides business logic for recording and retrieving audit events.
|
||||||
|
type AuditService struct {
|
||||||
|
auditRepo repository.AuditRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuditService creates a new audit service.
|
||||||
|
func NewAuditService(auditRepo repository.AuditRepository) *AuditService {
|
||||||
|
return &AuditService{
|
||||||
|
auditRepo: auditRepo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordEvent records an audit event with actor, action, and resource information.
|
||||||
|
func (s *AuditService) RecordEvent(ctx context.Context, actor string, actorType domain.ActorType, action string, resourceType string, resourceID string, details map[string]interface{}) error {
|
||||||
|
detailsJSON, err := json.Marshal(details)
|
||||||
|
if err != nil {
|
||||||
|
detailsJSON = []byte("{}")
|
||||||
|
}
|
||||||
|
|
||||||
|
event := &domain.AuditEvent{
|
||||||
|
ID: generateID("audit"),
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Actor: actor,
|
||||||
|
ActorType: actorType,
|
||||||
|
Action: action,
|
||||||
|
ResourceType: resourceType,
|
||||||
|
ResourceID: resourceID,
|
||||||
|
Details: json.RawMessage(detailsJSON),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.auditRepo.Create(ctx, event); err != nil {
|
||||||
|
return fmt.Errorf("failed to record audit event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns audit events matching filter criteria.
|
||||||
|
func (s *AuditService) List(ctx context.Context, filter *repository.AuditFilter) ([]*domain.AuditEvent, error) {
|
||||||
|
events, err := s.auditRepo.List(ctx, filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list audit events: %w", err)
|
||||||
|
}
|
||||||
|
return events, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListByResource returns all audit events for a specific resource.
|
||||||
|
func (s *AuditService) ListByResource(ctx context.Context, resourceType string, resourceID string) ([]*domain.AuditEvent, error) {
|
||||||
|
filter := &repository.AuditFilter{
|
||||||
|
ResourceType: resourceType,
|
||||||
|
ResourceID: resourceID,
|
||||||
|
PerPage: 1000, // reasonable default for single resource
|
||||||
|
}
|
||||||
|
|
||||||
|
events, err := s.auditRepo.List(ctx, filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list audit events: %w", err)
|
||||||
|
}
|
||||||
|
return events, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListByActor returns all audit events for a specific actor.
|
||||||
|
func (s *AuditService) ListByActor(ctx context.Context, actor string) ([]*domain.AuditEvent, error) {
|
||||||
|
filter := &repository.AuditFilter{
|
||||||
|
Actor: actor,
|
||||||
|
PerPage: 1000,
|
||||||
|
}
|
||||||
|
|
||||||
|
events, err := s.auditRepo.List(ctx, filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list audit events: %w", err)
|
||||||
|
}
|
||||||
|
return events, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListByAction returns all audit events for a specific action type.
|
||||||
|
func (s *AuditService) ListByAction(ctx context.Context, action string, from, to time.Time) ([]*domain.AuditEvent, error) {
|
||||||
|
filter := &repository.AuditFilter{
|
||||||
|
From: from,
|
||||||
|
To: to,
|
||||||
|
PerPage: 1000,
|
||||||
|
}
|
||||||
|
|
||||||
|
events, err := s.auditRepo.List(ctx, filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list audit events: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter by action on client side (repository may not filter by action directly)
|
||||||
|
var filtered []*domain.AuditEvent
|
||||||
|
for _, e := range events {
|
||||||
|
if e.Action == action {
|
||||||
|
filtered = append(filtered, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filtered, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListAuditEvents returns paginated audit events (handler interface method).
|
||||||
|
func (s *AuditService) ListAuditEvents(page, perPage int) ([]domain.AuditEvent, int64, error) {
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if perPage < 1 {
|
||||||
|
perPage = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := &repository.AuditFilter{
|
||||||
|
Offset: int64((page - 1) * perPage),
|
||||||
|
PerPage: int64(perPage),
|
||||||
|
}
|
||||||
|
|
||||||
|
events, err := s.auditRepo.List(context.Background(), filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to list audit events: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert pointers to values for the handler interface
|
||||||
|
var result []domain.AuditEvent
|
||||||
|
for _, e := range events {
|
||||||
|
if e != nil {
|
||||||
|
result = append(result, *e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Get total count from repository
|
||||||
|
total := int64(len(result))
|
||||||
|
|
||||||
|
return result, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuditEvent returns a single audit event (handler interface method).
|
||||||
|
func (s *AuditService) GetAuditEvent(id string) (*domain.AuditEvent, error) {
|
||||||
|
filter := &repository.AuditFilter{
|
||||||
|
ID: id,
|
||||||
|
}
|
||||||
|
|
||||||
|
events, err := s.auditRepo.List(context.Background(), filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get audit event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(events) == 0 {
|
||||||
|
return nil, fmt.Errorf("audit event not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return events[0], nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,170 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
"github.com/shankar0123/certctl/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IssuerService provides business logic for certificate issuer management.
|
||||||
|
type IssuerService struct {
|
||||||
|
issuerRepo repository.IssuerRepository
|
||||||
|
auditService *AuditService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIssuerService creates a new issuer service.
|
||||||
|
func NewIssuerService(
|
||||||
|
issuerRepo repository.IssuerRepository,
|
||||||
|
auditService *AuditService,
|
||||||
|
) *IssuerService {
|
||||||
|
return &IssuerService{
|
||||||
|
issuerRepo: issuerRepo,
|
||||||
|
auditService: auditService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns a paginated list of issuers.
|
||||||
|
func (s *IssuerService) List(ctx context.Context, page, perPage int) ([]*domain.Issuer, int64, error) {
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if perPage < 1 {
|
||||||
|
perPage = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := int64((page - 1) * perPage)
|
||||||
|
issuers, total, err := s.issuerRepo.List(ctx, offset, int64(perPage))
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to list issuers: %w", err)
|
||||||
|
}
|
||||||
|
return issuers, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves an issuer by ID.
|
||||||
|
func (s *IssuerService) Get(ctx context.Context, id string) (*domain.Issuer, error) {
|
||||||
|
issuer, err := s.issuerRepo.Get(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get issuer %s: %w", id, err)
|
||||||
|
}
|
||||||
|
return issuer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create validates and stores a new issuer.
|
||||||
|
func (s *IssuerService) Create(ctx context.Context, issuer *domain.Issuer, actor string) error {
|
||||||
|
if issuer.Name == "" {
|
||||||
|
return fmt.Errorf("issuer name is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
issuer.ID = generateID("issuer")
|
||||||
|
if err := s.issuerRepo.Create(ctx, issuer); err != nil {
|
||||||
|
return fmt.Errorf("failed to create issuer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.auditService != nil {
|
||||||
|
_ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "create_issuer", "issuer", issuer.ID, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update modifies an existing issuer.
|
||||||
|
func (s *IssuerService) Update(ctx context.Context, id string, issuer *domain.Issuer, actor string) error {
|
||||||
|
if issuer.Name == "" {
|
||||||
|
return fmt.Errorf("issuer name is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
issuer.ID = id
|
||||||
|
if err := s.issuerRepo.Update(ctx, issuer); err != nil {
|
||||||
|
return fmt.Errorf("failed to update issuer %s: %w", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.auditService != nil {
|
||||||
|
_ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "update_issuer", "issuer", id, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes an issuer.
|
||||||
|
func (s *IssuerService) Delete(ctx context.Context, id string, actor string) error {
|
||||||
|
if err := s.issuerRepo.Delete(ctx, id); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete issuer %s: %w", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.auditService != nil {
|
||||||
|
_ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "delete_issuer", "issuer", id, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConnection verifies the issuer connection.
|
||||||
|
func (s *IssuerService) TestConnection(ctx context.Context, id string) error {
|
||||||
|
issuer, err := s.issuerRepo.Get(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("issuer not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Implement actual connection test based on issuer type
|
||||||
|
if issuer == nil {
|
||||||
|
return fmt.Errorf("issuer not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListIssuers returns paginated issuers (handler interface method).
|
||||||
|
func (s *IssuerService) ListIssuers(page, perPage int) ([]domain.Issuer, int64, error) {
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if perPage < 1 {
|
||||||
|
perPage = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := int64((page - 1) * perPage)
|
||||||
|
issuers, total, err := s.issuerRepo.List(context.Background(), offset, int64(perPage))
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to list issuers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert pointers to values for the handler interface
|
||||||
|
var result []domain.Issuer
|
||||||
|
for _, i := range issuers {
|
||||||
|
if i != nil {
|
||||||
|
result = append(result, *i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIssuer returns a single issuer (handler interface method).
|
||||||
|
func (s *IssuerService) GetIssuer(id string) (*domain.Issuer, error) {
|
||||||
|
return s.issuerRepo.Get(context.Background(), id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateIssuer creates a new issuer (handler interface method).
|
||||||
|
func (s *IssuerService) CreateIssuer(issuer domain.Issuer) (*domain.Issuer, error) {
|
||||||
|
issuer.ID = generateID("issuer")
|
||||||
|
if err := s.issuerRepo.Create(context.Background(), &issuer); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create issuer: %w", err)
|
||||||
|
}
|
||||||
|
return &issuer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateIssuer modifies an issuer (handler interface method).
|
||||||
|
func (s *IssuerService) UpdateIssuer(id string, issuer domain.Issuer) (*domain.Issuer, error) {
|
||||||
|
issuer.ID = id
|
||||||
|
if err := s.issuerRepo.Update(context.Background(), &issuer); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to update issuer: %w", err)
|
||||||
|
}
|
||||||
|
return &issuer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteIssuer removes an issuer (handler interface method).
|
||||||
|
func (s *IssuerService) DeleteIssuer(id string) error {
|
||||||
|
return s.issuerRepo.Delete(context.Background(), id)
|
||||||
|
}
|
||||||
@@ -0,0 +1,116 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
"github.com/shankar0123/certctl/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IssuerService provides business logic for certificate issuer management.
|
||||||
|
type IssuerService struct {
|
||||||
|
issuerRepo repository.IssuerRepository
|
||||||
|
auditService *AuditService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIssuerService creates a new issuer service.
|
||||||
|
func NewIssuerService(
|
||||||
|
issuerRepo repository.IssuerRepository,
|
||||||
|
auditService *AuditService,
|
||||||
|
) *IssuerService {
|
||||||
|
return &IssuerService{
|
||||||
|
issuerRepo: issuerRepo,
|
||||||
|
auditService: auditService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns a paginated list of issuers.
|
||||||
|
func (s *IssuerService) List(ctx context.Context, page, perPage int) ([]*domain.Issuer, int64, error) {
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if perPage < 1 {
|
||||||
|
perPage = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := int64((page - 1) * perPage)
|
||||||
|
issuers, total, err := s.issuerRepo.List(ctx, offset, int64(perPage))
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to list issuers: %w", err)
|
||||||
|
}
|
||||||
|
return issuers, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves an issuer by ID.
|
||||||
|
func (s *IssuerService) Get(ctx context.Context, id string) (*domain.Issuer, error) {
|
||||||
|
issuer, err := s.issuerRepo.Get(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get issuer %s: %w", id, err)
|
||||||
|
}
|
||||||
|
return issuer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create validates and stores a new issuer.
|
||||||
|
func (s *IssuerService) Create(ctx context.Context, issuer *domain.Issuer, actor string) error {
|
||||||
|
if issuer.Name == "" {
|
||||||
|
return fmt.Errorf("issuer name is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
issuer.ID = generateID("issuer")
|
||||||
|
if err := s.issuerRepo.Create(ctx, issuer); err != nil {
|
||||||
|
return fmt.Errorf("failed to create issuer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.auditService != nil {
|
||||||
|
_ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "create_issuer", "issuer", issuer.ID, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update modifies an existing issuer.
|
||||||
|
func (s *IssuerService) Update(ctx context.Context, id string, issuer *domain.Issuer, actor string) error {
|
||||||
|
if issuer.Name == "" {
|
||||||
|
return fmt.Errorf("issuer name is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
issuer.ID = id
|
||||||
|
if err := s.issuerRepo.Update(ctx, issuer); err != nil {
|
||||||
|
return fmt.Errorf("failed to update issuer %s: %w", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.auditService != nil {
|
||||||
|
_ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "update_issuer", "issuer", id, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes an issuer.
|
||||||
|
func (s *IssuerService) Delete(ctx context.Context, id string, actor string) error {
|
||||||
|
if err := s.issuerRepo.Delete(ctx, id); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete issuer %s: %w", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.auditService != nil {
|
||||||
|
_ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "delete_issuer", "issuer", id, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConnection verifies the issuer connection.
|
||||||
|
func (s *IssuerService) TestConnection(ctx context.Context, id string) error {
|
||||||
|
issuer, err := s.issuerRepo.Get(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("issuer not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Implement actual connection test based on issuer type
|
||||||
|
if issuer == nil {
|
||||||
|
return fmt.Errorf("issuer not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,155 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
"github.com/shankar0123/certctl/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OwnerService provides business logic for certificate owner management.
|
||||||
|
type OwnerService struct {
|
||||||
|
ownerRepo repository.OwnerRepository
|
||||||
|
auditService *AuditService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOwnerService creates a new owner service.
|
||||||
|
func NewOwnerService(
|
||||||
|
ownerRepo repository.OwnerRepository,
|
||||||
|
auditService *AuditService,
|
||||||
|
) *OwnerService {
|
||||||
|
return &OwnerService{
|
||||||
|
ownerRepo: ownerRepo,
|
||||||
|
auditService: auditService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns a paginated list of owners.
|
||||||
|
func (s *OwnerService) List(ctx context.Context, page, perPage int) ([]*domain.Owner, int64, error) {
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if perPage < 1 {
|
||||||
|
perPage = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := int64((page - 1) * perPage)
|
||||||
|
owners, total, err := s.ownerRepo.List(ctx, offset, int64(perPage))
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to list owners: %w", err)
|
||||||
|
}
|
||||||
|
return owners, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves an owner by ID.
|
||||||
|
func (s *OwnerService) Get(ctx context.Context, id string) (*domain.Owner, error) {
|
||||||
|
owner, err := s.ownerRepo.Get(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get owner %s: %w", id, err)
|
||||||
|
}
|
||||||
|
return owner, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create validates and stores a new owner.
|
||||||
|
func (s *OwnerService) Create(ctx context.Context, owner *domain.Owner, actor string) error {
|
||||||
|
if owner.Name == "" {
|
||||||
|
return fmt.Errorf("owner name is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
owner.ID = generateID("owner")
|
||||||
|
if err := s.ownerRepo.Create(ctx, owner); err != nil {
|
||||||
|
return fmt.Errorf("failed to create owner: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.auditService != nil {
|
||||||
|
_ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "create_owner", "owner", owner.ID, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update modifies an existing owner.
|
||||||
|
func (s *OwnerService) Update(ctx context.Context, id string, owner *domain.Owner, actor string) error {
|
||||||
|
if owner.Name == "" {
|
||||||
|
return fmt.Errorf("owner name is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
owner.ID = id
|
||||||
|
if err := s.ownerRepo.Update(ctx, owner); err != nil {
|
||||||
|
return fmt.Errorf("failed to update owner %s: %w", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.auditService != nil {
|
||||||
|
_ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "update_owner", "owner", id, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes an owner.
|
||||||
|
func (s *OwnerService) Delete(ctx context.Context, id string, actor string) error {
|
||||||
|
if err := s.ownerRepo.Delete(ctx, id); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete owner %s: %w", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.auditService != nil {
|
||||||
|
_ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "delete_owner", "owner", id, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListOwners returns paginated owners (handler interface method).
|
||||||
|
func (s *OwnerService) ListOwners(page, perPage int) ([]domain.Owner, int64, error) {
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if perPage < 1 {
|
||||||
|
perPage = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := int64((page - 1) * perPage)
|
||||||
|
owners, total, err := s.ownerRepo.List(context.Background(), offset, int64(perPage))
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to list owners: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert pointers to values for the handler interface
|
||||||
|
var result []domain.Owner
|
||||||
|
for _, o := range owners {
|
||||||
|
if o != nil {
|
||||||
|
result = append(result, *o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOwner returns a single owner (handler interface method).
|
||||||
|
func (s *OwnerService) GetOwner(id string) (*domain.Owner, error) {
|
||||||
|
return s.ownerRepo.Get(context.Background(), id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateOwner creates a new owner (handler interface method).
|
||||||
|
func (s *OwnerService) CreateOwner(owner domain.Owner) (*domain.Owner, error) {
|
||||||
|
owner.ID = generateID("owner")
|
||||||
|
if err := s.ownerRepo.Create(context.Background(), &owner); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create owner: %w", err)
|
||||||
|
}
|
||||||
|
return &owner, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateOwner modifies an owner (handler interface method).
|
||||||
|
func (s *OwnerService) UpdateOwner(id string, owner domain.Owner) (*domain.Owner, error) {
|
||||||
|
owner.ID = id
|
||||||
|
if err := s.ownerRepo.Update(context.Background(), &owner); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to update owner: %w", err)
|
||||||
|
}
|
||||||
|
return &owner, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteOwner removes an owner (handler interface method).
|
||||||
|
func (s *OwnerService) DeleteOwner(id string) error {
|
||||||
|
return s.ownerRepo.Delete(context.Background(), id)
|
||||||
|
}
|
||||||
@@ -0,0 +1,155 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
"github.com/shankar0123/certctl/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TargetService provides business logic for deployment target management.
|
||||||
|
type TargetService struct {
|
||||||
|
targetRepo repository.TargetRepository
|
||||||
|
auditService *AuditService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTargetService creates a new target service.
|
||||||
|
func NewTargetService(
|
||||||
|
targetRepo repository.TargetRepository,
|
||||||
|
auditService *AuditService,
|
||||||
|
) *TargetService {
|
||||||
|
return &TargetService{
|
||||||
|
targetRepo: targetRepo,
|
||||||
|
auditService: auditService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns a paginated list of deployment targets.
|
||||||
|
func (s *TargetService) List(ctx context.Context, page, perPage int) ([]*domain.DeploymentTarget, int64, error) {
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if perPage < 1 {
|
||||||
|
perPage = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := int64((page - 1) * perPage)
|
||||||
|
targets, total, err := s.targetRepo.List(ctx, offset, int64(perPage))
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to list targets: %w", err)
|
||||||
|
}
|
||||||
|
return targets, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a deployment target by ID.
|
||||||
|
func (s *TargetService) Get(ctx context.Context, id string) (*domain.DeploymentTarget, error) {
|
||||||
|
target, err := s.targetRepo.Get(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get target %s: %w", id, err)
|
||||||
|
}
|
||||||
|
return target, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create validates and stores a new deployment target.
|
||||||
|
func (s *TargetService) Create(ctx context.Context, target *domain.DeploymentTarget, actor string) error {
|
||||||
|
if target.Name == "" {
|
||||||
|
return fmt.Errorf("target name is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
target.ID = generateID("target")
|
||||||
|
if err := s.targetRepo.Create(ctx, target); err != nil {
|
||||||
|
return fmt.Errorf("failed to create target: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.auditService != nil {
|
||||||
|
_ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "create_target", "target", target.ID, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update modifies an existing deployment target.
|
||||||
|
func (s *TargetService) Update(ctx context.Context, id string, target *domain.DeploymentTarget, actor string) error {
|
||||||
|
if target.Name == "" {
|
||||||
|
return fmt.Errorf("target name is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
target.ID = id
|
||||||
|
if err := s.targetRepo.Update(ctx, target); err != nil {
|
||||||
|
return fmt.Errorf("failed to update target %s: %w", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.auditService != nil {
|
||||||
|
_ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "update_target", "target", id, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a deployment target.
|
||||||
|
func (s *TargetService) Delete(ctx context.Context, id string, actor string) error {
|
||||||
|
if err := s.targetRepo.Delete(ctx, id); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete target %s: %w", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.auditService != nil {
|
||||||
|
_ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "delete_target", "target", id, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListTargets returns paginated targets (handler interface method).
|
||||||
|
func (s *TargetService) ListTargets(page, perPage int) ([]domain.DeploymentTarget, int64, error) {
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if perPage < 1 {
|
||||||
|
perPage = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := int64((page - 1) * perPage)
|
||||||
|
targets, total, err := s.targetRepo.List(context.Background(), offset, int64(perPage))
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to list targets: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert pointers to values for the handler interface
|
||||||
|
var result []domain.DeploymentTarget
|
||||||
|
for _, t := range targets {
|
||||||
|
if t != nil {
|
||||||
|
result = append(result, *t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTarget returns a single target (handler interface method).
|
||||||
|
func (s *TargetService) GetTarget(id string) (*domain.DeploymentTarget, error) {
|
||||||
|
return s.targetRepo.Get(context.Background(), id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTarget creates a new target (handler interface method).
|
||||||
|
func (s *TargetService) CreateTarget(target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
||||||
|
target.ID = generateID("target")
|
||||||
|
if err := s.targetRepo.Create(context.Background(), &target); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create target: %w", err)
|
||||||
|
}
|
||||||
|
return &target, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTarget modifies a target (handler interface method).
|
||||||
|
func (s *TargetService) UpdateTarget(id string, target domain.DeploymentTarget) (*domain.DeploymentTarget, error) {
|
||||||
|
target.ID = id
|
||||||
|
if err := s.targetRepo.Update(context.Background(), &target); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to update target: %w", err)
|
||||||
|
}
|
||||||
|
return &target, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteTarget removes a target (handler interface method).
|
||||||
|
func (s *TargetService) DeleteTarget(id string) error {
|
||||||
|
return s.targetRepo.Delete(context.Background(), id)
|
||||||
|
}
|
||||||
@@ -0,0 +1,155 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/shankar0123/certctl/internal/domain"
|
||||||
|
"github.com/shankar0123/certctl/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TeamService provides business logic for team management.
|
||||||
|
type TeamService struct {
|
||||||
|
teamRepo repository.TeamRepository
|
||||||
|
auditService *AuditService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTeamService creates a new team service.
|
||||||
|
func NewTeamService(
|
||||||
|
teamRepo repository.TeamRepository,
|
||||||
|
auditService *AuditService,
|
||||||
|
) *TeamService {
|
||||||
|
return &TeamService{
|
||||||
|
teamRepo: teamRepo,
|
||||||
|
auditService: auditService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns a paginated list of teams.
|
||||||
|
func (s *TeamService) List(ctx context.Context, page, perPage int) ([]*domain.Team, int64, error) {
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if perPage < 1 {
|
||||||
|
perPage = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := int64((page - 1) * perPage)
|
||||||
|
teams, total, err := s.teamRepo.List(ctx, offset, int64(perPage))
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to list teams: %w", err)
|
||||||
|
}
|
||||||
|
return teams, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a team by ID.
|
||||||
|
func (s *TeamService) Get(ctx context.Context, id string) (*domain.Team, error) {
|
||||||
|
team, err := s.teamRepo.Get(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get team %s: %w", id, err)
|
||||||
|
}
|
||||||
|
return team, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create validates and stores a new team.
|
||||||
|
func (s *TeamService) Create(ctx context.Context, team *domain.Team, actor string) error {
|
||||||
|
if team.Name == "" {
|
||||||
|
return fmt.Errorf("team name is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
team.ID = generateID("team")
|
||||||
|
if err := s.teamRepo.Create(ctx, team); err != nil {
|
||||||
|
return fmt.Errorf("failed to create team: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.auditService != nil {
|
||||||
|
_ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "create_team", "team", team.ID, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update modifies an existing team.
|
||||||
|
func (s *TeamService) Update(ctx context.Context, id string, team *domain.Team, actor string) error {
|
||||||
|
if team.Name == "" {
|
||||||
|
return fmt.Errorf("team name is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
team.ID = id
|
||||||
|
if err := s.teamRepo.Update(ctx, team); err != nil {
|
||||||
|
return fmt.Errorf("failed to update team %s: %w", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.auditService != nil {
|
||||||
|
_ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "update_team", "team", id, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a team.
|
||||||
|
func (s *TeamService) Delete(ctx context.Context, id string, actor string) error {
|
||||||
|
if err := s.teamRepo.Delete(ctx, id); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete team %s: %w", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.auditService != nil {
|
||||||
|
_ = s.auditService.RecordEvent(ctx, actor, domain.ActorTypeUser, "delete_team", "team", id, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListTeams returns paginated teams (handler interface method).
|
||||||
|
func (s *TeamService) ListTeams(page, perPage int) ([]domain.Team, int64, error) {
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if perPage < 1 {
|
||||||
|
perPage = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := int64((page - 1) * perPage)
|
||||||
|
teams, total, err := s.teamRepo.List(context.Background(), offset, int64(perPage))
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to list teams: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert pointers to values for the handler interface
|
||||||
|
var result []domain.Team
|
||||||
|
for _, t := range teams {
|
||||||
|
if t != nil {
|
||||||
|
result = append(result, *t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTeam returns a single team (handler interface method).
|
||||||
|
func (s *TeamService) GetTeam(id string) (*domain.Team, error) {
|
||||||
|
return s.teamRepo.Get(context.Background(), id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTeam creates a new team (handler interface method).
|
||||||
|
func (s *TeamService) CreateTeam(team domain.Team) (*domain.Team, error) {
|
||||||
|
team.ID = generateID("team")
|
||||||
|
if err := s.teamRepo.Create(context.Background(), &team); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create team: %w", err)
|
||||||
|
}
|
||||||
|
return &team, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTeam modifies a team (handler interface method).
|
||||||
|
func (s *TeamService) UpdateTeam(id string, team domain.Team) (*domain.Team, error) {
|
||||||
|
team.ID = id
|
||||||
|
if err := s.teamRepo.Update(context.Background(), &team); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to update team: %w", err)
|
||||||
|
}
|
||||||
|
return &team, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteTeam removes a team (handler interface method).
|
||||||
|
func (s *TeamService) DeleteTeam(id string) error {
|
||||||
|
return s.teamRepo.Delete(context.Background(), id)
|
||||||
|
}
|
||||||
+1868
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user