Commit 45bffcb4 authored by Vladimir Barsukov's avatar Vladimir Barsukov
Browse files

fix sql prepare

parent 9b08a5de
...@@ -31,6 +31,7 @@ require ( ...@@ -31,6 +31,7 @@ require (
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
......
...@@ -71,8 +71,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= ...@@ -71,8 +71,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lib/pq v1.10.0 h1:Zx5DJFEYQXio93kgXnQ09fXNiUKsqv4OUEu2UtGcB1E= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
......
...@@ -4,12 +4,15 @@ func (d *Pool) WExec(sql string, args ...any) error { ...@@ -4,12 +4,15 @@ func (d *Pool) WExec(sql string, args ...any) error {
return d.qExec(d.SrvMaster, sql, args...) return d.qExec(d.SrvMaster, sql, args...)
} }
func (d *Pool) WExecNamed(sql string, args map[string]any) error { func (d *Pool) WExecNamed(sql string, args map[string]any) error {
return d.qExec(d.SrvMaster, d.prepare(sql, args)) newSql, newArgs := d.prepare(sql, args)
return d.qExec(d.SrvMaster, newSql, newArgs)
} }
func (d *Pool) WExecOpts(opts Opts) error { func (d *Pool) WExecOpts(opts Opts) error {
sql, args := opts.Opts() sql, args := opts.Opts()
newSql, newArgs := d.prepare(sql, args)
return d.qExec(d.SrvMaster, d.prepare(sql, args)) return d.qExec(d.SrvMaster, newSql, newArgs)
} }
func (d *Pool) qExec(q *Conn, sql string, args ...any) error { func (d *Pool) qExec(q *Conn, sql string, args ...any) error {
......
package zdb package zdb
import "github.com/georgysavva/scany/v2/pgxscan" import (
"github.com/georgysavva/scany/v2/pgxscan"
)
func (d *Pool) WGet(dst any, sql string, args ...any) error { func (d *Pool) WGet(dst any, sql string, args ...any) error {
return d.qGet(d.SrvMaster, dst, sql, args...) return d.qGet(d.SrvMaster, dst, sql, args...)
} }
func (d *Pool) WGetNamed(dst any, sql string, args map[string]any) error { func (d *Pool) WGetNamed(dst any, sql string, args map[string]any) error {
return d.qGet(d.SrvMaster, dst, d.prepare(sql, args)) newSql, newArgs := d.prepare(sql, args)
return d.qGet(d.SrvMaster, dst, newSql, newArgs)
} }
func (d *Pool) WGetOpts(dst any, opts Opts) error { func (d *Pool) WGetOpts(dst any, opts Opts) error {
sql, args := opts.Opts() sql, args := opts.Opts()
newSql, newArgs := d.prepare(sql, args)
return d.qGet(d.SrvMaster, dst, d.prepare(sql, args)) return d.qGet(d.SrvMaster, dst, newSql, newArgs)
} }
func (d *Pool) Get(dst any, sql string, args ...any) error { func (d *Pool) Get(dst any, sql string, args ...any) error {
...@@ -21,14 +26,17 @@ func (d *Pool) Get(dst any, sql string, args ...any) error { ...@@ -21,14 +26,17 @@ func (d *Pool) Get(dst any, sql string, args ...any) error {
} }
func (d *Pool) GetNamed(dst any, sql string, args map[string]any) error { func (d *Pool) GetNamed(dst any, sql string, args map[string]any) error {
return d.execWrapper(ConnModeSync, dst, func(q *Conn, dst1 any) error { return d.execWrapper(ConnModeSync, dst, func(q *Conn, dst1 any) error {
return d.qGet(q, dst1, d.prepare(sql, args)) newSql, newArgs := d.prepare(sql, args)
return d.qGet(q, dst1, newSql, newArgs...)
}) })
} }
func (d *Pool) GetOpts(dst any, opts Opts) error { func (d *Pool) GetOpts(dst any, opts Opts) error {
sql, args := opts.Opts() sql, args := opts.Opts()
newSql, newArgs := d.prepare(sql, args)
return d.execWrapper(ConnModeSync, dst, func(q *Conn, dst1 any) error { return d.execWrapper(ConnModeSync, dst, func(q *Conn, dst1 any) error {
return d.qGet(q, dst1, d.prepare(sql, args)) return d.qGet(q, dst1, newSql, newArgs)
}) })
} }
...@@ -39,14 +47,18 @@ func (d *Pool) GetAsync(dst any, sql string, args ...any) error { ...@@ -39,14 +47,18 @@ func (d *Pool) GetAsync(dst any, sql string, args ...any) error {
} }
func (d *Pool) GetAsyncNamed(dst any, sql string, args map[string]any) error { func (d *Pool) GetAsyncNamed(dst any, sql string, args map[string]any) error {
return d.execWrapper(ConnModeAsync, dst, func(q *Conn, dst1 any) error { return d.execWrapper(ConnModeAsync, dst, func(q *Conn, dst1 any) error {
return d.qGet(q, dst1, d.prepare(sql, args)) newSql, newArgs := d.prepare(sql, args)
return d.qGet(q, dst1, newSql, newArgs)
}) })
} }
func (d *Pool) GetAsyncOpts(dst any, opts Opts) error { func (d *Pool) GetAsyncOpts(dst any, opts Opts) error {
sql, args := opts.Opts() sql, args := opts.Opts()
newSql, newArgs := d.prepare(sql, args)
return d.execWrapper(ConnModeAsync, dst, func(q *Conn, dst1 any) error { return d.execWrapper(ConnModeAsync, dst, func(q *Conn, dst1 any) error {
return d.qGet(q, dst1, d.prepare(sql, args))
return d.qGet(q, dst1, newSql, newArgs)
}) })
} }
......
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
"reflect" "reflect"
"runtime" "runtime"
"strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
...@@ -347,75 +346,91 @@ func (d *Pool) IsAlive() bool { ...@@ -347,75 +346,91 @@ func (d *Pool) IsAlive() bool {
return d.SrvMaster != nil && d.SrvMaster.Alive return d.SrvMaster != nil && d.SrvMaster.Alive
} }
func (d *Pool) prepare(sql string, param map[string]any) string { func (d *Pool) prepare(sql string, param map[string]any) (string, []any) {
for n, t1 := range param { args := make([]any, 0)
switch tv := t1.(type) {
idx := 0
for n, t := range param {
if n[0] == '_' {
continue
}
if !strings.Contains(sql, ":"+n+":") {
continue
}
switch v := t.(type) {
case time.Time: case time.Time:
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%s'::timestamptz", tv.Format(time.RFC3339Nano))) idx++
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d::timestamptz", idx))
args = append(args, v.Format(time.RFC3339Nano))
case *time.Time: case *time.Time:
if tv == nil { if v == nil {
sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
} else { } else {
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%s'::timestamptz", tv.Format(time.RFC3339Nano))) idx++
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d::timestamptz", idx))
args = append(args, v.Format(time.RFC3339Nano))
} }
case nil: case nil:
sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
case bool: case bool:
if tv { if v {
sql = strings.ReplaceAll(sql, ":"+n+":", "TRUE") sql = strings.ReplaceAll(sql, ":"+n+":", "TRUE")
} else { } else {
sql = strings.ReplaceAll(sql, ":"+n+":", "FALSE") sql = strings.ReplaceAll(sql, ":"+n+":", "FALSE")
} }
case *string: case *string:
if tv == nil || *tv == "NULL" { if v == nil || *v == "NULL" {
sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
} else { } else {
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%v'", strings.ReplaceAll(*tv, "'", "''"))) idx++
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx))
args = append(args, v)
} }
case string: case string:
if tv == "NULL" { if v == "NULL" {
sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
} else { } else {
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%v'", strings.ReplaceAll(tv, "'", "''"))) idx++
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx))
args = append(args, v)
} }
case *int: case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64, *float32, *float64:
if tv == nil { if v == nil {
sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
} else { } else {
sql = strings.ReplaceAll(sql, ":"+n+":", strconv.Itoa(*tv)) idx++
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx))
args = append(args, v)
} }
case *bool: case *bool:
if tv == nil { if v == nil {
sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
} else { } else {
if *tv { if *v {
sql = strings.ReplaceAll(sql, ":"+n+":", "TRUE") sql = strings.ReplaceAll(sql, ":"+n+":", "TRUE")
} else { } else {
sql = strings.ReplaceAll(sql, ":"+n+":", "FALSE") sql = strings.ReplaceAll(sql, ":"+n+":", "FALSE")
} }
} }
case *int64: case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
if tv == nil { idx++
sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx))
} else { args = append(args, v)
sql = strings.ReplaceAll(sql, ":"+n+":", strconv.FormatInt(*tv, 10))
}
case *float64:
if tv == nil {
sql = strings.ReplaceAll(sql, ":"+n+":", "NULL")
} else {
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("%v", *tv))
}
case int, int64:
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("%v", tv))
default: default:
switch reflect.TypeOf(tv).Kind() { switch reflect.TypeOf(v).Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
sql = strings.ReplaceAll(sql, ":"+n+":", "'{"+strings.Trim(strings.Join(strings.Split(fmt.Sprint(tv), " "), ","), "[]")+"}'") idx++
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx))
args = append(args, v)
default: default:
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%v'", tv)) idx++
sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx))
args = append(args, v)
} }
} }
} }
if d.Debug { if d.Debug {
...@@ -437,5 +452,5 @@ func (d *Pool) prepare(sql string, param map[string]any) string { ...@@ -437,5 +452,5 @@ func (d *Pool) prepare(sql string, param map[string]any) string {
} }
} }
return sql return sql, args
} }
...@@ -8,12 +8,15 @@ func (d *Pool) WSelect(dst any, sql string, args ...any) error { ...@@ -8,12 +8,15 @@ func (d *Pool) WSelect(dst any, sql string, args ...any) error {
return d.qSelect(d.SrvMaster, dst, sql, args...) return d.qSelect(d.SrvMaster, dst, sql, args...)
} }
func (d *Pool) WSelectNamed(dst any, sql string, args map[string]any) error { func (d *Pool) WSelectNamed(dst any, sql string, args map[string]any) error {
return d.qSelect(d.SrvMaster, dst, d.prepare(sql, args)) newSql, newArgs := d.prepare(sql, args)
return d.qSelect(d.SrvMaster, dst, newSql, newArgs)
} }
func (d *Pool) WSelectOpts(dst any, opts Opts) error { func (d *Pool) WSelectOpts(dst any, opts Opts) error {
sql, args := opts.Opts() sql, args := opts.Opts()
newSql, newArgs := d.prepare(sql, args)
return d.qSelect(d.SrvMaster, dst, d.prepare(sql, args)) return d.qSelect(d.SrvMaster, dst, newSql, newArgs)
} }
func (d *Pool) Select(dst any, sql string, args ...any) error { func (d *Pool) Select(dst any, sql string, args ...any) error {
...@@ -23,14 +26,17 @@ func (d *Pool) Select(dst any, sql string, args ...any) error { ...@@ -23,14 +26,17 @@ func (d *Pool) Select(dst any, sql string, args ...any) error {
} }
func (d *Pool) SelectNamed(dst any, sql string, args map[string]any) error { func (d *Pool) SelectNamed(dst any, sql string, args map[string]any) error {
return d.execWrapper(ConnModeSync, dst, func(conn *Conn, dst1 any) error { return d.execWrapper(ConnModeSync, dst, func(conn *Conn, dst1 any) error {
return d.qSelect(conn, dst1, d.prepare(sql, args)) newSql, newArgs := d.prepare(sql, args)
return d.qSelect(conn, dst1, newSql, newArgs)
}) })
} }
func (d *Pool) SelectOpts(dst any, opts Opts) error { func (d *Pool) SelectOpts(dst any, opts Opts) error {
sql, args := opts.Opts() sql, args := opts.Opts()
newSql, newArgs := d.prepare(sql, args)
return d.execWrapper(ConnModeSync, dst, func(conn *Conn, dst1 any) error { return d.execWrapper(ConnModeSync, dst, func(conn *Conn, dst1 any) error {
return d.qSelect(conn, dst1, d.prepare(sql, args)) return d.qSelect(conn, dst1, newSql, newArgs)
}) })
} }
...@@ -41,14 +47,17 @@ func (d *Pool) SelectAsync(dst any, sql string, args ...any) error { ...@@ -41,14 +47,17 @@ func (d *Pool) SelectAsync(dst any, sql string, args ...any) error {
} }
func (d *Pool) SelectAsyncNamed(dst any, sql string, args map[string]any) error { func (d *Pool) SelectAsyncNamed(dst any, sql string, args map[string]any) error {
return d.execWrapper(ConnModeAsync, dst, func(conn *Conn, dst1 any) error { return d.execWrapper(ConnModeAsync, dst, func(conn *Conn, dst1 any) error {
return d.qSelect(conn, dst1, d.prepare(sql, args)) newSql, newArgs := d.prepare(sql, args)
return d.qSelect(conn, dst1, newSql, newArgs)
}) })
} }
func (d *Pool) SelectAsyncOpts(dst any, opts Opts) error { func (d *Pool) SelectAsyncOpts(dst any, opts Opts) error {
sql, args := opts.Opts() sql, args := opts.Opts()
newSql, newArgs := d.prepare(sql, args)
return d.execWrapper(ConnModeAsync, dst, func(conn *Conn, dst1 any) error { return d.execWrapper(ConnModeAsync, dst, func(conn *Conn, dst1 any) error {
return d.qSelect(conn, dst1, d.prepare(sql, args)) return d.qSelect(conn, dst1, newSql, newArgs)
}) })
} }
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment