package api import ( "encoding/json" "errors" "io" "log/slog" "net/http" "net/http/httptest" "testing" ) func TestWrap(t *testing.T) { tests := []struct { name string handler HandlerFunc wantStatus int wantCode string wantHasError bool wantHasData bool }{ { name: "success response", handler: func(w http.ResponseWriter, r *http.Request) error { WriteSuccess(w, r, map[string]string{"message": "hello"}) return nil }, wantStatus: http.StatusOK, wantHasData: true, wantHasError: false, }, { name: "HTTPError returned", handler: func(w http.ResponseWriter, r *http.Request) error { return NotFound("user not found") }, wantStatus: http.StatusNotFound, wantCode: "NOT_FOUND", wantHasError: true, }, { name: "generic error returned", handler: func(w http.ResponseWriter, r *http.Request) error { return errors.New("something went wrong") }, wantStatus: http.StatusInternalServerError, wantCode: "INTERNAL_ERROR", wantHasError: true, }, { name: "validation error with details", handler: func(w http.ResponseWriter, r *http.Request) error { return WithDetails(Validation("validation failed"), []ValidationDetail{ {Field: "email", Message: "is required"}, }) }, wantStatus: http.StatusBadRequest, wantCode: "VALIDATION_ERROR", wantHasError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/test", nil) rec := httptest.NewRecorder() wrapped := Wrap(tt.handler) wrapped(rec, req) if rec.Code != tt.wantStatus { t.Errorf("status = %d, want %d", rec.Code, tt.wantStatus) } var resp Response if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { t.Fatalf("failed to decode response: %v", err) } hasError := resp.Error != nil if hasError != tt.wantHasError { t.Errorf("hasError = %v, want %v", hasError, tt.wantHasError) } hasData := resp.Data != nil if hasData != tt.wantHasData { t.Errorf("hasData = %v, want %v", hasData, tt.wantHasData) } if tt.wantCode != "" && resp.Error != nil { if resp.Error.Code != tt.wantCode { t.Errorf("error code = %q, want %q", resp.Error.Code, tt.wantCode) } } }) } } func TestWrapWithLogger(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) // Handler that returns a generic error (should be logged) h := func(w http.ResponseWriter, r *http.Request) error { return errors.New("database connection failed") } req := httptest.NewRequest(http.MethodGet, "/test", nil) rec := httptest.NewRecorder() wrapped := WrapWithLogger(h, logger) wrapped(rec, req) if rec.Code != http.StatusInternalServerError { t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError) } // Verify response doesn't leak internal error details var resp Response if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { t.Fatalf("failed to decode response: %v", err) } if resp.Error == nil { t.Fatal("expected error in response") } // Message should be generic, not the actual error if resp.Error.Message != "internal error" { t.Errorf("error message = %q, want %q", resp.Error.Message, "internal error") } } func TestWrapMiddleware(t *testing.T) { authMiddleware := func(next http.Handler) func(w http.ResponseWriter, r *http.Request) error { return func(w http.ResponseWriter, r *http.Request) error { token := r.Header.Get("Authorization") if token == "" { return Unauthorized("missing authorization header") } next.ServeHTTP(w, r) return nil } } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { WriteSuccess(w, r, map[string]string{"status": "ok"}) }) middleware := WrapMiddleware(authMiddleware) wrapped := middleware(handler) t.Run("unauthorized without token", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/test", nil) rec := httptest.NewRecorder() wrapped.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized) } }) t.Run("success with token", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/test", nil) req.Header.Set("Authorization", "Bearer token123") rec := httptest.NewRecorder() wrapped.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) } }) }