persona-community-5/pkg/auth/middleware.go
jordan bd2f591b98
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
ci/woodpecker/manual/woodpecker Pipeline was successful
Initialize project from skeleton template
2026-02-24 07:39:46 +00:00

309 lines
9.1 KiB
Go

package auth
import (
"context"
"errors"
"net/http"
"strings"
"git.threesix.ai/jordan/persona-community-5/pkg/httperror"
"git.threesix.ai/jordan/persona-community-5/pkg/httpresponse"
)
// MiddlewareConfig configures the authentication middleware.
type MiddlewareConfig struct {
// Validator is the token/key validator to use
Validator Validator
// TokenExtractor extracts the token from the request (optional)
// Default: BearerTokenExtractor or APIKeyExtractor
TokenExtractor func(*http.Request) string
// Optional returns 401 only when a token is provided but invalid.
// If no token is provided, the request continues without authentication.
Optional bool
// AllowExpired accepts expired tokens (still validates signature).
// Use for token refresh endpoints where the caller presents an expired
// access token to prove identity, and session validity is checked separately.
AllowExpired bool
// SkipPaths are paths that skip authentication entirely
SkipPaths []string
}
// Middleware creates an authentication middleware.
//
// Example:
//
// r.Use(auth.Middleware(auth.MiddlewareConfig{
// Validator: jwtValidator,
// }))
//
// // Or with optional auth (passes through if no token)
// r.Use(auth.Middleware(auth.MiddlewareConfig{
// Validator: jwtValidator,
// Optional: true,
// }))
func Middleware(cfg MiddlewareConfig) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check if path should be skipped
for _, path := range cfg.SkipPaths {
if strings.HasPrefix(r.URL.Path, path) {
next.ServeHTTP(w, r)
return
}
}
// Extract token
var token string
if cfg.TokenExtractor != nil {
token = cfg.TokenExtractor(r)
} else {
// Try Bearer token first, then API key
token = ExtractBearerToken(r)
if token == "" {
token = ExtractAPIKey(r)
}
}
// No token provided
if token == "" {
if cfg.Optional {
next.ServeHTTP(w, r)
return
}
httpresponse.Unauthorized(w, r, "authentication required")
return
}
// Validate token
user, err := cfg.Validator.Validate(r.Context(), token)
if err != nil {
// If AllowExpired is set and the token is expired (but signature valid),
// re-validate with relaxed expiry for refresh flows.
if cfg.AllowExpired && errors.Is(err, ErrExpiredToken) {
if jwtVal, ok := cfg.Validator.(*JWTValidator); ok {
user, err = jwtVal.ValidateAllowExpired(r.Context(), token)
}
}
if err != nil {
if cfg.Optional {
// Token invalid/expired but auth is optional — continue without user context.
// The handler falls back to anonymous behavior via auth.GetUser() == nil.
next.ServeHTTP(w, r)
return
}
httpresponse.Unauthorized(w, r, "invalid credentials")
return
}
}
// Store user and token in context
ctx := SetUser(r.Context(), user)
ctx = SetToken(ctx, token)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// RequireAuth middleware requires authentication.
// Use after auth.Middleware to ensure a user is present.
func RequireAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !IsAuthenticated(r.Context()) {
httpresponse.Unauthorized(w, r, "authentication required")
return
}
next.ServeHTTP(w, r)
})
}
// RequireRole middleware requires the user to have a specific role.
func RequireRole(role string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := GetUser(r.Context())
if user == nil {
httpresponse.Unauthorized(w, r, "authentication required")
return
}
if !user.HasRole(role) {
httpresponse.Forbidden(w, r, "insufficient permissions")
return
}
next.ServeHTTP(w, r)
})
}
}
// RequireAnyRole middleware requires the user to have any of the specified roles.
func RequireAnyRole(roles ...string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := GetUser(r.Context())
if user == nil {
httpresponse.Unauthorized(w, r, "authentication required")
return
}
if !user.HasAnyRole(roles...) {
httpresponse.Forbidden(w, r, "insufficient permissions")
return
}
next.ServeHTTP(w, r)
})
}
}
// RequireScope middleware requires the user to have a specific scope.
func RequireScope(scope string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := GetUser(r.Context())
if user == nil {
httpresponse.Unauthorized(w, r, "authentication required")
return
}
if !user.HasScope(scope) {
httpresponse.Forbidden(w, r, "insufficient scope")
return
}
next.ServeHTTP(w, r)
})
}
}
// -----------------------------------------------------------------------------
// Token Extractors
// -----------------------------------------------------------------------------
// ExtractBearerToken extracts a Bearer token from the Authorization header.
func ExtractBearerToken(r *http.Request) string {
auth := r.Header.Get("Authorization")
if auth == "" {
return ""
}
parts := strings.SplitN(auth, " ", 2)
if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") {
return ""
}
return parts[1]
}
// ExtractAPIKey extracts an API key from the X-API-Key header.
func ExtractAPIKey(r *http.Request) string {
return r.Header.Get("X-API-Key")
}
// ExtractFromQuery extracts a token from a query parameter.
func ExtractFromQuery(paramName string) func(*http.Request) string {
return func(r *http.Request) string {
return r.URL.Query().Get(paramName)
}
}
// ExtractFromCookie extracts a token from a cookie.
func ExtractFromCookie(cookieName string) func(*http.Request) string {
return func(r *http.Request) string {
cookie, err := r.Cookie(cookieName)
if err != nil {
return ""
}
return cookie.Value
}
}
// -----------------------------------------------------------------------------
// Error-returning middleware helpers (for use with app.Wrap)
// -----------------------------------------------------------------------------
// RequireAuthErr returns an error if the user is not authenticated.
// Use with app.Wrap pattern.
func RequireAuthErr(ctx context.Context) error {
if !IsAuthenticated(ctx) {
return httperror.Unauthorized("authentication required")
}
return nil
}
// RequireRoleErr returns an error if the user doesn't have the role.
// Use with app.Wrap pattern.
func RequireRoleErr(ctx context.Context, role string) error {
user := GetUser(ctx)
if user == nil {
return httperror.Unauthorized("authentication required")
}
if !user.HasRole(role) {
return httperror.Forbidden("insufficient permissions")
}
return nil
}
// RequireScopeErr returns an error if the user doesn't have the scope.
// Use with app.Wrap pattern.
func RequireScopeErr(ctx context.Context, scope string) error {
user := GetUser(ctx)
if user == nil {
return httperror.Unauthorized("authentication required")
}
if !user.HasScope(scope) {
return httperror.Forbidden("insufficient scope")
}
return nil
}
// SessionChecker is a function that checks whether a session is still active.
// Returns true if the session is active, false if revoked/expired.
// Implementations should query the session store.
type SessionChecker func(ctx context.Context, sessionID string) (bool, error)
// SessionCheck middleware validates that the JWT's embedded session is still active.
// It extracts the "sid" from the authenticated user's Metadata and calls the checker.
// If the session has been revoked, the request is rejected with 401.
//
// This middleware must be applied AFTER auth.Middleware (which sets the user in context).
// It is opt-in — services that don't need session revocation can skip it.
//
// Example:
//
// checker := func(ctx context.Context, sid string) (bool, error) {
// session, err := sessionRepo.Get(ctx, domain.SessionID(sid))
// if err != nil { return false, nil }
// return session.IsActive(), nil
// }
// r.Use(auth.SessionCheck(checker))
func SessionCheck(checker SessionChecker) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := GetUser(r.Context())
if user == nil {
// No user in context — let downstream middleware handle it.
next.ServeHTTP(w, r)
return
}
// Extract session ID from JWT metadata.
sid, _ := user.Metadata["sid"].(string)
if sid == "" {
// Token has no session ID (e.g., old token before sessions were added).
// Allow through — backward compatible.
next.ServeHTTP(w, r)
return
}
active, err := checker(r.Context(), sid)
if err != nil {
// Session check failed — fail open is dangerous, fail closed.
httpresponse.Unauthorized(w, r, "session validation failed")
return
}
if !active {
httpresponse.Unauthorized(w, r, "session has been revoked")
return
}
next.ServeHTTP(w, r)
})
}
}