package database import ( "context" "embed" "fmt" "io/fs" "log/slog" "path/filepath" "sort" "strings" ) // MigrationsFS is the embedded filesystem containing migration files. // Services should embed their migrations directory and pass it here: // // //go:embed migrations/*.sql // var migrationsFS embed.FS // // database.RunMigrations(ctx, pool.DB, migrationsFS, "migrations") type MigrationsFS = embed.FS // RunMigrations executes all pending SQL migrations in order. // Migration files must be named with a numeric prefix for ordering: // // migrations/001_create_users.sql // migrations/002_add_email_index.sql // // Migrations are tracked in a schema_migrations table. // Each migration runs in a transaction and is idempotent. func RunMigrations(ctx context.Context, pool *Pool, migrations MigrationsFS, dir string) error { if pool == nil || pool.DB == nil { return fmt.Errorf("database pool is required") } logger := slog.Default().With("component", "migrations") // Ensure migrations table exists if err := ensureMigrationsTable(ctx, pool); err != nil { return fmt.Errorf("failed to create migrations table: %w", err) } // Get applied migrations applied, err := getAppliedMigrations(ctx, pool) if err != nil { return fmt.Errorf("failed to get applied migrations: %w", err) } // Read migration files entries, err := fs.ReadDir(migrations, dir) if err != nil { return fmt.Errorf("failed to read migrations directory: %w", err) } // Sort by filename (numeric prefix ensures order) var files []string for _, entry := range entries { if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sql") { continue } files = append(files, entry.Name()) } sort.Strings(files) // Run pending migrations for _, filename := range files { if applied[filename] { continue } logger.Info("running migration", "file", filename) content, err := fs.ReadFile(migrations, filepath.Join(dir, filename)) if err != nil { return fmt.Errorf("failed to read migration %s: %w", filename, err) } if err := runMigration(ctx, pool, filename, string(content)); err != nil { return fmt.Errorf("failed to run migration %s: %w", filename, err) } logger.Info("migration complete", "file", filename) } return nil } // ensureMigrationsTable creates the schema_migrations table if it doesn't exist. func ensureMigrationsTable(ctx context.Context, pool *Pool) error { _, err := pool.DB.ExecContext(ctx, ` CREATE TABLE IF NOT EXISTS schema_migrations ( filename VARCHAR(255) PRIMARY KEY, applied_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() ) `) return err } // getAppliedMigrations returns a set of already-applied migration filenames. func getAppliedMigrations(ctx context.Context, pool *Pool) (map[string]bool, error) { rows, err := pool.DB.QueryContext(ctx, `SELECT filename FROM schema_migrations`) if err != nil { return nil, err } defer rows.Close() applied := make(map[string]bool) for rows.Next() { var filename string if err := rows.Scan(&filename); err != nil { return nil, err } applied[filename] = true } return applied, rows.Err() } // runMigration executes a single migration in a transaction. func runMigration(ctx context.Context, pool *Pool, filename, content string) error { tx, err := pool.DB.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } defer func() { _ = tx.Rollback() }() // Execute migration SQL if _, err := tx.ExecContext(ctx, content); err != nil { return fmt.Errorf("failed to execute migration: %w", err) } // Record migration if _, err := tx.ExecContext(ctx, ` INSERT INTO schema_migrations (filename) VALUES ($1) `, filename); err != nil { return fmt.Errorf("failed to record migration: %w", err) } return tx.Commit() } // MustRunMigrations is like RunMigrations but panics on error. // Useful in main() for fail-fast initialization. func MustRunMigrations(ctx context.Context, pool *Pool, migrations MigrationsFS, dir string) { if err := RunMigrations(ctx, pool, migrations, dir); err != nil { panic(fmt.Sprintf("failed to run migrations: %v", err)) } }