From 45bffcb401b29d76265af61315b87f08a43c7fbc Mon Sep 17 00:00:00 2001 From: Vladimir Barsukov Date: Sat, 8 Feb 2025 11:48:33 +0200 Subject: [PATCH] fix sql prepare --- go.mod | 1 + go.sum | 4 +-- zdb/exec.go | 7 +++-- zdb/get.go | 26 +++++++++++----- zdb/pool.go | 85 ++++++++++++++++++++++++++++++--------------------- zdb/select.go | 21 +++++++++---- 6 files changed, 92 insertions(+), 52 deletions(-) diff --git a/go.mod b/go.mod index 02fd44a..4eb2a93 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/leodido/go-urn v1.4.0 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect diff --git a/go.sum b/go.sum index e8739f4..6270230 100644 --- a/go.sum +++ b/go.sum @@ -71,8 +71,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= -github.com/lib/pq v1.10.0 h1:Zx5DJFEYQXio93kgXnQ09fXNiUKsqv4OUEu2UtGcB1E= -github.com/lib/pq v1.10.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/zdb/exec.go b/zdb/exec.go index 7a84c0a..17234a3 100644 --- a/zdb/exec.go +++ b/zdb/exec.go @@ -4,12 +4,15 @@ func (d *Pool) WExec(sql string, args ...any) error { return d.qExec(d.SrvMaster, sql, args...) } func (d *Pool) WExecNamed(sql string, args map[string]any) error { - return d.qExec(d.SrvMaster, d.prepare(sql, args)) + newSql, newArgs := d.prepare(sql, args) + + return d.qExec(d.SrvMaster, newSql, newArgs) } func (d *Pool) WExecOpts(opts Opts) error { sql, args := opts.Opts() + newSql, newArgs := d.prepare(sql, args) - return d.qExec(d.SrvMaster, d.prepare(sql, args)) + return d.qExec(d.SrvMaster, newSql, newArgs) } func (d *Pool) qExec(q *Conn, sql string, args ...any) error { diff --git a/zdb/get.go b/zdb/get.go index ec42f97..d3ab450 100644 --- a/zdb/get.go +++ b/zdb/get.go @@ -1,17 +1,22 @@ package zdb -import "github.com/georgysavva/scany/v2/pgxscan" +import ( + "github.com/georgysavva/scany/v2/pgxscan" +) func (d *Pool) WGet(dst any, sql string, args ...any) error { return d.qGet(d.SrvMaster, dst, sql, args...) } func (d *Pool) WGetNamed(dst any, sql string, args map[string]any) error { - return d.qGet(d.SrvMaster, dst, d.prepare(sql, args)) + newSql, newArgs := d.prepare(sql, args) + + return d.qGet(d.SrvMaster, dst, newSql, newArgs) } func (d *Pool) WGetOpts(dst any, opts Opts) error { sql, args := opts.Opts() + newSql, newArgs := d.prepare(sql, args) - return d.qGet(d.SrvMaster, dst, d.prepare(sql, args)) + return d.qGet(d.SrvMaster, dst, newSql, newArgs) } func (d *Pool) Get(dst any, sql string, args ...any) error { @@ -21,14 +26,17 @@ func (d *Pool) Get(dst any, sql string, args ...any) error { } func (d *Pool) GetNamed(dst any, sql string, args map[string]any) error { return d.execWrapper(ConnModeSync, dst, func(q *Conn, dst1 any) error { - return d.qGet(q, dst1, d.prepare(sql, args)) + newSql, newArgs := d.prepare(sql, args) + + return d.qGet(q, dst1, newSql, newArgs...) }) } func (d *Pool) GetOpts(dst any, opts Opts) error { sql, args := opts.Opts() + newSql, newArgs := d.prepare(sql, args) return d.execWrapper(ConnModeSync, dst, func(q *Conn, dst1 any) error { - return d.qGet(q, dst1, d.prepare(sql, args)) + return d.qGet(q, dst1, newSql, newArgs) }) } @@ -39,14 +47,18 @@ func (d *Pool) GetAsync(dst any, sql string, args ...any) error { } func (d *Pool) GetAsyncNamed(dst any, sql string, args map[string]any) error { return d.execWrapper(ConnModeAsync, dst, func(q *Conn, dst1 any) error { - return d.qGet(q, dst1, d.prepare(sql, args)) + newSql, newArgs := d.prepare(sql, args) + + return d.qGet(q, dst1, newSql, newArgs) }) } func (d *Pool) GetAsyncOpts(dst any, opts Opts) error { sql, args := opts.Opts() + newSql, newArgs := d.prepare(sql, args) return d.execWrapper(ConnModeAsync, dst, func(q *Conn, dst1 any) error { - return d.qGet(q, dst1, d.prepare(sql, args)) + + return d.qGet(q, dst1, newSql, newArgs) }) } diff --git a/zdb/pool.go b/zdb/pool.go index 7451e8c..a33846d 100644 --- a/zdb/pool.go +++ b/zdb/pool.go @@ -6,7 +6,6 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "reflect" "runtime" - "strconv" "strings" "sync" "sync/atomic" @@ -347,75 +346,91 @@ func (d *Pool) IsAlive() bool { return d.SrvMaster != nil && d.SrvMaster.Alive } -func (d *Pool) prepare(sql string, param map[string]any) string { - for n, t1 := range param { - switch tv := t1.(type) { +func (d *Pool) prepare(sql string, param map[string]any) (string, []any) { + args := make([]any, 0) + + idx := 0 + for n, t := range param { + if n[0] == '_' { + continue + } + + if !strings.Contains(sql, ":"+n+":") { + continue + } + + switch v := t.(type) { case time.Time: - sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%s'::timestamptz", tv.Format(time.RFC3339Nano))) + idx++ + sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d::timestamptz", idx)) + args = append(args, v.Format(time.RFC3339Nano)) case *time.Time: - if tv == nil { + if v == nil { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { - sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%s'::timestamptz", tv.Format(time.RFC3339Nano))) + idx++ + sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d::timestamptz", idx)) + args = append(args, v.Format(time.RFC3339Nano)) } case nil: sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") case bool: - if tv { + if v { sql = strings.ReplaceAll(sql, ":"+n+":", "TRUE") } else { sql = strings.ReplaceAll(sql, ":"+n+":", "FALSE") } case *string: - if tv == nil || *tv == "NULL" { + if v == nil || *v == "NULL" { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { - sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%v'", strings.ReplaceAll(*tv, "'", "''"))) + idx++ + sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx)) + args = append(args, v) } case string: - if tv == "NULL" { + if v == "NULL" { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { - sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%v'", strings.ReplaceAll(tv, "'", "''"))) + idx++ + sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx)) + args = append(args, v) } - case *int: - if tv == nil { + case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64, *float32, *float64: + if v == nil { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { - sql = strings.ReplaceAll(sql, ":"+n+":", strconv.Itoa(*tv)) + idx++ + sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx)) + args = append(args, v) } case *bool: - if tv == nil { + if v == nil { sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") } else { - if *tv { + if *v { sql = strings.ReplaceAll(sql, ":"+n+":", "TRUE") } else { sql = strings.ReplaceAll(sql, ":"+n+":", "FALSE") } } - case *int64: - if tv == nil { - sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") - } else { - sql = strings.ReplaceAll(sql, ":"+n+":", strconv.FormatInt(*tv, 10)) - } - case *float64: - if tv == nil { - sql = strings.ReplaceAll(sql, ":"+n+":", "NULL") - } else { - sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("%v", *tv)) - } - case int, int64: - sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("%v", tv)) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: + idx++ + sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx)) + args = append(args, v) default: - switch reflect.TypeOf(tv).Kind() { + switch reflect.TypeOf(v).Kind() { case reflect.Slice, reflect.Array: - sql = strings.ReplaceAll(sql, ":"+n+":", "'{"+strings.Trim(strings.Join(strings.Split(fmt.Sprint(tv), " "), ","), "[]")+"}'") + idx++ + sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx)) + args = append(args, v) default: - sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("'%v'", tv)) + idx++ + sql = strings.ReplaceAll(sql, ":"+n+":", fmt.Sprintf("$%d", idx)) + args = append(args, v) } } + } if d.Debug { @@ -437,5 +452,5 @@ func (d *Pool) prepare(sql string, param map[string]any) string { } } - return sql + return sql, args } diff --git a/zdb/select.go b/zdb/select.go index 9cba8c2..4a8ba6c 100644 --- a/zdb/select.go +++ b/zdb/select.go @@ -8,12 +8,15 @@ func (d *Pool) WSelect(dst any, sql string, args ...any) error { return d.qSelect(d.SrvMaster, dst, sql, args...) } func (d *Pool) WSelectNamed(dst any, sql string, args map[string]any) error { - return d.qSelect(d.SrvMaster, dst, d.prepare(sql, args)) + newSql, newArgs := d.prepare(sql, args) + + return d.qSelect(d.SrvMaster, dst, newSql, newArgs) } func (d *Pool) WSelectOpts(dst any, opts Opts) error { sql, args := opts.Opts() + newSql, newArgs := d.prepare(sql, args) - return d.qSelect(d.SrvMaster, dst, d.prepare(sql, args)) + return d.qSelect(d.SrvMaster, dst, newSql, newArgs) } func (d *Pool) Select(dst any, sql string, args ...any) error { @@ -23,14 +26,17 @@ func (d *Pool) Select(dst any, sql string, args ...any) error { } func (d *Pool) SelectNamed(dst any, sql string, args map[string]any) error { return d.execWrapper(ConnModeSync, dst, func(conn *Conn, dst1 any) error { - return d.qSelect(conn, dst1, d.prepare(sql, args)) + newSql, newArgs := d.prepare(sql, args) + + return d.qSelect(conn, dst1, newSql, newArgs) }) } func (d *Pool) SelectOpts(dst any, opts Opts) error { sql, args := opts.Opts() + newSql, newArgs := d.prepare(sql, args) return d.execWrapper(ConnModeSync, dst, func(conn *Conn, dst1 any) error { - return d.qSelect(conn, dst1, d.prepare(sql, args)) + return d.qSelect(conn, dst1, newSql, newArgs) }) } @@ -41,14 +47,17 @@ func (d *Pool) SelectAsync(dst any, sql string, args ...any) error { } func (d *Pool) SelectAsyncNamed(dst any, sql string, args map[string]any) error { return d.execWrapper(ConnModeAsync, dst, func(conn *Conn, dst1 any) error { - return d.qSelect(conn, dst1, d.prepare(sql, args)) + newSql, newArgs := d.prepare(sql, args) + + return d.qSelect(conn, dst1, newSql, newArgs) }) } func (d *Pool) SelectAsyncOpts(dst any, opts Opts) error { sql, args := opts.Opts() + newSql, newArgs := d.prepare(sql, args) return d.execWrapper(ConnModeAsync, dst, func(conn *Conn, dst1 any) error { - return d.qSelect(conn, dst1, d.prepare(sql, args)) + return d.qSelect(conn, dst1, newSql, newArgs) }) } -- GitLab