package service import ( "context" "testing" "time" "github.com/orchard9/rdev/internal/domain" ) // MockAPIKeyRepository implements port.APIKeyRepository for testing. type MockAPIKeyRepository struct { keys map[domain.APIKeyID]*domain.APIKey keysByHash map[string]*domain.APIKey createErr error lastUsedCalls int lastUsedErr error } func NewMockAPIKeyRepository() *MockAPIKeyRepository { return &MockAPIKeyRepository{ keys: make(map[domain.APIKeyID]*domain.APIKey), keysByHash: make(map[string]*domain.APIKey), } } func (m *MockAPIKeyRepository) Create(ctx context.Context, key *domain.APIKey, keyHash string) error { if m.createErr != nil { return m.createErr } key.ID = domain.APIKeyID("key-" + key.Name) key.CreatedAt = time.Now() m.keys[key.ID] = key m.keysByHash[keyHash] = key return nil } func (m *MockAPIKeyRepository) GetByHash(ctx context.Context, keyHash string) (*domain.APIKey, error) { key, ok := m.keysByHash[keyHash] if !ok { return nil, domain.ErrKeyNotFound } return key, nil } func (m *MockAPIKeyRepository) Get(ctx context.Context, id domain.APIKeyID) (*domain.APIKey, error) { key, ok := m.keys[id] if !ok { return nil, domain.ErrKeyNotFound } return key, nil } func (m *MockAPIKeyRepository) List(ctx context.Context) ([]*domain.APIKey, error) { result := make([]*domain.APIKey, 0, len(m.keys)) for _, k := range m.keys { result = append(result, k) } return result, nil } func (m *MockAPIKeyRepository) Revoke(ctx context.Context, id domain.APIKeyID) error { key, ok := m.keys[id] if !ok { return domain.ErrKeyNotFound } now := time.Now() key.RevokedAt = &now return nil } func (m *MockAPIKeyRepository) UpdateLastUsed(ctx context.Context, id domain.APIKeyID) error { m.lastUsedCalls++ if m.lastUsedErr != nil { return m.lastUsedErr } key, ok := m.keys[id] if !ok { return domain.ErrKeyNotFound } now := time.Now() key.LastUsedAt = &now return nil } func TestAPIKeyService_Create(t *testing.T) { repo := NewMockAPIKeyRepository() svc := NewAPIKeyService(repo, "admin-secret") t.Run("creates key successfully", func(t *testing.T) { result, err := svc.Create(context.Background(), CreateKeyRequest{ Name: "test-key", Scopes: []domain.Scope{domain.ScopeProjectsRead}, ExpiresIn: 24 * time.Hour, CreatedBy: "test-user", }) if err != nil { t.Fatalf("Create() error = %v", err) } if result.Key.Name != "test-key" { t.Errorf("Key.Name = %q, want %q", result.Key.Name, "test-key") } if result.Secret == "" { t.Error("Secret should not be empty") } if len(result.Key.KeyPrefix) != 8 { t.Errorf("KeyPrefix length = %d, want 8", len(result.Key.KeyPrefix)) } if result.Key.ExpiresAt == nil { t.Error("ExpiresAt should be set") } }) t.Run("creates key without expiration", func(t *testing.T) { result, err := svc.Create(context.Background(), CreateKeyRequest{ Name: "never-expires", Scopes: []domain.Scope{domain.ScopeAdmin}, CreatedBy: "admin", }) if err != nil { t.Fatalf("Create() error = %v", err) } if result.Key.ExpiresAt != nil { t.Error("ExpiresAt should be nil for keys without expiration") } }) t.Run("creates key with project restrictions", func(t *testing.T) { result, err := svc.Create(context.Background(), CreateKeyRequest{ Name: "restricted-key", Scopes: []domain.Scope{domain.ScopeProjectsRead}, ProjectIDs: []domain.ProjectID{"proj-a", "proj-b"}, CreatedBy: "test", }) if err != nil { t.Fatalf("Create() error = %v", err) } if len(result.Key.ProjectIDs) != 2 { t.Errorf("ProjectIDs length = %d, want 2", len(result.Key.ProjectIDs)) } }) } func TestAPIKeyService_Get(t *testing.T) { repo := NewMockAPIKeyRepository() svc := NewAPIKeyService(repo, "admin-secret") // Create a key first createResult, _ := svc.Create(context.Background(), CreateKeyRequest{ Name: "get-test", Scopes: []domain.Scope{domain.ScopeProjectsRead}, CreatedBy: "test", }) t.Run("gets existing key", func(t *testing.T) { key, err := svc.Get(context.Background(), createResult.Key.ID) if err != nil { t.Fatalf("Get() error = %v", err) } if key.Name != "get-test" { t.Errorf("Name = %q, want %q", key.Name, "get-test") } }) t.Run("returns error for nonexistent key", func(t *testing.T) { _, err := svc.Get(context.Background(), "nonexistent") if err != domain.ErrKeyNotFound { t.Errorf("Get() error = %v, want %v", err, domain.ErrKeyNotFound) } }) } func TestAPIKeyService_List(t *testing.T) { repo := NewMockAPIKeyRepository() svc := NewAPIKeyService(repo, "admin-secret") // Create some keys for i := 0; i < 3; i++ { _, _ = svc.Create(context.Background(), CreateKeyRequest{ Name: "list-key-" + string(rune('a'+i)), Scopes: []domain.Scope{domain.ScopeProjectsRead}, CreatedBy: "test", }) } keys, err := svc.List(context.Background()) if err != nil { t.Fatalf("List() error = %v", err) } if len(keys) != 3 { t.Errorf("List() returned %d keys, want 3", len(keys)) } } func TestAPIKeyService_Revoke(t *testing.T) { repo := NewMockAPIKeyRepository() svc := NewAPIKeyService(repo, "admin-secret") // Create a key createResult, _ := svc.Create(context.Background(), CreateKeyRequest{ Name: "revoke-test", Scopes: []domain.Scope{domain.ScopeProjectsRead}, CreatedBy: "test", }) t.Run("revokes existing key", func(t *testing.T) { err := svc.Revoke(context.Background(), createResult.Key.ID) if err != nil { t.Fatalf("Revoke() error = %v", err) } // Verify revoked key, _ := svc.Get(context.Background(), createResult.Key.ID) if key.RevokedAt == nil { t.Error("RevokedAt should be set after revoke") } }) t.Run("returns error for nonexistent key", func(t *testing.T) { err := svc.Revoke(context.Background(), "nonexistent") if err != domain.ErrKeyNotFound { t.Errorf("Revoke() error = %v, want %v", err, domain.ErrKeyNotFound) } }) } func TestAPIKeyService_UpdateLastUsed(t *testing.T) { repo := NewMockAPIKeyRepository() svc := NewAPIKeyService(repo, "admin-secret") // Create a key createResult, _ := svc.Create(context.Background(), CreateKeyRequest{ Name: "last-used-test", Scopes: []domain.Scope{domain.ScopeProjectsRead}, CreatedBy: "test", }) err := svc.UpdateLastUsed(context.Background(), createResult.Key.ID) if err != nil { t.Fatalf("UpdateLastUsed() error = %v", err) } // Verify updated key, _ := svc.Get(context.Background(), createResult.Key.ID) if key.LastUsedAt == nil { t.Error("LastUsedAt should be set after update") } } func TestAPIKeyService_ValidateAdminKey(t *testing.T) { svc := NewAPIKeyService(nil, "super-secret-admin") tests := []struct { key string want bool }{ {"super-secret-admin", true}, {"wrong-key", false}, {"", false}, } for _, tt := range tests { if got := svc.ValidateAdminKey(tt.key); got != tt.want { t.Errorf("ValidateAdminKey(%q) = %v, want %v", tt.key, got, tt.want) } } } func TestAPIKeyService_ValidateAdminKey_NoAdmin(t *testing.T) { svc := NewAPIKeyService(nil, "") // When no admin key is set, validation should always fail if svc.ValidateAdminKey("anything") { t.Error("ValidateAdminKey should return false when no admin key is set") } } func TestAPIKeyService_AdminKey(t *testing.T) { svc := NewAPIKeyService(nil, "my-admin-key") if got := svc.AdminKey(); got != "my-admin-key" { t.Errorf("AdminKey() = %q, want %q", got, "my-admin-key") } } func TestParseExpiration(t *testing.T) { tests := []struct { input string want time.Duration wantErr bool }{ {"", 0, false}, {"never", 0, false}, {"30d", 30 * 24 * time.Hour, false}, {"60d", 60 * 24 * time.Hour, false}, {"90d", 90 * 24 * time.Hour, false}, {"1y", 365 * 24 * time.Hour, false}, {"invalid", 0, true}, {"10d", 0, true}, // Not a supported format } for _, tt := range tests { got, err := ParseExpiration(tt.input) if (err != nil) != tt.wantErr { t.Errorf("ParseExpiration(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) continue } if got != tt.want { t.Errorf("ParseExpiration(%q) = %v, want %v", tt.input, got, tt.want) } } } func TestHashKey(t *testing.T) { // Same input should produce same hash hash1 := hashKey("test-key") hash2 := hashKey("test-key") if hash1 != hash2 { t.Error("Same input should produce same hash") } // Different input should produce different hash hash3 := hashKey("different-key") if hash1 == hash3 { t.Error("Different input should produce different hash") } // Hash should be 64 hex characters (SHA-256) if len(hash1) != 64 { t.Errorf("Hash length = %d, want 64", len(hash1)) } }