387 lines
9.7 KiB
Go
387 lines
9.7 KiB
Go
package handlers
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/gorilla/websocket"
|
|
|
|
"git.threesix.ai/jordan/sp3-solo-1770327084/pkg/httpresponse"
|
|
"git.threesix.ai/jordan/sp3-solo-1770327084/pkg/logging"
|
|
"git.threesix.ai/jordan/sp3-solo-1770327084/pkg/realtime"
|
|
)
|
|
|
|
// testHub wraps LocalHub for testing.
|
|
func newTestHub(t *testing.T) *realtime.LocalHub {
|
|
t.Helper()
|
|
hub := realtime.NewHub(logging.Nop())
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
t.Cleanup(cancel)
|
|
go hub.Run(ctx)
|
|
return hub
|
|
}
|
|
|
|
// testServer creates an HTTP test server with WebSocket handler.
|
|
func newTestServer(t *testing.T, hub realtime.Hub) *httptest.Server {
|
|
t.Helper()
|
|
|
|
wsHandler := realtime.NewHandler(hub, logging.Nop(), realtime.HandlerConfig{
|
|
Broadcaster: nil, // No Redis for unit tests
|
|
AuthRequired: false,
|
|
})
|
|
|
|
r := chi.NewRouter()
|
|
r.Mount("/ws", wsHandler.Routes())
|
|
r.Get("/ws/stats", func(w http.ResponseWriter, r *http.Request) {
|
|
stats := wsHandler.GetStats()
|
|
httpresponse.OK(w, r, stats)
|
|
})
|
|
|
|
return httptest.NewServer(r)
|
|
}
|
|
|
|
// wsURL converts http:// URL to ws:// URL.
|
|
func wsURL(server *httptest.Server, path string) string {
|
|
return "ws" + strings.TrimPrefix(server.URL, "http") + path
|
|
}
|
|
|
|
func TestWebSocket_Upgrade(t *testing.T) {
|
|
hub := newTestHub(t)
|
|
server := newTestServer(t, hub)
|
|
defer server.Close()
|
|
|
|
// Connect via WebSocket
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws"), nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to connect: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
if resp.StatusCode != http.StatusSwitchingProtocols {
|
|
t.Errorf("expected status 101, got %d", resp.StatusCode)
|
|
}
|
|
|
|
// Give hub time to register
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
if hub.ConnectionCount() != 1 {
|
|
t.Errorf("expected 1 connection, got %d", hub.ConnectionCount())
|
|
}
|
|
}
|
|
|
|
func TestWebSocket_RoomJoinViaURL(t *testing.T) {
|
|
hub := newTestHub(t)
|
|
server := newTestServer(t, hub)
|
|
defer server.Close()
|
|
|
|
// Connect to specific room via URL path
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws/test-room"), nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to connect: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Give hub time to register and join room
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
if hub.RoomCount("test-room") != 1 {
|
|
t.Errorf("expected 1 client in room, got %d", hub.RoomCount("test-room"))
|
|
}
|
|
}
|
|
|
|
func TestWebSocket_RoomJoinViaQuery(t *testing.T) {
|
|
hub := newTestHub(t)
|
|
server := newTestServer(t, hub)
|
|
defer server.Close()
|
|
|
|
// Connect with room query parameter
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws?room=query-room"), nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to connect: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Give hub time to register and join room
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
if hub.RoomCount("query-room") != 1 {
|
|
t.Errorf("expected 1 client in query-room, got %d", hub.RoomCount("query-room"))
|
|
}
|
|
}
|
|
|
|
func TestWebSocket_MessageBroadcast(t *testing.T) {
|
|
hub := newTestHub(t)
|
|
server := newTestServer(t, hub)
|
|
defer server.Close()
|
|
|
|
// Connect two clients to the same room
|
|
conn1, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws/broadcast-room"), nil)
|
|
if err != nil {
|
|
t.Fatalf("client 1 failed to connect: %v", err)
|
|
}
|
|
defer conn1.Close()
|
|
|
|
conn2, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws/broadcast-room"), nil)
|
|
if err != nil {
|
|
t.Fatalf("client 2 failed to connect: %v", err)
|
|
}
|
|
defer conn2.Close()
|
|
|
|
// Give hub time to register both clients
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
if hub.RoomCount("broadcast-room") != 2 {
|
|
t.Fatalf("expected 2 clients in room, got %d", hub.RoomCount("broadcast-room"))
|
|
}
|
|
|
|
// Send a message from client 1
|
|
msg := realtime.Message{
|
|
Type: realtime.MessageTypeChat,
|
|
Room: "broadcast-room",
|
|
Data: json.RawMessage(`{"text":"hello"}`),
|
|
}
|
|
if err := conn1.WriteJSON(msg); err != nil {
|
|
t.Fatalf("failed to send message: %v", err)
|
|
}
|
|
|
|
// Both clients should receive the message
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
|
|
readMessage := func(conn *websocket.Conn, name string) {
|
|
defer wg.Done()
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
var received realtime.Message
|
|
if err := conn.ReadJSON(&received); err != nil {
|
|
t.Errorf("%s failed to read message: %v", name, err)
|
|
return
|
|
}
|
|
if received.Type != realtime.MessageTypeChat {
|
|
t.Errorf("%s: expected type 'chat', got %s", name, received.Type)
|
|
}
|
|
}
|
|
|
|
go readMessage(conn1, "client1")
|
|
go readMessage(conn2, "client2")
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestWebSocket_GlobalBroadcast(t *testing.T) {
|
|
hub := newTestHub(t)
|
|
server := newTestServer(t, hub)
|
|
defer server.Close()
|
|
|
|
// Connect two clients without room (global)
|
|
conn1, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws"), nil)
|
|
if err != nil {
|
|
t.Fatalf("client 1 failed to connect: %v", err)
|
|
}
|
|
defer conn1.Close()
|
|
|
|
conn2, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws"), nil)
|
|
if err != nil {
|
|
t.Fatalf("client 2 failed to connect: %v", err)
|
|
}
|
|
defer conn2.Close()
|
|
|
|
// Give hub time to register
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
if hub.ConnectionCount() != 2 {
|
|
t.Fatalf("expected 2 connections, got %d", hub.ConnectionCount())
|
|
}
|
|
|
|
// Send a global message (no room)
|
|
msg := realtime.Message{
|
|
Type: realtime.MessageTypeChat,
|
|
Data: json.RawMessage(`{"text":"global message"}`),
|
|
}
|
|
if err := conn1.WriteJSON(msg); err != nil {
|
|
t.Fatalf("failed to send message: %v", err)
|
|
}
|
|
|
|
// Both clients should receive
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
|
|
readMessage := func(conn *websocket.Conn, name string) {
|
|
defer wg.Done()
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
var received realtime.Message
|
|
if err := conn.ReadJSON(&received); err != nil {
|
|
t.Errorf("%s failed to read message: %v", name, err)
|
|
return
|
|
}
|
|
}
|
|
|
|
go readMessage(conn1, "client1")
|
|
go readMessage(conn2, "client2")
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestWebSocket_Stats(t *testing.T) {
|
|
hub := newTestHub(t)
|
|
server := newTestServer(t, hub)
|
|
defer server.Close()
|
|
|
|
// Check stats before any connections
|
|
resp, err := http.Get(server.URL + "/ws/stats")
|
|
if err != nil {
|
|
t.Fatalf("failed to get stats: %v", err)
|
|
}
|
|
resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
|
}
|
|
|
|
// Connect a client
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws"), nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to connect: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Give hub time to register
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Check stats after connection
|
|
resp, err = http.Get(server.URL + "/ws/stats")
|
|
if err != nil {
|
|
t.Fatalf("failed to get stats: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var result map[string]any
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
t.Fatalf("failed to decode stats: %v", err)
|
|
}
|
|
|
|
data, ok := result["data"].(map[string]any)
|
|
if !ok {
|
|
t.Fatal("expected 'data' field in response")
|
|
}
|
|
|
|
count, ok := data["total_connections"].(float64)
|
|
if !ok {
|
|
t.Fatal("expected 'total_connections' in data")
|
|
}
|
|
|
|
if int(count) != 1 {
|
|
t.Errorf("expected 1 connection in stats, got %d", int(count))
|
|
}
|
|
}
|
|
|
|
func TestWebSocket_Disconnect(t *testing.T) {
|
|
hub := newTestHub(t)
|
|
server := newTestServer(t, hub)
|
|
defer server.Close()
|
|
|
|
// Connect
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws"), nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to connect: %v", err)
|
|
}
|
|
|
|
// Give hub time to register
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
if hub.ConnectionCount() != 1 {
|
|
t.Fatalf("expected 1 connection, got %d", hub.ConnectionCount())
|
|
}
|
|
|
|
// Close connection
|
|
conn.Close()
|
|
|
|
// Give hub time to unregister
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
if hub.ConnectionCount() != 0 {
|
|
t.Errorf("expected 0 connections after disconnect, got %d", hub.ConnectionCount())
|
|
}
|
|
}
|
|
|
|
func TestWebSocket_PingPong(t *testing.T) {
|
|
hub := newTestHub(t)
|
|
server := newTestServer(t, hub)
|
|
defer server.Close()
|
|
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws"), nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to connect: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Send ping message
|
|
msg := realtime.Message{
|
|
Type: realtime.MessageTypePing,
|
|
}
|
|
if err := conn.WriteJSON(msg); err != nil {
|
|
t.Fatalf("failed to send ping: %v", err)
|
|
}
|
|
|
|
// Should receive pong
|
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
|
var pong realtime.Message
|
|
if err := conn.ReadJSON(&pong); err != nil {
|
|
t.Fatalf("failed to read pong: %v", err)
|
|
}
|
|
|
|
if pong.Type != realtime.MessageTypePong {
|
|
t.Errorf("expected pong message, got %s", pong.Type)
|
|
}
|
|
}
|
|
|
|
func TestWebSocket_RoomIsolation(t *testing.T) {
|
|
hub := newTestHub(t)
|
|
server := newTestServer(t, hub)
|
|
defer server.Close()
|
|
|
|
// Connect client to room A
|
|
connA, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws/room-a"), nil)
|
|
if err != nil {
|
|
t.Fatalf("client A failed to connect: %v", err)
|
|
}
|
|
defer connA.Close()
|
|
|
|
// Connect client to room B
|
|
connB, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws/room-b"), nil)
|
|
if err != nil {
|
|
t.Fatalf("client B failed to connect: %v", err)
|
|
}
|
|
defer connB.Close()
|
|
|
|
// Give hub time to register
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Send message to room A
|
|
msg := realtime.Message{
|
|
Type: realtime.MessageTypeChat,
|
|
Room: "room-a",
|
|
Data: json.RawMessage(`{"text":"room A only"}`),
|
|
}
|
|
if err := connA.WriteJSON(msg); err != nil {
|
|
t.Fatalf("failed to send message: %v", err)
|
|
}
|
|
|
|
// Client A should receive it
|
|
connA.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
|
|
var received realtime.Message
|
|
if err := connA.ReadJSON(&received); err != nil {
|
|
t.Errorf("client A should receive message: %v", err)
|
|
}
|
|
|
|
// Client B should NOT receive it (timeout expected)
|
|
connB.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
|
|
if err := connB.ReadJSON(&received); err == nil {
|
|
t.Error("client B should NOT receive room-a message")
|
|
}
|
|
}
|