package webhook import ( "context" "encoding/json" "io" "log/slog" "net/http" "net/http/httptest" "sync" "sync/atomic" "testing" "time" "github.com/orchard9/rdev/internal/domain" ) // discardLogger returns a logger that discards all output. func discardLogger() *slog.Logger { return slog.New(slog.NewTextHandler(io.Discard, nil)) } // mockWebhookRepo implements port.WebhookRepository for testing. type mockWebhookRepo struct { webhooks []*domain.Webhook mu sync.RWMutex deliveries []*domain.WebhookDelivery err error } func (m *mockWebhookRepo) Create(ctx context.Context, webhook *domain.Webhook) error { return m.err } func (m *mockWebhookRepo) Update(ctx context.Context, webhook *domain.Webhook) error { return m.err } func (m *mockWebhookRepo) Delete(ctx context.Context, id domain.WebhookID) error { return m.err } func (m *mockWebhookRepo) GetByID(ctx context.Context, id domain.WebhookID) (*domain.Webhook, error) { if m.err != nil { return nil, m.err } for _, w := range m.webhooks { if w.ID == id { return w, nil } } return nil, domain.ErrWebhookNotFound } func (m *mockWebhookRepo) ListByProject(ctx context.Context, projectID string) ([]*domain.Webhook, error) { if m.err != nil { return nil, m.err } var result []*domain.Webhook for _, w := range m.webhooks { if w.ProjectID == projectID { result = append(result, w) } } return result, nil } func (m *mockWebhookRepo) ListEnabledByProjectAndEvent(ctx context.Context, projectID string, eventType domain.WebhookEventType) ([]*domain.Webhook, error) { if m.err != nil { return nil, m.err } var result []*domain.Webhook for _, w := range m.webhooks { if w.ProjectID == projectID && w.Enabled { for _, e := range w.Events { if e == eventType { result = append(result, w) break } } } } return result, nil } func (m *mockWebhookRepo) RecordDelivery(ctx context.Context, delivery *domain.WebhookDelivery) error { if m.err != nil { return m.err } m.mu.Lock() m.deliveries = append(m.deliveries, delivery) m.mu.Unlock() return nil } func (m *mockWebhookRepo) GetDeliveries(ctx context.Context, webhookID domain.WebhookID, filters *domain.WebhookDeliveryFilters) ([]*domain.WebhookDelivery, error) { m.mu.RLock() defer m.mu.RUnlock() return m.deliveries, m.err } func (m *mockWebhookRepo) DeliveryCount() int { m.mu.RLock() defer m.mu.RUnlock() return len(m.deliveries) } func (m *mockWebhookRepo) CleanupOldDeliveries(ctx context.Context, olderThanDays int) (int64, error) { return 0, m.err } func TestDispatcher_NewDispatcher(t *testing.T) { repo := &mockWebhookRepo{} // With nil config, should use defaults d := NewDispatcher(repo, nil) if d == nil { t.Fatal("NewDispatcher returned nil") } if d.config.WorkerCount != 10 { t.Errorf("expected default WorkerCount of 10, got %d", d.config.WorkerCount) } // With custom config cfg := &DispatcherConfig{ WorkerCount: 5, MaxRetries: 5, Timeout: 10 * time.Second, } d = NewDispatcher(repo, cfg) if d.config.WorkerCount != 5 { t.Errorf("expected WorkerCount of 5, got %d", d.config.WorkerCount) } } func TestDispatcher_StartStop(t *testing.T) { repo := &mockWebhookRepo{} d := NewDispatcher(repo, &DispatcherConfig{ WorkerCount: 2, Logger: discardLogger(), }) if err := d.Start(); err != nil { t.Fatalf("Start() error = %v", err) } // Verify health if !d.Health() { t.Error("expected dispatcher to be healthy after start") } // Stop should complete without deadlock done := make(chan struct{}) go func() { d.Stop() close(done) }() select { case <-done: // OK case <-time.After(5 * time.Second): t.Fatal("Stop() timed out") } // After stop, should not be healthy if d.Health() { t.Error("expected dispatcher to be unhealthy after stop") } } func TestDispatcher_Dispatch(t *testing.T) { // Create a test server to receive webhooks var receivedCount atomic.Int32 var payloadMu sync.Mutex var receivedPayload []byte server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { receivedCount.Add(1) buf := make([]byte, 1024) n, _ := r.Body.Read(buf) payloadMu.Lock() receivedPayload = buf[:n] payloadMu.Unlock() w.WriteHeader(http.StatusOK) })) defer server.Close() repo := &mockWebhookRepo{ webhooks: []*domain.Webhook{ { ID: "wh-1", ProjectID: "proj-1", URL: server.URL, Secret: "test-secret", Events: []domain.WebhookEventType{domain.WebhookEventCommandStarted}, Enabled: true, }, }, } d := NewDispatcher(repo, &DispatcherConfig{ WorkerCount: 2, MaxRetries: 0, Timeout: 5 * time.Second, RetryBackoff: time.Millisecond, Logger: discardLogger(), }) if err := d.Start(); err != nil { t.Fatalf("Start() error = %v", err) } defer d.Stop() // Dispatch an event event := &domain.WebhookEvent{ Type: domain.WebhookEventCommandStarted, ProjectID: "proj-1", Timestamp: time.Now(), Data: map[string]any{ "command_id": "cmd-123", }, } if err := d.Dispatch(context.Background(), "proj-1", event); err != nil { t.Fatalf("Dispatch() error = %v", err) } // Wait for delivery time.Sleep(100 * time.Millisecond) if receivedCount.Load() != 1 { t.Errorf("expected 1 webhook delivery, got %d", receivedCount.Load()) } // Verify payload payloadMu.Lock() payloadCopy := receivedPayload payloadMu.Unlock() if len(payloadCopy) > 0 { var payload domain.WebhookPayload if err := json.Unmarshal(payloadCopy, &payload); err != nil { t.Errorf("failed to unmarshal payload: %v", err) } if payload.Event != domain.WebhookEventCommandStarted { t.Errorf("expected event type %s, got %s", domain.WebhookEventCommandStarted, payload.Event) } } } func TestDispatcher_DispatchNoWebhooks(t *testing.T) { repo := &mockWebhookRepo{ webhooks: nil, // No webhooks configured } d := NewDispatcher(repo, &DispatcherConfig{ WorkerCount: 1, Logger: discardLogger(), }) if err := d.Start(); err != nil { t.Fatalf("Start() error = %v", err) } defer d.Stop() event := &domain.WebhookEvent{ Type: domain.WebhookEventCommandStarted, ProjectID: "proj-1", Timestamp: time.Now(), } // Should not error when there are no webhooks if err := d.Dispatch(context.Background(), "proj-1", event); err != nil { t.Errorf("Dispatch() error = %v, want nil", err) } } func TestDispatcher_DeliveryFailure(t *testing.T) { // Create a test server that always fails var requestCount atomic.Int32 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestCount.Add(1) w.WriteHeader(http.StatusInternalServerError) })) defer server.Close() repo := &mockWebhookRepo{ webhooks: []*domain.Webhook{ { ID: "wh-1", ProjectID: "proj-1", URL: server.URL, Events: []domain.WebhookEventType{domain.WebhookEventCommandStarted}, Enabled: true, }, }, } d := NewDispatcher(repo, &DispatcherConfig{ WorkerCount: 1, MaxRetries: 2, // 2 retries = 3 total attempts Timeout: 5 * time.Second, RetryBackoff: 10 * time.Millisecond, // Fast retries for testing Logger: discardLogger(), }) if err := d.Start(); err != nil { t.Fatalf("Start() error = %v", err) } defer d.Stop() event := &domain.WebhookEvent{ Type: domain.WebhookEventCommandStarted, ProjectID: "proj-1", Timestamp: time.Now(), } if err := d.Dispatch(context.Background(), "proj-1", event); err != nil { t.Fatalf("Dispatch() error = %v", err) } // Wait for delivery and retries (initial + 2 retries with exponential backoff) // Backoff: 10ms, 20ms = ~30ms total + processing time time.Sleep(200 * time.Millisecond) // Should have attempted delivery 3 times (initial + 2 retries) count := requestCount.Load() if count != 3 { t.Errorf("expected 3 delivery attempts, got %d", count) } // Verify delivery was recorded if repo.DeliveryCount() == 0 { t.Error("expected delivery to be recorded") } } func TestDispatcher_QueueSize(t *testing.T) { repo := &mockWebhookRepo{} d := NewDispatcher(repo, &DispatcherConfig{ WorkerCount: 1, }) // Before start, queue should be empty if d.QueueSize() != 0 { t.Errorf("expected queue size 0, got %d", d.QueueSize()) } } func TestDispatcher_SignPayload(t *testing.T) { d := &Dispatcher{} payload := []byte(`{"test": true}`) secret := "my-secret" signature := d.signPayload(payload, secret) // Should be sha256= if len(signature) < 10 || signature[:7] != "sha256=" { t.Errorf("invalid signature format: %s", signature) } // Same payload and secret should produce same signature signature2 := d.signPayload(payload, secret) if signature != signature2 { t.Error("signatures should be deterministic") } // Different secret should produce different signature signature3 := d.signPayload(payload, "different-secret") if signature == signature3 { t.Error("different secrets should produce different signatures") } } func TestDefaultDispatcherConfig(t *testing.T) { cfg := DefaultDispatcherConfig() if cfg.WorkerCount != 10 { t.Errorf("expected WorkerCount 10, got %d", cfg.WorkerCount) } if cfg.MaxRetries != 3 { t.Errorf("expected MaxRetries 3, got %d", cfg.MaxRetries) } if cfg.Timeout != 30*time.Second { t.Errorf("expected Timeout 30s, got %v", cfg.Timeout) } if cfg.RetryBackoff != 5*time.Second { t.Errorf("expected RetryBackoff 5s, got %v", cfg.RetryBackoff) } if cfg.MaxResponseBodySize != 1024 { t.Errorf("expected MaxResponseBodySize 1024, got %d", cfg.MaxResponseBodySize) } }