Skip to content

Commit 41a0db0

Browse files
authored
Fix xgboost CSV data train bug (#2863)
* fix pai xgb bug * fix shape error * fix tf train generator * add e2e ut * polish * fix python derivation
1 parent 343a923 commit 41a0db0

14 files changed

Lines changed: 144 additions & 25 deletions

File tree

go/cmd/sqlflowserver/e2e_common_cases.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,3 +958,63 @@ func caseTestOptimizeClauseWithGroupBy(t *testing.T) {
958958
a.True(reflect.DeepEqual(decodedRows[2], []interface{}{"plantB", "marketA", int64(30)}))
959959
a.True(reflect.DeepEqual(decodedRows[3], []interface{}{"plantB", "marketB", int64(60)}))
960960
}
961+
962+
func caseEnd2EndXGBoostDenseFeatureColumn(t *testing.T, isPai bool) {
963+
trainTableName := "feature_derivation_case.train"
964+
modelName := "feature_derivation_case.xgb_dense_column_model"
965+
predictTableName := "feature_derivation_case.xgb_dense_column_predict_table"
966+
evaluateTableName := "feature_derivation_case.xgb_dense_column_evaluate_table"
967+
968+
if isPai {
969+
trainTableName = caseDB + ".feature_derivation_train"
970+
modelName = "my_xgb_dense_column_model"
971+
predictTableName = caseDB + ".xgb_dense_column_predict_table"
972+
evaluateTableName = caseDB + ".xgb_dense_column_evaluate_table"
973+
}
974+
975+
sqlTemplate := `SELECT c3, class FROM %[1]s
976+
TO TRAIN xgboost.gbtree
977+
WITH objective="binary:logistic",
978+
validation.select="SELECT c3, class FROM %[1]s",
979+
train.num_boost_round=100,
980+
eta=0.3,
981+
max_depth=5
982+
column DENSE(c3, 4)
983+
LABEL class
984+
INTO %[2]s;
985+
986+
SELECT c3 FROM %[1]s TO PREDICT %[3]s.class USING %[2]s;
987+
988+
SELECT * FROM %[3]s;
989+
990+
SELECT c3, class FROM %[1]s
991+
TO EVALUATE %[2]s
992+
WITH
993+
validation.metrics="accuracy_score,f1_score"
994+
LABEL class
995+
INTO %[4]s;
996+
997+
SELECT * FROM %[4]s;`
998+
999+
const selectTrainTableSQL = `SELECT * FROM %[2]s;`
1000+
1001+
if !isPai {
1002+
sqlTemplate += selectTrainTableSQL
1003+
}
1004+
1005+
sqls := fmt.Sprintf(sqlTemplate, trainTableName, modelName, predictTableName, evaluateTableName)
1006+
1007+
a := assert.New(t)
1008+
for _, sql := range strings.Split(sqls, ";") {
1009+
sql := strings.TrimSpace(sql)
1010+
if sql == "" {
1011+
continue
1012+
}
1013+
1014+
sql += ";"
1015+
_, _, _, err := connectAndRunSQL(sql)
1016+
if err != nil {
1017+
a.Fail(fmt.Sprintf("Run SQL failure:\n%s\n%s", sql, err.Error()))
1018+
}
1019+
}
1020+
}

go/cmd/sqlflowserver/e2e_mysql_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ func TestEnd2EndMySQL(t *testing.T) {
107107
t.Run("CaseEnd2EndCrossFeatureColumn", caseEnd2EndCrossFeatureColumn)
108108

109109
t.Run("CaseXGBoostSparseKeyValueColumn", caseXGBoostSparseKeyValueColumn)
110+
t.Run("CaseEnd2EndXGBoostDenseFeatureColumn", func(t *testing.T) {
111+
caseEnd2EndXGBoostDenseFeatureColumn(t, false)
112+
})
110113

111114
// Cases for optimize
112115
t.Run("CaseTestOptimizeClauseWithoutGroupBy", caseTestOptimizeClauseWithoutGroupBy)

go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,9 @@ func TestEnd2EndMaxComputePAI(t *testing.T) {
507507
// FIXME(typhoonzero): Add this test back when we solve error: model already exist issue on the CI.
508508
// t.Run("CaseTrainPAIRandomForests", CaseTrainPAIRandomForests)
509509
t.Run("CaseXGBoostSparseKeyValueColumn", caseXGBoostSparseKeyValueColumn)
510+
t.Run("CaseEnd2EndXGBoostDenseFeatureColumn", func(t *testing.T) {
511+
caseEnd2EndXGBoostDenseFeatureColumn(t, true)
512+
})
510513
})
511514

512515
}

