persona-community-5/pkg/realtime/sse.go
jordan bd2f591b98
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
ci/woodpecker/manual/woodpecker Pipeline was successful
Initialize project from skeleton template
2026-02-24 07:39:46 +00:00

248 lines
6.1 KiB
Go

package realtime
import (
"encoding/json"
"fmt"
"log/slog"
"net/http"
"sync"
"time"
)
// SSEHub manages Server-Sent Events connections for user and channel subscriptions.
// Use this for one-way server-to-client events (generation progress, uploads, notifications).
// For bidirectional communication (chat), use WebSocket Hub instead.
type SSEHub struct {
// connections maps channel names to subscriber connections
connections map[string]map[*sseConn]struct{}
mu sync.RWMutex
logger *slog.Logger
}
// sseConn represents a single SSE connection.
type sseConn struct {
writer http.ResponseWriter
flusher http.Flusher
done chan struct{}
}
// NewSSEHub creates a new SSE hub for event distribution.
func NewSSEHub(logger *slog.Logger) *SSEHub {
if logger == nil {
logger = slog.Default()
}
return &SSEHub{
connections: make(map[string]map[*sseConn]struct{}),
logger: logger,
}
}
// SSEEvent represents an event to send via SSE.
type SSEEvent struct {
Type string `json:"type"`
Timestamp time.Time `json:"timestamp"`
JobID string `json:"jobId,omitempty"`
Progress int `json:"progress,omitempty"`
Message string `json:"message,omitempty"`
Result any `json:"result,omitempty"`
Error string `json:"error,omitempty"`
}
// SendToChannel sends an event to all connections subscribed to a channel.
// Channel format: "user:<id>" for user-specific events, "channel:<id>" for shared events.
func (h *SSEHub) SendToChannel(channel string, event *SSEEvent) {
if event.Timestamp.IsZero() {
event.Timestamp = time.Now().UTC()
}
data, err := json.Marshal(event)
if err != nil {
h.logger.Error("failed to marshal SSE event", "error", err)
return
}
h.mu.RLock()
conns, ok := h.connections[channel]
if !ok || len(conns) == 0 {
// Log active channels to help diagnose channel mismatches
channels := make([]string, 0, len(h.connections))
for ch, cs := range h.connections {
channels = append(channels, fmt.Sprintf("%s(%d)", ch, len(cs)))
}
h.mu.RUnlock()
h.logger.Warn("SSE event dropped: no subscribers on channel",
"target_channel", channel,
"event_type", event.Type,
"active_channels", channels,
)
return
}
// Copy connections to avoid holding lock during send
connList := make([]*sseConn, 0, len(conns))
for conn := range conns {
connList = append(connList, conn)
}
h.mu.RUnlock()
for _, conn := range connList {
select {
case <-conn.done:
continue
default:
h.writeEvent(conn, data)
}
}
}
// SendToUser sends an event to all connections for a specific user.
// Convenience wrapper for SendToChannel("user:<userID>", event).
func (h *SSEHub) SendToUser(userID string, event *SSEEvent) {
h.SendToChannel("user:"+userID, event)
}
// writeEvent writes a single SSE event to a connection.
func (h *SSEHub) writeEvent(conn *sseConn, data []byte) {
select {
case <-conn.done:
return
default:
}
_, err := fmt.Fprintf(conn.writer, "data: %s\n\n", data)
if err != nil {
return
}
conn.flusher.Flush()
}
// subscribe adds a connection to a channel.
func (h *SSEHub) subscribe(channel string, conn *sseConn) {
h.mu.Lock()
defer h.mu.Unlock()
if h.connections[channel] == nil {
h.connections[channel] = make(map[*sseConn]struct{})
}
h.connections[channel][conn] = struct{}{}
}
// unsubscribe removes a connection from a channel.
func (h *SSEHub) unsubscribe(channel string, conn *sseConn) {
h.mu.Lock()
defer h.mu.Unlock()
if conns, ok := h.connections[channel]; ok {
delete(conns, conn)
if len(conns) == 0 {
delete(h.connections, channel)
}
}
}
// ChannelCount returns the number of active connections for a channel.
func (h *SSEHub) ChannelCount(channel string) int {
h.mu.RLock()
defer h.mu.RUnlock()
if conns, ok := h.connections[channel]; ok {
return len(conns)
}
return 0
}
// SSEHandler handles HTTP requests for SSE event subscriptions.
type SSEHandler struct {
hub *SSEHub
logger *slog.Logger
}
// NewSSEHandler creates a new SSE HTTP handler.
func NewSSEHandler(hub *SSEHub, logger *slog.Logger) *SSEHandler {
if logger == nil {
logger = slog.Default()
}
return &SSEHandler{
hub: hub,
logger: logger,
}
}
// ServeHTTP handles SSE subscription requests.
// Query params:
// - channel: Channel to subscribe to (e.g., "user:123")
//
// Example: GET /api/events?channel=user:123
func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
channel := r.URL.Query().Get("channel")
if channel == "" {
http.Error(w, "channel parameter required", http.StatusBadRequest)
return
}
// Verify channel format: must be "user:<id>" or "channel:<id>"
validUser := len(channel) > 5 && channel[:5] == "user:"
validChannel := len(channel) > 8 && channel[:8] == "channel:"
if !validUser && !validChannel {
http.Error(w, "channel must be user:<id> or channel:<id>", http.StatusBadRequest)
return
}
// Check for SSE support
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
// Set SSE headers
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no") // Disable nginx buffering
// Create connection
conn := &sseConn{
writer: w,
flusher: flusher,
done: make(chan struct{}),
}
// Subscribe to channel
h.hub.subscribe(channel, conn)
defer h.hub.unsubscribe(channel, conn)
h.logger.Info("SSE client connected", "channel", channel)
// Send initial connection event
h.hub.writeEvent(conn, []byte(`{"type":"connected"}`))
// Keep connection alive until client disconnects
ctx := r.Context()
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
close(conn.done)
h.logger.Debug("SSE client disconnected", "channel", channel)
return
case <-ticker.C:
// Send keepalive comment
_, err := fmt.Fprintf(w, ": keepalive\n\n")
if err != nil {
close(conn.done)
return
}
flusher.Flush()
}
}
}
// Routes returns an http.Handler for the SSE endpoint.
// Mount at /api/events or similar.
func (h *SSEHandler) Routes() http.Handler {
return h
}