persona-community-2/pkg/realtime/handler.go
jordan cb3d4d5786
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
ci/woodpecker/manual/woodpecker Pipeline was successful
Initialize project from skeleton template
2026-02-23 10:53:55 +00:00

258 lines
6.8 KiB
Go

package realtime
import (
"context"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"git.threesix.ai/jordan/persona-community-2/pkg/auth"
"git.threesix.ai/jordan/persona-community-2/pkg/logging"
)
// HandlerConfig configures the WebSocket handler.
type HandlerConfig struct {
// Broadcaster handles cross-pod message distribution.
// If nil, messages only broadcast to local pod (single-pod mode).
Broadcaster Broadcaster
// OnConnect is called when a client connects (optional).
OnConnect func(conn Connection)
// OnDisconnect is called when a client disconnects (optional).
OnDisconnect func(conn Connection)
// OnMessage is called for incoming messages (optional).
// Return the message to broadcast it, or nil to suppress.
OnMessage func(conn Connection, msg *Message) *Message
// AuthRequired requires authentication for WebSocket connections.
// If true, unauthenticated connections are rejected.
AuthRequired bool
// JWTValidator validates JWT tokens from query parameters (optional).
// If set, tokens passed via ?token= query param will be validated.
JWTValidator auth.Validator
// AllowedOrigins is a whitelist of allowed origins for WebSocket connections.
// If empty, all origins are allowed (suitable for development only).
// In production, set this to your frontend domain(s).
AllowedOrigins []string
}
// Handler handles WebSocket connections.
type Handler struct {
hub Hub
broadcaster Broadcaster
logger *logging.Logger
config HandlerConfig
}
// NewHandler creates a new WebSocket handler.
func NewHandler(hub Hub, logger *logging.Logger, cfg HandlerConfig) *Handler {
return &Handler{
hub: hub,
broadcaster: cfg.Broadcaster,
logger: logger.WithComponent("ws-handler"),
config: cfg,
}
}
// Routes returns the chi router for this handler.
// Mount at your desired path: r.Mount("/ws", handler.Routes())
func (h *Handler) Routes() http.Handler {
r := chi.NewRouter()
r.Get("/", h.HandleWebSocket)
r.Get("/{room}", h.HandleWebSocket)
return r
}
// HandleWebSocket upgrades HTTP to WebSocket and manages the connection lifecycle.
func (h *Handler) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
// Extract room from URL (optional)
room := chi.URLParam(r, "room")
if room == "" {
room = r.URL.Query().Get("room")
}
// Extract user from auth context OR from query parameters
var userID string
var userName string
if user := auth.GetUser(r.Context()); user != nil {
userID = user.ID
userName = user.Email // Use email as display name from auth context
} else if token := r.URL.Query().Get("token"); token != "" && h.config.JWTValidator != nil {
// Validate JWT token from query parameter
user, err := h.config.JWTValidator.Validate(r.Context(), token)
if err != nil {
h.logger.Debug("token validation failed", "error", err)
} else {
userID = user.ID
userName = user.Email
}
}
// Fall back to query params for userId/userName (for testing or anonymous users)
if userID == "" {
userID = r.URL.Query().Get("userId")
}
if userName == "" {
userName = r.URL.Query().Get("userName")
}
// Check auth requirement
if h.config.AuthRequired && userID == "" {
http.Error(w, "authentication required", http.StatusUnauthorized)
return
}
// Upgrade connection with origin check
conn, err := UpgradeConnectionWithOrigins(w, r, h.config.AllowedOrigins)
if err != nil {
h.logger.Warn("websocket upgrade failed", "error", err)
return
}
// Create client
client := NewWSClient(h.hub, conn, h.logger, WSClientConfig{
UserID: userID,
UserName: userName,
OnMessage: h.makeMessageHandler(room),
})
h.logger.Info("websocket connection established",
"client_id", client.ID(),
"user_id", userID,
"user_name", userName,
"room", room,
)
// Join room if specified
if room != "" {
h.hub.JoinRoom(client, room)
}
// Broadcast presence (user joined)
if room != "" {
h.broadcastPresence(client, room, userID, userName, PresenceOnline)
}
// Notify connect callback
if h.config.OnConnect != nil {
h.config.OnConnect(client)
}
// Run connection (blocks until closed)
client.Run(r.Context())
// Broadcast presence (user left)
if room != "" {
h.broadcastPresence(client, room, userID, userName, PresenceOffline)
}
// Notify disconnect callback
if h.config.OnDisconnect != nil {
h.config.OnDisconnect(client)
}
h.logger.Info("websocket connection closed",
"client_id", client.ID(),
"user_id", userID,
)
}
// makeMessageHandler creates the message callback for a client.
func (h *Handler) makeMessageHandler(defaultRoom string) func(*WSClient, *Message) {
return func(client *WSClient, msg *Message) {
// Assign server-side fields for proper message tracking
if msg.ID == "" {
msg.ID = uuid.New().String()
}
if msg.Room == "" {
msg.Room = defaultRoom
}
if msg.Timestamp.IsZero() {
msg.Timestamp = time.Now().UTC()
}
// Call user callback for message transformation/filtering
if h.config.OnMessage != nil {
msg = h.config.OnMessage(client, msg)
if msg == nil {
return // Message suppressed
}
}
// Broadcast via Redis if available, otherwise local only
if h.broadcaster != nil {
if err := h.broadcaster.Publish(context.Background(), msg); err != nil {
h.logger.Warn("failed to publish to broadcaster",
"error", err,
"message_id", msg.ID,
)
// Fall back to local broadcast
h.hub.Broadcast(msg)
}
} else {
h.hub.Broadcast(msg)
}
}
}
// broadcastPresence sends a presence message to the room.
func (h *Handler) broadcastPresence(client Connection, room, userID, userName, status string) {
presenceData, err := SystemMessage(MessageTypePresence, PresenceData{
UserID: userID,
UserName: userName,
Status: status,
})
if err != nil {
h.logger.Warn("failed to create presence message", "error", err)
return
}
presenceData.ID = uuid.New().String()
presenceData.Room = room
presenceData.From = client.ID()
presenceData.Timestamp = time.Now().UTC()
// Broadcast via Redis if available, otherwise local only
if h.broadcaster != nil {
if err := h.broadcaster.Publish(context.Background(), presenceData); err != nil {
h.logger.Warn("failed to publish presence to broadcaster",
"error", err,
"status", status,
)
// Fall back to local broadcast
h.hub.Broadcast(presenceData)
}
} else {
h.hub.Broadcast(presenceData)
}
}
// Stats returns connection statistics.
type Stats struct {
TotalConnections int `json:"total_connections"`
RoomCounts map[string]int `json:"room_counts,omitempty"`
}
// GetStats returns current connection statistics.
func (h *Handler) GetStats(rooms ...string) Stats {
stats := Stats{
TotalConnections: h.hub.ConnectionCount(),
}
if len(rooms) > 0 {
stats.RoomCounts = make(map[string]int)
for _, room := range rooms {
stats.RoomCounts[room] = h.hub.RoomCount(room)
}
}
return stats
}