Skip to content

Commit 995bc66

Browse files
typhoonzerotonyyang-svail
authored andcommitted
use lexer to split multiple sql (#811)
* use lexer to split multiple sql * normalize sqlflowserver_test
1 parent 4a0506e commit 995bc66

4 files changed

Lines changed: 56 additions & 30 deletions

File tree

server/sqlflowserver.go

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ package server
2121

2222
import (
2323
"fmt"
24-
"strings"
2524
"time"
2625

2726
"github.com/golang/protobuf/proto"
@@ -56,26 +55,17 @@ func (s *Server) Run(req *pb.Request, stream pb.SQLFlow_RunServer) error {
5655
}
5756
defer db.Close()
5857
}
59-
60-
// FIXME(typhoonzero): split by ; can not deal with situations like
61-
// "SELECT * from mytable where col <> ';';", should be fixed.
62-
sqlStatements := strings.Split(req.Sql, ";")
63-
trimedStatements := []string{}
64-
for _, singleSQL := range sqlStatements {
65-
sqlToRun := strings.TrimSpace(singleSQL)
66-
if sqlToRun == "" {
67-
continue
68-
}
69-
trimedStatements = append(trimedStatements, sqlToRun)
58+
sqlStatements, err := sf.SplitMultipleSQL(req.Sql)
59+
if err != nil {
60+
return err
7061
}
71-
for _, singleSQL := range trimedStatements {
72-
sqlToRun := fmt.Sprintf("%s;", singleSQL)
62+
for _, singleSQL := range sqlStatements {
7363
var pr *sf.PipeReader
7464
startTime := time.Now().UnixNano()
7565
if s.enableSession == true {
76-
pr = s.run(sqlToRun, db, s.modelDir, req.Session)
66+
pr = s.run(singleSQL, db, s.modelDir, req.Session)
7767
} else {
78-
pr = s.run(sqlToRun, db, s.modelDir, nil)
68+
pr = s.run(singleSQL, db, s.modelDir, nil)
7969
}
8070

8171
defer pr.Close()
@@ -102,7 +92,7 @@ func (s *Server) Run(req *pb.Request, stream pb.SQLFlow_RunServer) error {
10292
}
10393
}
10494
// Send EndOfExecution message if have multiple requests.
105-
if len(trimedStatements) > 1 {
95+
if len(sqlStatements) > 1 {
10696
eoe := &pb.EndOfExecution{}
10797
eoe.Sql = singleSQL
10898
eoe.SpentTimeSeconds = time.Now().UnixNano() - startTime

server/sqlflowserver_test.go

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import (
2020
"log"
2121
"net"
2222
"os"
23-
"strings"
2423
"testing"
2524
"time"
2625

@@ -36,11 +35,12 @@ import (
3635
)
3736

3837
const (
39-
testErrorSQL = "ERROR ..."
40-
testQuerySQL = "SELECT ..."
41-
testExecuteSQL = "INSERT ..."
42-
testExtendedSQL = "SELECT ... TRAIN ..."
43-
testExtendedSQLWithSpace = "SELECT ... TRAIN ...; \n\t"
38+
testErrorSQL = "ERROR ..."
39+
testQuerySQL = "SELECT * FROM some_table;"
40+
testExecuteSQL = "INSERT INTO some_table VALUES (1,2,3,4);"
41+
testExtendedSQL = "SELECT * FROM some_table TRAIN SomeModel;"
42+
testExtendedSQLNoSemicolon = "SELECT * FROM some_table TRAIN SomeModel"
43+
testExtendedSQLWithSpace = "SELECT * FROM some_table TRAIN SomeModel; \n\t"
4444
)
4545

4646
var testServerAddress string
@@ -49,8 +49,6 @@ func mockRun(sql string, db *sf.DB, modelDir string, session *pb.Session) *sf.Pi
4949
rd, wr := sf.Pipe()
5050
go func() {
5151
defer wr.Close()
52-
// the server may automatically add a trailing ";", remove it
53-
sql = strings.Trim(sql, ";")
5452
switch sql {
5553
case testErrorSQL:
5654
wr.Write(fmt.Errorf("run error: %v", testErrorSQL))
@@ -65,7 +63,7 @@ func mockRun(sql string, db *sf.DB, modelDir string, session *pb.Session) *sf.Pi
6563
wr.Write([]interface{}{time.Now(), nil})
6664
case testExecuteSQL:
6765
wr.Write("success; 0 rows affected")
68-
case testExtendedSQL:
66+
case testExtendedSQL, testExtendedSQLNoSemicolon, testExtendedSQLWithSpace:
6967
wr.Write("log 0")
7068
wr.Write("log 1")
7169
default:
@@ -124,11 +122,10 @@ func TestSQL(t *testing.T) {
124122
stream, err := c.Run(ctx, &pb.Request{Sql: testErrorSQL})
125123
a.NoError(err)
126124
_, err = stream.Recv()
127-
a.Equal(status.Error(codes.Unknown, fmt.Sprintf("run error: %v", testErrorSQL)), err)
125+
a.Equal(status.Error(codes.Unknown, "Lex: Unknown problem ..."), err)
128126

129-
testMultipleSQL := fmt.Sprintf("%s; %s", testQuerySQL, testExtendedSQL)
130-
131-
for _, s := range []string{testQuerySQL, testExecuteSQL, testExtendedSQL, testExtendedSQLWithSpace, testMultipleSQL} {
127+
testMultipleSQL := fmt.Sprintf("%s %s", testQuerySQL, testExtendedSQL)
128+
for _, s := range []string{testQuerySQL, testExecuteSQL, testExtendedSQL, testExtendedSQLWithSpace, testExtendedSQLNoSemicolon, testMultipleSQL} {
132129
stream, err := c.Run(ctx, &pb.Request{Sql: s})
133130
a.NoError(err)
134131
for {

sql/executor.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,35 @@ func splitExtendedSQL(slct string) ([]string, error) {
9393
return []string{slct}, nil
9494
}
9595

96+
// SplitMultipleSQL returns a list of SQL statements if the input statements contains mutiple
97+
// SQL statements separated by ;
98+
func SplitMultipleSQL(statements string) ([]string, error) {
99+
l := newLexer(statements)
100+
var n sqlSymType
101+
var sqlList []string
102+
splitPos := 0
103+
for {
104+
t := l.Lex(&n)
105+
if t < 0 {
106+
return []string{}, fmt.Errorf("Lex: Unknown problem %s", statements[0-t:])
107+
}
108+
if t == 0 {
109+
if len(sqlList) == 0 {
110+
// NOTE: this line support executing SQL statement without a trailing ";"
111+
sqlList = append(sqlList, statements)
112+
}
113+
break
114+
}
115+
if t == ';' {
116+
splited := statements[splitPos:l.pos]
117+
splited = strings.TrimSpace(splited)
118+
sqlList = append(sqlList, splited)
119+
splitPos = l.pos
120+
}
121+
}
122+
return sqlList, nil
123+
}
124+
96125
// TODO(weiguo): isQuery is a hacky way to decide which API to call:
97126
// https://golang.org/pkg/database/sql/#DB.Exec .
98127
// We will need to extend our parser to be a full SQL parser in the future.

sql/executor_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,16 @@ func TestSplitExtendedSQL(t *testing.T) {
7171
a.Equal(`train a with b;`, s[0])
7272
}
7373

74+
func TestSplitMulipleSQL(t *testing.T) {
75+
a := assert.New(t)
76+
splited, err := SplitMultipleSQL(`CREATE TABLE copy_table_1 AS SELECT a,b,c FROM table_1 WHERE c<>";";
77+
SELECT * FROM copy_table_1;SELECT * FROM copy_table_1 TRAIN DNNClassifier WITH n_classes=2 INTO test_model;`)
78+
a.NoError(err)
79+
a.Equal("CREATE TABLE copy_table_1 AS SELECT a,b,c FROM table_1 WHERE c<>\";\";", splited[0])
80+
a.Equal("SELECT * FROM copy_table_1;", splited[1])
81+
a.Equal("SELECT * FROM copy_table_1 TRAIN DNNClassifier WITH n_classes=2 INTO test_model;", splited[2])
82+
}
83+
7484
func TestExecuteXGBoost(t *testing.T) {
7585
a := assert.New(t)
7686
modelDir := ""

0 commit comments

Comments
 (0)