@@ -54,6 +54,25 @@ const mysqlInListQuery = `/* name: FooByList :many */
5454SELECT a, b FROM foo WHERE foo.a IN (?, ?);
5555`
5656
57+ const starExpansionSeriesSchema = `
58+ CREATE TABLE alertreport (
59+ eventdate date
60+ );
61+ `
62+
63+ const starExpansionSeriesQuery = `-- name: CountAlertReportBy :many
64+ select DATE_TRUNC($1,ts)::text as datetime,coalesce(count,0) as count from
65+ (
66+ SELECT DATE_TRUNC($1,eventdate) as hr ,count(*)
67+ FROM alertreport
68+ where eventdate between $2 and $3
69+ GROUP BY 1
70+ ) AS cnt
71+ right outer join ( SELECT * FROM generate_series ( $2, $3, CONCAT('1 ',$1)::interval) AS ts ) as dte
72+ on DATE_TRUNC($1, ts ) = cnt.hr
73+ order by 1 asc;
74+ `
75+
5776type stubAnalyzer struct {
5877 analyze func (context.Context , ast.Node , string , []string , * named.ParamSet ) (* analysispb.Analysis , error )
5978}
@@ -126,6 +145,36 @@ func newMySQLInListCompiler(t *testing.T) (*Compiler, *ast.RawStmt) {
126145 }, stmts [0 ].Raw
127146}
128147
148+ func newStarExpansionSeriesCompiler (t * testing.T ) (* Compiler , * ast.RawStmt ) {
149+ t .Helper ()
150+
151+ parser := postgresql .NewParser ()
152+ catalog := postgresql .NewCatalog ()
153+
154+ schema , err := parser .Parse (strings .NewReader (starExpansionSeriesSchema ))
155+ if err != nil {
156+ t .Fatal (err )
157+ }
158+ if err := catalog .Build (schema ); err != nil {
159+ t .Fatal (err )
160+ }
161+
162+ stmts , err := parser .Parse (strings .NewReader (starExpansionSeriesQuery ))
163+ if err != nil {
164+ t .Fatal (err )
165+ }
166+ if len (stmts ) != 1 {
167+ t .Fatalf ("expected 1 statement, got %d" , len (stmts ))
168+ }
169+
170+ return & Compiler {
171+ conf : config.SQL {Engine : config .EnginePostgreSQL },
172+ parser : parser ,
173+ catalog : catalog ,
174+ selector : newDefaultSelector (),
175+ }, stmts [0 ].Raw
176+ }
177+
129178func assertBatchParameterNames (t * testing.T , params []Parameter ) {
130179 t .Helper ()
131180
@@ -168,6 +217,40 @@ func assertBatchParameterNames(t *testing.T, params []Parameter) {
168217 }
169218}
170219
220+ func assertStarExpansionSeriesParameterNames (t * testing.T , params []Parameter ) {
221+ t .Helper ()
222+
223+ checks := []struct {
224+ idx int
225+ number int
226+ name string
227+ typ string
228+ }{
229+ {idx : 0 , number : 1 , name : "date_trunc" , typ : "text" },
230+ {idx : 1 , number : 2 , name : "eventdate" , typ : "date" },
231+ {idx : 2 , number : 3 , name : "eventdate" , typ : "date" },
232+ }
233+ if len (params ) != len (checks ) {
234+ t .Fatalf ("expected %d params, got %d" , len (checks ), len (params ))
235+ }
236+
237+ for _ , check := range checks {
238+ param := params [check .idx ]
239+ if param .Number != check .number {
240+ t .Fatalf ("param %d number mismatch: got %d want %d" , check .idx , param .Number , check .number )
241+ }
242+ if param .Column == nil {
243+ t .Fatalf ("param %d column is nil" , check .idx )
244+ }
245+ if param .Column .Name != check .name {
246+ t .Fatalf ("param %d name mismatch: got %q want %q" , check .idx , param .Column .Name , check .name )
247+ }
248+ if param .Column .DataType != check .typ && param .Column .DataType != "pg_catalog." + check .typ {
249+ t .Fatalf ("param %d type mismatch: got %q want %q or %q" , check .idx , param .Column .DataType , check .typ , "pg_catalog." + check .typ )
250+ }
251+ }
252+ }
253+
171254func TestInferQueryPreservesInsertSelectParamNamesWithCTEAndMixedParams (t * testing.T ) {
172255 t .Parallel ()
173256
@@ -247,3 +330,38 @@ func TestInferQueryPreservesDistinctMySQLInListParams(t *testing.T) {
247330 }
248331 }
249332}
333+
334+ func TestInferQueryPreservesStarExpansionSeriesParamNames (t * testing.T ) {
335+ t .Parallel ()
336+
337+ comp , raw := newStarExpansionSeriesCompiler (t )
338+ anlys , err := comp .inferQuery (raw , starExpansionSeriesQuery )
339+ if err != nil {
340+ t .Fatal (err )
341+ }
342+ if anlys == nil {
343+ t .Fatal ("expected non-nil analysis" )
344+ }
345+
346+ assertStarExpansionSeriesParameterNames (t , anlys .Parameters )
347+ }
348+
349+ func TestParseQueryManagedDBPreservesStarExpansionSeriesParamNames (t * testing.T ) {
350+ t .Parallel ()
351+
352+ comp , raw := newStarExpansionSeriesCompiler (t )
353+ comp .analyzer = stubAnalyzer {analyze : func (_ context.Context , _ ast.Node , _ string , _ []string , _ * named.ParamSet ) (* analysispb.Analysis , error ) {
354+ return & analysispb.Analysis {Params : []* analysispb.Parameter {
355+ {Number : 1 , Column : & analysispb.Column {DataType : "pg_catalog.text" }},
356+ {Number : 2 , Column : & analysispb.Column {DataType : "pg_catalog.date" }},
357+ {Number : 3 , Column : & analysispb.Column {DataType : "pg_catalog.date" }},
358+ }}, nil
359+ }}
360+
361+ query , err := comp .parseQuery (raw , starExpansionSeriesQuery , opts.Parser {})
362+ if err != nil {
363+ t .Fatal (err )
364+ }
365+
366+ assertStarExpansionSeriesParameterNames (t , query .Params )
367+ }
0 commit comments