diff --git a/internal/repository/oidc.go b/internal/repository/oidc.go new file mode 100644 index 0000000..6f2d2d8 --- /dev/null +++ b/internal/repository/oidc.go @@ -0,0 +1,94 @@ +package repository + +import ( + "context" + "errors" + + oidcdomain "github.com/certctl-io/certctl/internal/auth/oidc/domain" +) + +// Sentinel errors for the OIDC repositories. Postgres implementations +// translate SQLSTATE codes into these so handler / service code can +// branch via errors.Is. +var ( + // ErrOIDCProviderNotFound: Get / GetByName returned no row. HTTP 404. + ErrOIDCProviderNotFound = errors.New("oidc: provider not found") + + // ErrOIDCProviderDuplicateName: Create tripped the (tenant_id, name) + // UNIQUE constraint. HTTP 409. + ErrOIDCProviderDuplicateName = errors.New("oidc: provider with this name already exists in tenant") + + // ErrOIDCProviderInUse: Delete failed because at least one users row + // references the provider via oidc_provider_id (FK ON DELETE + // RESTRICT). HTTP 409. + ErrOIDCProviderInUse = errors.New("oidc: provider has authenticated users; revoke all sessions before delete") + + // ErrGroupRoleMappingNotFound: Get returned no row. HTTP 404. + ErrGroupRoleMappingNotFound = errors.New("oidc: group-role mapping not found") + + // ErrGroupRoleMappingDuplicate: Add tripped the + // (provider_id, group_name, role_id) UNIQUE constraint. HTTP 409. + ErrGroupRoleMappingDuplicate = errors.New("oidc: group-role mapping already exists") +) + +// OIDCProviderRepository wraps the oidc_providers table. Phase 3's +// OIDCService consumes List + Get to look up the IdP for token +// validation; the GUI / CLI wire Create / Update / Delete behind +// auth.oidc.* permission gates per Phase 5. +type OIDCProviderRepository interface { + // List returns every configured provider in the tenant. Order: + // created_at ASC for stable GUI rendering. + List(ctx context.Context, tenantID string) ([]*oidcdomain.OIDCProvider, error) + + // Get returns one provider by id. ErrOIDCProviderNotFound on miss. + Get(ctx context.Context, id string) (*oidcdomain.OIDCProvider, error) + + // GetByName returns one provider by (tenant_id, name). + // ErrOIDCProviderNotFound on miss. + GetByName(ctx context.Context, tenantID, name string) (*oidcdomain.OIDCProvider, error) + + // Create persists a new provider. Caller MUST have already called + // p.Validate() and encrypted the client_secret_encrypted byte + // stream via internal/crypto/encryption.go. Returns + // ErrOIDCProviderDuplicateName when the (tenant_id, name) UNIQUE + // constraint fires. + Create(ctx context.Context, p *oidcdomain.OIDCProvider) error + + // Update writes the full mutable field set back to the row. + // Immutable fields (id, tenant_id, created_at) are read-only; + // updated_at is set to NOW() by the implementation. + Update(ctx context.Context, p *oidcdomain.OIDCProvider) error + + // Delete removes a provider by id. Returns ErrOIDCProviderInUse + // when at least one users row references this provider (FK ON + // DELETE RESTRICT). Phase 5's handler maps to HTTP 409. + Delete(ctx context.Context, id string) error +} + +// GroupRoleMappingRepository wraps the group_role_mappings table. +// Phase 3's OIDCService.HandleCallback uses Map() to translate IdP +// group claims into role IDs; the GUI / CLI wire ListByProvider / +// Add / Remove for operator configuration. +type GroupRoleMappingRepository interface { + // ListByProvider returns every mapping for the named provider. + // Order: group_name ASC for stable GUI rendering. + ListByProvider(ctx context.Context, providerID string) ([]*oidcdomain.GroupRoleMapping, error) + + // Get returns one mapping by id. ErrGroupRoleMappingNotFound on miss. + Get(ctx context.Context, id string) (*oidcdomain.GroupRoleMapping, error) + + // Add persists a new mapping. Caller MUST have called m.Validate(). + // Returns ErrGroupRoleMappingDuplicate when the + // (provider_id, group_name, role_id) UNIQUE constraint fires. + Add(ctx context.Context, m *oidcdomain.GroupRoleMapping) error + + // Remove deletes a mapping by id. + Remove(ctx context.Context, id string) error + + // Map resolves an IdP-supplied list of group names against the + // provider's mappings. Returns the deduplicated set of role IDs + // the user should hold. Empty result means the user matches no + // mapping (Phase 3 fail-closed: no session minted, audit row + // `auth.oidc_login_unmapped_groups`). + Map(ctx context.Context, providerID string, groupNames []string) ([]string, error) +} diff --git a/internal/repository/postgres/oidc.go b/internal/repository/postgres/oidc.go new file mode 100644 index 0000000..9cf08a4 --- /dev/null +++ b/internal/repository/postgres/oidc.go @@ -0,0 +1,309 @@ +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/lib/pq" + + oidcdomain "github.com/certctl-io/certctl/internal/auth/oidc/domain" + "github.com/certctl-io/certctl/internal/repository" +) + +// ============================================================================= +// OIDCProviderRepository (Auth Bundle 2 Phase 2) +// ============================================================================= + +// OIDCProviderRepository is the postgres implementation of +// repository.OIDCProviderRepository. +type OIDCProviderRepository struct { + db *sql.DB +} + +// NewOIDCProviderRepository constructs an OIDCProviderRepository. +func NewOIDCProviderRepository(db *sql.DB) *OIDCProviderRepository { + return &OIDCProviderRepository{db: db} +} + +const oidcProviderColumns = `id, tenant_id, name, issuer_url, client_id, + client_secret_encrypted, redirect_uri, groups_claim_path, + groups_claim_format, fetch_userinfo, scopes, + allowed_email_domains, iat_window_seconds, + jwks_cache_ttl_seconds, created_at, updated_at` + +func scanOIDCProvider(row interface{ Scan(...interface{}) error }) (*oidcdomain.OIDCProvider, error) { + var p oidcdomain.OIDCProvider + var scopes, domains pq.StringArray + if err := row.Scan( + &p.ID, &p.TenantID, &p.Name, &p.IssuerURL, &p.ClientID, + &p.ClientSecretEncrypted, &p.RedirectURI, &p.GroupsClaimPath, + &p.GroupsClaimFormat, &p.FetchUserinfo, &scopes, + &domains, &p.IATWindowSeconds, + &p.JWKSCacheTTLSeconds, &p.CreatedAt, &p.UpdatedAt, + ); err != nil { + return nil, err + } + p.Scopes = []string(scopes) + p.AllowedEmailDomains = []string(domains) + return &p, nil +} + +// List returns every configured OIDC provider in the tenant, ordered +// by created_at ASC for stable GUI rendering. +func (r *OIDCProviderRepository) List(ctx context.Context, tenantID string) ([]*oidcdomain.OIDCProvider, error) { + rows, err := r.db.QueryContext(ctx, `SELECT `+oidcProviderColumns+` FROM oidc_providers WHERE tenant_id = $1 ORDER BY created_at ASC`, tenantID) + if err != nil { + return nil, fmt.Errorf("oidc_providers list: %w", err) + } + defer rows.Close() + + var out []*oidcdomain.OIDCProvider + for rows.Next() { + p, err := scanOIDCProvider(rows) + if err != nil { + return nil, fmt.Errorf("oidc_providers scan: %w", err) + } + out = append(out, p) + } + if err := rows.Err(); err != nil { + return nil, err + } + return out, nil +} + +// Get returns one provider by id. ErrOIDCProviderNotFound on miss. +func (r *OIDCProviderRepository) Get(ctx context.Context, id string) (*oidcdomain.OIDCProvider, error) { + row := r.db.QueryRowContext(ctx, `SELECT `+oidcProviderColumns+` FROM oidc_providers WHERE id = $1`, id) + p, err := scanOIDCProvider(row) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, repository.ErrOIDCProviderNotFound + } + return nil, fmt.Errorf("oidc_providers get: %w", err) + } + return p, nil +} + +// GetByName returns one provider by (tenant_id, name). +func (r *OIDCProviderRepository) GetByName(ctx context.Context, tenantID, name string) (*oidcdomain.OIDCProvider, error) { + row := r.db.QueryRowContext(ctx, `SELECT `+oidcProviderColumns+` FROM oidc_providers WHERE tenant_id = $1 AND name = $2`, tenantID, name) + p, err := scanOIDCProvider(row) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, repository.ErrOIDCProviderNotFound + } + return nil, fmt.Errorf("oidc_providers get_by_name: %w", err) + } + return p, nil +} + +// Create persists a new provider. Caller MUST have called p.Validate() +// and encrypted ClientSecretEncrypted via internal/crypto/encryption.go. +// Translates SQLSTATE 23505 (unique_violation) to +// ErrOIDCProviderDuplicateName. +func (r *OIDCProviderRepository) Create(ctx context.Context, p *oidcdomain.OIDCProvider) error { + _, err := r.db.ExecContext(ctx, ` + INSERT INTO oidc_providers ( + id, tenant_id, name, issuer_url, client_id, + client_secret_encrypted, redirect_uri, groups_claim_path, + groups_claim_format, fetch_userinfo, scopes, + allowed_email_domains, iat_window_seconds, + jwks_cache_ttl_seconds + ) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14)`, + p.ID, p.TenantID, p.Name, p.IssuerURL, p.ClientID, + p.ClientSecretEncrypted, p.RedirectURI, p.GroupsClaimPath, + p.GroupsClaimFormat, p.FetchUserinfo, pq.StringArray(p.Scopes), + pq.StringArray(p.AllowedEmailDomains), p.IATWindowSeconds, + p.JWKSCacheTTLSeconds, + ) + if err != nil { + var pqErr *pq.Error + if errors.As(err, &pqErr) && pqErr.Code == "23505" { + return repository.ErrOIDCProviderDuplicateName + } + return fmt.Errorf("oidc_providers create: %w", err) + } + return nil +} + +// Update writes the mutable fields back. Immutable: id, tenant_id, +// created_at. updated_at = NOW(). +func (r *OIDCProviderRepository) Update(ctx context.Context, p *oidcdomain.OIDCProvider) error { + res, err := r.db.ExecContext(ctx, ` + UPDATE oidc_providers SET + name = $2, + issuer_url = $3, + client_id = $4, + client_secret_encrypted = $5, + redirect_uri = $6, + groups_claim_path = $7, + groups_claim_format = $8, + fetch_userinfo = $9, + scopes = $10, + allowed_email_domains = $11, + iat_window_seconds = $12, + jwks_cache_ttl_seconds = $13, + updated_at = NOW() + WHERE id = $1`, + p.ID, p.Name, p.IssuerURL, p.ClientID, + p.ClientSecretEncrypted, p.RedirectURI, p.GroupsClaimPath, + p.GroupsClaimFormat, p.FetchUserinfo, pq.StringArray(p.Scopes), + pq.StringArray(p.AllowedEmailDomains), p.IATWindowSeconds, + p.JWKSCacheTTLSeconds, + ) + if err != nil { + var pqErr *pq.Error + if errors.As(err, &pqErr) && pqErr.Code == "23505" { + return repository.ErrOIDCProviderDuplicateName + } + return fmt.Errorf("oidc_providers update: %w", err) + } + n, _ := res.RowsAffected() + if n == 0 { + return repository.ErrOIDCProviderNotFound + } + return nil +} + +// Delete removes a provider by id. Returns ErrOIDCProviderInUse on +// SQLSTATE 23503 (foreign_key_violation) — the users table's FK ON +// DELETE RESTRICT fires when authenticated users still reference +// this provider. +func (r *OIDCProviderRepository) Delete(ctx context.Context, id string) error { + res, err := r.db.ExecContext(ctx, `DELETE FROM oidc_providers WHERE id = $1`, id) + if err != nil { + var pqErr *pq.Error + if errors.As(err, &pqErr) && pqErr.Code == "23503" { + return repository.ErrOIDCProviderInUse + } + return fmt.Errorf("oidc_providers delete: %w", err) + } + n, _ := res.RowsAffected() + if n == 0 { + return repository.ErrOIDCProviderNotFound + } + return nil +} + +// ============================================================================= +// GroupRoleMappingRepository (Auth Bundle 2 Phase 2) +// ============================================================================= + +// GroupRoleMappingRepository is the postgres implementation of +// repository.GroupRoleMappingRepository. +type GroupRoleMappingRepository struct { + db *sql.DB +} + +// NewGroupRoleMappingRepository constructs a GroupRoleMappingRepository. +func NewGroupRoleMappingRepository(db *sql.DB) *GroupRoleMappingRepository { + return &GroupRoleMappingRepository{db: db} +} + +func scanGroupRoleMapping(row interface{ Scan(...interface{}) error }) (*oidcdomain.GroupRoleMapping, error) { + var m oidcdomain.GroupRoleMapping + if err := row.Scan(&m.ID, &m.TenantID, &m.ProviderID, &m.GroupName, &m.RoleID, &m.CreatedAt); err != nil { + return nil, err + } + return &m, nil +} + +// ListByProvider returns every mapping for the named provider, ordered +// group_name ASC. +func (r *GroupRoleMappingRepository) ListByProvider(ctx context.Context, providerID string) ([]*oidcdomain.GroupRoleMapping, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, tenant_id, provider_id, group_name, role_id, created_at + FROM group_role_mappings + WHERE provider_id = $1 + ORDER BY group_name ASC`, providerID) + if err != nil { + return nil, fmt.Errorf("group_role_mappings list_by_provider: %w", err) + } + defer rows.Close() + + var out []*oidcdomain.GroupRoleMapping + for rows.Next() { + m, err := scanGroupRoleMapping(rows) + if err != nil { + return nil, fmt.Errorf("group_role_mappings scan: %w", err) + } + out = append(out, m) + } + return out, rows.Err() +} + +// Get returns one mapping by id. +func (r *GroupRoleMappingRepository) Get(ctx context.Context, id string) (*oidcdomain.GroupRoleMapping, error) { + row := r.db.QueryRowContext(ctx, ` + SELECT id, tenant_id, provider_id, group_name, role_id, created_at + FROM group_role_mappings WHERE id = $1`, id) + m, err := scanGroupRoleMapping(row) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, repository.ErrGroupRoleMappingNotFound + } + return nil, fmt.Errorf("group_role_mappings get: %w", err) + } + return m, nil +} + +// Add persists a new mapping. Translates SQLSTATE 23505 into +// ErrGroupRoleMappingDuplicate. +func (r *GroupRoleMappingRepository) Add(ctx context.Context, m *oidcdomain.GroupRoleMapping) error { + _, err := r.db.ExecContext(ctx, ` + INSERT INTO group_role_mappings (id, tenant_id, provider_id, group_name, role_id) + VALUES ($1, $2, $3, $4, $5)`, + m.ID, m.TenantID, m.ProviderID, m.GroupName, m.RoleID) + if err != nil { + var pqErr *pq.Error + if errors.As(err, &pqErr) && pqErr.Code == "23505" { + return repository.ErrGroupRoleMappingDuplicate + } + return fmt.Errorf("group_role_mappings add: %w", err) + } + return nil +} + +// Remove deletes a mapping by id. +func (r *GroupRoleMappingRepository) Remove(ctx context.Context, id string) error { + res, err := r.db.ExecContext(ctx, `DELETE FROM group_role_mappings WHERE id = $1`, id) + if err != nil { + return fmt.Errorf("group_role_mappings remove: %w", err) + } + n, _ := res.RowsAffected() + if n == 0 { + return repository.ErrGroupRoleMappingNotFound + } + return nil +} + +// Map resolves IdP-supplied group names against the provider's +// mappings. Returns the deduplicated set of role IDs the user should +// hold. Empty group_names slice yields empty result; empty result +// means fail-closed (no roles, Phase 3 declines to mint a session). +func (r *GroupRoleMappingRepository) Map(ctx context.Context, providerID string, groupNames []string) ([]string, error) { + if len(groupNames) == 0 { + return nil, nil + } + rows, err := r.db.QueryContext(ctx, ` + SELECT DISTINCT role_id + FROM group_role_mappings + WHERE provider_id = $1 AND group_name = ANY($2)`, + providerID, pq.StringArray(groupNames)) + if err != nil { + return nil, fmt.Errorf("group_role_mappings map: %w", err) + } + defer rows.Close() + + var out []string + for rows.Next() { + var roleID string + if err := rows.Scan(&roleID); err != nil { + return nil, fmt.Errorf("group_role_mappings map scan: %w", err) + } + out = append(out, roleID) + } + return out, rows.Err() +} diff --git a/internal/repository/postgres/oidc_test.go b/internal/repository/postgres/oidc_test.go new file mode 100644 index 0000000..427d176 --- /dev/null +++ b/internal/repository/postgres/oidc_test.go @@ -0,0 +1,366 @@ +package postgres_test + +import ( + "context" + "errors" + "testing" + "time" + + oidcdomain "github.com/certctl-io/certctl/internal/auth/oidc/domain" + "github.com/certctl-io/certctl/internal/repository" + "github.com/certctl-io/certctl/internal/repository/postgres" +) + +// ============================================================================= +// OIDCProviderRepository tests (Auth Bundle 2 Phase 2) +// +// Schema-per-test isolation via getTestDB().freshSchema(t). Run with: +// +// go test -count=1 ./internal/repository/postgres/... +// +// (omit -short; testing.Short() skips all integration tests.) +// ============================================================================= + +func newValidProvider(suffix string) *oidcdomain.OIDCProvider { + return &oidcdomain.OIDCProvider{ + ID: "op-" + suffix, + TenantID: "t-default", + Name: "Provider " + suffix, + IssuerURL: "https://idp." + suffix + ".example.com", + ClientID: "certctl", + ClientSecretEncrypted: []byte{0x02, 0x00, 0x01, 0x02, 0x03}, + RedirectURI: "https://certctl.example.com/auth/oidc/callback", + GroupsClaimPath: "groups", + GroupsClaimFormat: "string-array", + Scopes: []string{"openid", "profile", "email"}, + AllowedEmailDomains: []string{}, + IATWindowSeconds: 300, + JWKSCacheTTLSeconds: 3600, + } +} + +func TestOIDCProviderRepository_CreateAndGet(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + repo := postgres.NewOIDCProviderRepository(db) + ctx := context.Background() + + p := newValidProvider("a") + if err := repo.Create(ctx, p); err != nil { + t.Fatalf("Create: %v", err) + } + + got, err := repo.Get(ctx, p.ID) + if err != nil { + t.Fatalf("Get: %v", err) + } + if got.Name != p.Name { + t.Errorf("Name roundtrip: got %q, want %q", got.Name, p.Name) + } + if got.IssuerURL != p.IssuerURL { + t.Errorf("IssuerURL roundtrip mismatch") + } + // Defaults from the migration kicked in for any unset bool / array. + if got.FetchUserinfo != false { + t.Errorf("FetchUserinfo default = %v; want false", got.FetchUserinfo) + } + if len(got.Scopes) != 3 { + t.Errorf("Scopes roundtrip count = %d; want 3", len(got.Scopes)) + } + // Defense: client_secret_encrypted column must NOT contain plaintext. + // Since we wrote a v2 magic-byte stub, the byte stream comes back as-is. + if len(got.ClientSecretEncrypted) == 0 { + t.Errorf("ClientSecretEncrypted lost on roundtrip") + } +} + +func TestOIDCProviderRepository_GetNotFound(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + repo := postgres.NewOIDCProviderRepository(db) + ctx := context.Background() + + _, err := repo.Get(ctx, "op-nonexistent") + if !errors.Is(err, repository.ErrOIDCProviderNotFound) { + t.Errorf("err = %v; want ErrOIDCProviderNotFound", err) + } +} + +func TestOIDCProviderRepository_DuplicateName(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + repo := postgres.NewOIDCProviderRepository(db) + ctx := context.Background() + + p1 := newValidProvider("dup1") + if err := repo.Create(ctx, p1); err != nil { + t.Fatalf("Create p1: %v", err) + } + + p2 := newValidProvider("dup2") + p2.Name = p1.Name // collision on (tenant_id, name) + err := repo.Create(ctx, p2) + if !errors.Is(err, repository.ErrOIDCProviderDuplicateName) { + t.Errorf("Create with duplicate name err = %v; want ErrOIDCProviderDuplicateName", err) + } +} + +func TestOIDCProviderRepository_List(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + repo := postgres.NewOIDCProviderRepository(db) + ctx := context.Background() + + for _, suf := range []string{"x", "y", "z"} { + if err := repo.Create(ctx, newValidProvider(suf)); err != nil { + t.Fatalf("Create %q: %v", suf, err) + } + } + + out, err := repo.List(ctx, "t-default") + if err != nil { + t.Fatalf("List: %v", err) + } + if len(out) != 3 { + t.Errorf("List count = %d; want 3", len(out)) + } +} + +func TestOIDCProviderRepository_Update(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + repo := postgres.NewOIDCProviderRepository(db) + ctx := context.Background() + + p := newValidProvider("upd") + if err := repo.Create(ctx, p); err != nil { + t.Fatalf("Create: %v", err) + } + + p.Name = "Renamed" + p.FetchUserinfo = true + if err := repo.Update(ctx, p); err != nil { + t.Fatalf("Update: %v", err) + } + + got, err := repo.Get(ctx, p.ID) + if err != nil { + t.Fatalf("Get post-update: %v", err) + } + if got.Name != "Renamed" { + t.Errorf("Update did not persist Name; got %q", got.Name) + } + if !got.FetchUserinfo { + t.Errorf("Update did not persist FetchUserinfo") + } +} + +func TestOIDCProviderRepository_DeleteNotFound(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + repo := postgres.NewOIDCProviderRepository(db) + ctx := context.Background() + + err := repo.Delete(ctx, "op-nonexistent") + if !errors.Is(err, repository.ErrOIDCProviderNotFound) { + t.Errorf("err = %v; want ErrOIDCProviderNotFound", err) + } +} + +func TestOIDCProviderRepository_DeleteSucceedsWhenNoUsersReference(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + repo := postgres.NewOIDCProviderRepository(db) + ctx := context.Background() + + p := newValidProvider("del") + if err := repo.Create(ctx, p); err != nil { + t.Fatalf("Create: %v", err) + } + if err := repo.Delete(ctx, p.ID); err != nil { + t.Fatalf("Delete: %v", err) + } + _, err := repo.Get(ctx, p.ID) + if !errors.Is(err, repository.ErrOIDCProviderNotFound) { + t.Errorf("post-delete Get err = %v; want ErrOIDCProviderNotFound", err) + } +} + +// TestOIDCProviderRepository_DeleteRefusedWhenUsersReference pins the +// FK ON DELETE RESTRICT translation. With at least one users row +// referencing the provider, Delete must return ErrOIDCProviderInUse. +func TestOIDCProviderRepository_DeleteRefusedWhenUsersReference(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + providerRepo := postgres.NewOIDCProviderRepository(db) + userRepo := postgres.NewUserRepository(db) + ctx := context.Background() + + p := newValidProvider("inuse") + if err := providerRepo.Create(ctx, p); err != nil { + t.Fatalf("Create provider: %v", err) + } + u := &struct{ ID string }{ID: "u-test"} + _ = u + user := newValidUser("inuse", p.ID) + if err := userRepo.Create(ctx, user); err != nil { + t.Fatalf("Create user: %v", err) + } + + err := providerRepo.Delete(ctx, p.ID) + if !errors.Is(err, repository.ErrOIDCProviderInUse) { + t.Errorf("Delete with referencing user err = %v; want ErrOIDCProviderInUse", err) + } +} + +// ============================================================================= +// GroupRoleMappingRepository +// ============================================================================= + +func TestGroupRoleMappingRepository_AddListMap(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + providerRepo := postgres.NewOIDCProviderRepository(db) + mappingRepo := postgres.NewGroupRoleMappingRepository(db) + ctx := context.Background() + + p := newValidProvider("grm") + if err := providerRepo.Create(ctx, p); err != nil { + t.Fatalf("Create provider: %v", err) + } + + mappings := []*oidcdomain.GroupRoleMapping{ + {ID: "grm-1", TenantID: "t-default", ProviderID: p.ID, GroupName: "engineers", RoleID: "r-operator"}, + {ID: "grm-2", TenantID: "t-default", ProviderID: p.ID, GroupName: "platform-admins", RoleID: "r-admin"}, + {ID: "grm-3", TenantID: "t-default", ProviderID: p.ID, GroupName: "compliance", RoleID: "r-auditor"}, + } + for _, m := range mappings { + if err := mappingRepo.Add(ctx, m); err != nil { + t.Fatalf("Add %s: %v", m.GroupName, err) + } + } + + listed, err := mappingRepo.ListByProvider(ctx, p.ID) + if err != nil { + t.Fatalf("ListByProvider: %v", err) + } + if len(listed) != 3 { + t.Errorf("ListByProvider count = %d; want 3", len(listed)) + } + + // Map: user has groups [engineers, marketing]. Marketing has no + // mapping; only engineers maps to r-operator. + roleIDs, err := mappingRepo.Map(ctx, p.ID, []string{"engineers", "marketing"}) + if err != nil { + t.Fatalf("Map: %v", err) + } + if len(roleIDs) != 1 || roleIDs[0] != "r-operator" { + t.Errorf("Map(engineers, marketing) = %v; want [r-operator]", roleIDs) + } + + // Map: user has groups [engineers, platform-admins]. Both map. + roleIDs, err = mappingRepo.Map(ctx, p.ID, []string{"engineers", "platform-admins"}) + if err != nil { + t.Fatalf("Map (multi): %v", err) + } + if len(roleIDs) != 2 { + t.Errorf("Map(engineers, platform-admins) count = %d; want 2", len(roleIDs)) + } + + // Map empty groups: empty result, no error (Phase 3 fail-closes). + roleIDs, err = mappingRepo.Map(ctx, p.ID, nil) + if err != nil { + t.Fatalf("Map(nil): %v", err) + } + if len(roleIDs) != 0 { + t.Errorf("Map(nil) returned %d roles; want 0", len(roleIDs)) + } +} + +func TestGroupRoleMappingRepository_DuplicateRejected(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + providerRepo := postgres.NewOIDCProviderRepository(db) + mappingRepo := postgres.NewGroupRoleMappingRepository(db) + ctx := context.Background() + + p := newValidProvider("dup") + if err := providerRepo.Create(ctx, p); err != nil { + t.Fatalf("Create provider: %v", err) + } + m := &oidcdomain.GroupRoleMapping{ + ID: "grm-dup-1", TenantID: "t-default", ProviderID: p.ID, + GroupName: "engineers", RoleID: "r-operator", + } + if err := mappingRepo.Add(ctx, m); err != nil { + t.Fatalf("Add first: %v", err) + } + m2 := &oidcdomain.GroupRoleMapping{ + ID: "grm-dup-2", TenantID: "t-default", ProviderID: p.ID, + GroupName: "engineers", RoleID: "r-operator", + } + err := mappingRepo.Add(ctx, m2) + if !errors.Is(err, repository.ErrGroupRoleMappingDuplicate) { + t.Errorf("Add duplicate err = %v; want ErrGroupRoleMappingDuplicate", err) + } +} + +func TestGroupRoleMappingRepository_ProviderDeleteCascades(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + providerRepo := postgres.NewOIDCProviderRepository(db) + mappingRepo := postgres.NewGroupRoleMappingRepository(db) + ctx := context.Background() + + p := newValidProvider("cascade") + if err := providerRepo.Create(ctx, p); err != nil { + t.Fatalf("Create provider: %v", err) + } + for i, group := range []string{"a", "b", "c"} { + m := &oidcdomain.GroupRoleMapping{ + ID: "grm-cas-" + string(rune('a'+i)), TenantID: "t-default", + ProviderID: p.ID, GroupName: group, RoleID: "r-viewer", + } + if err := mappingRepo.Add(ctx, m); err != nil { + t.Fatalf("Add %s: %v", group, err) + } + } + + // Delete provider: ON DELETE CASCADE on group_role_mappings.provider_id + // should drop the 3 mappings too. + if err := providerRepo.Delete(ctx, p.ID); err != nil { + t.Fatalf("Delete provider: %v", err) + } + listed, err := mappingRepo.ListByProvider(ctx, p.ID) + if err != nil { + t.Fatalf("ListByProvider post-cascade: %v", err) + } + if len(listed) != 0 { + t.Errorf("CASCADE failed; %d mappings remain", len(listed)) + } +} + +// quiet unused-import keepalives so single-test runs don't drop them. +var _ = time.Now diff --git a/internal/repository/postgres/session.go b/internal/repository/postgres/session.go new file mode 100644 index 0000000..c6dd503 --- /dev/null +++ b/internal/repository/postgres/session.go @@ -0,0 +1,350 @@ +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/lib/pq" + + sessiondomain "github.com/certctl-io/certctl/internal/auth/session/domain" + "github.com/certctl-io/certctl/internal/repository" +) + +// ============================================================================= +// SessionRepository (Auth Bundle 2 Phase 2) +// ============================================================================= + +// SessionRepository is the postgres implementation of +// repository.SessionRepository. +type SessionRepository struct { + db *sql.DB +} + +// NewSessionRepository constructs a SessionRepository. +func NewSessionRepository(db *sql.DB) *SessionRepository { + return &SessionRepository{db: db} +} + +const sessionColumns = `id, tenant_id, actor_id, actor_type, + signing_key_id, is_pre_login, csrf_token_hash, + idle_expires_at, absolute_expires_at, created_at, last_seen_at, + ip_address, user_agent, revoked_at` + +func scanSession(row interface{ Scan(...interface{}) error }) (*sessiondomain.Session, error) { + var s sessiondomain.Session + var revokedAt sql.NullTime + if err := row.Scan( + &s.ID, &s.TenantID, &s.ActorID, &s.ActorType, + &s.SigningKeyID, &s.IsPreLogin, &s.CSRFTokenHash, + &s.IdleExpiresAt, &s.AbsoluteExpiresAt, &s.CreatedAt, &s.LastSeenAt, + &s.IPAddress, &s.UserAgent, &revokedAt, + ); err != nil { + return nil, err + } + if revokedAt.Valid { + s.RevokedAt = &revokedAt.Time + } + return &s, nil +} + +// Create persists a session row. Caller MUST have called s.Validate(). +func (r *SessionRepository) Create(ctx context.Context, s *sessiondomain.Session) error { + _, err := r.db.ExecContext(ctx, ` + INSERT INTO sessions ( + id, tenant_id, actor_id, actor_type, signing_key_id, + is_pre_login, csrf_token_hash, idle_expires_at, + absolute_expires_at, created_at, last_seen_at, + ip_address, user_agent + ) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13)`, + s.ID, s.TenantID, s.ActorID, s.ActorType, s.SigningKeyID, + s.IsPreLogin, s.CSRFTokenHash, s.IdleExpiresAt, + s.AbsoluteExpiresAt, s.CreatedAt, s.LastSeenAt, + s.IPAddress, s.UserAgent) + if err != nil { + var pqErr *pq.Error + if errors.As(err, &pqErr) && pqErr.Code == "23505" { + return repository.ErrAuthDuplicateName + } + return fmt.Errorf("sessions create: %w", err) + } + return nil +} + +// Get returns a session by id. Returns the row even if revoked / +// expired; the service layer handles the disposition. +func (r *SessionRepository) Get(ctx context.Context, id string) (*sessiondomain.Session, error) { + row := r.db.QueryRowContext(ctx, `SELECT `+sessionColumns+` FROM sessions WHERE id = $1`, id) + s, err := scanSession(row) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, repository.ErrSessionNotFound + } + return nil, fmt.Errorf("sessions get: %w", err) + } + return s, nil +} + +// ListByActor returns active (non-revoked, non-expired, non-pre-login) +// sessions for an actor. +func (r *SessionRepository) ListByActor(ctx context.Context, actorID, actorType, tenantID string) ([]*sessiondomain.Session, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT `+sessionColumns+` + FROM sessions + WHERE actor_id = $1 + AND actor_type = $2 + AND tenant_id = $3 + AND revoked_at IS NULL + AND is_pre_login = FALSE + AND absolute_expires_at > NOW() + ORDER BY created_at DESC`, + actorID, actorType, tenantID) + if err != nil { + return nil, fmt.Errorf("sessions list_by_actor: %w", err) + } + defer rows.Close() + + var out []*sessiondomain.Session + for rows.Next() { + s, err := scanSession(rows) + if err != nil { + return nil, fmt.Errorf("sessions scan: %w", err) + } + out = append(out, s) + } + return out, rows.Err() +} + +// UpdateLastSeen sets last_seen_at = NOW() for the named session. +func (r *SessionRepository) UpdateLastSeen(ctx context.Context, id string) error { + res, err := r.db.ExecContext(ctx, `UPDATE sessions SET last_seen_at = NOW() WHERE id = $1`, id) + if err != nil { + return fmt.Errorf("sessions update_last_seen: %w", err) + } + n, _ := res.RowsAffected() + if n == 0 { + return repository.ErrSessionNotFound + } + return nil +} + +// Revoke sets revoked_at = NOW() for the named session. Idempotent: +// re-revoking an already-revoked session is a no-op (returns nil). +func (r *SessionRepository) Revoke(ctx context.Context, id string) error { + res, err := r.db.ExecContext(ctx, `UPDATE sessions SET revoked_at = NOW() WHERE id = $1 AND revoked_at IS NULL`, id) + if err != nil { + return fmt.Errorf("sessions revoke: %w", err) + } + n, _ := res.RowsAffected() + if n == 0 { + // Distinguish "not found" from "already revoked" by re-querying. + row := r.db.QueryRowContext(ctx, `SELECT 1 FROM sessions WHERE id = $1`, id) + var x int + if err := row.Scan(&x); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return repository.ErrSessionNotFound + } + return fmt.Errorf("sessions revoke probe: %w", err) + } + // Row exists but already revoked: idempotent success. + } + return nil +} + +// RevokeAllForActor sets revoked_at = NOW() on every active session +// for an actor. Returns nil on zero matches (idempotent). +func (r *SessionRepository) RevokeAllForActor(ctx context.Context, actorID, actorType, tenantID string) error { + _, err := r.db.ExecContext(ctx, ` + UPDATE sessions SET revoked_at = NOW() + WHERE actor_id = $1 AND actor_type = $2 AND tenant_id = $3 AND revoked_at IS NULL`, + actorID, actorType, tenantID) + if err != nil { + return fmt.Errorf("sessions revoke_all_for_actor: %w", err) + } + return nil +} + +// GarbageCollectExpired deletes: +// - Sessions whose absolute_expires_at < NOW() (post-login expired). +// - Pre-login sessions older than 10 minutes. +// +// Returns the number of rows deleted across both classes. +func (r *SessionRepository) GarbageCollectExpired(ctx context.Context) (int, error) { + res, err := r.db.ExecContext(ctx, ` + DELETE FROM sessions + WHERE absolute_expires_at < NOW() + OR (is_pre_login = TRUE AND created_at < NOW() - INTERVAL '10 minutes')`) + if err != nil { + return 0, fmt.Errorf("sessions garbage_collect: %w", err) + } + n, _ := res.RowsAffected() + return int(n), nil +} + +// Delete unconditionally removes a session row. +func (r *SessionRepository) Delete(ctx context.Context, id string) error { + res, err := r.db.ExecContext(ctx, `DELETE FROM sessions WHERE id = $1`, id) + if err != nil { + return fmt.Errorf("sessions delete: %w", err) + } + n, _ := res.RowsAffected() + if n == 0 { + return repository.ErrSessionNotFound + } + return nil +} + +// ============================================================================= +// SessionSigningKeyRepository (Auth Bundle 2 Phase 2) +// ============================================================================= + +// SessionSigningKeyRepository is the postgres implementation of +// repository.SessionSigningKeyRepository. +type SessionSigningKeyRepository struct { + db *sql.DB +} + +// NewSessionSigningKeyRepository constructs a SessionSigningKeyRepository. +func NewSessionSigningKeyRepository(db *sql.DB) *SessionSigningKeyRepository { + return &SessionSigningKeyRepository{db: db} +} + +const sessionSigningKeyColumns = `id, tenant_id, key_material_encrypted, created_at, retired_at` + +func scanSessionSigningKey(row interface{ Scan(...interface{}) error }) (*sessiondomain.SessionSigningKey, error) { + var k sessiondomain.SessionSigningKey + var retiredAt sql.NullTime + if err := row.Scan(&k.ID, &k.TenantID, &k.KeyMaterialEncrypted, &k.CreatedAt, &retiredAt); err != nil { + return nil, err + } + if retiredAt.Valid { + k.RetiredAt = &retiredAt.Time + } + return &k, nil +} + +// List returns every signing key in the tenant, including retired ones. +func (r *SessionSigningKeyRepository) List(ctx context.Context, tenantID string) ([]*sessiondomain.SessionSigningKey, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT `+sessionSigningKeyColumns+` FROM session_signing_keys WHERE tenant_id = $1 ORDER BY created_at DESC`, + tenantID) + if err != nil { + return nil, fmt.Errorf("session_signing_keys list: %w", err) + } + defer rows.Close() + + var out []*sessiondomain.SessionSigningKey + for rows.Next() { + k, err := scanSessionSigningKey(rows) + if err != nil { + return nil, fmt.Errorf("session_signing_keys scan: %w", err) + } + out = append(out, k) + } + return out, rows.Err() +} + +// GetActive returns the most-recently-created non-retired key. Returns +// ErrSessionSigningKeyNotFound when no non-retired key exists. +func (r *SessionSigningKeyRepository) GetActive(ctx context.Context, tenantID string) (*sessiondomain.SessionSigningKey, error) { + row := r.db.QueryRowContext(ctx, ` + SELECT `+sessionSigningKeyColumns+` + FROM session_signing_keys + WHERE tenant_id = $1 AND retired_at IS NULL + ORDER BY created_at DESC + LIMIT 1`, tenantID) + k, err := scanSessionSigningKey(row) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, repository.ErrSessionSigningKeyNotFound + } + return nil, fmt.Errorf("session_signing_keys get_active: %w", err) + } + return k, nil +} + +// Get returns a key by id (including retired keys; Phase 4's Validate +// consults this for cookies signed under retired-but-in-retention keys). +func (r *SessionSigningKeyRepository) Get(ctx context.Context, id string) (*sessiondomain.SessionSigningKey, error) { + row := r.db.QueryRowContext(ctx, + `SELECT `+sessionSigningKeyColumns+` FROM session_signing_keys WHERE id = $1`, id) + k, err := scanSessionSigningKey(row) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, repository.ErrSessionSigningKeyNotFound + } + return nil, fmt.Errorf("session_signing_keys get: %w", err) + } + return k, nil +} + +// Add persists a new signing key. Caller MUST have called k.Validate(). +func (r *SessionSigningKeyRepository) Add(ctx context.Context, k *sessiondomain.SessionSigningKey) error { + if k.CreatedAt.IsZero() { + _, err := r.db.ExecContext(ctx, ` + INSERT INTO session_signing_keys (id, tenant_id, key_material_encrypted) + VALUES ($1, $2, $3)`, + k.ID, k.TenantID, k.KeyMaterialEncrypted) + if err != nil { + return fmt.Errorf("session_signing_keys add: %w", err) + } + // Read the row back to populate CreatedAt. + row := r.db.QueryRowContext(ctx, `SELECT created_at FROM session_signing_keys WHERE id = $1`, k.ID) + if err := row.Scan(&k.CreatedAt); err != nil { + return fmt.Errorf("session_signing_keys add (read created_at): %w", err) + } + return nil + } + _, err := r.db.ExecContext(ctx, ` + INSERT INTO session_signing_keys (id, tenant_id, key_material_encrypted, created_at) + VALUES ($1, $2, $3, $4)`, + k.ID, k.TenantID, k.KeyMaterialEncrypted, k.CreatedAt) + if err != nil { + return fmt.Errorf("session_signing_keys add: %w", err) + } + return nil +} + +// Retire marks an active key as retired (sets retired_at = NOW()). +// Idempotent: re-retiring an already-retired key is a no-op. +func (r *SessionSigningKeyRepository) Retire(ctx context.Context, id string) error { + res, err := r.db.ExecContext(ctx, + `UPDATE session_signing_keys SET retired_at = NOW() WHERE id = $1 AND retired_at IS NULL`, id) + if err != nil { + return fmt.Errorf("session_signing_keys retire: %w", err) + } + n, _ := res.RowsAffected() + if n == 0 { + // Distinguish not-found vs already-retired. + row := r.db.QueryRowContext(ctx, `SELECT 1 FROM session_signing_keys WHERE id = $1`, id) + var x int + if err := row.Scan(&x); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return repository.ErrSessionSigningKeyNotFound + } + return fmt.Errorf("session_signing_keys retire probe: %w", err) + } + // Row exists but already retired: idempotent success. + } + return nil +} + +// Delete unconditionally removes a signing key. Returns +// ErrSessionSigningKeyInUse on SQLSTATE 23503 (FK ON DELETE RESTRICT +// from sessions.signing_key_id). +func (r *SessionSigningKeyRepository) Delete(ctx context.Context, id string) error { + res, err := r.db.ExecContext(ctx, `DELETE FROM session_signing_keys WHERE id = $1`, id) + if err != nil { + var pqErr *pq.Error + if errors.As(err, &pqErr) && pqErr.Code == "23503" { + return repository.ErrSessionSigningKeyInUse + } + return fmt.Errorf("session_signing_keys delete: %w", err) + } + n, _ := res.RowsAffected() + if n == 0 { + return repository.ErrSessionSigningKeyNotFound + } + return nil +} diff --git a/internal/repository/postgres/session_test.go b/internal/repository/postgres/session_test.go new file mode 100644 index 0000000..57625c7 --- /dev/null +++ b/internal/repository/postgres/session_test.go @@ -0,0 +1,431 @@ +package postgres_test + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + sessiondomain "github.com/certctl-io/certctl/internal/auth/session/domain" + "github.com/certctl-io/certctl/internal/repository" + "github.com/certctl-io/certctl/internal/repository/postgres" +) + +// ============================================================================= +// SessionSigningKey tests +// ============================================================================= + +func newValidSigningKey(suffix string) *sessiondomain.SessionSigningKey { + return &sessiondomain.SessionSigningKey{ + ID: "sk-" + suffix, + TenantID: "t-default", + KeyMaterialEncrypted: []byte{0x02, 0x00, 0x01, 0x02, 0x03}, + } +} + +func TestSessionSigningKeyRepository_AddAndGetActive(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + repo := postgres.NewSessionSigningKeyRepository(db) + ctx := context.Background() + + k := newValidSigningKey("a") + if err := repo.Add(ctx, k); err != nil { + t.Fatalf("Add: %v", err) + } + if k.CreatedAt.IsZero() { + t.Errorf("Add did not populate CreatedAt") + } + + got, err := repo.GetActive(ctx, "t-default") + if err != nil { + t.Fatalf("GetActive: %v", err) + } + if got.ID != k.ID { + t.Errorf("GetActive returned %q; want %q", got.ID, k.ID) + } +} + +func TestSessionSigningKeyRepository_GetActiveSkipsRetired(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + repo := postgres.NewSessionSigningKeyRepository(db) + ctx := context.Background() + + // Add older key, retire it. Add newer key. GetActive must return newer. + older := newValidSigningKey("older") + if err := repo.Add(ctx, older); err != nil { + t.Fatalf("Add older: %v", err) + } + if err := repo.Retire(ctx, older.ID); err != nil { + t.Fatalf("Retire older: %v", err) + } + // Sleep a millisecond so created_at orders deterministically. + time.Sleep(10 * time.Millisecond) + newer := newValidSigningKey("newer") + if err := repo.Add(ctx, newer); err != nil { + t.Fatalf("Add newer: %v", err) + } + + got, err := repo.GetActive(ctx, "t-default") + if err != nil { + t.Fatalf("GetActive: %v", err) + } + if got.ID != newer.ID { + t.Errorf("GetActive returned %q; want %q (older was retired)", got.ID, newer.ID) + } +} + +func TestSessionSigningKeyRepository_GetActiveReturnsNotFound(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + repo := postgres.NewSessionSigningKeyRepository(db) + ctx := context.Background() + + _, err := repo.GetActive(ctx, "t-default") + if !errors.Is(err, repository.ErrSessionSigningKeyNotFound) { + t.Errorf("err = %v; want ErrSessionSigningKeyNotFound", err) + } +} + +func TestSessionSigningKeyRepository_RetireIsIdempotent(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + repo := postgres.NewSessionSigningKeyRepository(db) + ctx := context.Background() + + k := newValidSigningKey("retire") + if err := repo.Add(ctx, k); err != nil { + t.Fatalf("Add: %v", err) + } + if err := repo.Retire(ctx, k.ID); err != nil { + t.Fatalf("first Retire: %v", err) + } + if err := repo.Retire(ctx, k.ID); err != nil { + t.Errorf("second Retire (already retired) should be idempotent; got %v", err) + } +} + +func TestSessionSigningKeyRepository_DeleteRefusedWhenSessionsReference(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + keyRepo := postgres.NewSessionSigningKeyRepository(db) + sessRepo := postgres.NewSessionRepository(db) + ctx := context.Background() + + k := newValidSigningKey("inuse") + if err := keyRepo.Add(ctx, k); err != nil { + t.Fatalf("Add key: %v", err) + } + s := newValidSession("s1", k.ID) + if err := sessRepo.Create(ctx, s); err != nil { + t.Fatalf("Create session: %v", err) + } + + err := keyRepo.Delete(ctx, k.ID) + if !errors.Is(err, repository.ErrSessionSigningKeyInUse) { + t.Errorf("Delete with referencing session err = %v; want ErrSessionSigningKeyInUse", err) + } +} + +// ============================================================================= +// Session tests +// ============================================================================= + +func newValidSession(suffix, signingKeyID string) *sessiondomain.Session { + now := time.Now().UTC().Truncate(time.Microsecond) + return &sessiondomain.Session{ + ID: "ses-" + suffix, + TenantID: "t-default", + ActorID: "u-" + suffix, + ActorType: "User", + SigningKeyID: signingKeyID, + IsPreLogin: false, + CSRFTokenHash: strings.Repeat("a", 64), + IdleExpiresAt: now.Add(time.Hour), + AbsoluteExpiresAt: now.Add(8 * time.Hour), + CreatedAt: now, + LastSeenAt: now, + IPAddress: "10.0.0.1", + UserAgent: "Mozilla/5.0", + } +} + +func TestSessionRepository_CreateAndGet(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + keyRepo := postgres.NewSessionSigningKeyRepository(db) + sessRepo := postgres.NewSessionRepository(db) + ctx := context.Background() + + k := newValidSigningKey("k1") + if err := keyRepo.Add(ctx, k); err != nil { + t.Fatalf("Add key: %v", err) + } + s := newValidSession("s1", k.ID) + if err := sessRepo.Create(ctx, s); err != nil { + t.Fatalf("Create: %v", err) + } + + got, err := sessRepo.Get(ctx, s.ID) + if err != nil { + t.Fatalf("Get: %v", err) + } + if got.ActorID != s.ActorID { + t.Errorf("ActorID roundtrip mismatch") + } + if got.RevokedAt != nil { + t.Errorf("RevokedAt should be nil on fresh session") + } +} + +func TestSessionRepository_GetNotFound(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + repo := postgres.NewSessionRepository(db) + ctx := context.Background() + + _, err := repo.Get(ctx, "ses-nonexistent") + if !errors.Is(err, repository.ErrSessionNotFound) { + t.Errorf("err = %v; want ErrSessionNotFound", err) + } +} + +func TestSessionRepository_RevokeAndGet(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + keyRepo := postgres.NewSessionSigningKeyRepository(db) + sessRepo := postgres.NewSessionRepository(db) + ctx := context.Background() + + k := newValidSigningKey("k2") + if err := keyRepo.Add(ctx, k); err != nil { + t.Fatalf("Add key: %v", err) + } + s := newValidSession("s2", k.ID) + if err := sessRepo.Create(ctx, s); err != nil { + t.Fatalf("Create: %v", err) + } + + if err := sessRepo.Revoke(ctx, s.ID); err != nil { + t.Fatalf("Revoke: %v", err) + } + got, err := sessRepo.Get(ctx, s.ID) + if err != nil { + t.Fatalf("Get post-revoke: %v", err) + } + if got.RevokedAt == nil { + t.Errorf("RevokedAt nil after Revoke") + } + + // Idempotent re-revoke: returns nil, no panic, no double-update. + if err := sessRepo.Revoke(ctx, s.ID); err != nil { + t.Errorf("re-Revoke (idempotent) err = %v; want nil", err) + } +} + +func TestSessionRepository_RevokeNotFound(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + repo := postgres.NewSessionRepository(db) + ctx := context.Background() + + if err := repo.Revoke(ctx, "ses-nonexistent"); !errors.Is(err, repository.ErrSessionNotFound) { + t.Errorf("err = %v; want ErrSessionNotFound", err) + } +} + +func TestSessionRepository_ListByActorActiveOnly(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + keyRepo := postgres.NewSessionSigningKeyRepository(db) + sessRepo := postgres.NewSessionRepository(db) + ctx := context.Background() + + k := newValidSigningKey("la") + if err := keyRepo.Add(ctx, k); err != nil { + t.Fatalf("Add key: %v", err) + } + // 3 active + 1 revoked + 1 pre-login. + for i, suf := range []string{"a1", "a2", "a3"} { + s := newValidSession(suf, k.ID) + s.ActorID = "u-list-actor" + // uniqueness: stagger created_at so list ordering is stable + s.CreatedAt = s.CreatedAt.Add(time.Duration(i) * time.Millisecond) + if err := sessRepo.Create(ctx, s); err != nil { + t.Fatalf("Create %s: %v", suf, err) + } + } + revoked := newValidSession("rev", k.ID) + revoked.ActorID = "u-list-actor" + if err := sessRepo.Create(ctx, revoked); err != nil { + t.Fatalf("Create revoked: %v", err) + } + if err := sessRepo.Revoke(ctx, revoked.ID); err != nil { + t.Fatalf("Revoke: %v", err) + } + preLogin := newValidSession("pre", k.ID) + preLogin.ActorID = "u-list-actor" + preLogin.IsPreLogin = true + preLogin.CSRFTokenHash = "" // pre-login rows have no CSRF token + if err := sessRepo.Create(ctx, preLogin); err != nil { + t.Fatalf("Create pre-login: %v", err) + } + + out, err := sessRepo.ListByActor(ctx, "u-list-actor", "User", "t-default") + if err != nil { + t.Fatalf("ListByActor: %v", err) + } + if len(out) != 3 { + t.Errorf("ListByActor count = %d; want 3 (revoked + pre-login excluded)", len(out)) + } +} + +func TestSessionRepository_RevokeAllForActor(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + keyRepo := postgres.NewSessionSigningKeyRepository(db) + sessRepo := postgres.NewSessionRepository(db) + ctx := context.Background() + + k := newValidSigningKey("ra") + if err := keyRepo.Add(ctx, k); err != nil { + t.Fatalf("Add key: %v", err) + } + // 3 sessions for one actor. + for _, suf := range []string{"r1", "r2", "r3"} { + s := newValidSession(suf, k.ID) + s.ActorID = "u-fired" + if err := sessRepo.Create(ctx, s); err != nil { + t.Fatalf("Create %s: %v", suf, err) + } + } + if err := sessRepo.RevokeAllForActor(ctx, "u-fired", "User", "t-default"); err != nil { + t.Fatalf("RevokeAllForActor: %v", err) + } + out, err := sessRepo.ListByActor(ctx, "u-fired", "User", "t-default") + if err != nil { + t.Fatalf("ListByActor post-revoke: %v", err) + } + if len(out) != 0 { + t.Errorf("RevokeAllForActor left %d sessions active; want 0", len(out)) + } +} + +func TestSessionRepository_GarbageCollectExpired(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + keyRepo := postgres.NewSessionSigningKeyRepository(db) + sessRepo := postgres.NewSessionRepository(db) + ctx := context.Background() + + k := newValidSigningKey("gc") + if err := keyRepo.Add(ctx, k); err != nil { + t.Fatalf("Add key: %v", err) + } + + // One session with absolute expiry in the past (write directly via SQL + // to bypass the CHECK constraints; this simulates a row that aged + // past expiry without GC having run yet). + now := time.Now().UTC() + old := time.Now().UTC().Add(-2 * time.Hour) + older := time.Now().UTC().Add(-3 * time.Hour) + _, err := db.ExecContext(ctx, ` + INSERT INTO sessions (id, tenant_id, actor_id, actor_type, signing_key_id, + is_pre_login, csrf_token_hash, idle_expires_at, absolute_expires_at, + created_at, last_seen_at, ip_address, user_agent) + VALUES ($1, 't-default', 'u-gc', 'User', $2, FALSE, '', + $3, $4, $5, $5, '', '')`, + "ses-expired", k.ID, older, old, time.Now().UTC().Add(-4*time.Hour)) + if err != nil { + t.Fatalf("seed expired: %v", err) + } + + // One pre-login row older than 10 minutes. + _, err = db.ExecContext(ctx, ` + INSERT INTO sessions (id, tenant_id, actor_id, actor_type, signing_key_id, + is_pre_login, csrf_token_hash, idle_expires_at, absolute_expires_at, + created_at, last_seen_at, ip_address, user_agent) + VALUES ($1, 't-default', 'u-gc', 'User', $2, TRUE, '', + $3, $4, $5, $5, '', '')`, + "ses-prelogin-old", k.ID, + now.Add(-15*time.Minute).Add(time.Hour), // idle in future relative to created + now.Add(-15*time.Minute).Add(2*time.Hour), // absolute > idle, both > created + now.Add(-15*time.Minute)) // created 15 min ago (older than 10 min TTL) + if err != nil { + t.Fatalf("seed pre-login: %v", err) + } + + // One active session (NOT to be GC'd). + active := newValidSession("active", k.ID) + active.ActorID = "u-gc" + if err := sessRepo.Create(ctx, active); err != nil { + t.Fatalf("seed active: %v", err) + } + + n, err := sessRepo.GarbageCollectExpired(ctx) + if err != nil { + t.Fatalf("GC: %v", err) + } + if n != 2 { + t.Errorf("GC deleted %d rows; want 2 (expired + old pre-login)", n) + } + + // Active session survives. + if _, err := sessRepo.Get(ctx, active.ID); err != nil { + t.Errorf("active session should survive GC; got %v", err) + } +} + +func TestSessionRepository_UpdateLastSeen(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + keyRepo := postgres.NewSessionSigningKeyRepository(db) + sessRepo := postgres.NewSessionRepository(db) + ctx := context.Background() + + k := newValidSigningKey("uls") + if err := keyRepo.Add(ctx, k); err != nil { + t.Fatalf("Add key: %v", err) + } + s := newValidSession("uls", k.ID) + if err := sessRepo.Create(ctx, s); err != nil { + t.Fatalf("Create: %v", err) + } + originalSeen := s.LastSeenAt + time.Sleep(10 * time.Millisecond) + if err := sessRepo.UpdateLastSeen(ctx, s.ID); err != nil { + t.Fatalf("UpdateLastSeen: %v", err) + } + got, _ := sessRepo.Get(ctx, s.ID) + if !got.LastSeenAt.After(originalSeen) { + t.Errorf("LastSeenAt did not advance after UpdateLastSeen") + } +} diff --git a/internal/repository/postgres/user.go b/internal/repository/postgres/user.go new file mode 100644 index 0000000..95a9ad2 --- /dev/null +++ b/internal/repository/postgres/user.go @@ -0,0 +1,137 @@ +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/lib/pq" + + userdomain "github.com/certctl-io/certctl/internal/auth/user/domain" + "github.com/certctl-io/certctl/internal/repository" +) + +// UserRepository is the postgres implementation of +// repository.UserRepository (Auth Bundle 2 Phase 2). +type UserRepository struct { + db *sql.DB +} + +// NewUserRepository constructs a UserRepository. +func NewUserRepository(db *sql.DB) *UserRepository { + return &UserRepository{db: db} +} + +const userColumns = `id, tenant_id, email, display_name, oidc_subject, + oidc_provider_id, last_login_at, webauthn_credentials, + created_at, updated_at` + +func scanUser(row interface{ Scan(...interface{}) error }) (*userdomain.User, error) { + var u userdomain.User + if err := row.Scan( + &u.ID, &u.TenantID, &u.Email, &u.DisplayName, &u.OIDCSubject, + &u.OIDCProviderID, &u.LastLoginAt, &u.WebAuthnCredentials, + &u.CreatedAt, &u.UpdatedAt, + ); err != nil { + return nil, err + } + return &u, nil +} + +// Get returns one user by id. +func (r *UserRepository) Get(ctx context.Context, id string) (*userdomain.User, error) { + row := r.db.QueryRowContext(ctx, `SELECT `+userColumns+` FROM users WHERE id = $1`, id) + u, err := scanUser(row) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, repository.ErrUserNotFound + } + return nil, fmt.Errorf("users get: %w", err) + } + return u, nil +} + +// GetByOIDCSubject is the Phase 3 hot-path lookup at login time. +// Returns ErrUserNotFound if no row matches the (provider, subject) +// tuple — Phase 3's HandleCallback then creates the row via Create. +func (r *UserRepository) GetByOIDCSubject(ctx context.Context, providerID, subject string) (*userdomain.User, error) { + row := r.db.QueryRowContext(ctx, ` + SELECT `+userColumns+` + FROM users + WHERE oidc_provider_id = $1 AND oidc_subject = $2`, + providerID, subject) + u, err := scanUser(row) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, repository.ErrUserNotFound + } + return nil, fmt.Errorf("users get_by_oidc_subject: %w", err) + } + return u, nil +} + +// Create persists a new user. Translates SQLSTATE 23505 into +// ErrUserDuplicateOIDCSubject (the unique constraint on +// (oidc_provider_id, oidc_subject)). +func (r *UserRepository) Create(ctx context.Context, u *userdomain.User) error { + _, err := r.db.ExecContext(ctx, ` + INSERT INTO users ( + id, tenant_id, email, display_name, oidc_subject, + oidc_provider_id, last_login_at, webauthn_credentials + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`, + u.ID, u.TenantID, u.Email, u.DisplayName, u.OIDCSubject, + u.OIDCProviderID, u.LastLoginAt, u.WebAuthnCredentials) + if err != nil { + var pqErr *pq.Error + if errors.As(err, &pqErr) && pqErr.Code == "23505" { + return repository.ErrUserDuplicateOIDCSubject + } + return fmt.Errorf("users create: %w", err) + } + return nil +} + +// Update writes the mutable fields (email, display_name, last_login_at, +// webauthn_credentials) back to the row. Immutable: id, tenant_id, +// oidc_subject, oidc_provider_id, created_at. updated_at = NOW(). +func (r *UserRepository) Update(ctx context.Context, u *userdomain.User) error { + res, err := r.db.ExecContext(ctx, ` + UPDATE users SET + email = $2, + display_name = $3, + last_login_at = $4, + webauthn_credentials = $5, + updated_at = NOW() + WHERE id = $1`, + u.ID, u.Email, u.DisplayName, u.LastLoginAt, u.WebAuthnCredentials) + if err != nil { + return fmt.Errorf("users update: %w", err) + } + n, _ := res.RowsAffected() + if n == 0 { + return repository.ErrUserNotFound + } + return nil +} + +// ListAll returns every user in the tenant, ordered by created_at ASC. +func (r *UserRepository) ListAll(ctx context.Context, tenantID string) ([]*userdomain.User, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT `+userColumns+` FROM users WHERE tenant_id = $1 ORDER BY created_at ASC`, + tenantID) + if err != nil { + return nil, fmt.Errorf("users list_all: %w", err) + } + defer rows.Close() + + var out []*userdomain.User + for rows.Next() { + u, err := scanUser(rows) + if err != nil { + return nil, fmt.Errorf("users scan: %w", err) + } + out = append(out, u) + } + return out, rows.Err() +} diff --git a/internal/repository/postgres/user_test.go b/internal/repository/postgres/user_test.go new file mode 100644 index 0000000..3c6dcc0 --- /dev/null +++ b/internal/repository/postgres/user_test.go @@ -0,0 +1,220 @@ +package postgres_test + +import ( + "context" + "errors" + "testing" + + userdomain "github.com/certctl-io/certctl/internal/auth/user/domain" + "github.com/certctl-io/certctl/internal/repository" + "github.com/certctl-io/certctl/internal/repository/postgres" +) + +// newValidUser is shared with oidc_test.go (same _test package). +func newValidUser(suffix, providerID string) *userdomain.User { + return &userdomain.User{ + ID: "u-" + suffix, + TenantID: "t-default", + Email: suffix + "@example.com", + DisplayName: "User " + suffix, + OIDCSubject: "subject-" + suffix, + OIDCProviderID: providerID, + WebAuthnCredentials: []byte("[]"), + } +} + +func TestUserRepository_CreateAndGet(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + providerRepo := postgres.NewOIDCProviderRepository(db) + userRepo := postgres.NewUserRepository(db) + ctx := context.Background() + + p := newValidProvider("u") + if err := providerRepo.Create(ctx, p); err != nil { + t.Fatalf("Create provider: %v", err) + } + u := newValidUser("alice", p.ID) + if err := userRepo.Create(ctx, u); err != nil { + t.Fatalf("Create user: %v", err) + } + + got, err := userRepo.Get(ctx, u.ID) + if err != nil { + t.Fatalf("Get: %v", err) + } + if got.Email != u.Email { + t.Errorf("Email roundtrip: got %q, want %q", got.Email, u.Email) + } + if string(got.WebAuthnCredentials) != "[]" { + t.Errorf("WebAuthnCredentials default = %q; want []", string(got.WebAuthnCredentials)) + } +} + +func TestUserRepository_GetNotFound(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + repo := postgres.NewUserRepository(db) + ctx := context.Background() + + _, err := repo.Get(ctx, "u-nonexistent") + if !errors.Is(err, repository.ErrUserNotFound) { + t.Errorf("err = %v; want ErrUserNotFound", err) + } +} + +func TestUserRepository_GetByOIDCSubject(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + providerRepo := postgres.NewOIDCProviderRepository(db) + userRepo := postgres.NewUserRepository(db) + ctx := context.Background() + + p := newValidProvider("subj") + if err := providerRepo.Create(ctx, p); err != nil { + t.Fatalf("Create provider: %v", err) + } + u := newValidUser("bob", p.ID) + if err := userRepo.Create(ctx, u); err != nil { + t.Fatalf("Create user: %v", err) + } + + got, err := userRepo.GetByOIDCSubject(ctx, p.ID, u.OIDCSubject) + if err != nil { + t.Fatalf("GetByOIDCSubject: %v", err) + } + if got.ID != u.ID { + t.Errorf("GetByOIDCSubject returned %q; want %q", got.ID, u.ID) + } + + // Wrong subject: not found. + _, err = userRepo.GetByOIDCSubject(ctx, p.ID, "wrong-subject") + if !errors.Is(err, repository.ErrUserNotFound) { + t.Errorf("err = %v; want ErrUserNotFound", err) + } +} + +func TestUserRepository_DuplicateOIDCSubjectRejected(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + providerRepo := postgres.NewOIDCProviderRepository(db) + userRepo := postgres.NewUserRepository(db) + ctx := context.Background() + + p := newValidProvider("dupsubj") + if err := providerRepo.Create(ctx, p); err != nil { + t.Fatalf("Create provider: %v", err) + } + u1 := newValidUser("first", p.ID) + if err := userRepo.Create(ctx, u1); err != nil { + t.Fatalf("Create u1: %v", err) + } + u2 := newValidUser("second", p.ID) + u2.OIDCSubject = u1.OIDCSubject // collision on (provider, subject) UNIQUE + err := userRepo.Create(ctx, u2) + if !errors.Is(err, repository.ErrUserDuplicateOIDCSubject) { + t.Errorf("Create duplicate (provider, subject) err = %v; want ErrUserDuplicateOIDCSubject", err) + } +} + +func TestUserRepository_UpdateMutableFields(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + providerRepo := postgres.NewOIDCProviderRepository(db) + userRepo := postgres.NewUserRepository(db) + ctx := context.Background() + + p := newValidProvider("upd") + if err := providerRepo.Create(ctx, p); err != nil { + t.Fatalf("Create provider: %v", err) + } + u := newValidUser("carol", p.ID) + if err := userRepo.Create(ctx, u); err != nil { + t.Fatalf("Create user: %v", err) + } + + u.Email = "carol-new@example.com" + u.DisplayName = "Carol Renamed" + if err := userRepo.Update(ctx, u); err != nil { + t.Fatalf("Update: %v", err) + } + got, err := userRepo.Get(ctx, u.ID) + if err != nil { + t.Fatalf("Get post-update: %v", err) + } + if got.Email != "carol-new@example.com" { + t.Errorf("Update did not persist Email; got %q", got.Email) + } + if got.DisplayName != "Carol Renamed" { + t.Errorf("Update did not persist DisplayName; got %q", got.DisplayName) + } + // Immutable: oidc_subject must NOT change. + if got.OIDCSubject != u.OIDCSubject { + t.Errorf("OIDCSubject mutated: got %q, want %q", got.OIDCSubject, u.OIDCSubject) + } +} + +func TestUserRepository_ListAll(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + providerRepo := postgres.NewOIDCProviderRepository(db) + userRepo := postgres.NewUserRepository(db) + ctx := context.Background() + + p := newValidProvider("la") + if err := providerRepo.Create(ctx, p); err != nil { + t.Fatalf("Create provider: %v", err) + } + for _, suf := range []string{"u1", "u2", "u3"} { + u := newValidUser(suf, p.ID) + if err := userRepo.Create(ctx, u); err != nil { + t.Fatalf("Create %s: %v", suf, err) + } + } + + out, err := userRepo.ListAll(ctx, "t-default") + if err != nil { + t.Fatalf("ListAll: %v", err) + } + if len(out) != 3 { + t.Errorf("ListAll count = %d; want 3", len(out)) + } +} + +// TestUserRepository_DeletingProviderRefusedWhenUsersReference complements +// the OIDCProviderRepository test of the same shape; pinning both ends +// of the FK ON DELETE RESTRICT contract. +func TestUserRepository_FKRestrictsProviderDelete(t *testing.T) { + if testing.Short() { + t.Skip("integration test in short mode") + } + db := getTestDB(t).freshSchema(t) + providerRepo := postgres.NewOIDCProviderRepository(db) + userRepo := postgres.NewUserRepository(db) + ctx := context.Background() + + p := newValidProvider("fkrest") + if err := providerRepo.Create(ctx, p); err != nil { + t.Fatalf("Create provider: %v", err) + } + u := newValidUser("fkrest-user", p.ID) + if err := userRepo.Create(ctx, u); err != nil { + t.Fatalf("Create user: %v", err) + } + + if err := providerRepo.Delete(ctx, p.ID); !errors.Is(err, repository.ErrOIDCProviderInUse) { + t.Errorf("Delete provider (with referencing user) err = %v; want ErrOIDCProviderInUse", err) + } +} diff --git a/internal/repository/session.go b/internal/repository/session.go new file mode 100644 index 0000000..c15533c --- /dev/null +++ b/internal/repository/session.go @@ -0,0 +1,124 @@ +package repository + +import ( + "context" + "errors" + + sessiondomain "github.com/certctl-io/certctl/internal/auth/session/domain" +) + +// Sentinel errors for the session repositories. +var ( + // ErrSessionNotFound: Get returned no row. Phase 4 maps to 401 + // (the cookie either expired or was forged with a known-good key + // id but stale session id). + ErrSessionNotFound = errors.New("session: not found") + + // ErrSessionRevoked: Get found a row but RevokedAt is set. Phase 4 + // maps to 401. + ErrSessionRevoked = errors.New("session: revoked") + + // ErrSessionExpired: Get found a row but the absolute expiry has + // passed (Phase 4 also enforces idle expiry but that's a service- + // level check against last_seen_at, not a repository sentinel). + ErrSessionExpired = errors.New("session: expired") + + // ErrSessionSigningKeyNotFound: GetActive returned no row. Phase 4 + // EnsureInitialSigningKey treats this as "boot-time provisioning + // needed" and mints the first key. + ErrSessionSigningKeyNotFound = errors.New("session: signing key not found") + + // ErrSessionSigningKeyInUse: Delete (full purge, not Retire) failed + // because at least one sessions row still references the key. Phase + // 4's GarbageCollect waits for sessions to expire before purging. + ErrSessionSigningKeyInUse = errors.New("session: signing key still referenced by active sessions") +) + +// SessionRepository wraps the sessions table. Two cookie shapes share +// the rows: post-login sessions (1h-idle/8h-absolute) and pre-login +// sessions (10-minute TTL, IsPreLogin=true; carry OIDC state + nonce +// + PKCE verifier across the IdP redirect). +type SessionRepository interface { + // Create persists a session row. Caller MUST have called + // s.Validate(). Returns ErrAuthDuplicateName-shape on the + // extremely-unlikely id collision (the id is a 32-byte random; + // callers SHOULD generate fresh ids on the second attempt). + Create(ctx context.Context, s *sessiondomain.Session) error + + // Get returns a session by id. ErrSessionNotFound on miss. + // Returns the row even if revoked / expired so the service layer + // can produce the right 401 reason code (revoked vs expired vs + // not-found are all 401 to the wire but distinguishable in audit). + Get(ctx context.Context, id string) (*sessiondomain.Session, error) + + // ListByActor returns every active (non-revoked, non-expired, + // non-pre-login) session for an actor. Used by the GUI's + // /v1/auth/sessions surface so users can revoke their old laptops. + ListByActor(ctx context.Context, actorID, actorType, tenantID string) ([]*sessiondomain.Session, error) + + // UpdateLastSeen sets last_seen_at = NOW() for the named session. + // Phase 4's middleware calls this on every request to keep the + // idle-expiry sliding window fresh. + UpdateLastSeen(ctx context.Context, id string) error + + // Revoke sets revoked_at = NOW() for the named session. Subsequent + // Get returns the row with RevokedAt set; Phase 4's Validate maps + // to 401. + Revoke(ctx context.Context, id string) error + + // RevokeAllForActor sets revoked_at = NOW() on every active session + // for an actor. Used on role change, fired-employee scenarios, and + // the back-channel logout endpoint (Phase 5). + RevokeAllForActor(ctx context.Context, actorID, actorType, tenantID string) error + + // GarbageCollectExpired deletes sessions whose absolute expiry + // has passed AND whose revoked_at is older than the configurable + // retention window (default 24h). Pre-login rows older than the + // 10-minute TTL are also deleted. Returns the number of rows + // deleted. + GarbageCollectExpired(ctx context.Context) (int, error) + + // Delete unconditionally removes a session row. Used for the + // admin-only "purge a specific session" surface (rarely needed; + // Revoke is the normal path). + Delete(ctx context.Context, id string) error +} + +// SessionSigningKeyRepository wraps the session_signing_keys table. +// Phase 4's Service.RotateSigningKey + EnsureInitialSigningKey + the +// scheduler-driven retention sweep consume this. +type SessionSigningKeyRepository interface { + // List returns every signing key in the tenant (including + // retired). Order: created_at DESC. + List(ctx context.Context, tenantID string) ([]*sessiondomain.SessionSigningKey, error) + + // GetActive returns the most-recently-created non-retired key. + // ErrSessionSigningKeyNotFound when no non-retired key exists + // (Phase 4's EnsureInitialSigningKey treats this as "mint first + // key"). + GetActive(ctx context.Context, tenantID string) (*sessiondomain.SessionSigningKey, error) + + // Get returns one key by id (including retired keys; Phase 4's + // Validate consults this for cookies signed under retired-but- + // in-retention keys). + Get(ctx context.Context, id string) (*sessiondomain.SessionSigningKey, error) + + // Add persists a new signing key. Caller MUST have called + // k.Validate() and encrypted the key_material via + // internal/crypto/encryption.go. CreatedAt defaults to NOW() if + // zero. + Add(ctx context.Context, k *sessiondomain.SessionSigningKey) error + + // Retire marks an active key as retired (sets retired_at = NOW()). + // The key stays in the table for verification of cookies signed + // under it; the scheduler's retention sweep purges it after the + // configurable retention window (default 24h beyond retired_at). + Retire(ctx context.Context, id string) error + + // Delete unconditionally removes a signing key row. Returns + // ErrSessionSigningKeyInUse if any sessions row still references + // the key (FK ON DELETE RESTRICT). Phase 4's GarbageCollect calls + // this only after RetentionWindow has passed AND no sessions + // reference the key. + Delete(ctx context.Context, id string) error +} diff --git a/internal/repository/user.go b/internal/repository/user.go new file mode 100644 index 0000000..f40b923 --- /dev/null +++ b/internal/repository/user.go @@ -0,0 +1,46 @@ +package repository + +import ( + "context" + "errors" + + userdomain "github.com/certctl-io/certctl/internal/auth/user/domain" +) + +// Sentinel errors for the user repository. +var ( + // ErrUserNotFound: Get / GetByOIDCSubject returned no row. Phase + // 3's HandleCallback treats this as "first login for this person; + // create the row". + ErrUserNotFound = errors.New("user: not found") + + // ErrUserDuplicateOIDCSubject: Create tripped the + // (oidc_provider_id, oidc_subject) UNIQUE constraint. HTTP 409. + ErrUserDuplicateOIDCSubject = errors.New("user: a user with this provider+subject already exists") +) + +// UserRepository wraps the users table. Phase 3's HandleCallback +// uses GetByOIDCSubject + Create + Update on every login; the GUI's +// admin user-list surface uses ListAll + Get. +type UserRepository interface { + // Get returns one user by id. ErrUserNotFound on miss. + Get(ctx context.Context, id string) (*userdomain.User, error) + + // GetByOIDCSubject is the Phase 3 hot-path lookup at login time. + // Returns the existing row if present, ErrUserNotFound otherwise. + GetByOIDCSubject(ctx context.Context, providerID, subject string) (*userdomain.User, error) + + // Create persists a new user. Caller MUST have called u.Validate(). + // Returns ErrUserDuplicateOIDCSubject on UNIQUE constraint trip. + Create(ctx context.Context, u *userdomain.User) error + + // Update writes the mutable field set back to the row. Immutable + // fields (id, tenant_id, oidc_subject, oidc_provider_id, + // created_at) are preserved. updated_at is set to NOW() by the + // implementation. + Update(ctx context.Context, u *userdomain.User) error + + // ListAll returns every user in the tenant. Order: + // created_at ASC. Used by the GUI's admin surface. + ListAll(ctx context.Context, tenantID string) ([]*userdomain.User, error) +}