mirror of
https://github.com/shankar0123/certctl.git
synced 2026-06-07 16:11:29 +00:00
Implement M7: auth middleware, rate limiting, CORS, and GUI login flow
Add SHA-256 API key authentication with constant-time comparison, configurable token bucket rate limiter, CORS origin allowlist middleware, and React auth context with login page. Auth info endpoint bootstraps GUI without credentials. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -5,11 +5,13 @@ import (
|
||||
)
|
||||
|
||||
// HealthHandler handles health and readiness check endpoints.
|
||||
type HealthHandler struct{}
|
||||
type HealthHandler struct {
|
||||
AuthType string // "api-key", "jwt", "none"
|
||||
}
|
||||
|
||||
// NewHealthHandler creates a new HealthHandler.
|
||||
func NewHealthHandler() HealthHandler {
|
||||
return HealthHandler{}
|
||||
func NewHealthHandler(authType string) HealthHandler {
|
||||
return HealthHandler{AuthType: authType}
|
||||
}
|
||||
|
||||
// Health responds with a simple health check indicating the service is alive.
|
||||
@@ -41,3 +43,21 @@ func (h HealthHandler) Ready(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
JSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// AuthInfo responds with the server's authentication configuration.
|
||||
// This lets the GUI know whether to show a login screen.
|
||||
// GET /api/v1/auth/info (served without auth middleware)
|
||||
func (h HealthHandler) AuthInfo(w http.ResponseWriter, r *http.Request) {
|
||||
response := map[string]interface{}{
|
||||
"auth_type": h.AuthType,
|
||||
"required": h.AuthType != "none",
|
||||
}
|
||||
JSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// AuthCheck returns 200 if the request has valid auth credentials.
|
||||
// The auth middleware runs before this handler, so reaching here means auth passed.
|
||||
// GET /api/v1/auth/check
|
||||
func (h HealthHandler) AuthCheck(w http.ResponseWriter, r *http.Request) {
|
||||
JSON(w, http.StatusOK, map[string]string{"status": "authenticated"})
|
||||
}
|
||||
|
||||
@@ -2,8 +2,12 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -48,36 +52,178 @@ func Recovery(next http.Handler) http.Handler {
|
||||
if err := recover(); err != nil {
|
||||
requestID := getRequestID(r.Context())
|
||||
log.Printf("[%s] PANIC: %v", requestID, err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
http.Error(w, `{"error":"Internal Server Error"}`, http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Auth middleware is a placeholder that checks the Authorization header and extracts user information.
|
||||
// In production, this would validate tokens, verify signatures, etc.
|
||||
func Auth(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
// For now, allow requests without auth (placeholder)
|
||||
// In production, enforce auth on protected routes
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
// HashAPIKey computes the SHA-256 hash of an API key for secure storage.
|
||||
// We use SHA-256 rather than bcrypt because API keys are high-entropy
|
||||
// random strings (not user-chosen passwords), so rainbow tables and
|
||||
// brute-force attacks are not a practical concern.
|
||||
func HashAPIKey(key string) string {
|
||||
h := sha256.Sum256([]byte(key))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
// Simple stub: just extract user ID from Bearer token (format: "Bearer <user_id>")
|
||||
// This is NOT secure and for development only
|
||||
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
|
||||
userID := authHeader[7:]
|
||||
ctx := context.WithValue(r.Context(), UserKey{}, userID)
|
||||
// AuthConfig holds configuration for the Auth middleware.
|
||||
type AuthConfig struct {
|
||||
Type string // "api-key", "jwt", "none"
|
||||
Secret string // The raw API key (server compares against this)
|
||||
}
|
||||
|
||||
// NewAuth creates an authentication middleware based on config.
|
||||
// When Type is "none", all requests pass through (demo/development mode).
|
||||
// When Type is "api-key", requests must include a valid Bearer token.
|
||||
func NewAuth(cfg AuthConfig) func(http.Handler) http.Handler {
|
||||
if cfg.Type == "none" {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return next
|
||||
}
|
||||
}
|
||||
|
||||
// Pre-compute hash of the expected key for constant-time comparison
|
||||
expectedHash := HashAPIKey(cfg.Secret)
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.Header().Set("WWW-Authenticate", `Bearer realm="certctl"`)
|
||||
http.Error(w, `{"error":"Authorization header required"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract Bearer token
|
||||
if len(authHeader) < 8 || authHeader[:7] != "Bearer " {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
http.Error(w, `{"error":"Invalid Authorization header format, expected: Bearer <token>"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
token := authHeader[7:]
|
||||
tokenHash := HashAPIKey(token)
|
||||
|
||||
// Constant-time comparison to prevent timing attacks
|
||||
if subtle.ConstantTimeCompare([]byte(tokenHash), []byte(expectedHash)) != 1 {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
http.Error(w, `{"error":"Invalid API key"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Store the authenticated identity in context
|
||||
ctx := context.WithValue(r.Context(), UserKey{}, "api-key-user")
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
http.Error(w, "Invalid Authorization header", http.StatusUnauthorized)
|
||||
})
|
||||
// RateLimitConfig holds configuration for the rate limiter.
|
||||
type RateLimitConfig struct {
|
||||
RPS float64 // Requests per second
|
||||
BurstSize int // Maximum burst size
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a token bucket rate limiting middleware.
|
||||
// Uses a simple token bucket: tokens refill at RPS rate, burst allows short spikes.
|
||||
func NewRateLimiter(cfg RateLimitConfig) func(http.Handler) http.Handler {
|
||||
limiter := &tokenBucket{
|
||||
rate: cfg.RPS,
|
||||
burstSize: float64(cfg.BurstSize),
|
||||
tokens: float64(cfg.BurstSize),
|
||||
lastRefill: time.Now(),
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !limiter.allow() {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.Header().Set("Retry-After", "1")
|
||||
http.Error(w, `{"error":"Rate limit exceeded"}`, http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// tokenBucket implements a simple thread-safe token bucket rate limiter.
|
||||
// This avoids importing golang.org/x/time/rate to keep dependencies minimal.
|
||||
type tokenBucket struct {
|
||||
mu sync.Mutex
|
||||
rate float64 // tokens per second
|
||||
burstSize float64 // max tokens
|
||||
tokens float64 // current tokens
|
||||
lastRefill time.Time // last refill time
|
||||
}
|
||||
|
||||
func (tb *tokenBucket) allow() bool {
|
||||
tb.mu.Lock()
|
||||
defer tb.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(tb.lastRefill).Seconds()
|
||||
tb.tokens += elapsed * tb.rate
|
||||
if tb.tokens > tb.burstSize {
|
||||
tb.tokens = tb.burstSize
|
||||
}
|
||||
tb.lastRefill = now
|
||||
|
||||
if tb.tokens < 1 {
|
||||
return false
|
||||
}
|
||||
tb.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
// CORSConfig holds configuration for the CORS middleware.
|
||||
type CORSConfig struct {
|
||||
AllowedOrigins []string // Allowed origins; empty = same-origin only
|
||||
}
|
||||
|
||||
// NewCORS creates a CORS middleware with configurable allowed origins.
|
||||
// If no origins are configured, same-origin requests are allowed by default.
|
||||
// If ["*"] is configured, all origins are allowed (development/demo mode).
|
||||
func NewCORS(cfg CORSConfig) func(http.Handler) http.Handler {
|
||||
allowAll := false
|
||||
originSet := make(map[string]bool)
|
||||
for _, o := range cfg.AllowedOrigins {
|
||||
if o == "*" {
|
||||
allowAll = true
|
||||
}
|
||||
originSet[o] = true
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
|
||||
if allowAll {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
} else if origin != "" && originSet[origin] {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Vary", "Origin")
|
||||
} else if len(cfg.AllowedOrigins) == 0 && origin != "" {
|
||||
// No config = permissive same-origin default for single-host deployments
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Vary", "Origin")
|
||||
}
|
||||
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Request-ID")
|
||||
w.Header().Set("Access-Control-Max-Age", "86400")
|
||||
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ContentType middleware sets the Content-Type header to application/json.
|
||||
@@ -89,6 +235,7 @@ func ContentType(next http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// CORS middleware adds CORS headers to allow cross-origin requests.
|
||||
// Deprecated: Use NewCORS for configurable origins. Kept for health endpoints.
|
||||
func CORS(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
@@ -57,7 +57,7 @@ func (r *Router) RegisterHandlers(
|
||||
notifications handler.NotificationHandler,
|
||||
health handler.HealthHandler,
|
||||
) {
|
||||
// Health endpoints (no middleware)
|
||||
// Health endpoints (no auth middleware — must always be accessible)
|
||||
r.mux.Handle("GET /health", middleware.Chain(
|
||||
http.HandlerFunc(health.Health),
|
||||
middleware.CORS,
|
||||
@@ -68,6 +68,14 @@ func (r *Router) RegisterHandlers(
|
||||
middleware.CORS,
|
||||
middleware.ContentType,
|
||||
))
|
||||
// Auth info endpoint (no auth middleware — GUI needs this before login)
|
||||
r.mux.Handle("GET /api/v1/auth/info", middleware.Chain(
|
||||
http.HandlerFunc(health.AuthInfo),
|
||||
middleware.CORS,
|
||||
middleware.ContentType,
|
||||
))
|
||||
// Auth check endpoint (uses full middleware chain via r.Register)
|
||||
r.Register("GET /api/v1/auth/check", http.HandlerFunc(health.AuthCheck))
|
||||
|
||||
// Certificates routes: /api/v1/certificates
|
||||
r.Register("GET /api/v1/certificates", http.HandlerFunc(certificates.ListCertificates))
|
||||
|
||||
@@ -16,6 +16,8 @@ type Config struct {
|
||||
Scheduler SchedulerConfig
|
||||
Log LogConfig
|
||||
Auth AuthConfig
|
||||
RateLimit RateLimitConfig
|
||||
CORS CORSConfig
|
||||
}
|
||||
|
||||
// ServerConfig contains HTTP server configuration.
|
||||
@@ -51,6 +53,18 @@ type AuthConfig struct {
|
||||
Secret string // Secret key for signing (if applicable)
|
||||
}
|
||||
|
||||
// RateLimitConfig contains rate limiting configuration.
|
||||
type RateLimitConfig struct {
|
||||
Enabled bool
|
||||
RPS float64 // Requests per second
|
||||
BurstSize int // Maximum burst size
|
||||
}
|
||||
|
||||
// CORSConfig contains CORS configuration.
|
||||
type CORSConfig struct {
|
||||
AllowedOrigins []string // Allowed origins; empty = same-origin only; ["*"] = all
|
||||
}
|
||||
|
||||
// Load reads configuration from environment variables and returns a Config.
|
||||
// Environment variables must have the CERTCTL_ prefix.
|
||||
// Example: CERTCTL_SERVER_HOST, CERTCTL_DATABASE_URL, etc.
|
||||
@@ -79,6 +93,14 @@ func Load() (*Config, error) {
|
||||
Type: getEnv("CERTCTL_AUTH_TYPE", "api-key"),
|
||||
Secret: getEnv("CERTCTL_AUTH_SECRET", ""),
|
||||
},
|
||||
RateLimit: RateLimitConfig{
|
||||
Enabled: getEnvBool("CERTCTL_RATE_LIMIT_ENABLED", true),
|
||||
RPS: getEnvFloat("CERTCTL_RATE_LIMIT_RPS", 50),
|
||||
BurstSize: getEnvInt("CERTCTL_RATE_LIMIT_BURST", 100),
|
||||
},
|
||||
CORS: CORSConfig{
|
||||
AllowedOrigins: getEnvList("CERTCTL_CORS_ORIGINS", nil),
|
||||
},
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
@@ -192,6 +214,67 @@ func getEnvDuration(key string, defaultValue time.Duration) time.Duration {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// getEnvBool reads a boolean environment variable.
|
||||
func getEnvBool(key string, defaultValue bool) bool {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value == "true" || value == "1" || value == "yes"
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// getEnvFloat reads a float64 environment variable.
|
||||
func getEnvFloat(key string, defaultValue float64) float64 {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
f, err := strconv.ParseFloat(value, 64)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return f
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// getEnvList reads a comma-separated list environment variable.
|
||||
func getEnvList(key string, defaultValue []string) []string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
var result []string
|
||||
for _, s := range splitComma(value) {
|
||||
s = trimSpace(s)
|
||||
if s != "" {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// splitComma splits a string by commas (no strings import needed).
|
||||
func splitComma(s string) []string {
|
||||
var parts []string
|
||||
start := 0
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] == ',' {
|
||||
parts = append(parts, s[start:i])
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
parts = append(parts, s[start:])
|
||||
return parts
|
||||
}
|
||||
|
||||
// trimSpace trims leading/trailing whitespace.
|
||||
func trimSpace(s string) string {
|
||||
start, end := 0, len(s)
|
||||
for start < end && (s[start] == ' ' || s[start] == '\t') {
|
||||
start++
|
||||
}
|
||||
for end > start && (s[end-1] == ' ' || s[end-1] == '\t') {
|
||||
end--
|
||||
}
|
||||
return s[start:end]
|
||||
}
|
||||
|
||||
// GetLogLevel returns the appropriate slog.Level from the configured log level.
|
||||
func (c *Config) GetLogLevel() slog.Level {
|
||||
switch c.Log.Level {
|
||||
|
||||
Reference in New Issue
Block a user