Skip to content

Commit 8cc20f2

Browse files
kyleconroyclaude
andcommitted
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 3d86155 commit 8cc20f2

File tree

11 files changed

+155
-62
lines changed

11 files changed

+155
-62
lines changed

internal/gen.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ type tmplCtx struct {
4141
UsesBatch bool
4242
OmitSqlcVersion bool
4343
BuildTags string
44+
WrapErrors bool
4445
}
4546

4647
func (t *tmplCtx) OutputQuery(sourceName string) bool {
@@ -98,6 +99,9 @@ func (t *tmplCtx) codegenQueryRetval(q Query) (string, error) {
9899
case ":execrows", ":execlastid":
99100
return "result, err :=", nil
100101
case ":execresult":
102+
if t.WrapErrors {
103+
return "result, err :=", nil
104+
}
101105
return "return", nil
102106
default:
103107
return "", fmt.Errorf("unhandled q.Cmd case %q", q.Cmd)
@@ -187,6 +191,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
187191
SqlcVersion: req.SqlcVersion,
188192
BuildTags: options.BuildTags,
189193
OmitSqlcVersion: options.OmitSqlcVersion,
194+
WrapErrors: options.WrapErrors,
190195
}
191196

192197
if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && options.SqlDriver != opts.SQLDriverGoSQLDriverMySQL {
@@ -378,7 +383,7 @@ func filterUnusedStructs(enums []Enum, structs []Struct, queries []Query) ([]Enu
378383
keepTypes[query.Ret.Type()] = struct{}{}
379384
if query.Ret.IsStruct() {
380385
for _, field := range query.Ret.Struct.Fields {
381-
keepTypes[field.Type] = struct{}{}
386+
keepTypes[strings.TrimPrefix(field.Type, "[]")] = struct{}{}
382387
for _, embedField := range field.EmbedFields {
383388
keepTypes[embedField.Type] = struct{}{}
384389
}
@@ -391,7 +396,8 @@ func filterUnusedStructs(enums []Enum, structs []Struct, queries []Query) ([]Enu
391396
for _, enum := range enums {
392397
_, keep := keepTypes[enum.Name]
393398
_, keepNull := keepTypes["Null"+enum.Name]
394-
if keep || keepNull {
399+
_, keepPointer := keepTypes["*"+enum.Name]
400+
if keep || keepNull || keepPointer {
395401
keepEnums = append(keepEnums, enum)
396402
}
397403
}

internal/go_type.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ func addExtraGoStructTags(tags map[string]string, req *plugin.GenerateRequest, o
1414
if oride.GoType.StructTags == nil {
1515
continue
1616
}
17+
if override.MatchesColumn(col) {
18+
for k, v := range oride.GoType.StructTags {
19+
tags[k] = v
20+
}
21+
continue
22+
}
1723
if !override.Matches(col.Table, req.Catalog.DefaultSchema) {
1824
// Different table.
1925
continue
@@ -64,16 +70,13 @@ func goType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Colu
6470
}
6571

6672
func goInnerType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string {
67-
columnType := sdk.DataType(col.Type)
68-
notNull := col.NotNull || col.IsArray
69-
7073
// package overrides have a higher precedence
7174
for _, override := range options.Overrides {
7275
oride := override.ShimOverride
7376
if oride.GoType.TypeName == "" {
7477
continue
7578
}
76-
if oride.DbType != "" && oride.DbType == columnType && oride.Nullable != notNull && oride.Unsigned == col.Unsigned {
79+
if override.MatchesColumn(col) {
7780
return oride.GoType.TypeName
7881
}
7982
}

internal/imports.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,10 @@ func (i *importer) queryImports(filename string) fileImports {
402402
pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{}
403403
}
404404

405+
if i.Options.WrapErrors {
406+
std["fmt"] = struct{}{}
407+
}
408+
405409
return sortedImports(std, pkg)
406410
}
407411

internal/mysql_type.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,11 @@ func mysqlType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.C
6464
}
6565
return "sql.NullInt32"
6666

67-
case "bigint":
67+
case "bigint", "bigint unsigned", "bigint signed":
68+
// "bigint unsigned" and "bigint signed" are MySQL CAST types
69+
// Note: We use int64 for CAST AS UNSIGNED to match original behavior,
70+
// even though uint64 would be more semantically correct.
71+
// The Unsigned flag on columns (from table schema) still uses uint64.
6872
if notNull {
6973
if unsigned {
7074
return "uint64"

internal/opts/options.go

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,40 +10,43 @@ import (
1010
)
1111

1212
type Options struct {
13-
EmitInterface bool `json:"emit_interface" yaml:"emit_interface"`
14-
EmitJsonTags bool `json:"emit_json_tags" yaml:"emit_json_tags"`
15-
JsonTagsIdUppercase bool `json:"json_tags_id_uppercase" yaml:"json_tags_id_uppercase"`
16-
EmitDbTags bool `json:"emit_db_tags" yaml:"emit_db_tags"`
17-
EmitPreparedQueries bool `json:"emit_prepared_queries" yaml:"emit_prepared_queries"`
18-
EmitExactTableNames bool `json:"emit_exact_table_names,omitempty" yaml:"emit_exact_table_names"`
19-
EmitEmptySlices bool `json:"emit_empty_slices,omitempty" yaml:"emit_empty_slices"`
20-
EmitExportedQueries bool `json:"emit_exported_queries" yaml:"emit_exported_queries"`
21-
EmitResultStructPointers bool `json:"emit_result_struct_pointers" yaml:"emit_result_struct_pointers"`
22-
EmitParamsStructPointers bool `json:"emit_params_struct_pointers" yaml:"emit_params_struct_pointers"`
23-
EmitMethodsWithDbArgument bool `json:"emit_methods_with_db_argument,omitempty" yaml:"emit_methods_with_db_argument"`
24-
EmitPointersForNullTypes bool `json:"emit_pointers_for_null_types" yaml:"emit_pointers_for_null_types"`
25-
EmitEnumValidMethod bool `json:"emit_enum_valid_method,omitempty" yaml:"emit_enum_valid_method"`
26-
EmitAllEnumValues bool `json:"emit_all_enum_values,omitempty" yaml:"emit_all_enum_values"`
27-
EmitSqlAsComment bool `json:"emit_sql_as_comment,omitempty" yaml:"emit_sql_as_comment"`
28-
JsonTagsCaseStyle string `json:"json_tags_case_style,omitempty" yaml:"json_tags_case_style"`
29-
Package string `json:"package" yaml:"package"`
30-
Out string `json:"out" yaml:"out"`
31-
Overrides []Override `json:"overrides,omitempty" yaml:"overrides"`
32-
Rename map[string]string `json:"rename,omitempty" yaml:"rename"`
33-
SqlPackage string `json:"sql_package" yaml:"sql_package"`
34-
SqlDriver string `json:"sql_driver" yaml:"sql_driver"`
35-
OutputBatchFileName string `json:"output_batch_file_name,omitempty" yaml:"output_batch_file_name"`
36-
OutputDbFileName string `json:"output_db_file_name,omitempty" yaml:"output_db_file_name"`
37-
OutputModelsFileName string `json:"output_models_file_name,omitempty" yaml:"output_models_file_name"`
38-
OutputQuerierFileName string `json:"output_querier_file_name,omitempty" yaml:"output_querier_file_name"`
39-
OutputCopyfromFileName string `json:"output_copyfrom_file_name,omitempty" yaml:"output_copyfrom_file_name"`
40-
OutputFilesSuffix string `json:"output_files_suffix,omitempty" yaml:"output_files_suffix"`
41-
InflectionExcludeTableNames []string `json:"inflection_exclude_table_names,omitempty" yaml:"inflection_exclude_table_names"`
42-
QueryParameterLimit *int32 `json:"query_parameter_limit,omitempty" yaml:"query_parameter_limit"`
43-
OmitSqlcVersion bool `json:"omit_sqlc_version,omitempty" yaml:"omit_sqlc_version"`
44-
OmitUnusedStructs bool `json:"omit_unused_structs,omitempty" yaml:"omit_unused_structs"`
45-
BuildTags string `json:"build_tags,omitempty" yaml:"build_tags"`
46-
Initialisms *[]string `json:"initialisms,omitempty" yaml:"initialisms"`
13+
EmitInterface bool `json:"emit_interface" yaml:"emit_interface"`
14+
EmitJsonTags bool `json:"emit_json_tags" yaml:"emit_json_tags"`
15+
JsonTagsIdUppercase bool `json:"json_tags_id_uppercase" yaml:"json_tags_id_uppercase"`
16+
EmitDbTags bool `json:"emit_db_tags" yaml:"emit_db_tags"`
17+
EmitPreparedQueries bool `json:"emit_prepared_queries" yaml:"emit_prepared_queries"`
18+
EmitExactTableNames bool `json:"emit_exact_table_names,omitempty" yaml:"emit_exact_table_names"`
19+
EmitEmptySlices bool `json:"emit_empty_slices,omitempty" yaml:"emit_empty_slices"`
20+
EmitExportedQueries bool `json:"emit_exported_queries" yaml:"emit_exported_queries"`
21+
EmitResultStructPointers bool `json:"emit_result_struct_pointers" yaml:"emit_result_struct_pointers"`
22+
EmitParamsStructPointers bool `json:"emit_params_struct_pointers" yaml:"emit_params_struct_pointers"`
23+
EmitMethodsWithDbArgument bool `json:"emit_methods_with_db_argument,omitempty" yaml:"emit_methods_with_db_argument"`
24+
EmitPointersForNullTypes bool `json:"emit_pointers_for_null_types" yaml:"emit_pointers_for_null_types"`
25+
// nil inherits EmitPointersForNullTypes; non-nil overrides for enums only.
26+
EmitPointersForNullEnumTypes *bool `json:"emit_pointers_for_null_enum_types,omitempty" yaml:"emit_pointers_for_null_enum_types"`
27+
EmitEnumValidMethod bool `json:"emit_enum_valid_method,omitempty" yaml:"emit_enum_valid_method"`
28+
EmitAllEnumValues bool `json:"emit_all_enum_values,omitempty" yaml:"emit_all_enum_values"`
29+
EmitSqlAsComment bool `json:"emit_sql_as_comment,omitempty" yaml:"emit_sql_as_comment"`
30+
JsonTagsCaseStyle string `json:"json_tags_case_style,omitempty" yaml:"json_tags_case_style"`
31+
Package string `json:"package" yaml:"package"`
32+
Out string `json:"out" yaml:"out"`
33+
Overrides []Override `json:"overrides,omitempty" yaml:"overrides"`
34+
Rename map[string]string `json:"rename,omitempty" yaml:"rename"`
35+
SqlPackage string `json:"sql_package" yaml:"sql_package"`
36+
SqlDriver string `json:"sql_driver" yaml:"sql_driver"`
37+
OutputBatchFileName string `json:"output_batch_file_name,omitempty" yaml:"output_batch_file_name"`
38+
OutputDbFileName string `json:"output_db_file_name,omitempty" yaml:"output_db_file_name"`
39+
OutputModelsFileName string `json:"output_models_file_name,omitempty" yaml:"output_models_file_name"`
40+
OutputQuerierFileName string `json:"output_querier_file_name,omitempty" yaml:"output_querier_file_name"`
41+
OutputCopyfromFileName string `json:"output_copyfrom_file_name,omitempty" yaml:"output_copyfrom_file_name"`
42+
OutputFilesSuffix string `json:"output_files_suffix,omitempty" yaml:"output_files_suffix"`
43+
InflectionExcludeTableNames []string `json:"inflection_exclude_table_names,omitempty" yaml:"inflection_exclude_table_names"`
44+
WrapErrors bool `json:"wrap_errors,omitempty" yaml:"wrap_errors"`
45+
QueryParameterLimit *int32 `json:"query_parameter_limit,omitempty" yaml:"query_parameter_limit"`
46+
OmitSqlcVersion bool `json:"omit_sqlc_version,omitempty" yaml:"omit_sqlc_version"`
47+
OmitUnusedStructs bool `json:"omit_unused_structs,omitempty" yaml:"omit_unused_structs"`
48+
BuildTags string `json:"build_tags,omitempty" yaml:"build_tags"`
49+
Initialisms *[]string `json:"initialisms,omitempty" yaml:"initialisms"`
4750

4851
InitialismsMap map[string]struct{} `json:"-" yaml:"-"`
4952
}

internal/opts/override.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"os"
66
"strings"
77

8+
"github.com/sqlc-dev/plugin-sdk-go/sdk"
89
"github.com/sqlc-dev/plugin-sdk-go/pattern"
910
"github.com/sqlc-dev/plugin-sdk-go/plugin"
1011
)
@@ -76,6 +77,12 @@ func (o *Override) Matches(n *plugin.Identifier, defaultSchema string) bool {
7677
return true
7778
}
7879

80+
func (o *Override) MatchesColumn(col *plugin.Column) bool {
81+
columnType := sdk.DataType(col.Type)
82+
notNull := col.NotNull || col.IsArray
83+
return o.DBType != "" && o.DBType == columnType && o.Nullable != notNull && o.Unsigned == col.Unsigned
84+
}
85+
7986
func (o *Override) parse(req *plugin.GenerateRequest) (err error) {
8087
// validate deprecated postgres_type field
8188
if o.Deprecated_PostgresType != "" {

internal/postgresql_type.go

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi
3939
notNull := col.NotNull || col.IsArray
4040
driver := parseDriver(options.SqlPackage)
4141
emitPointersForNull := driver.IsPGX() && options.EmitPointersForNullTypes
42+
emitPointersForNullEnums := emitPointersForNull
43+
if options.EmitPointersForNullEnumTypes != nil {
44+
emitPointersForNullEnums = driver.IsPGX() && *options.EmitPointersForNullEnumTypes
45+
}
4246

4347
switch columnType {
4448
case "serial", "serial4", "pg_catalog.serial4":
@@ -165,7 +169,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi
165169
}
166170
return "sql.NullBool"
167171

168-
case "json":
172+
case "json", "pg_catalog.json":
169173
switch driver {
170174
case opts.SQLDriverPGXV5:
171175
return "[]byte"
@@ -181,7 +185,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi
181185
return "interface{}"
182186
}
183187

184-
case "jsonb":
188+
case "jsonb", "pg_catalog.jsonb":
185189
switch driver {
186190
case opts.SQLDriverPGXV5:
187191
return "[]byte"
@@ -233,7 +237,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi
233237
}
234238
return "sql.NullTime"
235239

236-
case "pg_catalog.timestamp":
240+
case "pg_catalog.timestamp", "timestamp":
237241
if driver == opts.SQLDriverPGXV5 {
238242
return "pgtype.Timestamp"
239243
}
@@ -503,6 +507,11 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi
503507
return "pgtype.XID"
504508
}
505509

510+
case "xid8":
511+
if driver == opts.SQLDriverPGXV5 {
512+
return "pgtype.Uint64"
513+
}
514+
506515
case "box":
507516
if driver.IsPGX() {
508517
return "pgtype.Box"
@@ -577,10 +586,14 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi
577586
}
578587
return StructName(schema.Name+"_"+enum.Name, options)
579588
} else {
589+
nullPrefix := "Null"
590+
if emitPointersForNullEnums {
591+
nullPrefix = "*"
592+
}
580593
if schema.Name == req.Catalog.DefaultSchema {
581-
return "Null" + StructName(enum.Name, options)
594+
return nullPrefix + StructName(enum.Name, options)
582595
}
583-
return "Null" + StructName(schema.Name+"_"+enum.Name, options)
596+
return nullPrefix + StructName(schema.Name+"_"+enum.Name, options)
584597
}
585598
}
586599
}

internal/result.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,7 @@ func newGoEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string
141141
}
142142

143143
fields := make([]Field, len(s.Fields))
144-
for i, f := range s.Fields {
145-
fields[i] = f
146-
}
144+
copy(fields, s.Fields)
147145

148146
return &goEmbed{
149147
modelType: s.Name,
@@ -270,8 +268,26 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs []
270268
c := query.Columns[0]
271269
name := columnName(c, 0)
272270
name = strings.Replace(name, "$", "_", -1)
271+
retName := escape(name)
272+
// For :one queries the scan destination lives in the same scope as
273+
// the query parameters, so reusing a parameter's name would cause
274+
// Scan to overwrite the input and leak it back to the caller on
275+
// sql.ErrNoRows (see sqlc-dev/sqlc#4354). Rename the return
276+
// variable when it would collide.
277+
if query.Cmd == metadata.CmdOne {
278+
argNames := map[string]struct{}{}
279+
for _, p := range gq.Arg.Pairs() {
280+
argNames[p.Name] = struct{}{}
281+
}
282+
for {
283+
if _, conflict := argNames[retName]; !conflict {
284+
break
285+
}
286+
retName += "_2"
287+
}
288+
}
273289
gq.Ret = QueryValue{
274-
Name: escape(name),
290+
Name: retName,
275291
DBName: name,
276292
Typ: goType(req, options, c),
277293
SQLDriver: sqlpkg,

internal/sqlite_type.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ func sqliteType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.
5656
}
5757
return "sql.NullTime"
5858

59+
case "json", "jsonb":
60+
return "json.RawMessage"
61+
5962
case "any":
6063
return "interface{}"
6164

internal/templates/pgx/queryCode.tmpl

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.De
3737
var {{.Ret.Name}} {{.Ret.Type}}
3838
{{- end}}
3939
err := row.Scan({{.Ret.Scan}})
40+
{{- if $.WrapErrors}}
41+
if err != nil {
42+
err = fmt.Errorf("query {{.MethodName}}: %w", err)
43+
}
44+
{{- end}}
4045
return {{.Ret.ReturnName}}, err
4146
}
4247
{{end}}
@@ -52,7 +57,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.
5257
rows, err := q.db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}})
5358
{{- end}}
5459
if err != nil {
55-
return nil, err
60+
return nil, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}}
5661
}
5762
defer rows.Close()
5863
{{- if $.EmitEmptySlices}}
@@ -63,12 +68,12 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.
6368
for rows.Next() {
6469
var {{.Ret.Name}} {{.Ret.Type}}
6570
if err := rows.Scan({{.Ret.Scan}}); err != nil {
66-
return nil, err
71+
return nil, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}}
6772
}
6873
items = append(items, {{.Ret.ReturnName}})
6974
}
7075
if err := rows.Err(); err != nil {
71-
return nil, err
76+
return nil, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}}
7277
}
7378
return items, nil
7479
}
@@ -84,7 +89,14 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) e
8489
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error {
8590
_, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
8691
{{- end}}
92+
{{- if $.WrapErrors }}
93+
if err != nil {
94+
return fmt.Errorf("query {{.MethodName}}: %w", err)
95+
}
96+
return nil
97+
{{- else }}
8798
return err
99+
{{- end }}
88100
}
89101
{{end}}
90102

@@ -99,7 +111,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, er
99111
result, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
100112
{{- end}}
101113
if err != nil {
102-
return 0, err
114+
return 0, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}}
103115
}
104116
return result.RowsAffected(), nil
105117
}
@@ -110,11 +122,17 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, er
110122
{{end -}}
111123
{{- if $.EmitMethodsWithDBArgument -}}
112124
func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (pgconn.CommandTag, error) {
113-
return db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
125+
{{queryRetval .}} db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
114126
{{- else -}}
115127
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error) {
116-
return q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
128+
{{queryRetval .}} q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
117129
{{- end}}
130+
{{- if $.WrapErrors}}
131+
if err != nil {
132+
err = fmt.Errorf("query {{.MethodName}}: %w", err)
133+
}
134+
return result, err
135+
{{- end}}
118136
}
119137
{{end}}
120138

0 commit comments

Comments
 (0)