177 lines
4.4 KiB
Go
177 lines
4.4 KiB
Go
package realtime
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
|
|
"git.threesix.ai/jordan/slate-verify-1770510662/pkg/auth"
|
|
"git.threesix.ai/jordan/slate-verify-1770510662/pkg/logging"
|
|
)
|
|
|
|
// HandlerConfig configures the WebSocket handler.
|
|
type HandlerConfig struct {
|
|
// Broadcaster handles cross-pod message distribution.
|
|
// If nil, messages only broadcast to local pod (single-pod mode).
|
|
Broadcaster Broadcaster
|
|
|
|
// OnConnect is called when a client connects (optional).
|
|
OnConnect func(conn Connection)
|
|
|
|
// OnDisconnect is called when a client disconnects (optional).
|
|
OnDisconnect func(conn Connection)
|
|
|
|
// OnMessage is called for incoming messages (optional).
|
|
// Return the message to broadcast it, or nil to suppress.
|
|
OnMessage func(conn Connection, msg *Message) *Message
|
|
|
|
// AuthRequired requires authentication for WebSocket connections.
|
|
// If true, unauthenticated connections are rejected.
|
|
AuthRequired bool
|
|
}
|
|
|
|
// Handler handles WebSocket connections.
|
|
type Handler struct {
|
|
hub Hub
|
|
broadcaster Broadcaster
|
|
logger *logging.Logger
|
|
config HandlerConfig
|
|
}
|
|
|
|
// NewHandler creates a new WebSocket handler.
|
|
func NewHandler(hub Hub, logger *logging.Logger, cfg HandlerConfig) *Handler {
|
|
return &Handler{
|
|
hub: hub,
|
|
broadcaster: cfg.Broadcaster,
|
|
logger: logger.WithComponent("ws-handler"),
|
|
config: cfg,
|
|
}
|
|
}
|
|
|
|
// Routes returns the chi router for this handler.
|
|
// Mount at your desired path: r.Mount("/ws", handler.Routes())
|
|
func (h *Handler) Routes() http.Handler {
|
|
r := chi.NewRouter()
|
|
r.Get("/", h.HandleWebSocket)
|
|
r.Get("/{room}", h.HandleWebSocket)
|
|
return r
|
|
}
|
|
|
|
// HandleWebSocket upgrades HTTP to WebSocket and manages the connection lifecycle.
|
|
func (h *Handler) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|
// Extract room from URL (optional)
|
|
room := chi.URLParam(r, "room")
|
|
if room == "" {
|
|
room = r.URL.Query().Get("room")
|
|
}
|
|
|
|
// Extract user from auth context
|
|
var userID string
|
|
if claims := auth.ClaimsFromContext(r.Context()); claims != nil {
|
|
userID = claims.Subject
|
|
}
|
|
|
|
// Check auth requirement
|
|
if h.config.AuthRequired && userID == "" {
|
|
http.Error(w, "authentication required", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Upgrade connection
|
|
conn, err := UpgradeConnection(w, r)
|
|
if err != nil {
|
|
h.logger.Warn("websocket upgrade failed", "error", err)
|
|
return
|
|
}
|
|
|
|
// Create client
|
|
client := NewWSClient(h.hub, conn, h.logger, WSClientConfig{
|
|
UserID: userID,
|
|
OnMessage: h.makeMessageHandler(room),
|
|
})
|
|
|
|
h.logger.Info("websocket connection established",
|
|
"client_id", client.ID(),
|
|
"user_id", userID,
|
|
"room", room,
|
|
)
|
|
|
|
// Join room if specified
|
|
if room != "" {
|
|
h.hub.JoinRoom(client, room)
|
|
}
|
|
|
|
// Notify connect callback
|
|
if h.config.OnConnect != nil {
|
|
h.config.OnConnect(client)
|
|
}
|
|
|
|
// Run connection (blocks until closed)
|
|
client.Run(r.Context())
|
|
|
|
// Notify disconnect callback
|
|
if h.config.OnDisconnect != nil {
|
|
h.config.OnDisconnect(client)
|
|
}
|
|
|
|
h.logger.Info("websocket connection closed",
|
|
"client_id", client.ID(),
|
|
"user_id", userID,
|
|
)
|
|
}
|
|
|
|
// makeMessageHandler creates the message callback for a client.
|
|
func (h *Handler) makeMessageHandler(defaultRoom string) func(*WSClient, *Message) {
|
|
return func(client *WSClient, msg *Message) {
|
|
// Set room if not specified
|
|
if msg.Room == "" {
|
|
msg.Room = defaultRoom
|
|
}
|
|
|
|
// Call user callback for message transformation/filtering
|
|
if h.config.OnMessage != nil {
|
|
msg = h.config.OnMessage(client, msg)
|
|
if msg == nil {
|
|
return // Message suppressed
|
|
}
|
|
}
|
|
|
|
// Broadcast via Redis if available, otherwise local only
|
|
if h.broadcaster != nil {
|
|
if err := h.broadcaster.Publish(context.Background(), msg); err != nil {
|
|
h.logger.Warn("failed to publish to broadcaster",
|
|
"error", err,
|
|
"message_id", msg.ID,
|
|
)
|
|
// Fall back to local broadcast
|
|
h.hub.Broadcast(msg)
|
|
}
|
|
} else {
|
|
h.hub.Broadcast(msg)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Stats returns connection statistics.
|
|
type Stats struct {
|
|
TotalConnections int `json:"total_connections"`
|
|
RoomCounts map[string]int `json:"room_counts,omitempty"`
|
|
}
|
|
|
|
// GetStats returns current connection statistics.
|
|
func (h *Handler) GetStats(rooms ...string) Stats {
|
|
stats := Stats{
|
|
TotalConnections: h.hub.ConnectionCount(),
|
|
}
|
|
|
|
if len(rooms) > 0 {
|
|
stats.RoomCounts = make(map[string]int)
|
|
for _, room := range rooms {
|
|
stats.RoomCounts[room] = h.hub.RoomCount(room)
|
|
}
|
|
}
|
|
|
|
return stats
|
|
}
|