108 lines
2.9 KiB
Go
108 lines
2.9 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
)
|
|
|
|
// TxFn is a function that runs within a transaction.
|
|
// If it returns an error, the transaction is rolled back.
|
|
// If it returns nil, the transaction is committed.
|
|
type TxFn func(tx *sql.Tx) error
|
|
|
|
// WithTx executes a function within a database transaction.
|
|
// The transaction is automatically committed on success or rolled back on error.
|
|
//
|
|
// Usage:
|
|
//
|
|
// err := database.WithTx(ctx, pool, func(tx *sql.Tx) error {
|
|
// _, err := tx.ExecContext(ctx, "INSERT INTO users (name) VALUES ($1)", name)
|
|
// if err != nil {
|
|
// return err
|
|
// }
|
|
// _, err = tx.ExecContext(ctx, "INSERT INTO audit_log (action) VALUES ($1)", "user_created")
|
|
// return err
|
|
// })
|
|
func WithTx(ctx context.Context, pool *Pool, fn TxFn) error {
|
|
tx, err := pool.DB.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
|
|
defer func() {
|
|
if p := recover(); p != nil {
|
|
_ = tx.Rollback()
|
|
panic(p) // re-throw panic after rollback
|
|
}
|
|
}()
|
|
|
|
if err := fn(tx); err != nil {
|
|
if rbErr := tx.Rollback(); rbErr != nil {
|
|
return fmt.Errorf("tx failed: %w, rollback failed: %v", err, rbErr)
|
|
}
|
|
return err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// TxOptions configures transaction behavior.
|
|
type TxOptions struct {
|
|
// Isolation sets the transaction isolation level.
|
|
// Default: sql.LevelDefault (database default)
|
|
Isolation sql.IsolationLevel
|
|
|
|
// ReadOnly marks the transaction as read-only.
|
|
// Useful for reporting queries that should never modify data.
|
|
ReadOnly bool
|
|
}
|
|
|
|
// WithTxOptions is like WithTx but with configurable transaction options.
|
|
func WithTxOptions(ctx context.Context, pool *Pool, opts TxOptions, fn TxFn) error {
|
|
tx, err := pool.DB.BeginTx(ctx, &sql.TxOptions{
|
|
Isolation: opts.Isolation,
|
|
ReadOnly: opts.ReadOnly,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
|
|
defer func() {
|
|
if p := recover(); p != nil {
|
|
_ = tx.Rollback()
|
|
panic(p)
|
|
}
|
|
}()
|
|
|
|
if err := fn(tx); err != nil {
|
|
if rbErr := tx.Rollback(); rbErr != nil {
|
|
return fmt.Errorf("tx failed: %w, rollback failed: %v", err, rbErr)
|
|
}
|
|
return err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ReadOnlyTx executes a function in a read-only transaction.
|
|
// This is useful for queries that should never accidentally modify data.
|
|
func ReadOnlyTx(ctx context.Context, pool *Pool, fn TxFn) error {
|
|
return WithTxOptions(ctx, pool, TxOptions{ReadOnly: true}, fn)
|
|
}
|
|
|
|
// SerializableTx executes a function in a serializable transaction.
|
|
// Use this for operations that require the strongest isolation.
|
|
// Note: May need retry logic for serialization failures.
|
|
func SerializableTx(ctx context.Context, pool *Pool, fn TxFn) error {
|
|
return WithTxOptions(ctx, pool, TxOptions{Isolation: sql.LevelSerializable}, fn)
|
|
}
|