258 lines
6.8 KiB
Go
258 lines
6.8 KiB
Go
package realtime
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/google/uuid"
|
|
|
|
"git.threesix.ai/jordan/persona-community-2/pkg/auth"
|
|
"git.threesix.ai/jordan/persona-community-2/pkg/logging"
|
|
)
|
|
|
|
// HandlerConfig configures the WebSocket handler.
|
|
type HandlerConfig struct {
|
|
// Broadcaster handles cross-pod message distribution.
|
|
// If nil, messages only broadcast to local pod (single-pod mode).
|
|
Broadcaster Broadcaster
|
|
|
|
// OnConnect is called when a client connects (optional).
|
|
OnConnect func(conn Connection)
|
|
|
|
// OnDisconnect is called when a client disconnects (optional).
|
|
OnDisconnect func(conn Connection)
|
|
|
|
// OnMessage is called for incoming messages (optional).
|
|
// Return the message to broadcast it, or nil to suppress.
|
|
OnMessage func(conn Connection, msg *Message) *Message
|
|
|
|
// AuthRequired requires authentication for WebSocket connections.
|
|
// If true, unauthenticated connections are rejected.
|
|
AuthRequired bool
|
|
|
|
// JWTValidator validates JWT tokens from query parameters (optional).
|
|
// If set, tokens passed via ?token= query param will be validated.
|
|
JWTValidator auth.Validator
|
|
|
|
// AllowedOrigins is a whitelist of allowed origins for WebSocket connections.
|
|
// If empty, all origins are allowed (suitable for development only).
|
|
// In production, set this to your frontend domain(s).
|
|
AllowedOrigins []string
|
|
}
|
|
|
|
// Handler handles WebSocket connections.
|
|
type Handler struct {
|
|
hub Hub
|
|
broadcaster Broadcaster
|
|
logger *logging.Logger
|
|
config HandlerConfig
|
|
}
|
|
|
|
// NewHandler creates a new WebSocket handler.
|
|
func NewHandler(hub Hub, logger *logging.Logger, cfg HandlerConfig) *Handler {
|
|
return &Handler{
|
|
hub: hub,
|
|
broadcaster: cfg.Broadcaster,
|
|
logger: logger.WithComponent("ws-handler"),
|
|
config: cfg,
|
|
}
|
|
}
|
|
|
|
// Routes returns the chi router for this handler.
|
|
// Mount at your desired path: r.Mount("/ws", handler.Routes())
|
|
func (h *Handler) Routes() http.Handler {
|
|
r := chi.NewRouter()
|
|
r.Get("/", h.HandleWebSocket)
|
|
r.Get("/{room}", h.HandleWebSocket)
|
|
return r
|
|
}
|
|
|
|
// HandleWebSocket upgrades HTTP to WebSocket and manages the connection lifecycle.
|
|
func (h *Handler) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|
// Extract room from URL (optional)
|
|
room := chi.URLParam(r, "room")
|
|
if room == "" {
|
|
room = r.URL.Query().Get("room")
|
|
}
|
|
|
|
// Extract user from auth context OR from query parameters
|
|
var userID string
|
|
var userName string
|
|
|
|
if user := auth.GetUser(r.Context()); user != nil {
|
|
userID = user.ID
|
|
userName = user.Email // Use email as display name from auth context
|
|
} else if token := r.URL.Query().Get("token"); token != "" && h.config.JWTValidator != nil {
|
|
// Validate JWT token from query parameter
|
|
user, err := h.config.JWTValidator.Validate(r.Context(), token)
|
|
if err != nil {
|
|
h.logger.Debug("token validation failed", "error", err)
|
|
} else {
|
|
userID = user.ID
|
|
userName = user.Email
|
|
}
|
|
}
|
|
|
|
// Fall back to query params for userId/userName (for testing or anonymous users)
|
|
if userID == "" {
|
|
userID = r.URL.Query().Get("userId")
|
|
}
|
|
if userName == "" {
|
|
userName = r.URL.Query().Get("userName")
|
|
}
|
|
|
|
// Check auth requirement
|
|
if h.config.AuthRequired && userID == "" {
|
|
http.Error(w, "authentication required", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Upgrade connection with origin check
|
|
conn, err := UpgradeConnectionWithOrigins(w, r, h.config.AllowedOrigins)
|
|
if err != nil {
|
|
h.logger.Warn("websocket upgrade failed", "error", err)
|
|
return
|
|
}
|
|
|
|
// Create client
|
|
client := NewWSClient(h.hub, conn, h.logger, WSClientConfig{
|
|
UserID: userID,
|
|
UserName: userName,
|
|
OnMessage: h.makeMessageHandler(room),
|
|
})
|
|
|
|
h.logger.Info("websocket connection established",
|
|
"client_id", client.ID(),
|
|
"user_id", userID,
|
|
"user_name", userName,
|
|
"room", room,
|
|
)
|
|
|
|
// Join room if specified
|
|
if room != "" {
|
|
h.hub.JoinRoom(client, room)
|
|
}
|
|
|
|
// Broadcast presence (user joined)
|
|
if room != "" {
|
|
h.broadcastPresence(client, room, userID, userName, PresenceOnline)
|
|
}
|
|
|
|
// Notify connect callback
|
|
if h.config.OnConnect != nil {
|
|
h.config.OnConnect(client)
|
|
}
|
|
|
|
// Run connection (blocks until closed)
|
|
client.Run(r.Context())
|
|
|
|
// Broadcast presence (user left)
|
|
if room != "" {
|
|
h.broadcastPresence(client, room, userID, userName, PresenceOffline)
|
|
}
|
|
|
|
// Notify disconnect callback
|
|
if h.config.OnDisconnect != nil {
|
|
h.config.OnDisconnect(client)
|
|
}
|
|
|
|
h.logger.Info("websocket connection closed",
|
|
"client_id", client.ID(),
|
|
"user_id", userID,
|
|
)
|
|
}
|
|
|
|
// makeMessageHandler creates the message callback for a client.
|
|
func (h *Handler) makeMessageHandler(defaultRoom string) func(*WSClient, *Message) {
|
|
return func(client *WSClient, msg *Message) {
|
|
// Assign server-side fields for proper message tracking
|
|
if msg.ID == "" {
|
|
msg.ID = uuid.New().String()
|
|
}
|
|
if msg.Room == "" {
|
|
msg.Room = defaultRoom
|
|
}
|
|
if msg.Timestamp.IsZero() {
|
|
msg.Timestamp = time.Now().UTC()
|
|
}
|
|
|
|
// Call user callback for message transformation/filtering
|
|
if h.config.OnMessage != nil {
|
|
msg = h.config.OnMessage(client, msg)
|
|
if msg == nil {
|
|
return // Message suppressed
|
|
}
|
|
}
|
|
|
|
// Broadcast via Redis if available, otherwise local only
|
|
if h.broadcaster != nil {
|
|
if err := h.broadcaster.Publish(context.Background(), msg); err != nil {
|
|
h.logger.Warn("failed to publish to broadcaster",
|
|
"error", err,
|
|
"message_id", msg.ID,
|
|
)
|
|
// Fall back to local broadcast
|
|
h.hub.Broadcast(msg)
|
|
}
|
|
} else {
|
|
h.hub.Broadcast(msg)
|
|
}
|
|
}
|
|
}
|
|
|
|
// broadcastPresence sends a presence message to the room.
|
|
func (h *Handler) broadcastPresence(client Connection, room, userID, userName, status string) {
|
|
presenceData, err := SystemMessage(MessageTypePresence, PresenceData{
|
|
UserID: userID,
|
|
UserName: userName,
|
|
Status: status,
|
|
})
|
|
if err != nil {
|
|
h.logger.Warn("failed to create presence message", "error", err)
|
|
return
|
|
}
|
|
|
|
presenceData.ID = uuid.New().String()
|
|
presenceData.Room = room
|
|
presenceData.From = client.ID()
|
|
presenceData.Timestamp = time.Now().UTC()
|
|
|
|
// Broadcast via Redis if available, otherwise local only
|
|
if h.broadcaster != nil {
|
|
if err := h.broadcaster.Publish(context.Background(), presenceData); err != nil {
|
|
h.logger.Warn("failed to publish presence to broadcaster",
|
|
"error", err,
|
|
"status", status,
|
|
)
|
|
// Fall back to local broadcast
|
|
h.hub.Broadcast(presenceData)
|
|
}
|
|
} else {
|
|
h.hub.Broadcast(presenceData)
|
|
}
|
|
}
|
|
|
|
// Stats returns connection statistics.
|
|
type Stats struct {
|
|
TotalConnections int `json:"total_connections"`
|
|
RoomCounts map[string]int `json:"room_counts,omitempty"`
|
|
}
|
|
|
|
// GetStats returns current connection statistics.
|
|
func (h *Handler) GetStats(rooms ...string) Stats {
|
|
stats := Stats{
|
|
TotalConnections: h.hub.ConnectionCount(),
|
|
}
|
|
|
|
if len(rooms) > 0 {
|
|
stats.RoomCounts = make(map[string]int)
|
|
for _, room := range rooms {
|
|
stats.RoomCounts[room] = h.hub.RoomCount(room)
|
|
}
|
|
}
|
|
|
|
return stats
|
|
}
|