Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions internal/compiler/output_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
targets = n.TargetList
isUnion := len(targets.Items) == 0 && n.Larg != nil

if err := c.findColumnsInClause(qc, n.WhereClause, [][]*Table{tables}); err != nil {
return nil, err
}

if n.GroupClause != nil {
for _, item := range n.GroupClause.Items {
if err := findColumnForNode(item, tables, targets); err != nil {
Expand Down Expand Up @@ -722,6 +726,108 @@ func findColumnForNode(item ast.Node, tables []*Table, targetList *ast.List) err
return findColumnForRef(ref, tables, targetList)
}

func (c *Compiler) findColumnsInClause(qc *QueryCatalog, node ast.Node, scopes [][]*Table) error {
if node == nil {
return nil
}

validator := &columnRefClauseValidator{
compiler: c,
qc: qc,
scopes: scopes,
}
astutils.Walk(validator, node)
return validator.err
}

type columnRefClauseValidator struct {
compiler *Compiler
qc *QueryCatalog
scopes [][]*Table
err error
}

func (v *columnRefClauseValidator) Visit(node ast.Node) astutils.Visitor {
if node == nil || v.err != nil {
return nil
}

if selectStmt, ok := node.(*ast.SelectStmt); ok {
tables, err := v.compiler.sourceTables(v.qc, selectStmt)
if err != nil {
v.err = err
return nil
}
scopes := append([][]*Table{tables}, v.scopes...)
if err := v.compiler.findColumnsInClause(v.qc, selectStmt.WhereClause, scopes); err != nil {
v.err = err
}
return nil
}

if ref, ok := node.(*ast.ColumnRef); ok {
if err := findColumnForRefInScopes(ref, v.scopes); err != nil {
v.err = err
return nil
}
}

return v
}

func findColumnForRefInScopes(ref *ast.ColumnRef, scopes [][]*Table) error {
parts := stringSlice(ref.Fields)
var schema, alias, name string
switch len(parts) {
case 1:
name = parts[0]
case 2:
alias = parts[0]
name = parts[1]
case 3:
schema = parts[0]
alias = parts[1]
name = parts[2]
default:
return fmt.Errorf("unknown number of fields: %d", len(parts))
}

for _, tables := range scopes {
var found int
for _, t := range tables {
if schema != "" && t.Rel.Schema != schema {
continue
}
if alias != "" && t.Rel.Name != alias {
continue
}
for _, c := range t.Columns {
if c.Name == name {
found++
break
}
}
}
if found == 0 {
continue
}
if found > 1 {
return &sqlerr.Error{
Code: "42703",
Message: fmt.Sprintf("column reference %q is ambiguous", name),
Location: ref.Location,
}
}
return nil
}

return &sqlerr.Error{
Code: "42703",
Message: fmt.Sprintf("column reference %q not found", name),
Location: ref.Location,
}
}

func findColumnForRef(ref *ast.ColumnRef, tables []*Table, targetList *ast.List) error {
parts := stringSlice(ref.Fields)
var alias, name string
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#4264
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- name: GetUsers :many
select * from "user" where is_deleted = false;
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
create table "user" ("name" text not null, deleted bool not null);
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
version: "2"
sql:
- engine: "postgresql"
schema: "schema.sql"
queries: "query.sql"
gen:
go:
package: "querytest"
out: "go"
sql_package: "pgx/v5"
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# package querytest
query.sql:2:28: column reference "is_deleted" not found
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# package querytest
query.sql:2:41: column "is_deleted" does not exist
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#4264
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- name: GetUsers :many
select * from "user" where exists (
select 1 from "user" as u2 where is_deleted = false
);
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
create table "user" ("name" text not null, deleted bool not null);
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
version: "2"
sql:
- engine: "postgresql"
schema: "schema.sql"
queries: "query.sql"
gen:
go:
package: "querytest"
out: "go"
sql_package: "pgx/v5"
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# package querytest
query.sql:3:36: column reference "is_deleted" not found
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# package querytest
query.sql:3:49: column "is_deleted" does not exist
Loading