// Copyright 2026 certctl LLC. All rights reserved. // SPDX-License-Identifier: BUSL-1.1 package postgres import ( "context" "database/sql" "fmt" "strings" "github.com/certctl-io/certctl/internal/domain" "github.com/certctl-io/certctl/internal/repository" "github.com/google/uuid" ) // AuditRepository implements repository.AuditRepository type AuditRepository struct { db *sql.DB } // NewAuditRepository creates a new AuditRepository func NewAuditRepository(db *sql.DB) *AuditRepository { return &AuditRepository{db: db} } // Create stores a new audit event using the repository's package-level // *sql.DB. Use CreateWithTx when the audit event must be atomic with // another database operation in a service-layer transaction. func (r *AuditRepository) Create(ctx context.Context, event *domain.AuditEvent) error { return r.CreateWithTx(ctx, r.db, event) } // CreateWithTx stores a new audit event using the supplied Querier. // Pass *sql.Tx (typically from postgres.WithinTx) to participate in a // caller's transaction; pass *sql.DB or call Create for stand-alone // inserts. The SQL and side-effect contract is identical to Create — // CreateWithTx is the load-bearing path that closes the audit's // atomicity blocker (audit row must be transactional with the // operation that triggered it). func (r *AuditRepository) CreateWithTx(ctx context.Context, q repository.Querier, event *domain.AuditEvent) error { if event.ID == "" { event.ID = uuid.New().String() } // Bundle 1 Phase 8: empty EventCategory defaults to // cert_lifecycle (matches the migration's DEFAULT clause + the // DB CHECK constraint). The boundary catches callers that // haven't yet been migrated to the categorized API. if event.EventCategory == "" { event.EventCategory = domain.EventCategoryCertLifecycle } err := q.QueryRowContext(ctx, ` INSERT INTO audit_events ( id, actor, actor_type, action, resource_type, resource_id, details, timestamp, event_category ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id `, event.ID, event.Actor, event.ActorType, event.Action, event.ResourceType, event.ResourceID, event.Details, event.Timestamp, event.EventCategory).Scan(&event.ID) if err != nil { return fmt.Errorf("failed to create audit event: %w", err) } return nil } // List returns audit events matching the filter criteria func (r *AuditRepository) List(ctx context.Context, filter *repository.AuditFilter) ([]*domain.AuditEvent, error) { if filter == nil { filter = &repository.AuditFilter{} } // Set defaults if filter.Page < 1 { filter.Page = 1 } if filter.PerPage == 0 || filter.PerPage > 500 { filter.PerPage = 50 } // Build WHERE clause var whereConditions []string var args []interface{} argCount := 1 if filter.Actor != "" { whereConditions = append(whereConditions, fmt.Sprintf("actor = $%d", argCount)) args = append(args, filter.Actor) argCount++ } if filter.ActorType != "" { whereConditions = append(whereConditions, fmt.Sprintf("actor_type = $%d", argCount)) args = append(args, filter.ActorType) argCount++ } if filter.ResourceType != "" { whereConditions = append(whereConditions, fmt.Sprintf("resource_type = $%d", argCount)) args = append(args, filter.ResourceType) argCount++ } if filter.ResourceID != "" { whereConditions = append(whereConditions, fmt.Sprintf("resource_id = $%d", argCount)) args = append(args, filter.ResourceID) argCount++ } if !filter.From.IsZero() { whereConditions = append(whereConditions, fmt.Sprintf("timestamp >= $%d", argCount)) args = append(args, filter.From) argCount++ } if !filter.To.IsZero() { whereConditions = append(whereConditions, fmt.Sprintf("timestamp <= $%d", argCount)) args = append(args, filter.To) argCount++ } if filter.EventCategory != "" { whereConditions = append(whereConditions, fmt.Sprintf("event_category = $%d", argCount)) args = append(args, filter.EventCategory) argCount++ } whereClause := "" if len(whereConditions) > 0 { whereClause = "WHERE " + strings.Join(whereConditions, " AND ") } // Get total count countQuery := fmt.Sprintf("SELECT COUNT(*) FROM audit_events %s", whereClause) var total int if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { return nil, fmt.Errorf("failed to count audit events: %w", err) } // Get paginated results offset := (filter.Page - 1) * filter.PerPage query := fmt.Sprintf(` SELECT id, actor, actor_type, action, resource_type, resource_id, details, timestamp, event_category FROM audit_events %s ORDER BY timestamp DESC LIMIT $%d OFFSET $%d `, whereClause, argCount, argCount+1) args = append(args, filter.PerPage, offset) rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("failed to query audit events: %w", err) } defer rows.Close() var events []*domain.AuditEvent for rows.Next() { var event domain.AuditEvent if err := rows.Scan(&event.ID, &event.Actor, &event.ActorType, &event.Action, &event.ResourceType, &event.ResourceID, &event.Details, &event.Timestamp, &event.EventCategory); err != nil { return nil, fmt.Errorf("failed to scan audit event: %w", err) } events = append(events, &event) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating audit event rows: %w", err) } return events, nil } // VerifyHashChain calls the migration 000047 audit_events_verify_chain() // stored function and returns its three OUT parameters. This is the // Sprint 6 COMP-001-HASH tamper-evidence verifier — the scheduler's // auditChainVerifyLoop invokes it every CERTCTL_AUDIT_CHAIN_VERIFY_INTERVAL // tick and emits the certctl_audit_chain_break_detected counter on any // non-empty brokenAtID. // // The chain walk happens entirely server-side (plpgsql, STABLE). For an // audit_events table with N rows the cost is O(N) per call; we expect // modest fleets (single-digit-millions of events) so the per-tick cost // is bounded. Operators with very large audit tables can lengthen the // interval — the metric is sticky once incremented, so even an hourly // walk is enough lead time to surface tampering for human investigation. func (r *AuditRepository) VerifyHashChain(ctx context.Context) (brokenAtID string, brokenAtPos int, rowCount int, err error) { var ( brokenID sql.NullString pos sql.NullInt32 total sql.NullInt32 ) row := r.db.QueryRowContext(ctx, `SELECT first_break_id, first_break_pos, row_count FROM audit_events_verify_chain()`) if err := row.Scan(&brokenID, &pos, &total); err != nil { return "", -1, 0, fmt.Errorf("audit_events_verify_chain: %w", err) } if brokenID.Valid { brokenAtID = brokenID.String } if pos.Valid { brokenAtPos = int(pos.Int32) } else { brokenAtPos = -1 } if total.Valid { rowCount = int(total.Int32) } return brokenAtID, brokenAtPos, rowCount, nil }