Skip to content

Commit ee28761

Browse files
authored
Add workflow codegen for XGBoost predict (#2845)
* add codegen predict * replace temp_file apis * update * update
1 parent 12c0b6b commit ee28761

9 files changed

Lines changed: 115 additions & 23 deletions

File tree

go/cmd/sqlflowserver/e2e_workflow_test.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,14 @@ SELECT * FROM iris.train
373373
TO TRAIN xgboost.gbtree
374374
WITH objective="multi:softmax",num_class=3
375375
LABEL class
376-
INTO sqlflow_models.xgb_classification;`
376+
INTO sqlflow_models.xgb_classification;
377+
378+
SELECT * FROM iris.test
379+
TO PREDICT iris.test_result_table.class
380+
USING sqlflow_models.xgb_classification;
381+
382+
SELECT * FROM iris.test_result_table;
383+
`
377384

378385
conn, err := createRPCConn()
379386
if err != nil {

go/codegen/experimental/codegen_normal_stmt.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import (
2020
pb "sqlflow.org/sqlflow/go/proto"
2121
)
2222

23-
var normalStmtStepTmpl = `
23+
const normalStmtStepTmpl = `
2424
def step_entry_{{.StepIndex}}():
2525
import runtime
2626
import runtime.dbapi

go/codegen/experimental/codegen_step.go

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,19 @@ import (
2424
pb "sqlflow.org/sqlflow/go/proto"
2525
)
2626

27-
func generateStepCode(stmt ir.SQLFlowStmt, stepIndex int, session *pb.Session) (string, error) {
28-
switch stmt.(type) {
27+
// TODO(sneaxiy): implement this method to distinguish whether
28+
// a model is a XGBoost model.
29+
func isTrainedXBoostModel(modelName string) bool {
30+
return true
31+
}
32+
33+
func generateStepCode(sqlStmt ir.SQLFlowStmt, stepIndex int, session *pb.Session) (string, error) {
34+
switch stmt := sqlStmt.(type) {
2935
case *ir.TrainStmt:
30-
trainStmt := stmt.(*ir.TrainStmt)
31-
return generateTrainCode(trainStmt, stepIndex, session)
36+
return generateTrainCode(stmt, stepIndex, session)
37+
case *ir.PredictStmt:
38+
return generatePredictCode(stmt, stepIndex, session)
3239
case *ir.NormalStmt:
33-
stmt := stmt.(*ir.NormalStmt)
3440
return GenerateNormalStmtStep(string(*stmt), session, stepIndex)
3541
default:
3642
return "", fmt.Errorf("not implemented stmt execution type %v", stmt)
@@ -44,6 +50,13 @@ func generateTrainCode(trainStmt *ir.TrainStmt, stepIndex int, session *pb.Sessi
4450
return "", fmt.Errorf("not implemented estimator type %s", trainStmt.Estimator)
4551
}
4652

53+
func generatePredictCode(predStmt *ir.PredictStmt, stepIndex int, session *pb.Session) (string, error) {
54+
if isTrainedXBoostModel(predStmt.Using) {
55+
return XGBoostGeneratePredict(predStmt, stepIndex, session)
56+
}
57+
return "", fmt.Errorf("not implemented model type")
58+
}
59+
4760
func initializeAndCheckAttributes(stmt ir.SQLFlowStmt) error {
4861
switch s := stmt.(type) {
4962
case *ir.TrainStmt:

go/codegen/experimental/xgboost.go

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,6 @@ func XGBoostGenerateTrain(trainStmt *ir.TrainStmt, stepIndex int, session *pb.Se
9999
if err != nil {
100100
return "", err
101101
}
102-
submitter := os.Getenv("SQLFLOW_submitter")
103-
if submitter == "" {
104-
submitter = "local"
105-
}
106102

107103
dbConnStr, err := GeneratePyDbConnStr(session)
108104
if err != nil {
@@ -126,7 +122,7 @@ func XGBoostGenerateTrain(trainStmt *ir.TrainStmt, stepIndex int, session *pb.Se
126122
DiskCache: diskCache,
127123
BatchSize: batchSize,
128124
Epoch: epoch,
129-
Submitter: submitter,
125+
Submitter: getSubmitter(session, "local"),
130126
}
131127
var program bytes.Buffer
132128
var trainTemplate = template.Must(template.New("Train").Parse(xgbTrainTemplate))
@@ -140,7 +136,6 @@ func XGBoostGenerateTrain(trainStmt *ir.TrainStmt, stepIndex int, session *pb.Se
140136
const xgbTrainTemplate = `
141137
def step_entry_{{.StepIndex}}():
142138
import json
143-
import os
144139
import runtime.temp_file as temp_file
145140
import runtime.feature.column as fc
146141
import runtime.feature.field_desc as fd
@@ -157,7 +152,6 @@ def step_entry_{{.StepIndex}}():
157152
train_params = json.loads('''{{.TrainParamsJSON}}''')
158153
159154
with temp_file.TemporaryDirectory(as_cwd=True) as temp_dir:
160-
os.chdir(temp_dir)
161155
train_params["original_sql"] = '''{{.OriginalSQL}}'''
162156
train_params["model_image"] = '''{{.ModelImage}}'''
163157
train_params["feature_column_map"] = feature_column_map
@@ -176,6 +170,67 @@ def step_entry_{{.StepIndex}}():
176170
train_params=train_params)
177171
`
178172

173+
type xgbPredFiller struct {
174+
StepIndex int
175+
DataSource string
176+
Select string
177+
PredLabelName string
178+
ResultTable string
179+
Load string
180+
Submitter string
181+
}
182+
183+
// XGBoostGeneratePredict generates the XGBoost prediction code
184+
func XGBoostGeneratePredict(predStmt *ir.PredictStmt, stepIndex int, session *pb.Session) (string, error) {
185+
dbConnStr, err := GeneratePyDbConnStr(session)
186+
if err != nil {
187+
return "", err
188+
}
189+
190+
filler := &xgbPredFiller{
191+
StepIndex: stepIndex,
192+
DataSource: dbConnStr,
193+
Select: replaceNewLineRuneAndTrimSpace(predStmt.Select),
194+
PredLabelName: predStmt.ResultColumn,
195+
ResultTable: predStmt.ResultTable,
196+
Load: predStmt.Using,
197+
Submitter: getSubmitter(session, "local"),
198+
}
199+
200+
var program bytes.Buffer
201+
predTmpl := template.Must(template.New("Train").Parse(xgbPredTemplate))
202+
err = predTmpl.Execute(&program, filler)
203+
if err != nil {
204+
return "", err
205+
}
206+
return program.String(), nil
207+
}
208+
209+
const xgbPredTemplate = `
210+
def step_entry_{{.StepIndex}}():
211+
import runtime.temp_file as temp_file
212+
from runtime.{{.Submitter}} import pred
213+
214+
with temp_file.TemporaryDirectory(as_cwd=True):
215+
pred(datasource='''{{.DataSource}}''',
216+
select='''{{.Select}}''',
217+
result_table='''{{.ResultTable}}''',
218+
pred_label_name='''{{.PredLabelName}}''',
219+
load='''{{.Load}}''')
220+
`
221+
222+
func getSubmitter(session *pb.Session, defaultValue string) string {
223+
if session.Submitter != "" {
224+
return session.Submitter
225+
}
226+
227+
submitter := os.Getenv("SQLFLOW_submitter")
228+
if submitter != "" {
229+
return submitter
230+
}
231+
return defaultValue
232+
}
233+
179234
func generateFeatureColumnCode(fcList []ir.FeatureColumn) (string, error) {
180235
fcCodes := make([]string, 0, len(fcList))
181236
for _, fc := range fcList {

python/runtime/local/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
from runtime.local.submitter import submit_local_pred as pred # noqa: F401
1415
from runtime.local.submitter import submit_local_train as train # noqa: F401

python/runtime/local/submitter.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
from runtime.local.xgboost_submitter.predict import pred as xgboost_pred
1415
from runtime.local.xgboost_submitter.train import train as xgboost_train
16+
from runtime.model.model import EstimatorType, Model
1517

1618

1719
def submit_local_train(datasource, estimator_string, select, validation_select,
@@ -60,3 +62,12 @@ def submit_local_train(datasource, estimator_string, select, validation_select,
6062
else:
6163
raise NotImplementedError("not implemented model type: %s" %
6264
estimator_string)
65+
66+
67+
def submit_local_pred(datasource, select, result_table, pred_label_name, load):
68+
model = Model.load_from_db(datasource, load)
69+
if model.get_type() == EstimatorType.XGBOOST:
70+
xgboost_pred(datasource, select, result_table, pred_label_name, model)
71+
else:
72+
raise NotImplementedError("not implemented model type: %s" %
73+
model.get_type())

python/runtime/local/xgboost_submitter/train_predict_test.py renamed to python/runtime/local/xgboost_submitter/local_submitter_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
# limitations under the License.
1313

1414
import os
15-
import tempfile
1615
import unittest
1716

1817
import runtime.db as db
18+
import runtime.temp_file as temp_file
1919
import runtime.testing as testing
2020
from runtime.feature.column import NumericColumn
2121
from runtime.feature.field_desc import FieldDesc
@@ -57,8 +57,7 @@ def test_train_and_predict(self):
5757
class_name = "class"
5858

5959
old_dir_name = os.getcwd()
60-
with tempfile.TemporaryDirectory() as tmp_dir_name:
61-
os.chdir(tmp_dir_name)
60+
with temp_file.TemporaryDirectory(as_cwd=True):
6261
eval_result = train(original_sql=original_sql,
6362
model_image="sqlflow:step",
6463
estimator_string="xgboost.gbtree",

python/runtime/local/xgboost_submitter/predict.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# limitations under the License.
1313

1414
import os
15-
import tempfile
1615

1716
import numpy as np
17+
import runtime.temp_file as temp_file
1818
import runtime.xgboost as xgboost_extended
19+
import six
1920
import xgboost as xgb
2021
from runtime import db
2122
from runtime.feature.compile import compile_ir_feature_columns
@@ -25,7 +26,7 @@
2526
from runtime.xgboost.dataset import xgb_dataset
2627

2728

28-
def pred(datasource, select, result_table, pred_label_name, load):
29+
def pred(datasource, select, result_table, pred_label_name, model):
2930
"""
3031
Do prediction using a trained model.
3132
@@ -34,12 +35,17 @@ def pred(datasource, select, result_table, pred_label_name, load):
3435
select (str): the input data to predict.
3536
result_table (str): the output data table.
3637
pred_label_name (str): the output label name to predict.
37-
load (str): where the trained model stores.
38+
model (Model|str): the model object or where to load the model.
3839
3940
Returns:
4041
None.
4142
"""
42-
model = Model.load_from_db(datasource, load)
43+
if isinstance(model, six.string_types):
44+
model = Model.load_from_db(datasource, model)
45+
else:
46+
assert isinstance(model,
47+
Model), "not supported model type %s" % type(model)
48+
4349
model_params = model.get_meta("attributes")
4450
train_fc_map = model.get_meta("features")
4551
train_label_desc = model.get_meta("label").get_field_desc()[0]
@@ -62,7 +68,7 @@ def pred(datasource, select, result_table, pred_label_name, load):
6268
result_column_names, train_label_idx = _create_predict_table(
6369
conn, select, result_table, train_label_desc, pred_label_name)
6470

65-
with tempfile.TemporaryDirectory() as tmp_dir_name:
71+
with temp_file.TemporaryDirectory() as tmp_dir_name:
6672
pred_fn = os.path.join(tmp_dir_name, "predict.txt")
6773
raw_data_dir = os.path.join(tmp_dir_name, "predict_raw_dir")
6874

python/runtime/tensorflow/import_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,5 @@ def import_model(model):
6262
Returns:
6363
An imported model class or function.
6464
"""
65-
import_model_package(model, globals())
65+
import_model_package(model, locals())
6666
return eval(model)

0 commit comments

Comments
 (0)