package logging import ( "bytes" "encoding/json" "net/http" "net/http/httptest" "strings" "testing" ) func TestMiddleware(t *testing.T) { var buf bytes.Buffer cfg := DefaultConfig() cfg.RedactEnabled = false logger := NewWithWriter(cfg, &buf) middlewareCfg := DefaultMiddlewareConfig() middlewareCfg.Logger = logger handler := Middleware(middlewareCfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("ok")) })) req := httptest.NewRequest("GET", "/api/test", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("expected status 200, got %d", rec.Code) } output := buf.String() // Should log request completion if !strings.Contains(output, "request completed") { t.Errorf("expected 'request completed' in log, got: %s", output) } // Should include HTTP fields if !strings.Contains(output, FieldHTTPMethod) { t.Errorf("expected http_method field, got: %s", output) } if !strings.Contains(output, FieldHTTPPath) { t.Errorf("expected http_path field, got: %s", output) } if !strings.Contains(output, FieldHTTPStatus) { t.Errorf("expected http_status field, got: %s", output) } if !strings.Contains(output, FieldDuration) { t.Errorf("expected duration_ms field, got: %s", output) } } func TestMiddlewareGeneratesRequestID(t *testing.T) { var buf bytes.Buffer cfg := DefaultConfig() cfg.RedactEnabled = false logger := NewWithWriter(cfg, &buf) middlewareCfg := DefaultMiddlewareConfig() middlewareCfg.Logger = logger middlewareCfg.GenerateRequestID = true handler := Middleware(middlewareCfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest("GET", "/api/test", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) // Should set request ID header in response requestID := rec.Header().Get(RequestIDHeader) if requestID == "" { t.Error("expected X-Request-ID header to be set") } // Should include request ID in log output := buf.String() if !strings.Contains(output, FieldRequestID) { t.Errorf("expected request_id field, got: %s", output) } } func TestMiddlewarePropagatesRequestID(t *testing.T) { var buf bytes.Buffer cfg := DefaultConfig() cfg.RedactEnabled = false logger := NewWithWriter(cfg, &buf) middlewareCfg := DefaultMiddlewareConfig() middlewareCfg.Logger = logger handler := Middleware(middlewareCfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest("GET", "/api/test", nil) req.Header.Set(RequestIDHeader, "test-request-id-123") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) // Should echo back the request ID requestID := rec.Header().Get(RequestIDHeader) if requestID != "test-request-id-123" { t.Errorf("expected X-Request-ID to be echoed, got: %s", requestID) } // Should use the provided request ID in log output := buf.String() if !strings.Contains(output, "test-request-id-123") { t.Errorf("expected provided request_id in log, got: %s", output) } } func TestMiddlewareSkipPaths(t *testing.T) { var buf bytes.Buffer cfg := DefaultConfig() cfg.RedactEnabled = false logger := NewWithWriter(cfg, &buf) middlewareCfg := DefaultMiddlewareConfig() middlewareCfg.Logger = logger middlewareCfg.SkipPaths = map[string]bool{"/health": true} handler := Middleware(middlewareCfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest("GET", "/health", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) // Should not log anything for skipped paths output := buf.String() if output != "" { t.Errorf("expected no log output for skipped path, got: %s", output) } } func TestMiddlewareLogLevelByStatus(t *testing.T) { tests := []struct { status int logLevel string }{ {200, "INFO"}, {201, "INFO"}, {204, "INFO"}, {301, "INFO"}, {400, "WARN"}, {401, "WARN"}, {404, "WARN"}, {500, "ERROR"}, {502, "ERROR"}, {503, "ERROR"}, } for _, tt := range tests { t.Run(http.StatusText(tt.status), func(t *testing.T) { var buf bytes.Buffer cfg := DefaultConfig() cfg.RedactEnabled = false logger := NewWithWriter(cfg, &buf) middlewareCfg := DefaultMiddlewareConfig() middlewareCfg.Logger = logger handler := Middleware(middlewareCfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(tt.status) })) req := httptest.NewRequest("GET", "/api/test", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) output := buf.String() if !strings.Contains(output, tt.logLevel) { t.Errorf("expected %s level for status %d, got: %s", tt.logLevel, tt.status, output) } }) } } func TestMiddlewareContextLogger(t *testing.T) { var buf bytes.Buffer cfg := DefaultConfig() cfg.RedactEnabled = false logger := NewWithWriter(cfg, &buf) middlewareCfg := DefaultMiddlewareConfig() middlewareCfg.Logger = logger var ctxLogger *Logger handler := Middleware(middlewareCfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Handler can get logger from context ctxLogger = FromContext(r.Context()) ctxLogger.Info("handler logging") w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest("GET", "/api/test", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if ctxLogger == nil { t.Fatal("expected logger in context") } output := buf.String() // Handler's log should include the request context if !strings.Contains(output, "handler logging") { t.Errorf("expected handler log in output, got: %s", output) } } func TestMiddlewareJSONOutput(t *testing.T) { var buf bytes.Buffer cfg := DefaultConfig() cfg.Format = FormatJSON cfg.RedactEnabled = false logger := NewWithWriter(cfg, &buf) middlewareCfg := DefaultMiddlewareConfig() middlewareCfg.Logger = logger handler := Middleware(middlewareCfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("test")) })) req := httptest.NewRequest("GET", "/api/test", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) // Should have two JSON log entries (debug start + info complete) lines := strings.Split(strings.TrimSpace(buf.String()), "\n") if len(lines) < 1 { t.Fatalf("expected at least 1 log line, got: %d", len(lines)) } // Parse the completion log (last line) var entry map[string]any if err := json.Unmarshal([]byte(lines[len(lines)-1]), &entry); err != nil { t.Fatalf("expected valid JSON, got error: %v, line: %s", err, lines[len(lines)-1]) } // Verify expected fields if entry[FieldHTTPMethod] != "GET" { t.Errorf("expected http_method=GET, got: %v", entry[FieldHTTPMethod]) } if entry[FieldHTTPPath] != "/api/test" { t.Errorf("expected http_path=/api/test, got: %v", entry[FieldHTTPPath]) } if entry[FieldHTTPStatus] != float64(200) { t.Errorf("expected http_status=200, got: %v", entry[FieldHTTPStatus]) } if _, ok := entry[FieldDuration]; !ok { t.Error("expected duration_ms field") } if _, ok := entry[FieldRequestID]; !ok { t.Error("expected request_id field") } }