@@ -122,7 +122,7 @@ func (s *pythonExecutor) SaveModel(cl *ir.TrainStmt) error {
122122 if s .ModelDir != "" {
123123 modelURI = fmt .Sprintf ("file://%s/%s" , s .ModelDir , cl .Into )
124124 }
125- return m .Save (modelURI , cl , s .Session )
125+ return m .Save (modelURI , s .Session )
126126}
127127
128128func (s * pythonExecutor ) runCommand (program string , logStderr bool ) error {
@@ -158,7 +158,7 @@ func (s *pythonExecutor) ExecuteQuery(stmt *ir.NormalStmt) error {
158158
159159func (s * pythonExecutor ) ExecuteTrain (cl * ir.TrainStmt ) (e error ) {
160160 var code string
161- if isXGBoostModel ( cl .Estimator ) {
161+ if cl .GetModelKind () == ir . XGBoost {
162162 if code , e = xgboost .Train (cl , s .Session ); e != nil {
163163 return e
164164 }
@@ -180,7 +180,7 @@ func (s *pythonExecutor) ExecutePredict(cl *ir.PredictStmt) (e error) {
180180 }
181181
182182 var code string
183- if isXGBoostModel ( cl .TrainStmt .Estimator ) {
183+ if cl .TrainStmt .GetModelKind () == ir . XGBoost {
184184 if code , e = xgboost .Pred (cl , s .Session ); e != nil {
185185 return e
186186 }
@@ -201,7 +201,7 @@ func (s *pythonExecutor) ExecuteExplain(cl *ir.ExplainStmt) error {
201201 return err
202202 }
203203 defer db .Close ()
204- if isXGBoostModel ( cl .TrainStmt .Estimator ) {
204+ if cl .TrainStmt .GetModelKind () == ir . XGBoost {
205205 code , err = xgboost .Explain (cl , s .Session )
206206 // TODO(typhoonzero): deal with XGBoost model explain result table creation.
207207 } else {
@@ -236,7 +236,7 @@ func (s *pythonExecutor) ExecuteEvaluate(cl *ir.EvaluateStmt) error {
236236 // NOTE(typhoonzero): model is already loaded under s.Cwd
237237 var code string
238238 var err error
239- if isXGBoostModel ( cl .TrainStmt .Estimator ) {
239+ if cl .TrainStmt .GetModelKind () == ir . XGBoost {
240240 code , err = xgboost .Evaluate (cl , s .Session )
241241 if err != nil {
242242 return err
0 commit comments