package realtime import ( "context" "encoding/json" "net/http" "sync" "time" "github.com/google/uuid" "github.com/gorilla/websocket" "git.threesix.ai/jordan/persona-community-2/pkg/logging" ) const ( // Time allowed to write a message to the peer. writeWait = 10 * time.Second // Time allowed to read the next pong message from the peer. pongWait = 60 * time.Second // Send pings to peer with this period. Must be less than pongWait. pingPeriod = (pongWait * 9) / 10 // Maximum message size allowed from peer. maxMessageSize = 64 * 1024 // 64KB // Size of the send channel buffer. sendBufferSize = 256 ) // WSClient represents a WebSocket connection to the hub. type WSClient struct { id string userID string userName string hub Hub conn *websocket.Conn send chan *Message logger *logging.Logger onMsg func(*WSClient, *Message) // Optional message callback closeOnce sync.Once } // Ensure WSClient implements Connection at compile time. var _ Connection = (*WSClient)(nil) // WSClientConfig configures a new WebSocket client. type WSClientConfig struct { // UserID is the authenticated user ID (empty if anonymous). UserID string // UserName is the display name for the user (optional). UserName string // OnMessage is called for each incoming message. // If nil, messages are ignored (useful for broadcast-only connections). OnMessage func(*WSClient, *Message) } // NewWSClient creates a new WebSocket client from an upgraded connection. func NewWSClient(hub Hub, conn *websocket.Conn, logger *logging.Logger, cfg WSClientConfig) *WSClient { return &WSClient{ id: uuid.New().String(), userID: cfg.UserID, userName: cfg.UserName, hub: hub, conn: conn, send: make(chan *Message, sendBufferSize), logger: logger.WithComponent("ws-client"), onMsg: cfg.OnMessage, } } // ID returns the unique connection identifier. func (c *WSClient) ID() string { return c.id } // UserID returns the authenticated user ID. func (c *WSClient) UserID() string { return c.userID } // UserName returns the display name for the user. func (c *WSClient) UserName() string { return c.userName } // Send queues a message for delivery. func (c *WSClient) Send(msg *Message) bool { select { case c.send <- msg: return true default: // Buffer full, message dropped c.logger.Warn("send buffer full, dropping message", "client_id", c.id) return false } } // Close gracefully closes the connection. func (c *WSClient) Close() { c.closeOnce.Do(func() { close(c.send) }) } // Run starts the read and write pumps. Call after registering with hub. // This method blocks until the connection is closed. func (c *WSClient) Run(ctx context.Context) { // Register with hub c.hub.Register(c) defer c.hub.Unregister(c) // Start pumps go c.writePump(ctx) c.readPump(ctx) } // readPump reads messages from the WebSocket connection. func (c *WSClient) readPump(ctx context.Context) { defer func() { c.Close() _ = c.conn.Close() }() c.conn.SetReadLimit(maxMessageSize) _ = c.conn.SetReadDeadline(time.Now().Add(pongWait)) c.conn.SetPongHandler(func(string) error { _ = c.conn.SetReadDeadline(time.Now().Add(pongWait)) return nil }) for { select { case <-ctx.Done(): return default: } _, data, err := c.conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { c.logger.Debug("websocket read error", "client_id", c.id, "error", err) } return } // Parse message var msg Message if err := json.Unmarshal(data, &msg); err != nil { c.logger.Debug("invalid message format", "client_id", c.id, "error", err) continue } // Set server-side fields msg.From = c.id if msg.Timestamp.IsZero() { msg.Timestamp = time.Now().UTC() } // Handle ping messages locally if msg.Type == MessageTypePing { pong := &Message{ Type: MessageTypePong, Timestamp: time.Now().UTC(), } c.Send(pong) continue } // Dispatch to callback if set if c.onMsg != nil { c.onMsg(c, &msg) } } } // writePump writes messages to the WebSocket connection. func (c *WSClient) writePump(ctx context.Context) { ticker := time.NewTicker(pingPeriod) defer func() { ticker.Stop() _ = c.conn.Close() }() for { select { case <-ctx.Done(): // Send close message _ = c.conn.SetWriteDeadline(time.Now().Add(writeWait)) _ = c.conn.WriteMessage(websocket.CloseMessage, []byte{}) return case msg, ok := <-c.send: _ = c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { // Channel closed _ = c.conn.WriteMessage(websocket.CloseMessage, []byte{}) return } if err := c.conn.WriteJSON(msg); err != nil { c.logger.Debug("websocket write error", "client_id", c.id, "error", err) return } case <-ticker.C: _ = c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } } // UpgradeConnection upgrades an HTTP connection to WebSocket. // Deprecated: Use UpgradeConnectionWithOrigins for production use. func UpgradeConnection(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) { return UpgradeConnectionWithOrigins(w, r, nil) } // UpgradeConnectionWithOrigins upgrades an HTTP connection to WebSocket with origin checking. // If allowedOrigins is empty, all origins are allowed (development mode). // In production, pass a list of allowed origins (e.g., ["https://example.com"]). func UpgradeConnectionWithOrigins(w http.ResponseWriter, r *http.Request, allowedOrigins []string) (*websocket.Conn, error) { upgrader := websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: makeOriginChecker(allowedOrigins), } return upgrader.Upgrade(w, r, nil) } // makeOriginChecker creates an origin check function for the WebSocket upgrader. // If allowedOrigins is empty, all origins are allowed (development mode). func makeOriginChecker(allowedOrigins []string) func(r *http.Request) bool { // If no origins specified, allow all (development mode) if len(allowedOrigins) == 0 { return func(r *http.Request) bool { return true } } // Build a set for O(1) lookup allowed := make(map[string]bool, len(allowedOrigins)) for _, origin := range allowedOrigins { allowed[origin] = true } return func(r *http.Request) bool { origin := r.Header.Get("Origin") if origin == "" { // No Origin header (same-origin request or non-browser client) return true } return allowed[origin] } }