package realtime import ( "context" "net/http" "github.com/go-chi/chi/v5" "git.threesix.ai/jordan/slack5-1770574304/pkg/auth" "git.threesix.ai/jordan/slack5-1770574304/pkg/logging" ) // HandlerConfig configures the WebSocket handler. type HandlerConfig struct { // Broadcaster handles cross-pod message distribution. // If nil, messages only broadcast to local pod (single-pod mode). Broadcaster Broadcaster // OnConnect is called when a client connects (optional). OnConnect func(conn Connection) // OnDisconnect is called when a client disconnects (optional). OnDisconnect func(conn Connection) // OnMessage is called for incoming messages (optional). // Return the message to broadcast it, or nil to suppress. OnMessage func(conn Connection, msg *Message) *Message // AuthRequired requires authentication for WebSocket connections. // If true, unauthenticated connections are rejected. AuthRequired bool } // Handler handles WebSocket connections. type Handler struct { hub Hub broadcaster Broadcaster logger *logging.Logger config HandlerConfig } // NewHandler creates a new WebSocket handler. func NewHandler(hub Hub, logger *logging.Logger, cfg HandlerConfig) *Handler { return &Handler{ hub: hub, broadcaster: cfg.Broadcaster, logger: logger.WithComponent("ws-handler"), config: cfg, } } // Routes returns the chi router for this handler. // Mount at your desired path: r.Mount("/ws", handler.Routes()) func (h *Handler) Routes() http.Handler { r := chi.NewRouter() r.Get("/", h.HandleWebSocket) r.Get("/{room}", h.HandleWebSocket) return r } // HandleWebSocket upgrades HTTP to WebSocket and manages the connection lifecycle. func (h *Handler) HandleWebSocket(w http.ResponseWriter, r *http.Request) { // Extract room from URL (optional) room := chi.URLParam(r, "room") if room == "" { room = r.URL.Query().Get("room") } // Extract user from auth context var userID string if claims := auth.ClaimsFromContext(r.Context()); claims != nil { userID = claims.Subject } // Check auth requirement if h.config.AuthRequired && userID == "" { http.Error(w, "authentication required", http.StatusUnauthorized) return } // Upgrade connection conn, err := UpgradeConnection(w, r) if err != nil { h.logger.Warn("websocket upgrade failed", "error", err) return } // Create client client := NewWSClient(h.hub, conn, h.logger, WSClientConfig{ UserID: userID, OnMessage: h.makeMessageHandler(room), }) h.logger.Info("websocket connection established", "client_id", client.ID(), "user_id", userID, "room", room, ) // Join room if specified if room != "" { h.hub.JoinRoom(client, room) } // Notify connect callback if h.config.OnConnect != nil { h.config.OnConnect(client) } // Run connection (blocks until closed) client.Run(r.Context()) // Notify disconnect callback if h.config.OnDisconnect != nil { h.config.OnDisconnect(client) } h.logger.Info("websocket connection closed", "client_id", client.ID(), "user_id", userID, ) } // makeMessageHandler creates the message callback for a client. func (h *Handler) makeMessageHandler(defaultRoom string) func(*WSClient, *Message) { return func(client *WSClient, msg *Message) { // Set room if not specified if msg.Room == "" { msg.Room = defaultRoom } // Call user callback for message transformation/filtering if h.config.OnMessage != nil { msg = h.config.OnMessage(client, msg) if msg == nil { return // Message suppressed } } // Broadcast via Redis if available, otherwise local only if h.broadcaster != nil { if err := h.broadcaster.Publish(context.Background(), msg); err != nil { h.logger.Warn("failed to publish to broadcaster", "error", err, "message_id", msg.ID, ) // Fall back to local broadcast h.hub.Broadcast(msg) } } else { h.hub.Broadcast(msg) } } } // Stats returns connection statistics. type Stats struct { TotalConnections int `json:"total_connections"` RoomCounts map[string]int `json:"room_counts,omitempty"` } // GetStats returns current connection statistics. func (h *Handler) GetStats(rooms ...string) Stats { stats := Stats{ TotalConnections: h.hub.ConnectionCount(), } if len(rooms) > 0 { stats.RoomCounts = make(map[string]int) for _, room := range rooms { stats.RoomCounts[room] = h.hub.RoomCount(room) } } return stats }