sp3-verify-1770325830/services/chat-api/internal/api/handlers/websocket_test.go
rdev-worker 42c1444274
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
build: /implement-feature websocket-chat --requirements 'GET /ws upgrades to...
2026-02-05 21:50:17 +00:00

272 lines
6.6 KiB
Go

package handlers
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"git.threesix.ai/jordan/sp3-verify-1770325830/pkg/logging"
"git.threesix.ai/jordan/sp3-verify-1770325830/pkg/realtime"
)
// testSetup creates a hub and handler for WebSocket tests.
func testSetup(t *testing.T) (*realtime.LocalHub, *realtime.Handler, context.CancelFunc) {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
logger := logging.Nop()
hub := realtime.NewHub(logger)
go hub.Run(ctx)
handler := realtime.NewHandler(hub, logger, realtime.HandlerConfig{
Broadcaster: nil, // No Redis for tests
AuthRequired: false,
OnMessage: func(conn realtime.Connection, msg *realtime.Message) *realtime.Message {
if msg.Type == "" {
msg.Type = realtime.MessageTypeChat
}
return msg
},
})
return hub, handler, cancel
}
// dial connects to a WebSocket test server.
func dial(t *testing.T, server *httptest.Server, path string) *websocket.Conn {
t.Helper()
url := "ws" + strings.TrimPrefix(server.URL, "http") + path
conn, _, err := websocket.DefaultDialer.Dial(url, nil)
if err != nil {
t.Fatalf("failed to dial websocket: %v", err)
}
return conn
}
func TestWebSocket_Connection(t *testing.T) {
hub, handler, cancel := testSetup(t)
defer cancel()
server := httptest.NewServer(handler.Routes())
defer server.Close()
// Connect
conn := dial(t, server, "/")
defer conn.Close()
// Wait for registration
time.Sleep(50 * time.Millisecond)
// Verify connection count
if count := hub.ConnectionCount(); count != 1 {
t.Errorf("expected 1 connection, got %d", count)
}
// Close and verify cleanup
conn.Close()
time.Sleep(50 * time.Millisecond)
if count := hub.ConnectionCount(); count != 0 {
t.Errorf("expected 0 connections after close, got %d", count)
}
}
func TestWebSocket_RoomConnection(t *testing.T) {
hub, handler, cancel := testSetup(t)
defer cancel()
server := httptest.NewServer(handler.Routes())
defer server.Close()
// Connect to room
conn := dial(t, server, "/test-room")
defer conn.Close()
// Wait for registration
time.Sleep(50 * time.Millisecond)
// Verify room count
if count := hub.RoomCount("test-room"); count != 1 {
t.Errorf("expected 1 connection in room, got %d", count)
}
}
func TestWebSocket_MessageBroadcast(t *testing.T) {
_, handler, cancel := testSetup(t)
defer cancel()
server := httptest.NewServer(handler.Routes())
defer server.Close()
// Connect two clients
conn1 := dial(t, server, "/")
defer conn1.Close()
conn2 := dial(t, server, "/")
defer conn2.Close()
// Wait for registration
time.Sleep(50 * time.Millisecond)
// Send message from conn1
msg := realtime.Message{
Type: realtime.MessageTypeChat,
Data: json.RawMessage(`{"content":"Hello"}`),
}
if err := conn1.WriteJSON(msg); err != nil {
t.Fatalf("failed to send message: %v", err)
}
// Read from conn2
conn2.SetReadDeadline(time.Now().Add(2 * time.Second))
var received realtime.Message
if err := conn2.ReadJSON(&received); err != nil {
t.Fatalf("failed to receive message: %v", err)
}
if received.Type != realtime.MessageTypeChat {
t.Errorf("expected type %s, got %s", realtime.MessageTypeChat, received.Type)
}
}
func TestWebSocket_RoomIsolation(t *testing.T) {
_, handler, cancel := testSetup(t)
defer cancel()
server := httptest.NewServer(handler.Routes())
defer server.Close()
// Connect to room1
conn1 := dial(t, server, "/room1")
defer conn1.Close()
// Connect to room2
conn2 := dial(t, server, "/room2")
defer conn2.Close()
// Wait for registration
time.Sleep(50 * time.Millisecond)
// Send message to room1
msg := realtime.Message{
Type: realtime.MessageTypeChat,
Room: "room1",
Data: json.RawMessage(`{"content":"Room1 Only"}`),
}
if err := conn1.WriteJSON(msg); err != nil {
t.Fatalf("failed to send message: %v", err)
}
// Set short timeout on conn2 - it should NOT receive the message
conn2.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
var received realtime.Message
err := conn2.ReadJSON(&received)
// We expect a timeout error since room2 should not receive room1 messages
if err == nil {
t.Errorf("room2 client received message meant for room1: %+v", received)
}
}
func TestWebSocket_PingPong(t *testing.T) {
_, handler, cancel := testSetup(t)
defer cancel()
server := httptest.NewServer(handler.Routes())
defer server.Close()
conn := dial(t, server, "/")
defer conn.Close()
// Wait for registration
time.Sleep(50 * time.Millisecond)
// Send ping
msg := realtime.Message{
Type: realtime.MessageTypePing,
}
if err := conn.WriteJSON(msg); err != nil {
t.Fatalf("failed to send ping: %v", err)
}
// Expect pong
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
var received realtime.Message
if err := conn.ReadJSON(&received); err != nil {
t.Fatalf("failed to receive pong: %v", err)
}
if received.Type != realtime.MessageTypePong {
t.Errorf("expected type %s, got %s", realtime.MessageTypePong, received.Type)
}
}
func TestWebSocket_AuthRequired(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := logging.Nop()
hub := realtime.NewHub(logger)
go hub.Run(ctx)
handler := realtime.NewHandler(hub, logger, realtime.HandlerConfig{
AuthRequired: true, // Require auth
})
server := httptest.NewServer(handler.Routes())
defer server.Close()
// Try to connect without auth
url := "ws" + strings.TrimPrefix(server.URL, "http") + "/"
_, resp, err := websocket.DefaultDialer.Dial(url, nil)
// Should fail with 401
if err == nil {
t.Error("expected connection to fail without auth")
}
if resp != nil && resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected status 401, got %d", resp.StatusCode)
}
}
func TestWebSocket_MultipleClients(t *testing.T) {
hub, handler, cancel := testSetup(t)
defer cancel()
server := httptest.NewServer(handler.Routes())
defer server.Close()
// Connect multiple clients
const numClients = 5
conns := make([]*websocket.Conn, numClients)
for i := 0; i < numClients; i++ {
conns[i] = dial(t, server, "/")
defer conns[i].Close()
}
// Wait for registration
time.Sleep(100 * time.Millisecond)
if count := hub.ConnectionCount(); count != numClients {
t.Errorf("expected %d connections, got %d", numClients, count)
}
// Close half the clients
for i := 0; i < numClients/2; i++ {
conns[i].Close()
}
time.Sleep(100 * time.Millisecond)
expected := numClients - numClients/2
if count := hub.ConnectionCount(); count != expected {
t.Errorf("expected %d connections after partial close, got %d", expected, count)
}
}