package zdb import ( "context" "errors" "fmt" "git.barsukov.pro/barsukov/zdb/ztype" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" "log" "reflect" "strings" "sync" "sync/atomic" "time" ) type Pool struct { ctx context.Context logger Logger srvMaster *conn srvSlaves []*conn srvSlavesAsync []*conn mu *sync.RWMutex notAliveConns []*conn slavesIter *atomic.Int64 slavesAsyncIter *atomic.Int64 stop bool PgTsFormat string Continues []string ContinuesTry []string TryOnError int TryOnSleep time.Duration } func New() *Pool { return &Pool{ mu: &sync.RWMutex{}, slavesIter: &atomic.Int64{}, slavesAsyncIter: &atomic.Int64{}, PgTsFormat: "2006-01-02 15:04:05", Continues: []string{"connect", "EOF", "conflict with recovery"}, ContinuesTry: []string{"conflict with recovery"}, TryOnError: 1, TryOnSleep: time.Second * 5, } } func NewDefault() *Pool { p := New() p.ctx = context.Background() p.logger = log.Default() return p } func (d *Pool) WithContext(ctx context.Context) *Pool { return &Pool{ ctx: ctx, srvMaster: d.srvMaster, srvSlaves: d.srvSlaves, srvSlavesAsync: d.srvSlavesAsync, mu: d.mu, notAliveConns: d.notAliveConns, slavesIter: d.slavesIter, slavesAsyncIter: d.slavesAsyncIter, } } func (d *Pool) NewConn(mode connMode, pgConnString string) error { d.mu.Lock() defer d.mu.Unlock() q, err := d.newConn(mode, pgConnString) switch mode { case ConnModeMaster: d.srvMaster = q case ConnModeSync: d.srvSlaves = append(d.srvSlaves, q) case ConnModeAsync: d.srvSlavesAsync = append(d.srvSlavesAsync, q) default: return errors.New("unknown mode") } return err } func (d *Pool) NewConns(mode connMode, pgConnString ...string) error { d.mu.Lock() defer d.mu.Unlock() for _, s := range pgConnString { if err := d.NewConn(mode, s); err != nil { return err } } return nil } func (d *Pool) newConn(mode connMode, pgConnString string) (q *conn, err error) { var pgxPool *pgxpool.Pool var pgxConfig *pgxpool.Config if !strings.Contains(pgConnString, "default_query_exec_mode=") { pgConnString += " default_query_exec_mode=simple_protocol" } if pgxConfig, err = pgxpool.ParseConfig(pgConnString); err != nil { return nil, err } pgxConfig.AfterConnect = func(ctx context.Context, p *pgx.Conn) error { j := &pgtype.Type{Name: "jsonb", OID: pgtype.JSONBOID, Codec: ztype.JSONBCodec{}} jb := &pgtype.Type{Name: "json", OID: pgtype.JSONOID, Codec: ztype.JSONCodec{}} p.TypeMap().RegisterType(j) p.TypeMap().RegisterType(jb) p.TypeMap().RegisterType(&pgtype.Type{Name: "_json", OID: pgtype.JSONArrayOID, Codec: &pgtype.ArrayCodec{ElementType: j}}) p.TypeMap().RegisterType(&pgtype.Type{Name: "_jsonb", OID: pgtype.JSONBArrayOID, Codec: &pgtype.ArrayCodec{ElementType: jb}}) return nil } if pgxPool, err = pgxpool.NewWithConfig(d.ctx, pgxConfig); err != nil { return &conn{Pool: pgxPool, Alive: false, Mode: mode}, err } q = &conn{Pool: pgxPool, Alive: false, Mode: mode} if err = d.ping(q); err != nil { return q, err } q.Alive = true return q, nil } func (d *Pool) sync() *conn { if len(d.srvSlaves) == 0 { return d.srvMaster } d.mu.RLock() defer d.mu.RUnlock() return d.srvSlaves[d.slavesIter.Add(1)%int64(len(d.srvSlaves))] } func (d *Pool) async() *conn { if len(d.srvSlavesAsync) == 0 { return d.sync() } d.mu.RLock() defer d.mu.RUnlock() return d.srvSlavesAsync[d.slavesAsyncIter.Add(1)%int64(len(d.srvSlavesAsync))] } func (d *Pool) execWrapper(pool connMode, dst any, f func(conn *conn, dst1 any) error) error { for { var q *conn try := 0 if pool == ConnModeSync { q = d.sync() } else { q = d.async() } repeat: if err := f(q, dst); err != nil { if q.Mode == ConnModeMaster { return err } else { if try < d.TryOnError && contains(err.Error(), d.ContinuesTry) { try++ d.logger.Printf("DB_EXEC_WRAPPER_REPEAT_ERR: SRV: %s TRY: %d; %s", q.ToString(), try, err.Error()) time.Sleep(d.TryOnSleep) goto repeat } if contains(err.Error(), d.Continues) { d.setNotAliveConn(q) d.logger.Printf("DB_EXEC_WRAPPER_ERR: SRV: %s; %s", q.ToString(), err.Error()) continue } else { return err } } } return nil } } func (d *Pool) ping(q *conn) (err error) { var n any if err = d.qGet(q, &n, "SELECT 1"); err != nil { d.logger.Printf("DB_PING_ERR: SRV: %s; %v", q.ToString(), err, ) } return } func (d *Pool) setNotAliveConn(conn *conn) { d.mu.Lock() defer d.mu.Unlock() for i, slave := range d.srvSlaves { if slave == conn { conn.Alive = false d.notAliveConns = append(d.notAliveConns, conn) d.srvSlaves = remove(d.srvSlaves, i) return } } for i, slave := range d.srvSlavesAsync { if slave == conn { conn.Alive = false d.notAliveConns = append(d.notAliveConns, conn) d.srvSlavesAsync = remove(d.srvSlavesAsync, i) return } } } func (d *Pool) Start() { d.stop = false go func() { for { if d.stop { return } for i, q := range d.notAliveConns { if err := d.ping(q); err == nil { d.mu.Lock() q.Alive = true d.notAliveConns = remove(d.notAliveConns, i) if q.Mode == ConnModeSync { d.srvSlaves = append(d.srvSlaves, q) } else if q.Mode == ConnModeAsync { d.srvSlavesAsync = append(d.srvSlavesAsync, q) } d.mu.Unlock() } } d.srvMaster.Alive = d.ping(d.srvMaster) == nil time.Sleep(time.Second * 1) } }() } func (d *Pool) Stop() { d.stop = true } func (d *Pool) IsAlive() bool { return d.srvMaster.Alive } func (d *Pool) prepare(sql string, param map[string]any) string { for n, t := range param { switch t.(type) { case time.Time: sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%s'", t.(time.Time).UTC().Format(time.DateTime))) case *time.Time: if t.(*time.Time) == nil { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%s'", t.(*time.Time).UTC().Format(time.DateTime))) } case nil: sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") case bool: if t.(bool) { sql = strings.ReplaceAll(sql, ":"+n+":", "TRUE") } else { sql = strings.ReplaceAll(sql, ":"+n+":", "FALSE") } case *string: if t.(*string) == nil || *t.(*string) == "NULL" { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%v'", strings.ReplaceAll(*t.(*string), "'", "''"))) } case string: if t.(string) == "NULL" { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%v'", strings.ReplaceAll(t.(string), "'", "''"))) } case *int: if t.(*int) == nil { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("%v", *t.(*int))) } case *bool: if t.(*bool) == nil { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { if *t.(*bool) { sql = strings.ReplaceAll(sql, ":"+n+":", "TRUE") } else { sql = strings.ReplaceAll(sql, ":"+n+":", "FALSE") } } case *int64: if t.(*int64) == nil { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("%v", *t.(*int64))) } case *float64: if t.(*float64) == nil { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("%v", *t.(*float64))) } case int, int64: sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("%v", t)) default: switch reflect.TypeOf(t).Kind() { case reflect.Slice: sql = strings.ReplaceAll(sql, ":"+n+":", "'{"+strings.Trim(strings.Join(strings.Split(fmt.Sprint(t), " "), ","), "[]")+"}'") } sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%v'", t)) } } if _, ok := param["_debug"]; ok { d.logger.Printf(sql) } return sql }