package handlers import ( "context" "encoding/json" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/gorilla/websocket" "git.threesix.ai/jordan/sp3-verify-1770325830/pkg/logging" "git.threesix.ai/jordan/sp3-verify-1770325830/pkg/realtime" ) // testSetup creates a hub and handler for WebSocket tests. func testSetup(t *testing.T) (*realtime.LocalHub, *realtime.Handler, context.CancelFunc) { t.Helper() ctx, cancel := context.WithCancel(context.Background()) logger := logging.Nop() hub := realtime.NewHub(logger) go hub.Run(ctx) handler := realtime.NewHandler(hub, logger, realtime.HandlerConfig{ Broadcaster: nil, // No Redis for tests AuthRequired: false, OnMessage: func(conn realtime.Connection, msg *realtime.Message) *realtime.Message { if msg.Type == "" { msg.Type = realtime.MessageTypeChat } return msg }, }) return hub, handler, cancel } // dial connects to a WebSocket test server. func dial(t *testing.T, server *httptest.Server, path string) *websocket.Conn { t.Helper() url := "ws" + strings.TrimPrefix(server.URL, "http") + path conn, _, err := websocket.DefaultDialer.Dial(url, nil) if err != nil { t.Fatalf("failed to dial websocket: %v", err) } return conn } func TestWebSocket_Connection(t *testing.T) { hub, handler, cancel := testSetup(t) defer cancel() server := httptest.NewServer(handler.Routes()) defer server.Close() // Connect conn := dial(t, server, "/") defer conn.Close() // Wait for registration time.Sleep(50 * time.Millisecond) // Verify connection count if count := hub.ConnectionCount(); count != 1 { t.Errorf("expected 1 connection, got %d", count) } // Close and verify cleanup conn.Close() time.Sleep(50 * time.Millisecond) if count := hub.ConnectionCount(); count != 0 { t.Errorf("expected 0 connections after close, got %d", count) } } func TestWebSocket_RoomConnection(t *testing.T) { hub, handler, cancel := testSetup(t) defer cancel() server := httptest.NewServer(handler.Routes()) defer server.Close() // Connect to room conn := dial(t, server, "/test-room") defer conn.Close() // Wait for registration time.Sleep(50 * time.Millisecond) // Verify room count if count := hub.RoomCount("test-room"); count != 1 { t.Errorf("expected 1 connection in room, got %d", count) } } func TestWebSocket_MessageBroadcast(t *testing.T) { _, handler, cancel := testSetup(t) defer cancel() server := httptest.NewServer(handler.Routes()) defer server.Close() // Connect two clients conn1 := dial(t, server, "/") defer conn1.Close() conn2 := dial(t, server, "/") defer conn2.Close() // Wait for registration time.Sleep(50 * time.Millisecond) // Send message from conn1 msg := realtime.Message{ Type: realtime.MessageTypeChat, Data: json.RawMessage(`{"content":"Hello"}`), } if err := conn1.WriteJSON(msg); err != nil { t.Fatalf("failed to send message: %v", err) } // Read from conn2 conn2.SetReadDeadline(time.Now().Add(2 * time.Second)) var received realtime.Message if err := conn2.ReadJSON(&received); err != nil { t.Fatalf("failed to receive message: %v", err) } if received.Type != realtime.MessageTypeChat { t.Errorf("expected type %s, got %s", realtime.MessageTypeChat, received.Type) } } func TestWebSocket_RoomIsolation(t *testing.T) { _, handler, cancel := testSetup(t) defer cancel() server := httptest.NewServer(handler.Routes()) defer server.Close() // Connect to room1 conn1 := dial(t, server, "/room1") defer conn1.Close() // Connect to room2 conn2 := dial(t, server, "/room2") defer conn2.Close() // Wait for registration time.Sleep(50 * time.Millisecond) // Send message to room1 msg := realtime.Message{ Type: realtime.MessageTypeChat, Room: "room1", Data: json.RawMessage(`{"content":"Room1 Only"}`), } if err := conn1.WriteJSON(msg); err != nil { t.Fatalf("failed to send message: %v", err) } // Set short timeout on conn2 - it should NOT receive the message conn2.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) var received realtime.Message err := conn2.ReadJSON(&received) // We expect a timeout error since room2 should not receive room1 messages if err == nil { t.Errorf("room2 client received message meant for room1: %+v", received) } } func TestWebSocket_PingPong(t *testing.T) { _, handler, cancel := testSetup(t) defer cancel() server := httptest.NewServer(handler.Routes()) defer server.Close() conn := dial(t, server, "/") defer conn.Close() // Wait for registration time.Sleep(50 * time.Millisecond) // Send ping msg := realtime.Message{ Type: realtime.MessageTypePing, } if err := conn.WriteJSON(msg); err != nil { t.Fatalf("failed to send ping: %v", err) } // Expect pong conn.SetReadDeadline(time.Now().Add(2 * time.Second)) var received realtime.Message if err := conn.ReadJSON(&received); err != nil { t.Fatalf("failed to receive pong: %v", err) } if received.Type != realtime.MessageTypePong { t.Errorf("expected type %s, got %s", realtime.MessageTypePong, received.Type) } } func TestWebSocket_AuthRequired(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() logger := logging.Nop() hub := realtime.NewHub(logger) go hub.Run(ctx) handler := realtime.NewHandler(hub, logger, realtime.HandlerConfig{ AuthRequired: true, // Require auth }) server := httptest.NewServer(handler.Routes()) defer server.Close() // Try to connect without auth url := "ws" + strings.TrimPrefix(server.URL, "http") + "/" _, resp, err := websocket.DefaultDialer.Dial(url, nil) // Should fail with 401 if err == nil { t.Error("expected connection to fail without auth") } if resp != nil && resp.StatusCode != http.StatusUnauthorized { t.Errorf("expected status 401, got %d", resp.StatusCode) } } func TestWebSocket_MultipleClients(t *testing.T) { hub, handler, cancel := testSetup(t) defer cancel() server := httptest.NewServer(handler.Routes()) defer server.Close() // Connect multiple clients const numClients = 5 conns := make([]*websocket.Conn, numClients) for i := 0; i < numClients; i++ { conns[i] = dial(t, server, "/") defer conns[i].Close() } // Wait for registration time.Sleep(100 * time.Millisecond) if count := hub.ConnectionCount(); count != numClients { t.Errorf("expected %d connections, got %d", numClients, count) } // Close half the clients for i := 0; i < numClients/2; i++ { conns[i].Close() } time.Sleep(100 * time.Millisecond) expected := numClients - numClients/2 if count := hub.ConnectionCount(); count != expected { t.Errorf("expected %d connections after partial close, got %d", expected, count) } }