package zdb

import (
	"context"
	"fmt"
	"reflect"
	"runtime"
	"strings"
	"sync"
	"sync/atomic"
	"time"

	"github.com/jackc/pgx/v5/pgxpool"
)

type Pool struct {
	SrvMaster     *Conn
	SrvSync       []*Conn
	SrvAsync      []*Conn
	notAliveConns []*Conn

	slavesIter      *atomic.Int64
	slavesAsyncIter *atomic.Int64
	Continues       []string
	ContinuesTry    []string
	TryOnError      int
	TryOnSleep      time.Duration
	PingTimeout     time.Duration
	PingTry         int
	Debug           bool

	ctx    context.Context
	logger Logger
	mu     *sync.RWMutex
}

func (d *Pool) WithContext(ctx context.Context) *Pool {
	return &Pool{
		ctx:             ctx,
		logger:          d.logger,
		SrvMaster:       d.SrvMaster,
		SrvSync:         d.SrvSync,
		SrvAsync:        d.SrvAsync,
		mu:              d.mu,
		notAliveConns:   d.notAliveConns,
		slavesIter:      d.slavesIter,
		slavesAsyncIter: d.slavesAsyncIter,
		Continues:       d.Continues,
		ContinuesTry:    d.ContinuesTry,
		TryOnError:      d.TryOnError,
		TryOnSleep:      d.TryOnSleep,
		PingTimeout:     d.PingTimeout,
		PingTry:         d.PingTry,
	}
}

func (d *Pool) WithTimeout(dur time.Duration) *Pool {
	ctx, cancel := context.WithTimeout(d.ctx, dur)

	go func() {
		time.Sleep(dur)
		cancel()
	}()

	return d.WithContext(ctx)
}

func (d *Pool) WithDeadline(dur time.Time) *Pool {
	ctx, cancel := context.WithDeadline(d.ctx, dur)
	go func() {
		time.Sleep(time.Since(dur))
		cancel()
	}()

	return d.WithContext(ctx)
}

func (d *Pool) NewConn(mode connMode, pgConnString string) error {
	d.mu.Lock()
	defer d.mu.Unlock()

	q, err := d.newConn(mode, pgConnString)

	switch mode {
	case ConnModeMaster:
		d.SrvMaster = q
	case ConnModeSync:
		q.Index = len(d.SrvSync)
		d.SrvSync = append(d.SrvSync, q)
	case ConnModeAsync:
		q.Index = len(d.SrvAsync)
		d.SrvAsync = append(d.SrvAsync, q)
	default:
		panic("unknown mode")
	}

	return err
}

func (d *Pool) NewConns(mode connMode, pgConnString ...string) error {
	for _, s := range pgConnString {
		if err := d.NewConn(mode, s); err != nil {
			return err
		}
	}

	return nil
}

func (d *Pool) newConn(mode connMode, pgConnString string) (q *Conn, err error) {
	var pgxPool *pgxpool.Pool
	var pgxConfig *pgxpool.Config

	if !strings.Contains(pgConnString, "default_query_exec_mode=") {
		pgConnString += " default_query_exec_mode=simple_protocol"
	}

	if !strings.Contains(pgConnString, "connect_timeout=") {
		pgConnString += " connect_timeout=3"
	}

	if !strings.Contains(pgConnString, "sslmode=") {
		pgConnString += " sslmode=disable"
	}

	if !strings.Contains(pgConnString, "pool_max_conns=") {
		pgConnString += fmt.Sprintf(" pool_max_conns=%d", runtime.NumCPU()*2)
	}

	if pgxConfig, err = pgxpool.ParseConfig(pgConnString); err != nil {
		return nil, err
	}

	if pgxPool, err = pgxpool.NewWithConfig(d.ctx, pgxConfig); err != nil {
		return &Conn{Pool: pgxPool, Alive: false, Mode: mode}, err
	}

	q = &Conn{Pool: pgxPool, Alive: false, Mode: mode}

	if err = d.Ping(q); err != nil {
		return q, err
	}

	q.Alive = true

	return q, nil
}

func (d *Pool) least(s []*Conn) *Conn {
	var out *Conn
	var m float64

	for i, conn := range s {
		ratio := float64(conn.Stat().AcquiredConns()) / float64(conn.Stat().MaxConns())

		if ratio < m || i == 0 {
			m = ratio
			out = conn
		}
	}

	logConnStat(out)

	return out
}

func (d *Pool) sync() *Conn {
	if len(d.SrvSync) == 0 {
		return d.SrvMaster
	}

	d.mu.RLock()
	defer d.mu.RUnlock()

	if len(d.SrvSync) == 1 {
		return d.SrvSync[0]
	}

	return d.least(d.SrvSync)
}

func (d *Pool) async() *Conn {
	if len(d.SrvAsync) == 0 {
		return d.sync()
	}

	d.mu.RLock()
	defer d.mu.RUnlock()

	if len(d.SrvAsync) == 1 {
		return d.SrvAsync[0]
	}

	return d.least(d.SrvAsync)
}

func (d *Pool) execWrapper(pool connMode, dst any, f func(conn *Conn, dst1 any) error) error {
	for {
		var q *Conn
		try := 0

		if pool == ConnModeSync {
			q = d.sync()
		} else {
			q = d.async()
		}

	repeat:
		if err := f(q, dst); err != nil {
			if q.Mode == ConnModeMaster {
				return err
			}

			if try < d.TryOnError && contains(err.Error(), d.ContinuesTry) {
				try++
				d.logger.Printf("ZDB_EXEC_WRAPPER_REPEAT_ERR: SRV: %s TRY: %d; %s", q.ToString(), try, err.Error())

				time.Sleep(d.TryOnSleep)

				goto repeat
			}

			if contains(err.Error(), d.Continues) {
				d.setNotAliveConn(q)
				d.logger.Printf("ZDB_EXEC_WRAPPER_ERR: SRV: %s; %s", q.ToString(), err.Error())

				continue
			}

			return err
		}

		return nil
	}
}

