Skip to content

Commit 45645b6

Browse files
authored
Refine feature column code generation (#2860)
* use codegen way to get py fc * fix ut
1 parent eaa2473 commit 45645b6

6 files changed

Lines changed: 47 additions & 61 deletions

File tree

go/cmd/sqlflowserver/e2e_workflow_test.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,13 @@ WITH objective="multi:softmax",num_class=3
375375
LABEL class
376376
INTO sqlflow_models.xgb_classification;
377377
378+
SELECT * FROM iris.train
379+
TO TRAIN xgboost.gbtree
380+
WITH objective="multi:softmax",num_class=3
381+
COLUMN sepal_length, DENSE(sepal_width)
382+
LABEL class
383+
INTO sqlflow_models.xgb_classification;
384+
378385
SELECT * FROM iris.test
379386
TO PREDICT iris.test_result_table.class
380387
USING sqlflow_models.xgb_classification;

go/codegen/experimental/codegen_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ func TestExperimentalXGBCodegen(t *testing.T) {
4343
if err != nil {
4444
t.Errorf("error %s", err)
4545
}
46-
expected := `feature_column_map = {"feature_columns": [fc.NumericColumn(fd.FieldDesc(name="petal_length", dtype=fd.DataType.FLOAT32, delimiter="", format="", shape=[1], is_sparse=False, vocabulary=[]))]}`
46+
expected := `feature_column_map = {"feature_columns":[runtime.feature.column.NumericColumn(runtime.feature.field_desc.FieldDesc(name="petal_length", dtype=runtime.feature.field_desc.DataType.FLOAT32, delimiter="", format="", shape=[1], is_sparse=False, vocabulary=[]))]}`
4747
a.True(strings.Contains(coulerCode, expected))
4848
}
4949

go/codegen/experimental/xgboost.go

Lines changed: 16 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import (
1818
"encoding/json"
1919
"fmt"
2020
"os"
21-
"reflect"
2221
"strings"
2322
"text/template"
2423

@@ -82,14 +81,8 @@ func XGBoostGenerateTrain(trainStmt *ir.TrainStmt, stepIndex int, session *pb.Se
8281
return "", fmt.Errorf("xgboost only support 0 or 1 feature column set, received %d", len(trainStmt.Features))
8382
}
8483
// featureColumnCode is a python map definition code like fc_map = {"feature_columns": [...]}
85-
featureColumnCode := ""
86-
if len(trainStmt.Features) == 1 {
87-
featureColumnCode, err = generateFeatureColumnCode(trainStmt.Features["feature_columns"])
88-
if err != nil {
89-
return "", err
90-
}
91-
}
92-
labelColumnCode, err := generateFeatureColumnCode([]ir.FeatureColumn{trainStmt.Label})
84+
featureColumnCode := generateFeatureColumnCode(trainStmt.Features)
85+
labelColumnCode := trainStmt.Label.GenPythonCode()
9386

9487
mp, err := json.Marshal(params[""])
9588
if err != nil {
@@ -137,15 +130,11 @@ const xgbTrainTemplate = `
137130
def step_entry_{{.StepIndex}}():
138131
import json
139132
import runtime.temp_file as temp_file
140-
import runtime.feature.column as fc
141-
import runtime.feature.field_desc as fd
133+
import runtime.feature.column
134+
import runtime.feature.field_desc
142135
from runtime.{{.Submitter}} import train
143136
144-
{{ if .FeatureColumnCode }}
145-
feature_column_map = {"feature_columns": [{{.FeatureColumnCode}}]}
146-
{{ else }}
147-
feature_column_map = None
148-
{{ end }}
137+
feature_column_map = {{.FeatureColumnCode}}
149138
label_column = {{.LabelColumnCode}}
150139
151140
model_params = json.loads('''{{.ModelParamsJSON}}''')
@@ -231,37 +220,20 @@ func getSubmitter(session *pb.Session, defaultValue string) string {
231220
return defaultValue
232221
}
233222

234-
func generateFeatureColumnCode(fcList []ir.FeatureColumn) (string, error) {
235-
fcCodes := make([]string, 0, len(fcList))
236-
for _, fc := range fcList {
237-
// xgboost have no cross feature column, just get the first field desc
238-
fd := fc.GetFieldDesc()[0]
239-
// pass format = "" to let runtime feature derivation to fill it in.
240-
tmpl := `fc.%s(fd.FieldDesc(name="%s", dtype=fd.DataType.%s, delimiter="%s", format="", shape=%s, is_sparse=%s, vocabulary=%s))`
241-
fcTypeName := reflect.TypeOf(fc).Elem().Name()
242-
isSparseStr := "False"
243-
if fd.IsSparse {
244-
isSparseStr = "True"
245-
}
246-
vocabList := []string{}
247-
for k := range fd.Vocabulary {
248-
vocabList = append(vocabList, k)
223+
func generateFeatureColumnCode(fcMap map[string][]ir.FeatureColumn) string {
224+
allFCCodes := make([]string, 0)
225+
for target, fcList := range fcMap {
226+
if len(fcList) == 0 {
227+
continue
249228
}
250-
shape := []int{1}
251-
if len(fd.Shape) != 0 {
252-
shape = fd.Shape
229+
codeList := make([]string, 0)
230+
for _, fc := range fcList {
231+
codeList = append(codeList, fc.GenPythonCode())
253232
}
254-
255-
code := fmt.Sprintf(tmpl, fcTypeName, fd.Name,
256-
strings.ToUpper(ir.DTypeToString(fd.DType)),
257-
fd.Delimiter,
258-
ir.AttrToPythonValue(shape),
259-
isSparseStr,
260-
ir.AttrToPythonValue(vocabList))
261-
fcCodes = append(fcCodes, code)
233+
code := fmt.Sprintf(`"%s":[%s]`, target, strings.Join(codeList, ","))
234+
allFCCodes = append(allFCCodes, code)
262235
}
263-
264-
return strings.Join(fcCodes, ",\n"), nil
236+
return fmt.Sprintf("{%s}", strings.Join(allFCCodes, ","))
265237
}
266238

267239
// TODO(typhoonzero): below functions are copied from codegen/xgboost/codegen.go

go/ir/codegen_python_values.go

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,33 +35,40 @@ func DTypeToString(dt int) string {
3535

3636
// AttrToPythonValue format the WITH attributes to corresponding Python code.
3737
func AttrToPythonValue(attr interface{}) string {
38-
switch attr.(type) {
38+
switch a := attr.(type) {
3939
case bool:
40-
return strings.Title(fmt.Sprintf("%v", attr.(bool)))
40+
return strings.Title(fmt.Sprintf("%v", a))
4141
case int:
42-
return fmt.Sprintf("%d", attr.(int))
42+
return fmt.Sprintf("%d", a)
4343
case int64:
44-
return fmt.Sprintf("%d", attr.(int64))
44+
return fmt.Sprintf("%d", a)
4545
case float32:
46-
return fmt.Sprintf("%f", attr.(float32))
46+
return fmt.Sprintf("%f", a)
4747
case float64: // FIXME(typhoonzero): may never use
4848
return fmt.Sprintf("%f", attr.(float64))
4949
case []int:
50-
intArrayAttrStr, _ := MarshalToJSONString(attr.([]int))
50+
if a == nil {
51+
return "None"
52+
}
53+
intArrayAttrStr, _ := MarshalToJSONString(a)
5154
return intArrayAttrStr
5255
case []string:
53-
l := attr.([]string)
54-
if len(l) == 0 {
56+
if a == nil {
57+
return "None"
58+
}
59+
if len(a) == 0 {
5560
return "[]"
5661
}
57-
stringListStr, _ := MarshalToJSONString(l)
62+
stringListStr, _ := MarshalToJSONString(a)
5863
return stringListStr
5964
case []interface{}:
60-
tmplist := attr.([]interface{})
61-
if len(tmplist) > 0 {
62-
if _, ok := tmplist[0].(int); ok {
65+
if a == nil {
66+
return "None"
67+
}
68+
if len(a) > 0 {
69+
if _, ok := a[0].(int); ok {
6370
intlist := []int{}
64-
for _, v := range tmplist {
71+
for _, v := range a {
6572
intlist = append(intlist, v.(int))
6673
}
6774
intlistStr, _ := MarshalToJSONString(intlist)
@@ -71,7 +78,7 @@ func AttrToPythonValue(attr interface{}) string {
7178
// TODO(typhoonzero): support []float etc.
7279
return "[]"
7380
case string:
74-
return fmt.Sprintf("\"%s\"", attr.(string))
81+
return fmt.Sprintf(`"%s"`, a)
7582
default:
7683
return ""
7784
}

go/ir/feature_column.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func (fd *FieldDesc) GenPythonCode() string {
5858
vocabList = append(vocabList, k)
5959
}
6060
// pass format = "" to let runtime feature derivation to fill it in.
61-
return fmt.Sprintf(`runtime.feature.field_desc.FieldDesc(name="%s", dtype=fd.DataType.%s, delimiter="%s", format="", shape=%s, is_sparse=%s, vocabulary=%s)`,
61+
return fmt.Sprintf(`runtime.feature.field_desc.FieldDesc(name="%s", dtype=runtime.feature.field_desc.DataType.%s, delimiter="%s", format="", shape=%s, is_sparse=%s, vocabulary=%s)`,
6262
fd.Name,
6363
strings.ToUpper(DTypeToString(fd.DType)),
6464
fd.Delimiter,

go/ir/feature_column_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func TestFeatureColumnGenPythonCode(t *testing.T) {
3232
DType: 0,
3333
},
3434
}
35-
a.Equal("runtime.feature.column.NumericColumn(runtime.feature.field_desc.FieldDesc(name=\"testcol\", dtype=fd.DataType.INT64, delimiter=\"\", format=\"\", shape=[10], is_sparse=False, vocabulary=[]))",
35+
a.Equal("runtime.feature.column.NumericColumn(runtime.feature.field_desc.FieldDesc(name=\"testcol\", dtype=runtime.feature.field_desc.DataType.INT64, delimiter=\"\", format=\"\", shape=[10], is_sparse=False, vocabulary=[]))",
3636
nc.GenPythonCode())
3737

3838
emd := &EmbeddingColumn{

0 commit comments

Comments
 (0)