Implement M7: auth middleware, rate limiting, CORS, and GUI login flow

Add SHA-256 API key authentication with constant-time comparison, configurable
token bucket rate limiter, CORS origin allowlist middleware, and React auth
context with login page. Auth info endpoint bootstraps GUI without credentials.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
shankar0123
2026-03-15 11:58:13 -04:00
parent 2ba8245159
commit 28205e1131
12 changed files with 590 additions and 71 deletions
+23 -3
View File
@@ -5,11 +5,13 @@ import (
)
// HealthHandler handles health and readiness check endpoints.
type HealthHandler struct{}
type HealthHandler struct {
AuthType string // "api-key", "jwt", "none"
}
// NewHealthHandler creates a new HealthHandler.
func NewHealthHandler() HealthHandler {
return HealthHandler{}
func NewHealthHandler(authType string) HealthHandler {
return HealthHandler{AuthType: authType}
}
// Health responds with a simple health check indicating the service is alive.
@@ -41,3 +43,21 @@ func (h HealthHandler) Ready(w http.ResponseWriter, r *http.Request) {
JSON(w, http.StatusOK, response)
}
// AuthInfo responds with the server's authentication configuration.
// This lets the GUI know whether to show a login screen.
// GET /api/v1/auth/info (served without auth middleware)
func (h HealthHandler) AuthInfo(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"auth_type": h.AuthType,
"required": h.AuthType != "none",
}
JSON(w, http.StatusOK, response)
}
// AuthCheck returns 200 if the request has valid auth credentials.
// The auth middleware runs before this handler, so reaching here means auth passed.
// GET /api/v1/auth/check
func (h HealthHandler) AuthCheck(w http.ResponseWriter, r *http.Request) {
JSON(w, http.StatusOK, map[string]string{"status": "authenticated"})
}
+168 -21
View File
@@ -2,8 +2,12 @@ package middleware
import (
"context"
"crypto/sha256"
"crypto/subtle"
"encoding/hex"
"log"
"net/http"
"sync"
"time"
"github.com/google/uuid"
@@ -48,36 +52,178 @@ func Recovery(next http.Handler) http.Handler {
if err := recover(); err != nil {
requestID := getRequestID(r.Context())
log.Printf("[%s] PANIC: %v", requestID, err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
http.Error(w, `{"error":"Internal Server Error"}`, http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}
// Auth middleware is a placeholder that checks the Authorization header and extracts user information.
// In production, this would validate tokens, verify signatures, etc.
func Auth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
// For now, allow requests without auth (placeholder)
// In production, enforce auth on protected routes
next.ServeHTTP(w, r)
return
}
// HashAPIKey computes the SHA-256 hash of an API key for secure storage.
// We use SHA-256 rather than bcrypt because API keys are high-entropy
// random strings (not user-chosen passwords), so rainbow tables and
// brute-force attacks are not a practical concern.
func HashAPIKey(key string) string {
h := sha256.Sum256([]byte(key))
return hex.EncodeToString(h[:])
}
// Simple stub: just extract user ID from Bearer token (format: "Bearer <user_id>")
// This is NOT secure and for development only
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
userID := authHeader[7:]
ctx := context.WithValue(r.Context(), UserKey{}, userID)
// AuthConfig holds configuration for the Auth middleware.
type AuthConfig struct {
Type string // "api-key", "jwt", "none"
Secret string // The raw API key (server compares against this)
}
// NewAuth creates an authentication middleware based on config.
// When Type is "none", all requests pass through (demo/development mode).
// When Type is "api-key", requests must include a valid Bearer token.
func NewAuth(cfg AuthConfig) func(http.Handler) http.Handler {
if cfg.Type == "none" {
return func(next http.Handler) http.Handler {
return next
}
}
// Pre-compute hash of the expected key for constant-time comparison
expectedHash := HashAPIKey(cfg.Secret)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Header().Set("WWW-Authenticate", `Bearer realm="certctl"`)
http.Error(w, `{"error":"Authorization header required"}`, http.StatusUnauthorized)
return
}
// Extract Bearer token
if len(authHeader) < 8 || authHeader[:7] != "Bearer " {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
http.Error(w, `{"error":"Invalid Authorization header format, expected: Bearer <token>"}`, http.StatusUnauthorized)
return
}
token := authHeader[7:]
tokenHash := HashAPIKey(token)
// Constant-time comparison to prevent timing attacks
if subtle.ConstantTimeCompare([]byte(tokenHash), []byte(expectedHash)) != 1 {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
http.Error(w, `{"error":"Invalid API key"}`, http.StatusUnauthorized)
return
}
// Store the authenticated identity in context
ctx := context.WithValue(r.Context(), UserKey{}, "api-key-user")
next.ServeHTTP(w, r.WithContext(ctx))
return
}
})
}
}
http.Error(w, "Invalid Authorization header", http.StatusUnauthorized)
})
// RateLimitConfig holds configuration for the rate limiter.
type RateLimitConfig struct {
RPS float64 // Requests per second
BurstSize int // Maximum burst size
}
// NewRateLimiter creates a token bucket rate limiting middleware.
// Uses a simple token bucket: tokens refill at RPS rate, burst allows short spikes.
func NewRateLimiter(cfg RateLimitConfig) func(http.Handler) http.Handler {
limiter := &tokenBucket{
rate: cfg.RPS,
burstSize: float64(cfg.BurstSize),
tokens: float64(cfg.BurstSize),
lastRefill: time.Now(),
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !limiter.allow() {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Header().Set("Retry-After", "1")
http.Error(w, `{"error":"Rate limit exceeded"}`, http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
// tokenBucket implements a simple thread-safe token bucket rate limiter.
// This avoids importing golang.org/x/time/rate to keep dependencies minimal.
type tokenBucket struct {
mu sync.Mutex
rate float64 // tokens per second
burstSize float64 // max tokens
tokens float64 // current tokens
lastRefill time.Time // last refill time
}
func (tb *tokenBucket) allow() bool {
tb.mu.Lock()
defer tb.mu.Unlock()
now := time.Now()
elapsed := now.Sub(tb.lastRefill).Seconds()
tb.tokens += elapsed * tb.rate
if tb.tokens > tb.burstSize {
tb.tokens = tb.burstSize
}
tb.lastRefill = now
if tb.tokens < 1 {
return false
}
tb.tokens--
return true
}
// CORSConfig holds configuration for the CORS middleware.
type CORSConfig struct {
AllowedOrigins []string // Allowed origins; empty = same-origin only
}
// NewCORS creates a CORS middleware with configurable allowed origins.
// If no origins are configured, same-origin requests are allowed by default.
// If ["*"] is configured, all origins are allowed (development/demo mode).
func NewCORS(cfg CORSConfig) func(http.Handler) http.Handler {
allowAll := false
originSet := make(map[string]bool)
for _, o := range cfg.AllowedOrigins {
if o == "*" {
allowAll = true
}
originSet[o] = true
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if allowAll {
w.Header().Set("Access-Control-Allow-Origin", "*")
} else if origin != "" && originSet[origin] {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Vary", "Origin")
} else if len(cfg.AllowedOrigins) == 0 && origin != "" {
// No config = permissive same-origin default for single-host deployments
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Vary", "Origin")
}
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Request-ID")
w.Header().Set("Access-Control-Max-Age", "86400")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}
}
// ContentType middleware sets the Content-Type header to application/json.
@@ -89,6 +235,7 @@ func ContentType(next http.Handler) http.Handler {
}
// CORS middleware adds CORS headers to allow cross-origin requests.
// Deprecated: Use NewCORS for configurable origins. Kept for health endpoints.
func CORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
+9 -1
View File
@@ -57,7 +57,7 @@ func (r *Router) RegisterHandlers(
notifications handler.NotificationHandler,
health handler.HealthHandler,
) {
// Health endpoints (no middleware)
// Health endpoints (no auth middleware — must always be accessible)
r.mux.Handle("GET /health", middleware.Chain(
http.HandlerFunc(health.Health),
middleware.CORS,
@@ -68,6 +68,14 @@ func (r *Router) RegisterHandlers(
middleware.CORS,
middleware.ContentType,
))
// Auth info endpoint (no auth middleware — GUI needs this before login)
r.mux.Handle("GET /api/v1/auth/info", middleware.Chain(
http.HandlerFunc(health.AuthInfo),
middleware.CORS,
middleware.ContentType,
))
// Auth check endpoint (uses full middleware chain via r.Register)
r.Register("GET /api/v1/auth/check", http.HandlerFunc(health.AuthCheck))
// Certificates routes: /api/v1/certificates
r.Register("GET /api/v1/certificates", http.HandlerFunc(certificates.ListCertificates))
+83
View File
@@ -16,6 +16,8 @@ type Config struct {
Scheduler SchedulerConfig
Log LogConfig
Auth AuthConfig
RateLimit RateLimitConfig
CORS CORSConfig
}
// ServerConfig contains HTTP server configuration.
@@ -51,6 +53,18 @@ type AuthConfig struct {
Secret string // Secret key for signing (if applicable)
}
// RateLimitConfig contains rate limiting configuration.
type RateLimitConfig struct {
Enabled bool
RPS float64 // Requests per second
BurstSize int // Maximum burst size
}
// CORSConfig contains CORS configuration.
type CORSConfig struct {
AllowedOrigins []string // Allowed origins; empty = same-origin only; ["*"] = all
}
// Load reads configuration from environment variables and returns a Config.
// Environment variables must have the CERTCTL_ prefix.
// Example: CERTCTL_SERVER_HOST, CERTCTL_DATABASE_URL, etc.
@@ -79,6 +93,14 @@ func Load() (*Config, error) {
Type: getEnv("CERTCTL_AUTH_TYPE", "api-key"),
Secret: getEnv("CERTCTL_AUTH_SECRET", ""),
},
RateLimit: RateLimitConfig{
Enabled: getEnvBool("CERTCTL_RATE_LIMIT_ENABLED", true),
RPS: getEnvFloat("CERTCTL_RATE_LIMIT_RPS", 50),
BurstSize: getEnvInt("CERTCTL_RATE_LIMIT_BURST", 100),
},
CORS: CORSConfig{
AllowedOrigins: getEnvList("CERTCTL_CORS_ORIGINS", nil),
},
}
if err := cfg.Validate(); err != nil {
@@ -192,6 +214,67 @@ func getEnvDuration(key string, defaultValue time.Duration) time.Duration {
return defaultValue
}
// getEnvBool reads a boolean environment variable.
func getEnvBool(key string, defaultValue bool) bool {
if value := os.Getenv(key); value != "" {
return value == "true" || value == "1" || value == "yes"
}
return defaultValue
}
// getEnvFloat reads a float64 environment variable.
func getEnvFloat(key string, defaultValue float64) float64 {
if value := os.Getenv(key); value != "" {
f, err := strconv.ParseFloat(value, 64)
if err != nil {
return defaultValue
}
return f
}
return defaultValue
}
// getEnvList reads a comma-separated list environment variable.
func getEnvList(key string, defaultValue []string) []string {
if value := os.Getenv(key); value != "" {
var result []string
for _, s := range splitComma(value) {
s = trimSpace(s)
if s != "" {
result = append(result, s)
}
}
return result
}
return defaultValue
}
// splitComma splits a string by commas (no strings import needed).
func splitComma(s string) []string {
var parts []string
start := 0
for i := 0; i < len(s); i++ {
if s[i] == ',' {
parts = append(parts, s[start:i])
start = i + 1
}
}
parts = append(parts, s[start:])
return parts
}
// trimSpace trims leading/trailing whitespace.
func trimSpace(s string) string {
start, end := 0, len(s)
for start < end && (s[start] == ' ' || s[start] == '\t') {
start++
}
for end > start && (s[end-1] == ' ' || s[end-1] == '\t') {
end--
}
return s[start:end]
}
// GetLogLevel returns the appropriate slog.Level from the configured log level.
func (c *Config) GetLogLevel() slog.Level {
switch c.Log.Level {