sp3-verify-1770325794/pkg/realtime/handler.go
jordan 286d313d81
All checks were successful
ci/woodpecker/manual/woodpecker Pipeline was successful
ci/woodpecker/push/woodpecker Pipeline was successful
Initialize project from skeleton template
2026-02-05 21:09:55 +00:00

177 lines
4.4 KiB
Go

package realtime
import (
"context"
"net/http"
"github.com/go-chi/chi/v5"
"git.threesix.ai/jordan/sp3-verify-1770325794/pkg/auth"
"git.threesix.ai/jordan/sp3-verify-1770325794/pkg/logging"
)
// HandlerConfig configures the WebSocket handler.
type HandlerConfig struct {
// Broadcaster handles cross-pod message distribution.
// If nil, messages only broadcast to local pod (single-pod mode).
Broadcaster Broadcaster
// OnConnect is called when a client connects (optional).
OnConnect func(conn Connection)
// OnDisconnect is called when a client disconnects (optional).
OnDisconnect func(conn Connection)
// OnMessage is called for incoming messages (optional).
// Return the message to broadcast it, or nil to suppress.
OnMessage func(conn Connection, msg *Message) *Message
// AuthRequired requires authentication for WebSocket connections.
// If true, unauthenticated connections are rejected.
AuthRequired bool
}
// Handler handles WebSocket connections.
type Handler struct {
hub Hub
broadcaster Broadcaster
logger *logging.Logger
config HandlerConfig
}
// NewHandler creates a new WebSocket handler.
func NewHandler(hub Hub, logger *logging.Logger, cfg HandlerConfig) *Handler {
return &Handler{
hub: hub,
broadcaster: cfg.Broadcaster,
logger: logger.WithComponent("ws-handler"),
config: cfg,
}
}
// Routes returns the chi router for this handler.
// Mount at your desired path: r.Mount("/ws", handler.Routes())
func (h *Handler) Routes() http.Handler {
r := chi.NewRouter()
r.Get("/", h.HandleWebSocket)
r.Get("/{room}", h.HandleWebSocket)
return r
}
// HandleWebSocket upgrades HTTP to WebSocket and manages the connection lifecycle.
func (h *Handler) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
// Extract room from URL (optional)
room := chi.URLParam(r, "room")
if room == "" {
room = r.URL.Query().Get("room")
}
// Extract user from auth context
var userID string
if claims := auth.ClaimsFromContext(r.Context()); claims != nil {
userID = claims.Subject
}
// Check auth requirement
if h.config.AuthRequired && userID == "" {
http.Error(w, "authentication required", http.StatusUnauthorized)
return
}
// Upgrade connection
conn, err := UpgradeConnection(w, r)
if err != nil {
h.logger.Warn("websocket upgrade failed", "error", err)
return
}
// Create client
client := NewWSClient(h.hub, conn, h.logger, WSClientConfig{
UserID: userID,
OnMessage: h.makeMessageHandler(room),
})
h.logger.Info("websocket connection established",
"client_id", client.ID(),
"user_id", userID,
"room", room,
)
// Join room if specified
if room != "" {
h.hub.JoinRoom(client, room)
}
// Notify connect callback
if h.config.OnConnect != nil {
h.config.OnConnect(client)
}
// Run connection (blocks until closed)
client.Run(r.Context())
// Notify disconnect callback
if h.config.OnDisconnect != nil {
h.config.OnDisconnect(client)
}
h.logger.Info("websocket connection closed",
"client_id", client.ID(),
"user_id", userID,
)
}
// makeMessageHandler creates the message callback for a client.
func (h *Handler) makeMessageHandler(defaultRoom string) func(*WSClient, *Message) {
return func(client *WSClient, msg *Message) {
// Set room if not specified
if msg.Room == "" {
msg.Room = defaultRoom
}
// Call user callback for message transformation/filtering
if h.config.OnMessage != nil {
msg = h.config.OnMessage(client, msg)
if msg == nil {
return // Message suppressed
}
}
// Broadcast via Redis if available, otherwise local only
if h.broadcaster != nil {
if err := h.broadcaster.Publish(context.Background(), msg); err != nil {
h.logger.Warn("failed to publish to broadcaster",
"error", err,
"message_id", msg.ID,
)
// Fall back to local broadcast
h.hub.Broadcast(msg)
}
} else {
h.hub.Broadcast(msg)
}
}
}
// Stats returns connection statistics.
type Stats struct {
TotalConnections int `json:"total_connections"`
RoomCounts map[string]int `json:"room_counts,omitempty"`
}
// GetStats returns current connection statistics.
func (h *Handler) GetStats(rooms ...string) Stats {
stats := Stats{
TotalConnections: h.hub.ConnectionCount(),
}
if len(rooms) > 0 {
stats.RoomCounts = make(map[string]int)
for _, room := range rooms {
stats.RoomCounts[room] = h.hub.RoomCount(room)
}
}
return stats
}