Skip to content

Commit 5a9d9d2

Browse files
authored
support xgboost regression model (#797)
1 parent 5ae82c0 commit 5a9d9d2

3 files changed

Lines changed: 86 additions & 3 deletions

File tree

cmd/sqlflowserver/main_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ func TestEnd2EndMySQL(t *testing.T) {
256256
t.Run("CaseSparseFeature", CaseSparseFeature)
257257
t.Run("CaseSQLByPassLeftJoin", CaseSQLByPassLeftJoin)
258258
t.Run("CaseTrainRegression", CaseTrainRegression)
259+
t.Run("CaseTrainXGBoostRegression", CaseTrainXGBoostRegression)
259260
t.Run("CaseTrainDeepWideModel", CaseTrainDeepWideModel)
260261

261262
}
@@ -1000,3 +1001,68 @@ FROM housing.predict LIMIT 5;`)
10001001
a.False(nilCount == 13)
10011002
}
10021003
}
1004+
1005+
// CaseTrainXGBoostRegression is used to test xgboost regression models
1006+
func CaseTrainXGBoostRegression(t *testing.T) {
1007+
a := assert.New(t)
1008+
trainSQL := fmt.Sprintf(`
1009+
SELECT *
1010+
FROM housing.train
1011+
TRAIN xgboost.gbtree
1012+
WITH
1013+
objective="reg:squarederror",
1014+
train.num_boost_round = 30
1015+
COLUMN f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12,f13
1016+
LABEL target
1017+
INTO sqlflow_models.my_xgb_regression_model;
1018+
`)
1019+
1020+
conn, err := createRPCConn()
1021+
a.NoError(err)
1022+
defer conn.Close()
1023+
cli := pb.NewSQLFlowClient(conn)
1024+
1025+
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second)
1026+
defer cancel()
1027+
1028+
stream, err := cli.Run(ctx, sqlRequest(trainSQL))
1029+
if err != nil {
1030+
a.Fail("Check if the server started successfully. %v", err)
1031+
}
1032+
// call ParseRow only to wait train finish
1033+
ParseRow(stream)
1034+
1035+
predSQL := fmt.Sprintf(`SELECT *
1036+
FROM housing.test
1037+
PREDICT housing.xgb_predict.target
1038+
USING sqlflow_models.my_xgb_regression_model;`)
1039+
1040+
stream, err = cli.Run(ctx, sqlRequest(predSQL))
1041+
if err != nil {
1042+
a.Fail("Check if the server started successfully. %v", err)
1043+
}
1044+
// call ParseRow only to wait predict finish
1045+
ParseRow(stream)
1046+
1047+
showPred := fmt.Sprintf(`SELECT *
1048+
FROM housing.xgb_predict LIMIT 5;`)
1049+
1050+
stream, err = cli.Run(ctx, sqlRequest(showPred))
1051+
if err != nil {
1052+
a.Fail("Check if the server started successfully. %v", err)
1053+
}
1054+
_, rows := ParseRow(stream)
1055+
1056+
for _, row := range rows {
1057+
// NOTE: predict result maybe random, only check predicted
1058+
// class >=0, need to change to more flexible checks than
1059+
// checking expectedPredClasses := []int64{2, 1, 0, 2, 0}
1060+
AssertGreaterEqualAny(a, row[13], float64(0))
1061+
1062+
// avoiding nil features in predict result
1063+
nilCount := 0
1064+
for ; nilCount < 13 && row[nilCount] == nil; nilCount++ {
1065+
}
1066+
a.False(nilCount == 13)
1067+
}
1068+
}

sql/executor_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@ func TestExecuteXGBoost(t *testing.T) {
8484
})
8585
}
8686

87+
func TestExecuteXGBoostRegression(t *testing.T) {
88+
a := assert.New(t)
89+
modelDir := ""
90+
a.NotPanics(func() {
91+
stream := runExtendedSQL(testXGBoostTrainSelectIris, testDB, modelDir, nil)
92+
a.True(goodStream(stream.ReadAll()))
93+
stream = runExtendedSQL(testAnalyzeTreeModelSelectIris, testDB, modelDir, nil)
94+
a.True(goodStream(stream.ReadAll()))
95+
stream = runExtendedSQL(testXGBoostPredictIris, testDB, modelDir, nil)
96+
a.True(goodStream(stream.ReadAll()))
97+
})
98+
}
99+
87100
func TestExecutorTrainAndPredictDNN(t *testing.T) {
88101
a := assert.New(t)
89102
modelDir := ""

sql/template_xgboost.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,11 @@ dpred = xgb_dataset('predict.txt', """{{.PredictionDatasetSQL}}""")
134134
bst = xgb.Booster({'nthread': 4}) # init model
135135
bst.load_model("{{.Save}}") # load data
136136
preds = bst.predict(dpred)
137-
# TODO(typhoonzero): regression models may have different behavior
138-
pred_classes = np.argmax(np.array(preds), axis=1)
137+
138+
# TODO(Yancey1989): using the train parameters to decide regressoin model or classifier model
139+
if len(preds.shape) == 2:
140+
# classifier result
141+
preds = np.argmax(np.array(preds), axis=1)
139142
140143
feature_file_read = open("predict.txt", "r")
141144
@@ -149,7 +152,8 @@ with buffered_db_writer(driver, conn, "{{.TableName}}", result_column_names, 100
149152
if not line:
150153
break
151154
row = [i.split(":")[1] for i in line.replace("\n", "").split("\t")[1:]]
152-
row.append(pred_classes[line_no])
155+
row.append(preds[line_no])
153156
w.write(row)
154157
line_no += 1
158+
print("Done predicting. Predict table : {{.TableName}}")
155159
`

0 commit comments

Comments
 (0)