106 lines
2.8 KiB
Go
106 lines
2.8 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"time"
|
|
|
|
"git.threesix.ai/jordan/slack-q-1770281596/pkg/httpcontext"
|
|
"git.threesix.ai/jordan/slack-q-1770281596/pkg/logging"
|
|
)
|
|
|
|
// responseWriter wraps http.ResponseWriter to capture status code.
|
|
type responseWriter struct {
|
|
http.ResponseWriter
|
|
status int
|
|
wroteHeader bool
|
|
bytesWritten int
|
|
}
|
|
|
|
func (rw *responseWriter) WriteHeader(code int) {
|
|
if rw.wroteHeader {
|
|
return
|
|
}
|
|
rw.status = code
|
|
rw.wroteHeader = true
|
|
rw.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
func (rw *responseWriter) Write(b []byte) (int, error) {
|
|
if !rw.wroteHeader {
|
|
rw.WriteHeader(http.StatusOK)
|
|
}
|
|
n, err := rw.ResponseWriter.Write(b)
|
|
rw.bytesWritten += n
|
|
return n, err
|
|
}
|
|
|
|
// RequestLogger returns a middleware that logs HTTP requests using slog.
|
|
// It logs request completion with status code, duration, and bytes written.
|
|
// Log level is determined by response status (error for 5xx, warn for 4xx, info otherwise).
|
|
//
|
|
// IMPORTANT: This middleware expects the RequestID and Tracing middleware to have
|
|
// run first to set request_id and trace_id in context.
|
|
//
|
|
// Usage:
|
|
//
|
|
// r.Use(middleware.RequestID())
|
|
// r.Use(middleware.Tracing())
|
|
// r.Use(middleware.RequestLogger(logger))
|
|
// r.Use(middleware.Recoverer(logger))
|
|
func RequestLogger(logger *logging.Logger) func(next http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
|
|
// Wrap response writer to capture status code and bytes
|
|
wrapped := &responseWriter{
|
|
ResponseWriter: w,
|
|
status: http.StatusOK,
|
|
}
|
|
|
|
// Get request ID and trace ID from context (set by middleware)
|
|
requestID, _ := httpcontext.GetRequestID(r.Context())
|
|
traceID, _ := httpcontext.GetTraceID(r.Context())
|
|
|
|
// Create request-scoped logger
|
|
reqLogger := logger.With(
|
|
"request_id", requestID,
|
|
"trace_id", traceID,
|
|
"method", r.Method,
|
|
"path", r.URL.Path,
|
|
"remote_addr", r.RemoteAddr,
|
|
)
|
|
|
|
// Store logger in context for handlers to use
|
|
ctx := logging.NewContext(r.Context(), reqLogger)
|
|
|
|
// Log request start at debug level
|
|
reqLogger.Debug("request started",
|
|
"user_agent", r.UserAgent(),
|
|
)
|
|
|
|
// Call next handler with enriched context
|
|
next.ServeHTTP(wrapped, r.WithContext(ctx))
|
|
|
|
// Calculate duration
|
|
duration := time.Since(start)
|
|
|
|
// Determine log level based on status and log completion
|
|
attrs := []any{
|
|
"status", wrapped.status,
|
|
"duration_ms", duration.Milliseconds(),
|
|
"bytes", wrapped.bytesWritten,
|
|
}
|
|
|
|
switch {
|
|
case wrapped.status >= 500:
|
|
reqLogger.Error("request completed", attrs...)
|
|
case wrapped.status >= 400:
|
|
reqLogger.Warn("request completed", attrs...)
|
|
default:
|
|
reqLogger.Info("request completed", attrs...)
|
|
}
|
|
})
|
|
}
|
|
}
|