Commit 45bffcb4 authored by Vladimir Barsukov's avatar Vladimir Barsukov
Browse files

fix sql prepare

parent 9b08a5de
......@@ -31,6 +31,7 @@ require (
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.8 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
......
......@@ -4,12 +4,15 @@ func (d *Pool) WExec(sql string, args ...any) error {
return d.qExec(d.SrvMaster, sql, args...)
}
func (d *Pool) WExecNamed(sql string, args map[string]any) error {
return d.qExec(d.SrvMaster, d.prepare(sql, args))
newSql, newArgs := d.prepare(sql, args)
return d.qExec(d.SrvMaster, newSql, newArgs)
}
func (d *Pool) WExecOpts(opts Opts) error {
sql, args := opts.Opts()
newSql, newArgs := d.prepare(sql, args)
return d.qExec(d.SrvMaster, d.prepare(sql, args))
return d.qExec(d.SrvMaster, newSql, newArgs)
}
func (d *Pool) qExec(q *Conn, sql string, args ...any) error {
......
package zdb
import "github.com/georgysavva/scany/v2/pgxscan"
import (
"github.com/georgysavva/scany/v2/pgxscan"
)
func (d *Pool) WGet(dst any, sql string, args ...any) error {
return d.qGet(d.SrvMaster, dst, sql, args...)
}
func (d *Pool) WGetNamed(dst any, sql string, args map[string]any) error {
return d.qGet(d.SrvMaster, dst, d.prepare(sql, args))
newSql, newArgs := d.prepare(sql, args)
return d.qGet(d.SrvMaster, dst, newSql, newArgs)
}
func (d *Pool) WGetOpts(dst any, opts Opts) error {
sql, args := opts.Opts()
newSql, newArgs := d.prepare(sql, args)
return d.qGet(d.SrvMaster, dst, d.prepare(sql, args))
return d.qGet(d.SrvMaster, dst, newSql, newArgs)
}
func (d *Pool) Get(dst any, sql string, args ...any) error {
......@@ -21,14 +26,17 @@ func (d *Pool) Get(dst any, sql string, args ...any) error {
}
func (d *Pool) GetNamed(dst any, sql string, args map[string]any) error {
return d.execWrapper(ConnModeSync, dst, func(q *Conn, dst1 any) error {
return d.qGet(q, dst1, d.prepare(sql, args))
newSql, newArgs := d.prepare(sql, args)
return d.qGet(q, dst1, newSql, newArgs...)
})
}
func (d *Pool) GetOpts(dst any, opts Opts) error {
sql, args := opts.Opts()
newSql, newArgs := d.prepare(sql, args)
return d.execWrapper(ConnModeSync, dst, func(q *Conn, dst1 any) error {
return d.qGet(q, dst1, d.prepare(sql, args))
return d.qGet(q, dst1, newSql, newArgs)
})
}
......@@ -39,14 +47,18 @@ func (d *Pool) GetAsync(dst any, sql string, args ...any) error {
}
func (d *Pool) GetAsyncNamed(dst any, sql string, args map[string]any) error {
return d.execWrapper(ConnModeAsync, dst, func(q *Conn, dst1 any) error {
return d.qGet(q, dst1, d.prepare(sql, args))
newSql, newArgs := d.prepare(sql, args)
return d.qGet(q, dst1, newSql, newArgs)
})
}
func (d *Pool) GetAsyncOpts(dst any, opts Opts) error {
sql, args := opts.Opts()
newSql, newArgs := d.prepare(sql, args)
return d.execWrapper(ConnModeAsync, dst, func(q *Conn, dst1 any) error {
return d.qGet(q, dst1, d.prepare(sql, args))
return d.qGet(q, dst1, newSql, newArgs)
})
}
......
......@@ -6,7 +6,6 @@ import (
"github.com/jackc/pgx/v5/pgxpool"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
......@@ -347,75 +346,91 @@ func (d *Pool) IsAlive() bool {
return d.SrvMaster != nil && d.SrvMaster.Alive
}
func (d *Pool) prepare(sql string, param map[string]any) string {
for n, t1 := range param {
switch tv := t1.(type) {
func (d *Pool) prepare(sql string, param map[string]any) (string, []any) {
args := make([]any, 0)
idx := 0
for n, t := range param {
if n[0] == '_' {
continue
}
if !strings.Contains(sql, ":"+n+":") {
continue
}
switch v := t.(type) {
case time.Time:
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%s'::timestamptz", tv.Format(time.RFC3339Nano)))
idx++
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d::timestamptz", idx))
args = append(args, v.Format(time.RFC3339Nano))
case *time.Time:
if tv == nil {
if v == nil {
sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
} else {
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%s'::timestamptz", tv.Format(time.RFC3339Nano)))
idx++
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d::timestamptz", idx))
args = append(args, v.Format(time.RFC3339Nano))
}
case nil:
sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
case bool:
if tv {
if v {
sql = strings.ReplaceAll(sql, ":"+n+":", "TRUE")
} else {
sql = strings.ReplaceAll(sql, ":"+n+":", "FALSE")
}
case *string:
if tv == nil || *tv == "NULL" {
if v == nil || *v == "NULL" {
sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
} else {
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%v'", strings.ReplaceAll(*tv, "'", "''")))
idx++
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx))
args = append(args, v)
}
case string:
if tv == "NULL" {
if v == "NULL" {
sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
} else {
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%v'", strings.ReplaceAll(tv, "'", "''")))
idx++
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx))
args = append(args, v)
}
case *int:
if tv == nil {
case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64, *float32, *float64:
if v == nil {
sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
} else {
sql = strings.ReplaceAll(sql, ":"+n+":", strconv.Itoa(*tv))
idx++
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx))
args = append(args, v)
}
case *bool:
if tv == nil {
if v == nil {
sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
} else {
if *tv {
if *v {
sql = strings.ReplaceAll(sql, ":"+n+":", "TRUE")
} else {
sql = strings.ReplaceAll(sql, ":"+n+":", "FALSE")
}
}
case *int64:
if tv == nil {
sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
} else {
sql = strings.ReplaceAll(sql, ":"+n+":", strconv.FormatInt(*tv, 10))
}
case *float64:
if tv == nil {
sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
} else {
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("%v", *tv))
}
case int, int64:
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("%v", tv))
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
idx++
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx))
args = append(args, v)
default:
switch reflect.TypeOf(tv).Kind() {
switch reflect.TypeOf(v).Kind() {
case reflect.Slice, reflect.Array:
sql = strings.ReplaceAll(sql, ":"+n+":", "'{"+strings.Trim(strings.Join(strings.Split(fmt.Sprint(tv), " "), ","), "[]")+"}'")
idx++
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx))
args = append(args, v)
default:
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%v'", tv))
idx++
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx))
args = append(args, v)
}
}
}
if d.Debug {
......@@ -437,5 +452,5 @@ func (d *Pool) prepare(sql string, param map[string]any) string {
}
}
return sql
return sql, args
}
......@@ -8,12 +8,15 @@ func (d *Pool) WSelect(dst any, sql string, args ...any) error {
return d.qSelect(d.SrvMaster, dst, sql, args...)
}
func (d *Pool) WSelectNamed(dst any, sql string, args map[string]any) error {
return d.qSelect(d.SrvMaster, dst, d.prepare(sql, args))
newSql, newArgs := d.prepare(sql, args)
return d.qSelect(d.SrvMaster, dst, newSql, newArgs)
}
func (d *Pool) WSelectOpts(dst any, opts Opts) error {
sql, args := opts.Opts()
newSql, newArgs := d.prepare(sql, args)
return d.qSelect(d.SrvMaster, dst, d.prepare(sql, args))
return d.qSelect(d.SrvMaster, dst, newSql, newArgs)
}
func (d *Pool) Select(dst any, sql string, args ...any) error {
......@@ -23,14 +26,17 @@ func (d *Pool) Select(dst any, sql string, args ...any) error {
}
func (d *Pool) SelectNamed(dst any, sql string, args map[string]any) error {
return d.execWrapper(ConnModeSync, dst, func(conn *Conn, dst1 any) error {
return d.qSelect(conn, dst1, d.prepare(sql, args))
newSql, newArgs := d.prepare(sql, args)
return d.qSelect(conn, dst1, newSql, newArgs)
})
}
func (d *Pool) SelectOpts(dst any, opts Opts) error {
sql, args := opts.Opts()
newSql, newArgs := d.prepare(sql, args)
return d.execWrapper(ConnModeSync, dst, func(conn *Conn, dst1 any) error {
return d.qSelect(conn, dst1, d.prepare(sql, args))
return d.qSelect(conn, dst1, newSql, newArgs)
})
}
......@@ -41,14 +47,17 @@ func (d *Pool) SelectAsync(dst any, sql string, args ...any) error {
}
func (d *Pool) SelectAsyncNamed(dst any, sql string, args map[string]any) error {
return d.execWrapper(ConnModeAsync, dst, func(conn *Conn, dst1 any) error {
return d.qSelect(conn, dst1, d.prepare(sql, args))
newSql, newArgs := d.prepare(sql, args)
return d.qSelect(conn, dst1, newSql, newArgs)
})
}
func (d *Pool) SelectAsyncOpts(dst any, opts Opts) error {
sql, args := opts.Opts()
newSql, newArgs := d.prepare(sql, args)
return d.execWrapper(ConnModeAsync, dst, func(conn *Conn, dst1 any) error {
return d.qSelect(conn, dst1, d.prepare(sql, args))
return d.qSelect(conn, dst1, newSql, newArgs)
})
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment