@@ -18,7 +18,6 @@ import (
1818 "encoding/json"
1919 "fmt"
2020 "os"
21- "reflect"
2221 "strings"
2322 "text/template"
2423
@@ -82,14 +81,8 @@ func XGBoostGenerateTrain(trainStmt *ir.TrainStmt, stepIndex int, session *pb.Se
8281 return "" , fmt .Errorf ("xgboost only support 0 or 1 feature column set, received %d" , len (trainStmt .Features ))
8382 }
8483 // featureColumnCode is a python map definition code like fc_map = {"feature_columns": [...]}
85- featureColumnCode := ""
86- if len (trainStmt .Features ) == 1 {
87- featureColumnCode , err = generateFeatureColumnCode (trainStmt .Features ["feature_columns" ])
88- if err != nil {
89- return "" , err
90- }
91- }
92- labelColumnCode , err := generateFeatureColumnCode ([]ir.FeatureColumn {trainStmt .Label })
84+ featureColumnCode := generateFeatureColumnCode (trainStmt .Features )
85+ labelColumnCode := trainStmt .Label .GenPythonCode ()
9386
9487 mp , err := json .Marshal (params ["" ])
9588 if err != nil {
@@ -137,15 +130,11 @@ const xgbTrainTemplate = `
137130def step_entry_{{.StepIndex}}():
138131 import json
139132 import runtime.temp_file as temp_file
140- import runtime.feature.column as fc
141- import runtime.feature.field_desc as fd
133+ import runtime.feature.column
134+ import runtime.feature.field_desc
142135 from runtime.{{.Submitter}} import train
143136
144- {{ if .FeatureColumnCode }}
145- feature_column_map = {"feature_columns": [{{.FeatureColumnCode}}]}
146- {{ else }}
147- feature_column_map = None
148- {{ end }}
137+ feature_column_map = {{.FeatureColumnCode}}
149138 label_column = {{.LabelColumnCode}}
150139
151140 model_params = json.loads('''{{.ModelParamsJSON}}''')
@@ -231,37 +220,20 @@ func getSubmitter(session *pb.Session, defaultValue string) string {
231220 return defaultValue
232221}
233222
234- func generateFeatureColumnCode (fcList []ir.FeatureColumn ) (string , error ) {
235- fcCodes := make ([]string , 0 , len (fcList ))
236- for _ , fc := range fcList {
237- // xgboost have no cross feature column, just get the first field desc
238- fd := fc .GetFieldDesc ()[0 ]
239- // pass format = "" to let runtime feature derivation to fill it in.
240- tmpl := `fc.%s(fd.FieldDesc(name="%s", dtype=fd.DataType.%s, delimiter="%s", format="", shape=%s, is_sparse=%s, vocabulary=%s))`
241- fcTypeName := reflect .TypeOf (fc ).Elem ().Name ()
242- isSparseStr := "False"
243- if fd .IsSparse {
244- isSparseStr = "True"
245- }
246- vocabList := []string {}
247- for k := range fd .Vocabulary {
248- vocabList = append (vocabList , k )
223+ func generateFeatureColumnCode (fcMap map [string ][]ir.FeatureColumn ) string {
224+ allFCCodes := make ([]string , 0 )
225+ for target , fcList := range fcMap {
226+ if len (fcList ) == 0 {
227+ continue
249228 }
250- shape := [] int { 1 }
251- if len ( fd . Shape ) != 0 {
252- shape = fd . Shape
229+ codeList := make ([] string , 0 )
230+ for _ , fc := range fcList {
231+ codeList = append ( codeList , fc . GenPythonCode ())
253232 }
254-
255- code := fmt .Sprintf (tmpl , fcTypeName , fd .Name ,
256- strings .ToUpper (ir .DTypeToString (fd .DType )),
257- fd .Delimiter ,
258- ir .AttrToPythonValue (shape ),
259- isSparseStr ,
260- ir .AttrToPythonValue (vocabList ))
261- fcCodes = append (fcCodes , code )
233+ code := fmt .Sprintf (`"%s":[%s]` , target , strings .Join (codeList , "," ))
234+ allFCCodes = append (allFCCodes , code )
262235 }
263-
264- return strings .Join (fcCodes , ",\n " ), nil
236+ return fmt .Sprintf ("{%s}" , strings .Join (allFCCodes , "," ))
265237}
266238
267239// TODO(typhoonzero): below functions are copied from codegen/xgboost/codegen.go
0 commit comments