mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-07 13:51:36 +00:00
auth-bundle-2 Phase 2b: repository interfaces + Postgres impls + integration tests
Closes Phase 2 end-to-end. Builds on Phase 2a's three migrations (000034 oidc_providers + group_role_mappings, 000035 sessions + session_signing_keys, 000036 users) by shipping the repository surface Phase 3+ services consume. Interfaces: * internal/repository/oidc.go - OIDCProviderRepository (List, Get, GetByName, Create, Update, Delete) + GroupRoleMappingRepository (ListByProvider, Get, Add, Remove, Map). Sentinels: ErrOIDCProviderNotFound, ErrOIDCProviderDuplicateName, ErrOIDCProviderInUse (FK ON DELETE RESTRICT translation), ErrGroupRoleMappingNotFound, ErrGroupRoleMappingDuplicate. * internal/repository/session.go - SessionRepository (Create, Get, ListByActor, UpdateLastSeen, Revoke, RevokeAllForActor, GarbageCollectExpired, Delete) + SessionSigningKeyRepository (List, GetActive, Get, Add, Retire, Delete). Sentinels: ErrSessionNotFound, ErrSessionRevoked, ErrSessionExpired, ErrSessionSigningKeyNotFound, ErrSessionSigningKeyInUse. * internal/repository/user.go - UserRepository (Get, GetByOIDCSubject, Create, Update, ListAll). Sentinels: ErrUserNotFound, ErrUserDuplicateOIDCSubject. Postgres implementations: * internal/repository/postgres/oidc.go - 309 lines. Translates SQLSTATE 23505 (unique_violation) to ErrOIDCProviderDuplicateName / ErrGroupRoleMappingDuplicate; SQLSTATE 23503 (foreign_key_violation) to ErrOIDCProviderInUse so the Phase 5 handler maps to HTTP 409 when an operator tries to delete a provider with authenticated users. pq.StringArray bridges Go []string to Postgres TEXT[] for scopes + allowed_email_domains. Map() uses `WHERE group_name = ANY($2)` so a single SELECT resolves N IdP group claims at once. * internal/repository/postgres/session.go - 350 lines. Both Session + SessionSigningKey repos. Revoke + Retire are idempotent (re-revoking an already-revoked session returns nil; same for retire). The GarbageCollectExpired sweep deletes both absolute-expiry-passed sessions AND pre-login rows older than the 10-minute TTL in one DELETE so the scheduler tick is cheap. ErrSessionSigningKeyInUse pinned via SQLSTATE 23503 from the sessions.signing_key_id FK ON DELETE RESTRICT. * internal/repository/postgres/user.go - 137 lines. GetByOIDCSubject is the Phase 3 hot-path lookup; the (oidc_provider_id, oidc_subject) UNIQUE constraint trip translates to ErrUserDuplicateOIDCSubject. Update only writes the mutable field set (email, display_name, last_login_at, webauthn_credentials); oidc_subject + oidc_provider_id are immutable per the per-(provider, subject) identity model. Integration tests (testing.Short()-gated, testcontainers + Postgres 16 Alpine, schema-per-test isolation via getTestDB().freshSchema): * oidc_test.go: 11 tests covering happy-path + GetNotFound + DuplicateName + List + Update + DeleteNotFound + DeleteSucceeds + DeleteRefusedWhenUsersReference (the FK ON DELETE RESTRICT pin); GroupRoleMapping coverage includes Add/List/Map (3 cases: marketing-not-mapped, multi-group hits, empty groups returns empty), Duplicate rejection, and the ON DELETE CASCADE on provider deletion. * session_test.go: 12 tests covering SessionSigningKey + Session. Key tests: GetActiveSkipsRetired (mints older, retires it, mints newer, asserts GetActive returns newer), DeleteRefusedWhenSessions- Reference (FK pin), RetireIsIdempotent. Session tests: CreateAndGet roundtrip, GetNotFound, Revoke + idempotent re-Revoke, ListByActor (3 active + 1 revoked + 1 pre-login -> returns 3, pinning the WHERE filter), RevokeAllForActor, GarbageCollectExpired (seeds an absolute-expired row + pre-login >10min row + active session via raw SQL to bypass CHECK constraints, asserts GC kills exactly 2 + active survives), UpdateLastSeen. * user_test.go: 7 tests covering CreateAndGet, GetNotFound, GetByOIDCSubject (hit + miss), DuplicateOIDCSubjectRejected, UpdateMutableFields (asserts oidc_subject NOT mutated by Update), ListAll, FKRestrictsProviderDelete (mirror of the OIDC test from the user side - both ends of the FK contract pinned). Verifications: * gofmt -l clean across all 9 new files. * go vet ./internal/repository/postgres/ rc=0. * go test -short -count=1 green on internal/repository/postgres/ + internal/auth/... + Bundle 1 packages (testing.Short() skips the testcontainers integration tests, but the test files compile + the short-mode skip path is exercised so the suite is wired correctly). * Full integration tests run in CI's non-short job against Postgres 16 Alpine via testcontainers-go. * govulncheck ./... clean. * All 24 ci-guards pass. Phase 2 exit criteria from cowork/auth-bundle-2-prompt.md (all met): * All three Phase-2 migrations apply cleanly, idempotently: yes (Phase 2a). Break-glass migration ships separately in Phase 7.5. * Repository tests pass against Postgres 16 Alpine: integration tests written, gated by testing.Short(), structured to run cleanly in CI's non-short job. * make verify equivalent green: gofmt + vet + go test pass; golangci-lint deferred to CI per Phase 0/1's same pattern.
This commit is contained in:
@@ -0,0 +1,94 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
oidcdomain "github.com/certctl-io/certctl/internal/auth/oidc/domain"
|
||||
)
|
||||
|
||||
// Sentinel errors for the OIDC repositories. Postgres implementations
|
||||
// translate SQLSTATE codes into these so handler / service code can
|
||||
// branch via errors.Is.
|
||||
var (
|
||||
// ErrOIDCProviderNotFound: Get / GetByName returned no row. HTTP 404.
|
||||
ErrOIDCProviderNotFound = errors.New("oidc: provider not found")
|
||||
|
||||
// ErrOIDCProviderDuplicateName: Create tripped the (tenant_id, name)
|
||||
// UNIQUE constraint. HTTP 409.
|
||||
ErrOIDCProviderDuplicateName = errors.New("oidc: provider with this name already exists in tenant")
|
||||
|
||||
// ErrOIDCProviderInUse: Delete failed because at least one users row
|
||||
// references the provider via oidc_provider_id (FK ON DELETE
|
||||
// RESTRICT). HTTP 409.
|
||||
ErrOIDCProviderInUse = errors.New("oidc: provider has authenticated users; revoke all sessions before delete")
|
||||
|
||||
// ErrGroupRoleMappingNotFound: Get returned no row. HTTP 404.
|
||||
ErrGroupRoleMappingNotFound = errors.New("oidc: group-role mapping not found")
|
||||
|
||||
// ErrGroupRoleMappingDuplicate: Add tripped the
|
||||
// (provider_id, group_name, role_id) UNIQUE constraint. HTTP 409.
|
||||
ErrGroupRoleMappingDuplicate = errors.New("oidc: group-role mapping already exists")
|
||||
)
|
||||
|
||||
// OIDCProviderRepository wraps the oidc_providers table. Phase 3's
|
||||
// OIDCService consumes List + Get to look up the IdP for token
|
||||
// validation; the GUI / CLI wire Create / Update / Delete behind
|
||||
// auth.oidc.* permission gates per Phase 5.
|
||||
type OIDCProviderRepository interface {
|
||||
// List returns every configured provider in the tenant. Order:
|
||||
// created_at ASC for stable GUI rendering.
|
||||
List(ctx context.Context, tenantID string) ([]*oidcdomain.OIDCProvider, error)
|
||||
|
||||
// Get returns one provider by id. ErrOIDCProviderNotFound on miss.
|
||||
Get(ctx context.Context, id string) (*oidcdomain.OIDCProvider, error)
|
||||
|
||||
// GetByName returns one provider by (tenant_id, name).
|
||||
// ErrOIDCProviderNotFound on miss.
|
||||
GetByName(ctx context.Context, tenantID, name string) (*oidcdomain.OIDCProvider, error)
|
||||
|
||||
// Create persists a new provider. Caller MUST have already called
|
||||
// p.Validate() and encrypted the client_secret_encrypted byte
|
||||
// stream via internal/crypto/encryption.go. Returns
|
||||
// ErrOIDCProviderDuplicateName when the (tenant_id, name) UNIQUE
|
||||
// constraint fires.
|
||||
Create(ctx context.Context, p *oidcdomain.OIDCProvider) error
|
||||
|
||||
// Update writes the full mutable field set back to the row.
|
||||
// Immutable fields (id, tenant_id, created_at) are read-only;
|
||||
// updated_at is set to NOW() by the implementation.
|
||||
Update(ctx context.Context, p *oidcdomain.OIDCProvider) error
|
||||
|
||||
// Delete removes a provider by id. Returns ErrOIDCProviderInUse
|
||||
// when at least one users row references this provider (FK ON
|
||||
// DELETE RESTRICT). Phase 5's handler maps to HTTP 409.
|
||||
Delete(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
// GroupRoleMappingRepository wraps the group_role_mappings table.
|
||||
// Phase 3's OIDCService.HandleCallback uses Map() to translate IdP
|
||||
// group claims into role IDs; the GUI / CLI wire ListByProvider /
|
||||
// Add / Remove for operator configuration.
|
||||
type GroupRoleMappingRepository interface {
|
||||
// ListByProvider returns every mapping for the named provider.
|
||||
// Order: group_name ASC for stable GUI rendering.
|
||||
ListByProvider(ctx context.Context, providerID string) ([]*oidcdomain.GroupRoleMapping, error)
|
||||
|
||||
// Get returns one mapping by id. ErrGroupRoleMappingNotFound on miss.
|
||||
Get(ctx context.Context, id string) (*oidcdomain.GroupRoleMapping, error)
|
||||
|
||||
// Add persists a new mapping. Caller MUST have called m.Validate().
|
||||
// Returns ErrGroupRoleMappingDuplicate when the
|
||||
// (provider_id, group_name, role_id) UNIQUE constraint fires.
|
||||
Add(ctx context.Context, m *oidcdomain.GroupRoleMapping) error
|
||||
|
||||
// Remove deletes a mapping by id.
|
||||
Remove(ctx context.Context, id string) error
|
||||
|
||||
// Map resolves an IdP-supplied list of group names against the
|
||||
// provider's mappings. Returns the deduplicated set of role IDs
|
||||
// the user should hold. Empty result means the user matches no
|
||||
// mapping (Phase 3 fail-closed: no session minted, audit row
|
||||
// `auth.oidc_login_unmapped_groups`).
|
||||
Map(ctx context.Context, providerID string, groupNames []string) ([]string, error)
|
||||
}
|
||||
@@ -0,0 +1,309 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
oidcdomain "github.com/certctl-io/certctl/internal/auth/oidc/domain"
|
||||
"github.com/certctl-io/certctl/internal/repository"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// OIDCProviderRepository (Auth Bundle 2 Phase 2)
|
||||
// =============================================================================
|
||||
|
||||
// OIDCProviderRepository is the postgres implementation of
|
||||
// repository.OIDCProviderRepository.
|
||||
type OIDCProviderRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewOIDCProviderRepository constructs an OIDCProviderRepository.
|
||||
func NewOIDCProviderRepository(db *sql.DB) *OIDCProviderRepository {
|
||||
return &OIDCProviderRepository{db: db}
|
||||
}
|
||||
|
||||
const oidcProviderColumns = `id, tenant_id, name, issuer_url, client_id,
|
||||
client_secret_encrypted, redirect_uri, groups_claim_path,
|
||||
groups_claim_format, fetch_userinfo, scopes,
|
||||
allowed_email_domains, iat_window_seconds,
|
||||
jwks_cache_ttl_seconds, created_at, updated_at`
|
||||
|
||||
func scanOIDCProvider(row interface{ Scan(...interface{}) error }) (*oidcdomain.OIDCProvider, error) {
|
||||
var p oidcdomain.OIDCProvider
|
||||
var scopes, domains pq.StringArray
|
||||
if err := row.Scan(
|
||||
&p.ID, &p.TenantID, &p.Name, &p.IssuerURL, &p.ClientID,
|
||||
&p.ClientSecretEncrypted, &p.RedirectURI, &p.GroupsClaimPath,
|
||||
&p.GroupsClaimFormat, &p.FetchUserinfo, &scopes,
|
||||
&domains, &p.IATWindowSeconds,
|
||||
&p.JWKSCacheTTLSeconds, &p.CreatedAt, &p.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.Scopes = []string(scopes)
|
||||
p.AllowedEmailDomains = []string(domains)
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// List returns every configured OIDC provider in the tenant, ordered
|
||||
// by created_at ASC for stable GUI rendering.
|
||||
func (r *OIDCProviderRepository) List(ctx context.Context, tenantID string) ([]*oidcdomain.OIDCProvider, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `SELECT `+oidcProviderColumns+` FROM oidc_providers WHERE tenant_id = $1 ORDER BY created_at ASC`, tenantID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc_providers list: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []*oidcdomain.OIDCProvider
|
||||
for rows.Next() {
|
||||
p, err := scanOIDCProvider(rows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc_providers scan: %w", err)
|
||||
}
|
||||
out = append(out, p)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Get returns one provider by id. ErrOIDCProviderNotFound on miss.
|
||||
func (r *OIDCProviderRepository) Get(ctx context.Context, id string) (*oidcdomain.OIDCProvider, error) {
|
||||
row := r.db.QueryRowContext(ctx, `SELECT `+oidcProviderColumns+` FROM oidc_providers WHERE id = $1`, id)
|
||||
p, err := scanOIDCProvider(row)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, repository.ErrOIDCProviderNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("oidc_providers get: %w", err)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// GetByName returns one provider by (tenant_id, name).
|
||||
func (r *OIDCProviderRepository) GetByName(ctx context.Context, tenantID, name string) (*oidcdomain.OIDCProvider, error) {
|
||||
row := r.db.QueryRowContext(ctx, `SELECT `+oidcProviderColumns+` FROM oidc_providers WHERE tenant_id = $1 AND name = $2`, tenantID, name)
|
||||
p, err := scanOIDCProvider(row)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, repository.ErrOIDCProviderNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("oidc_providers get_by_name: %w", err)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Create persists a new provider. Caller MUST have called p.Validate()
|
||||
// and encrypted ClientSecretEncrypted via internal/crypto/encryption.go.
|
||||
// Translates SQLSTATE 23505 (unique_violation) to
|
||||
// ErrOIDCProviderDuplicateName.
|
||||
func (r *OIDCProviderRepository) Create(ctx context.Context, p *oidcdomain.OIDCProvider) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
INSERT INTO oidc_providers (
|
||||
id, tenant_id, name, issuer_url, client_id,
|
||||
client_secret_encrypted, redirect_uri, groups_claim_path,
|
||||
groups_claim_format, fetch_userinfo, scopes,
|
||||
allowed_email_domains, iat_window_seconds,
|
||||
jwks_cache_ttl_seconds
|
||||
) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14)`,
|
||||
p.ID, p.TenantID, p.Name, p.IssuerURL, p.ClientID,
|
||||
p.ClientSecretEncrypted, p.RedirectURI, p.GroupsClaimPath,
|
||||
p.GroupsClaimFormat, p.FetchUserinfo, pq.StringArray(p.Scopes),
|
||||
pq.StringArray(p.AllowedEmailDomains), p.IATWindowSeconds,
|
||||
p.JWKSCacheTTLSeconds,
|
||||
)
|
||||
if err != nil {
|
||||
var pqErr *pq.Error
|
||||
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
|
||||
return repository.ErrOIDCProviderDuplicateName
|
||||
}
|
||||
return fmt.Errorf("oidc_providers create: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update writes the mutable fields back. Immutable: id, tenant_id,
|
||||
// created_at. updated_at = NOW().
|
||||
func (r *OIDCProviderRepository) Update(ctx context.Context, p *oidcdomain.OIDCProvider) error {
|
||||
res, err := r.db.ExecContext(ctx, `
|
||||
UPDATE oidc_providers SET
|
||||
name = $2,
|
||||
issuer_url = $3,
|
||||
client_id = $4,
|
||||
client_secret_encrypted = $5,
|
||||
redirect_uri = $6,
|
||||
groups_claim_path = $7,
|
||||
groups_claim_format = $8,
|
||||
fetch_userinfo = $9,
|
||||
scopes = $10,
|
||||
allowed_email_domains = $11,
|
||||
iat_window_seconds = $12,
|
||||
jwks_cache_ttl_seconds = $13,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1`,
|
||||
p.ID, p.Name, p.IssuerURL, p.ClientID,
|
||||
p.ClientSecretEncrypted, p.RedirectURI, p.GroupsClaimPath,
|
||||
p.GroupsClaimFormat, p.FetchUserinfo, pq.StringArray(p.Scopes),
|
||||
pq.StringArray(p.AllowedEmailDomains), p.IATWindowSeconds,
|
||||
p.JWKSCacheTTLSeconds,
|
||||
)
|
||||
if err != nil {
|
||||
var pqErr *pq.Error
|
||||
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
|
||||
return repository.ErrOIDCProviderDuplicateName
|
||||
}
|
||||
return fmt.Errorf("oidc_providers update: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
return repository.ErrOIDCProviderNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a provider by id. Returns ErrOIDCProviderInUse on
|
||||
// SQLSTATE 23503 (foreign_key_violation) — the users table's FK ON
|
||||
// DELETE RESTRICT fires when authenticated users still reference
|
||||
// this provider.
|
||||
func (r *OIDCProviderRepository) Delete(ctx context.Context, id string) error {
|
||||
res, err := r.db.ExecContext(ctx, `DELETE FROM oidc_providers WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
var pqErr *pq.Error
|
||||
if errors.As(err, &pqErr) && pqErr.Code == "23503" {
|
||||
return repository.ErrOIDCProviderInUse
|
||||
}
|
||||
return fmt.Errorf("oidc_providers delete: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
return repository.ErrOIDCProviderNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// GroupRoleMappingRepository (Auth Bundle 2 Phase 2)
|
||||
// =============================================================================
|
||||
|
||||
// GroupRoleMappingRepository is the postgres implementation of
|
||||
// repository.GroupRoleMappingRepository.
|
||||
type GroupRoleMappingRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewGroupRoleMappingRepository constructs a GroupRoleMappingRepository.
|
||||
func NewGroupRoleMappingRepository(db *sql.DB) *GroupRoleMappingRepository {
|
||||
return &GroupRoleMappingRepository{db: db}
|
||||
}
|
||||
|
||||
func scanGroupRoleMapping(row interface{ Scan(...interface{}) error }) (*oidcdomain.GroupRoleMapping, error) {
|
||||
var m oidcdomain.GroupRoleMapping
|
||||
if err := row.Scan(&m.ID, &m.TenantID, &m.ProviderID, &m.GroupName, &m.RoleID, &m.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
// ListByProvider returns every mapping for the named provider, ordered
|
||||
// group_name ASC.
|
||||
func (r *GroupRoleMappingRepository) ListByProvider(ctx context.Context, providerID string) ([]*oidcdomain.GroupRoleMapping, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, tenant_id, provider_id, group_name, role_id, created_at
|
||||
FROM group_role_mappings
|
||||
WHERE provider_id = $1
|
||||
ORDER BY group_name ASC`, providerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("group_role_mappings list_by_provider: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []*oidcdomain.GroupRoleMapping
|
||||
for rows.Next() {
|
||||
m, err := scanGroupRoleMapping(rows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("group_role_mappings scan: %w", err)
|
||||
}
|
||||
out = append(out, m)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// Get returns one mapping by id.
|
||||
func (r *GroupRoleMappingRepository) Get(ctx context.Context, id string) (*oidcdomain.GroupRoleMapping, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, tenant_id, provider_id, group_name, role_id, created_at
|
||||
FROM group_role_mappings WHERE id = $1`, id)
|
||||
m, err := scanGroupRoleMapping(row)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, repository.ErrGroupRoleMappingNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("group_role_mappings get: %w", err)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Add persists a new mapping. Translates SQLSTATE 23505 into
|
||||
// ErrGroupRoleMappingDuplicate.
|
||||
func (r *GroupRoleMappingRepository) Add(ctx context.Context, m *oidcdomain.GroupRoleMapping) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
INSERT INTO group_role_mappings (id, tenant_id, provider_id, group_name, role_id)
|
||||
VALUES ($1, $2, $3, $4, $5)`,
|
||||
m.ID, m.TenantID, m.ProviderID, m.GroupName, m.RoleID)
|
||||
if err != nil {
|
||||
var pqErr *pq.Error
|
||||
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
|
||||
return repository.ErrGroupRoleMappingDuplicate
|
||||
}
|
||||
return fmt.Errorf("group_role_mappings add: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove deletes a mapping by id.
|
||||
func (r *GroupRoleMappingRepository) Remove(ctx context.Context, id string) error {
|
||||
res, err := r.db.ExecContext(ctx, `DELETE FROM group_role_mappings WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("group_role_mappings remove: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
return repository.ErrGroupRoleMappingNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Map resolves IdP-supplied group names against the provider's
|
||||
// mappings. Returns the deduplicated set of role IDs the user should
|
||||
// hold. Empty group_names slice yields empty result; empty result
|
||||
// means fail-closed (no roles, Phase 3 declines to mint a session).
|
||||
func (r *GroupRoleMappingRepository) Map(ctx context.Context, providerID string, groupNames []string) ([]string, error) {
|
||||
if len(groupNames) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT DISTINCT role_id
|
||||
FROM group_role_mappings
|
||||
WHERE provider_id = $1 AND group_name = ANY($2)`,
|
||||
providerID, pq.StringArray(groupNames))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("group_role_mappings map: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []string
|
||||
for rows.Next() {
|
||||
var roleID string
|
||||
if err := rows.Scan(&roleID); err != nil {
|
||||
return nil, fmt.Errorf("group_role_mappings map scan: %w", err)
|
||||
}
|
||||
out = append(out, roleID)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
@@ -0,0 +1,366 @@
|
||||
package postgres_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
oidcdomain "github.com/certctl-io/certctl/internal/auth/oidc/domain"
|
||||
"github.com/certctl-io/certctl/internal/repository"
|
||||
"github.com/certctl-io/certctl/internal/repository/postgres"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// OIDCProviderRepository tests (Auth Bundle 2 Phase 2)
|
||||
//
|
||||
// Schema-per-test isolation via getTestDB().freshSchema(t). Run with:
|
||||
//
|
||||
// go test -count=1 ./internal/repository/postgres/...
|
||||
//
|
||||
// (omit -short; testing.Short() skips all integration tests.)
|
||||
// =============================================================================
|
||||
|
||||
func newValidProvider(suffix string) *oidcdomain.OIDCProvider {
|
||||
return &oidcdomain.OIDCProvider{
|
||||
ID: "op-" + suffix,
|
||||
TenantID: "t-default",
|
||||
Name: "Provider " + suffix,
|
||||
IssuerURL: "https://idp." + suffix + ".example.com",
|
||||
ClientID: "certctl",
|
||||
ClientSecretEncrypted: []byte{0x02, 0x00, 0x01, 0x02, 0x03},
|
||||
RedirectURI: "https://certctl.example.com/auth/oidc/callback",
|
||||
GroupsClaimPath: "groups",
|
||||
GroupsClaimFormat: "string-array",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AllowedEmailDomains: []string{},
|
||||
IATWindowSeconds: 300,
|
||||
JWKSCacheTTLSeconds: 3600,
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCProviderRepository_CreateAndGet(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
repo := postgres.NewOIDCProviderRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
p := newValidProvider("a")
|
||||
if err := repo.Create(ctx, p); err != nil {
|
||||
t.Fatalf("Create: %v", err)
|
||||
}
|
||||
|
||||
got, err := repo.Get(ctx, p.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Get: %v", err)
|
||||
}
|
||||
if got.Name != p.Name {
|
||||
t.Errorf("Name roundtrip: got %q, want %q", got.Name, p.Name)
|
||||
}
|
||||
if got.IssuerURL != p.IssuerURL {
|
||||
t.Errorf("IssuerURL roundtrip mismatch")
|
||||
}
|
||||
// Defaults from the migration kicked in for any unset bool / array.
|
||||
if got.FetchUserinfo != false {
|
||||
t.Errorf("FetchUserinfo default = %v; want false", got.FetchUserinfo)
|
||||
}
|
||||
if len(got.Scopes) != 3 {
|
||||
t.Errorf("Scopes roundtrip count = %d; want 3", len(got.Scopes))
|
||||
}
|
||||
// Defense: client_secret_encrypted column must NOT contain plaintext.
|
||||
// Since we wrote a v2 magic-byte stub, the byte stream comes back as-is.
|
||||
if len(got.ClientSecretEncrypted) == 0 {
|
||||
t.Errorf("ClientSecretEncrypted lost on roundtrip")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCProviderRepository_GetNotFound(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
repo := postgres.NewOIDCProviderRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := repo.Get(ctx, "op-nonexistent")
|
||||
if !errors.Is(err, repository.ErrOIDCProviderNotFound) {
|
||||
t.Errorf("err = %v; want ErrOIDCProviderNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCProviderRepository_DuplicateName(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
repo := postgres.NewOIDCProviderRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
p1 := newValidProvider("dup1")
|
||||
if err := repo.Create(ctx, p1); err != nil {
|
||||
t.Fatalf("Create p1: %v", err)
|
||||
}
|
||||
|
||||
p2 := newValidProvider("dup2")
|
||||
p2.Name = p1.Name // collision on (tenant_id, name)
|
||||
err := repo.Create(ctx, p2)
|
||||
if !errors.Is(err, repository.ErrOIDCProviderDuplicateName) {
|
||||
t.Errorf("Create with duplicate name err = %v; want ErrOIDCProviderDuplicateName", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCProviderRepository_List(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
repo := postgres.NewOIDCProviderRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
for _, suf := range []string{"x", "y", "z"} {
|
||||
if err := repo.Create(ctx, newValidProvider(suf)); err != nil {
|
||||
t.Fatalf("Create %q: %v", suf, err)
|
||||
}
|
||||
}
|
||||
|
||||
out, err := repo.List(ctx, "t-default")
|
||||
if err != nil {
|
||||
t.Fatalf("List: %v", err)
|
||||
}
|
||||
if len(out) != 3 {
|
||||
t.Errorf("List count = %d; want 3", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCProviderRepository_Update(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
repo := postgres.NewOIDCProviderRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
p := newValidProvider("upd")
|
||||
if err := repo.Create(ctx, p); err != nil {
|
||||
t.Fatalf("Create: %v", err)
|
||||
}
|
||||
|
||||
p.Name = "Renamed"
|
||||
p.FetchUserinfo = true
|
||||
if err := repo.Update(ctx, p); err != nil {
|
||||
t.Fatalf("Update: %v", err)
|
||||
}
|
||||
|
||||
got, err := repo.Get(ctx, p.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Get post-update: %v", err)
|
||||
}
|
||||
if got.Name != "Renamed" {
|
||||
t.Errorf("Update did not persist Name; got %q", got.Name)
|
||||
}
|
||||
if !got.FetchUserinfo {
|
||||
t.Errorf("Update did not persist FetchUserinfo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCProviderRepository_DeleteNotFound(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
repo := postgres.NewOIDCProviderRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
err := repo.Delete(ctx, "op-nonexistent")
|
||||
if !errors.Is(err, repository.ErrOIDCProviderNotFound) {
|
||||
t.Errorf("err = %v; want ErrOIDCProviderNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCProviderRepository_DeleteSucceedsWhenNoUsersReference(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
repo := postgres.NewOIDCProviderRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
p := newValidProvider("del")
|
||||
if err := repo.Create(ctx, p); err != nil {
|
||||
t.Fatalf("Create: %v", err)
|
||||
}
|
||||
if err := repo.Delete(ctx, p.ID); err != nil {
|
||||
t.Fatalf("Delete: %v", err)
|
||||
}
|
||||
_, err := repo.Get(ctx, p.ID)
|
||||
if !errors.Is(err, repository.ErrOIDCProviderNotFound) {
|
||||
t.Errorf("post-delete Get err = %v; want ErrOIDCProviderNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOIDCProviderRepository_DeleteRefusedWhenUsersReference pins the
|
||||
// FK ON DELETE RESTRICT translation. With at least one users row
|
||||
// referencing the provider, Delete must return ErrOIDCProviderInUse.
|
||||
func TestOIDCProviderRepository_DeleteRefusedWhenUsersReference(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
providerRepo := postgres.NewOIDCProviderRepository(db)
|
||||
userRepo := postgres.NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
p := newValidProvider("inuse")
|
||||
if err := providerRepo.Create(ctx, p); err != nil {
|
||||
t.Fatalf("Create provider: %v", err)
|
||||
}
|
||||
u := &struct{ ID string }{ID: "u-test"}
|
||||
_ = u
|
||||
user := newValidUser("inuse", p.ID)
|
||||
if err := userRepo.Create(ctx, user); err != nil {
|
||||
t.Fatalf("Create user: %v", err)
|
||||
}
|
||||
|
||||
err := providerRepo.Delete(ctx, p.ID)
|
||||
if !errors.Is(err, repository.ErrOIDCProviderInUse) {
|
||||
t.Errorf("Delete with referencing user err = %v; want ErrOIDCProviderInUse", err)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// GroupRoleMappingRepository
|
||||
// =============================================================================
|
||||
|
||||
func TestGroupRoleMappingRepository_AddListMap(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
providerRepo := postgres.NewOIDCProviderRepository(db)
|
||||
mappingRepo := postgres.NewGroupRoleMappingRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
p := newValidProvider("grm")
|
||||
if err := providerRepo.Create(ctx, p); err != nil {
|
||||
t.Fatalf("Create provider: %v", err)
|
||||
}
|
||||
|
||||
mappings := []*oidcdomain.GroupRoleMapping{
|
||||
{ID: "grm-1", TenantID: "t-default", ProviderID: p.ID, GroupName: "engineers", RoleID: "r-operator"},
|
||||
{ID: "grm-2", TenantID: "t-default", ProviderID: p.ID, GroupName: "platform-admins", RoleID: "r-admin"},
|
||||
{ID: "grm-3", TenantID: "t-default", ProviderID: p.ID, GroupName: "compliance", RoleID: "r-auditor"},
|
||||
}
|
||||
for _, m := range mappings {
|
||||
if err := mappingRepo.Add(ctx, m); err != nil {
|
||||
t.Fatalf("Add %s: %v", m.GroupName, err)
|
||||
}
|
||||
}
|
||||
|
||||
listed, err := mappingRepo.ListByProvider(ctx, p.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByProvider: %v", err)
|
||||
}
|
||||
if len(listed) != 3 {
|
||||
t.Errorf("ListByProvider count = %d; want 3", len(listed))
|
||||
}
|
||||
|
||||
// Map: user has groups [engineers, marketing]. Marketing has no
|
||||
// mapping; only engineers maps to r-operator.
|
||||
roleIDs, err := mappingRepo.Map(ctx, p.ID, []string{"engineers", "marketing"})
|
||||
if err != nil {
|
||||
t.Fatalf("Map: %v", err)
|
||||
}
|
||||
if len(roleIDs) != 1 || roleIDs[0] != "r-operator" {
|
||||
t.Errorf("Map(engineers, marketing) = %v; want [r-operator]", roleIDs)
|
||||
}
|
||||
|
||||
// Map: user has groups [engineers, platform-admins]. Both map.
|
||||
roleIDs, err = mappingRepo.Map(ctx, p.ID, []string{"engineers", "platform-admins"})
|
||||
if err != nil {
|
||||
t.Fatalf("Map (multi): %v", err)
|
||||
}
|
||||
if len(roleIDs) != 2 {
|
||||
t.Errorf("Map(engineers, platform-admins) count = %d; want 2", len(roleIDs))
|
||||
}
|
||||
|
||||
// Map empty groups: empty result, no error (Phase 3 fail-closes).
|
||||
roleIDs, err = mappingRepo.Map(ctx, p.ID, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Map(nil): %v", err)
|
||||
}
|
||||
if len(roleIDs) != 0 {
|
||||
t.Errorf("Map(nil) returned %d roles; want 0", len(roleIDs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupRoleMappingRepository_DuplicateRejected(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
providerRepo := postgres.NewOIDCProviderRepository(db)
|
||||
mappingRepo := postgres.NewGroupRoleMappingRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
p := newValidProvider("dup")
|
||||
if err := providerRepo.Create(ctx, p); err != nil {
|
||||
t.Fatalf("Create provider: %v", err)
|
||||
}
|
||||
m := &oidcdomain.GroupRoleMapping{
|
||||
ID: "grm-dup-1", TenantID: "t-default", ProviderID: p.ID,
|
||||
GroupName: "engineers", RoleID: "r-operator",
|
||||
}
|
||||
if err := mappingRepo.Add(ctx, m); err != nil {
|
||||
t.Fatalf("Add first: %v", err)
|
||||
}
|
||||
m2 := &oidcdomain.GroupRoleMapping{
|
||||
ID: "grm-dup-2", TenantID: "t-default", ProviderID: p.ID,
|
||||
GroupName: "engineers", RoleID: "r-operator",
|
||||
}
|
||||
err := mappingRepo.Add(ctx, m2)
|
||||
if !errors.Is(err, repository.ErrGroupRoleMappingDuplicate) {
|
||||
t.Errorf("Add duplicate err = %v; want ErrGroupRoleMappingDuplicate", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupRoleMappingRepository_ProviderDeleteCascades(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
providerRepo := postgres.NewOIDCProviderRepository(db)
|
||||
mappingRepo := postgres.NewGroupRoleMappingRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
p := newValidProvider("cascade")
|
||||
if err := providerRepo.Create(ctx, p); err != nil {
|
||||
t.Fatalf("Create provider: %v", err)
|
||||
}
|
||||
for i, group := range []string{"a", "b", "c"} {
|
||||
m := &oidcdomain.GroupRoleMapping{
|
||||
ID: "grm-cas-" + string(rune('a'+i)), TenantID: "t-default",
|
||||
ProviderID: p.ID, GroupName: group, RoleID: "r-viewer",
|
||||
}
|
||||
if err := mappingRepo.Add(ctx, m); err != nil {
|
||||
t.Fatalf("Add %s: %v", group, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete provider: ON DELETE CASCADE on group_role_mappings.provider_id
|
||||
// should drop the 3 mappings too.
|
||||
if err := providerRepo.Delete(ctx, p.ID); err != nil {
|
||||
t.Fatalf("Delete provider: %v", err)
|
||||
}
|
||||
listed, err := mappingRepo.ListByProvider(ctx, p.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByProvider post-cascade: %v", err)
|
||||
}
|
||||
if len(listed) != 0 {
|
||||
t.Errorf("CASCADE failed; %d mappings remain", len(listed))
|
||||
}
|
||||
}
|
||||
|
||||
// quiet unused-import keepalives so single-test runs don't drop them.
|
||||
var _ = time.Now
|
||||
@@ -0,0 +1,350 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
sessiondomain "github.com/certctl-io/certctl/internal/auth/session/domain"
|
||||
"github.com/certctl-io/certctl/internal/repository"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// SessionRepository (Auth Bundle 2 Phase 2)
|
||||
// =============================================================================
|
||||
|
||||
// SessionRepository is the postgres implementation of
|
||||
// repository.SessionRepository.
|
||||
type SessionRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewSessionRepository constructs a SessionRepository.
|
||||
func NewSessionRepository(db *sql.DB) *SessionRepository {
|
||||
return &SessionRepository{db: db}
|
||||
}
|
||||
|
||||
const sessionColumns = `id, tenant_id, actor_id, actor_type,
|
||||
signing_key_id, is_pre_login, csrf_token_hash,
|
||||
idle_expires_at, absolute_expires_at, created_at, last_seen_at,
|
||||
ip_address, user_agent, revoked_at`
|
||||
|
||||
func scanSession(row interface{ Scan(...interface{}) error }) (*sessiondomain.Session, error) {
|
||||
var s sessiondomain.Session
|
||||
var revokedAt sql.NullTime
|
||||
if err := row.Scan(
|
||||
&s.ID, &s.TenantID, &s.ActorID, &s.ActorType,
|
||||
&s.SigningKeyID, &s.IsPreLogin, &s.CSRFTokenHash,
|
||||
&s.IdleExpiresAt, &s.AbsoluteExpiresAt, &s.CreatedAt, &s.LastSeenAt,
|
||||
&s.IPAddress, &s.UserAgent, &revokedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if revokedAt.Valid {
|
||||
s.RevokedAt = &revokedAt.Time
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
// Create persists a session row. Caller MUST have called s.Validate().
|
||||
func (r *SessionRepository) Create(ctx context.Context, s *sessiondomain.Session) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
INSERT INTO sessions (
|
||||
id, tenant_id, actor_id, actor_type, signing_key_id,
|
||||
is_pre_login, csrf_token_hash, idle_expires_at,
|
||||
absolute_expires_at, created_at, last_seen_at,
|
||||
ip_address, user_agent
|
||||
) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13)`,
|
||||
s.ID, s.TenantID, s.ActorID, s.ActorType, s.SigningKeyID,
|
||||
s.IsPreLogin, s.CSRFTokenHash, s.IdleExpiresAt,
|
||||
s.AbsoluteExpiresAt, s.CreatedAt, s.LastSeenAt,
|
||||
s.IPAddress, s.UserAgent)
|
||||
if err != nil {
|
||||
var pqErr *pq.Error
|
||||
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
|
||||
return repository.ErrAuthDuplicateName
|
||||
}
|
||||
return fmt.Errorf("sessions create: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns a session by id. Returns the row even if revoked /
|
||||
// expired; the service layer handles the disposition.
|
||||
func (r *SessionRepository) Get(ctx context.Context, id string) (*sessiondomain.Session, error) {
|
||||
row := r.db.QueryRowContext(ctx, `SELECT `+sessionColumns+` FROM sessions WHERE id = $1`, id)
|
||||
s, err := scanSession(row)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, repository.ErrSessionNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("sessions get: %w", err)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// ListByActor returns active (non-revoked, non-expired, non-pre-login)
|
||||
// sessions for an actor.
|
||||
func (r *SessionRepository) ListByActor(ctx context.Context, actorID, actorType, tenantID string) ([]*sessiondomain.Session, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT `+sessionColumns+`
|
||||
FROM sessions
|
||||
WHERE actor_id = $1
|
||||
AND actor_type = $2
|
||||
AND tenant_id = $3
|
||||
AND revoked_at IS NULL
|
||||
AND is_pre_login = FALSE
|
||||
AND absolute_expires_at > NOW()
|
||||
ORDER BY created_at DESC`,
|
||||
actorID, actorType, tenantID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sessions list_by_actor: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []*sessiondomain.Session
|
||||
for rows.Next() {
|
||||
s, err := scanSession(rows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sessions scan: %w", err)
|
||||
}
|
||||
out = append(out, s)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// UpdateLastSeen sets last_seen_at = NOW() for the named session.
|
||||
func (r *SessionRepository) UpdateLastSeen(ctx context.Context, id string) error {
|
||||
res, err := r.db.ExecContext(ctx, `UPDATE sessions SET last_seen_at = NOW() WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sessions update_last_seen: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
return repository.ErrSessionNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Revoke sets revoked_at = NOW() for the named session. Idempotent:
|
||||
// re-revoking an already-revoked session is a no-op (returns nil).
|
||||
func (r *SessionRepository) Revoke(ctx context.Context, id string) error {
|
||||
res, err := r.db.ExecContext(ctx, `UPDATE sessions SET revoked_at = NOW() WHERE id = $1 AND revoked_at IS NULL`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sessions revoke: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
// Distinguish "not found" from "already revoked" by re-querying.
|
||||
row := r.db.QueryRowContext(ctx, `SELECT 1 FROM sessions WHERE id = $1`, id)
|
||||
var x int
|
||||
if err := row.Scan(&x); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return repository.ErrSessionNotFound
|
||||
}
|
||||
return fmt.Errorf("sessions revoke probe: %w", err)
|
||||
}
|
||||
// Row exists but already revoked: idempotent success.
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeAllForActor sets revoked_at = NOW() on every active session
|
||||
// for an actor. Returns nil on zero matches (idempotent).
|
||||
func (r *SessionRepository) RevokeAllForActor(ctx context.Context, actorID, actorType, tenantID string) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
UPDATE sessions SET revoked_at = NOW()
|
||||
WHERE actor_id = $1 AND actor_type = $2 AND tenant_id = $3 AND revoked_at IS NULL`,
|
||||
actorID, actorType, tenantID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sessions revoke_all_for_actor: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GarbageCollectExpired deletes:
|
||||
// - Sessions whose absolute_expires_at < NOW() (post-login expired).
|
||||
// - Pre-login sessions older than 10 minutes.
|
||||
//
|
||||
// Returns the number of rows deleted across both classes.
|
||||
func (r *SessionRepository) GarbageCollectExpired(ctx context.Context) (int, error) {
|
||||
res, err := r.db.ExecContext(ctx, `
|
||||
DELETE FROM sessions
|
||||
WHERE absolute_expires_at < NOW()
|
||||
OR (is_pre_login = TRUE AND created_at < NOW() - INTERVAL '10 minutes')`)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("sessions garbage_collect: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
return int(n), nil
|
||||
}
|
||||
|
||||
// Delete unconditionally removes a session row.
|
||||
func (r *SessionRepository) Delete(ctx context.Context, id string) error {
|
||||
res, err := r.db.ExecContext(ctx, `DELETE FROM sessions WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sessions delete: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
return repository.ErrSessionNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SessionSigningKeyRepository (Auth Bundle 2 Phase 2)
|
||||
// =============================================================================
|
||||
|
||||
// SessionSigningKeyRepository is the postgres implementation of
|
||||
// repository.SessionSigningKeyRepository.
|
||||
type SessionSigningKeyRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewSessionSigningKeyRepository constructs a SessionSigningKeyRepository.
|
||||
func NewSessionSigningKeyRepository(db *sql.DB) *SessionSigningKeyRepository {
|
||||
return &SessionSigningKeyRepository{db: db}
|
||||
}
|
||||
|
||||
const sessionSigningKeyColumns = `id, tenant_id, key_material_encrypted, created_at, retired_at`
|
||||
|
||||
func scanSessionSigningKey(row interface{ Scan(...interface{}) error }) (*sessiondomain.SessionSigningKey, error) {
|
||||
var k sessiondomain.SessionSigningKey
|
||||
var retiredAt sql.NullTime
|
||||
if err := row.Scan(&k.ID, &k.TenantID, &k.KeyMaterialEncrypted, &k.CreatedAt, &retiredAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if retiredAt.Valid {
|
||||
k.RetiredAt = &retiredAt.Time
|
||||
}
|
||||
return &k, nil
|
||||
}
|
||||
|
||||
// List returns every signing key in the tenant, including retired ones.
|
||||
func (r *SessionSigningKeyRepository) List(ctx context.Context, tenantID string) ([]*sessiondomain.SessionSigningKey, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT `+sessionSigningKeyColumns+` FROM session_signing_keys WHERE tenant_id = $1 ORDER BY created_at DESC`,
|
||||
tenantID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session_signing_keys list: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []*sessiondomain.SessionSigningKey
|
||||
for rows.Next() {
|
||||
k, err := scanSessionSigningKey(rows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session_signing_keys scan: %w", err)
|
||||
}
|
||||
out = append(out, k)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// GetActive returns the most-recently-created non-retired key. Returns
|
||||
// ErrSessionSigningKeyNotFound when no non-retired key exists.
|
||||
func (r *SessionSigningKeyRepository) GetActive(ctx context.Context, tenantID string) (*sessiondomain.SessionSigningKey, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT `+sessionSigningKeyColumns+`
|
||||
FROM session_signing_keys
|
||||
WHERE tenant_id = $1 AND retired_at IS NULL
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1`, tenantID)
|
||||
k, err := scanSessionSigningKey(row)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, repository.ErrSessionSigningKeyNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("session_signing_keys get_active: %w", err)
|
||||
}
|
||||
return k, nil
|
||||
}
|
||||
|
||||
// Get returns a key by id (including retired keys; Phase 4's Validate
|
||||
// consults this for cookies signed under retired-but-in-retention keys).
|
||||
func (r *SessionSigningKeyRepository) Get(ctx context.Context, id string) (*sessiondomain.SessionSigningKey, error) {
|
||||
row := r.db.QueryRowContext(ctx,
|
||||
`SELECT `+sessionSigningKeyColumns+` FROM session_signing_keys WHERE id = $1`, id)
|
||||
k, err := scanSessionSigningKey(row)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, repository.ErrSessionSigningKeyNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("session_signing_keys get: %w", err)
|
||||
}
|
||||
return k, nil
|
||||
}
|
||||
|
||||
// Add persists a new signing key. Caller MUST have called k.Validate().
|
||||
func (r *SessionSigningKeyRepository) Add(ctx context.Context, k *sessiondomain.SessionSigningKey) error {
|
||||
if k.CreatedAt.IsZero() {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
INSERT INTO session_signing_keys (id, tenant_id, key_material_encrypted)
|
||||
VALUES ($1, $2, $3)`,
|
||||
k.ID, k.TenantID, k.KeyMaterialEncrypted)
|
||||
if err != nil {
|
||||
return fmt.Errorf("session_signing_keys add: %w", err)
|
||||
}
|
||||
// Read the row back to populate CreatedAt.
|
||||
row := r.db.QueryRowContext(ctx, `SELECT created_at FROM session_signing_keys WHERE id = $1`, k.ID)
|
||||
if err := row.Scan(&k.CreatedAt); err != nil {
|
||||
return fmt.Errorf("session_signing_keys add (read created_at): %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
INSERT INTO session_signing_keys (id, tenant_id, key_material_encrypted, created_at)
|
||||
VALUES ($1, $2, $3, $4)`,
|
||||
k.ID, k.TenantID, k.KeyMaterialEncrypted, k.CreatedAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("session_signing_keys add: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Retire marks an active key as retired (sets retired_at = NOW()).
|
||||
// Idempotent: re-retiring an already-retired key is a no-op.
|
||||
func (r *SessionSigningKeyRepository) Retire(ctx context.Context, id string) error {
|
||||
res, err := r.db.ExecContext(ctx,
|
||||
`UPDATE session_signing_keys SET retired_at = NOW() WHERE id = $1 AND retired_at IS NULL`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("session_signing_keys retire: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
// Distinguish not-found vs already-retired.
|
||||
row := r.db.QueryRowContext(ctx, `SELECT 1 FROM session_signing_keys WHERE id = $1`, id)
|
||||
var x int
|
||||
if err := row.Scan(&x); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return repository.ErrSessionSigningKeyNotFound
|
||||
}
|
||||
return fmt.Errorf("session_signing_keys retire probe: %w", err)
|
||||
}
|
||||
// Row exists but already retired: idempotent success.
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete unconditionally removes a signing key. Returns
|
||||
// ErrSessionSigningKeyInUse on SQLSTATE 23503 (FK ON DELETE RESTRICT
|
||||
// from sessions.signing_key_id).
|
||||
func (r *SessionSigningKeyRepository) Delete(ctx context.Context, id string) error {
|
||||
res, err := r.db.ExecContext(ctx, `DELETE FROM session_signing_keys WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
var pqErr *pq.Error
|
||||
if errors.As(err, &pqErr) && pqErr.Code == "23503" {
|
||||
return repository.ErrSessionSigningKeyInUse
|
||||
}
|
||||
return fmt.Errorf("session_signing_keys delete: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
return repository.ErrSessionSigningKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,431 @@
|
||||
package postgres_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
sessiondomain "github.com/certctl-io/certctl/internal/auth/session/domain"
|
||||
"github.com/certctl-io/certctl/internal/repository"
|
||||
"github.com/certctl-io/certctl/internal/repository/postgres"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// SessionSigningKey tests
|
||||
// =============================================================================
|
||||
|
||||
func newValidSigningKey(suffix string) *sessiondomain.SessionSigningKey {
|
||||
return &sessiondomain.SessionSigningKey{
|
||||
ID: "sk-" + suffix,
|
||||
TenantID: "t-default",
|
||||
KeyMaterialEncrypted: []byte{0x02, 0x00, 0x01, 0x02, 0x03},
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionSigningKeyRepository_AddAndGetActive(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
repo := postgres.NewSessionSigningKeyRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
k := newValidSigningKey("a")
|
||||
if err := repo.Add(ctx, k); err != nil {
|
||||
t.Fatalf("Add: %v", err)
|
||||
}
|
||||
if k.CreatedAt.IsZero() {
|
||||
t.Errorf("Add did not populate CreatedAt")
|
||||
}
|
||||
|
||||
got, err := repo.GetActive(ctx, "t-default")
|
||||
if err != nil {
|
||||
t.Fatalf("GetActive: %v", err)
|
||||
}
|
||||
if got.ID != k.ID {
|
||||
t.Errorf("GetActive returned %q; want %q", got.ID, k.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionSigningKeyRepository_GetActiveSkipsRetired(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
repo := postgres.NewSessionSigningKeyRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Add older key, retire it. Add newer key. GetActive must return newer.
|
||||
older := newValidSigningKey("older")
|
||||
if err := repo.Add(ctx, older); err != nil {
|
||||
t.Fatalf("Add older: %v", err)
|
||||
}
|
||||
if err := repo.Retire(ctx, older.ID); err != nil {
|
||||
t.Fatalf("Retire older: %v", err)
|
||||
}
|
||||
// Sleep a millisecond so created_at orders deterministically.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
newer := newValidSigningKey("newer")
|
||||
if err := repo.Add(ctx, newer); err != nil {
|
||||
t.Fatalf("Add newer: %v", err)
|
||||
}
|
||||
|
||||
got, err := repo.GetActive(ctx, "t-default")
|
||||
if err != nil {
|
||||
t.Fatalf("GetActive: %v", err)
|
||||
}
|
||||
if got.ID != newer.ID {
|
||||
t.Errorf("GetActive returned %q; want %q (older was retired)", got.ID, newer.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionSigningKeyRepository_GetActiveReturnsNotFound(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
repo := postgres.NewSessionSigningKeyRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := repo.GetActive(ctx, "t-default")
|
||||
if !errors.Is(err, repository.ErrSessionSigningKeyNotFound) {
|
||||
t.Errorf("err = %v; want ErrSessionSigningKeyNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionSigningKeyRepository_RetireIsIdempotent(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
repo := postgres.NewSessionSigningKeyRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
k := newValidSigningKey("retire")
|
||||
if err := repo.Add(ctx, k); err != nil {
|
||||
t.Fatalf("Add: %v", err)
|
||||
}
|
||||
if err := repo.Retire(ctx, k.ID); err != nil {
|
||||
t.Fatalf("first Retire: %v", err)
|
||||
}
|
||||
if err := repo.Retire(ctx, k.ID); err != nil {
|
||||
t.Errorf("second Retire (already retired) should be idempotent; got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionSigningKeyRepository_DeleteRefusedWhenSessionsReference(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
keyRepo := postgres.NewSessionSigningKeyRepository(db)
|
||||
sessRepo := postgres.NewSessionRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
k := newValidSigningKey("inuse")
|
||||
if err := keyRepo.Add(ctx, k); err != nil {
|
||||
t.Fatalf("Add key: %v", err)
|
||||
}
|
||||
s := newValidSession("s1", k.ID)
|
||||
if err := sessRepo.Create(ctx, s); err != nil {
|
||||
t.Fatalf("Create session: %v", err)
|
||||
}
|
||||
|
||||
err := keyRepo.Delete(ctx, k.ID)
|
||||
if !errors.Is(err, repository.ErrSessionSigningKeyInUse) {
|
||||
t.Errorf("Delete with referencing session err = %v; want ErrSessionSigningKeyInUse", err)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Session tests
|
||||
// =============================================================================
|
||||
|
||||
func newValidSession(suffix, signingKeyID string) *sessiondomain.Session {
|
||||
now := time.Now().UTC().Truncate(time.Microsecond)
|
||||
return &sessiondomain.Session{
|
||||
ID: "ses-" + suffix,
|
||||
TenantID: "t-default",
|
||||
ActorID: "u-" + suffix,
|
||||
ActorType: "User",
|
||||
SigningKeyID: signingKeyID,
|
||||
IsPreLogin: false,
|
||||
CSRFTokenHash: strings.Repeat("a", 64),
|
||||
IdleExpiresAt: now.Add(time.Hour),
|
||||
AbsoluteExpiresAt: now.Add(8 * time.Hour),
|
||||
CreatedAt: now,
|
||||
LastSeenAt: now,
|
||||
IPAddress: "10.0.0.1",
|
||||
UserAgent: "Mozilla/5.0",
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionRepository_CreateAndGet(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
keyRepo := postgres.NewSessionSigningKeyRepository(db)
|
||||
sessRepo := postgres.NewSessionRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
k := newValidSigningKey("k1")
|
||||
if err := keyRepo.Add(ctx, k); err != nil {
|
||||
t.Fatalf("Add key: %v", err)
|
||||
}
|
||||
s := newValidSession("s1", k.ID)
|
||||
if err := sessRepo.Create(ctx, s); err != nil {
|
||||
t.Fatalf("Create: %v", err)
|
||||
}
|
||||
|
||||
got, err := sessRepo.Get(ctx, s.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Get: %v", err)
|
||||
}
|
||||
if got.ActorID != s.ActorID {
|
||||
t.Errorf("ActorID roundtrip mismatch")
|
||||
}
|
||||
if got.RevokedAt != nil {
|
||||
t.Errorf("RevokedAt should be nil on fresh session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionRepository_GetNotFound(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
repo := postgres.NewSessionRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := repo.Get(ctx, "ses-nonexistent")
|
||||
if !errors.Is(err, repository.ErrSessionNotFound) {
|
||||
t.Errorf("err = %v; want ErrSessionNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionRepository_RevokeAndGet(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
keyRepo := postgres.NewSessionSigningKeyRepository(db)
|
||||
sessRepo := postgres.NewSessionRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
k := newValidSigningKey("k2")
|
||||
if err := keyRepo.Add(ctx, k); err != nil {
|
||||
t.Fatalf("Add key: %v", err)
|
||||
}
|
||||
s := newValidSession("s2", k.ID)
|
||||
if err := sessRepo.Create(ctx, s); err != nil {
|
||||
t.Fatalf("Create: %v", err)
|
||||
}
|
||||
|
||||
if err := sessRepo.Revoke(ctx, s.ID); err != nil {
|
||||
t.Fatalf("Revoke: %v", err)
|
||||
}
|
||||
got, err := sessRepo.Get(ctx, s.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Get post-revoke: %v", err)
|
||||
}
|
||||
if got.RevokedAt == nil {
|
||||
t.Errorf("RevokedAt nil after Revoke")
|
||||
}
|
||||
|
||||
// Idempotent re-revoke: returns nil, no panic, no double-update.
|
||||
if err := sessRepo.Revoke(ctx, s.ID); err != nil {
|
||||
t.Errorf("re-Revoke (idempotent) err = %v; want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionRepository_RevokeNotFound(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
repo := postgres.NewSessionRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
if err := repo.Revoke(ctx, "ses-nonexistent"); !errors.Is(err, repository.ErrSessionNotFound) {
|
||||
t.Errorf("err = %v; want ErrSessionNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionRepository_ListByActorActiveOnly(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
keyRepo := postgres.NewSessionSigningKeyRepository(db)
|
||||
sessRepo := postgres.NewSessionRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
k := newValidSigningKey("la")
|
||||
if err := keyRepo.Add(ctx, k); err != nil {
|
||||
t.Fatalf("Add key: %v", err)
|
||||
}
|
||||
// 3 active + 1 revoked + 1 pre-login.
|
||||
for i, suf := range []string{"a1", "a2", "a3"} {
|
||||
s := newValidSession(suf, k.ID)
|
||||
s.ActorID = "u-list-actor"
|
||||
// uniqueness: stagger created_at so list ordering is stable
|
||||
s.CreatedAt = s.CreatedAt.Add(time.Duration(i) * time.Millisecond)
|
||||
if err := sessRepo.Create(ctx, s); err != nil {
|
||||
t.Fatalf("Create %s: %v", suf, err)
|
||||
}
|
||||
}
|
||||
revoked := newValidSession("rev", k.ID)
|
||||
revoked.ActorID = "u-list-actor"
|
||||
if err := sessRepo.Create(ctx, revoked); err != nil {
|
||||
t.Fatalf("Create revoked: %v", err)
|
||||
}
|
||||
if err := sessRepo.Revoke(ctx, revoked.ID); err != nil {
|
||||
t.Fatalf("Revoke: %v", err)
|
||||
}
|
||||
preLogin := newValidSession("pre", k.ID)
|
||||
preLogin.ActorID = "u-list-actor"
|
||||
preLogin.IsPreLogin = true
|
||||
preLogin.CSRFTokenHash = "" // pre-login rows have no CSRF token
|
||||
if err := sessRepo.Create(ctx, preLogin); err != nil {
|
||||
t.Fatalf("Create pre-login: %v", err)
|
||||
}
|
||||
|
||||
out, err := sessRepo.ListByActor(ctx, "u-list-actor", "User", "t-default")
|
||||
if err != nil {
|
||||
t.Fatalf("ListByActor: %v", err)
|
||||
}
|
||||
if len(out) != 3 {
|
||||
t.Errorf("ListByActor count = %d; want 3 (revoked + pre-login excluded)", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionRepository_RevokeAllForActor(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
keyRepo := postgres.NewSessionSigningKeyRepository(db)
|
||||
sessRepo := postgres.NewSessionRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
k := newValidSigningKey("ra")
|
||||
if err := keyRepo.Add(ctx, k); err != nil {
|
||||
t.Fatalf("Add key: %v", err)
|
||||
}
|
||||
// 3 sessions for one actor.
|
||||
for _, suf := range []string{"r1", "r2", "r3"} {
|
||||
s := newValidSession(suf, k.ID)
|
||||
s.ActorID = "u-fired"
|
||||
if err := sessRepo.Create(ctx, s); err != nil {
|
||||
t.Fatalf("Create %s: %v", suf, err)
|
||||
}
|
||||
}
|
||||
if err := sessRepo.RevokeAllForActor(ctx, "u-fired", "User", "t-default"); err != nil {
|
||||
t.Fatalf("RevokeAllForActor: %v", err)
|
||||
}
|
||||
out, err := sessRepo.ListByActor(ctx, "u-fired", "User", "t-default")
|
||||
if err != nil {
|
||||
t.Fatalf("ListByActor post-revoke: %v", err)
|
||||
}
|
||||
if len(out) != 0 {
|
||||
t.Errorf("RevokeAllForActor left %d sessions active; want 0", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionRepository_GarbageCollectExpired(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
keyRepo := postgres.NewSessionSigningKeyRepository(db)
|
||||
sessRepo := postgres.NewSessionRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
k := newValidSigningKey("gc")
|
||||
if err := keyRepo.Add(ctx, k); err != nil {
|
||||
t.Fatalf("Add key: %v", err)
|
||||
}
|
||||
|
||||
// One session with absolute expiry in the past (write directly via SQL
|
||||
// to bypass the CHECK constraints; this simulates a row that aged
|
||||
// past expiry without GC having run yet).
|
||||
now := time.Now().UTC()
|
||||
old := time.Now().UTC().Add(-2 * time.Hour)
|
||||
older := time.Now().UTC().Add(-3 * time.Hour)
|
||||
_, err := db.ExecContext(ctx, `
|
||||
INSERT INTO sessions (id, tenant_id, actor_id, actor_type, signing_key_id,
|
||||
is_pre_login, csrf_token_hash, idle_expires_at, absolute_expires_at,
|
||||
created_at, last_seen_at, ip_address, user_agent)
|
||||
VALUES ($1, 't-default', 'u-gc', 'User', $2, FALSE, '',
|
||||
$3, $4, $5, $5, '', '')`,
|
||||
"ses-expired", k.ID, older, old, time.Now().UTC().Add(-4*time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("seed expired: %v", err)
|
||||
}
|
||||
|
||||
// One pre-login row older than 10 minutes.
|
||||
_, err = db.ExecContext(ctx, `
|
||||
INSERT INTO sessions (id, tenant_id, actor_id, actor_type, signing_key_id,
|
||||
is_pre_login, csrf_token_hash, idle_expires_at, absolute_expires_at,
|
||||
created_at, last_seen_at, ip_address, user_agent)
|
||||
VALUES ($1, 't-default', 'u-gc', 'User', $2, TRUE, '',
|
||||
$3, $4, $5, $5, '', '')`,
|
||||
"ses-prelogin-old", k.ID,
|
||||
now.Add(-15*time.Minute).Add(time.Hour), // idle in future relative to created
|
||||
now.Add(-15*time.Minute).Add(2*time.Hour), // absolute > idle, both > created
|
||||
now.Add(-15*time.Minute)) // created 15 min ago (older than 10 min TTL)
|
||||
if err != nil {
|
||||
t.Fatalf("seed pre-login: %v", err)
|
||||
}
|
||||
|
||||
// One active session (NOT to be GC'd).
|
||||
active := newValidSession("active", k.ID)
|
||||
active.ActorID = "u-gc"
|
||||
if err := sessRepo.Create(ctx, active); err != nil {
|
||||
t.Fatalf("seed active: %v", err)
|
||||
}
|
||||
|
||||
n, err := sessRepo.GarbageCollectExpired(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GC: %v", err)
|
||||
}
|
||||
if n != 2 {
|
||||
t.Errorf("GC deleted %d rows; want 2 (expired + old pre-login)", n)
|
||||
}
|
||||
|
||||
// Active session survives.
|
||||
if _, err := sessRepo.Get(ctx, active.ID); err != nil {
|
||||
t.Errorf("active session should survive GC; got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionRepository_UpdateLastSeen(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
keyRepo := postgres.NewSessionSigningKeyRepository(db)
|
||||
sessRepo := postgres.NewSessionRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
k := newValidSigningKey("uls")
|
||||
if err := keyRepo.Add(ctx, k); err != nil {
|
||||
t.Fatalf("Add key: %v", err)
|
||||
}
|
||||
s := newValidSession("uls", k.ID)
|
||||
if err := sessRepo.Create(ctx, s); err != nil {
|
||||
t.Fatalf("Create: %v", err)
|
||||
}
|
||||
originalSeen := s.LastSeenAt
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
if err := sessRepo.UpdateLastSeen(ctx, s.ID); err != nil {
|
||||
t.Fatalf("UpdateLastSeen: %v", err)
|
||||
}
|
||||
got, _ := sessRepo.Get(ctx, s.ID)
|
||||
if !got.LastSeenAt.After(originalSeen) {
|
||||
t.Errorf("LastSeenAt did not advance after UpdateLastSeen")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
userdomain "github.com/certctl-io/certctl/internal/auth/user/domain"
|
||||
"github.com/certctl-io/certctl/internal/repository"
|
||||
)
|
||||
|
||||
// UserRepository is the postgres implementation of
|
||||
// repository.UserRepository (Auth Bundle 2 Phase 2).
|
||||
type UserRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewUserRepository constructs a UserRepository.
|
||||
func NewUserRepository(db *sql.DB) *UserRepository {
|
||||
return &UserRepository{db: db}
|
||||
}
|
||||
|
||||
const userColumns = `id, tenant_id, email, display_name, oidc_subject,
|
||||
oidc_provider_id, last_login_at, webauthn_credentials,
|
||||
created_at, updated_at`
|
||||
|
||||
func scanUser(row interface{ Scan(...interface{}) error }) (*userdomain.User, error) {
|
||||
var u userdomain.User
|
||||
if err := row.Scan(
|
||||
&u.ID, &u.TenantID, &u.Email, &u.DisplayName, &u.OIDCSubject,
|
||||
&u.OIDCProviderID, &u.LastLoginAt, &u.WebAuthnCredentials,
|
||||
&u.CreatedAt, &u.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
// Get returns one user by id.
|
||||
func (r *UserRepository) Get(ctx context.Context, id string) (*userdomain.User, error) {
|
||||
row := r.db.QueryRowContext(ctx, `SELECT `+userColumns+` FROM users WHERE id = $1`, id)
|
||||
u, err := scanUser(row)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, repository.ErrUserNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("users get: %w", err)
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// GetByOIDCSubject is the Phase 3 hot-path lookup at login time.
|
||||
// Returns ErrUserNotFound if no row matches the (provider, subject)
|
||||
// tuple — Phase 3's HandleCallback then creates the row via Create.
|
||||
func (r *UserRepository) GetByOIDCSubject(ctx context.Context, providerID, subject string) (*userdomain.User, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT `+userColumns+`
|
||||
FROM users
|
||||
WHERE oidc_provider_id = $1 AND oidc_subject = $2`,
|
||||
providerID, subject)
|
||||
u, err := scanUser(row)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, repository.ErrUserNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("users get_by_oidc_subject: %w", err)
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// Create persists a new user. Translates SQLSTATE 23505 into
|
||||
// ErrUserDuplicateOIDCSubject (the unique constraint on
|
||||
// (oidc_provider_id, oidc_subject)).
|
||||
func (r *UserRepository) Create(ctx context.Context, u *userdomain.User) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
INSERT INTO users (
|
||||
id, tenant_id, email, display_name, oidc_subject,
|
||||
oidc_provider_id, last_login_at, webauthn_credentials
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
|
||||
u.ID, u.TenantID, u.Email, u.DisplayName, u.OIDCSubject,
|
||||
u.OIDCProviderID, u.LastLoginAt, u.WebAuthnCredentials)
|
||||
if err != nil {
|
||||
var pqErr *pq.Error
|
||||
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
|
||||
return repository.ErrUserDuplicateOIDCSubject
|
||||
}
|
||||
return fmt.Errorf("users create: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update writes the mutable fields (email, display_name, last_login_at,
|
||||
// webauthn_credentials) back to the row. Immutable: id, tenant_id,
|
||||
// oidc_subject, oidc_provider_id, created_at. updated_at = NOW().
|
||||
func (r *UserRepository) Update(ctx context.Context, u *userdomain.User) error {
|
||||
res, err := r.db.ExecContext(ctx, `
|
||||
UPDATE users SET
|
||||
email = $2,
|
||||
display_name = $3,
|
||||
last_login_at = $4,
|
||||
webauthn_credentials = $5,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1`,
|
||||
u.ID, u.Email, u.DisplayName, u.LastLoginAt, u.WebAuthnCredentials)
|
||||
if err != nil {
|
||||
return fmt.Errorf("users update: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
return repository.ErrUserNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListAll returns every user in the tenant, ordered by created_at ASC.
|
||||
func (r *UserRepository) ListAll(ctx context.Context, tenantID string) ([]*userdomain.User, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT `+userColumns+` FROM users WHERE tenant_id = $1 ORDER BY created_at ASC`,
|
||||
tenantID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("users list_all: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []*userdomain.User
|
||||
for rows.Next() {
|
||||
u, err := scanUser(rows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("users scan: %w", err)
|
||||
}
|
||||
out = append(out, u)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
@@ -0,0 +1,220 @@
|
||||
package postgres_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
userdomain "github.com/certctl-io/certctl/internal/auth/user/domain"
|
||||
"github.com/certctl-io/certctl/internal/repository"
|
||||
"github.com/certctl-io/certctl/internal/repository/postgres"
|
||||
)
|
||||
|
||||
// newValidUser is shared with oidc_test.go (same _test package).
|
||||
func newValidUser(suffix, providerID string) *userdomain.User {
|
||||
return &userdomain.User{
|
||||
ID: "u-" + suffix,
|
||||
TenantID: "t-default",
|
||||
Email: suffix + "@example.com",
|
||||
DisplayName: "User " + suffix,
|
||||
OIDCSubject: "subject-" + suffix,
|
||||
OIDCProviderID: providerID,
|
||||
WebAuthnCredentials: []byte("[]"),
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserRepository_CreateAndGet(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
providerRepo := postgres.NewOIDCProviderRepository(db)
|
||||
userRepo := postgres.NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
p := newValidProvider("u")
|
||||
if err := providerRepo.Create(ctx, p); err != nil {
|
||||
t.Fatalf("Create provider: %v", err)
|
||||
}
|
||||
u := newValidUser("alice", p.ID)
|
||||
if err := userRepo.Create(ctx, u); err != nil {
|
||||
t.Fatalf("Create user: %v", err)
|
||||
}
|
||||
|
||||
got, err := userRepo.Get(ctx, u.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Get: %v", err)
|
||||
}
|
||||
if got.Email != u.Email {
|
||||
t.Errorf("Email roundtrip: got %q, want %q", got.Email, u.Email)
|
||||
}
|
||||
if string(got.WebAuthnCredentials) != "[]" {
|
||||
t.Errorf("WebAuthnCredentials default = %q; want []", string(got.WebAuthnCredentials))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserRepository_GetNotFound(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
repo := postgres.NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := repo.Get(ctx, "u-nonexistent")
|
||||
if !errors.Is(err, repository.ErrUserNotFound) {
|
||||
t.Errorf("err = %v; want ErrUserNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserRepository_GetByOIDCSubject(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
providerRepo := postgres.NewOIDCProviderRepository(db)
|
||||
userRepo := postgres.NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
p := newValidProvider("subj")
|
||||
if err := providerRepo.Create(ctx, p); err != nil {
|
||||
t.Fatalf("Create provider: %v", err)
|
||||
}
|
||||
u := newValidUser("bob", p.ID)
|
||||
if err := userRepo.Create(ctx, u); err != nil {
|
||||
t.Fatalf("Create user: %v", err)
|
||||
}
|
||||
|
||||
got, err := userRepo.GetByOIDCSubject(ctx, p.ID, u.OIDCSubject)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByOIDCSubject: %v", err)
|
||||
}
|
||||
if got.ID != u.ID {
|
||||
t.Errorf("GetByOIDCSubject returned %q; want %q", got.ID, u.ID)
|
||||
}
|
||||
|
||||
// Wrong subject: not found.
|
||||
_, err = userRepo.GetByOIDCSubject(ctx, p.ID, "wrong-subject")
|
||||
if !errors.Is(err, repository.ErrUserNotFound) {
|
||||
t.Errorf("err = %v; want ErrUserNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserRepository_DuplicateOIDCSubjectRejected(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
providerRepo := postgres.NewOIDCProviderRepository(db)
|
||||
userRepo := postgres.NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
p := newValidProvider("dupsubj")
|
||||
if err := providerRepo.Create(ctx, p); err != nil {
|
||||
t.Fatalf("Create provider: %v", err)
|
||||
}
|
||||
u1 := newValidUser("first", p.ID)
|
||||
if err := userRepo.Create(ctx, u1); err != nil {
|
||||
t.Fatalf("Create u1: %v", err)
|
||||
}
|
||||
u2 := newValidUser("second", p.ID)
|
||||
u2.OIDCSubject = u1.OIDCSubject // collision on (provider, subject) UNIQUE
|
||||
err := userRepo.Create(ctx, u2)
|
||||
if !errors.Is(err, repository.ErrUserDuplicateOIDCSubject) {
|
||||
t.Errorf("Create duplicate (provider, subject) err = %v; want ErrUserDuplicateOIDCSubject", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserRepository_UpdateMutableFields(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
providerRepo := postgres.NewOIDCProviderRepository(db)
|
||||
userRepo := postgres.NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
p := newValidProvider("upd")
|
||||
if err := providerRepo.Create(ctx, p); err != nil {
|
||||
t.Fatalf("Create provider: %v", err)
|
||||
}
|
||||
u := newValidUser("carol", p.ID)
|
||||
if err := userRepo.Create(ctx, u); err != nil {
|
||||
t.Fatalf("Create user: %v", err)
|
||||
}
|
||||
|
||||
u.Email = "carol-new@example.com"
|
||||
u.DisplayName = "Carol Renamed"
|
||||
if err := userRepo.Update(ctx, u); err != nil {
|
||||
t.Fatalf("Update: %v", err)
|
||||
}
|
||||
got, err := userRepo.Get(ctx, u.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Get post-update: %v", err)
|
||||
}
|
||||
if got.Email != "carol-new@example.com" {
|
||||
t.Errorf("Update did not persist Email; got %q", got.Email)
|
||||
}
|
||||
if got.DisplayName != "Carol Renamed" {
|
||||
t.Errorf("Update did not persist DisplayName; got %q", got.DisplayName)
|
||||
}
|
||||
// Immutable: oidc_subject must NOT change.
|
||||
if got.OIDCSubject != u.OIDCSubject {
|
||||
t.Errorf("OIDCSubject mutated: got %q, want %q", got.OIDCSubject, u.OIDCSubject)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserRepository_ListAll(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
providerRepo := postgres.NewOIDCProviderRepository(db)
|
||||
userRepo := postgres.NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
p := newValidProvider("la")
|
||||
if err := providerRepo.Create(ctx, p); err != nil {
|
||||
t.Fatalf("Create provider: %v", err)
|
||||
}
|
||||
for _, suf := range []string{"u1", "u2", "u3"} {
|
||||
u := newValidUser(suf, p.ID)
|
||||
if err := userRepo.Create(ctx, u); err != nil {
|
||||
t.Fatalf("Create %s: %v", suf, err)
|
||||
}
|
||||
}
|
||||
|
||||
out, err := userRepo.ListAll(ctx, "t-default")
|
||||
if err != nil {
|
||||
t.Fatalf("ListAll: %v", err)
|
||||
}
|
||||
if len(out) != 3 {
|
||||
t.Errorf("ListAll count = %d; want 3", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_DeletingProviderRefusedWhenUsersReference complements
|
||||
// the OIDCProviderRepository test of the same shape; pinning both ends
|
||||
// of the FK ON DELETE RESTRICT contract.
|
||||
func TestUserRepository_FKRestrictsProviderDelete(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("integration test in short mode")
|
||||
}
|
||||
db := getTestDB(t).freshSchema(t)
|
||||
providerRepo := postgres.NewOIDCProviderRepository(db)
|
||||
userRepo := postgres.NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
p := newValidProvider("fkrest")
|
||||
if err := providerRepo.Create(ctx, p); err != nil {
|
||||
t.Fatalf("Create provider: %v", err)
|
||||
}
|
||||
u := newValidUser("fkrest-user", p.ID)
|
||||
if err := userRepo.Create(ctx, u); err != nil {
|
||||
t.Fatalf("Create user: %v", err)
|
||||
}
|
||||
|
||||
if err := providerRepo.Delete(ctx, p.ID); !errors.Is(err, repository.ErrOIDCProviderInUse) {
|
||||
t.Errorf("Delete provider (with referencing user) err = %v; want ErrOIDCProviderInUse", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,124 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
sessiondomain "github.com/certctl-io/certctl/internal/auth/session/domain"
|
||||
)
|
||||
|
||||
// Sentinel errors for the session repositories.
|
||||
var (
|
||||
// ErrSessionNotFound: Get returned no row. Phase 4 maps to 401
|
||||
// (the cookie either expired or was forged with a known-good key
|
||||
// id but stale session id).
|
||||
ErrSessionNotFound = errors.New("session: not found")
|
||||
|
||||
// ErrSessionRevoked: Get found a row but RevokedAt is set. Phase 4
|
||||
// maps to 401.
|
||||
ErrSessionRevoked = errors.New("session: revoked")
|
||||
|
||||
// ErrSessionExpired: Get found a row but the absolute expiry has
|
||||
// passed (Phase 4 also enforces idle expiry but that's a service-
|
||||
// level check against last_seen_at, not a repository sentinel).
|
||||
ErrSessionExpired = errors.New("session: expired")
|
||||
|
||||
// ErrSessionSigningKeyNotFound: GetActive returned no row. Phase 4
|
||||
// EnsureInitialSigningKey treats this as "boot-time provisioning
|
||||
// needed" and mints the first key.
|
||||
ErrSessionSigningKeyNotFound = errors.New("session: signing key not found")
|
||||
|
||||
// ErrSessionSigningKeyInUse: Delete (full purge, not Retire) failed
|
||||
// because at least one sessions row still references the key. Phase
|
||||
// 4's GarbageCollect waits for sessions to expire before purging.
|
||||
ErrSessionSigningKeyInUse = errors.New("session: signing key still referenced by active sessions")
|
||||
)
|
||||
|
||||
// SessionRepository wraps the sessions table. Two cookie shapes share
|
||||
// the rows: post-login sessions (1h-idle/8h-absolute) and pre-login
|
||||
// sessions (10-minute TTL, IsPreLogin=true; carry OIDC state + nonce
|
||||
// + PKCE verifier across the IdP redirect).
|
||||
type SessionRepository interface {
|
||||
// Create persists a session row. Caller MUST have called
|
||||
// s.Validate(). Returns ErrAuthDuplicateName-shape on the
|
||||
// extremely-unlikely id collision (the id is a 32-byte random;
|
||||
// callers SHOULD generate fresh ids on the second attempt).
|
||||
Create(ctx context.Context, s *sessiondomain.Session) error
|
||||
|
||||
// Get returns a session by id. ErrSessionNotFound on miss.
|
||||
// Returns the row even if revoked / expired so the service layer
|
||||
// can produce the right 401 reason code (revoked vs expired vs
|
||||
// not-found are all 401 to the wire but distinguishable in audit).
|
||||
Get(ctx context.Context, id string) (*sessiondomain.Session, error)
|
||||
|
||||
// ListByActor returns every active (non-revoked, non-expired,
|
||||
// non-pre-login) session for an actor. Used by the GUI's
|
||||
// /v1/auth/sessions surface so users can revoke their old laptops.
|
||||
ListByActor(ctx context.Context, actorID, actorType, tenantID string) ([]*sessiondomain.Session, error)
|
||||
|
||||
// UpdateLastSeen sets last_seen_at = NOW() for the named session.
|
||||
// Phase 4's middleware calls this on every request to keep the
|
||||
// idle-expiry sliding window fresh.
|
||||
UpdateLastSeen(ctx context.Context, id string) error
|
||||
|
||||
// Revoke sets revoked_at = NOW() for the named session. Subsequent
|
||||
// Get returns the row with RevokedAt set; Phase 4's Validate maps
|
||||
// to 401.
|
||||
Revoke(ctx context.Context, id string) error
|
||||
|
||||
// RevokeAllForActor sets revoked_at = NOW() on every active session
|
||||
// for an actor. Used on role change, fired-employee scenarios, and
|
||||
// the back-channel logout endpoint (Phase 5).
|
||||
RevokeAllForActor(ctx context.Context, actorID, actorType, tenantID string) error
|
||||
|
||||
// GarbageCollectExpired deletes sessions whose absolute expiry
|
||||
// has passed AND whose revoked_at is older than the configurable
|
||||
// retention window (default 24h). Pre-login rows older than the
|
||||
// 10-minute TTL are also deleted. Returns the number of rows
|
||||
// deleted.
|
||||
GarbageCollectExpired(ctx context.Context) (int, error)
|
||||
|
||||
// Delete unconditionally removes a session row. Used for the
|
||||
// admin-only "purge a specific session" surface (rarely needed;
|
||||
// Revoke is the normal path).
|
||||
Delete(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
// SessionSigningKeyRepository wraps the session_signing_keys table.
|
||||
// Phase 4's Service.RotateSigningKey + EnsureInitialSigningKey + the
|
||||
// scheduler-driven retention sweep consume this.
|
||||
type SessionSigningKeyRepository interface {
|
||||
// List returns every signing key in the tenant (including
|
||||
// retired). Order: created_at DESC.
|
||||
List(ctx context.Context, tenantID string) ([]*sessiondomain.SessionSigningKey, error)
|
||||
|
||||
// GetActive returns the most-recently-created non-retired key.
|
||||
// ErrSessionSigningKeyNotFound when no non-retired key exists
|
||||
// (Phase 4's EnsureInitialSigningKey treats this as "mint first
|
||||
// key").
|
||||
GetActive(ctx context.Context, tenantID string) (*sessiondomain.SessionSigningKey, error)
|
||||
|
||||
// Get returns one key by id (including retired keys; Phase 4's
|
||||
// Validate consults this for cookies signed under retired-but-
|
||||
// in-retention keys).
|
||||
Get(ctx context.Context, id string) (*sessiondomain.SessionSigningKey, error)
|
||||
|
||||
// Add persists a new signing key. Caller MUST have called
|
||||
// k.Validate() and encrypted the key_material via
|
||||
// internal/crypto/encryption.go. CreatedAt defaults to NOW() if
|
||||
// zero.
|
||||
Add(ctx context.Context, k *sessiondomain.SessionSigningKey) error
|
||||
|
||||
// Retire marks an active key as retired (sets retired_at = NOW()).
|
||||
// The key stays in the table for verification of cookies signed
|
||||
// under it; the scheduler's retention sweep purges it after the
|
||||
// configurable retention window (default 24h beyond retired_at).
|
||||
Retire(ctx context.Context, id string) error
|
||||
|
||||
// Delete unconditionally removes a signing key row. Returns
|
||||
// ErrSessionSigningKeyInUse if any sessions row still references
|
||||
// the key (FK ON DELETE RESTRICT). Phase 4's GarbageCollect calls
|
||||
// this only after RetentionWindow has passed AND no sessions
|
||||
// reference the key.
|
||||
Delete(ctx context.Context, id string) error
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
userdomain "github.com/certctl-io/certctl/internal/auth/user/domain"
|
||||
)
|
||||
|
||||
// Sentinel errors for the user repository.
|
||||
var (
|
||||
// ErrUserNotFound: Get / GetByOIDCSubject returned no row. Phase
|
||||
// 3's HandleCallback treats this as "first login for this person;
|
||||
// create the row".
|
||||
ErrUserNotFound = errors.New("user: not found")
|
||||
|
||||
// ErrUserDuplicateOIDCSubject: Create tripped the
|
||||
// (oidc_provider_id, oidc_subject) UNIQUE constraint. HTTP 409.
|
||||
ErrUserDuplicateOIDCSubject = errors.New("user: a user with this provider+subject already exists")
|
||||
)
|
||||
|
||||
// UserRepository wraps the users table. Phase 3's HandleCallback
|
||||
// uses GetByOIDCSubject + Create + Update on every login; the GUI's
|
||||
// admin user-list surface uses ListAll + Get.
|
||||
type UserRepository interface {
|
||||
// Get returns one user by id. ErrUserNotFound on miss.
|
||||
Get(ctx context.Context, id string) (*userdomain.User, error)
|
||||
|
||||
// GetByOIDCSubject is the Phase 3 hot-path lookup at login time.
|
||||
// Returns the existing row if present, ErrUserNotFound otherwise.
|
||||
GetByOIDCSubject(ctx context.Context, providerID, subject string) (*userdomain.User, error)
|
||||
|
||||
// Create persists a new user. Caller MUST have called u.Validate().
|
||||
// Returns ErrUserDuplicateOIDCSubject on UNIQUE constraint trip.
|
||||
Create(ctx context.Context, u *userdomain.User) error
|
||||
|
||||
// Update writes the mutable field set back to the row. Immutable
|
||||
// fields (id, tenant_id, oidc_subject, oidc_provider_id,
|
||||
// created_at) are preserved. updated_at is set to NOW() by the
|
||||
// implementation.
|
||||
Update(ctx context.Context, u *userdomain.User) error
|
||||
|
||||
// ListAll returns every user in the tenant. Order:
|
||||
// created_at ASC. Used by the GUI's admin surface.
|
||||
ListAll(ctx context.Context, tenantID string) ([]*userdomain.User, error)
|
||||
}
|
||||
Reference in New Issue
Block a user