@@ -256,6 +256,7 @@ func TestEnd2EndMySQL(t *testing.T) {
256256 t .Run ("CaseSparseFeature" , CaseSparseFeature )
257257 t .Run ("CaseSQLByPassLeftJoin" , CaseSQLByPassLeftJoin )
258258 t .Run ("CaseTrainRegression" , CaseTrainRegression )
259+ t .Run ("CaseTrainXGBoostRegression" , CaseTrainXGBoostRegression )
259260 t .Run ("CaseTrainDeepWideModel" , CaseTrainDeepWideModel )
260261
261262}
@@ -1000,3 +1001,68 @@ FROM housing.predict LIMIT 5;`)
10001001 a .False (nilCount == 13 )
10011002 }
10021003}
1004+
1005+ // CaseTrainXGBoostRegression is used to test xgboost regression models
1006+ func CaseTrainXGBoostRegression (t * testing.T ) {
1007+ a := assert .New (t )
1008+ trainSQL := fmt .Sprintf (`
1009+ SELECT *
1010+ FROM housing.train
1011+ TRAIN xgboost.gbtree
1012+ WITH
1013+ objective="reg:squarederror",
1014+ train.num_boost_round = 30
1015+ COLUMN f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12,f13
1016+ LABEL target
1017+ INTO sqlflow_models.my_xgb_regression_model;
1018+ ` )
1019+
1020+ conn , err := createRPCConn ()
1021+ a .NoError (err )
1022+ defer conn .Close ()
1023+ cli := pb .NewSQLFlowClient (conn )
1024+
1025+ ctx , cancel := context .WithTimeout (context .Background (), 300 * time .Second )
1026+ defer cancel ()
1027+
1028+ stream , err := cli .Run (ctx , sqlRequest (trainSQL ))
1029+ if err != nil {
1030+ a .Fail ("Check if the server started successfully. %v" , err )
1031+ }
1032+ // call ParseRow only to wait train finish
1033+ ParseRow (stream )
1034+
1035+ predSQL := fmt .Sprintf (`SELECT *
1036+ FROM housing.test
1037+ PREDICT housing.xgb_predict.target
1038+ USING sqlflow_models.my_xgb_regression_model;` )
1039+
1040+ stream , err = cli .Run (ctx , sqlRequest (predSQL ))
1041+ if err != nil {
1042+ a .Fail ("Check if the server started successfully. %v" , err )
1043+ }
1044+ // call ParseRow only to wait predict finish
1045+ ParseRow (stream )
1046+
1047+ showPred := fmt .Sprintf (`SELECT *
1048+ FROM housing.xgb_predict LIMIT 5;` )
1049+
1050+ stream , err = cli .Run (ctx , sqlRequest (showPred ))
1051+ if err != nil {
1052+ a .Fail ("Check if the server started successfully. %v" , err )
1053+ }
1054+ _ , rows := ParseRow (stream )
1055+
1056+ for _ , row := range rows {
1057+ // NOTE: predict result maybe random, only check predicted
1058+ // class >=0, need to change to more flexible checks than
1059+ // checking expectedPredClasses := []int64{2, 1, 0, 2, 0}
1060+ AssertGreaterEqualAny (a , row [13 ], float64 (0 ))
1061+
1062+ // avoiding nil features in predict result
1063+ nilCount := 0
1064+ for ; nilCount < 13 && row [nilCount ] == nil ; nilCount ++ {
1065+ }
1066+ a .False (nilCount == 13 )
1067+ }
1068+ }
0 commit comments