package circuitbreaker import ( "errors" "sync" "sync/atomic" "testing" "time" ) var errTest = errors.New("test error") func TestCircuitBreaker_Closed(t *testing.T) { cb := New(DefaultConfig()) // Should be closed initially if cb.State() != Closed { t.Errorf("initial state = %v, want Closed", cb.State()) } // Successful calls should work called := false err := cb.Execute(func() error { called = true return nil }) if err != nil { t.Errorf("Execute() error = %v", err) } if !called { t.Error("function was not called") } } func TestCircuitBreaker_OpensAfterFailures(t *testing.T) { cb := New(Config{ FailureThreshold: 3, ResetTimeout: 1 * time.Second, }) // Fail 3 times for i := 0; i < 3; i++ { _ = cb.Execute(func() error { return errTest }) } // Should be open now if cb.State() != Open { t.Errorf("state after 3 failures = %v, want Open", cb.State()) } // Next call should fail immediately called := false err := cb.Execute(func() error { called = true return nil }) if err != ErrCircuitOpen { t.Errorf("Execute() error = %v, want ErrCircuitOpen", err) } if called { t.Error("function should not be called when circuit is open") } } func TestCircuitBreaker_HalfOpenAfterTimeout(t *testing.T) { cb := New(Config{ FailureThreshold: 2, ResetTimeout: 50 * time.Millisecond, }) // Trip the circuit _ = cb.Execute(func() error { return errTest }) _ = cb.Execute(func() error { return errTest }) if cb.State() != Open { t.Fatalf("expected Open state, got %v", cb.State()) } // Wait for reset timeout time.Sleep(60 * time.Millisecond) // Next request should be allowed (half-open) called := false err := cb.Execute(func() error { called = true return nil }) if err != nil { t.Errorf("Execute() in half-open = %v", err) } if !called { t.Error("function should be called in half-open state") } // After success, circuit should be closed if cb.State() != Closed { t.Errorf("state after successful probe = %v, want Closed", cb.State()) } } func TestCircuitBreaker_HalfOpenRetripsOnFailure(t *testing.T) { cb := New(Config{ FailureThreshold: 2, ResetTimeout: 50 * time.Millisecond, }) // Trip the circuit _ = cb.Execute(func() error { return errTest }) _ = cb.Execute(func() error { return errTest }) // Wait for reset timeout time.Sleep(60 * time.Millisecond) // Fail in half-open state _ = cb.Execute(func() error { return errTest }) // Should be open again if cb.State() != Open { t.Errorf("state after half-open failure = %v, want Open", cb.State()) } } func TestCircuitBreaker_SuccessResetsFailures(t *testing.T) { cb := New(Config{ FailureThreshold: 3, ResetTimeout: 1 * time.Second, }) // 2 failures cb.Execute(func() error { return errTest }) cb.Execute(func() error { return errTest }) // 1 success should reset the count cb.Execute(func() error { return nil }) // 2 more failures - should not open (only 2 consecutive) cb.Execute(func() error { return errTest }) cb.Execute(func() error { return errTest }) if cb.State() != Closed { t.Errorf("state = %v, want Closed (success reset counter)", cb.State()) } } func TestCircuitBreaker_Stats(t *testing.T) { cb := New(Config{ FailureThreshold: 5, ResetTimeout: 1 * time.Second, }) // Some operations cb.Execute(func() error { return nil }) cb.Execute(func() error { return errTest }) cb.Execute(func() error { return errTest }) stats := cb.Stats() if stats.State != Closed { t.Errorf("Stats.State = %v, want Closed", stats.State) } if stats.Failures != 2 { t.Errorf("Stats.Failures = %d, want 2", stats.Failures) } if stats.LastFailure.IsZero() { t.Error("Stats.LastFailure should not be zero") } } func TestCircuitBreaker_Reset(t *testing.T) { cb := New(Config{ FailureThreshold: 2, ResetTimeout: 1 * time.Hour, }) // Trip the circuit cb.Execute(func() error { return errTest }) cb.Execute(func() error { return errTest }) if cb.State() != Open { t.Fatalf("expected Open state, got %v", cb.State()) } // Manual reset cb.Reset() if cb.State() != Closed { t.Errorf("state after Reset() = %v, want Closed", cb.State()) } // Should work again called := false cb.Execute(func() error { called = true return nil }) if !called { t.Error("function should be called after Reset()") } } func TestCircuitBreaker_Concurrent(t *testing.T) { cb := New(Config{ FailureThreshold: 10, ResetTimeout: 100 * time.Millisecond, }) var wg sync.WaitGroup var successCount, failCount atomic.Int64 // Concurrent executions for i := 0; i < 100; i++ { wg.Add(1) go func(id int) { defer wg.Done() var err error if id%3 == 0 { err = errTest } result := cb.Execute(func() error { return err }) if result == nil { successCount.Add(1) } else { failCount.Add(1) } }(i) } wg.Wait() total := successCount.Load() + failCount.Load() if total != 100 { t.Errorf("total executions = %d, want 100", total) } } func TestState_String(t *testing.T) { tests := []struct { state State want string }{ {Closed, "closed"}, {Open, "open"}, {HalfOpen, "half-open"}, {State(99), "unknown"}, } for _, tt := range tests { if got := tt.state.String(); got != tt.want { t.Errorf("State(%d).String() = %q, want %q", tt.state, got, tt.want) } } } func TestDefaultConfig(t *testing.T) { cfg := DefaultConfig() if cfg.FailureThreshold != 5 { t.Errorf("FailureThreshold = %d, want 5", cfg.FailureThreshold) } if cfg.ResetTimeout != 30*time.Second { t.Errorf("ResetTimeout = %v, want 30s", cfg.ResetTimeout) } if cfg.HalfOpenRequests != 1 { t.Errorf("HalfOpenRequests = %d, want 1", cfg.HalfOpenRequests) } } func TestNew_DefaultsInvalidValues(t *testing.T) { cb := New(Config{ FailureThreshold: -1, ResetTimeout: -1, HalfOpenRequests: -1, }) stats := cb.Stats() if stats.State != Closed { t.Error("new circuit breaker should be Closed") } }