package handlers import ( "context" "encoding/json" "net/http" "net/http/httptest" "strings" "sync" "testing" "time" "github.com/go-chi/chi/v5" "github.com/gorilla/websocket" "git.threesix.ai/jordan/sp3-solo-1770327084/pkg/httpresponse" "git.threesix.ai/jordan/sp3-solo-1770327084/pkg/logging" "git.threesix.ai/jordan/sp3-solo-1770327084/pkg/realtime" ) // testHub wraps LocalHub for testing. func newTestHub(t *testing.T) *realtime.LocalHub { t.Helper() hub := realtime.NewHub(logging.Nop()) ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) go hub.Run(ctx) return hub } // testServer creates an HTTP test server with WebSocket handler. func newTestServer(t *testing.T, hub realtime.Hub) *httptest.Server { t.Helper() wsHandler := realtime.NewHandler(hub, logging.Nop(), realtime.HandlerConfig{ Broadcaster: nil, // No Redis for unit tests AuthRequired: false, }) r := chi.NewRouter() r.Mount("/ws", wsHandler.Routes()) r.Get("/ws/stats", func(w http.ResponseWriter, r *http.Request) { stats := wsHandler.GetStats() httpresponse.OK(w, r, stats) }) return httptest.NewServer(r) } // wsURL converts http:// URL to ws:// URL. func wsURL(server *httptest.Server, path string) string { return "ws" + strings.TrimPrefix(server.URL, "http") + path } func TestWebSocket_Upgrade(t *testing.T) { hub := newTestHub(t) server := newTestServer(t, hub) defer server.Close() // Connect via WebSocket conn, resp, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws"), nil) if err != nil { t.Fatalf("failed to connect: %v", err) } defer conn.Close() if resp.StatusCode != http.StatusSwitchingProtocols { t.Errorf("expected status 101, got %d", resp.StatusCode) } // Give hub time to register time.Sleep(50 * time.Millisecond) if hub.ConnectionCount() != 1 { t.Errorf("expected 1 connection, got %d", hub.ConnectionCount()) } } func TestWebSocket_RoomJoinViaURL(t *testing.T) { hub := newTestHub(t) server := newTestServer(t, hub) defer server.Close() // Connect to specific room via URL path conn, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws/test-room"), nil) if err != nil { t.Fatalf("failed to connect: %v", err) } defer conn.Close() // Give hub time to register and join room time.Sleep(50 * time.Millisecond) if hub.RoomCount("test-room") != 1 { t.Errorf("expected 1 client in room, got %d", hub.RoomCount("test-room")) } } func TestWebSocket_RoomJoinViaQuery(t *testing.T) { hub := newTestHub(t) server := newTestServer(t, hub) defer server.Close() // Connect with room query parameter conn, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws?room=query-room"), nil) if err != nil { t.Fatalf("failed to connect: %v", err) } defer conn.Close() // Give hub time to register and join room time.Sleep(50 * time.Millisecond) if hub.RoomCount("query-room") != 1 { t.Errorf("expected 1 client in query-room, got %d", hub.RoomCount("query-room")) } } func TestWebSocket_MessageBroadcast(t *testing.T) { hub := newTestHub(t) server := newTestServer(t, hub) defer server.Close() // Connect two clients to the same room conn1, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws/broadcast-room"), nil) if err != nil { t.Fatalf("client 1 failed to connect: %v", err) } defer conn1.Close() conn2, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws/broadcast-room"), nil) if err != nil { t.Fatalf("client 2 failed to connect: %v", err) } defer conn2.Close() // Give hub time to register both clients time.Sleep(50 * time.Millisecond) if hub.RoomCount("broadcast-room") != 2 { t.Fatalf("expected 2 clients in room, got %d", hub.RoomCount("broadcast-room")) } // Send a message from client 1 msg := realtime.Message{ Type: realtime.MessageTypeChat, Room: "broadcast-room", Data: json.RawMessage(`{"text":"hello"}`), } if err := conn1.WriteJSON(msg); err != nil { t.Fatalf("failed to send message: %v", err) } // Both clients should receive the message var wg sync.WaitGroup wg.Add(2) readMessage := func(conn *websocket.Conn, name string) { defer wg.Done() conn.SetReadDeadline(time.Now().Add(time.Second)) var received realtime.Message if err := conn.ReadJSON(&received); err != nil { t.Errorf("%s failed to read message: %v", name, err) return } if received.Type != realtime.MessageTypeChat { t.Errorf("%s: expected type 'chat', got %s", name, received.Type) } } go readMessage(conn1, "client1") go readMessage(conn2, "client2") wg.Wait() } func TestWebSocket_GlobalBroadcast(t *testing.T) { hub := newTestHub(t) server := newTestServer(t, hub) defer server.Close() // Connect two clients without room (global) conn1, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws"), nil) if err != nil { t.Fatalf("client 1 failed to connect: %v", err) } defer conn1.Close() conn2, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws"), nil) if err != nil { t.Fatalf("client 2 failed to connect: %v", err) } defer conn2.Close() // Give hub time to register time.Sleep(50 * time.Millisecond) if hub.ConnectionCount() != 2 { t.Fatalf("expected 2 connections, got %d", hub.ConnectionCount()) } // Send a global message (no room) msg := realtime.Message{ Type: realtime.MessageTypeChat, Data: json.RawMessage(`{"text":"global message"}`), } if err := conn1.WriteJSON(msg); err != nil { t.Fatalf("failed to send message: %v", err) } // Both clients should receive var wg sync.WaitGroup wg.Add(2) readMessage := func(conn *websocket.Conn, name string) { defer wg.Done() conn.SetReadDeadline(time.Now().Add(time.Second)) var received realtime.Message if err := conn.ReadJSON(&received); err != nil { t.Errorf("%s failed to read message: %v", name, err) return } } go readMessage(conn1, "client1") go readMessage(conn2, "client2") wg.Wait() } func TestWebSocket_Stats(t *testing.T) { hub := newTestHub(t) server := newTestServer(t, hub) defer server.Close() // Check stats before any connections resp, err := http.Get(server.URL + "/ws/stats") if err != nil { t.Fatalf("failed to get stats: %v", err) } resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } // Connect a client conn, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws"), nil) if err != nil { t.Fatalf("failed to connect: %v", err) } defer conn.Close() // Give hub time to register time.Sleep(50 * time.Millisecond) // Check stats after connection resp, err = http.Get(server.URL + "/ws/stats") if err != nil { t.Fatalf("failed to get stats: %v", err) } defer resp.Body.Close() var result map[string]any if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { t.Fatalf("failed to decode stats: %v", err) } data, ok := result["data"].(map[string]any) if !ok { t.Fatal("expected 'data' field in response") } count, ok := data["total_connections"].(float64) if !ok { t.Fatal("expected 'total_connections' in data") } if int(count) != 1 { t.Errorf("expected 1 connection in stats, got %d", int(count)) } } func TestWebSocket_Disconnect(t *testing.T) { hub := newTestHub(t) server := newTestServer(t, hub) defer server.Close() // Connect conn, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws"), nil) if err != nil { t.Fatalf("failed to connect: %v", err) } // Give hub time to register time.Sleep(50 * time.Millisecond) if hub.ConnectionCount() != 1 { t.Fatalf("expected 1 connection, got %d", hub.ConnectionCount()) } // Close connection conn.Close() // Give hub time to unregister time.Sleep(100 * time.Millisecond) if hub.ConnectionCount() != 0 { t.Errorf("expected 0 connections after disconnect, got %d", hub.ConnectionCount()) } } func TestWebSocket_PingPong(t *testing.T) { hub := newTestHub(t) server := newTestServer(t, hub) defer server.Close() conn, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws"), nil) if err != nil { t.Fatalf("failed to connect: %v", err) } defer conn.Close() // Send ping message msg := realtime.Message{ Type: realtime.MessageTypePing, } if err := conn.WriteJSON(msg); err != nil { t.Fatalf("failed to send ping: %v", err) } // Should receive pong conn.SetReadDeadline(time.Now().Add(time.Second)) var pong realtime.Message if err := conn.ReadJSON(&pong); err != nil { t.Fatalf("failed to read pong: %v", err) } if pong.Type != realtime.MessageTypePong { t.Errorf("expected pong message, got %s", pong.Type) } } func TestWebSocket_RoomIsolation(t *testing.T) { hub := newTestHub(t) server := newTestServer(t, hub) defer server.Close() // Connect client to room A connA, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws/room-a"), nil) if err != nil { t.Fatalf("client A failed to connect: %v", err) } defer connA.Close() // Connect client to room B connB, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws/room-b"), nil) if err != nil { t.Fatalf("client B failed to connect: %v", err) } defer connB.Close() // Give hub time to register time.Sleep(50 * time.Millisecond) // Send message to room A msg := realtime.Message{ Type: realtime.MessageTypeChat, Room: "room-a", Data: json.RawMessage(`{"text":"room A only"}`), } if err := connA.WriteJSON(msg); err != nil { t.Fatalf("failed to send message: %v", err) } // Client A should receive it connA.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) var received realtime.Message if err := connA.ReadJSON(&received); err != nil { t.Errorf("client A should receive message: %v", err) } // Client B should NOT receive it (timeout expected) connB.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) if err := connB.ReadJSON(&received); err == nil { t.Error("client B should NOT receive room-a message") } }