// Package testutil provides testing utilities for rdev-api. package testutil import ( "context" "database/sql" "log/slog" "os" "testing" "time" "github.com/orchard9/rdev/internal/db" ) // TestDB returns a database connection for testing. // Uses TEST_DATABASE_URL or falls back to the standard local dev connection. // Automatically runs migrations to ensure schema is up to date. func TestDB(t *testing.T) *sql.DB { t.Helper() // Use db.New() to get a connection with migrations applied cfg := db.Config{ Host: "localhost", Port: 5433, User: "appuser", Password: "localdev", Database: "rdev", SSLMode: "disable", } // Check for override if dsn := os.Getenv("TEST_DATABASE_URL"); dsn != "" { // Parse DSN - for simplicity, just use it directly with sql.Open // This path is for CI/CD environments rawDB, err := sql.Open("postgres", dsn) if err != nil { t.Fatalf("open database: %v", err) } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := rawDB.PingContext(ctx); err != nil { t.Skipf("database not available: %v", err) } t.Cleanup(func() { _ = rawDB.Close() }) return rawDB } // Use the db package which handles migrations logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelWarn})) database, err := db.New(cfg, logger) if err != nil { // Check if it's a connection error vs migration error if ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second); true { defer cancel() rawDB, openErr := sql.Open("postgres", cfg.DSN()) if openErr == nil { if pingErr := rawDB.PingContext(ctx); pingErr != nil { t.Skipf("database not available: %v", pingErr) } _ = rawDB.Close() } } t.Fatalf("open database with migrations: %v", err) } t.Cleanup(func() { _ = database.Close() }) return database.DB } // CleanupTestKeys removes all test keys from the database. func CleanupTestKeys(t *testing.T, db *sql.DB) { t.Helper() _, err := db.Exec("DELETE FROM api_keys WHERE name LIKE 'test-%'") if err != nil { t.Fatalf("cleanup test keys: %v", err) } } // TimePtr returns a pointer to a time.Time. func TimePtr(t time.Time) *time.Time { return &t } // MustParseTime parses a time string or panics. func MustParseTime(layout, value string) time.Time { t, err := time.Parse(layout, value) if err != nil { panic(err) } return t }