Skip to content

Commit aac2130

Browse files
authored
Distinguish XGBoost model when generating prediction workflow code (#2854)
* judge XGBoost model * fix pre-commit * polish * polish
1 parent 143a7b7 commit aac2130

File tree

12 files changed

+350
-134
lines changed

12 files changed

+350
-134
lines changed

go/cmd/sqlflowserver/e2e_workflow_test.go

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,25 @@ func TestEnd2EndFluidWorkflow(t *testing.T) {
367367
func CaseWorkflowTrainXgboost(t *testing.T) {
368368
a := assert.New(t)
369369

370-
sqlProgram := `SELECT * FROM iris.train LIMIT 100;
370+
testMain := func(sqlProgram string) {
371+
conn, err := createRPCConn()
372+
if err != nil {
373+
a.Fail("Create gRPC client error: %v", err)
374+
}
375+
defer conn.Close()
376+
377+
cli := pb.NewSQLFlowClient(conn)
378+
ctx, cancel := context.WithTimeout(context.Background(), 3600*time.Second)
379+
defer cancel()
380+
381+
stream, err := cli.Run(ctx, &pb.Request{Sql: sqlProgram, Session: &pb.Session{DbConnStr: testDatasource}})
382+
if err != nil {
383+
a.Fail("Create gRPC client error: %v", err)
384+
}
385+
a.NoError(checkWorkflow(ctx, cli, stream))
386+
}
387+
388+
extraTrainSQLProgram := `SELECT * FROM iris.train LIMIT 100;
371389
372390
SELECT * FROM iris.train
373391
TO TRAIN xgboost.gbtree
@@ -382,26 +400,16 @@ COLUMN sepal_length, DENSE(sepal_width)
382400
LABEL class
383401
INTO sqlflow_models.xgb_classification;
384402
403+
SELECT * FROM sqlflow_models.xgb_classification;
404+
`
405+
406+
sqlProgram := `
385407
SELECT * FROM iris.test
386408
TO PREDICT iris.test_result_table.class
387409
USING sqlflow_models.xgb_classification;
388410
389411
SELECT * FROM iris.test_result_table;
390412
`
391-
392-
conn, err := createRPCConn()
393-
if err != nil {
394-
a.Fail("Create gRPC client error: %v", err)
395-
}
396-
defer conn.Close()
397-
398-
cli := pb.NewSQLFlowClient(conn)
399-
ctx, cancel := context.WithTimeout(context.Background(), 3600*time.Second)
400-
defer cancel()
401-
402-
stream, err := cli.Run(ctx, &pb.Request{Sql: sqlProgram, Session: &pb.Session{DbConnStr: testDatasource}})
403-
if err != nil {
404-
a.Fail("Create gRPC client error: %v", err)
405-
}
406-
a.NoError(checkWorkflow(ctx, cli, stream))
413+
testMain(extraTrainSQLProgram + sqlProgram)
414+
testMain(sqlProgram)
407415
}

go/codegen/experimental/codegen_couler.go

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import (
2020
"strconv"
2121
"text/template"
2222

23-
"sqlflow.org/sqlflow/go/ir"
2423
pb "sqlflow.org/sqlflow/go/proto"
2524
"sqlflow.org/sqlflow/go/workflow/couler"
2625
)
@@ -53,17 +52,14 @@ func GenerateCodeCouler(sqlProgram string, session *pb.Session) (string, error)
5352
if err != nil {
5453
return "", err
5554
}
56-
stepList := []*stepContext{}
55+
stepList := make([]*stepContext, 0)
5756
for idx, stmt := range stmts {
58-
stepCode, err := generateStepCode(stmt, idx, session)
57+
stepCode, image, err := generateStepCodeAndImage(stmt, idx, session, stmts)
5958
if err != nil {
6059
return "", err
6160
}
62-
image := defaultDockerImage
63-
if trainStmt, ok := stmt.(*ir.TrainStmt); ok {
64-
if trainStmt.ModelImage != "" {
65-
image = trainStmt.ModelImage
66-
}
61+
if image == "" {
62+
image = defaultDockerImage
6763
}
6864
// TODO(typhoonzero): find out the image that should be used by the predict statements.
6965
step := &stepContext{

go/codegen/experimental/codegen_normal_stmt.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ type normalStmtFiller struct {
4848
Stmt string
4949
}
5050

51-
// GenerateNormalStmtStep generate step Python code to run a normal SQL statement.
52-
func GenerateNormalStmtStep(stmt string, session *pb.Session, stepIndex int) (string, error) {
51+
// generateNormalStmtStep generate step Python code to run a normal SQL statement.
52+
func generateNormalStmtStep(stmt string, stepIndex int, session *pb.Session) (string, error) {
5353
filler := &normalStmtFiller{
5454
StepIndex: stepIndex,
5555
DataSource: session.DbConnStr,

go/codegen/experimental/codegen_step.go

Lines changed: 116 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414
package experimental
1515

1616
import (
17+
"encoding/binary"
1718
"fmt"
19+
"github.com/bitly/go-simplejson"
1820
"net/url"
21+
"sqlflow.org/sqlflow/go/model"
22+
"sqlflow.org/sqlflow/go/sqlfs"
1923
"strings"
2024

2125
"sqlflow.org/sqlflow/go/database"
@@ -24,37 +28,130 @@ import (
2428
pb "sqlflow.org/sqlflow/go/proto"
2529
)
2630

27-
// TODO(sneaxiy): implement this method to distinguish whether
28-
// a model is a XGBoost model.
29-
func isTrainedXBoostModel(modelName string) bool {
30-
return true
31-
}
32-
33-
func generateStepCode(sqlStmt ir.SQLFlowStmt, stepIndex int, session *pb.Session) (string, error) {
31+
func generateStepCodeAndImage(sqlStmt ir.SQLFlowStmt, stepIndex int, session *pb.Session, sqlStmts []ir.SQLFlowStmt) (string, string, error) {
3432
switch stmt := sqlStmt.(type) {
3533
case *ir.TrainStmt:
36-
return generateTrainCode(stmt, stepIndex, session)
34+
return generateTrainCodeAndImage(stmt, stepIndex, session)
3735
case *ir.PredictStmt:
38-
return generatePredictCode(stmt, stepIndex, session)
36+
return generatePredictCodeAndImage(stmt, stepIndex, session, sqlStmts)
3937
case *ir.NormalStmt:
40-
return GenerateNormalStmtStep(string(*stmt), session, stepIndex)
38+
code, err := generateNormalStmtStep(string(*stmt), stepIndex, session)
39+
return code, "", err
4140
default:
42-
return "", fmt.Errorf("not implemented stmt execution type %v", stmt)
41+
return "", "", fmt.Errorf("not implemented stmt execution type %v", stmt)
42+
}
43+
}
44+
45+
func generateTrainCodeAndImage(trainStmt *ir.TrainStmt, stepIndex int, session *pb.Session) (string, string, error) {
46+
isXGBoost := isXGBoostEstimator(trainStmt.Estimator)
47+
if isXGBoost {
48+
code, err := XGBoostGenerateTrain(trainStmt, stepIndex, session)
49+
if err != nil {
50+
return "", "", err
51+
}
52+
return code, trainStmt.ModelImage, nil
53+
}
54+
return "", "", fmt.Errorf("not implemented estimator type %s", trainStmt.Estimator)
55+
}
56+
57+
func generatePredictCodeAndImage(predStmt *ir.PredictStmt, stepIndex int, session *pb.Session, sqlStmts []ir.SQLFlowStmt) (string, string, error) {
58+
trainStmt := findModelGenerationTrainStmt(predStmt.Using, stepIndex, sqlStmts)
59+
image := ""
60+
isXGBoost := false
61+
if trainStmt != nil {
62+
image = trainStmt.ModelImage
63+
isXGBoost = isXGBoostEstimator(trainStmt.Estimator)
64+
} else {
65+
meta, err := getModelMetadata(session, predStmt.Using)
66+
if err != nil {
67+
return "", "", err
68+
}
69+
image = meta.imageName()
70+
isXGBoost = meta.isXGBoostModel()
4371
}
72+
73+
if isXGBoost {
74+
code, err := XGBoostGeneratePredict(predStmt, stepIndex, session)
75+
if err != nil {
76+
return "", "", err
77+
}
78+
return code, image, nil
79+
}
80+
return "", "", fmt.Errorf("not implemented model type")
4481
}
4582

46-
func generateTrainCode(trainStmt *ir.TrainStmt, stepIndex int, session *pb.Session) (string, error) {
47-
if strings.HasPrefix(strings.ToUpper(trainStmt.Estimator), "XGBOOST.") {
48-
return XGBoostGenerateTrain(trainStmt, stepIndex, session)
83+
// findModelGenerationTrainStmt finds the *ir.TrainStmt that generates the model named `modelName`.
84+
// TODO(sneaxiy): find a better way to do this when we have a well designed dependency analysis.
85+
func findModelGenerationTrainStmt(modelName string, idx int, sqlStmts []ir.SQLFlowStmt) *ir.TrainStmt {
86+
idx--
87+
for idx >= 0 {
88+
trainStmt, ok := sqlStmts[idx].(*ir.TrainStmt)
89+
if ok && trainStmt.Into == modelName {
90+
return trainStmt
91+
}
92+
idx--
4993
}
50-
return "", fmt.Errorf("not implemented estimator type %s", trainStmt.Estimator)
94+
return nil
95+
}
96+
97+
func isXGBoostEstimator(estimator string) bool {
98+
return strings.HasPrefix(strings.ToUpper(estimator), "XGBOOST.")
99+
}
100+
101+
type metadata simplejson.Json
102+
103+
func (m *metadata) imageName() string {
104+
return (*simplejson.Json)(m).Get("model_repo_image").MustString()
51105
}
52106

53-
func generatePredictCode(predStmt *ir.PredictStmt, stepIndex int, session *pb.Session) (string, error) {
54-
if isTrainedXBoostModel(predStmt.Using) {
55-
return XGBoostGeneratePredict(predStmt, stepIndex, session)
107+
func (m *metadata) isXGBoostModel() bool {
108+
return (*simplejson.Json)(m).Get("model_type").MustInt() == model.XGBOOST
109+
}
110+
111+
func getModelMetadata(session *pb.Session, table string) (*metadata, error) {
112+
submitter := getSubmitter(session)
113+
if submitter == "local" {
114+
return getModelMetadataFromDB(session.DbConnStr, table)
115+
}
116+
return nil, fmt.Errorf("not supported submitter %s", submitter)
117+
}
118+
119+
func getModelMetadataFromDB(dbConnStr, table string) (*metadata, error) {
120+
db, err := database.OpenAndConnectDB(dbConnStr)
121+
if err != nil {
122+
return nil, err
123+
}
124+
defer db.Close()
125+
126+
fs, err := sqlfs.Open(db.DB, table)
127+
if err != nil {
128+
return nil, err
129+
}
130+
defer fs.Close()
131+
132+
lengthBytes := make([]byte, 8)
133+
readCnt, err := fs.Read(lengthBytes)
134+
if err != nil {
135+
return nil, err
136+
}
137+
if readCnt != 8 {
138+
return nil, fmt.Errorf("invalid model table")
139+
}
140+
141+
length := binary.LittleEndian.Uint64(lengthBytes)
142+
jsonBytes := make([]byte, length)
143+
readCnt, err = fs.Read(jsonBytes)
144+
if err != nil {
145+
return nil, err
146+
}
147+
if readCnt != int(length) {
148+
return nil, fmt.Errorf("invalid model metadata")
149+
}
150+
json, err := simplejson.NewJson(jsonBytes)
151+
if err != nil {
152+
return nil, err
56153
}
57-
return "", fmt.Errorf("not implemented model type")
154+
return (*metadata)(json), nil
58155
}
59156

60157
func initializeAndCheckAttributes(stmt ir.SQLFlowStmt) error {

go/codegen/experimental/xgboost.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ func XGBoostGenerateTrain(trainStmt *ir.TrainStmt, stepIndex int, session *pb.Se
115115
DiskCache: diskCache,
116116
BatchSize: batchSize,
117117
Epoch: epoch,
118-
Submitter: getSubmitter(session, "local"),
118+
Submitter: getSubmitter(session),
119119
}
120120
var program bytes.Buffer
121121
var trainTemplate = template.Must(template.New("Train").Parse(xgbTrainTemplate))
@@ -183,7 +183,7 @@ func XGBoostGeneratePredict(predStmt *ir.PredictStmt, stepIndex int, session *pb
183183
PredLabelName: predStmt.ResultColumn,
184184
ResultTable: predStmt.ResultTable,
185185
Load: predStmt.Using,
186-
Submitter: getSubmitter(session, "local"),
186+
Submitter: getSubmitter(session),
187187
}
188188

189189
var program bytes.Buffer
@@ -208,7 +208,7 @@ def step_entry_{{.StepIndex}}():
208208
load='''{{.Load}}''')
209209
`
210210

211-
func getSubmitter(session *pb.Session, defaultValue string) string {
211+
func getSubmitter(session *pb.Session) string {
212212
if session.Submitter != "" {
213213
return session.Submitter
214214
}
@@ -217,7 +217,7 @@ func getSubmitter(session *pb.Session, defaultValue string) string {
217217
if submitter != "" {
218218
return submitter
219219
}
220-
return defaultValue
220+
return "local"
221221
}
222222

223223
func generateFeatureColumnCode(fcMap map[string][]ir.FeatureColumn) string {

go/codegen/pai/codegen.go

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package pai
1616
import (
1717
"bytes"
1818
"fmt"
19+
"sqlflow.org/sqlflow/go/model"
1920
"strings"
2021
"text/template"
2122

@@ -27,15 +28,6 @@ import (
2728
"sqlflow.org/sqlflow/go/verifier"
2829
)
2930

30-
const (
31-
// ModelTypeTF is the mode type that trained by PAI TensorFlow.
32-
ModelTypeTF = iota
33-
// ModelTypeXGBoost is the model type that use PAI TensorFlow to train XGBoost models.
34-
ModelTypeXGBoost
35-
// ModelTypePAIML is the model type that trained by PAI machine learning algorithm toolkit
36-
ModelTypePAIML
37-
)
38-
3931
const entryFile = "entry.py"
4032

4133
// BucketName is the OSS bucket to save trained models
@@ -154,11 +146,11 @@ func Predict(ir *ir.PredictStmt, session *pb.Session, tarball, paramsFile, model
154146
if e != nil {
155147
return
156148
}
157-
if modelType == ModelTypePAIML {
149+
if modelType == model.PAIML {
158150
if paiCmd, e = getPAIPredictCmd(ir, session); e != nil {
159151
return
160152
}
161-
} else if modelType == ModelTypeXGBoost {
153+
} else if modelType == model.XGBOOST {
162154
requirements, e = genRequirements(true)
163155
ossURI := OSSModelURL(ossModelPath)
164156
var xgbPredCode bytes.Buffer
@@ -220,15 +212,15 @@ func Explain(ir *ir.ExplainStmt, session *pb.Session, tarball, paramsFile, model
220212
}
221213

222214
expn := newExplainRender(session.UserId)
223-
if modelType == ModelTypePAIML {
215+
if modelType == model.PAIML {
224216
if ir.Into == "" {
225217
return nil, fmt.Errorf("explain PAI random forests model need INTO clause to output the explain result to a table")
226218
}
227219
if expn.Requirements, err = genRequirements(false); err != nil {
228220
return nil, err
229221
}
230222
expn.PaiCmd, err = getExplainRandomForestsPAICmd(ir, session)
231-
} else if modelType == ModelTypeXGBoost {
223+
} else if modelType == model.XGBOOST {
232224
if expn.Requirements, err = genRequirements(true); err != nil {
233225
return nil, err
234226
}
@@ -284,9 +276,9 @@ func Evaluate(ir *ir.EvaluateStmt, session *pb.Session, tarball, paramsFile, mod
284276
return "", "", "", err
285277
}
286278

287-
if modelType == ModelTypePAIML {
279+
if modelType == model.PAIML {
288280
return "", "", "", fmt.Errorf("evaluate PAI ML model is not supported for now")
289-
} else if modelType == ModelTypeXGBoost {
281+
} else if modelType == model.XGBOOST {
290282
if requirements, err = genRequirements(true); err != nil {
291283
return "", "", "", err
292284
}

go/codegen/pai/codegen_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"fmt"
1818
"os"
1919
"regexp"
20+
"sqlflow.org/sqlflow/go/model"
2021
"strings"
2122
"testing"
2223

@@ -148,7 +149,7 @@ func TestPredictCodegen(t *testing.T) {
148149
ossModelPath := "iris/sqlflow/my_dnn_model"
149150
scriptPath := "file:///tmp/task.tar.gz"
150151
paramsPath := "file:///tmp/params.txt"
151-
paiTFCode, paiCmd, _, e := Predict(ir, sess, scriptPath, paramsPath, "my_dnn_model", ossModelPath, "", ModelTypeTF)
152+
paiTFCode, paiCmd, _, e := Predict(ir, sess, scriptPath, paramsPath, "my_dnn_model", ossModelPath, "", model.TENSORFLOW)
152153
a.NoError(e)
153154
a.False(hasUnknownParameters(paiTFCode, knownPredictParams))
154155
tfCode, err := tensorflow.Pred(ir, sess)

0 commit comments

Comments
 (0)