package zdb import ( "context" "fmt" "github.com/jackc/pgx/v5/pgxpool" "log" "reflect" "strings" "sync" "sync/atomic" "time" ) type Balance int const ( BalanceRoundRobin Balance = iota BalanceLeastConn ) type Pool struct { ctx context.Context logger Logger SrvMaster *Conn SrvSync []*Conn SrvAsync []*Conn mu *sync.RWMutex notAliveConns []*Conn slavesIter *atomic.Int64 slavesAsyncIter *atomic.Int64 stop bool Continues []string ContinuesTry []string TryOnError int TryOnSleep time.Duration Balance Balance PingTimout time.Duration PingTry int } func New(ctx context.Context) *Pool { return &Pool{ ctx: ctx, mu: &sync.RWMutex{}, slavesIter: &atomic.Int64{}, slavesAsyncIter: &atomic.Int64{}, Continues: []string{"connect", "EOF", "conflict with recovery", "context deadline exceeded"}, ContinuesTry: []string{"conflict with recovery"}, TryOnError: 1, TryOnSleep: time.Second, Balance: BalanceLeastConn, PingTimout: time.Second * 5, PingTry: 5, } } func NewDefault() *Pool { p := New(context.Background()) p.logger = log.Default() return p } func (d *Pool) WithContext(ctx context.Context) *Pool { return &Pool{ ctx: ctx, logger: d.logger, SrvMaster: d.SrvMaster, SrvSync: d.SrvSync, SrvAsync: d.SrvAsync, mu: d.mu, notAliveConns: d.notAliveConns, slavesIter: d.slavesIter, slavesAsyncIter: d.slavesAsyncIter, stop: d.stop, Continues: d.Continues, ContinuesTry: d.ContinuesTry, TryOnError: d.TryOnError, TryOnSleep: d.TryOnSleep, Balance: d.Balance, PingTimout: d.PingTimout, PingTry: d.PingTry, } } func (d *Pool) WithTimeout(dur time.Duration) *Pool { ctx, cancel := context.WithTimeout(d.ctx, dur) go func() { time.Sleep(dur) cancel() }() return d.WithContext(ctx) } func (d *Pool) WithDeadline(dur time.Time) *Pool { ctx, cancel := context.WithDeadline(d.ctx, dur) go func() { time.Sleep(time.Since(dur)) cancel() }() return d.WithContext(ctx) } func (d *Pool) NewConn(mode connMode, pgConnString string) error { d.mu.Lock() defer d.mu.Unlock() q, err := d.newConn(mode, pgConnString) switch mode { case ConnModeMaster: d.SrvMaster = q case ConnModeSync: q.Index = len(d.SrvSync) d.SrvSync = append(d.SrvSync, q) case ConnModeAsync: q.Index = len(d.SrvAsync) d.SrvAsync = append(d.SrvAsync, q) default: panic("unknown mode") } return err } func (d *Pool) NewConns(mode connMode, pgConnString ...string) error { for _, s := range pgConnString { if err := d.NewConn(mode, s); err != nil { return err } } return nil } func (d *Pool) newConn(mode connMode, pgConnString string) (q *Conn, err error) { var pgxPool *pgxpool.Pool var pgxConfig *pgxpool.Config if !strings.Contains(pgConnString, "default_query_exec_mode=") { pgConnString += " default_query_exec_mode=simple_protocol" } if !strings.Contains(pgConnString, "connect_timeout=") { pgConnString += " connect_timeout=3" } if !strings.Contains(pgConnString, "sslmode=") { pgConnString += " sslmode=disable" } if pgxConfig, err = pgxpool.ParseConfig(pgConnString); err != nil { return nil, err } if pgxPool, err = pgxpool.NewWithConfig(d.ctx, pgxConfig); err != nil { return &Conn{Pool: pgxPool, Alive: false, Mode: mode}, err } q = &Conn{Pool: pgxPool, Alive: false, Mode: mode} if err = d.Ping(q); err != nil { return q, err } q.Alive = true return q, nil } func (d *Pool) least(s []*Conn) *Conn { var out *Conn var m float64 = 0 for i, conn := range s { ratio := float64(conn.Stat().AcquiredConns()) / float64(conn.Stat().MaxConns()) if ratio < m || i == 0 { m = ratio out = conn } } logConnStat(out) return out } func (d *Pool) sync() *Conn { if len(d.SrvSync) == 0 { return d.SrvMaster } d.mu.RLock() defer d.mu.RUnlock() if len(d.SrvSync) == 1 { return d.SrvSync[0] } if d.Balance == BalanceRoundRobin { return d.SrvSync[d.slavesIter.Add(1)%int64(len(d.SrvSync))] } return d.least(d.SrvSync) } func (d *Pool) async() *Conn { if len(d.SrvAsync) == 0 { return d.sync() } d.mu.RLock() defer d.mu.RUnlock() if len(d.SrvAsync) == 1 { return d.SrvAsync[0] } if d.Balance == BalanceRoundRobin { return d.SrvAsync[d.slavesAsyncIter.Add(1)%int64(len(d.SrvAsync))] } return d.least(d.SrvAsync) } func (d *Pool) execWrapper(pool connMode, dst any, f func(conn *Conn, dst1 any) error) error { for { var q *Conn try := 0 if pool == ConnModeSync { q = d.sync() } else { q = d.async() } repeat: if err := f(q, dst); err != nil { if q.Mode == ConnModeMaster { return err } else { if try < d.TryOnError && contains(err.Error(), d.ContinuesTry) { try++ d.logger.Printf("ZDB_EXEC_WRAPPER_REPEAT_ERR: SRV: %s TRY: %d; %s", q.ToString(), try, err.Error()) time.Sleep(d.TryOnSleep) goto repeat } if contains(err.Error(), d.Continues) { d.setNotAliveConn(q) d.logger.Printf("ZDB_EXEC_WRAPPER_ERR: SRV: %s; %s", q.ToString(), err.Error()) continue } else { return err } } } return nil } } func (d *Pool) Ping(q *Conn) (err error) { var n any if err = d.WithTimeout(d.PingTimout).qGet(q, &n, "SELECT 1"); err != nil { q.PingTry++ d.logger.Printf("ZDB_PING_ERR: SRV: %s; TRY: %d; %v", q.ToString(), q.PingTry, err, ) if d.PingTry <= q.PingTry { return nil } } q.PingTry = 0 return } func (d *Pool) setNotAliveConn(conn *Conn) { d.mu.Lock() defer d.mu.Unlock() for i, slave := range d.SrvSync { if slave == conn { conn.Alive = false d.notAliveConns = append(d.notAliveConns, conn) d.SrvSync = remove(d.SrvSync, i) return } } for i, slave := range d.SrvAsync { if slave == conn { conn.Alive = false d.notAliveConns = append(d.notAliveConns, conn) d.SrvAsync = remove(d.SrvAsync, i) return } } } func (d *Pool) Start() { d.stop = false go func() { for { if d.stop { return } rep: for i, q := range d.notAliveConns { if err := d.Ping(q); err == nil { d.mu.Lock() q.Alive = true d.notAliveConns = remove(d.notAliveConns, i) if q.Mode == ConnModeSync { d.SrvSync = append(d.SrvSync, q) } else if q.Mode == ConnModeAsync { d.SrvAsync = append(d.SrvAsync, q) } d.mu.Unlock() goto rep } } if d.SrvMaster != nil { d.SrvMaster.Alive = d.Ping(d.SrvMaster) == nil } time.Sleep(time.Second) } }() } func (d *Pool) Stop() { d.stop = true } func (d *Pool) IsAlive() bool { return d.SrvMaster != nil && d.SrvMaster.Alive } func (d *Pool) prepare(sql string, param map[string]any) string { for n, t1 := range param { switch tv := t1.(type) { case time.Time: sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%s'::timestamptz", tv.Format(time.RFC3339Nano))) case *time.Time: if tv == nil { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%s'::timestamptz", tv.Format(time.RFC3339Nano))) } case nil: sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") case bool: if tv { sql = strings.ReplaceAll(sql, ":"+n+":", "TRUE") } else { sql = strings.ReplaceAll(sql, ":"+n+":", "FALSE") } case *string: if tv == nil || *tv == "NULL" { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%v'", strings.ReplaceAll(*tv, "'", "''"))) } case string: if tv == "NULL" { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%v'", strings.ReplaceAll(tv, "'", "''"))) } case *int: if tv == nil { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("%v", *tv)) } case *bool: if tv == nil { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { if *tv { sql = strings.ReplaceAll(sql, ":"+n+":", "TRUE") } else { sql = strings.ReplaceAll(sql, ":"+n+":", "FALSE") } } case *int64: if tv == nil { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("%v", *tv)) } case *float64: if tv == nil { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("%v", *tv)) } case int, int64: sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("%v", tv)) default: switch reflect.TypeOf(tv).Kind() { case reflect.Slice, reflect.Array: sql = strings.ReplaceAll(sql, ":"+n+":", "'{"+strings.Trim(strings.Join(strings.Split(fmt.Sprint(tv), " "), ","), "[]")+"}'") } sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%v'", tv)) } } if v, ok := param["_debug"]; ok { switch vv := v.(type) { case bool: if vv { d.logger.Printf(sql) } case int, uint: if vv == 1 { d.logger.Printf(sql) } case string: if vv == "1" { d.logger.Printf(sql) } } } return sql }