248 lines
6.1 KiB
Go
248 lines
6.1 KiB
Go
package realtime
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// SSEHub manages Server-Sent Events connections for user and channel subscriptions.
|
|
// Use this for one-way server-to-client events (generation progress, uploads, notifications).
|
|
// For bidirectional communication (chat), use WebSocket Hub instead.
|
|
type SSEHub struct {
|
|
// connections maps channel names to subscriber connections
|
|
connections map[string]map[*sseConn]struct{}
|
|
mu sync.RWMutex
|
|
logger *slog.Logger
|
|
}
|
|
|
|
// sseConn represents a single SSE connection.
|
|
type sseConn struct {
|
|
writer http.ResponseWriter
|
|
flusher http.Flusher
|
|
done chan struct{}
|
|
}
|
|
|
|
// NewSSEHub creates a new SSE hub for event distribution.
|
|
func NewSSEHub(logger *slog.Logger) *SSEHub {
|
|
if logger == nil {
|
|
logger = slog.Default()
|
|
}
|
|
return &SSEHub{
|
|
connections: make(map[string]map[*sseConn]struct{}),
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// SSEEvent represents an event to send via SSE.
|
|
type SSEEvent struct {
|
|
Type string `json:"type"`
|
|
Timestamp time.Time `json:"timestamp"`
|
|
JobID string `json:"jobId,omitempty"`
|
|
Progress int `json:"progress,omitempty"`
|
|
Message string `json:"message,omitempty"`
|
|
Result any `json:"result,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
// SendToChannel sends an event to all connections subscribed to a channel.
|
|
// Channel format: "user:<id>" for user-specific events, "channel:<id>" for shared events.
|
|
func (h *SSEHub) SendToChannel(channel string, event *SSEEvent) {
|
|
if event.Timestamp.IsZero() {
|
|
event.Timestamp = time.Now().UTC()
|
|
}
|
|
|
|
data, err := json.Marshal(event)
|
|
if err != nil {
|
|
h.logger.Error("failed to marshal SSE event", "error", err)
|
|
return
|
|
}
|
|
|
|
h.mu.RLock()
|
|
conns, ok := h.connections[channel]
|
|
if !ok || len(conns) == 0 {
|
|
// Log active channels to help diagnose channel mismatches
|
|
channels := make([]string, 0, len(h.connections))
|
|
for ch, cs := range h.connections {
|
|
channels = append(channels, fmt.Sprintf("%s(%d)", ch, len(cs)))
|
|
}
|
|
h.mu.RUnlock()
|
|
h.logger.Warn("SSE event dropped: no subscribers on channel",
|
|
"target_channel", channel,
|
|
"event_type", event.Type,
|
|
"active_channels", channels,
|
|
)
|
|
return
|
|
}
|
|
|
|
// Copy connections to avoid holding lock during send
|
|
connList := make([]*sseConn, 0, len(conns))
|
|
for conn := range conns {
|
|
connList = append(connList, conn)
|
|
}
|
|
h.mu.RUnlock()
|
|
|
|
for _, conn := range connList {
|
|
select {
|
|
case <-conn.done:
|
|
continue
|
|
default:
|
|
h.writeEvent(conn, data)
|
|
}
|
|
}
|
|
}
|
|
|
|
// SendToUser sends an event to all connections for a specific user.
|
|
// Convenience wrapper for SendToChannel("user:<userID>", event).
|
|
func (h *SSEHub) SendToUser(userID string, event *SSEEvent) {
|
|
h.SendToChannel("user:"+userID, event)
|
|
}
|
|
|
|
// writeEvent writes a single SSE event to a connection.
|
|
func (h *SSEHub) writeEvent(conn *sseConn, data []byte) {
|
|
select {
|
|
case <-conn.done:
|
|
return
|
|
default:
|
|
}
|
|
|
|
_, err := fmt.Fprintf(conn.writer, "data: %s\n\n", data)
|
|
if err != nil {
|
|
return
|
|
}
|
|
conn.flusher.Flush()
|
|
}
|
|
|
|
// subscribe adds a connection to a channel.
|
|
func (h *SSEHub) subscribe(channel string, conn *sseConn) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
if h.connections[channel] == nil {
|
|
h.connections[channel] = make(map[*sseConn]struct{})
|
|
}
|
|
h.connections[channel][conn] = struct{}{}
|
|
}
|
|
|
|
// unsubscribe removes a connection from a channel.
|
|
func (h *SSEHub) unsubscribe(channel string, conn *sseConn) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
if conns, ok := h.connections[channel]; ok {
|
|
delete(conns, conn)
|
|
if len(conns) == 0 {
|
|
delete(h.connections, channel)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ChannelCount returns the number of active connections for a channel.
|
|
func (h *SSEHub) ChannelCount(channel string) int {
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
|
|
if conns, ok := h.connections[channel]; ok {
|
|
return len(conns)
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// SSEHandler handles HTTP requests for SSE event subscriptions.
|
|
type SSEHandler struct {
|
|
hub *SSEHub
|
|
logger *slog.Logger
|
|
}
|
|
|
|
// NewSSEHandler creates a new SSE HTTP handler.
|
|
func NewSSEHandler(hub *SSEHub, logger *slog.Logger) *SSEHandler {
|
|
if logger == nil {
|
|
logger = slog.Default()
|
|
}
|
|
return &SSEHandler{
|
|
hub: hub,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// ServeHTTP handles SSE subscription requests.
|
|
// Query params:
|
|
// - channel: Channel to subscribe to (e.g., "user:123")
|
|
//
|
|
// Example: GET /api/events?channel=user:123
|
|
func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
channel := r.URL.Query().Get("channel")
|
|
if channel == "" {
|
|
http.Error(w, "channel parameter required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Verify channel format: must be "user:<id>" or "channel:<id>"
|
|
validUser := len(channel) > 5 && channel[:5] == "user:"
|
|
validChannel := len(channel) > 8 && channel[:8] == "channel:"
|
|
if !validUser && !validChannel {
|
|
http.Error(w, "channel must be user:<id> or channel:<id>", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Check for SSE support
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Set SSE headers
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.Header().Set("X-Accel-Buffering", "no") // Disable nginx buffering
|
|
|
|
// Create connection
|
|
conn := &sseConn{
|
|
writer: w,
|
|
flusher: flusher,
|
|
done: make(chan struct{}),
|
|
}
|
|
|
|
// Subscribe to channel
|
|
h.hub.subscribe(channel, conn)
|
|
defer h.hub.unsubscribe(channel, conn)
|
|
|
|
h.logger.Info("SSE client connected", "channel", channel)
|
|
|
|
// Send initial connection event
|
|
h.hub.writeEvent(conn, []byte(`{"type":"connected"}`))
|
|
|
|
// Keep connection alive until client disconnects
|
|
ctx := r.Context()
|
|
ticker := time.NewTicker(30 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
close(conn.done)
|
|
h.logger.Debug("SSE client disconnected", "channel", channel)
|
|
return
|
|
case <-ticker.C:
|
|
// Send keepalive comment
|
|
_, err := fmt.Fprintf(w, ": keepalive\n\n")
|
|
if err != nil {
|
|
close(conn.done)
|
|
return
|
|
}
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
}
|
|
|
|
// Routes returns an http.Handler for the SSE endpoint.
|
|
// Mount at /api/events or similar.
|
|
func (h *SSEHandler) Routes() http.Handler {
|
|
return h
|
|
}
|