@@ -99,10 +99,6 @@ func XGBoostGenerateTrain(trainStmt *ir.TrainStmt, stepIndex int, session *pb.Se
9999 if err != nil {
100100 return "" , err
101101 }
102- submitter := os .Getenv ("SQLFLOW_submitter" )
103- if submitter == "" {
104- submitter = "local"
105- }
106102
107103 dbConnStr , err := GeneratePyDbConnStr (session )
108104 if err != nil {
@@ -126,7 +122,7 @@ func XGBoostGenerateTrain(trainStmt *ir.TrainStmt, stepIndex int, session *pb.Se
126122 DiskCache : diskCache ,
127123 BatchSize : batchSize ,
128124 Epoch : epoch ,
129- Submitter : submitter ,
125+ Submitter : getSubmitter ( session , "local" ) ,
130126 }
131127 var program bytes.Buffer
132128 var trainTemplate = template .Must (template .New ("Train" ).Parse (xgbTrainTemplate ))
@@ -140,7 +136,6 @@ func XGBoostGenerateTrain(trainStmt *ir.TrainStmt, stepIndex int, session *pb.Se
140136const xgbTrainTemplate = `
141137def step_entry_{{.StepIndex}}():
142138 import json
143- import os
144139 import runtime.temp_file as temp_file
145140 import runtime.feature.column as fc
146141 import runtime.feature.field_desc as fd
@@ -157,7 +152,6 @@ def step_entry_{{.StepIndex}}():
157152 train_params = json.loads('''{{.TrainParamsJSON}}''')
158153
159154 with temp_file.TemporaryDirectory(as_cwd=True) as temp_dir:
160- os.chdir(temp_dir)
161155 train_params["original_sql"] = '''{{.OriginalSQL}}'''
162156 train_params["model_image"] = '''{{.ModelImage}}'''
163157 train_params["feature_column_map"] = feature_column_map
@@ -176,6 +170,67 @@ def step_entry_{{.StepIndex}}():
176170 train_params=train_params)
177171`
178172
173+ type xgbPredFiller struct {
174+ StepIndex int
175+ DataSource string
176+ Select string
177+ PredLabelName string
178+ ResultTable string
179+ Load string
180+ Submitter string
181+ }
182+
183+ // XGBoostGeneratePredict generates the XGBoost prediction code
184+ func XGBoostGeneratePredict (predStmt * ir.PredictStmt , stepIndex int , session * pb.Session ) (string , error ) {
185+ dbConnStr , err := GeneratePyDbConnStr (session )
186+ if err != nil {
187+ return "" , err
188+ }
189+
190+ filler := & xgbPredFiller {
191+ StepIndex : stepIndex ,
192+ DataSource : dbConnStr ,
193+ Select : replaceNewLineRuneAndTrimSpace (predStmt .Select ),
194+ PredLabelName : predStmt .ResultColumn ,
195+ ResultTable : predStmt .ResultTable ,
196+ Load : predStmt .Using ,
197+ Submitter : getSubmitter (session , "local" ),
198+ }
199+
200+ var program bytes.Buffer
201+ predTmpl := template .Must (template .New ("Train" ).Parse (xgbPredTemplate ))
202+ err = predTmpl .Execute (& program , filler )
203+ if err != nil {
204+ return "" , err
205+ }
206+ return program .String (), nil
207+ }
208+
209+ const xgbPredTemplate = `
210+ def step_entry_{{.StepIndex}}():
211+ import runtime.temp_file as temp_file
212+ from runtime.{{.Submitter}} import pred
213+
214+ with temp_file.TemporaryDirectory(as_cwd=True):
215+ pred(datasource='''{{.DataSource}}''',
216+ select='''{{.Select}}''',
217+ result_table='''{{.ResultTable}}''',
218+ pred_label_name='''{{.PredLabelName}}''',
219+ load='''{{.Load}}''')
220+ `
221+
222+ func getSubmitter (session * pb.Session , defaultValue string ) string {
223+ if session .Submitter != "" {
224+ return session .Submitter
225+ }
226+
227+ submitter := os .Getenv ("SQLFLOW_submitter" )
228+ if submitter != "" {
229+ return submitter
230+ }
231+ return defaultValue
232+ }
233+
179234func generateFeatureColumnCode (fcList []ir.FeatureColumn ) (string , error ) {
180235 fcCodes := make ([]string , 0 , len (fcList ))
181236 for _ , fc := range fcList {
0 commit comments