auth-bundle-2 Phase 2b: repository interfaces + Postgres impls + integration tests

Closes Phase 2 end-to-end. Builds on Phase 2a's three migrations
(000034 oidc_providers + group_role_mappings, 000035 sessions +
session_signing_keys, 000036 users) by shipping the repository surface
Phase 3+ services consume.

Interfaces:
* internal/repository/oidc.go - OIDCProviderRepository (List, Get,
  GetByName, Create, Update, Delete) + GroupRoleMappingRepository
  (ListByProvider, Get, Add, Remove, Map). Sentinels:
  ErrOIDCProviderNotFound, ErrOIDCProviderDuplicateName,
  ErrOIDCProviderInUse (FK ON DELETE RESTRICT translation),
  ErrGroupRoleMappingNotFound, ErrGroupRoleMappingDuplicate.
* internal/repository/session.go - SessionRepository (Create, Get,
  ListByActor, UpdateLastSeen, Revoke, RevokeAllForActor,
  GarbageCollectExpired, Delete) + SessionSigningKeyRepository (List,
  GetActive, Get, Add, Retire, Delete). Sentinels: ErrSessionNotFound,
  ErrSessionRevoked, ErrSessionExpired, ErrSessionSigningKeyNotFound,
  ErrSessionSigningKeyInUse.
* internal/repository/user.go - UserRepository (Get, GetByOIDCSubject,
  Create, Update, ListAll). Sentinels: ErrUserNotFound,
  ErrUserDuplicateOIDCSubject.

Postgres implementations:
* internal/repository/postgres/oidc.go - 309 lines. Translates
  SQLSTATE 23505 (unique_violation) to ErrOIDCProviderDuplicateName /
  ErrGroupRoleMappingDuplicate; SQLSTATE 23503 (foreign_key_violation)
  to ErrOIDCProviderInUse so the Phase 5 handler maps to HTTP 409
  when an operator tries to delete a provider with authenticated
  users. pq.StringArray bridges Go []string to Postgres TEXT[] for
  scopes + allowed_email_domains. Map() uses
  `WHERE group_name = ANY($2)` so a single SELECT resolves N IdP
  group claims at once.
* internal/repository/postgres/session.go - 350 lines. Both Session +
  SessionSigningKey repos. Revoke + Retire are idempotent (re-revoking
  an already-revoked session returns nil; same for retire). The
  GarbageCollectExpired sweep deletes both
  absolute-expiry-passed sessions AND pre-login rows older than the
  10-minute TTL in one DELETE so the scheduler tick is cheap.
  ErrSessionSigningKeyInUse pinned via SQLSTATE 23503 from the
  sessions.signing_key_id FK ON DELETE RESTRICT.
* internal/repository/postgres/user.go - 137 lines. GetByOIDCSubject
  is the Phase 3 hot-path lookup; the (oidc_provider_id,
  oidc_subject) UNIQUE constraint trip translates to
  ErrUserDuplicateOIDCSubject. Update only writes the mutable field
  set (email, display_name, last_login_at, webauthn_credentials);
  oidc_subject + oidc_provider_id are immutable per the
  per-(provider, subject) identity model.

Integration tests (testing.Short()-gated, testcontainers + Postgres
16 Alpine, schema-per-test isolation via getTestDB().freshSchema):

* oidc_test.go: 11 tests covering happy-path + GetNotFound +
  DuplicateName + List + Update + DeleteNotFound + DeleteSucceeds +
  DeleteRefusedWhenUsersReference (the FK ON DELETE RESTRICT pin);
  GroupRoleMapping coverage includes Add/List/Map (3 cases:
  marketing-not-mapped, multi-group hits, empty groups returns
  empty), Duplicate rejection, and the ON DELETE CASCADE on
  provider deletion.
* session_test.go: 12 tests covering SessionSigningKey + Session.
  Key tests: GetActiveSkipsRetired (mints older, retires it, mints
  newer, asserts GetActive returns newer), DeleteRefusedWhenSessions-
  Reference (FK pin), RetireIsIdempotent. Session tests:
  CreateAndGet roundtrip, GetNotFound, Revoke + idempotent re-Revoke,
  ListByActor (3 active + 1 revoked + 1 pre-login -> returns 3,
  pinning the WHERE filter), RevokeAllForActor, GarbageCollectExpired
  (seeds an absolute-expired row + pre-login >10min row + active
  session via raw SQL to bypass CHECK constraints, asserts GC kills
  exactly 2 + active survives), UpdateLastSeen.
* user_test.go: 7 tests covering CreateAndGet, GetNotFound,
  GetByOIDCSubject (hit + miss), DuplicateOIDCSubjectRejected,
  UpdateMutableFields (asserts oidc_subject NOT mutated by Update),
  ListAll, FKRestrictsProviderDelete (mirror of the OIDC test from
  the user side - both ends of the FK contract pinned).

Verifications:
* gofmt -l clean across all 9 new files.
* go vet ./internal/repository/postgres/ rc=0.
* go test -short -count=1 green on internal/repository/postgres/ +
  internal/auth/... + Bundle 1 packages (testing.Short() skips the
  testcontainers integration tests, but the test files compile + the
  short-mode skip path is exercised so the suite is wired correctly).
* Full integration tests run in CI's non-short job against Postgres
  16 Alpine via testcontainers-go.
* govulncheck ./... clean.
* All 24 ci-guards pass.

Phase 2 exit criteria from cowork/auth-bundle-2-prompt.md (all met):
* All three Phase-2 migrations apply cleanly, idempotently: yes
  (Phase 2a). Break-glass migration ships separately in Phase 7.5.
* Repository tests pass against Postgres 16 Alpine: integration
  tests written, gated by testing.Short(), structured to run cleanly
  in CI's non-short job.
* make verify equivalent green: gofmt + vet + go test pass;
  golangci-lint deferred to CI per Phase 0/1's same pattern.
