mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-09 09:08:51 +00:00
Complete V1 scaffold
This commit is contained in:
@@ -0,0 +1,193 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
// AgentRepository implements repository.AgentRepository
|
||||
type AgentRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewAgentRepository creates a new AgentRepository
|
||||
func NewAgentRepository(db *sql.DB) *AgentRepository {
|
||||
return &AgentRepository{db: db}
|
||||
}
|
||||
|
||||
// List returns all agents
|
||||
func (r *AgentRepository) List(ctx context.Context) ([]*domain.Agent, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash
|
||||
FROM agents
|
||||
ORDER BY registered_at DESC
|
||||
`)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query agents: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var agents []*domain.Agent
|
||||
for rows.Next() {
|
||||
agent, err := scanAgent(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
agents = append(agents, agent)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating agent rows: %w", err)
|
||||
}
|
||||
|
||||
return agents, nil
|
||||
}
|
||||
|
||||
// Get retrieves an agent by ID
|
||||
func (r *AgentRepository) Get(ctx context.Context, id string) (*domain.Agent, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash
|
||||
FROM agents
|
||||
WHERE id = $1
|
||||
`, id)
|
||||
|
||||
agent, err := scanAgent(row)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("agent not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to query agent: %w", err)
|
||||
}
|
||||
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
// Create stores a new agent
|
||||
func (r *AgentRepository) Create(ctx context.Context, agent *domain.Agent) error {
|
||||
if agent.ID == "" {
|
||||
agent.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO agents (id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id
|
||||
`, agent.ID, agent.Name, agent.Hostname, agent.Status, agent.LastHeartbeatAt,
|
||||
agent.RegisteredAt, agent.APIKeyHash).Scan(&agent.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create agent: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update modifies an existing agent
|
||||
func (r *AgentRepository) Update(ctx context.Context, agent *domain.Agent) error {
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE agents SET
|
||||
name = $1,
|
||||
hostname = $2,
|
||||
status = $3,
|
||||
last_heartbeat_at = $4,
|
||||
api_key_hash = $5
|
||||
WHERE id = $6
|
||||
`, agent.Name, agent.Hostname, agent.Status, agent.LastHeartbeatAt, agent.APIKeyHash, agent.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update agent: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("agent not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes an agent
|
||||
func (r *AgentRepository) Delete(ctx context.Context, id string) error {
|
||||
result, err := r.db.ExecContext(ctx, "DELETE FROM agents WHERE id = $1", id)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete agent: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("agent not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateHeartbeat updates the agent's last heartbeat timestamp
|
||||
func (r *AgentRepository) UpdateHeartbeat(ctx context.Context, id string) error {
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE agents SET last_heartbeat_at = $1 WHERE id = $2
|
||||
`, time.Now(), id)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update heartbeat: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("agent not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByAPIKey retrieves an agent by hashed API key
|
||||
func (r *AgentRepository) GetByAPIKey(ctx context.Context, keyHash string) (*domain.Agent, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, name, hostname, status, last_heartbeat_at, registered_at, api_key_hash
|
||||
FROM agents
|
||||
WHERE api_key_hash = $1
|
||||
`, keyHash)
|
||||
|
||||
agent, err := scanAgent(row)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("agent not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to query agent: %w", err)
|
||||
}
|
||||
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
// scanAgent scans an agent from a row or rows
|
||||
func scanAgent(scanner interface {
|
||||
Scan(...interface{}) error
|
||||
}) (*domain.Agent, error) {
|
||||
var agent domain.Agent
|
||||
err := scanner.Scan(&agent.ID, &agent.Name, &agent.Hostname, &agent.Status,
|
||||
&agent.LastHeartbeatAt, &agent.RegisteredAt, &agent.APIKeyHash)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan agent: %w", err)
|
||||
}
|
||||
|
||||
return &agent, nil
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/repository"
|
||||
)
|
||||
|
||||
// AuditRepository implements repository.AuditRepository
|
||||
type AuditRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewAuditRepository creates a new AuditRepository
|
||||
func NewAuditRepository(db *sql.DB) *AuditRepository {
|
||||
return &AuditRepository{db: db}
|
||||
}
|
||||
|
||||
// Create stores a new audit event
|
||||
func (r *AuditRepository) Create(ctx context.Context, event *domain.AuditEvent) error {
|
||||
if event.ID == "" {
|
||||
event.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO audit_events (
|
||||
id, actor, actor_type, action, resource_type, resource_id, details, timestamp
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id
|
||||
`, event.ID, event.Actor, event.ActorType, event.Action, event.ResourceType,
|
||||
event.ResourceID, event.Details, event.Timestamp).Scan(&event.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create audit event: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// List returns audit events matching the filter criteria
|
||||
func (r *AuditRepository) List(ctx context.Context, filter *repository.AuditFilter) ([]*domain.AuditEvent, error) {
|
||||
if filter == nil {
|
||||
filter = &repository.AuditFilter{}
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
if filter.Page < 1 {
|
||||
filter.Page = 1
|
||||
}
|
||||
if filter.PerPage == 0 || filter.PerPage > 500 {
|
||||
filter.PerPage = 50
|
||||
}
|
||||
|
||||
// Build WHERE clause
|
||||
var whereConditions []string
|
||||
var args []interface{}
|
||||
argCount := 1
|
||||
|
||||
if filter.Actor != "" {
|
||||
whereConditions = append(whereConditions, fmt.Sprintf("actor = $%d", argCount))
|
||||
args = append(args, filter.Actor)
|
||||
argCount++
|
||||
}
|
||||
if filter.ActorType != "" {
|
||||
whereConditions = append(whereConditions, fmt.Sprintf("actor_type = $%d", argCount))
|
||||
args = append(args, filter.ActorType)
|
||||
argCount++
|
||||
}
|
||||
if filter.ResourceType != "" {
|
||||
whereConditions = append(whereConditions, fmt.Sprintf("resource_type = $%d", argCount))
|
||||
args = append(args, filter.ResourceType)
|
||||
argCount++
|
||||
}
|
||||
if filter.ResourceID != "" {
|
||||
whereConditions = append(whereConditions, fmt.Sprintf("resource_id = $%d", argCount))
|
||||
args = append(args, filter.ResourceID)
|
||||
argCount++
|
||||
}
|
||||
if !filter.From.IsZero() {
|
||||
whereConditions = append(whereConditions, fmt.Sprintf("timestamp >= $%d", argCount))
|
||||
args = append(args, filter.From)
|
||||
argCount++
|
||||
}
|
||||
if !filter.To.IsZero() {
|
||||
whereConditions = append(whereConditions, fmt.Sprintf("timestamp <= $%d", argCount))
|
||||
args = append(args, filter.To)
|
||||
argCount++
|
||||
}
|
||||
|
||||
whereClause := ""
|
||||
if len(whereConditions) > 0 {
|
||||
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
|
||||
}
|
||||
|
||||
// Get total count
|
||||
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM audit_events %s", whereClause)
|
||||
var total int
|
||||
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
return nil, fmt.Errorf("failed to count audit events: %w", err)
|
||||
}
|
||||
|
||||
// Get paginated results
|
||||
offset := (filter.Page - 1) * filter.PerPage
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, actor, actor_type, action, resource_type, resource_id, details, timestamp
|
||||
FROM audit_events
|
||||
%s
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT $%d OFFSET $%d
|
||||
`, whereClause, argCount, argCount+1)
|
||||
|
||||
args = append(args, filter.PerPage, offset)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query audit events: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var events []*domain.AuditEvent
|
||||
for rows.Next() {
|
||||
var event domain.AuditEvent
|
||||
if err := rows.Scan(&event.ID, &event.Actor, &event.ActorType, &event.Action,
|
||||
&event.ResourceType, &event.ResourceID, &event.Details, &event.Timestamp); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan audit event: %w", err)
|
||||
}
|
||||
events = append(events, &event)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating audit event rows: %w", err)
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
@@ -0,0 +1,346 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/repository"
|
||||
)
|
||||
|
||||
// CertificateRepository implements repository.CertificateRepository
|
||||
type CertificateRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewCertificateRepository creates a new CertificateRepository
|
||||
func NewCertificateRepository(db *sql.DB) *CertificateRepository {
|
||||
return &CertificateRepository{db: db}
|
||||
}
|
||||
|
||||
// List returns a paginated list of certificates matching the filter criteria
|
||||
func (r *CertificateRepository) List(ctx context.Context, filter *repository.CertificateFilter) ([]*domain.ManagedCertificate, int, error) {
|
||||
if filter == nil {
|
||||
filter = &repository.CertificateFilter{}
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
if filter.Page < 1 {
|
||||
filter.Page = 1
|
||||
}
|
||||
if filter.PerPage == 0 || filter.PerPage > 500 {
|
||||
filter.PerPage = 50
|
||||
}
|
||||
|
||||
// Build WHERE clause
|
||||
var whereConditions []string
|
||||
var args []interface{}
|
||||
argCount := 1
|
||||
|
||||
if filter.Status != "" {
|
||||
whereConditions = append(whereConditions, fmt.Sprintf("status = $%d", argCount))
|
||||
args = append(args, filter.Status)
|
||||
argCount++
|
||||
}
|
||||
if filter.Environment != "" {
|
||||
whereConditions = append(whereConditions, fmt.Sprintf("environment = $%d", argCount))
|
||||
args = append(args, filter.Environment)
|
||||
argCount++
|
||||
}
|
||||
if filter.OwnerID != "" {
|
||||
whereConditions = append(whereConditions, fmt.Sprintf("owner_id = $%d", argCount))
|
||||
args = append(args, filter.OwnerID)
|
||||
argCount++
|
||||
}
|
||||
if filter.TeamID != "" {
|
||||
whereConditions = append(whereConditions, fmt.Sprintf("team_id = $%d", argCount))
|
||||
args = append(args, filter.TeamID)
|
||||
argCount++
|
||||
}
|
||||
if filter.IssuerID != "" {
|
||||
whereConditions = append(whereConditions, fmt.Sprintf("issuer_id = $%d", argCount))
|
||||
args = append(args, filter.IssuerID)
|
||||
argCount++
|
||||
}
|
||||
|
||||
whereClause := ""
|
||||
if len(whereConditions) > 0 {
|
||||
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
|
||||
}
|
||||
|
||||
// Get total count
|
||||
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM certificates %s", whereClause)
|
||||
var total int
|
||||
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to count certificates: %w", err)
|
||||
}
|
||||
|
||||
// Get paginated results
|
||||
offset := (filter.Page - 1) * filter.PerPage
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id,
|
||||
status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
|
||||
FROM certificates
|
||||
%s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $%d OFFSET $%d
|
||||
`, whereClause, argCount, argCount+1)
|
||||
|
||||
args = append(args, filter.PerPage, offset)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to query certificates: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var certs []*domain.ManagedCertificate
|
||||
for rows.Next() {
|
||||
cert, err := scanCertificate(rows)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
certs = append(certs, cert)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, fmt.Errorf("error iterating certificate rows: %w", err)
|
||||
}
|
||||
|
||||
return certs, total, nil
|
||||
}
|
||||
|
||||
// Get retrieves a certificate by ID
|
||||
func (r *CertificateRepository) Get(ctx context.Context, id string) (*domain.ManagedCertificate, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id,
|
||||
status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
|
||||
FROM certificates
|
||||
WHERE id = $1
|
||||
`, id)
|
||||
|
||||
cert, err := scanCertificate(row)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("certificate not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to query certificate: %w", err)
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// Create stores a new certificate
|
||||
func (r *CertificateRepository) Create(ctx context.Context, cert *domain.ManagedCertificate) error {
|
||||
if cert.ID == "" {
|
||||
cert.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
tagsJSON, err := json.Marshal(cert.Tags)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal tags: %w", err)
|
||||
}
|
||||
|
||||
err = r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO certificates (
|
||||
id, name, common_name, sans, environment, owner_id, team_id, issuer_id,
|
||||
status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
|
||||
RETURNING id
|
||||
`, cert.ID, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment,
|
||||
cert.OwnerID, cert.TeamID, cert.IssuerID, cert.Status, cert.ExpiresAt,
|
||||
tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.CreatedAt, cert.UpdatedAt).Scan(&cert.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create certificate: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update modifies an existing certificate
|
||||
func (r *CertificateRepository) Update(ctx context.Context, cert *domain.ManagedCertificate) error {
|
||||
tagsJSON, err := json.Marshal(cert.Tags)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal tags: %w", err)
|
||||
}
|
||||
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE certificates SET
|
||||
name = $1,
|
||||
common_name = $2,
|
||||
sans = $3,
|
||||
environment = $4,
|
||||
owner_id = $5,
|
||||
team_id = $6,
|
||||
issuer_id = $7,
|
||||
status = $8,
|
||||
expires_at = $9,
|
||||
tags = $10,
|
||||
last_renewal_at = $11,
|
||||
last_deployment_at = $12,
|
||||
updated_at = $13
|
||||
WHERE id = $14
|
||||
`, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment,
|
||||
cert.OwnerID, cert.TeamID, cert.IssuerID, cert.Status, cert.ExpiresAt,
|
||||
tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.UpdatedAt, cert.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update certificate: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("certificate not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Archive marks a certificate as archived
|
||||
func (r *CertificateRepository) Archive(ctx context.Context, id string) error {
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE certificates SET status = $1, updated_at = $2 WHERE id = $3
|
||||
`, domain.CertificateStatusArchived, time.Now(), id)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to archive certificate: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("certificate not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListVersions returns all versions of a certificate
|
||||
func (r *CertificateRepository) ListVersions(ctx context.Context, certID string) ([]*domain.CertificateVersion, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, certificate_id, serial_number, not_before, not_after,
|
||||
fingerprint_sha256, pem_chain, csr_pem, created_at
|
||||
FROM certificate_versions
|
||||
WHERE certificate_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`, certID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query certificate versions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var versions []*domain.CertificateVersion
|
||||
for rows.Next() {
|
||||
var v domain.CertificateVersion
|
||||
if err := rows.Scan(&v.ID, &v.CertificateID, &v.SerialNumber, &v.NotBefore, &v.NotAfter,
|
||||
&v.FingerprintSHA256, &v.PEMChain, &v.CSRPEM, &v.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan certificate version: %w", err)
|
||||
}
|
||||
versions = append(versions, &v)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating version rows: %w", err)
|
||||
}
|
||||
|
||||
return versions, nil
|
||||
}
|
||||
|
||||
// CreateVersion stores a new certificate version
|
||||
func (r *CertificateRepository) CreateVersion(ctx context.Context, version *domain.CertificateVersion) error {
|
||||
if version.ID == "" {
|
||||
version.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO certificate_versions (
|
||||
id, certificate_id, serial_number, not_before, not_after,
|
||||
fingerprint_sha256, pem_chain, csr_pem, created_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
RETURNING id
|
||||
`, version.ID, version.CertificateID, version.SerialNumber, version.NotBefore, version.NotAfter,
|
||||
version.FingerprintSHA256, version.PEMChain, version.CSRPEM, version.CreatedAt).Scan(&version.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create certificate version: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetExpiringCertificates returns certificates expiring before the given time
|
||||
func (r *CertificateRepository) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id,
|
||||
status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at
|
||||
FROM certificates
|
||||
WHERE expires_at < $1 AND status != $2
|
||||
ORDER BY expires_at ASC
|
||||
`, before, domain.CertificateStatusArchived)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query expiring certificates: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var certs []*domain.ManagedCertificate
|
||||
for rows.Next() {
|
||||
cert, err := scanCertificate(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certs = append(certs, cert)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating expiring certificate rows: %w", err)
|
||||
}
|
||||
|
||||
return certs, nil
|
||||
}
|
||||
|
||||
// scanCertificate scans a certificate from a row or rows
|
||||
func scanCertificate(scanner interface {
|
||||
Scan(...interface{}) error
|
||||
}) (*domain.ManagedCertificate, error) {
|
||||
var cert domain.ManagedCertificate
|
||||
var tagsJSON []byte
|
||||
var sans pq.StringArray
|
||||
|
||||
err := scanner.Scan(
|
||||
&cert.ID, &cert.Name, &cert.CommonName, &sans, &cert.Environment, &cert.OwnerID,
|
||||
&cert.TeamID, &cert.IssuerID, &cert.Status, &cert.ExpiresAt, &tagsJSON,
|
||||
&cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.CreatedAt, &cert.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan certificate: %w", err)
|
||||
}
|
||||
|
||||
cert.SANs = []string(sans)
|
||||
|
||||
// Unmarshal tags
|
||||
if len(tagsJSON) > 0 {
|
||||
if err := json.Unmarshal(tagsJSON, &cert.Tags); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal tags: %w", err)
|
||||
}
|
||||
} else {
|
||||
cert.Tags = make(map[string]string)
|
||||
}
|
||||
|
||||
return &cert, nil
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// NewDB opens a PostgreSQL database connection and sets up connection pooling.
|
||||
func NewDB(connStr string) (*sql.DB, error) {
|
||||
db, err := sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
|
||||
// Ping to verify connection
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// RunMigrations reads and executes SQL migration files from a directory.
|
||||
func RunMigrations(db *sql.DB, migrationsPath string) error {
|
||||
// Check if migrations directory exists
|
||||
if _, err := os.Stat(migrationsPath); os.IsNotExist(err) {
|
||||
return fmt.Errorf("migrations directory not found: %s", migrationsPath)
|
||||
}
|
||||
|
||||
// Read all SQL files from the migrations directory
|
||||
files, err := os.ReadDir(migrationsPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read migrations directory: %w", err)
|
||||
}
|
||||
|
||||
// Sort and filter SQL files
|
||||
var sqlFiles []string
|
||||
for _, file := range files {
|
||||
if !file.IsDir() && strings.HasSuffix(file.Name(), ".sql") {
|
||||
sqlFiles = append(sqlFiles, file.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// Execute each migration file in order
|
||||
for _, filename := range sqlFiles {
|
||||
filePath := filepath.Join(migrationsPath, filename)
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read migration file %s: %w", filename, err)
|
||||
}
|
||||
|
||||
// Execute the SQL content
|
||||
if _, err := db.Exec(string(content)); err != nil {
|
||||
return fmt.Errorf("failed to execute migration %s: %w", filename, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
// IssuerRepository implements repository.IssuerRepository
|
||||
type IssuerRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewIssuerRepository creates a new IssuerRepository
|
||||
func NewIssuerRepository(db *sql.DB) *IssuerRepository {
|
||||
return &IssuerRepository{db: db}
|
||||
}
|
||||
|
||||
// List returns all issuers
|
||||
func (r *IssuerRepository) List(ctx context.Context) ([]*domain.Issuer, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, name, type, config, enabled, created_at, updated_at
|
||||
FROM issuers
|
||||
ORDER BY created_at DESC
|
||||
`)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query issuers: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var issuers []*domain.Issuer
|
||||
for rows.Next() {
|
||||
var issuer domain.Issuer
|
||||
if err := rows.Scan(&issuer.ID, &issuer.Name, &issuer.Type, &issuer.Config,
|
||||
&issuer.Enabled, &issuer.CreatedAt, &issuer.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan issuer: %w", err)
|
||||
}
|
||||
issuers = append(issuers, &issuer)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating issuer rows: %w", err)
|
||||
}
|
||||
|
||||
return issuers, nil
|
||||
}
|
||||
|
||||
// Get retrieves an issuer by ID
|
||||
func (r *IssuerRepository) Get(ctx context.Context, id string) (*domain.Issuer, error) {
|
||||
var issuer domain.Issuer
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, name, type, config, enabled, created_at, updated_at
|
||||
FROM issuers
|
||||
WHERE id = $1
|
||||
`, id).Scan(&issuer.ID, &issuer.Name, &issuer.Type, &issuer.Config,
|
||||
&issuer.Enabled, &issuer.CreatedAt, &issuer.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("issuer not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to query issuer: %w", err)
|
||||
}
|
||||
|
||||
return &issuer, nil
|
||||
}
|
||||
|
||||
// Create stores a new issuer
|
||||
func (r *IssuerRepository) Create(ctx context.Context, issuer *domain.Issuer) error {
|
||||
if issuer.ID == "" {
|
||||
issuer.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO issuers (id, name, type, config, enabled, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id
|
||||
`, issuer.ID, issuer.Name, issuer.Type, issuer.Config, issuer.Enabled,
|
||||
issuer.CreatedAt, issuer.UpdatedAt).Scan(&issuer.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create issuer: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update modifies an existing issuer
|
||||
func (r *IssuerRepository) Update(ctx context.Context, issuer *domain.Issuer) error {
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE issuers SET
|
||||
name = $1,
|
||||
type = $2,
|
||||
config = $3,
|
||||
enabled = $4,
|
||||
updated_at = $5
|
||||
WHERE id = $6
|
||||
`, issuer.Name, issuer.Type, issuer.Config, issuer.Enabled, issuer.UpdatedAt, issuer.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update issuer: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("issuer not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes an issuer
|
||||
func (r *IssuerRepository) Delete(ctx context.Context, id string) error {
|
||||
result, err := r.db.ExecContext(ctx, "DELETE FROM issuers WHERE id = $1", id)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete issuer: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("issuer not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,284 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
// JobRepository implements repository.JobRepository
|
||||
type JobRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewJobRepository creates a new JobRepository
|
||||
func NewJobRepository(db *sql.DB) *JobRepository {
|
||||
return &JobRepository{db: db}
|
||||
}
|
||||
|
||||
// List returns all jobs
|
||||
func (r *JobRepository) List(ctx context.Context) ([]*domain.Job, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, type, certificate_id, target_id, status, attempts, max_attempts,
|
||||
last_error, scheduled_at, started_at, completed_at, created_at
|
||||
FROM jobs
|
||||
ORDER BY created_at DESC
|
||||
`)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query jobs: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var jobs []*domain.Job
|
||||
for rows.Next() {
|
||||
job, err := scanJob(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating job rows: %w", err)
|
||||
}
|
||||
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
// Get retrieves a job by ID
|
||||
func (r *JobRepository) Get(ctx context.Context, id string) (*domain.Job, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, type, certificate_id, target_id, status, attempts, max_attempts,
|
||||
last_error, scheduled_at, started_at, completed_at, created_at
|
||||
FROM jobs
|
||||
WHERE id = $1
|
||||
`, id)
|
||||
|
||||
job, err := scanJob(row)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("job not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to query job: %w", err)
|
||||
}
|
||||
|
||||
return job, nil
|
||||
}
|
||||
|
||||
// Create stores a new job
|
||||
func (r *JobRepository) Create(ctx context.Context, job *domain.Job) error {
|
||||
if job.ID == "" {
|
||||
job.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO jobs (
|
||||
id, type, certificate_id, target_id, status, attempts, max_attempts,
|
||||
last_error, scheduled_at, started_at, completed_at, created_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
RETURNING id
|
||||
`, job.ID, job.Type, job.CertificateID, job.TargetID, job.Status, job.Attempts,
|
||||
job.MaxAttempts, job.LastError, job.ScheduledAt, job.StartedAt, job.CompletedAt,
|
||||
job.CreatedAt).Scan(&job.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create job: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update modifies an existing job
|
||||
func (r *JobRepository) Update(ctx context.Context, job *domain.Job) error {
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE jobs SET
|
||||
type = $1,
|
||||
certificate_id = $2,
|
||||
target_id = $3,
|
||||
status = $4,
|
||||
attempts = $5,
|
||||
max_attempts = $6,
|
||||
last_error = $7,
|
||||
scheduled_at = $8,
|
||||
started_at = $9,
|
||||
completed_at = $10
|
||||
WHERE id = $11
|
||||
`, job.Type, job.CertificateID, job.TargetID, job.Status, job.Attempts,
|
||||
job.MaxAttempts, job.LastError, job.ScheduledAt, job.StartedAt,
|
||||
job.CompletedAt, job.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update job: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("job not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a job
|
||||
func (r *JobRepository) Delete(ctx context.Context, id string) error {
|
||||
result, err := r.db.ExecContext(ctx, "DELETE FROM jobs WHERE id = $1", id)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete job: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("job not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListByStatus returns jobs with a specific status
|
||||
func (r *JobRepository) ListByStatus(ctx context.Context, status domain.JobStatus) ([]*domain.Job, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, type, certificate_id, target_id, status, attempts, max_attempts,
|
||||
last_error, scheduled_at, started_at, completed_at, created_at
|
||||
FROM jobs
|
||||
WHERE status = $1
|
||||
ORDER BY created_at DESC
|
||||
`, status)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query jobs by status: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var jobs []*domain.Job
|
||||
for rows.Next() {
|
||||
job, err := scanJob(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating job rows: %w", err)
|
||||
}
|
||||
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
// ListByCertificate returns all jobs for a certificate
|
||||
func (r *JobRepository) ListByCertificate(ctx context.Context, certID string) ([]*domain.Job, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, type, certificate_id, target_id, status, attempts, max_attempts,
|
||||
last_error, scheduled_at, started_at, completed_at, created_at
|
||||
FROM jobs
|
||||
WHERE certificate_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`, certID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query jobs for certificate: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var jobs []*domain.Job
|
||||
for rows.Next() {
|
||||
job, err := scanJob(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating job rows: %w", err)
|
||||
}
|
||||
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates a job's status and optional error message
|
||||
func (r *JobRepository) UpdateStatus(ctx context.Context, id string, status domain.JobStatus, errMsg string) error {
|
||||
var lastError *string
|
||||
if errMsg != "" {
|
||||
lastError = &errMsg
|
||||
}
|
||||
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE jobs SET status = $1, last_error = $2 WHERE id = $3
|
||||
`, status, lastError, id)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update job status: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("job not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPendingJobs returns jobs not yet processed of a specific type
|
||||
func (r *JobRepository) GetPendingJobs(ctx context.Context, jobType domain.JobType) ([]*domain.Job, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, type, certificate_id, target_id, status, attempts, max_attempts,
|
||||
last_error, scheduled_at, started_at, completed_at, created_at
|
||||
FROM jobs
|
||||
WHERE type = $1 AND status = $2
|
||||
ORDER BY scheduled_at ASC
|
||||
`, jobType, domain.JobStatusPending)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query pending jobs: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var jobs []*domain.Job
|
||||
for rows.Next() {
|
||||
job, err := scanJob(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating job rows: %w", err)
|
||||
}
|
||||
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
// scanJob scans a job from a row or rows
|
||||
func scanJob(scanner interface {
|
||||
Scan(...interface{}) error
|
||||
}) (*domain.Job, error) {
|
||||
var job domain.Job
|
||||
err := scanner.Scan(&job.ID, &job.Type, &job.CertificateID, &job.TargetID,
|
||||
&job.Status, &job.Attempts, &job.MaxAttempts, &job.LastError,
|
||||
&job.ScheduledAt, &job.StartedAt, &job.CompletedAt, &job.CreatedAt)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan job: %w", err)
|
||||
}
|
||||
|
||||
return &job, nil
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/repository"
|
||||
)
|
||||
|
||||
// NotificationRepository implements repository.NotificationRepository
|
||||
type NotificationRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewNotificationRepository creates a new NotificationRepository
|
||||
func NewNotificationRepository(db *sql.DB) *NotificationRepository {
|
||||
return &NotificationRepository{db: db}
|
||||
}
|
||||
|
||||
// Create stores a new notification
|
||||
func (r *NotificationRepository) Create(ctx context.Context, notif *domain.NotificationEvent) error {
|
||||
if notif.ID == "" {
|
||||
notif.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO notifications (
|
||||
id, type, certificate_id, channel, recipient, message, sent_at, status, error, created_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
RETURNING id
|
||||
`, notif.ID, notif.Type, notif.CertificateID, notif.Channel, notif.Recipient,
|
||||
notif.Message, notif.SentAt, notif.Status, notif.Error, notif.CreatedAt).Scan(¬if.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create notification: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// List returns notifications matching the filter criteria
|
||||
func (r *NotificationRepository) List(ctx context.Context, filter *repository.NotificationFilter) ([]*domain.NotificationEvent, error) {
|
||||
if filter == nil {
|
||||
filter = &repository.NotificationFilter{}
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
if filter.Page < 1 {
|
||||
filter.Page = 1
|
||||
}
|
||||
if filter.PerPage == 0 || filter.PerPage > 500 {
|
||||
filter.PerPage = 50
|
||||
}
|
||||
|
||||
// Build WHERE clause
|
||||
var whereConditions []string
|
||||
var args []interface{}
|
||||
argCount := 1
|
||||
|
||||
if filter.CertificateID != "" {
|
||||
whereConditions = append(whereConditions, fmt.Sprintf("certificate_id = $%d", argCount))
|
||||
args = append(args, filter.CertificateID)
|
||||
argCount++
|
||||
}
|
||||
if filter.Status != "" {
|
||||
whereConditions = append(whereConditions, fmt.Sprintf("status = $%d", argCount))
|
||||
args = append(args, filter.Status)
|
||||
argCount++
|
||||
}
|
||||
if filter.Channel != "" {
|
||||
whereConditions = append(whereConditions, fmt.Sprintf("channel = $%d", argCount))
|
||||
args = append(args, filter.Channel)
|
||||
argCount++
|
||||
}
|
||||
|
||||
whereClause := ""
|
||||
if len(whereConditions) > 0 {
|
||||
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
|
||||
}
|
||||
|
||||
// Get total count
|
||||
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM notifications %s", whereClause)
|
||||
var total int
|
||||
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
return nil, fmt.Errorf("failed to count notifications: %w", err)
|
||||
}
|
||||
|
||||
// Get paginated results
|
||||
offset := (filter.Page - 1) * filter.PerPage
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, type, certificate_id, channel, recipient, message, sent_at, status, error, created_at
|
||||
FROM notifications
|
||||
%s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $%d OFFSET $%d
|
||||
`, whereClause, argCount, argCount+1)
|
||||
|
||||
args = append(args, filter.PerPage, offset)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query notifications: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var notifs []*domain.NotificationEvent
|
||||
for rows.Next() {
|
||||
notif, err := scanNotification(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
notifs = append(notifs, notif)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating notification rows: %w", err)
|
||||
}
|
||||
|
||||
return notifs, nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates a notification's delivery status
|
||||
func (r *NotificationRepository) UpdateStatus(ctx context.Context, id string, status string, sentAt time.Time) error {
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE notifications SET status = $1, sent_at = $2 WHERE id = $3
|
||||
`, status, sentAt, id)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update notification status: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("notification not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// scanNotification scans a notification from a row or rows
|
||||
func scanNotification(scanner interface {
|
||||
Scan(...interface{}) error
|
||||
}) (*domain.NotificationEvent, error) {
|
||||
var notif domain.NotificationEvent
|
||||
err := scanner.Scan(¬if.ID, ¬if.Type, ¬if.CertificateID, ¬if.Channel,
|
||||
¬if.Recipient, ¬if.Message, ¬if.SentAt, ¬if.Status, ¬if.Error, ¬if.CreatedAt)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan notification: %w", err)
|
||||
}
|
||||
|
||||
return ¬if, nil
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
// OwnerRepository implements repository.OwnerRepository
|
||||
type OwnerRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewOwnerRepository creates a new OwnerRepository
|
||||
func NewOwnerRepository(db *sql.DB) *OwnerRepository {
|
||||
return &OwnerRepository{db: db}
|
||||
}
|
||||
|
||||
// List returns all owners
|
||||
func (r *OwnerRepository) List(ctx context.Context) ([]*domain.Owner, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, name, email, team_id, created_at, updated_at
|
||||
FROM owners
|
||||
ORDER BY created_at DESC
|
||||
`)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query owners: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var owners []*domain.Owner
|
||||
for rows.Next() {
|
||||
var owner domain.Owner
|
||||
if err := rows.Scan(&owner.ID, &owner.Name, &owner.Email, &owner.TeamID,
|
||||
&owner.CreatedAt, &owner.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan owner: %w", err)
|
||||
}
|
||||
owners = append(owners, &owner)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating owner rows: %w", err)
|
||||
}
|
||||
|
||||
return owners, nil
|
||||
}
|
||||
|
||||
// Get retrieves an owner by ID
|
||||
func (r *OwnerRepository) Get(ctx context.Context, id string) (*domain.Owner, error) {
|
||||
var owner domain.Owner
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, name, email, team_id, created_at, updated_at
|
||||
FROM owners
|
||||
WHERE id = $1
|
||||
`, id).Scan(&owner.ID, &owner.Name, &owner.Email, &owner.TeamID,
|
||||
&owner.CreatedAt, &owner.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("owner not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to query owner: %w", err)
|
||||
}
|
||||
|
||||
return &owner, nil
|
||||
}
|
||||
|
||||
// Create stores a new owner
|
||||
func (r *OwnerRepository) Create(ctx context.Context, owner *domain.Owner) error {
|
||||
if owner.ID == "" {
|
||||
owner.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO owners (id, name, email, team_id, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id
|
||||
`, owner.ID, owner.Name, owner.Email, owner.TeamID,
|
||||
owner.CreatedAt, owner.UpdatedAt).Scan(&owner.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create owner: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update modifies an existing owner
|
||||
func (r *OwnerRepository) Update(ctx context.Context, owner *domain.Owner) error {
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE owners SET
|
||||
name = $1,
|
||||
email = $2,
|
||||
team_id = $3,
|
||||
updated_at = $4
|
||||
WHERE id = $5
|
||||
`, owner.Name, owner.Email, owner.TeamID, owner.UpdatedAt, owner.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update owner: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("owner not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes an owner
|
||||
func (r *OwnerRepository) Delete(ctx context.Context, id string) error {
|
||||
result, err := r.db.ExecContext(ctx, "DELETE FROM owners WHERE id = $1", id)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete owner: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("owner not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,242 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
"github.com/shankar0123/certctl/internal/repository"
|
||||
)
|
||||
|
||||
// PolicyRepository implements repository.PolicyRepository
|
||||
type PolicyRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewPolicyRepository creates a new PolicyRepository
|
||||
func NewPolicyRepository(db *sql.DB) *PolicyRepository {
|
||||
return &PolicyRepository{db: db}
|
||||
}
|
||||
|
||||
// ListRules returns all policy rules
|
||||
func (r *PolicyRepository) ListRules(ctx context.Context) ([]*domain.PolicyRule, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, name, type, config, enabled, created_at, updated_at
|
||||
FROM policy_rules
|
||||
ORDER BY created_at DESC
|
||||
`)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query policy rules: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var rules []*domain.PolicyRule
|
||||
for rows.Next() {
|
||||
var rule domain.PolicyRule
|
||||
if err := rows.Scan(&rule.ID, &rule.Name, &rule.Type, &rule.Config,
|
||||
&rule.Enabled, &rule.CreatedAt, &rule.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan policy rule: %w", err)
|
||||
}
|
||||
rules = append(rules, &rule)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating policy rule rows: %w", err)
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
// GetRule retrieves a policy rule by ID
|
||||
func (r *PolicyRepository) GetRule(ctx context.Context, id string) (*domain.PolicyRule, error) {
|
||||
var rule domain.PolicyRule
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, name, type, config, enabled, created_at, updated_at
|
||||
FROM policy_rules
|
||||
WHERE id = $1
|
||||
`, id).Scan(&rule.ID, &rule.Name, &rule.Type, &rule.Config,
|
||||
&rule.Enabled, &rule.CreatedAt, &rule.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("policy rule not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to query policy rule: %w", err)
|
||||
}
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
// CreateRule stores a new policy rule
|
||||
func (r *PolicyRepository) CreateRule(ctx context.Context, rule *domain.PolicyRule) error {
|
||||
if rule.ID == "" {
|
||||
rule.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO policy_rules (id, name, type, config, enabled, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id
|
||||
`, rule.ID, rule.Name, rule.Type, rule.Config, rule.Enabled,
|
||||
rule.CreatedAt, rule.UpdatedAt).Scan(&rule.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create policy rule: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateRule modifies an existing policy rule
|
||||
func (r *PolicyRepository) UpdateRule(ctx context.Context, rule *domain.PolicyRule) error {
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE policy_rules SET
|
||||
name = $1,
|
||||
type = $2,
|
||||
config = $3,
|
||||
enabled = $4,
|
||||
updated_at = $5
|
||||
WHERE id = $6
|
||||
`, rule.Name, rule.Type, rule.Config, rule.Enabled, rule.UpdatedAt, rule.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update policy rule: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("policy rule not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRule removes a policy rule
|
||||
func (r *PolicyRepository) DeleteRule(ctx context.Context, id string) error {
|
||||
result, err := r.db.ExecContext(ctx, "DELETE FROM policy_rules WHERE id = $1", id)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete policy rule: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("policy rule not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateViolation records a policy violation
|
||||
func (r *PolicyRepository) CreateViolation(ctx context.Context, violation *domain.PolicyViolation) error {
|
||||
if violation.ID == "" {
|
||||
violation.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO policy_violations (id, certificate_id, rule_id, message, severity, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id
|
||||
`, violation.ID, violation.CertificateID, violation.RuleID, violation.Message,
|
||||
violation.Severity, violation.CreatedAt).Scan(&violation.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create policy violation: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListViolations returns policy violations, optionally filtered
|
||||
func (r *PolicyRepository) ListViolations(ctx context.Context, filter *repository.AuditFilter) ([]*domain.PolicyViolation, error) {
|
||||
if filter == nil {
|
||||
filter = &repository.AuditFilter{}
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
if filter.Page < 1 {
|
||||
filter.Page = 1
|
||||
}
|
||||
if filter.PerPage == 0 || filter.PerPage > 500 {
|
||||
filter.PerPage = 50
|
||||
}
|
||||
|
||||
// Build WHERE clause
|
||||
var whereConditions []string
|
||||
var args []interface{}
|
||||
argCount := 1
|
||||
|
||||
if filter.ResourceID != "" {
|
||||
whereConditions = append(whereConditions, fmt.Sprintf("certificate_id = $%d", argCount))
|
||||
args = append(args, filter.ResourceID)
|
||||
argCount++
|
||||
}
|
||||
if !filter.From.IsZero() {
|
||||
whereConditions = append(whereConditions, fmt.Sprintf("created_at >= $%d", argCount))
|
||||
args = append(args, filter.From)
|
||||
argCount++
|
||||
}
|
||||
if !filter.To.IsZero() {
|
||||
whereConditions = append(whereConditions, fmt.Sprintf("created_at <= $%d", argCount))
|
||||
args = append(args, filter.To)
|
||||
argCount++
|
||||
}
|
||||
|
||||
whereClause := ""
|
||||
if len(whereConditions) > 0 {
|
||||
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
|
||||
}
|
||||
|
||||
// Get total count
|
||||
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM policy_violations %s", whereClause)
|
||||
var total int
|
||||
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
return nil, fmt.Errorf("failed to count policy violations: %w", err)
|
||||
}
|
||||
|
||||
// Get paginated results
|
||||
offset := (filter.Page - 1) * filter.PerPage
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, certificate_id, rule_id, message, severity, created_at
|
||||
FROM policy_violations
|
||||
%s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $%d OFFSET $%d
|
||||
`, whereClause, argCount, argCount+1)
|
||||
|
||||
args = append(args, filter.PerPage, offset)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query policy violations: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var violations []*domain.PolicyViolation
|
||||
for rows.Next() {
|
||||
var v domain.PolicyViolation
|
||||
if err := rows.Scan(&v.ID, &v.CertificateID, &v.RuleID, &v.Message,
|
||||
&v.Severity, &v.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan policy violation: %w", err)
|
||||
}
|
||||
violations = append(violations, &v)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating policy violation rows: %w", err)
|
||||
}
|
||||
|
||||
return violations, nil
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
// TargetRepository implements repository.TargetRepository
|
||||
type TargetRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewTargetRepository creates a new TargetRepository
|
||||
func NewTargetRepository(db *sql.DB) *TargetRepository {
|
||||
return &TargetRepository{db: db}
|
||||
}
|
||||
|
||||
// List returns all targets
|
||||
func (r *TargetRepository) List(ctx context.Context) ([]*domain.DeploymentTarget, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, name, type, agent_id, config, enabled, created_at, updated_at
|
||||
FROM deployment_targets
|
||||
ORDER BY created_at DESC
|
||||
`)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query targets: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var targets []*domain.DeploymentTarget
|
||||
for rows.Next() {
|
||||
var target domain.DeploymentTarget
|
||||
if err := rows.Scan(&target.ID, &target.Name, &target.Type, &target.AgentID,
|
||||
&target.Config, &target.Enabled, &target.CreatedAt, &target.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan target: %w", err)
|
||||
}
|
||||
targets = append(targets, &target)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating target rows: %w", err)
|
||||
}
|
||||
|
||||
return targets, nil
|
||||
}
|
||||
|
||||
// Get retrieves a target by ID
|
||||
func (r *TargetRepository) Get(ctx context.Context, id string) (*domain.DeploymentTarget, error) {
|
||||
var target domain.DeploymentTarget
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, name, type, agent_id, config, enabled, created_at, updated_at
|
||||
FROM deployment_targets
|
||||
WHERE id = $1
|
||||
`, id).Scan(&target.ID, &target.Name, &target.Type, &target.AgentID,
|
||||
&target.Config, &target.Enabled, &target.CreatedAt, &target.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("target not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to query target: %w", err)
|
||||
}
|
||||
|
||||
return &target, nil
|
||||
}
|
||||
|
||||
// Create stores a new target
|
||||
func (r *TargetRepository) Create(ctx context.Context, target *domain.DeploymentTarget) error {
|
||||
if target.ID == "" {
|
||||
target.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO deployment_targets (id, name, type, agent_id, config, enabled, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id
|
||||
`, target.ID, target.Name, target.Type, target.AgentID, target.Config, target.Enabled,
|
||||
target.CreatedAt, target.UpdatedAt).Scan(&target.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create target: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update modifies an existing target
|
||||
func (r *TargetRepository) Update(ctx context.Context, target *domain.DeploymentTarget) error {
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE deployment_targets SET
|
||||
name = $1,
|
||||
type = $2,
|
||||
agent_id = $3,
|
||||
config = $4,
|
||||
enabled = $5,
|
||||
updated_at = $6
|
||||
WHERE id = $7
|
||||
`, target.Name, target.Type, target.AgentID, target.Config, target.Enabled, target.UpdatedAt, target.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update target: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("target not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a target
|
||||
func (r *TargetRepository) Delete(ctx context.Context, id string) error {
|
||||
result, err := r.db.ExecContext(ctx, "DELETE FROM deployment_targets WHERE id = $1", id)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete target: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("target not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListByCertificate returns all targets for a given certificate
|
||||
func (r *TargetRepository) ListByCertificate(ctx context.Context, certID string) ([]*domain.DeploymentTarget, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT dt.id, dt.name, dt.type, dt.agent_id, dt.config, dt.enabled, dt.created_at, dt.updated_at
|
||||
FROM deployment_targets dt
|
||||
INNER JOIN certificate_target_mappings ctm ON dt.id = ctm.target_id
|
||||
WHERE ctm.certificate_id = $1
|
||||
ORDER BY dt.created_at DESC
|
||||
`, certID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query targets for certificate: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var targets []*domain.DeploymentTarget
|
||||
for rows.Next() {
|
||||
var target domain.DeploymentTarget
|
||||
if err := rows.Scan(&target.ID, &target.Name, &target.Type, &target.AgentID,
|
||||
&target.Config, &target.Enabled, &target.CreatedAt, &target.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan target: %w", err)
|
||||
}
|
||||
targets = append(targets, &target)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating target rows: %w", err)
|
||||
}
|
||||
|
||||
return targets, nil
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/shankar0123/certctl/internal/domain"
|
||||
)
|
||||
|
||||
// TeamRepository implements repository.TeamRepository
|
||||
type TeamRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewTeamRepository creates a new TeamRepository
|
||||
func NewTeamRepository(db *sql.DB) *TeamRepository {
|
||||
return &TeamRepository{db: db}
|
||||
}
|
||||
|
||||
// List returns all teams
|
||||
func (r *TeamRepository) List(ctx context.Context) ([]*domain.Team, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, name, description, created_at, updated_at
|
||||
FROM teams
|
||||
ORDER BY created_at DESC
|
||||
`)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query teams: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var teams []*domain.Team
|
||||
for rows.Next() {
|
||||
var team domain.Team
|
||||
if err := rows.Scan(&team.ID, &team.Name, &team.Description,
|
||||
&team.CreatedAt, &team.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan team: %w", err)
|
||||
}
|
||||
teams = append(teams, &team)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating team rows: %w", err)
|
||||
}
|
||||
|
||||
return teams, nil
|
||||
}
|
||||
|
||||
// Get retrieves a team by ID
|
||||
func (r *TeamRepository) Get(ctx context.Context, id string) (*domain.Team, error) {
|
||||
var team domain.Team
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, name, description, created_at, updated_at
|
||||
FROM teams
|
||||
WHERE id = $1
|
||||
`, id).Scan(&team.ID, &team.Name, &team.Description,
|
||||
&team.CreatedAt, &team.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("team not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to query team: %w", err)
|
||||
}
|
||||
|
||||
return &team, nil
|
||||
}
|
||||
|
||||
// Create stores a new team
|
||||
func (r *TeamRepository) Create(ctx context.Context, team *domain.Team) error {
|
||||
if team.ID == "" {
|
||||
team.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO teams (id, name, description, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING id
|
||||
`, team.ID, team.Name, team.Description, team.CreatedAt, team.UpdatedAt).Scan(&team.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create team: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update modifies an existing team
|
||||
func (r *TeamRepository) Update(ctx context.Context, team *domain.Team) error {
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE teams SET
|
||||
name = $1,
|
||||
description = $2,
|
||||
updated_at = $3
|
||||
WHERE id = $4
|
||||
`, team.Name, team.Description, team.UpdatedAt, team.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update team: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("team not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a team
|
||||
func (r *TeamRepository) Delete(ctx context.Context, id string) error {
|
||||
result, err := r.db.ExecContext(ctx, "DELETE FROM teams WHERE id = $1", id)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete team: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("team not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user