Skip to content

Commit 2143d55

Browse files
committed
Validate unknown columns in SELECT WHERE clauses
1 parent 428d4e6 commit 2143d55

File tree

13 files changed

+144
-0
lines changed

13 files changed

+144
-0
lines changed

internal/compiler/output_columns.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
6767
targets = n.TargetList
6868
isUnion := len(targets.Items) == 0 && n.Larg != nil
6969

70+
if err := c.findColumnsInClause(qc, n.WhereClause, [][]*Table{tables}); err != nil {
71+
return nil, err
72+
}
73+
7074
if n.GroupClause != nil {
7175
for _, item := range n.GroupClause.Items {
7276
if err := findColumnForNode(item, tables, targets); err != nil {
@@ -722,6 +726,108 @@ func findColumnForNode(item ast.Node, tables []*Table, targetList *ast.List) err
722726
return findColumnForRef(ref, tables, targetList)
723727
}
724728

729+
func (c *Compiler) findColumnsInClause(qc *QueryCatalog, node ast.Node, scopes [][]*Table) error {
730+
if node == nil {
731+
return nil
732+
}
733+
734+
validator := &columnRefClauseValidator{
735+
compiler: c,
736+
qc: qc,
737+
scopes: scopes,
738+
}
739+
astutils.Walk(validator, node)
740+
return validator.err
741+
}
742+
743+
type columnRefClauseValidator struct {
744+
compiler *Compiler
745+
qc *QueryCatalog
746+
scopes [][]*Table
747+
err error
748+
}
749+
750+
func (v *columnRefClauseValidator) Visit(node ast.Node) astutils.Visitor {
751+
if node == nil || v.err != nil {
752+
return nil
753+
}
754+
755+
if selectStmt, ok := node.(*ast.SelectStmt); ok {
756+
tables, err := v.compiler.sourceTables(v.qc, selectStmt)
757+
if err != nil {
758+
v.err = err
759+
return nil
760+
}
761+
scopes := append([][]*Table{tables}, v.scopes...)
762+
if err := v.compiler.findColumnsInClause(v.qc, selectStmt.WhereClause, scopes); err != nil {
763+
v.err = err
764+
}
765+
return nil
766+
}
767+
768+
if ref, ok := node.(*ast.ColumnRef); ok {
769+
if err := findColumnForRefInScopes(ref, v.scopes); err != nil {
770+
v.err = err
771+
return nil
772+
}
773+
}
774+
775+
return v
776+
}
777+
778+
func findColumnForRefInScopes(ref *ast.ColumnRef, scopes [][]*Table) error {
779+
parts := stringSlice(ref.Fields)
780+
var schema, alias, name string
781+
switch len(parts) {
782+
case 1:
783+
name = parts[0]
784+
case 2:
785+
alias = parts[0]
786+
name = parts[1]
787+
case 3:
788+
schema = parts[0]
789+
alias = parts[1]
790+
name = parts[2]
791+
default:
792+
return fmt.Errorf("unknown number of fields: %d", len(parts))
793+
}
794+
795+
for _, tables := range scopes {
796+
var found int
797+
for _, t := range tables {
798+
if schema != "" && t.Rel.Schema != schema {
799+
continue
800+
}
801+
if alias != "" && t.Rel.Name != alias {
802+
continue
803+
}
804+
for _, c := range t.Columns {
805+
if c.Name == name {
806+
found++
807+
break
808+
}
809+
}
810+
}
811+
if found == 0 {
812+
continue
813+
}
814+
if found > 1 {
815+
return &sqlerr.Error{
816+
Code: "42703",
817+
Message: fmt.Sprintf("column reference %q is ambiguous", name),
818+
Location: ref.Location,
819+
}
820+
}
821+
return nil
822+
}
823+
824+
return &sqlerr.Error{
825+
Code: "42703",
826+
Message: fmt.Sprintf("column reference %q not found", name),
827+
Location: ref.Location,
828+
}
829+
}
830+
725831
func findColumnForRef(ref *ast.ColumnRef, tables []*Table, targetList *ast.List) error {
726832
parts := stringSlice(ref.Fields)
727833
var alias, name string
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#4264
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
-- name: GetUsers :many
2+
select * from "user" where is_deleted = false;
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
create table "user" ("name" text not null, deleted bool not null);
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
version: "2"
2+
sql:
3+
- engine: "postgresql"
4+
schema: "schema.sql"
5+
queries: "query.sql"
6+
gen:
7+
go:
8+
package: "querytest"
9+
out: "go"
10+
sql_package: "pgx/v5"
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# package querytest
2+
query.sql:2:28: column reference "is_deleted" not found
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# package querytest
2+
query.sql:2:41: column "is_deleted" does not exist
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#4264
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
-- name: GetUsers :many
2+
select * from "user" where exists (
3+
select 1 from "user" as u2 where is_deleted = false
4+
);
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
create table "user" ("name" text not null, deleted bool not null);

0 commit comments

Comments
 (0)