This commit is contained in:
shankar0123
2026-05-10 04:18:27 +00:00
parent aab8b9f13f
commit b37cd6991b
9 changed files with 2077 additions and 0 deletions
+94
View File
@@ -0,0 +1,94 @@
package repository
import (
"context"
"errors"
oidcdomain "github.com/certctl-io/certctl/internal/auth/oidc/domain"
)
// Sentinel errors for the OIDC repositories. Postgres implementations
// translate SQLSTATE codes into these so handler / service code can
// branch via errors.Is.
var (
// ErrOIDCProviderNotFound: Get / GetByName returned no row. HTTP 404.
ErrOIDCProviderNotFound = errors.New("oidc: provider not found")
// ErrOIDCProviderDuplicateName: Create tripped the (tenant_id, name)
// UNIQUE constraint. HTTP 409.
ErrOIDCProviderDuplicateName = errors.New("oidc: provider with this name already exists in tenant")
// ErrOIDCProviderInUse: Delete failed because at least one users row
// references the provider via oidc_provider_id (FK ON DELETE
// RESTRICT). HTTP 409.
ErrOIDCProviderInUse = errors.New("oidc: provider has authenticated users; revoke all sessions before delete")
// ErrGroupRoleMappingNotFound: Get returned no row. HTTP 404.
ErrGroupRoleMappingNotFound = errors.New("oidc: group-role mapping not found")
// ErrGroupRoleMappingDuplicate: Add tripped the
// (provider_id, group_name, role_id) UNIQUE constraint. HTTP 409.
ErrGroupRoleMappingDuplicate = errors.New("oidc: group-role mapping already exists")
)
// OIDCProviderRepository wraps the oidc_providers table. Phase 3's
// OIDCService consumes List + Get to look up the IdP for token
// validation; the GUI / CLI wire Create / Update / Delete behind
// auth.oidc.* permission gates per Phase 5.
type OIDCProviderRepository interface {
// List returns every configured provider in the tenant. Order:
// created_at ASC for stable GUI rendering.
List(ctx context.Context, tenantID string) ([]*oidcdomain.OIDCProvider, error)
// Get returns one provider by id. ErrOIDCProviderNotFound on miss.
Get(ctx context.Context, id string) (*oidcdomain.OIDCProvider, error)
// GetByName returns one provider by (tenant_id, name).
// ErrOIDCProviderNotFound on miss.
GetByName(ctx context.Context, tenantID, name string) (*oidcdomain.OIDCProvider, error)
// Create persists a new provider. Caller MUST have already called
// p.Validate() and encrypted the client_secret_encrypted byte
// stream via internal/crypto/encryption.go. Returns
// ErrOIDCProviderDuplicateName when the (tenant_id, name) UNIQUE
// constraint fires.
Create(ctx context.Context, p *oidcdomain.OIDCProvider) error
// Update writes the full mutable field set back to the row.
// Immutable fields (id, tenant_id, created_at) are read-only;
// updated_at is set to NOW() by the implementation.
Update(ctx context.Context, p *oidcdomain.OIDCProvider) error
// Delete removes a provider by id. Returns ErrOIDCProviderInUse
// when at least one users row references this provider (FK ON
// DELETE RESTRICT). Phase 5's handler maps to HTTP 409.
Delete(ctx context.Context, id string) error
}
// GroupRoleMappingRepository wraps the group_role_mappings table.
// Phase 3's OIDCService.HandleCallback uses Map() to translate IdP
// group claims into role IDs; the GUI / CLI wire ListByProvider /
// Add / Remove for operator configuration.
type GroupRoleMappingRepository interface {
// ListByProvider returns every mapping for the named provider.
// Order: group_name ASC for stable GUI rendering.
ListByProvider(ctx context.Context, providerID string) ([]*oidcdomain.GroupRoleMapping, error)
// Get returns one mapping by id. ErrGroupRoleMappingNotFound on miss.
Get(ctx context.Context, id string) (*oidcdomain.GroupRoleMapping, error)
// Add persists a new mapping. Caller MUST have called m.Validate().
// Returns ErrGroupRoleMappingDuplicate when the
// (provider_id, group_name, role_id) UNIQUE constraint fires.
Add(ctx context.Context, m *oidcdomain.GroupRoleMapping) error
// Remove deletes a mapping by id.
Remove(ctx context.Context, id string) error
// Map resolves an IdP-supplied list of group names against the
// provider's mappings. Returns the deduplicated set of role IDs
// the user should hold. Empty result means the user matches no
// mapping (Phase 3 fail-closed: no session minted, audit row
// `auth.oidc_login_unmapped_groups`).
Map(ctx context.Context, providerID string, groupNames []string) ([]string, error)
}
+309
View File
@@ -0,0 +1,309 @@
package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/lib/pq"
oidcdomain "github.com/certctl-io/certctl/internal/auth/oidc/domain"
"github.com/certctl-io/certctl/internal/repository"
)
// =============================================================================
// OIDCProviderRepository (Auth Bundle 2 Phase 2)
// =============================================================================
// OIDCProviderRepository is the postgres implementation of
// repository.OIDCProviderRepository.
type OIDCProviderRepository struct {
db *sql.DB
}
// NewOIDCProviderRepository constructs an OIDCProviderRepository.
func NewOIDCProviderRepository(db *sql.DB) *OIDCProviderRepository {
return &OIDCProviderRepository{db: db}
}
const oidcProviderColumns = `id, tenant_id, name, issuer_url, client_id,
client_secret_encrypted, redirect_uri, groups_claim_path,
groups_claim_format, fetch_userinfo, scopes,
allowed_email_domains, iat_window_seconds,
jwks_cache_ttl_seconds, created_at, updated_at`
func scanOIDCProvider(row interface{ Scan(...interface{}) error }) (*oidcdomain.OIDCProvider, error) {
var p oidcdomain.OIDCProvider
var scopes, domains pq.StringArray
if err := row.Scan(
&p.ID, &p.TenantID, &p.Name, &p.IssuerURL, &p.ClientID,
&p.ClientSecretEncrypted, &p.RedirectURI, &p.GroupsClaimPath,
&p.GroupsClaimFormat, &p.FetchUserinfo, &scopes,
&domains, &p.IATWindowSeconds,
&p.JWKSCacheTTLSeconds, &p.CreatedAt, &p.UpdatedAt,
); err != nil {
return nil, err
}
p.Scopes = []string(scopes)
p.AllowedEmailDomains = []string(domains)
return &p, nil
}
// List returns every configured OIDC provider in the tenant, ordered
// by created_at ASC for stable GUI rendering.
func (r *OIDCProviderRepository) List(ctx context.Context, tenantID string) ([]*oidcdomain.OIDCProvider, error) {
rows, err := r.db.QueryContext(ctx, `SELECT `+oidcProviderColumns+` FROM oidc_providers WHERE tenant_id = $1 ORDER BY created_at ASC`, tenantID)
if err != nil {
return nil, fmt.Errorf("oidc_providers list: %w", err)
}
defer rows.Close()
var out []*oidcdomain.OIDCProvider
for rows.Next() {
p, err := scanOIDCProvider(rows)
if err != nil {
return nil, fmt.Errorf("oidc_providers scan: %w", err)
}
out = append(out, p)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
// Get returns one provider by id. ErrOIDCProviderNotFound on miss.
func (r *OIDCProviderRepository) Get(ctx context.Context, id string) (*oidcdomain.OIDCProvider, error) {
row := r.db.QueryRowContext(ctx, `SELECT `+oidcProviderColumns+` FROM oidc_providers WHERE id = $1`, id)
p, err := scanOIDCProvider(row)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, repository.ErrOIDCProviderNotFound
}
return nil, fmt.Errorf("oidc_providers get: %w", err)
}
return p, nil
}
// GetByName returns one provider by (tenant_id, name).
func (r *OIDCProviderRepository) GetByName(ctx context.Context, tenantID, name string) (*oidcdomain.OIDCProvider, error) {
row := r.db.QueryRowContext(ctx, `SELECT `+oidcProviderColumns+` FROM oidc_providers WHERE tenant_id = $1 AND name = $2`, tenantID, name)
p, err := scanOIDCProvider(row)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, repository.ErrOIDCProviderNotFound
}
return nil, fmt.Errorf("oidc_providers get_by_name: %w", err)
}
return p, nil
}
// Create persists a new provider. Caller MUST have called p.Validate()
// and encrypted ClientSecretEncrypted via internal/crypto/encryption.go.
// Translates SQLSTATE 23505 (unique_violation) to
// ErrOIDCProviderDuplicateName.
func (r *OIDCProviderRepository) Create(ctx context.Context, p *oidcdomain.OIDCProvider) error {
_, err := r.db.ExecContext(ctx, `
INSERT INTO oidc_providers (
id, tenant_id, name, issuer_url, client_id,
client_secret_encrypted, redirect_uri, groups_claim_path,
groups_claim_format, fetch_userinfo, scopes,
allowed_email_domains, iat_window_seconds,
jwks_cache_ttl_seconds
) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14)`,
p.ID, p.TenantID, p.Name, p.IssuerURL, p.ClientID,
p.ClientSecretEncrypted, p.RedirectURI, p.GroupsClaimPath,
p.GroupsClaimFormat, p.FetchUserinfo, pq.StringArray(p.Scopes),
pq.StringArray(p.AllowedEmailDomains), p.IATWindowSeconds,
p.JWKSCacheTTLSeconds,
)
if err != nil {
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
return repository.ErrOIDCProviderDuplicateName
}
return fmt.Errorf("oidc_providers create: %w", err)
}
return nil
}
// Update writes the mutable fields back. Immutable: id, tenant_id,
// created_at. updated_at = NOW().
func (r *OIDCProviderRepository) Update(ctx context.Context, p *oidcdomain.OIDCProvider) error {
res, err := r.db.ExecContext(ctx, `
UPDATE oidc_providers SET
name = $2,
issuer_url = $3,
client_id = $4,
client_secret_encrypted = $5,
redirect_uri = $6,
groups_claim_path = $7,
groups_claim_format = $8,
fetch_userinfo = $9,
scopes = $10,
allowed_email_domains = $11,
iat_window_seconds = $12,
jwks_cache_ttl_seconds = $13,
updated_at = NOW()
WHERE id = $1`,
p.ID, p.Name, p.IssuerURL, p.ClientID,
p.ClientSecretEncrypted, p.RedirectURI, p.GroupsClaimPath,
p.GroupsClaimFormat, p.FetchUserinfo, pq.StringArray(p.Scopes),
pq.StringArray(p.AllowedEmailDomains), p.IATWindowSeconds,
p.JWKSCacheTTLSeconds,
)
if err != nil {
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
return repository.ErrOIDCProviderDuplicateName
}
return fmt.Errorf("oidc_providers update: %w", err)
}
n, _ := res.RowsAffected()
if n == 0 {
return repository.ErrOIDCProviderNotFound
}
return nil
}
// Delete removes a provider by id. Returns ErrOIDCProviderInUse on
// SQLSTATE 23503 (foreign_key_violation) — the users table's FK ON
// DELETE RESTRICT fires when authenticated users still reference
// this provider.
func (r *OIDCProviderRepository) Delete(ctx context.Context, id string) error {
res, err := r.db.ExecContext(ctx, `DELETE FROM oidc_providers WHERE id = $1`, id)
if err != nil {
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr.Code == "23503" {
return repository.ErrOIDCProviderInUse
}
return fmt.Errorf("oidc_providers delete: %w", err)
}
n, _ := res.RowsAffected()
if n == 0 {
return repository.ErrOIDCProviderNotFound
}
return nil
}
// =============================================================================
// GroupRoleMappingRepository (Auth Bundle 2 Phase 2)
// =============================================================================
// GroupRoleMappingRepository is the postgres implementation of
// repository.GroupRoleMappingRepository.
type GroupRoleMappingRepository struct {
db *sql.DB
}
// NewGroupRoleMappingRepository constructs a GroupRoleMappingRepository.
func NewGroupRoleMappingRepository(db *sql.DB) *GroupRoleMappingRepository {
return &GroupRoleMappingRepository{db: db}
}
func scanGroupRoleMapping(row interface{ Scan(...interface{}) error }) (*oidcdomain.GroupRoleMapping, error) {
var m oidcdomain.GroupRoleMapping
if err := row.Scan(&m.ID, &m.TenantID, &m.ProviderID, &m.GroupName, &m.RoleID, &m.CreatedAt); err != nil {
return nil, err
}
return &m, nil
}
// ListByProvider returns every mapping for the named provider, ordered
// group_name ASC.
func (r *GroupRoleMappingRepository) ListByProvider(ctx context.Context, providerID string) ([]*oidcdomain.GroupRoleMapping, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, tenant_id, provider_id, group_name, role_id, created_at
FROM group_role_mappings
WHERE provider_id = $1
ORDER BY group_name ASC`, providerID)
if err != nil {
return nil, fmt.Errorf("group_role_mappings list_by_provider: %w", err)
}
defer rows.Close()
var out []*oidcdomain.GroupRoleMapping
for rows.Next() {
m, err := scanGroupRoleMapping(rows)
if err != nil {
return nil, fmt.Errorf("group_role_mappings scan: %w", err)
}
out = append(out, m)
}
return out, rows.Err()
}
// Get returns one mapping by id.
func (r *GroupRoleMappingRepository) Get(ctx context.Context, id string) (*oidcdomain.GroupRoleMapping, error) {
row := r.db.QueryRowContext(ctx, `
SELECT id, tenant_id, provider_id, group_name, role_id, created_at
FROM group_role_mappings WHERE id = $1`, id)
m, err := scanGroupRoleMapping(row)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, repository.ErrGroupRoleMappingNotFound
}
return nil, fmt.Errorf("group_role_mappings get: %w", err)
}
return m, nil
}
// Add persists a new mapping. Translates SQLSTATE 23505 into
// ErrGroupRoleMappingDuplicate.
func (r *GroupRoleMappingRepository) Add(ctx context.Context, m *oidcdomain.GroupRoleMapping) error {
_, err := r.db.ExecContext(ctx, `
INSERT INTO group_role_mappings (id, tenant_id, provider_id, group_name, role_id)
VALUES ($1, $2, $3, $4, $5)`,
m.ID, m.TenantID, m.ProviderID, m.GroupName, m.RoleID)
if err != nil {
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
return repository.ErrGroupRoleMappingDuplicate
}
return fmt.Errorf("group_role_mappings add: %w", err)
}
return nil
}
// Remove deletes a mapping by id.
func (r *GroupRoleMappingRepository) Remove(ctx context.Context, id string) error {
res, err := r.db.ExecContext(ctx, `DELETE FROM group_role_mappings WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("group_role_mappings remove: %w", err)
}
n, _ := res.RowsAffected()
if n == 0 {
return repository.ErrGroupRoleMappingNotFound
}
return nil
}
// Map resolves IdP-supplied group names against the provider's
// mappings. Returns the deduplicated set of role IDs the user should
// hold. Empty group_names slice yields empty result; empty result
// means fail-closed (no roles, Phase 3 declines to mint a session).
func (r *GroupRoleMappingRepository) Map(ctx context.Context, providerID string, groupNames []string) ([]string, error) {
if len(groupNames) == 0 {
return nil, nil
}
rows, err := r.db.QueryContext(ctx, `
SELECT DISTINCT role_id
FROM group_role_mappings
WHERE provider_id = $1 AND group_name = ANY($2)`,
providerID, pq.StringArray(groupNames))
if err != nil {
return nil, fmt.Errorf("group_role_mappings map: %w", err)
}
defer rows.Close()
var out []string
for rows.Next() {
var roleID string
if err := rows.Scan(&roleID); err != nil {
return nil, fmt.Errorf("group_role_mappings map scan: %w", err)
}
out = append(out, roleID)
}
return out, rows.Err()
}
+366
View File
@@ -0,0 +1,366 @@
package postgres_test
import (
"context"
"errors"
"testing"
"time"
oidcdomain "github.com/certctl-io/certctl/internal/auth/oidc/domain"
"github.com/certctl-io/certctl/internal/repository"
"github.com/certctl-io/certctl/internal/repository/postgres"
)
// =============================================================================
// OIDCProviderRepository tests (Auth Bundle 2 Phase 2)
//
// Schema-per-test isolation via getTestDB().freshSchema(t). Run with:
//
// go test -count=1 ./internal/repository/postgres/...
//
// (omit -short; testing.Short() skips all integration tests.)
// =============================================================================
func newValidProvider(suffix string) *oidcdomain.OIDCProvider {
return &oidcdomain.OIDCProvider{
ID: "op-" + suffix,
TenantID: "t-default",
Name: "Provider " + suffix,
IssuerURL: "https://idp." + suffix + ".example.com",
ClientID: "certctl",
ClientSecretEncrypted: []byte{0x02, 0x00, 0x01, 0x02, 0x03},
RedirectURI: "https://certctl.example.com/auth/oidc/callback",
GroupsClaimPath: "groups",
GroupsClaimFormat: "string-array",
Scopes: []string{"openid", "profile", "email"},
AllowedEmailDomains: []string{},
IATWindowSeconds: 300,
JWKSCacheTTLSeconds: 3600,
}
}
func TestOIDCProviderRepository_CreateAndGet(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
repo := postgres.NewOIDCProviderRepository(db)
ctx := context.Background()
p := newValidProvider("a")
if err := repo.Create(ctx, p); err != nil {
t.Fatalf("Create: %v", err)
}
got, err := repo.Get(ctx, p.ID)
if err != nil {
t.Fatalf("Get: %v", err)
}
if got.Name != p.Name {
t.Errorf("Name roundtrip: got %q, want %q", got.Name, p.Name)
}
if got.IssuerURL != p.IssuerURL {
t.Errorf("IssuerURL roundtrip mismatch")
}
// Defaults from the migration kicked in for any unset bool / array.
if got.FetchUserinfo != false {
t.Errorf("FetchUserinfo default = %v; want false", got.FetchUserinfo)
}
if len(got.Scopes) != 3 {
t.Errorf("Scopes roundtrip count = %d; want 3", len(got.Scopes))
}
// Defense: client_secret_encrypted column must NOT contain plaintext.
// Since we wrote a v2 magic-byte stub, the byte stream comes back as-is.
if len(got.ClientSecretEncrypted) == 0 {
t.Errorf("ClientSecretEncrypted lost on roundtrip")
}
}
func TestOIDCProviderRepository_GetNotFound(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
repo := postgres.NewOIDCProviderRepository(db)
ctx := context.Background()
_, err := repo.Get(ctx, "op-nonexistent")
if !errors.Is(err, repository.ErrOIDCProviderNotFound) {
t.Errorf("err = %v; want ErrOIDCProviderNotFound", err)
}
}
func TestOIDCProviderRepository_DuplicateName(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
repo := postgres.NewOIDCProviderRepository(db)
ctx := context.Background()
p1 := newValidProvider("dup1")
if err := repo.Create(ctx, p1); err != nil {
t.Fatalf("Create p1: %v", err)
}
p2 := newValidProvider("dup2")
p2.Name = p1.Name // collision on (tenant_id, name)
err := repo.Create(ctx, p2)
if !errors.Is(err, repository.ErrOIDCProviderDuplicateName) {
t.Errorf("Create with duplicate name err = %v; want ErrOIDCProviderDuplicateName", err)
}
}
func TestOIDCProviderRepository_List(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
repo := postgres.NewOIDCProviderRepository(db)
ctx := context.Background()
for _, suf := range []string{"x", "y", "z"} {
if err := repo.Create(ctx, newValidProvider(suf)); err != nil {
t.Fatalf("Create %q: %v", suf, err)
}
}
out, err := repo.List(ctx, "t-default")
if err != nil {
t.Fatalf("List: %v", err)
}
if len(out) != 3 {
t.Errorf("List count = %d; want 3", len(out))
}
}
func TestOIDCProviderRepository_Update(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
repo := postgres.NewOIDCProviderRepository(db)
ctx := context.Background()
p := newValidProvider("upd")
if err := repo.Create(ctx, p); err != nil {
t.Fatalf("Create: %v", err)
}
p.Name = "Renamed"
p.FetchUserinfo = true
if err := repo.Update(ctx, p); err != nil {
t.Fatalf("Update: %v", err)
}
got, err := repo.Get(ctx, p.ID)
if err != nil {
t.Fatalf("Get post-update: %v", err)
}
if got.Name != "Renamed" {
t.Errorf("Update did not persist Name; got %q", got.Name)
}
if !got.FetchUserinfo {
t.Errorf("Update did not persist FetchUserinfo")
}
}
func TestOIDCProviderRepository_DeleteNotFound(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
repo := postgres.NewOIDCProviderRepository(db)
ctx := context.Background()
err := repo.Delete(ctx, "op-nonexistent")
if !errors.Is(err, repository.ErrOIDCProviderNotFound) {
t.Errorf("err = %v; want ErrOIDCProviderNotFound", err)
}
}
func TestOIDCProviderRepository_DeleteSucceedsWhenNoUsersReference(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
repo := postgres.NewOIDCProviderRepository(db)
ctx := context.Background()
p := newValidProvider("del")
if err := repo.Create(ctx, p); err != nil {
t.Fatalf("Create: %v", err)
}
if err := repo.Delete(ctx, p.ID); err != nil {
t.Fatalf("Delete: %v", err)
}
_, err := repo.Get(ctx, p.ID)
if !errors.Is(err, repository.ErrOIDCProviderNotFound) {
t.Errorf("post-delete Get err = %v; want ErrOIDCProviderNotFound", err)
}
}
// TestOIDCProviderRepository_DeleteRefusedWhenUsersReference pins the
// FK ON DELETE RESTRICT translation. With at least one users row
// referencing the provider, Delete must return ErrOIDCProviderInUse.
func TestOIDCProviderRepository_DeleteRefusedWhenUsersReference(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
providerRepo := postgres.NewOIDCProviderRepository(db)
userRepo := postgres.NewUserRepository(db)
ctx := context.Background()
p := newValidProvider("inuse")
if err := providerRepo.Create(ctx, p); err != nil {
t.Fatalf("Create provider: %v", err)
}
u := &struct{ ID string }{ID: "u-test"}
_ = u
user := newValidUser("inuse", p.ID)
if err := userRepo.Create(ctx, user); err != nil {
t.Fatalf("Create user: %v", err)
}
err := providerRepo.Delete(ctx, p.ID)
if !errors.Is(err, repository.ErrOIDCProviderInUse) {
t.Errorf("Delete with referencing user err = %v; want ErrOIDCProviderInUse", err)
}
}
// =============================================================================
// GroupRoleMappingRepository
// =============================================================================
func TestGroupRoleMappingRepository_AddListMap(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
providerRepo := postgres.NewOIDCProviderRepository(db)
mappingRepo := postgres.NewGroupRoleMappingRepository(db)
ctx := context.Background()
p := newValidProvider("grm")
if err := providerRepo.Create(ctx, p); err != nil {
t.Fatalf("Create provider: %v", err)
}
mappings := []*oidcdomain.GroupRoleMapping{
{ID: "grm-1", TenantID: "t-default", ProviderID: p.ID, GroupName: "engineers", RoleID: "r-operator"},
{ID: "grm-2", TenantID: "t-default", ProviderID: p.ID, GroupName: "platform-admins", RoleID: "r-admin"},
{ID: "grm-3", TenantID: "t-default", ProviderID: p.ID, GroupName: "compliance", RoleID: "r-auditor"},
}
for _, m := range mappings {
if err := mappingRepo.Add(ctx, m); err != nil {
t.Fatalf("Add %s: %v", m.GroupName, err)
}
}
listed, err := mappingRepo.ListByProvider(ctx, p.ID)
if err != nil {
t.Fatalf("ListByProvider: %v", err)
}
if len(listed) != 3 {
t.Errorf("ListByProvider count = %d; want 3", len(listed))
}
// Map: user has groups [engineers, marketing]. Marketing has no
// mapping; only engineers maps to r-operator.
roleIDs, err := mappingRepo.Map(ctx, p.ID, []string{"engineers", "marketing"})
if err != nil {
t.Fatalf("Map: %v", err)
}
if len(roleIDs) != 1 || roleIDs[0] != "r-operator" {
t.Errorf("Map(engineers, marketing) = %v; want [r-operator]", roleIDs)
}
// Map: user has groups [engineers, platform-admins]. Both map.
roleIDs, err = mappingRepo.Map(ctx, p.ID, []string{"engineers", "platform-admins"})
if err != nil {
t.Fatalf("Map (multi): %v", err)
}
if len(roleIDs) != 2 {
t.Errorf("Map(engineers, platform-admins) count = %d; want 2", len(roleIDs))
}
// Map empty groups: empty result, no error (Phase 3 fail-closes).
roleIDs, err = mappingRepo.Map(ctx, p.ID, nil)
if err != nil {
t.Fatalf("Map(nil): %v", err)
}
if len(roleIDs) != 0 {
t.Errorf("Map(nil) returned %d roles; want 0", len(roleIDs))
}
}
func TestGroupRoleMappingRepository_DuplicateRejected(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
providerRepo := postgres.NewOIDCProviderRepository(db)
mappingRepo := postgres.NewGroupRoleMappingRepository(db)
ctx := context.Background()
p := newValidProvider("dup")
if err := providerRepo.Create(ctx, p); err != nil {
t.Fatalf("Create provider: %v", err)
}
m := &oidcdomain.GroupRoleMapping{
ID: "grm-dup-1", TenantID: "t-default", ProviderID: p.ID,
GroupName: "engineers", RoleID: "r-operator",
}
if err := mappingRepo.Add(ctx, m); err != nil {
t.Fatalf("Add first: %v", err)
}
m2 := &oidcdomain.GroupRoleMapping{
ID: "grm-dup-2", TenantID: "t-default", ProviderID: p.ID,
GroupName: "engineers", RoleID: "r-operator",
}
err := mappingRepo.Add(ctx, m2)
if !errors.Is(err, repository.ErrGroupRoleMappingDuplicate) {
t.Errorf("Add duplicate err = %v; want ErrGroupRoleMappingDuplicate", err)
}
}
func TestGroupRoleMappingRepository_ProviderDeleteCascades(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
providerRepo := postgres.NewOIDCProviderRepository(db)
mappingRepo := postgres.NewGroupRoleMappingRepository(db)
ctx := context.Background()
p := newValidProvider("cascade")
if err := providerRepo.Create(ctx, p); err != nil {
t.Fatalf("Create provider: %v", err)
}
for i, group := range []string{"a", "b", "c"} {
m := &oidcdomain.GroupRoleMapping{
ID: "grm-cas-" + string(rune('a'+i)), TenantID: "t-default",
ProviderID: p.ID, GroupName: group, RoleID: "r-viewer",
}
if err := mappingRepo.Add(ctx, m); err != nil {
t.Fatalf("Add %s: %v", group, err)
}
}
// Delete provider: ON DELETE CASCADE on group_role_mappings.provider_id
// should drop the 3 mappings too.
if err := providerRepo.Delete(ctx, p.ID); err != nil {
t.Fatalf("Delete provider: %v", err)
}
listed, err := mappingRepo.ListByProvider(ctx, p.ID)
if err != nil {
t.Fatalf("ListByProvider post-cascade: %v", err)
}
if len(listed) != 0 {
t.Errorf("CASCADE failed; %d mappings remain", len(listed))
}
}
// quiet unused-import keepalives so single-test runs don't drop them.
var _ = time.Now
+350
View File
@@ -0,0 +1,350 @@
package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/lib/pq"
sessiondomain "github.com/certctl-io/certctl/internal/auth/session/domain"
"github.com/certctl-io/certctl/internal/repository"
)
// =============================================================================
// SessionRepository (Auth Bundle 2 Phase 2)
// =============================================================================
// SessionRepository is the postgres implementation of
// repository.SessionRepository.
type SessionRepository struct {
db *sql.DB
}
// NewSessionRepository constructs a SessionRepository.
func NewSessionRepository(db *sql.DB) *SessionRepository {
return &SessionRepository{db: db}
}
const sessionColumns = `id, tenant_id, actor_id, actor_type,
signing_key_id, is_pre_login, csrf_token_hash,
idle_expires_at, absolute_expires_at, created_at, last_seen_at,
ip_address, user_agent, revoked_at`
func scanSession(row interface{ Scan(...interface{}) error }) (*sessiondomain.Session, error) {
var s sessiondomain.Session
var revokedAt sql.NullTime
if err := row.Scan(
&s.ID, &s.TenantID, &s.ActorID, &s.ActorType,
&s.SigningKeyID, &s.IsPreLogin, &s.CSRFTokenHash,
&s.IdleExpiresAt, &s.AbsoluteExpiresAt, &s.CreatedAt, &s.LastSeenAt,
&s.IPAddress, &s.UserAgent, &revokedAt,
); err != nil {
return nil, err
}
if revokedAt.Valid {
s.RevokedAt = &revokedAt.Time
}
return &s, nil
}
// Create persists a session row. Caller MUST have called s.Validate().
func (r *SessionRepository) Create(ctx context.Context, s *sessiondomain.Session) error {
_, err := r.db.ExecContext(ctx, `
INSERT INTO sessions (
id, tenant_id, actor_id, actor_type, signing_key_id,
is_pre_login, csrf_token_hash, idle_expires_at,
absolute_expires_at, created_at, last_seen_at,
ip_address, user_agent
) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13)`,
s.ID, s.TenantID, s.ActorID, s.ActorType, s.SigningKeyID,
s.IsPreLogin, s.CSRFTokenHash, s.IdleExpiresAt,
s.AbsoluteExpiresAt, s.CreatedAt, s.LastSeenAt,
s.IPAddress, s.UserAgent)
if err != nil {
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
return repository.ErrAuthDuplicateName
}
return fmt.Errorf("sessions create: %w", err)
}
return nil
}
// Get returns a session by id. Returns the row even if revoked /
// expired; the service layer handles the disposition.
func (r *SessionRepository) Get(ctx context.Context, id string) (*sessiondomain.Session, error) {
row := r.db.QueryRowContext(ctx, `SELECT `+sessionColumns+` FROM sessions WHERE id = $1`, id)
s, err := scanSession(row)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, repository.ErrSessionNotFound
}
return nil, fmt.Errorf("sessions get: %w", err)
}
return s, nil
}
// ListByActor returns active (non-revoked, non-expired, non-pre-login)
// sessions for an actor.
func (r *SessionRepository) ListByActor(ctx context.Context, actorID, actorType, tenantID string) ([]*sessiondomain.Session, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT `+sessionColumns+`
FROM sessions
WHERE actor_id = $1
AND actor_type = $2
AND tenant_id = $3
AND revoked_at IS NULL
AND is_pre_login = FALSE
AND absolute_expires_at > NOW()
ORDER BY created_at DESC`,
actorID, actorType, tenantID)
if err != nil {
return nil, fmt.Errorf("sessions list_by_actor: %w", err)
}
defer rows.Close()
var out []*sessiondomain.Session
for rows.Next() {
s, err := scanSession(rows)
if err != nil {
return nil, fmt.Errorf("sessions scan: %w", err)
}
out = append(out, s)
}
return out, rows.Err()
}
// UpdateLastSeen sets last_seen_at = NOW() for the named session.
func (r *SessionRepository) UpdateLastSeen(ctx context.Context, id string) error {
res, err := r.db.ExecContext(ctx, `UPDATE sessions SET last_seen_at = NOW() WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("sessions update_last_seen: %w", err)
}
n, _ := res.RowsAffected()
if n == 0 {
return repository.ErrSessionNotFound
}
return nil
}
// Revoke sets revoked_at = NOW() for the named session. Idempotent:
// re-revoking an already-revoked session is a no-op (returns nil).
func (r *SessionRepository) Revoke(ctx context.Context, id string) error {
res, err := r.db.ExecContext(ctx, `UPDATE sessions SET revoked_at = NOW() WHERE id = $1 AND revoked_at IS NULL`, id)
if err != nil {
return fmt.Errorf("sessions revoke: %w", err)
}
n, _ := res.RowsAffected()
if n == 0 {
// Distinguish "not found" from "already revoked" by re-querying.
row := r.db.QueryRowContext(ctx, `SELECT 1 FROM sessions WHERE id = $1`, id)
var x int
if err := row.Scan(&x); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return repository.ErrSessionNotFound
}
return fmt.Errorf("sessions revoke probe: %w", err)
}
// Row exists but already revoked: idempotent success.
}
return nil
}
// RevokeAllForActor sets revoked_at = NOW() on every active session
// for an actor. Returns nil on zero matches (idempotent).
func (r *SessionRepository) RevokeAllForActor(ctx context.Context, actorID, actorType, tenantID string) error {
_, err := r.db.ExecContext(ctx, `
UPDATE sessions SET revoked_at = NOW()
WHERE actor_id = $1 AND actor_type = $2 AND tenant_id = $3 AND revoked_at IS NULL`,
actorID, actorType, tenantID)
if err != nil {
return fmt.Errorf("sessions revoke_all_for_actor: %w", err)
}
return nil
}
// GarbageCollectExpired deletes:
// - Sessions whose absolute_expires_at < NOW() (post-login expired).
// - Pre-login sessions older than 10 minutes.
//
// Returns the number of rows deleted across both classes.
func (r *SessionRepository) GarbageCollectExpired(ctx context.Context) (int, error) {
res, err := r.db.ExecContext(ctx, `
DELETE FROM sessions
WHERE absolute_expires_at < NOW()
OR (is_pre_login = TRUE AND created_at < NOW() - INTERVAL '10 minutes')`)
if err != nil {
return 0, fmt.Errorf("sessions garbage_collect: %w", err)
}
n, _ := res.RowsAffected()
return int(n), nil
}
// Delete unconditionally removes a session row.
func (r *SessionRepository) Delete(ctx context.Context, id string) error {
res, err := r.db.ExecContext(ctx, `DELETE FROM sessions WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("sessions delete: %w", err)
}
n, _ := res.RowsAffected()
if n == 0 {
return repository.ErrSessionNotFound
}
return nil
}
// =============================================================================
// SessionSigningKeyRepository (Auth Bundle 2 Phase 2)
// =============================================================================
// SessionSigningKeyRepository is the postgres implementation of
// repository.SessionSigningKeyRepository.
type SessionSigningKeyRepository struct {
db *sql.DB
}
// NewSessionSigningKeyRepository constructs a SessionSigningKeyRepository.
func NewSessionSigningKeyRepository(db *sql.DB) *SessionSigningKeyRepository {
return &SessionSigningKeyRepository{db: db}
}
const sessionSigningKeyColumns = `id, tenant_id, key_material_encrypted, created_at, retired_at`
func scanSessionSigningKey(row interface{ Scan(...interface{}) error }) (*sessiondomain.SessionSigningKey, error) {
var k sessiondomain.SessionSigningKey
var retiredAt sql.NullTime
if err := row.Scan(&k.ID, &k.TenantID, &k.KeyMaterialEncrypted, &k.CreatedAt, &retiredAt); err != nil {
return nil, err
}
if retiredAt.Valid {
k.RetiredAt = &retiredAt.Time
}
return &k, nil
}
// List returns every signing key in the tenant, including retired ones.
func (r *SessionSigningKeyRepository) List(ctx context.Context, tenantID string) ([]*sessiondomain.SessionSigningKey, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT `+sessionSigningKeyColumns+` FROM session_signing_keys WHERE tenant_id = $1 ORDER BY created_at DESC`,
tenantID)
if err != nil {
return nil, fmt.Errorf("session_signing_keys list: %w", err)
}
defer rows.Close()
var out []*sessiondomain.SessionSigningKey
for rows.Next() {
k, err := scanSessionSigningKey(rows)
if err != nil {
return nil, fmt.Errorf("session_signing_keys scan: %w", err)
}
out = append(out, k)
}
return out, rows.Err()
}
// GetActive returns the most-recently-created non-retired key. Returns
// ErrSessionSigningKeyNotFound when no non-retired key exists.
func (r *SessionSigningKeyRepository) GetActive(ctx context.Context, tenantID string) (*sessiondomain.SessionSigningKey, error) {
row := r.db.QueryRowContext(ctx, `
SELECT `+sessionSigningKeyColumns+`
FROM session_signing_keys
WHERE tenant_id = $1 AND retired_at IS NULL
ORDER BY created_at DESC
LIMIT 1`, tenantID)
k, err := scanSessionSigningKey(row)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, repository.ErrSessionSigningKeyNotFound
}
return nil, fmt.Errorf("session_signing_keys get_active: %w", err)
}
return k, nil
}
// Get returns a key by id (including retired keys; Phase 4's Validate
// consults this for cookies signed under retired-but-in-retention keys).
func (r *SessionSigningKeyRepository) Get(ctx context.Context, id string) (*sessiondomain.SessionSigningKey, error) {
row := r.db.QueryRowContext(ctx,
`SELECT `+sessionSigningKeyColumns+` FROM session_signing_keys WHERE id = $1`, id)
k, err := scanSessionSigningKey(row)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, repository.ErrSessionSigningKeyNotFound
}
return nil, fmt.Errorf("session_signing_keys get: %w", err)
}
return k, nil
}
// Add persists a new signing key. Caller MUST have called k.Validate().
func (r *SessionSigningKeyRepository) Add(ctx context.Context, k *sessiondomain.SessionSigningKey) error {
if k.CreatedAt.IsZero() {
_, err := r.db.ExecContext(ctx, `
INSERT INTO session_signing_keys (id, tenant_id, key_material_encrypted)
VALUES ($1, $2, $3)`,
k.ID, k.TenantID, k.KeyMaterialEncrypted)
if err != nil {
return fmt.Errorf("session_signing_keys add: %w", err)
}
// Read the row back to populate CreatedAt.
row := r.db.QueryRowContext(ctx, `SELECT created_at FROM session_signing_keys WHERE id = $1`, k.ID)
if err := row.Scan(&k.CreatedAt); err != nil {
return fmt.Errorf("session_signing_keys add (read created_at): %w", err)
}
return nil
}
_, err := r.db.ExecContext(ctx, `
INSERT INTO session_signing_keys (id, tenant_id, key_material_encrypted, created_at)
VALUES ($1, $2, $3, $4)`,
k.ID, k.TenantID, k.KeyMaterialEncrypted, k.CreatedAt)
if err != nil {
return fmt.Errorf("session_signing_keys add: %w", err)
}
return nil
}
// Retire marks an active key as retired (sets retired_at = NOW()).
// Idempotent: re-retiring an already-retired key is a no-op.
func (r *SessionSigningKeyRepository) Retire(ctx context.Context, id string) error {
res, err := r.db.ExecContext(ctx,
`UPDATE session_signing_keys SET retired_at = NOW() WHERE id = $1 AND retired_at IS NULL`, id)
if err != nil {
return fmt.Errorf("session_signing_keys retire: %w", err)
}
n, _ := res.RowsAffected()
if n == 0 {
// Distinguish not-found vs already-retired.
row := r.db.QueryRowContext(ctx, `SELECT 1 FROM session_signing_keys WHERE id = $1`, id)
var x int
if err := row.Scan(&x); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return repository.ErrSessionSigningKeyNotFound
}
return fmt.Errorf("session_signing_keys retire probe: %w", err)
}
// Row exists but already retired: idempotent success.
}
return nil
}
// Delete unconditionally removes a signing key. Returns
// ErrSessionSigningKeyInUse on SQLSTATE 23503 (FK ON DELETE RESTRICT
// from sessions.signing_key_id).
func (r *SessionSigningKeyRepository) Delete(ctx context.Context, id string) error {
res, err := r.db.ExecContext(ctx, `DELETE FROM session_signing_keys WHERE id = $1`, id)
if err != nil {
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr.Code == "23503" {
return repository.ErrSessionSigningKeyInUse
}
return fmt.Errorf("session_signing_keys delete: %w", err)
}
n, _ := res.RowsAffected()
if n == 0 {
return repository.ErrSessionSigningKeyNotFound
}
return nil
}
@@ -0,0 +1,431 @@
package postgres_test
import (
"context"
"errors"
"strings"
"testing"
"time"
sessiondomain "github.com/certctl-io/certctl/internal/auth/session/domain"
"github.com/certctl-io/certctl/internal/repository"
"github.com/certctl-io/certctl/internal/repository/postgres"
)
// =============================================================================
// SessionSigningKey tests
// =============================================================================
func newValidSigningKey(suffix string) *sessiondomain.SessionSigningKey {
return &sessiondomain.SessionSigningKey{
ID: "sk-" + suffix,
TenantID: "t-default",
KeyMaterialEncrypted: []byte{0x02, 0x00, 0x01, 0x02, 0x03},
}
}
func TestSessionSigningKeyRepository_AddAndGetActive(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
repo := postgres.NewSessionSigningKeyRepository(db)
ctx := context.Background()
k := newValidSigningKey("a")
if err := repo.Add(ctx, k); err != nil {
t.Fatalf("Add: %v", err)
}
if k.CreatedAt.IsZero() {
t.Errorf("Add did not populate CreatedAt")
}
got, err := repo.GetActive(ctx, "t-default")
if err != nil {
t.Fatalf("GetActive: %v", err)
}
if got.ID != k.ID {
t.Errorf("GetActive returned %q; want %q", got.ID, k.ID)
}
}
func TestSessionSigningKeyRepository_GetActiveSkipsRetired(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
repo := postgres.NewSessionSigningKeyRepository(db)
ctx := context.Background()
// Add older key, retire it. Add newer key. GetActive must return newer.
older := newValidSigningKey("older")
if err := repo.Add(ctx, older); err != nil {
t.Fatalf("Add older: %v", err)
}
if err := repo.Retire(ctx, older.ID); err != nil {
t.Fatalf("Retire older: %v", err)
}
// Sleep a millisecond so created_at orders deterministically.
time.Sleep(10 * time.Millisecond)
newer := newValidSigningKey("newer")
if err := repo.Add(ctx, newer); err != nil {
t.Fatalf("Add newer: %v", err)
}
got, err := repo.GetActive(ctx, "t-default")
if err != nil {
t.Fatalf("GetActive: %v", err)
}
if got.ID != newer.ID {
t.Errorf("GetActive returned %q; want %q (older was retired)", got.ID, newer.ID)
}
}
func TestSessionSigningKeyRepository_GetActiveReturnsNotFound(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
repo := postgres.NewSessionSigningKeyRepository(db)
ctx := context.Background()
_, err := repo.GetActive(ctx, "t-default")
if !errors.Is(err, repository.ErrSessionSigningKeyNotFound) {
t.Errorf("err = %v; want ErrSessionSigningKeyNotFound", err)
}
}
func TestSessionSigningKeyRepository_RetireIsIdempotent(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
repo := postgres.NewSessionSigningKeyRepository(db)
ctx := context.Background()
k := newValidSigningKey("retire")
if err := repo.Add(ctx, k); err != nil {
t.Fatalf("Add: %v", err)
}
if err := repo.Retire(ctx, k.ID); err != nil {
t.Fatalf("first Retire: %v", err)
}
if err := repo.Retire(ctx, k.ID); err != nil {
t.Errorf("second Retire (already retired) should be idempotent; got %v", err)
}
}
func TestSessionSigningKeyRepository_DeleteRefusedWhenSessionsReference(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
keyRepo := postgres.NewSessionSigningKeyRepository(db)
sessRepo := postgres.NewSessionRepository(db)
ctx := context.Background()
k := newValidSigningKey("inuse")
if err := keyRepo.Add(ctx, k); err != nil {
t.Fatalf("Add key: %v", err)
}
s := newValidSession("s1", k.ID)
if err := sessRepo.Create(ctx, s); err != nil {
t.Fatalf("Create session: %v", err)
}
err := keyRepo.Delete(ctx, k.ID)
if !errors.Is(err, repository.ErrSessionSigningKeyInUse) {
t.Errorf("Delete with referencing session err = %v; want ErrSessionSigningKeyInUse", err)
}
}
// =============================================================================
// Session tests
// =============================================================================
func newValidSession(suffix, signingKeyID string) *sessiondomain.Session {
now := time.Now().UTC().Truncate(time.Microsecond)
return &sessiondomain.Session{
ID: "ses-" + suffix,
TenantID: "t-default",
ActorID: "u-" + suffix,
ActorType: "User",
SigningKeyID: signingKeyID,
IsPreLogin: false,
CSRFTokenHash: strings.Repeat("a", 64),
IdleExpiresAt: now.Add(time.Hour),
AbsoluteExpiresAt: now.Add(8 * time.Hour),
CreatedAt: now,
LastSeenAt: now,
IPAddress: "10.0.0.1",
UserAgent: "Mozilla/5.0",
}
}
func TestSessionRepository_CreateAndGet(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
keyRepo := postgres.NewSessionSigningKeyRepository(db)
sessRepo := postgres.NewSessionRepository(db)
ctx := context.Background()
k := newValidSigningKey("k1")
if err := keyRepo.Add(ctx, k); err != nil {
t.Fatalf("Add key: %v", err)
}
s := newValidSession("s1", k.ID)
if err := sessRepo.Create(ctx, s); err != nil {
t.Fatalf("Create: %v", err)
}
got, err := sessRepo.Get(ctx, s.ID)
if err != nil {
t.Fatalf("Get: %v", err)
}
if got.ActorID != s.ActorID {
t.Errorf("ActorID roundtrip mismatch")
}
if got.RevokedAt != nil {
t.Errorf("RevokedAt should be nil on fresh session")
}
}
func TestSessionRepository_GetNotFound(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
repo := postgres.NewSessionRepository(db)
ctx := context.Background()
_, err := repo.Get(ctx, "ses-nonexistent")
if !errors.Is(err, repository.ErrSessionNotFound) {
t.Errorf("err = %v; want ErrSessionNotFound", err)
}
}
func TestSessionRepository_RevokeAndGet(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
keyRepo := postgres.NewSessionSigningKeyRepository(db)
sessRepo := postgres.NewSessionRepository(db)
ctx := context.Background()
k := newValidSigningKey("k2")
if err := keyRepo.Add(ctx, k); err != nil {
t.Fatalf("Add key: %v", err)
}
s := newValidSession("s2", k.ID)
if err := sessRepo.Create(ctx, s); err != nil {
t.Fatalf("Create: %v", err)
}
if err := sessRepo.Revoke(ctx, s.ID); err != nil {
t.Fatalf("Revoke: %v", err)
}
got, err := sessRepo.Get(ctx, s.ID)
if err != nil {
t.Fatalf("Get post-revoke: %v", err)
}
if got.RevokedAt == nil {
t.Errorf("RevokedAt nil after Revoke")
}
// Idempotent re-revoke: returns nil, no panic, no double-update.
if err := sessRepo.Revoke(ctx, s.ID); err != nil {
t.Errorf("re-Revoke (idempotent) err = %v; want nil", err)
}
}
func TestSessionRepository_RevokeNotFound(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
repo := postgres.NewSessionRepository(db)
ctx := context.Background()
if err := repo.Revoke(ctx, "ses-nonexistent"); !errors.Is(err, repository.ErrSessionNotFound) {
t.Errorf("err = %v; want ErrSessionNotFound", err)
}
}
func TestSessionRepository_ListByActorActiveOnly(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
keyRepo := postgres.NewSessionSigningKeyRepository(db)
sessRepo := postgres.NewSessionRepository(db)
ctx := context.Background()
k := newValidSigningKey("la")
if err := keyRepo.Add(ctx, k); err != nil {
t.Fatalf("Add key: %v", err)
}
// 3 active + 1 revoked + 1 pre-login.
for i, suf := range []string{"a1", "a2", "a3"} {
s := newValidSession(suf, k.ID)
s.ActorID = "u-list-actor"
// uniqueness: stagger created_at so list ordering is stable
s.CreatedAt = s.CreatedAt.Add(time.Duration(i) * time.Millisecond)
if err := sessRepo.Create(ctx, s); err != nil {
t.Fatalf("Create %s: %v", suf, err)
}
}
revoked := newValidSession("rev", k.ID)
revoked.ActorID = "u-list-actor"
if err := sessRepo.Create(ctx, revoked); err != nil {
t.Fatalf("Create revoked: %v", err)
}
if err := sessRepo.Revoke(ctx, revoked.ID); err != nil {
t.Fatalf("Revoke: %v", err)
}
preLogin := newValidSession("pre", k.ID)
preLogin.ActorID = "u-list-actor"
preLogin.IsPreLogin = true
preLogin.CSRFTokenHash = "" // pre-login rows have no CSRF token
if err := sessRepo.Create(ctx, preLogin); err != nil {
t.Fatalf("Create pre-login: %v", err)
}
out, err := sessRepo.ListByActor(ctx, "u-list-actor", "User", "t-default")
if err != nil {
t.Fatalf("ListByActor: %v", err)
}
if len(out) != 3 {
t.Errorf("ListByActor count = %d; want 3 (revoked + pre-login excluded)", len(out))
}
}
func TestSessionRepository_RevokeAllForActor(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
keyRepo := postgres.NewSessionSigningKeyRepository(db)
sessRepo := postgres.NewSessionRepository(db)
ctx := context.Background()
k := newValidSigningKey("ra")
if err := keyRepo.Add(ctx, k); err != nil {
t.Fatalf("Add key: %v", err)
}
// 3 sessions for one actor.
for _, suf := range []string{"r1", "r2", "r3"} {
s := newValidSession(suf, k.ID)
s.ActorID = "u-fired"
if err := sessRepo.Create(ctx, s); err != nil {
t.Fatalf("Create %s: %v", suf, err)
}
}
if err := sessRepo.RevokeAllForActor(ctx, "u-fired", "User", "t-default"); err != nil {
t.Fatalf("RevokeAllForActor: %v", err)
}
out, err := sessRepo.ListByActor(ctx, "u-fired", "User", "t-default")
if err != nil {
t.Fatalf("ListByActor post-revoke: %v", err)
}
if len(out) != 0 {
t.Errorf("RevokeAllForActor left %d sessions active; want 0", len(out))
}
}
func TestSessionRepository_GarbageCollectExpired(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
keyRepo := postgres.NewSessionSigningKeyRepository(db)
sessRepo := postgres.NewSessionRepository(db)
ctx := context.Background()
k := newValidSigningKey("gc")
if err := keyRepo.Add(ctx, k); err != nil {
t.Fatalf("Add key: %v", err)
}
// One session with absolute expiry in the past (write directly via SQL
// to bypass the CHECK constraints; this simulates a row that aged
// past expiry without GC having run yet).
now := time.Now().UTC()
old := time.Now().UTC().Add(-2 * time.Hour)
older := time.Now().UTC().Add(-3 * time.Hour)
_, err := db.ExecContext(ctx, `
INSERT INTO sessions (id, tenant_id, actor_id, actor_type, signing_key_id,
is_pre_login, csrf_token_hash, idle_expires_at, absolute_expires_at,
created_at, last_seen_at, ip_address, user_agent)
VALUES ($1, 't-default', 'u-gc', 'User', $2, FALSE, '',
$3, $4, $5, $5, '', '')`,
"ses-expired", k.ID, older, old, time.Now().UTC().Add(-4*time.Hour))
if err != nil {
t.Fatalf("seed expired: %v", err)
}
// One pre-login row older than 10 minutes.
_, err = db.ExecContext(ctx, `
INSERT INTO sessions (id, tenant_id, actor_id, actor_type, signing_key_id,
is_pre_login, csrf_token_hash, idle_expires_at, absolute_expires_at,
created_at, last_seen_at, ip_address, user_agent)
VALUES ($1, 't-default', 'u-gc', 'User', $2, TRUE, '',
$3, $4, $5, $5, '', '')`,
"ses-prelogin-old", k.ID,
now.Add(-15*time.Minute).Add(time.Hour), // idle in future relative to created
now.Add(-15*time.Minute).Add(2*time.Hour), // absolute > idle, both > created
now.Add(-15*time.Minute)) // created 15 min ago (older than 10 min TTL)
if err != nil {
t.Fatalf("seed pre-login: %v", err)
}
// One active session (NOT to be GC'd).
active := newValidSession("active", k.ID)
active.ActorID = "u-gc"
if err := sessRepo.Create(ctx, active); err != nil {
t.Fatalf("seed active: %v", err)
}
n, err := sessRepo.GarbageCollectExpired(ctx)
if err != nil {
t.Fatalf("GC: %v", err)
}
if n != 2 {
t.Errorf("GC deleted %d rows; want 2 (expired + old pre-login)", n)
}
// Active session survives.
if _, err := sessRepo.Get(ctx, active.ID); err != nil {
t.Errorf("active session should survive GC; got %v", err)
}
}
func TestSessionRepository_UpdateLastSeen(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
keyRepo := postgres.NewSessionSigningKeyRepository(db)
sessRepo := postgres.NewSessionRepository(db)
ctx := context.Background()
k := newValidSigningKey("uls")
if err := keyRepo.Add(ctx, k); err != nil {
t.Fatalf("Add key: %v", err)
}
s := newValidSession("uls", k.ID)
if err := sessRepo.Create(ctx, s); err != nil {
t.Fatalf("Create: %v", err)
}
originalSeen := s.LastSeenAt
time.Sleep(10 * time.Millisecond)
if err := sessRepo.UpdateLastSeen(ctx, s.ID); err != nil {
t.Fatalf("UpdateLastSeen: %v", err)
}
got, _ := sessRepo.Get(ctx, s.ID)
if !got.LastSeenAt.After(originalSeen) {
t.Errorf("LastSeenAt did not advance after UpdateLastSeen")
}
}
+137
View File
@@ -0,0 +1,137 @@
package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/lib/pq"
userdomain "github.com/certctl-io/certctl/internal/auth/user/domain"
"github.com/certctl-io/certctl/internal/repository"
)
// UserRepository is the postgres implementation of
// repository.UserRepository (Auth Bundle 2 Phase 2).
type UserRepository struct {
db *sql.DB
}
// NewUserRepository constructs a UserRepository.
func NewUserRepository(db *sql.DB) *UserRepository {
return &UserRepository{db: db}
}
const userColumns = `id, tenant_id, email, display_name, oidc_subject,
oidc_provider_id, last_login_at, webauthn_credentials,
created_at, updated_at`
func scanUser(row interface{ Scan(...interface{}) error }) (*userdomain.User, error) {
var u userdomain.User
if err := row.Scan(
&u.ID, &u.TenantID, &u.Email, &u.DisplayName, &u.OIDCSubject,
&u.OIDCProviderID, &u.LastLoginAt, &u.WebAuthnCredentials,
&u.CreatedAt, &u.UpdatedAt,
); err != nil {
return nil, err
}
return &u, nil
}
// Get returns one user by id.
func (r *UserRepository) Get(ctx context.Context, id string) (*userdomain.User, error) {
row := r.db.QueryRowContext(ctx, `SELECT `+userColumns+` FROM users WHERE id = $1`, id)
u, err := scanUser(row)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, repository.ErrUserNotFound
}
return nil, fmt.Errorf("users get: %w", err)
}
return u, nil
}
// GetByOIDCSubject is the Phase 3 hot-path lookup at login time.
// Returns ErrUserNotFound if no row matches the (provider, subject)
// tuple — Phase 3's HandleCallback then creates the row via Create.
func (r *UserRepository) GetByOIDCSubject(ctx context.Context, providerID, subject string) (*userdomain.User, error) {
row := r.db.QueryRowContext(ctx, `
SELECT `+userColumns+`
FROM users
WHERE oidc_provider_id = $1 AND oidc_subject = $2`,
providerID, subject)
u, err := scanUser(row)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, repository.ErrUserNotFound
}
return nil, fmt.Errorf("users get_by_oidc_subject: %w", err)
}
return u, nil
}
// Create persists a new user. Translates SQLSTATE 23505 into
// ErrUserDuplicateOIDCSubject (the unique constraint on
// (oidc_provider_id, oidc_subject)).
func (r *UserRepository) Create(ctx context.Context, u *userdomain.User) error {
_, err := r.db.ExecContext(ctx, `
INSERT INTO users (
id, tenant_id, email, display_name, oidc_subject,
oidc_provider_id, last_login_at, webauthn_credentials
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
u.ID, u.TenantID, u.Email, u.DisplayName, u.OIDCSubject,
u.OIDCProviderID, u.LastLoginAt, u.WebAuthnCredentials)
if err != nil {
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
return repository.ErrUserDuplicateOIDCSubject
}
return fmt.Errorf("users create: %w", err)
}
return nil
}
// Update writes the mutable fields (email, display_name, last_login_at,
// webauthn_credentials) back to the row. Immutable: id, tenant_id,
// oidc_subject, oidc_provider_id, created_at. updated_at = NOW().
func (r *UserRepository) Update(ctx context.Context, u *userdomain.User) error {
res, err := r.db.ExecContext(ctx, `
UPDATE users SET
email = $2,
display_name = $3,
last_login_at = $4,
webauthn_credentials = $5,
updated_at = NOW()
WHERE id = $1`,
u.ID, u.Email, u.DisplayName, u.LastLoginAt, u.WebAuthnCredentials)
if err != nil {
return fmt.Errorf("users update: %w", err)
}
n, _ := res.RowsAffected()
if n == 0 {
return repository.ErrUserNotFound
}
return nil
}
// ListAll returns every user in the tenant, ordered by created_at ASC.
func (r *UserRepository) ListAll(ctx context.Context, tenantID string) ([]*userdomain.User, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT `+userColumns+` FROM users WHERE tenant_id = $1 ORDER BY created_at ASC`,
tenantID)
if err != nil {
return nil, fmt.Errorf("users list_all: %w", err)
}
defer rows.Close()
var out []*userdomain.User
for rows.Next() {
u, err := scanUser(rows)
if err != nil {
return nil, fmt.Errorf("users scan: %w", err)
}
out = append(out, u)
}
return out, rows.Err()
}
+220
View File
@@ -0,0 +1,220 @@
package postgres_test
import (
"context"
"errors"
"testing"
userdomain "github.com/certctl-io/certctl/internal/auth/user/domain"
"github.com/certctl-io/certctl/internal/repository"
"github.com/certctl-io/certctl/internal/repository/postgres"
)
// newValidUser is shared with oidc_test.go (same _test package).
func newValidUser(suffix, providerID string) *userdomain.User {
return &userdomain.User{
ID: "u-" + suffix,
TenantID: "t-default",
Email: suffix + "@example.com",
DisplayName: "User " + suffix,
OIDCSubject: "subject-" + suffix,
OIDCProviderID: providerID,
WebAuthnCredentials: []byte("[]"),
}
}
func TestUserRepository_CreateAndGet(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
providerRepo := postgres.NewOIDCProviderRepository(db)
userRepo := postgres.NewUserRepository(db)
ctx := context.Background()
p := newValidProvider("u")
if err := providerRepo.Create(ctx, p); err != nil {
t.Fatalf("Create provider: %v", err)
}
u := newValidUser("alice", p.ID)
if err := userRepo.Create(ctx, u); err != nil {
t.Fatalf("Create user: %v", err)
}
got, err := userRepo.Get(ctx, u.ID)
if err != nil {
t.Fatalf("Get: %v", err)
}
if got.Email != u.Email {
t.Errorf("Email roundtrip: got %q, want %q", got.Email, u.Email)
}
if string(got.WebAuthnCredentials) != "[]" {
t.Errorf("WebAuthnCredentials default = %q; want []", string(got.WebAuthnCredentials))
}
}
func TestUserRepository_GetNotFound(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
repo := postgres.NewUserRepository(db)
ctx := context.Background()
_, err := repo.Get(ctx, "u-nonexistent")
if !errors.Is(err, repository.ErrUserNotFound) {
t.Errorf("err = %v; want ErrUserNotFound", err)
}
}
func TestUserRepository_GetByOIDCSubject(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
providerRepo := postgres.NewOIDCProviderRepository(db)
userRepo := postgres.NewUserRepository(db)
ctx := context.Background()
p := newValidProvider("subj")
if err := providerRepo.Create(ctx, p); err != nil {
t.Fatalf("Create provider: %v", err)
}
u := newValidUser("bob", p.ID)
if err := userRepo.Create(ctx, u); err != nil {
t.Fatalf("Create user: %v", err)
}
got, err := userRepo.GetByOIDCSubject(ctx, p.ID, u.OIDCSubject)
if err != nil {
t.Fatalf("GetByOIDCSubject: %v", err)
}
if got.ID != u.ID {
t.Errorf("GetByOIDCSubject returned %q; want %q", got.ID, u.ID)
}
// Wrong subject: not found.
_, err = userRepo.GetByOIDCSubject(ctx, p.ID, "wrong-subject")
if !errors.Is(err, repository.ErrUserNotFound) {
t.Errorf("err = %v; want ErrUserNotFound", err)
}
}
func TestUserRepository_DuplicateOIDCSubjectRejected(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
providerRepo := postgres.NewOIDCProviderRepository(db)
userRepo := postgres.NewUserRepository(db)
ctx := context.Background()
p := newValidProvider("dupsubj")
if err := providerRepo.Create(ctx, p); err != nil {
t.Fatalf("Create provider: %v", err)
}
u1 := newValidUser("first", p.ID)
if err := userRepo.Create(ctx, u1); err != nil {
t.Fatalf("Create u1: %v", err)
}
u2 := newValidUser("second", p.ID)
u2.OIDCSubject = u1.OIDCSubject // collision on (provider, subject) UNIQUE
err := userRepo.Create(ctx, u2)
if !errors.Is(err, repository.ErrUserDuplicateOIDCSubject) {
t.Errorf("Create duplicate (provider, subject) err = %v; want ErrUserDuplicateOIDCSubject", err)
}
}
func TestUserRepository_UpdateMutableFields(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
providerRepo := postgres.NewOIDCProviderRepository(db)
userRepo := postgres.NewUserRepository(db)
ctx := context.Background()
p := newValidProvider("upd")
if err := providerRepo.Create(ctx, p); err != nil {
t.Fatalf("Create provider: %v", err)
}
u := newValidUser("carol", p.ID)
if err := userRepo.Create(ctx, u); err != nil {
t.Fatalf("Create user: %v", err)
}
u.Email = "carol-new@example.com"
u.DisplayName = "Carol Renamed"
if err := userRepo.Update(ctx, u); err != nil {
t.Fatalf("Update: %v", err)
}
got, err := userRepo.Get(ctx, u.ID)
if err != nil {
t.Fatalf("Get post-update: %v", err)
}
if got.Email != "carol-new@example.com" {
t.Errorf("Update did not persist Email; got %q", got.Email)
}
if got.DisplayName != "Carol Renamed" {
t.Errorf("Update did not persist DisplayName; got %q", got.DisplayName)
}
// Immutable: oidc_subject must NOT change.
if got.OIDCSubject != u.OIDCSubject {
t.Errorf("OIDCSubject mutated: got %q, want %q", got.OIDCSubject, u.OIDCSubject)
}
}
func TestUserRepository_ListAll(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
providerRepo := postgres.NewOIDCProviderRepository(db)
userRepo := postgres.NewUserRepository(db)
ctx := context.Background()
p := newValidProvider("la")
if err := providerRepo.Create(ctx, p); err != nil {
t.Fatalf("Create provider: %v", err)
}
for _, suf := range []string{"u1", "u2", "u3"} {
u := newValidUser(suf, p.ID)
if err := userRepo.Create(ctx, u); err != nil {
t.Fatalf("Create %s: %v", suf, err)
}
}
out, err := userRepo.ListAll(ctx, "t-default")
if err != nil {
t.Fatalf("ListAll: %v", err)
}
if len(out) != 3 {
t.Errorf("ListAll count = %d; want 3", len(out))
}
}
// TestUserRepository_DeletingProviderRefusedWhenUsersReference complements
// the OIDCProviderRepository test of the same shape; pinning both ends
// of the FK ON DELETE RESTRICT contract.
func TestUserRepository_FKRestrictsProviderDelete(t *testing.T) {
if testing.Short() {
t.Skip("integration test in short mode")
}
db := getTestDB(t).freshSchema(t)
providerRepo := postgres.NewOIDCProviderRepository(db)
userRepo := postgres.NewUserRepository(db)
ctx := context.Background()
p := newValidProvider("fkrest")
if err := providerRepo.Create(ctx, p); err != nil {
t.Fatalf("Create provider: %v", err)
}
u := newValidUser("fkrest-user", p.ID)
if err := userRepo.Create(ctx, u); err != nil {
t.Fatalf("Create user: %v", err)
}
if err := providerRepo.Delete(ctx, p.ID); !errors.Is(err, repository.ErrOIDCProviderInUse) {
t.Errorf("Delete provider (with referencing user) err = %v; want ErrOIDCProviderInUse", err)
}
}
+124
View File
@@ -0,0 +1,124 @@
package repository
import (
"context"
"errors"
sessiondomain "github.com/certctl-io/certctl/internal/auth/session/domain"
)
// Sentinel errors for the session repositories.
var (
// ErrSessionNotFound: Get returned no row. Phase 4 maps to 401
// (the cookie either expired or was forged with a known-good key
// id but stale session id).
ErrSessionNotFound = errors.New("session: not found")
// ErrSessionRevoked: Get found a row but RevokedAt is set. Phase 4
// maps to 401.
ErrSessionRevoked = errors.New("session: revoked")
// ErrSessionExpired: Get found a row but the absolute expiry has
// passed (Phase 4 also enforces idle expiry but that's a service-
// level check against last_seen_at, not a repository sentinel).
ErrSessionExpired = errors.New("session: expired")
// ErrSessionSigningKeyNotFound: GetActive returned no row. Phase 4
// EnsureInitialSigningKey treats this as "boot-time provisioning
// needed" and mints the first key.
ErrSessionSigningKeyNotFound = errors.New("session: signing key not found")
// ErrSessionSigningKeyInUse: Delete (full purge, not Retire) failed
// because at least one sessions row still references the key. Phase
// 4's GarbageCollect waits for sessions to expire before purging.
ErrSessionSigningKeyInUse = errors.New("session: signing key still referenced by active sessions")
)
// SessionRepository wraps the sessions table. Two cookie shapes share
// the rows: post-login sessions (1h-idle/8h-absolute) and pre-login
// sessions (10-minute TTL, IsPreLogin=true; carry OIDC state + nonce
// + PKCE verifier across the IdP redirect).
type SessionRepository interface {
// Create persists a session row. Caller MUST have called
// s.Validate(). Returns ErrAuthDuplicateName-shape on the
// extremely-unlikely id collision (the id is a 32-byte random;
// callers SHOULD generate fresh ids on the second attempt).
Create(ctx context.Context, s *sessiondomain.Session) error
// Get returns a session by id. ErrSessionNotFound on miss.
// Returns the row even if revoked / expired so the service layer
// can produce the right 401 reason code (revoked vs expired vs
// not-found are all 401 to the wire but distinguishable in audit).
Get(ctx context.Context, id string) (*sessiondomain.Session, error)
// ListByActor returns every active (non-revoked, non-expired,
// non-pre-login) session for an actor. Used by the GUI's
// /v1/auth/sessions surface so users can revoke their old laptops.
ListByActor(ctx context.Context, actorID, actorType, tenantID string) ([]*sessiondomain.Session, error)
// UpdateLastSeen sets last_seen_at = NOW() for the named session.
// Phase 4's middleware calls this on every request to keep the
// idle-expiry sliding window fresh.
UpdateLastSeen(ctx context.Context, id string) error
// Revoke sets revoked_at = NOW() for the named session. Subsequent
// Get returns the row with RevokedAt set; Phase 4's Validate maps
// to 401.
Revoke(ctx context.Context, id string) error
// RevokeAllForActor sets revoked_at = NOW() on every active session
// for an actor. Used on role change, fired-employee scenarios, and
// the back-channel logout endpoint (Phase 5).
RevokeAllForActor(ctx context.Context, actorID, actorType, tenantID string) error
// GarbageCollectExpired deletes sessions whose absolute expiry
// has passed AND whose revoked_at is older than the configurable
// retention window (default 24h). Pre-login rows older than the
// 10-minute TTL are also deleted. Returns the number of rows
// deleted.
GarbageCollectExpired(ctx context.Context) (int, error)
// Delete unconditionally removes a session row. Used for the
// admin-only "purge a specific session" surface (rarely needed;
// Revoke is the normal path).
Delete(ctx context.Context, id string) error
}
// SessionSigningKeyRepository wraps the session_signing_keys table.
// Phase 4's Service.RotateSigningKey + EnsureInitialSigningKey + the
// scheduler-driven retention sweep consume this.
type SessionSigningKeyRepository interface {
// List returns every signing key in the tenant (including
// retired). Order: created_at DESC.
List(ctx context.Context, tenantID string) ([]*sessiondomain.SessionSigningKey, error)
// GetActive returns the most-recently-created non-retired key.
// ErrSessionSigningKeyNotFound when no non-retired key exists
// (Phase 4's EnsureInitialSigningKey treats this as "mint first
// key").
GetActive(ctx context.Context, tenantID string) (*sessiondomain.SessionSigningKey, error)
// Get returns one key by id (including retired keys; Phase 4's
// Validate consults this for cookies signed under retired-but-
// in-retention keys).
Get(ctx context.Context, id string) (*sessiondomain.SessionSigningKey, error)
// Add persists a new signing key. Caller MUST have called
// k.Validate() and encrypted the key_material via
// internal/crypto/encryption.go. CreatedAt defaults to NOW() if
// zero.
Add(ctx context.Context, k *sessiondomain.SessionSigningKey) error
// Retire marks an active key as retired (sets retired_at = NOW()).
// The key stays in the table for verification of cookies signed
// under it; the scheduler's retention sweep purges it after the
// configurable retention window (default 24h beyond retired_at).
Retire(ctx context.Context, id string) error
// Delete unconditionally removes a signing key row. Returns
// ErrSessionSigningKeyInUse if any sessions row still references
// the key (FK ON DELETE RESTRICT). Phase 4's GarbageCollect calls
// this only after RetentionWindow has passed AND no sessions
// reference the key.
Delete(ctx context.Context, id string) error
}
+46
View File
@@ -0,0 +1,46 @@
package repository
import (
"context"
"errors"
userdomain "github.com/certctl-io/certctl/internal/auth/user/domain"
)
// Sentinel errors for the user repository.
var (
// ErrUserNotFound: Get / GetByOIDCSubject returned no row. Phase
// 3's HandleCallback treats this as "first login for this person;
// create the row".
ErrUserNotFound = errors.New("user: not found")
// ErrUserDuplicateOIDCSubject: Create tripped the
// (oidc_provider_id, oidc_subject) UNIQUE constraint. HTTP 409.
ErrUserDuplicateOIDCSubject = errors.New("user: a user with this provider+subject already exists")
)
// UserRepository wraps the users table. Phase 3's HandleCallback
// uses GetByOIDCSubject + Create + Update on every login; the GUI's
// admin user-list surface uses ListAll + Get.
type UserRepository interface {
// Get returns one user by id. ErrUserNotFound on miss.
Get(ctx context.Context, id string) (*userdomain.User, error)
// GetByOIDCSubject is the Phase 3 hot-path lookup at login time.
// Returns the existing row if present, ErrUserNotFound otherwise.
GetByOIDCSubject(ctx context.Context, providerID, subject string) (*userdomain.User, error)
// Create persists a new user. Caller MUST have called u.Validate().
// Returns ErrUserDuplicateOIDCSubject on UNIQUE constraint trip.
Create(ctx context.Context, u *userdomain.User) error
// Update writes the mutable field set back to the row. Immutable
// fields (id, tenant_id, oidc_subject, oidc_provider_id,
// created_at) are preserved. updated_at is set to NOW() by the
// implementation.
Update(ctx context.Context, u *userdomain.User) error
// ListAll returns every user in the tenant. Order:
// created_at ASC. Used by the GUI's admin surface.
ListAll(ctx context.Context, tenantID string) ([]*userdomain.User, error)
}