@@ -144,7 +144,7 @@ def step_entry_{{.StepIndex}}():
144144 import runtime.temp_file as temp_file
145145 import runtime.feature.column as fc
146146 import runtime.feature.field_desc as fd
147- import runtime.{{.Submitter}}.xgboost as xgboost_submitter
147+ from runtime.{{.Submitter}} import train
148148
149149 {{ if .FeatureColumnCode }}
150150 feature_column_map = {"feature_columns": [{{.FeatureColumnCode}}]}
@@ -157,21 +157,23 @@ def step_entry_{{.StepIndex}}():
157157 train_params = json.loads('''{{.TrainParamsJSON}}''')
158158
159159 with temp_file.TemporaryDirectory(as_cwd=True) as temp_dir:
160- xgboost_submitter.train(original_sql='''{{.OriginalSQL}}''',
161- model_image='''{{.ModelImage}}''',
162- estimator='''{{.Estimator}}''',
163- datasource='''{{.DataSource}}''',
164- select='''{{.Select}}''',
165- validation_select='''{{.ValidationSelect}}''',
166- model_params=model_params,
167- train_params=train_params,
168- feature_column_map=feature_column_map,
169- label_column=label_column,
170- save='''{{.Save}}''',
171- load='''{{.Load}}''',
172- disk_cache="{{.DiskCache}}"=="true",
173- batch_size={{.BatchSize}},
174- epoch={{.Epoch}})
160+ os.chdir(temp_dir)
161+ train_params["original_sql"] = '''{{.OriginalSQL}}'''
162+ train_params["model_image"] = '''{{.ModelImage}}'''
163+ train_params["feature_column_map"] = feature_column_map
164+ train_params["label_column"] = label_column
165+ train_params["disk_cache"] = "{{.DiskCache}}"=="true"
166+ train_params["batch_size"] = {{.BatchSize}}
167+ train_params["epoch"] = {{.Epoch}}
168+
169+ train(datasource='''{{.DataSource}}''',
170+ estimator_string='''{{.Estimator}}''',
171+ select='''{{.Select}}''',
172+ validation_select='''{{.ValidationSelect}}''',
173+ model_params=model_params,
174+ save='''{{.Save}}''',
175+ load='''{{.Load}}''',
176+ train_params=train_params)
175177`
176178
177179func generateFeatureColumnCode (fcList []ir.FeatureColumn ) (string , error ) {
0 commit comments