sp3-solo-1770327084/services/chat-api/internal/api/handlers/ws_test.go
rdev-worker 82c41e819b
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
build: /implement-feature websocket-chat --requirements 'GET /ws upgrades to...
2026-02-05 21:58:16 +00:00

387 lines
9.7 KiB
Go

package handlers
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/gorilla/websocket"
"git.threesix.ai/jordan/sp3-solo-1770327084/pkg/httpresponse"
"git.threesix.ai/jordan/sp3-solo-1770327084/pkg/logging"
"git.threesix.ai/jordan/sp3-solo-1770327084/pkg/realtime"
)
// testHub wraps LocalHub for testing.
func newTestHub(t *testing.T) *realtime.LocalHub {
t.Helper()
hub := realtime.NewHub(logging.Nop())
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
go hub.Run(ctx)
return hub
}
// testServer creates an HTTP test server with WebSocket handler.
func newTestServer(t *testing.T, hub realtime.Hub) *httptest.Server {
t.Helper()
wsHandler := realtime.NewHandler(hub, logging.Nop(), realtime.HandlerConfig{
Broadcaster: nil, // No Redis for unit tests
AuthRequired: false,
})
r := chi.NewRouter()
r.Mount("/ws", wsHandler.Routes())
r.Get("/ws/stats", func(w http.ResponseWriter, r *http.Request) {
stats := wsHandler.GetStats()
httpresponse.OK(w, r, stats)
})
return httptest.NewServer(r)
}
// wsURL converts http:// URL to ws:// URL.
func wsURL(server *httptest.Server, path string) string {
return "ws" + strings.TrimPrefix(server.URL, "http") + path
}
func TestWebSocket_Upgrade(t *testing.T) {
hub := newTestHub(t)
server := newTestServer(t, hub)
defer server.Close()
// Connect via WebSocket
conn, resp, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws"), nil)
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
defer conn.Close()
if resp.StatusCode != http.StatusSwitchingProtocols {
t.Errorf("expected status 101, got %d", resp.StatusCode)
}
// Give hub time to register
time.Sleep(50 * time.Millisecond)
if hub.ConnectionCount() != 1 {
t.Errorf("expected 1 connection, got %d", hub.ConnectionCount())
}
}
func TestWebSocket_RoomJoinViaURL(t *testing.T) {
hub := newTestHub(t)
server := newTestServer(t, hub)
defer server.Close()
// Connect to specific room via URL path
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws/test-room"), nil)
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
defer conn.Close()
// Give hub time to register and join room
time.Sleep(50 * time.Millisecond)
if hub.RoomCount("test-room") != 1 {
t.Errorf("expected 1 client in room, got %d", hub.RoomCount("test-room"))
}
}
func TestWebSocket_RoomJoinViaQuery(t *testing.T) {
hub := newTestHub(t)
server := newTestServer(t, hub)
defer server.Close()
// Connect with room query parameter
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws?room=query-room"), nil)
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
defer conn.Close()
// Give hub time to register and join room
time.Sleep(50 * time.Millisecond)
if hub.RoomCount("query-room") != 1 {
t.Errorf("expected 1 client in query-room, got %d", hub.RoomCount("query-room"))
}
}
func TestWebSocket_MessageBroadcast(t *testing.T) {
hub := newTestHub(t)
server := newTestServer(t, hub)
defer server.Close()
// Connect two clients to the same room
conn1, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws/broadcast-room"), nil)
if err != nil {
t.Fatalf("client 1 failed to connect: %v", err)
}
defer conn1.Close()
conn2, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws/broadcast-room"), nil)
if err != nil {
t.Fatalf("client 2 failed to connect: %v", err)
}
defer conn2.Close()
// Give hub time to register both clients
time.Sleep(50 * time.Millisecond)
if hub.RoomCount("broadcast-room") != 2 {
t.Fatalf("expected 2 clients in room, got %d", hub.RoomCount("broadcast-room"))
}
// Send a message from client 1
msg := realtime.Message{
Type: realtime.MessageTypeChat,
Room: "broadcast-room",
Data: json.RawMessage(`{"text":"hello"}`),
}
if err := conn1.WriteJSON(msg); err != nil {
t.Fatalf("failed to send message: %v", err)
}
// Both clients should receive the message
var wg sync.WaitGroup
wg.Add(2)
readMessage := func(conn *websocket.Conn, name string) {
defer wg.Done()
conn.SetReadDeadline(time.Now().Add(time.Second))
var received realtime.Message
if err := conn.ReadJSON(&received); err != nil {
t.Errorf("%s failed to read message: %v", name, err)
return
}
if received.Type != realtime.MessageTypeChat {
t.Errorf("%s: expected type 'chat', got %s", name, received.Type)
}
}
go readMessage(conn1, "client1")
go readMessage(conn2, "client2")
wg.Wait()
}
func TestWebSocket_GlobalBroadcast(t *testing.T) {
hub := newTestHub(t)
server := newTestServer(t, hub)
defer server.Close()
// Connect two clients without room (global)
conn1, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws"), nil)
if err != nil {
t.Fatalf("client 1 failed to connect: %v", err)
}
defer conn1.Close()
conn2, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws"), nil)
if err != nil {
t.Fatalf("client 2 failed to connect: %v", err)
}
defer conn2.Close()
// Give hub time to register
time.Sleep(50 * time.Millisecond)
if hub.ConnectionCount() != 2 {
t.Fatalf("expected 2 connections, got %d", hub.ConnectionCount())
}
// Send a global message (no room)
msg := realtime.Message{
Type: realtime.MessageTypeChat,
Data: json.RawMessage(`{"text":"global message"}`),
}
if err := conn1.WriteJSON(msg); err != nil {
t.Fatalf("failed to send message: %v", err)
}
// Both clients should receive
var wg sync.WaitGroup
wg.Add(2)
readMessage := func(conn *websocket.Conn, name string) {
defer wg.Done()
conn.SetReadDeadline(time.Now().Add(time.Second))
var received realtime.Message
if err := conn.ReadJSON(&received); err != nil {
t.Errorf("%s failed to read message: %v", name, err)
return
}
}
go readMessage(conn1, "client1")
go readMessage(conn2, "client2")
wg.Wait()
}
func TestWebSocket_Stats(t *testing.T) {
hub := newTestHub(t)
server := newTestServer(t, hub)
defer server.Close()
// Check stats before any connections
resp, err := http.Get(server.URL + "/ws/stats")
if err != nil {
t.Fatalf("failed to get stats: %v", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
// Connect a client
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws"), nil)
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
defer conn.Close()
// Give hub time to register
time.Sleep(50 * time.Millisecond)
// Check stats after connection
resp, err = http.Get(server.URL + "/ws/stats")
if err != nil {
t.Fatalf("failed to get stats: %v", err)
}
defer resp.Body.Close()
var result map[string]any
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
t.Fatalf("failed to decode stats: %v", err)
}
data, ok := result["data"].(map[string]any)
if !ok {
t.Fatal("expected 'data' field in response")
}
count, ok := data["total_connections"].(float64)
if !ok {
t.Fatal("expected 'total_connections' in data")
}
if int(count) != 1 {
t.Errorf("expected 1 connection in stats, got %d", int(count))
}
}
func TestWebSocket_Disconnect(t *testing.T) {
hub := newTestHub(t)
server := newTestServer(t, hub)
defer server.Close()
// Connect
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws"), nil)
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
// Give hub time to register
time.Sleep(50 * time.Millisecond)
if hub.ConnectionCount() != 1 {
t.Fatalf("expected 1 connection, got %d", hub.ConnectionCount())
}
// Close connection
conn.Close()
// Give hub time to unregister
time.Sleep(100 * time.Millisecond)
if hub.ConnectionCount() != 0 {
t.Errorf("expected 0 connections after disconnect, got %d", hub.ConnectionCount())
}
}
func TestWebSocket_PingPong(t *testing.T) {
hub := newTestHub(t)
server := newTestServer(t, hub)
defer server.Close()
conn, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws"), nil)
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
defer conn.Close()
// Send ping message
msg := realtime.Message{
Type: realtime.MessageTypePing,
}
if err := conn.WriteJSON(msg); err != nil {
t.Fatalf("failed to send ping: %v", err)
}
// Should receive pong
conn.SetReadDeadline(time.Now().Add(time.Second))
var pong realtime.Message
if err := conn.ReadJSON(&pong); err != nil {
t.Fatalf("failed to read pong: %v", err)
}
if pong.Type != realtime.MessageTypePong {
t.Errorf("expected pong message, got %s", pong.Type)
}
}
func TestWebSocket_RoomIsolation(t *testing.T) {
hub := newTestHub(t)
server := newTestServer(t, hub)
defer server.Close()
// Connect client to room A
connA, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws/room-a"), nil)
if err != nil {
t.Fatalf("client A failed to connect: %v", err)
}
defer connA.Close()
// Connect client to room B
connB, _, err := websocket.DefaultDialer.Dial(wsURL(server, "/ws/room-b"), nil)
if err != nil {
t.Fatalf("client B failed to connect: %v", err)
}
defer connB.Close()
// Give hub time to register
time.Sleep(50 * time.Millisecond)
// Send message to room A
msg := realtime.Message{
Type: realtime.MessageTypeChat,
Room: "room-a",
Data: json.RawMessage(`{"text":"room A only"}`),
}
if err := connA.WriteJSON(msg); err != nil {
t.Fatalf("failed to send message: %v", err)
}
// Client A should receive it
connA.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
var received realtime.Message
if err := connA.ReadJSON(&received); err != nil {
t.Errorf("client A should receive message: %v", err)
}
// Client B should NOT receive it (timeout expected)
connB.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
if err := connB.ReadJSON(&received); err == nil {
t.Error("client B should NOT receive room-a message")
}
}