package postgres import ( "context" "database/sql" "encoding/json" "fmt" "strings" "time" "github.com/google/uuid" "github.com/lib/pq" "github.com/shankar0123/certctl/internal/domain" "github.com/shankar0123/certctl/internal/repository" ) // CertificateRepository implements repository.CertificateRepository type CertificateRepository struct { db *sql.DB } // NewCertificateRepository creates a new CertificateRepository func NewCertificateRepository(db *sql.DB) *CertificateRepository { return &CertificateRepository{db: db} } // List returns a paginated list of certificates matching the filter criteria func (r *CertificateRepository) List(ctx context.Context, filter *repository.CertificateFilter) ([]*domain.ManagedCertificate, int, error) { if filter == nil { filter = &repository.CertificateFilter{} } // Set defaults if filter.Page < 1 { filter.Page = 1 } if filter.PerPage == 0 || filter.PerPage > 500 { filter.PerPage = 50 } // Build WHERE clause var whereConditions []string var args []interface{} argCount := 1 if filter.Status != "" { whereConditions = append(whereConditions, fmt.Sprintf("status = $%d", argCount)) args = append(args, filter.Status) argCount++ } if filter.Environment != "" { whereConditions = append(whereConditions, fmt.Sprintf("environment = $%d", argCount)) args = append(args, filter.Environment) argCount++ } if filter.OwnerID != "" { whereConditions = append(whereConditions, fmt.Sprintf("owner_id = $%d", argCount)) args = append(args, filter.OwnerID) argCount++ } if filter.TeamID != "" { whereConditions = append(whereConditions, fmt.Sprintf("team_id = $%d", argCount)) args = append(args, filter.TeamID) argCount++ } if filter.IssuerID != "" { whereConditions = append(whereConditions, fmt.Sprintf("issuer_id = $%d", argCount)) args = append(args, filter.IssuerID) argCount++ } whereClause := "" if len(whereConditions) > 0 { whereClause = "WHERE " + strings.Join(whereConditions, " AND ") } // Get total count countQuery := fmt.Sprintf("SELECT COUNT(*) FROM managed_certificates %s", whereClause) var total int if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { return nil, 0, fmt.Errorf("failed to count certificates: %w", err) } // Get paginated results offset := (filter.Page - 1) * filter.PerPage query := fmt.Sprintf(` SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at FROM managed_certificates %s ORDER BY created_at DESC LIMIT $%d OFFSET $%d `, whereClause, argCount, argCount+1) args = append(args, filter.PerPage, offset) rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { return nil, 0, fmt.Errorf("failed to query certificates: %w", err) } defer rows.Close() var certs []*domain.ManagedCertificate for rows.Next() { cert, err := scanCertificate(rows) if err != nil { return nil, 0, err } certs = append(certs, cert) } if err := rows.Err(); err != nil { return nil, 0, fmt.Errorf("error iterating certificate rows: %w", err) } return certs, total, nil } // Get retrieves a certificate by ID func (r *CertificateRepository) Get(ctx context.Context, id string) (*domain.ManagedCertificate, error) { row := r.db.QueryRowContext(ctx, ` SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at FROM managed_certificates WHERE id = $1 `, id) cert, err := scanCertificate(row) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("certificate not found") } return nil, fmt.Errorf("failed to query certificate: %w", err) } return cert, nil } // Create stores a new certificate func (r *CertificateRepository) Create(ctx context.Context, cert *domain.ManagedCertificate) error { if cert.ID == "" { cert.ID = uuid.New().String() } tagsJSON, err := json.Marshal(cert.Tags) if err != nil { return fmt.Errorf("failed to marshal tags: %w", err) } err = r.db.QueryRowContext(ctx, ` INSERT INTO managed_certificates ( id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) RETURNING id `, cert.ID, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment, cert.OwnerID, cert.TeamID, cert.IssuerID, cert.RenewalPolicyID, cert.Status, cert.ExpiresAt, tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.CreatedAt, cert.UpdatedAt).Scan(&cert.ID) if err != nil { return fmt.Errorf("failed to create certificate: %w", err) } return nil } // Update modifies an existing certificate func (r *CertificateRepository) Update(ctx context.Context, cert *domain.ManagedCertificate) error { tagsJSON, err := json.Marshal(cert.Tags) if err != nil { return fmt.Errorf("failed to marshal tags: %w", err) } result, err := r.db.ExecContext(ctx, ` UPDATE managed_certificates SET name = $1, common_name = $2, sans = $3, environment = $4, owner_id = $5, team_id = $6, issuer_id = $7, status = $8, expires_at = $9, tags = $10, last_renewal_at = $11, last_deployment_at = $12, updated_at = $13 WHERE id = $14 `, cert.Name, cert.CommonName, pq.Array(cert.SANs), cert.Environment, cert.OwnerID, cert.TeamID, cert.IssuerID, cert.Status, cert.ExpiresAt, tagsJSON, cert.LastRenewalAt, cert.LastDeploymentAt, cert.UpdatedAt, cert.ID) if err != nil { return fmt.Errorf("failed to update certificate: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get rows affected: %w", err) } if rows == 0 { return fmt.Errorf("certificate not found") } return nil } // Archive marks a certificate as archived func (r *CertificateRepository) Archive(ctx context.Context, id string) error { result, err := r.db.ExecContext(ctx, ` UPDATE managed_certificates SET status = $1, updated_at = $2 WHERE id = $3 `, domain.CertificateStatusArchived, time.Now(), id) if err != nil { return fmt.Errorf("failed to archive certificate: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get rows affected: %w", err) } if rows == 0 { return fmt.Errorf("certificate not found") } return nil } // ListVersions returns all versions of a certificate func (r *CertificateRepository) ListVersions(ctx context.Context, certID string) ([]*domain.CertificateVersion, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, certificate_id, serial_number, not_before, not_after, fingerprint_sha256, pem_chain, csr_pem, created_at FROM certificate_versions WHERE certificate_id = $1 ORDER BY created_at DESC `, certID) if err != nil { return nil, fmt.Errorf("failed to query certificate versions: %w", err) } defer rows.Close() var versions []*domain.CertificateVersion for rows.Next() { var v domain.CertificateVersion if err := rows.Scan(&v.ID, &v.CertificateID, &v.SerialNumber, &v.NotBefore, &v.NotAfter, &v.FingerprintSHA256, &v.PEMChain, &v.CSRPEM, &v.CreatedAt); err != nil { return nil, fmt.Errorf("failed to scan certificate version: %w", err) } versions = append(versions, &v) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating version rows: %w", err) } return versions, nil } // CreateVersion stores a new certificate version func (r *CertificateRepository) CreateVersion(ctx context.Context, version *domain.CertificateVersion) error { if version.ID == "" { version.ID = uuid.New().String() } err := r.db.QueryRowContext(ctx, ` INSERT INTO certificate_versions ( id, certificate_id, serial_number, not_before, not_after, fingerprint_sha256, pem_chain, csr_pem, created_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id `, version.ID, version.CertificateID, version.SerialNumber, version.NotBefore, version.NotAfter, version.FingerprintSHA256, version.PEMChain, version.CSRPEM, version.CreatedAt).Scan(&version.ID) if err != nil { return fmt.Errorf("failed to create certificate version: %w", err) } return nil } // GetExpiringCertificates returns certificates expiring before the given time func (r *CertificateRepository) GetExpiringCertificates(ctx context.Context, before time.Time) ([]*domain.ManagedCertificate, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, name, common_name, sans, environment, owner_id, team_id, issuer_id, renewal_policy_id, status, expires_at, tags, last_renewal_at, last_deployment_at, created_at, updated_at FROM managed_certificates WHERE expires_at < $1 AND status != $2 ORDER BY expires_at ASC `, before, domain.CertificateStatusArchived) if err != nil { return nil, fmt.Errorf("failed to query expiring certificates: %w", err) } defer rows.Close() var certs []*domain.ManagedCertificate for rows.Next() { cert, err := scanCertificate(rows) if err != nil { return nil, err } certs = append(certs, cert) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating expiring certificate rows: %w", err) } return certs, nil } // scanCertificate scans a certificate from a row or rows func scanCertificate(scanner interface { Scan(...interface{}) error }) (*domain.ManagedCertificate, error) { var cert domain.ManagedCertificate var tagsJSON []byte var sans pq.StringArray err := scanner.Scan( &cert.ID, &cert.Name, &cert.CommonName, &sans, &cert.Environment, &cert.OwnerID, &cert.TeamID, &cert.IssuerID, &cert.RenewalPolicyID, &cert.Status, &cert.ExpiresAt, &tagsJSON, &cert.LastRenewalAt, &cert.LastDeploymentAt, &cert.CreatedAt, &cert.UpdatedAt) if err != nil { return nil, fmt.Errorf("failed to scan certificate: %w", err) } cert.SANs = []string(sans) // Unmarshal tags if len(tagsJSON) > 0 { if err := json.Unmarshal(tagsJSON, &cert.Tags); err != nil { return nil, fmt.Errorf("failed to unmarshal tags: %w", err) } } else { cert.Tags = make(map[string]string) } return &cert, nil }