package zdb import ( "context" "errors" "fmt" "git.barsukov.pro/barsukov/zgo/zutils" "github.com/jackc/pgx/v5/pgxpool" "log" "reflect" "strings" "sync" "sync/atomic" "time" ) type Balance int const ( BalanceRoundRobin Balance = iota BalanceLeastConn ) type Pool struct { ctx context.Context logger Logger SrvMaster *Conn SrvSync []*Conn SrvAsync []*Conn mu *sync.RWMutex notAliveConns []*Conn slavesIter *atomic.Int64 slavesAsyncIter *atomic.Int64 stop bool PgTsFormat string Continues []string ContinuesTry []string TryOnError int TryOnSleep time.Duration Balance Balance } func New(ctx context.Context) *Pool { return &Pool{ ctx: ctx, mu: &sync.RWMutex{}, slavesIter: &atomic.Int64{}, slavesAsyncIter: &atomic.Int64{}, PgTsFormat: "2006-01-02 15:04:05", Continues: []string{"connect", "EOF", "conflict with recovery"}, ContinuesTry: []string{"conflict with recovery"}, TryOnError: 1, TryOnSleep: time.Second, Balance: BalanceLeastConn, } } func NewDefault() *Pool { p := New(context.Background()) p.logger = log.Default() return p } func (d *Pool) WithContext(ctx context.Context) *Pool { return &Pool{ ctx: ctx, logger: d.logger, SrvMaster: d.SrvMaster, SrvSync: d.SrvSync, SrvAsync: d.SrvAsync, mu: d.mu, notAliveConns: d.notAliveConns, slavesIter: d.slavesIter, slavesAsyncIter: d.slavesAsyncIter, stop: d.stop, PgTsFormat: d.PgTsFormat, Continues: d.Continues, ContinuesTry: d.ContinuesTry, TryOnError: d.TryOnError, TryOnSleep: d.TryOnSleep, Balance: d.Balance, } } func (d *Pool) WithTimeout(dur time.Duration) *Pool { ctx, _ := context.WithTimeout(d.ctx, dur) return d.WithContext(ctx) } func (d *Pool) WithDeadline(dur time.Time) *Pool { ctx, _ := context.WithDeadline(d.ctx, dur) return d.WithContext(ctx) } func (d *Pool) NewConn(mode connMode, pgConnString string) error { d.mu.Lock() defer d.mu.Unlock() q, err := d.newConn(mode, pgConnString) switch mode { case ConnModeMaster: d.SrvMaster = q case ConnModeSync: d.SrvSync = append(d.SrvSync, q) case ConnModeAsync: d.SrvAsync = append(d.SrvAsync, q) default: return errors.New("unknown mode") } return err } func (d *Pool) NewConns(mode connMode, pgConnString ...string) error { for _, s := range pgConnString { if err := d.NewConn(mode, s); err != nil { return err } } return nil } func (d *Pool) newConn(mode connMode, pgConnString string) (q *Conn, err error) { var pgxPool *pgxpool.Pool var pgxConfig *pgxpool.Config if !strings.Contains(pgConnString, "default_query_exec_mode=") { pgConnString += " default_query_exec_mode=simple_protocol" } if !strings.Contains(pgConnString, "connect_timeout=") { pgConnString += " connect_timeout=3" } if !strings.Contains(pgConnString, "sslmode=") { pgConnString += " sslmode=disable" } if pgxConfig, err = pgxpool.ParseConfig(pgConnString); err != nil { return nil, err } if pgxPool, err = pgxpool.NewWithConfig(d.ctx, pgxConfig); err != nil { return &Conn{Pool: pgxPool, Alive: false, Mode: mode}, err } q = &Conn{Pool: pgxPool, Alive: false, Mode: mode} if err = d.ping(q); err != nil { return q, err } q.Alive = true return q, nil } func (d *Pool) sync() *Conn { if len(d.SrvSync) == 0 { return d.SrvMaster } d.mu.RLock() defer d.mu.RUnlock() if len(d.SrvSync) == 1 { return d.SrvSync[0] } if d.Balance == BalanceRoundRobin { return d.SrvSync[d.slavesIter.Add(1)%int64(len(d.SrvSync))] } var out *Conn var min int32 = 1 << 30 for _, conn := range d.SrvSync { if conn.Stat().AcquiredConns() < min { min = conn.Stat().AcquiredConns() out = conn } } return out } func (d *Pool) async() *Conn { if len(d.SrvAsync) == 0 { return d.sync() } d.mu.RLock() defer d.mu.RUnlock() if len(d.SrvAsync) == 1 { return d.SrvAsync[0] } if d.Balance == BalanceRoundRobin { return d.SrvAsync[d.slavesAsyncIter.Add(1)%int64(len(d.SrvAsync))] } var out *Conn var min int32 = 1 << 30 for _, conn := range d.SrvAsync { if conn.Stat().AcquiredConns() < min { min = conn.Stat().AcquiredConns() out = conn } } return out } func (d *Pool) execWrapper(pool connMode, dst any, f func(conn *Conn, dst1 any) error) error { for { var q *Conn try := 0 if pool == ConnModeSync { q = d.sync() } else { q = d.async() } repeat: if err := f(q, dst); err != nil { if q.Mode == ConnModeMaster { return err } else { if try < d.TryOnError && zutils.Contains(err.Error(), d.ContinuesTry) { try++ d.logger.Printf("DB_EXEC_WRAPPER_REPEAT_ERR: SRV: %s TRY: %d; %s", q.ToString(), try, err.Error()) time.Sleep(d.TryOnSleep) goto repeat } if zutils.Contains(err.Error(), d.Continues) { d.setNotAliveConn(q) d.logger.Printf("DB_EXEC_WRAPPER_ERR: SRV: %s; %s", q.ToString(), err.Error()) continue } else { return err } } } return nil } } func (d *Pool) ping(q *Conn) (err error) { var n any if err = d.qGet(q, &n, "SELECT 1"); err != nil { d.logger.Printf("DB_PING_ERR: SRV: %s; %v", q.ToString(), err, ) } return } func (d *Pool) setNotAliveConn(conn *Conn) { d.mu.Lock() defer d.mu.Unlock() for i, slave := range d.SrvSync { if slave == conn { conn.Alive = false d.notAliveConns = append(d.notAliveConns, conn) d.SrvSync = zutils.Remove(d.SrvSync, i) return } } for i, slave := range d.SrvAsync { if slave == conn { conn.Alive = false d.notAliveConns = append(d.notAliveConns, conn) d.SrvAsync = zutils.Remove(d.SrvAsync, i) return } } } func (d *Pool) Start() { d.stop = false go func() { for { if d.stop { return } for i, q := range d.notAliveConns { if err := d.ping(q); err == nil { d.mu.Lock() q.Alive = true d.notAliveConns = zutils.Remove(d.notAliveConns, i) if q.Mode == ConnModeSync { d.SrvSync = append(d.SrvSync, q) } else if q.Mode == ConnModeAsync { d.SrvAsync = append(d.SrvAsync, q) } d.mu.Unlock() } } if d.SrvMaster != nil { d.SrvMaster.Alive = d.ping(d.SrvMaster) == nil } time.Sleep(time.Second) } }() } func (d *Pool) Stop() { d.stop = true } func (d *Pool) IsAlive() bool { return d.SrvMaster != nil && d.SrvMaster.Alive } func (d *Pool) prepare(sql string, param map[string]any) string { for n, t := range param { switch t.(type) { case time.Time: sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%s'", t.(time.Time).UTC().Format(time.DateTime))) case *time.Time: if t.(*time.Time) == nil { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%s'", t.(*time.Time).UTC().Format(time.DateTime))) } case nil: sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") case bool: if t.(bool) { sql = strings.ReplaceAll(sql, ":"+n+":", "TRUE") } else { sql = strings.ReplaceAll(sql, ":"+n+":", "FALSE") } case *string: if t.(*string) == nil || *t.(*string) == "NULL" { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%v'", strings.ReplaceAll(*t.(*string), "'", "''"))) } case string: if t.(string) == "NULL" { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%v'", strings.ReplaceAll(t.(string), "'", "''"))) } case *int: if t.(*int) == nil { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("%v", *t.(*int))) } case *bool: if t.(*bool) == nil { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { if *t.(*bool) { sql = strings.ReplaceAll(sql, ":"+n+":", "TRUE") } else { sql = strings.ReplaceAll(sql, ":"+n+":", "FALSE") } } case *int64: if t.(*int64) == nil { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("%v", *t.(*int64))) } case *float64: if t.(*float64) == nil { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("%v", *t.(*float64))) } case int, int64: sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("%v", t)) default: switch reflect.TypeOf(t).Kind() { case reflect.Slice: sql = strings.ReplaceAll(sql, ":"+n+":", "'{"+strings.Trim(strings.Join(strings.Split(fmt.Sprint(t), " "), ","), "[]")+"}'") } sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%v'", t)) } } if _, ok := param["_debug"]; ok { d.logger.Printf(sql) } return sql }