package postgres import ( "context" "crypto/sha256" "encoding/hex" "testing" "time" "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/testutil" ) func hashKey(key string) string { h := sha256.Sum256([]byte(key)) return hex.EncodeToString(h[:]) } func TestAPIKeyRepository_Create(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) repo := NewAPIKeyRepository(db) ctx := context.Background() t.Run("creates key with all fields", func(t *testing.T) { expires := time.Now().Add(24 * time.Hour) key := &domain.APIKey{ Name: "test-repo-create", KeyPrefix: "abc12345", Scopes: []domain.Scope{domain.ScopeProjectsRead, domain.ScopeKeysManage}, ProjectIDs: []domain.ProjectID{"proj-a", "proj-b"}, AllowedIPs: []string{"192.168.1.0/24", "10.0.0.1"}, ExpiresAt: &expires, CreatedBy: "test-user", } keyHash := hashKey("test-key-123") err := repo.Create(ctx, key, keyHash) if err != nil { t.Fatalf("Create() error = %v", err) } if key.ID == "" { t.Error("ID should be set after create") } // Verify via GetByHash retrieved, err := repo.GetByHash(ctx, keyHash) if err != nil { t.Fatalf("GetByHash() error = %v", err) } if retrieved.Name != "test-repo-create" { t.Errorf("Name = %q, want %q", retrieved.Name, "test-repo-create") } if len(retrieved.Scopes) != 2 { t.Errorf("Scopes length = %d, want 2", len(retrieved.Scopes)) } if len(retrieved.ProjectIDs) != 2 { t.Errorf("ProjectIDs length = %d, want 2", len(retrieved.ProjectIDs)) } if len(retrieved.AllowedIPs) != 2 { t.Errorf("AllowedIPs length = %d, want 2", len(retrieved.AllowedIPs)) } }) t.Run("creates key with minimal fields", func(t *testing.T) { key := &domain.APIKey{ Name: "test-repo-minimal", KeyPrefix: "min12345", Scopes: []domain.Scope{domain.ScopeProjectsRead}, CreatedBy: "test", } keyHash := hashKey("minimal-key-456") err := repo.Create(ctx, key, keyHash) if err != nil { t.Fatalf("Create() error = %v", err) } retrieved, _ := repo.GetByHash(ctx, keyHash) if retrieved.ExpiresAt != nil { t.Error("ExpiresAt should be nil for keys without expiration") } if len(retrieved.ProjectIDs) != 0 { t.Error("ProjectIDs should be empty") } if len(retrieved.AllowedIPs) != 0 { t.Error("AllowedIPs should be empty") } }) } func TestAPIKeyRepository_GetByHash(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) repo := NewAPIKeyRepository(db) ctx := context.Background() // Create a test key keyHash := hashKey("get-by-hash-key") key := &domain.APIKey{ Name: "test-get-hash", KeyPrefix: "geth1234", Scopes: []domain.Scope{domain.ScopeAdmin}, CreatedBy: "test", } _ = repo.Create(ctx, key, keyHash) t.Run("finds existing key", func(t *testing.T) { retrieved, err := repo.GetByHash(ctx, keyHash) if err != nil { t.Fatalf("GetByHash() error = %v", err) } if retrieved.Name != "test-get-hash" { t.Errorf("Name = %q, want %q", retrieved.Name, "test-get-hash") } }) t.Run("returns error for nonexistent hash", func(t *testing.T) { _, err := repo.GetByHash(ctx, hashKey("nonexistent")) if err != domain.ErrKeyNotFound { t.Errorf("GetByHash() error = %v, want %v", err, domain.ErrKeyNotFound) } }) } func TestAPIKeyRepository_Get(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) repo := NewAPIKeyRepository(db) ctx := context.Background() // Create a test key key := &domain.APIKey{ Name: "test-get-by-id", KeyPrefix: "getid123", Scopes: []domain.Scope{domain.ScopeProjectsRead}, CreatedBy: "test", } _ = repo.Create(ctx, key, hashKey("get-by-id-key")) t.Run("finds existing key", func(t *testing.T) { retrieved, err := repo.Get(ctx, key.ID) if err != nil { t.Fatalf("Get() error = %v", err) } if retrieved.Name != "test-get-by-id" { t.Errorf("Name = %q, want %q", retrieved.Name, "test-get-by-id") } }) t.Run("returns error for nonexistent ID", func(t *testing.T) { _, err := repo.Get(ctx, "00000000-0000-0000-0000-000000000000") if err != domain.ErrKeyNotFound { t.Errorf("Get() error = %v, want %v", err, domain.ErrKeyNotFound) } }) } func TestAPIKeyRepository_List(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) repo := NewAPIKeyRepository(db) ctx := context.Background() // Create test keys for i := 0; i < 3; i++ { key := &domain.APIKey{ Name: "test-list-" + string(rune('a'+i)), KeyPrefix: "list1234", Scopes: []domain.Scope{domain.ScopeProjectsRead}, CreatedBy: "test", } _ = repo.Create(ctx, key, hashKey("list-key-"+string(rune('a'+i)))) } keys, err := repo.List(ctx) if err != nil { t.Fatalf("List() error = %v", err) } // Count our test keys testKeyCount := 0 for _, k := range keys { if len(k.Name) >= 10 && k.Name[:10] == "test-list-" { testKeyCount++ } } if testKeyCount != 3 { t.Errorf("List() returned %d test keys, want 3", testKeyCount) } } func TestAPIKeyRepository_Revoke(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) repo := NewAPIKeyRepository(db) ctx := context.Background() t.Run("revokes existing key", func(t *testing.T) { key := &domain.APIKey{ Name: "test-revoke", KeyPrefix: "rev12345", Scopes: []domain.Scope{domain.ScopeProjectsRead}, CreatedBy: "test", } keyHash := hashKey("revoke-key") repo.Create(ctx, key, keyHash) err := repo.Revoke(ctx, key.ID) if err != nil { t.Fatalf("Revoke() error = %v", err) } // Verify revoked retrieved, _ := repo.Get(ctx, key.ID) if retrieved.RevokedAt == nil { t.Error("RevokedAt should be set after revoke") } }) t.Run("returns error for nonexistent key", func(t *testing.T) { err := repo.Revoke(ctx, "00000000-0000-0000-0000-000000000000") if err != domain.ErrKeyNotFound { t.Errorf("Revoke() error = %v, want %v", err, domain.ErrKeyNotFound) } }) t.Run("returns error for already revoked key", func(t *testing.T) { key := &domain.APIKey{ Name: "test-revoke-twice", KeyPrefix: "rev21234", Scopes: []domain.Scope{domain.ScopeProjectsRead}, CreatedBy: "test", } _ = repo.Create(ctx, key, hashKey("revoke-twice-key")) // First revoke _ = repo.Revoke(ctx, key.ID) // Second revoke should fail err := repo.Revoke(ctx, key.ID) if err != domain.ErrKeyNotFound { t.Errorf("Second Revoke() error = %v, want %v", err, domain.ErrKeyNotFound) } }) } func TestAPIKeyRepository_UpdateLastUsed(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) repo := NewAPIKeyRepository(db) ctx := context.Background() key := &domain.APIKey{ Name: "test-last-used", KeyPrefix: "lu123456", Scopes: []domain.Scope{domain.ScopeProjectsRead}, CreatedBy: "test", } repo.Create(ctx, key, hashKey("last-used-key")) // Initial state - no last_used_at retrieved, _ := repo.Get(ctx, key.ID) if retrieved.LastUsedAt != nil { t.Error("LastUsedAt should be nil initially") } // Update last used err := repo.UpdateLastUsed(ctx, key.ID) if err != nil { t.Fatalf("UpdateLastUsed() error = %v", err) } // Verify updated retrieved, _ = repo.Get(ctx, key.ID) if retrieved.LastUsedAt == nil { t.Error("LastUsedAt should be set after update") } if time.Since(*retrieved.LastUsedAt) > time.Minute { t.Error("LastUsedAt should be recent") } } func TestAPIKeyRepository_ScopeArrayHandling(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) repo := NewAPIKeyRepository(db) ctx := context.Background() tests := []struct { name string scopes []domain.Scope }{ {"single scope", []domain.Scope{domain.ScopeProjectsRead}}, {"multiple scopes", []domain.Scope{domain.ScopeProjectsRead, domain.ScopeProjectsExecute, domain.ScopeKeysManage}}, {"admin scope", []domain.Scope{domain.ScopeAdmin}}, {"all scopes", []domain.Scope{domain.ScopeProjectsRead, domain.ScopeProjectsExecute, domain.ScopeKeysManage, domain.ScopeKeysManage, domain.ScopeAdmin}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { key := &domain.APIKey{ Name: "test-scopes-" + tt.name, KeyPrefix: "sc123456", Scopes: tt.scopes, CreatedBy: "test", } repo.Create(ctx, key, hashKey("scopes-"+tt.name)) retrieved, _ := repo.Get(ctx, key.ID) if len(retrieved.Scopes) != len(tt.scopes) { t.Errorf("Scopes length = %d, want %d", len(retrieved.Scopes), len(tt.scopes)) } // Verify each scope for _, expected := range tt.scopes { found := false for _, actual := range retrieved.Scopes { if actual == expected { found = true break } } if !found { t.Errorf("Missing scope: %q", expected) } } }) } } func TestAPIKeyRepository_ProjectIDArrayHandling(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) repo := NewAPIKeyRepository(db) ctx := context.Background() tests := []struct { name string projectIDs []domain.ProjectID }{ {"nil projects", nil}, {"empty projects", []domain.ProjectID{}}, {"single project", []domain.ProjectID{"proj-a"}}, {"multiple projects", []domain.ProjectID{"proj-a", "proj-b", "proj-c"}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { key := &domain.APIKey{ Name: "test-projects-" + tt.name, KeyPrefix: "pr123456", Scopes: []domain.Scope{domain.ScopeProjectsRead}, ProjectIDs: tt.projectIDs, CreatedBy: "test", } repo.Create(ctx, key, hashKey("projects-"+tt.name)) retrieved, _ := repo.Get(ctx, key.ID) expectedLen := 0 if tt.projectIDs != nil { expectedLen = len(tt.projectIDs) } if len(retrieved.ProjectIDs) != expectedLen { t.Errorf("ProjectIDs length = %d, want %d", len(retrieved.ProjectIDs), expectedLen) } }) } } func TestAPIKeyRepository_AllowedIPsArrayHandling(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) repo := NewAPIKeyRepository(db) ctx := context.Background() tests := []struct { name string allowedIPs []string }{ {"nil IPs", nil}, {"empty IPs", []string{}}, {"single IP", []string{"192.168.1.100"}}, {"CIDR", []string{"10.0.0.0/8"}}, {"mixed IPs and CIDRs", []string{"192.168.1.0/24", "10.0.0.1", "2001:db8::/32"}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { key := &domain.APIKey{ Name: "test-ips-" + tt.name, KeyPrefix: "ip123456", Scopes: []domain.Scope{domain.ScopeProjectsRead}, AllowedIPs: tt.allowedIPs, CreatedBy: "test", } if err := repo.Create(ctx, key, hashKey("ips-"+tt.name)); err != nil { t.Fatalf("Create() error = %v", err) } retrieved, err := repo.Get(ctx, key.ID) if err != nil { t.Fatalf("Get() error = %v", err) } expectedLen := 0 if tt.allowedIPs != nil { expectedLen = len(tt.allowedIPs) } if len(retrieved.AllowedIPs) != expectedLen { t.Errorf("AllowedIPs length = %d, want %d", len(retrieved.AllowedIPs), expectedLen) } // Verify content preserved for i, expected := range tt.allowedIPs { if i < len(retrieved.AllowedIPs) && retrieved.AllowedIPs[i] != expected { t.Errorf("AllowedIPs[%d] = %q, want %q", i, retrieved.AllowedIPs[i], expected) } } }) } } // Helper function conversion tests func TestScopesToStrings(t *testing.T) { scopes := []domain.Scope{domain.ScopeProjectsRead, domain.ScopeAdmin} strings := scopesToStrings(scopes) if len(strings) != 2 { t.Fatalf("Length = %d, want 2", len(strings)) } if strings[0] != "projects:read" { t.Errorf("strings[0] = %q, want %q", strings[0], "projects:read") } if strings[1] != "admin" { t.Errorf("strings[1] = %q, want %q", strings[1], "admin") } } func TestScopesFromStrings(t *testing.T) { strings := []string{"projects:read", "keys:manage"} scopes := scopesFromStrings(strings) if len(scopes) != 2 { t.Fatalf("Length = %d, want 2", len(scopes)) } if scopes[0] != domain.ScopeProjectsRead { t.Errorf("scopes[0] = %q, want %q", scopes[0], domain.ScopeProjectsRead) } if scopes[1] != domain.ScopeKeysManage { t.Errorf("scopes[1] = %q, want %q", scopes[1], domain.ScopeKeysManage) } } func TestProjectIDsToStrings(t *testing.T) { t.Run("nil input", func(t *testing.T) { result := projectIDsToStrings(nil) if result != nil { t.Errorf("Expected nil, got %v", result) } }) t.Run("non-nil input", func(t *testing.T) { ids := []domain.ProjectID{"proj-a", "proj-b"} result := projectIDsToStrings(ids) if len(result) != 2 { t.Fatalf("Length = %d, want 2", len(result)) } if result[0] != "proj-a" || result[1] != "proj-b" { t.Errorf("Unexpected result: %v", result) } }) } func TestProjectIDsFromStrings(t *testing.T) { t.Run("nil input", func(t *testing.T) { result := projectIDsFromStrings(nil) if result != nil { t.Errorf("Expected nil, got %v", result) } }) t.Run("non-nil input", func(t *testing.T) { strings := []string{"proj-x", "proj-y"} result := projectIDsFromStrings(strings) if len(result) != 2 { t.Fatalf("Length = %d, want 2", len(result)) } if result[0] != "proj-x" || result[1] != "proj-y" { t.Errorf("Unexpected result: %v", result) } }) }