package middleware import ( "net/http" "time" "github.com/jordan/composed4/pkg/httpcontext" "github.com/jordan/composed4/pkg/logging" ) // responseWriter wraps http.ResponseWriter to capture status code. type responseWriter struct { http.ResponseWriter status int wroteHeader bool bytesWritten int } func (rw *responseWriter) WriteHeader(code int) { if rw.wroteHeader { return } rw.status = code rw.wroteHeader = true rw.ResponseWriter.WriteHeader(code) } func (rw *responseWriter) Write(b []byte) (int, error) { if !rw.wroteHeader { rw.WriteHeader(http.StatusOK) } n, err := rw.ResponseWriter.Write(b) rw.bytesWritten += n return n, err } // RequestLogger returns a middleware that logs HTTP requests using slog. // It logs request completion with status code, duration, and bytes written. // Log level is determined by response status (error for 5xx, warn for 4xx, info otherwise). // // IMPORTANT: This middleware expects the RequestID and Tracing middleware to have // run first to set request_id and trace_id in context. // // Usage: // // r.Use(middleware.RequestID()) // r.Use(middleware.Tracing()) // r.Use(middleware.RequestLogger(logger)) // r.Use(middleware.Recoverer(logger)) func RequestLogger(logger *logging.Logger) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() // Wrap response writer to capture status code and bytes wrapped := &responseWriter{ ResponseWriter: w, status: http.StatusOK, } // Get request ID and trace ID from context (set by middleware) requestID, _ := httpcontext.GetRequestID(r.Context()) traceID, _ := httpcontext.GetTraceID(r.Context()) // Create request-scoped logger reqLogger := logger.With( "request_id", requestID, "trace_id", traceID, "method", r.Method, "path", r.URL.Path, "remote_addr", r.RemoteAddr, ) // Store logger in context for handlers to use ctx := logging.NewContext(r.Context(), reqLogger) // Log request start at debug level reqLogger.Debug("request started", "user_agent", r.UserAgent(), ) // Call next handler with enriched context next.ServeHTTP(wrapped, r.WithContext(ctx)) // Calculate duration duration := time.Since(start) // Determine log level based on status and log completion attrs := []any{ "status", wrapped.status, "duration_ms", duration.Milliseconds(), "bytes", wrapped.bytesWritten, } switch { case wrapped.status >= 500: reqLogger.Error("request completed", attrs...) case wrapped.status >= 400: reqLogger.Warn("request completed", attrs...) default: reqLogger.Info("request completed", attrs...) } }) } }