Skip to content

Commit 015253b

Browse files
authored
Move ir_generator.go to the ir package and split ir_generate_test.go into two files (#2491)
1 parent 56868d3 commit 015253b

11 files changed

Lines changed: 584 additions & 535 deletions

File tree

pkg/executor/executor.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ func (s *pythonExecutor) SaveModel(cl *ir.TrainStmt) error {
122122
if s.ModelDir != "" {
123123
modelURI = fmt.Sprintf("file://%s/%s", s.ModelDir, cl.Into)
124124
}
125-
return m.Save(modelURI, cl, s.Session)
125+
return m.Save(modelURI, s.Session)
126126
}
127127

128128
func (s *pythonExecutor) runCommand(program string, logStderr bool) error {
@@ -158,7 +158,7 @@ func (s *pythonExecutor) ExecuteQuery(stmt *ir.NormalStmt) error {
158158

159159
func (s *pythonExecutor) ExecuteTrain(cl *ir.TrainStmt) (e error) {
160160
var code string
161-
if isXGBoostModel(cl.Estimator) {
161+
if cl.GetModelKind() == ir.XGBoost {
162162
if code, e = xgboost.Train(cl, s.Session); e != nil {
163163
return e
164164
}
@@ -180,7 +180,7 @@ func (s *pythonExecutor) ExecutePredict(cl *ir.PredictStmt) (e error) {
180180
}
181181

182182
var code string
183-
if isXGBoostModel(cl.TrainStmt.Estimator) {
183+
if cl.TrainStmt.GetModelKind() == ir.XGBoost {
184184
if code, e = xgboost.Pred(cl, s.Session); e != nil {
185185
return e
186186
}
@@ -201,7 +201,7 @@ func (s *pythonExecutor) ExecuteExplain(cl *ir.ExplainStmt) error {
201201
return err
202202
}
203203
defer db.Close()
204-
if isXGBoostModel(cl.TrainStmt.Estimator) {
204+
if cl.TrainStmt.GetModelKind() == ir.XGBoost {
205205
code, err = xgboost.Explain(cl, s.Session)
206206
// TODO(typhoonzero): deal with XGBoost model explain result table creation.
207207
} else {
@@ -236,7 +236,7 @@ func (s *pythonExecutor) ExecuteEvaluate(cl *ir.EvaluateStmt) error {
236236
// NOTE(typhoonzero): model is already loaded under s.Cwd
237237
var code string
238238
var err error
239-
if isXGBoostModel(cl.TrainStmt.Estimator) {
239+
if cl.TrainStmt.GetModelKind() == ir.XGBoost {
240240
code, err = xgboost.Evaluate(cl, s.Session)
241241
if err != nil {
242242
return err

pkg/executor/pre_exec.go

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,10 @@ package executor
1616
import (
1717
"bytes"
1818
"fmt"
19-
"strings"
2019

2120
"sqlflow.org/sqlflow/pkg/database"
2221
"sqlflow.org/sqlflow/pkg/ir"
2322
pb "sqlflow.org/sqlflow/pkg/proto"
24-
"sqlflow.org/sqlflow/pkg/step/feature"
2523
"sqlflow.org/sqlflow/pkg/verifier"
2624
)
2725

@@ -98,7 +96,7 @@ func createPredictionResultTable(predStmt *ir.PredictStmt, db *database.DB, sess
9896
// getSQLFieldType is quiet like verify but accept a SQL string as input, and returns
9997
// an ordered list of the field types.
10098
func getSQLFieldType(slct string, db *database.DB) ([]string, []string, error) {
101-
rows, err := feature.FetchSamples(db, slct)
99+
rows, err := verifier.FetchSamples(db, slct)
102100
if err != nil {
103101
return nil, nil, err
104102
}
@@ -128,9 +126,3 @@ func getSQLFieldType(slct string, db *database.DB) ([]string, []string, error) {
128126

129127
return flds, ft, nil
130128
}
131-
132-
// TODO(yancey1989): need to discuss fill esimator type in IR,
133-
// that we don't need the duplicate judgement with pkg/sql/ir_generator.go
134-
func isXGBoostModel(estimator string) bool {
135-
return strings.HasPrefix(strings.ToUpper(estimator), `XGB`)
136-
}

0 commit comments

Comments
 (0)