Skip to content

Commit e48074d

Browse files
authored
Unify local and PAI submitter (#2841)
* unify local and pai submitter * update * update * fix merge * update * update * update * update
1 parent bc0be7a commit e48074d

16 files changed

Lines changed: 183 additions & 48 deletions

File tree

docker/step/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ RUN apt-get update && \
3636
unzip -qq odpscmd_public.zip -d /install/local/odpscmd && \
3737
rm -rf odpscmd_public.zip
3838

39-
RUN wget -q https://minlp.com/downloads/xecs/baron/current/baron-lin64.zip && \
39+
RUN wget -q https://sqlflow-models.oss-cn-zhangjiakou.aliyuncs.com/baron-lin64.zip && \
4040
unzip -qq baron-lin64.zip -d /install && \
4141
mv /install/baron-lin64/baron /usr/bin && \
4242
rm -rf /install/baron-lin64 && \

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ require (
4242
github.com/spf13/cobra v0.0.5 // indirect
4343
github.com/stretchr/objx v0.2.0 // indirect
4444
github.com/stretchr/testify v1.4.0
45+
github.com/topicai/candy v0.0.0-20160816022300-1b9030d056fa // indirect
46+
github.com/wangkuiyi/ipynb v0.0.0-20190916115031-f33fb706ed27 // indirect
4547
go.starlark.net v0.0.0-20191218235703-9fcb808a6221 // indirect
4648
golang.org/x/arch v0.0.0-20191126211547-368ea8f32fff // indirect
4749
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550

go.sum

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,8 @@ github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2/go.mod h1:2PfK
434434
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
435435
github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6 h1:lYIiVDtZnyTWlNwiAxLj0bbpTcx1BWCFhXjfsvmPdNc=
436436
github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
437+
github.com/topicai/candy v0.0.0-20160816022300-1b9030d056fa h1:24E2c2W5B4DUt6PDuRzqFTSMqX0665pn+WDts8taA9c=
438+
github.com/topicai/candy v0.0.0-20160816022300-1b9030d056fa/go.mod h1:INujiuYA1CkXgGrv0uWxYpX6NLNj4dzalZCUoou2F0U=
437439
github.com/twinj/uuid v1.0.0 h1:fzz7COZnDrXGTAOHGuUGYd6sG+JMq+AoE7+Jlu0przk=
438440
github.com/twinj/uuid v1.0.0/go.mod h1:mMgcE1RHFUFqe5AfiwlINXisXfDGro23fWdPUfOMjRY=
439441
github.com/uber-go/atomic v1.3.2 h1:Azu9lPBWRNKzYXSIwRfgRuDuS0YKsK4NFhiQv98gkxo=
@@ -452,6 +454,8 @@ github.com/unrolled/render v0.0.0-20180914162206-b9786414de4d h1:ggUgChAeyge4NZ4
452454
github.com/unrolled/render v0.0.0-20180914162206-b9786414de4d/go.mod h1:tu82oB5W2ykJRVioYsB+IQKcft7ryBr7w12qMBUPyXg=
453455
github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA=
454456
github.com/urfave/negroni v0.3.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4=
457+
github.com/wangkuiyi/ipynb v0.0.0-20190916115031-f33fb706ed27 h1:ao/Qf6LLr9DfAfn4cgWzvcgfvNnGkprQ7SUynXDRkA0=
458+
github.com/wangkuiyi/ipynb v0.0.0-20190916115031-f33fb706ed27/go.mod h1:gs1oId3tYIgGyuNKTQYtDmJCY9ovWfUXHYcnbWNLTCs=
455459
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 h1:eY9dn8+vbi4tKz5Qo6v2eYzo7kUS51QINcR5jNpbZS8=
456460
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU=
457461
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ package experimental
1616
import (
1717
"fmt"
1818
"net/url"
19-
"sqlflow.org/sqlflow/go/database"
2019
"strings"
2120

21+
"sqlflow.org/sqlflow/go/database"
22+
2223
"sqlflow.org/sqlflow/go/ir"
2324
pb "sqlflow.org/sqlflow/go/proto"
2425
)
@@ -27,10 +28,7 @@ func generateStepCode(stmt ir.SQLFlowStmt, stepIndex int, session *pb.Session) (
2728
switch stmt.(type) {
2829
case *ir.TrainStmt:
2930
trainStmt := stmt.(*ir.TrainStmt)
30-
if strings.HasPrefix(strings.ToUpper(trainStmt.Estimator), "XGBOOST.") {
31-
return XGBoostGenerateTrain(trainStmt, stepIndex, session)
32-
}
33-
return "", fmt.Errorf("not implemented estimator type %s", trainStmt.Estimator)
31+
return generateTrainCode(trainStmt, stepIndex, session)
3432
case *ir.NormalStmt:
3533
stmt := stmt.(*ir.NormalStmt)
3634
return GenerateNormalStmtStep(string(*stmt), session, stepIndex)
@@ -39,6 +37,13 @@ func generateStepCode(stmt ir.SQLFlowStmt, stepIndex int, session *pb.Session) (
3937
}
4038
}
4139

40+
func generateTrainCode(trainStmt *ir.TrainStmt, stepIndex int, session *pb.Session) (string, error) {
41+
if strings.HasPrefix(strings.ToUpper(trainStmt.Estimator), "XGBOOST.") {
42+
return XGBoostGenerateTrain(trainStmt, stepIndex, session)
43+
}
44+
return "", fmt.Errorf("not implemented estimator type %s", trainStmt.Estimator)
45+
}
46+
4247
func initializeAndCheckAttributes(stmt ir.SQLFlowStmt) error {
4348
switch s := stmt.(type) {
4449
case *ir.TrainStmt:

go/codegen/experimental/xgboost.go

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def step_entry_{{.StepIndex}}():
144144
import runtime.temp_file as temp_file
145145
import runtime.feature.column as fc
146146
import runtime.feature.field_desc as fd
147-
import runtime.{{.Submitter}}.xgboost as xgboost_submitter
147+
from runtime.{{.Submitter}} import train
148148
149149
{{ if .FeatureColumnCode }}
150150
feature_column_map = {"feature_columns": [{{.FeatureColumnCode}}]}
@@ -157,21 +157,23 @@ def step_entry_{{.StepIndex}}():
157157
train_params = json.loads('''{{.TrainParamsJSON}}''')
158158
159159
with temp_file.TemporaryDirectory(as_cwd=True) as temp_dir:
160-
xgboost_submitter.train(original_sql='''{{.OriginalSQL}}''',
161-
model_image='''{{.ModelImage}}''',
162-
estimator='''{{.Estimator}}''',
163-
datasource='''{{.DataSource}}''',
164-
select='''{{.Select}}''',
165-
validation_select='''{{.ValidationSelect}}''',
166-
model_params=model_params,
167-
train_params=train_params,
168-
feature_column_map=feature_column_map,
169-
label_column=label_column,
170-
save='''{{.Save}}''',
171-
load='''{{.Load}}''',
172-
disk_cache="{{.DiskCache}}"=="true",
173-
batch_size={{.BatchSize}},
174-
epoch={{.Epoch}})
160+
os.chdir(temp_dir)
161+
train_params["original_sql"] = '''{{.OriginalSQL}}'''
162+
train_params["model_image"] = '''{{.ModelImage}}'''
163+
train_params["feature_column_map"] = feature_column_map
164+
train_params["label_column"] = label_column
165+
train_params["disk_cache"] = "{{.DiskCache}}"=="true"
166+
train_params["batch_size"] = {{.BatchSize}}
167+
train_params["epoch"] = {{.Epoch}}
168+
169+
train(datasource='''{{.DataSource}}''',
170+
estimator_string='''{{.Estimator}}''',
171+
select='''{{.Select}}''',
172+
validation_select='''{{.ValidationSelect}}''',
173+
model_params=model_params,
174+
save='''{{.Save}}''',
175+
load='''{{.Load}}''',
176+
train_params=train_params)
175177
`
176178

177179
func generateFeatureColumnCode(fcList []ir.FeatureColumn) (string, error) {

python/runtime/local/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,5 @@
1010
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
13+
14+
from runtime.local.submitter import submit_local_train as train # noqa: F401

python/runtime/local/submitter.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
from runtime.local.xgboost_submitter.train import train as xgboost_train
15+
16+
17+
def submit_local_train(datasource, estimator_string, select, validation_select,
18+
model_params, save, load, train_params):
19+
"""This function run train task locally.
20+
21+
Args:
22+
datasource: string
23+
Like: odps://access_id:access_key@service.com/api?
24+
curr_project=test_ci&scheme=http
25+
estimator_string: string
26+
TensorFlow estimator name, Keras class name, or XGBoost
27+
select: string
28+
The SQL statement for selecting data for train
29+
validation_select: string
30+
Ths SQL statement for selecting data for validation
31+
model_params: dict
32+
Params for training, crossponding to WITH clause
33+
load: string
34+
The pre-trained model name to load
35+
train_params: dict
36+
Extra train params, will be passed to runtime.tensorflow.train
37+
or runtime.xgboost.train, required fields: original_sql,
38+
model_image, feature_column_map, label_column; optional fields:
39+
disk_cache, batch_size, epoch.
40+
"""
41+
if estimator_string.lower().startswith("xgboost"):
42+
# pop required params from train_params
43+
original_sql = train_params.pop("original_sql")
44+
model_image = train_params.pop("model_image")
45+
feature_column_map = train_params.pop("feature_column_map")
46+
label_column = train_params.pop("label_column")
47+
48+
return xgboost_train(original_sql,
49+
model_image,
50+
estimator_string,
51+
datasource,
52+
select,
53+
validation_select,
54+
model_params,
55+
train_params,
56+
feature_column_map,
57+
label_column,
58+
save,
59+
load=load)
60+
else:
61+
raise NotImplementedError("not implemented model type: %s" %
62+
estimator_string)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import unittest
15+
16+
import runtime.testing as testing
17+
from runtime.feature.column import NumericColumn
18+
from runtime.feature.field_desc import FieldDesc
19+
from runtime.local import train
20+
21+
22+
class TestXGBoostTrain(unittest.TestCase):
23+
@unittest.skipUnless(testing.get_driver() == "mysql",
24+
"skip non mysql tests")
25+
def test_train(self):
26+
ds = testing.get_datasource()
27+
original_sql = """SELECT * FROM iris.train
28+
TO TRAIN xgboost.gbtree
29+
WITH
30+
objective="multi:softmax",
31+
num_boost_round=20,
32+
num_class=3,
33+
validation.select="SELECT * FROM iris.test"
34+
INTO iris.xgboost_train_model_test;
35+
"""
36+
37+
select = "SELECT * FROM iris.train"
38+
val_select = "SELECT * FROM iris.test"
39+
train_params = {
40+
"num_boost_round": 20,
41+
"original_sql": original_sql,
42+
"feature_column_map": None,
43+
"label_column": NumericColumn(FieldDesc(name="class")),
44+
"model_image": "sqlflow:step"
45+
}
46+
model_params = {"num_class": 3, "objective": "multi:softmax"}
47+
eval_result = train(ds, "xgboost.gbtree", select, val_select,
48+
model_params, "iris.xgboost_train_model_test",
49+
None, train_params)
50+
self.assertLess(eval_result['train']['merror'][-1], 0.01)
51+
self.assertLess(eval_result['validate']['merror'][-1], 0.01)
52+
53+
54+
if __name__ == '__main__':
55+
unittest.main()

python/runtime/local/xgboost/__init__.py renamed to python/runtime/local/xgboost_submitter/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,3 @@
1010
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
13-
14-
from runtime.local.xgboost.predict import pred # noqa: F401
15-
from runtime.local.xgboost.train import train # noqa: F401
File renamed without changes.

0 commit comments

Comments
 (0)