func (d *Pool) Ping(q *Conn) (err error) {
	var n any

	if err = d.WithTimeout(d.PingTimeout).qGet(q, &n, "SELECT 1"); err != nil {
		q.PingTry++

		d.logger.Printf("ZDB_PING_ERR: SRV: %s; TRY: %d; %v",
			q.ToString(),
			q.PingTry,
			err,
		)

		if d.PingTry <= q.PingTry {
			return nil
		}
	}

	q.PingTry = 0

	return
}

func (d *Pool) setNotAliveConn(conn *Conn) {
	d.mu.Lock()
	defer d.mu.Unlock()

	for i, slave := range d.SrvSync {
		if slave == conn {
			conn.Alive = false
			d.notAliveConns = append(d.notAliveConns, conn)
			d.SrvSync = remove(d.SrvSync, i)

			return
		}
	}

	for i, slave := range d.SrvAsync {
		if slave == conn {
			conn.Alive = false
			d.notAliveConns = append(d.notAliveConns, conn)
			d.SrvAsync = remove(d.SrvAsync, i)

			return
		}
	}
}

func (d *Pool) startPing() {
	go func() {
		for {

		rep:
			for i, q := range d.notAliveConns {
				if err := d.Ping(q); err == nil {
					d.mu.Lock()
					q.Alive = true
					d.notAliveConns = remove(d.notAliveConns, i)
					if q.Mode == ConnModeSync {
						d.SrvSync = append(d.SrvSync, q)
					} else if q.Mode == ConnModeAsync {
						d.SrvAsync = append(d.SrvAsync, q)
					}
					d.mu.Unlock()

					goto rep
				}
			}

			if d.SrvMaster != nil {
				d.SrvMaster.Alive = d.Ping(d.SrvMaster) == nil
			}

			time.Sleep(time.Second)
		}
	}()
}

func (d *Pool) IsAlive() bool {
	return d.SrvMaster != nil && d.SrvMaster.Alive
}

func (d *Pool) prepare(sql string, param map[string]any) (string, []any) {
	args := make([]any, 0)

	idx := 0
	for n, t := range param {
		if n[0] == '_' {
			continue
		}

		if !strings.Contains(sql, ":"+n+":") {
			continue
		}

		switch v := t.(type) {
		case time.Time:
			idx++
			sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d::timestamptz", idx))
			args = append(args, v.Format(time.RFC3339Nano))
		case *time.Time:
			if v == nil {
				sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
			} else {
				idx++
				sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d::timestamptz", idx))
				args = append(args, v.Format(time.RFC3339Nano))
			}
		case nil:
			sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
		case bool:
			if v {
				sql = strings.ReplaceAll(sql, ":"+n+":", "TRUE")
			} else {
				sql = strings.ReplaceAll(sql, ":"+n+":", "FALSE")
			}
		case *string:
			if v == nil || *v == "NULL" {
				sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
			} else {
				idx++
				sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx))
				args = append(args, v)
			}
		case string:
			if v == "NULL" {
				sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
			} else {
				idx++
				sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx))
				args = append(args, v)
			}
		case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64, *float32, *float64:
			if v == nil {
				sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
			} else {
				idx++
				sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx))
				args = append(args, v)
			}
		case *bool:
			if v == nil {
				sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
			} else {
				if *v {
					sql = strings.ReplaceAll(sql, ":"+n+":", "TRUE")
				} else {
					sql = strings.ReplaceAll(sql, ":"+n+":", "FALSE")
				}
			}
		case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
			idx++
			sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx))
			args = append(args, v)
		default:
			switch reflect.TypeOf(v).Kind() {
			case reflect.Slice, reflect.Array:
				tail := ""
				if !strings.Contains(sql, ":"+n+":::") {
					switch reflect.TypeOf(v).Elem().Kind().String() {
					case "string":
						tail = "::text[]"
					case "bool":
						tail = "::bool[]"
					default:
						tail = "::int[]"
					}
				}

				idx++
				sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d%s", idx, tail))
				args = append(args, "{"+strings.Trim(strings.Join(strings.Split(fmt.Sprint(v), " "), ","), "[]")+"}")

			default:
				idx++
				sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx))
				args = append(args, v)
			}
		}

	}

	if d.Debug {
		d.logger.Printf("\n---SQL_START---\n%v\n--SQL_END---\n\n---ARGS_START---\n%+v\n---ARGS_END---", sql, args)
	} else if v, ok := param["_debug"]; ok {
		switch vv := v.(type) {
		case bool:
			if vv {
				d.logger.Printf("\n---SQL_START---\n%v\n--SQL_END---\n\n---ARGS_START---\n%+v\n---ARGS_END---", sql, args)
			}
		case int, uint:
			if vv == 1 {
				d.logger.Printf("\n---SQL_START---\n%v\n--SQL_END---\n\n---ARGS_START---\n%+v\n---ARGS_END---", sql, args)
			}
		case string:
			if vv == "1" {
				d.logger.Printf("\n---SQL_START---\n%v\n--SQL_END---\n\n---ARGS_START---\n%+v\n---ARGS_END---", sql, args)
			}
		}
	}

	return sql, args
}
