package zdb

import (
	"context"

	"github.com/georgysavva/scany/v2/pgxscan"
	"github.com/jackc/pgx/v5"
	"github.com/pkg/errors"
)

type Tx struct {
	tx  pgx.Tx
	ctx context.Context

	pool *Pool
}

func (d *Pool) Tx() (*Tx, error) {
	return d.TxNew(d.ctx)
}

func (d *Pool) TxNew(ctx context.Context) (*Tx, error) {
	pgTx, err := d.SrvMaster.Begin(ctx)

	return &Tx{tx: pgTx, ctx: ctx, pool: d}, err
}

func (d *Pool) TxNewOpts(ctx context.Context, opts pgx.TxOptions) (*Tx, error) {
	pgTx, err := d.SrvMaster.BeginTx(ctx, opts)

	return &Tx{tx: pgTx, ctx: ctx, pool: d}, err
}

func (d *Pool) MustTx() *Tx {
	tx, _ := d.TxNew(d.ctx)

	return tx
}

func (d *Pool) MustTxCtx(ctx context.Context) *Tx {
	tx, _ := d.TxNew(ctx)

	return tx
}

func (t *Tx) Rollback() error {
	return t.tx.Rollback(t.ctx)
}

func (t *Tx) Commit() error {
	return t.tx.Commit(t.ctx)
}

func (t *Tx) Invoke(fn func(*Tx) error) error {
	if err := fn(t); err != nil {
		if err1 := t.Rollback(); err1 != nil {
			return errors.Wrap(err1, err.Error())
		}

		return err
	}

	if err := t.Commit(); err != nil {
		_ = t.Rollback()

		return err
	}

	return nil
}

func (t *Tx) Get(dst any, sql string, args ...any) error {
	return pgxscan.Get(t.ctx, t.tx, dst, sql, args...)
}
func (t *Tx) GetNamed(dst any, sql string, args map[string]any) error {
	newSql, newArgs := t.pool.prepare(sql, args)

	return t.Get(dst, newSql, newArgs...)
}
func (t *Tx) GetOpts(dst any, opts Opts) error {
	sql, args := opts.Opts()
	newSql, newArgs := t.pool.prepare(sql, args)

	return t.Get(dst, newSql, newArgs...)
}

func (t *Tx) Select(dst any, sql string, args ...any) error {
	return pgxscan.Select(t.ctx, t.tx, dst, sql, args...)
}
func (t *Tx) SelectNamed(dst any, sql string, args map[string]any) error {
	newSql, newArgs := t.pool.prepare(sql, args)

	return t.Select(dst, newSql, newArgs...)
}
func (t *Tx) SelectOpts(dst any, opts Opts) error {
	sql, args := opts.Opts()
	newSql, newArgs := t.pool.prepare(sql, args)

	return t.Select(dst, newSql, newArgs...)
}

func (t *Tx) Exec(sql string, args ...any) error {
	_, err := t.ExecQty(sql, args...)

	return err
}
func (t *Tx) ExecNamed(sql string, args map[string]any) error {
	_, err := t.ExecNamedQty(sql, args)

	return err
}
func (t *Tx) ExecOpts(opts Opts) error {
	_, err := t.ExecOptsQty(opts)

	return err
}

func (t *Tx) ExecQty(sql string, args ...any) (int, error) {
	s, err := t.tx.Exec(t.ctx, sql, args...)
	if err != nil {
		return 0, err
	}

	return int(s.RowsAffected()), nil
}
func (t *Tx) ExecNamedQty(sql string, args map[string]any) (int, error) {
	newSql, newArgs := t.pool.prepare(sql, args)

	return t.ExecQty(newSql, newArgs...)
}
func (t *Tx) ExecOptsQty(opts Opts) (int, error) {
	sql, args := opts.Opts()
	newSql, newArgs := t.pool.prepare(sql, args)

	return t.ExecQty(newSql, newArgs...)
}