go/cmd/sqlflowserver/testing.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,8 @@ func prepareTestData(dbStr string) error {
237237
datasets = append(datasets,
238238
fmt.Sprintf(testdata.IrisMaxComputeSQL, caseDB),
239239
fmt.Sprintf(testdata.ChurnMaxComputeSQL, caseDB),
240-
fmt.Sprintf(testdata.XGBoostMaxComputeSparseDataCaseSQL, caseDB))
240+
fmt.Sprintf(testdata.XGBoostMaxComputeSparseDataCaseSQL, caseDB),
241+
fmt.Sprintf(testdata.FeatureDerivationCaseSQLMaxCompute, caseDB))
241242
default:
242243
return fmt.Errorf("unrecognized SQLFLOW_TEST_DB %s", db)
243244
}

go/ir/derivation.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,15 @@ func fillCSVFieldDesc(cellData string, fieldDescMap FieldDescMap, fieldName stri
138138
size *= s
139139
}
140140

141-
values := strings.Split(cellData, ",")
141+
rawValues := strings.Split(cellData, ",")
142+
values := make([]string, 0, len(rawValues))
143+
for _, value := range rawValues {
144+
trimmedValue := strings.TrimSpace(value)
145+
if trimmedValue != "" {
146+
values = append(values, trimmedValue)
147+
}
148+
}
149+
142150
// set shape only when the column is "DENSE"
143151
if fieldDescMap[fieldName].IsSparse == false && fieldDescMap[fieldName].Shape == nil {
144152
fieldDescMap[fieldName].Shape = []int{len(values)}
@@ -224,7 +232,7 @@ func inferStringDataFormat(strData string) string {
224232
const realNumberRegex = "((\\+|-)?([0-9]+)(\\.[0-9]+)?)|((\\+|-)?\\.?[0-9]+)"
225233

226234
// string in the form of "3,5,7"
227-
csvRegex := regexp.MustCompile(fmt.Sprintf("^((%s)\\,)+(%s)$", realNumberRegex, realNumberRegex))
235+
csvRegex := regexp.MustCompile(fmt.Sprintf("^\\s*((%s)\\s*\\,\\s*)+(%s)\\s*(\\,?)\\s*$", realNumberRegex, realNumberRegex))
228236
if csvRegex.MatchString(strData) {
229237
return csv
230238
}

go/ir/derivation_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ func TestCSVRegex(t *testing.T) {
2727
"1,2,3,4",
2828
"1.3,-3.2,132,32",
2929
"33,-33",
30+
"33,-33,",
31+
" 33 , -70 , 80 , ",
32+
" 33 , -70 , 80 ,",
33+
" 33 , -70 , 80, ",
34+
" 33 , -70 , 80,",
3035
}
3136
for _, s := range csvStings {
3237
if inferStringDataFormat(s) != csv {

go/sql/testdata/feature_derivation_case.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
package testdata
1515

1616
// FeatureDerivationCaseSQL is .sql format data samples to test feature derivation.
17-
var FeatureDerivationCaseSQL = `CREATE DATABASE IF NOT EXISTS feature_derivation_case;
17+
const FeatureDerivationCaseSQL = `CREATE DATABASE IF NOT EXISTS feature_derivation_case;
1818
DROP TABLE IF EXISTS feature_derivation_case.train;
1919
CREATE TABLE feature_derivation_case.train (
2020
c1 float,
@@ -25,14 +25,14 @@ CREATE TABLE feature_derivation_case.train (
2525
c6 CHAR(255),
2626
class TINYINT);
2727
INSERT INTO feature_derivation_case.train VALUES
28-
(6.4,2.8, '1,4,2,3', '1,3,2,6', '3,140', 'MALE', 0),
29-
(5.0,2.3, '1,3,8,3', '3,2,5,3', '93,12,1,392,49,13,398', 'FEMALE', 1),
28+
(6.4,2.8, '1,4,2,3,', '1,3,2,6', '3,140', 'MALE', 0),
29+
(5.0,2.3, '1,3,8,3,', '3,2,5,3', '93,12,1,392,49,13,398', 'FEMALE', 1),
3030
(4.9,2.5, '9,2,2,2', '1.2,4.8,3.2,1', '10,11,32,32,1', 'FEMALE', 1),
3131
(5.1,2.2, '2,1,8,5', '5.0,3,2,1', '23,22,1', 'FEMALE', 1),
3232
(4.8,3.1, '3,3,2,6', '3,2,3,5', '30,3,1,32', 'NULL', 0);`
3333

3434
// FeatureDerivationCaseSQLHive is .sql format data samples to test feature derivation.
35-
var FeatureDerivationCaseSQLHive = `CREATE DATABASE IF NOT EXISTS feature_derivation_case;
35+
const FeatureDerivationCaseSQLHive = `CREATE DATABASE IF NOT EXISTS feature_derivation_case;
3636
DROP TABLE IF EXISTS feature_derivation_case.train;
3737
CREATE TABLE feature_derivation_case.train (
3838
c1 float,
@@ -48,3 +48,22 @@ INSERT INTO TABLE feature_derivation_case.train VALUES
4848
(4.9,2.5, '9,2,2,2', '1.2,4.8,3.2,1', '10,11,32,32,1', 'FEMALE', 1),
4949
(5.1,2.2, '2,1,8,5', '5.0,3,2,1', '23,22,1', 'FEMALE', 1),
5050
(4.8,3.1, '3,3,2,6', '3,2,3,5', '30,3,1,32', 'NULL', 0);`
51+
52+
// FeatureDerivationCaseSQLMaxCompute is .sql format data samples to test feature derivation on MaxCompute.
53+
const FeatureDerivationCaseSQLMaxCompute = `
54+
DROP TABLE IF EXISTS %[1]s.feature_derivation_train;
55+
CREATE TABLE %[1]s.feature_derivation_train (
56+
c1 DOUBLE,
57+
c2 DOUBLE,
58+
c3 STRING,
59+
c4 STRING,
60+
c5 STRING,
61+
c6 STRING,
62+
class INT);
63+
INSERT INTO %[1]s.feature_derivation_train VALUES
64+
(6.4,2.8, '1,4,2,3,', '1,3,2,6', '3,140', 'MALE', 0),
65+
(5.0,2.3, '1,3,8,3,', '3,2,5,3', '93,12,1,392,49,13,398', 'FEMALE', 1),
66+
(4.9,2.5, '9,2,2,2', '1.2,4.8,3.2,1', '10,11,32,32,1', 'FEMALE', 1),
67+
(5.1,2.2, '2,1,8,5', '5.0,3,2,1', '23,22,1', 'FEMALE', 1),
68+
(4.8,3.1, '3,3,2,6', '3,2,3,5', '30,3,1,32', 'NULL', 0);
69+
`

python/runtime/db.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,19 @@ def read_feature(raw_val, feature_spec, feature_name):
4949
elif feature_spec["delimiter"] != "":
5050
# Dense string vector
5151
if feature_spec["dtype"] == "float32":
52-
return np.fromstring(raw_val,
53-
dtype=np.float32,
54-
sep=feature_spec["delimiter"])
52+
vec = np.fromstring(raw_val,
53+
dtype=np.float32,
54+
sep=feature_spec["delimiter"])
5555
elif feature_spec["dtype"] == "int64":
56-
return np.fromstring(raw_val,
57-
dtype=np.int64,
58-
sep=feature_spec["delimiter"])
56+
vec = np.fromstring(raw_val,
57+
dtype=np.int64,
58+
sep=feature_spec["delimiter"])
5959
else:
6060
raise ValueError('unrecognize dtype {}'.format(
61-
feature_spec[feature_name]["dtype"]))
61+
feature_spec["dtype"]))
62+
63+
vec = vec.reshape(list(feature_spec["shape"]))
64+
return vec,
6265
elif feature_spec["dtype"] == "float32":
6366
return float(raw_val),
6467
elif feature_spec["dtype"] == "int64":

python/runtime/feature/derivation.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def new_default_field_desc(name):
116116

117117
# A regular expression to match the form of "3,5,7"
118118
CSV_PATTERN = re.compile(
119-
"((%s)\\,)+(%s)" %
119+
"\\s*((%s)\\s*\\,\\s*)+(%s)\\s*(\\,?)\\s*" %
120120
(REAL_NUMBER_PATTERN.pattern, REAL_NUMBER_PATTERN.pattern))
121121

122122
# A regular expression to match the form of "0:3.2 7:-2.3"
@@ -160,7 +160,13 @@ def fill_csv_field_desc(cell, field_desc):
160160
Returns:
161161
None.
162162
"""
163-
values = cell.split(",")
163+
raw_values = cell.split(",")
164+
values = []
165+
for v in raw_values:
166+
v = v.strip()
167+
if v:
168+
values.append(v)
169+
164170
if field_desc.is_sparse:
165171
assert field_desc.shape is not None, \
166172
"the shape of CSV format data must be given"

python/runtime/feature/derivation_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ def test_csv_strings(self):
2828
"1,2,3,4",
2929
"1.3,-3.2,132,32",
3030
"33,-33",
31+
"33,-33,",
32+
" 33 , -70 , 80 , ",
33+
" 33 , -70 , 80 ,",
34+
" 33 , -70 , 80, ",
35+
" 33 , -70 , 80,",
3136
]
3237

3338
for s in csv_strs:

0 commit comments

Comments
 (0)