package realtime import ( "encoding/json" "fmt" "log/slog" "net/http" "sync" "time" ) // SSEHub manages Server-Sent Events connections for user and channel subscriptions. // Use this for one-way server-to-client events (generation progress, uploads, notifications). // For bidirectional communication (chat), use WebSocket Hub instead. type SSEHub struct { // connections maps channel names to subscriber connections connections map[string]map[*sseConn]struct{} mu sync.RWMutex logger *slog.Logger } // sseConn represents a single SSE connection. type sseConn struct { writer http.ResponseWriter flusher http.Flusher done chan struct{} } // NewSSEHub creates a new SSE hub for event distribution. func NewSSEHub(logger *slog.Logger) *SSEHub { if logger == nil { logger = slog.Default() } return &SSEHub{ connections: make(map[string]map[*sseConn]struct{}), logger: logger, } } // SSEEvent represents an event to send via SSE. type SSEEvent struct { Type string `json:"type"` Timestamp time.Time `json:"timestamp"` JobID string `json:"jobId,omitempty"` Progress int `json:"progress,omitempty"` Message string `json:"message,omitempty"` Result any `json:"result,omitempty"` Error string `json:"error,omitempty"` } // SendToChannel sends an event to all connections subscribed to a channel. // Channel format: "user:" for user-specific events, "channel:" for shared events. func (h *SSEHub) SendToChannel(channel string, event *SSEEvent) { if event.Timestamp.IsZero() { event.Timestamp = time.Now().UTC() } data, err := json.Marshal(event) if err != nil { h.logger.Error("failed to marshal SSE event", "error", err) return } h.mu.RLock() conns, ok := h.connections[channel] if !ok || len(conns) == 0 { // Log active channels to help diagnose channel mismatches channels := make([]string, 0, len(h.connections)) for ch, cs := range h.connections { channels = append(channels, fmt.Sprintf("%s(%d)", ch, len(cs))) } h.mu.RUnlock() h.logger.Warn("SSE event dropped: no subscribers on channel", "target_channel", channel, "event_type", event.Type, "active_channels", channels, ) return } // Copy connections to avoid holding lock during send connList := make([]*sseConn, 0, len(conns)) for conn := range conns { connList = append(connList, conn) } h.mu.RUnlock() for _, conn := range connList { select { case <-conn.done: continue default: h.writeEvent(conn, data) } } } // SendToUser sends an event to all connections for a specific user. // Convenience wrapper for SendToChannel("user:", event). func (h *SSEHub) SendToUser(userID string, event *SSEEvent) { h.SendToChannel("user:"+userID, event) } // writeEvent writes a single SSE event to a connection. func (h *SSEHub) writeEvent(conn *sseConn, data []byte) { select { case <-conn.done: return default: } _, err := fmt.Fprintf(conn.writer, "data: %s\n\n", data) if err != nil { return } conn.flusher.Flush() } // subscribe adds a connection to a channel. func (h *SSEHub) subscribe(channel string, conn *sseConn) { h.mu.Lock() defer h.mu.Unlock() if h.connections[channel] == nil { h.connections[channel] = make(map[*sseConn]struct{}) } h.connections[channel][conn] = struct{}{} } // unsubscribe removes a connection from a channel. func (h *SSEHub) unsubscribe(channel string, conn *sseConn) { h.mu.Lock() defer h.mu.Unlock() if conns, ok := h.connections[channel]; ok { delete(conns, conn) if len(conns) == 0 { delete(h.connections, channel) } } } // ChannelCount returns the number of active connections for a channel. func (h *SSEHub) ChannelCount(channel string) int { h.mu.RLock() defer h.mu.RUnlock() if conns, ok := h.connections[channel]; ok { return len(conns) } return 0 } // SSEHandler handles HTTP requests for SSE event subscriptions. type SSEHandler struct { hub *SSEHub logger *slog.Logger } // NewSSEHandler creates a new SSE HTTP handler. func NewSSEHandler(hub *SSEHub, logger *slog.Logger) *SSEHandler { if logger == nil { logger = slog.Default() } return &SSEHandler{ hub: hub, logger: logger, } } // ServeHTTP handles SSE subscription requests. // Query params: // - channel: Channel to subscribe to (e.g., "user:123") // // Example: GET /api/events?channel=user:123 func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { channel := r.URL.Query().Get("channel") if channel == "" { http.Error(w, "channel parameter required", http.StatusBadRequest) return } // Verify channel format: must be "user:" or "channel:" validUser := len(channel) > 5 && channel[:5] == "user:" validChannel := len(channel) > 8 && channel[:8] == "channel:" if !validUser && !validChannel { http.Error(w, "channel must be user: or channel:", http.StatusBadRequest) return } // Check for SSE support flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "streaming not supported", http.StatusInternalServerError) return } // Set SSE headers w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") // Disable nginx buffering // Create connection conn := &sseConn{ writer: w, flusher: flusher, done: make(chan struct{}), } // Subscribe to channel h.hub.subscribe(channel, conn) defer h.hub.unsubscribe(channel, conn) h.logger.Info("SSE client connected", "channel", channel) // Send initial connection event h.hub.writeEvent(conn, []byte(`{"type":"connected"}`)) // Keep connection alive until client disconnects ctx := r.Context() ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { select { case <-ctx.Done(): close(conn.done) h.logger.Debug("SSE client disconnected", "channel", channel) return case <-ticker.C: // Send keepalive comment _, err := fmt.Fprintf(w, ": keepalive\n\n") if err != nil { close(conn.done) return } flusher.Flush() } } } // Routes returns an http.Handler for the SSE endpoint. // Mount at /api/events or similar. func (h *SSEHandler) Routes() http.Handler { return h }