Skip to content

Commit b645325

Browse files
feat(mysql): infer types for simple numeric expressions
1 parent d7f2758 commit b645325

File tree

1 file changed

+273
-0
lines changed

1 file changed

+273
-0
lines changed
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
package compiler
2+
3+
import (
4+
"github.com/sqlc-dev/sqlc/internal/config"
5+
"github.com/sqlc-dev/sqlc/internal/sql/ast"
6+
)
7+
8+
//
9+
// ==============================
10+
// Internal Type System
11+
// ==============================
12+
//
13+
14+
type Kind int
15+
16+
const (
17+
KindUnknown Kind = iota // inference not supported
18+
KindInt
19+
KindFloat
20+
KindDecimal
21+
KindAny
22+
)
23+
24+
type Type struct {
25+
Kind Kind
26+
NotNull bool
27+
Valid bool // explicit signal: inference succeeded
28+
}
29+
30+
func unknownType() Type {
31+
return Type{Kind: KindUnknown, Valid: false}
32+
}
33+
34+
//
35+
// ==============================
36+
// Entry Point
37+
// ==============================
38+
//
39+
40+
func (c *Compiler) inferExprType(node ast.Node, tables []*Table) *Column {
41+
if node == nil {
42+
return nil
43+
}
44+
45+
switch c.conf.Engine {
46+
case config.EngineMySQL:
47+
t := c.inferMySQLExpr(node, tables)
48+
return c.mysqlTypeToColumn(t)
49+
50+
// case config.EnginePostgreSQL:
51+
// t := c.inferPostgresExpr(node, tables)
52+
// return c.postgresTypeToColumn(t)
53+
54+
default:
55+
return nil
56+
}
57+
}
58+
59+
//
60+
// ==============================
61+
// MySQL Inference
62+
// ==============================
63+
//
64+
65+
func (c *Compiler) inferMySQLExpr(node ast.Node, tables []*Table) Type {
66+
switch n := node.(type) {
67+
case *ast.ColumnRef:
68+
return c.inferMySQLColumnRef(n, tables)
69+
70+
case *ast.A_Const:
71+
return inferConst(n)
72+
73+
case *ast.TypeCast:
74+
return c.inferMySQLTypeCast(n, tables)
75+
76+
case *ast.A_Expr:
77+
return c.inferMySQLBinary(n, tables)
78+
79+
default:
80+
return unknownType()
81+
}
82+
}
83+
84+
//
85+
// ------------------------------
86+
// Leaf nodes
87+
// ------------------------------
88+
//
89+
90+
func (c *Compiler) inferMySQLColumnRef(ref *ast.ColumnRef, tables []*Table) Type {
91+
cols, err := outputColumnRefs(&ast.ResTarget{}, tables, ref)
92+
if err != nil || len(cols) == 0 {
93+
return unknownType()
94+
}
95+
96+
col := cols[0]
97+
98+
return Type{
99+
Kind: mapMySQLKind(col.DataType),
100+
NotNull: col.NotNull,
101+
Valid: true,
102+
}
103+
}
104+
105+
func inferConst(node *ast.A_Const) Type {
106+
if node == nil || node.Val == nil {
107+
return unknownType()
108+
}
109+
110+
switch node.Val.(type) {
111+
case *ast.Integer:
112+
return Type{Kind: KindInt, NotNull: true, Valid: true}
113+
114+
case *ast.Float:
115+
return Type{Kind: KindFloat, NotNull: true, Valid: true}
116+
117+
case *ast.Null:
118+
return Type{Kind: KindAny, NotNull: false, Valid: true}
119+
120+
default:
121+
return unknownType()
122+
}
123+
}
124+
125+
func (c *Compiler) inferMySQLTypeCast(node *ast.TypeCast, tables []*Table) Type {
126+
if node == nil || node.TypeName == nil {
127+
return unknownType()
128+
}
129+
130+
base := toColumn(node.TypeName)
131+
if base == nil {
132+
return unknownType()
133+
}
134+
135+
arg := c.inferMySQLExpr(node.Arg, tables)
136+
137+
t := Type{
138+
Kind: mapMySQLKind(base.DataType),
139+
Valid: true,
140+
}
141+
142+
// propagate nullability
143+
if arg.Valid {
144+
t.NotNull = arg.NotNull
145+
}
146+
147+
// explicit NULL literal
148+
if constant, ok := node.Arg.(*ast.A_Const); ok {
149+
if _, isNull := constant.Val.(*ast.Null); isNull {
150+
t.NotNull = false
151+
}
152+
}
153+
154+
return t
155+
}
156+
157+
//
158+
// ------------------------------
159+
// Binary expressions
160+
// ------------------------------
161+
//
162+
163+
func (c *Compiler) inferMySQLBinary(node *ast.A_Expr, tables []*Table) Type {
164+
op := joinOperator(node)
165+
166+
left := c.inferMySQLExpr(node.Lexpr, tables)
167+
right := c.inferMySQLExpr(node.Rexpr, tables)
168+
169+
if !left.Valid || !right.Valid {
170+
return unknownType()
171+
}
172+
173+
// NOTE: only normal division ("/") is supported for now.
174+
// Unsupported operators intentionally fall back to the existing behavior.
175+
return promoteMySQLNumeric(op, left, right)
176+
}
177+
178+
//
179+
// ==============================
180+
// Promotion Rules (MySQL-specific for now)
181+
// ==============================
182+
//
183+
184+
// promoteMySQLNumeric applies simplified numeric promotion rules for MySQL.
185+
// It currently only supports "/" and intentionally falls back for other operators.
186+
func promoteMySQLNumeric(op string, a, b Type) Type {
187+
notNull := a.NotNull && b.NotNull
188+
189+
switch op {
190+
case "/":
191+
if a.Kind == KindFloat || b.Kind == KindFloat {
192+
return Type{
193+
Kind: KindFloat,
194+
NotNull: notNull,
195+
Valid: true,
196+
}
197+
}
198+
199+
return Type{
200+
Kind: KindDecimal,
201+
NotNull: notNull,
202+
Valid: true,
203+
}
204+
}
205+
206+
return unknownType()
207+
}
208+
209+
//
210+
// ==============================
211+
// Engine-specific Mapping
212+
// ==============================
213+
//
214+
215+
func (c *Compiler) mysqlTypeToColumn(t Type) *Column {
216+
if !t.Valid {
217+
return nil
218+
}
219+
220+
col := &Column{
221+
NotNull: t.NotNull,
222+
}
223+
224+
switch t.Kind {
225+
case KindInt:
226+
col.DataType = "int"
227+
228+
case KindFloat:
229+
col.DataType = "float"
230+
231+
case KindDecimal:
232+
col.DataType = "decimal"
233+
234+
default:
235+
col.DataType = "any"
236+
}
237+
238+
return col
239+
}
240+
241+
func mapMySQLKind(dt string) Kind {
242+
switch dt {
243+
case "int", "integer", "bigint", "smallint":
244+
return KindInt
245+
246+
case "float", "double", "real":
247+
return KindFloat
248+
249+
case "decimal", "numeric":
250+
return KindDecimal
251+
252+
default:
253+
return KindUnknown
254+
}
255+
}
256+
257+
//
258+
// ==============================
259+
// AST helpers
260+
// ==============================
261+
//
262+
263+
func joinOperator(node *ast.A_Expr) string {
264+
if node == nil || node.Name == nil || len(node.Name.Items) == 0 {
265+
return ""
266+
}
267+
268+
if s, ok := node.Name.Items[0].(*ast.String); ok {
269+
return s.Str
270+
}
271+
272+
return ""
273+
}

0 commit comments

Comments
 (0)