Skip to content

Commit eaa2473

Browse files
authored
Polish pai submitter structure (#2859)
* polish structure * update * fix lint
1 parent 41a0db0 commit eaa2473

34 files changed

Lines changed: 1272 additions & 1083 deletions

go/codegen/pai/template_tf.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ const tfPredictTmplText = tfImportsText + `
113113
import os
114114
import types
115115
import traceback
116-
from runtime.pai.tensorflow import predict
116+
from runtime.pai.tensorflow_submitter import predict
117117
118118
try:
119119
import sqlflow_models
@@ -178,7 +178,7 @@ if os.environ.get('DISPLAY', '') == '':
178178
import json
179179
import types
180180
import sys
181-
from runtime.pai.tensorflow import explain
181+
from runtime.pai.tensorflow_submitter import explain
182182
183183
try:
184184
tf.enable_eager_execution()
@@ -241,7 +241,7 @@ if os.environ.get('DISPLAY', '') == '':
241241
import json
242242
import types
243243
import sys
244-
from runtime.pai.tensorflow import evaluate
244+
from runtime.pai.tensorflow_submitter import evaluate
245245
246246
try:
247247
tf.enable_eager_execution()

go/codegen/tensorflow/template_train.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ import traceback
4242
import tensorflow as tf
4343
import runtime
4444
{{ if .IsPAI }}
45-
from runtime.pai.tensorflow.train import train
45+
from runtime.pai.tensorflow_submitter.train import train
4646
{{ else }}
4747
from runtime.tensorflow.train import train
4848
{{ end }}

python/runtime/dbapi/connection.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,15 @@ def execute(self, statement):
166166
Returns:
167167
True on success, False otherwise
168168
"""
169+
rs = None
169170
try:
170171
rs = self._get_result_set(statement)
171172
return rs.success()
172-
except: # noqa: E722
173-
return False
173+
except Exception as e: # noqa: E722
174+
raise e
174175
finally:
175-
rs.close()
176+
if rs:
177+
rs.close()
176178

177179
def get_table_schema(self, table_name):
178180
"""Get table schema for given table

python/runtime/dbapi/maxcompute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _get_result_set(self, statement):
116116
instance = self._conn.execute_sql(statement)
117117
return MaxComputeResultSet(instance)
118118
except Exception as e:
119-
return MaxComputeResultSet(None, str(e))
119+
raise e
120120

121121
def close(self):
122122
if self._conn:

python/runtime/local/submitter.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,14 @@ def submit_local_train(datasource, estimator_string, select, validation_select,
3636
The pre-trained model name to load
3737
train_params: dict
3838
Extra train params, will be passed to runtime.tensorflow.train
39-
or runtime.xgboost.train, required fields: original_sql,
40-
model_image, feature_column_map, label_column; optional fields:
41-
disk_cache, batch_size, epoch.
39+
or runtime.xgboost.train. Required fields:
40+
- original_sql: Original SQLFlow statement.
41+
- model_image: Docker image used for training.
42+
- feature_column_map: A map of Python feature column IR.
43+
- label_column: Feature column instance describing the label.
44+
- disk_cache (optional): Use dmatrix disk cache if True.
45+
- batch_size (optional): Split data to batches and train.
46+
- epoch (optional): Epochs to train.
4247
"""
4348
if estimator_string.lower().startswith("xgboost"):
4449
# pop required params from train_params

python/runtime/pai/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
from runtime.pai.submitter import submit_pai_evaluate as evaluate # noqa: F401
15-
from runtime.pai.submitter import submit_pai_explain as explain # noqa: F401
16-
from runtime.pai.submitter import submit_pai_predict as predict # noqa: F401
17-
from runtime.pai.submitter import submit_pai_train as train # noqa: F401
14+
from runtime.pai.submitter_evaluate import submit_pai_evaluate as evaluate # noqa
15+
from runtime.pai.submitter_explain import submit_pai_explain as explain # noqa
16+
from runtime.pai.submitter_predict import submit_pai_predict as predict # noqa
17+
from runtime.pai.submitter_train import submit_pai_train as train # noqa
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
from runtime import db
15+
from runtime.diagnostics import SQLFlowDiagnostic
16+
from runtime.model import EstimatorType
17+
from runtime.pai import table_ops
18+
19+
20+
def create_predict_result_table(datasource, select, result_table, label_column,
21+
train_label_column, model_type):
22+
"""Create predict result table with given name and label column
23+
24+
Args:
25+
datasource: current datasource
26+
select: sql statement to get prediction data set
27+
result_table: the table name to save result
28+
label_column: name of the label column, if not exist in select
29+
result, we will add a int column in the result table
30+
train_label_column: name of the label column when training
31+
model_type: type of model defined in runtime.model.oss
32+
"""
33+
conn = db.connect_with_data_source(datasource)
34+
conn.execute("DROP TABLE IF EXISTS %s" % result_table)
35+
# PAI ml will create result table itself
36+
if model_type == EstimatorType.PAIML:
37+
return
38+
39+
create_table_sql = "CREATE TABLE %s AS SELECT * FROM %s LIMIT 0" % (
40+
result_table, select)
41+
conn.execute(create_table_sql)
42+
43+
# if label is not in data table, add a int column for it
44+
schema = db.get_table_schema(conn, result_table)
45+
col_type = "INT"
46+
for (name, ctype) in schema:
47+
if name == train_label_column or name == label_column:
48+
col_type = ctype
49+
break
50+
col_names = [col[0] for col in schema]
51+
if label_column not in col_names:
52+
conn.execute(
53+
conn, "ALTER TABLE %s ADD %s %s" %
54+
(result_table, label_column, col_type))
55+
if train_label_column != label_column and train_label_column in col_names:
56+
conn.execute(
57+
conn, "ALTER TABLE %s DROP COLUMN %s" %
58+
(result_table, train_label_column))
59+
60+
61+
# (TODO: lhw) This function is a common tool for prediction
62+
# on all platforms, we need to move it to a new file
63+
def create_explain_result_table(datasource, data_table, result_table,
64+
model_type, estimator, label_column):
65+
"""Create explain result table from given datasource
66+
67+
Args:
68+
datasource: current datasource
69+
data_table: input data table name
70+
result_table: table name to store the result
71+
model_type: type of the model to use
72+
estimator: estimator class if the model is TensorFlow estimator
73+
label_column: column name of the predict label
74+
"""
75+
conn = db.connect_with_data_source(datasource)
76+
drop_stmt = "DROP TABLE IF EXISTS %s" % result_table
77+
conn.execute(drop_stmt)
78+
79+
create_stmt = ""
80+
if model_type == EstimatorType.PAIML:
81+
return
82+
elif model_type == EstimatorType.TENSORFLOW:
83+
if estimator.startswith("BoostedTrees"):
84+
column_def = ""
85+
if conn.driver == "mysql":
86+
column_def = "(feature VARCHAR(255), dfc FLOAT, gain FLOAT)"
87+
else:
88+
# Hive & MaxCompute
89+
column_def = "(feature STRING, dfc STRING, gain STRING)"
90+
create_stmt = "CREATE TABLE IF NOT EXISTS %s %s;" % (result_table,
91+
column_def)
92+
else:
93+
if not label_column:
94+
raise SQLFlowDiagnostic(
95+
"need to specify WITH label_col=lable_col_name "
96+
"when explaining deep models")
97+
create_stmt = get_create_shap_result_sql(conn, data_table,
98+
result_table,
99+
label_column)
100+
elif model_type == EstimatorType.XGBOOST:
101+
if not label_column:
102+
raise SQLFlowDiagnostic(
103+
"need to specify WITH label_col=lable_col_name "
104+
"when explaining xgboost models")
105+
create_stmt = get_create_shap_result_sql(conn, data_table,
106+
result_table, label_column)
107+
else:
108+
raise SQLFlowDiagnostic(
109+
"not supported modelType %d for creating Explain result table" %
110+
model_type)
111+
112+
if not conn.execute(create_stmt):
113+
raise SQLFlowDiagnostic("Can't create explain result table")
114+
115+
116+
def get_create_shap_result_sql(conn, data_table, result_table, label_column):
117+
"""Get a sql statement which create a result table for SHAP
118+
119+
Args:
120+
conn: a database connection
121+
data_table: table name to read data from
122+
result_table: result table name
123+
label_column: column name of label
124+
125+
Returns:
126+
a sql statement to create SHAP result table
127+
"""
128+
schema = db.get_table_schema(conn, data_table)
129+
fields = ["%s STRING" % f[0] for f in schema if f[0] != label_column]
130+
return "CREATE TABLE IF NOT EXISTS %s (%s)" % (result_table,
131+
",".join(fields))
132+
133+
134+
def create_evaluate_result_table(datasource, result_table, metrics):
135+
"""Create a table to hold the evaluation result
136+
137+
Args:
138+
datasource: current datasource
139+
result_table: the table name to save result
140+
metrics: list of evaluation metrics names
141+
"""
142+
table_ops.drop_tables([result_table], datasource)
143+
# Always add loss
144+
ext_metrics = ["loss"]
145+
if isinstance(metrics, list):
146+
ext_metrics.extend(metrics)
147+
fields = ["%s STRING" % m for m in ext_metrics]
148+
sql = "CREATE TABLE IF NOT EXISTS %s (%s);" % (result_table,
149+
",".join(fields))
150+
conn = db.connect_with_data_source(datasource)
151+
conn.execute(sql)

python/runtime/pai/entry.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,18 @@
1616

1717
from runtime.diagnostics import SQLFlowDiagnostic
1818
from runtime.pai.pai_distributed import define_tf_flags, set_oss_environs
19-
from runtime.pai.tensorflow.evaluate import evaluate as evaluate_tf
20-
from runtime.pai.tensorflow.explain import explain as explain_tf
21-
from runtime.pai.tensorflow.predict import predict as predict_tf
22-
from runtime.pai.tensorflow.train import train as train_tf
19+
from runtime.pai.tensorflow_submitter.evaluate import evaluate as evaluate_tf
20+
from runtime.pai.tensorflow_submitter.explain import explain as explain_tf
21+
from runtime.pai.tensorflow_submitter.predict import predict as predict_tf
22+
from runtime.pai.tensorflow_submitter.train import train as train_tf
2323

2424
try:
2525
# (TODO: lhw) split entry.py into multiple files,
2626
# so, we can only import needed packages
27-
from runtime.pai.xgboost.predict import predict as predict_xgb
28-
from runtime.pai.xgboost.train import train as train_xgb
29-
from runtime.pai.xgboost.explain import explain as explain_xgb
30-
from runtime.pai.xgboost.evaluate import evaluate as evaluate_xgb
27+
from runtime.pai.xgboost_submitter.predict import predict as predict_xgb
28+
from runtime.pai.xgboost_submitter.train import train as train_xgb
29+
from runtime.pai.xgboost_submitter.explain import explain as explain_xgb
30+
from runtime.pai.xgboost_submitter.evaluate import evaluate as evaluate_xgb
3131
except: # noqa: E722
3232
pass
3333

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import json
15+
import os
16+
import string
17+
18+
from runtime.diagnostics import SQLFlowDiagnostic
19+
from runtime.pai import pai_model
20+
21+
JOB_ARCHIVE_FILE = "job.tar.gz"
22+
PARAMS_FILE = "params.txt"
23+
ENTRY_FILE = "entry.py"
24+
25+
26+
def get_pai_tf_cmd(cluster_config, tarball, params_file, entry_file,
27+
model_name, oss_model_path, train_table, val_table,
28+
res_table, project):
29+
"""Get PAI-TF cmd for training
30+
31+
Args:
32+
cluster_config: PAI cluster config
33+
tarball: the zipped resource name
34+
params_file: PAI param file name
35+
entry_file: entry file in the tarball
36+
model_name: trained model name
37+
oss_model_path: path to save the model
38+
train_table: train data table
39+
val_table: evaluate data table
40+
res_table: table to save train model, if given
41+
project: current odps project
42+
43+
Retruns:
44+
The cmd to run on PAI
45+
"""
46+
job_name = "_".join(["sqlflow", model_name]).replace(".", "_")
47+
cf_quote = json.dumps(cluster_config).replace("\"", "\\\"")
48+
49+
# submit table should format as: odps://<project>/tables/<table >,
50+
# odps://<project>/tables/<table > ...
51+
submit_tables = _max_compute_table_url(train_table)
52+
if train_table != val_table and val_table:
53+
val_table = _max_compute_table_url(val_table)
54+
submit_tables = "%s,%s" % (submit_tables, val_table)
55+
output_tables = ""
56+
if res_table != "":
57+
table = _max_compute_table_url(res_table)
58+
output_tables = "-Doutputs=%s" % table
59+
60+
# NOTE(typhoonzero): use - DhyperParameters to define flags passing
61+
# OSS credentials.
62+
# TODO(typhoonzero): need to find a more secure way to pass credentials.
63+
cmd = ("pai -name tensorflow1150 -project algo_public_dev "
64+
"-DmaxHungTimeBeforeGCInSeconds=0 -DjobName=%s -Dtags=dnn "
65+
"-Dscript=%s -DentryFile=%s -Dtables=%s %s -DhyperParameters='%s'"
66+
) % (job_name, tarball, entry_file, submit_tables, output_tables,
67+
params_file)
68+
69+
# format the oss checkpoint path with ARN authorization.
70+
oss_checkpoint_configs = os.getenv("SQLFLOW_OSS_CHECKPOINT_CONFIG")
71+
if not oss_checkpoint_configs:
72+
raise SQLFlowDiagnostic(
73+
"need to configure SQLFLOW_OSS_CHECKPOINT_CONFIG when "
74+
"submitting to PAI")
75+
ckpt_conf = json.loads(oss_checkpoint_configs)
76+
model_url = pai_model.get_oss_model_url(oss_model_path)
77+
role_name = _get_project_role_name(project)
78+
# format the oss checkpoint path with ARN authorization.
79+
oss_checkpoint_path = "%s/?role_arn=%s/%s&host=%s" % (
80+
model_url, ckpt_conf["arn"], role_name, ckpt_conf["host"])
81+
cmd = "%s -DcheckpointDir='%s'" % (cmd, oss_checkpoint_path)
82+
83+
if cluster_config["worker"]["count"] > 1:
84+
cmd = "%s -Dcluster=\"%s\"" % (cmd, cf_quote)
85+
else:
86+
cmd = "%s -DgpuRequired='%d'" % (cmd, cluster_config["worker"]["gpu"])
87+
return cmd
88+
89+
90+
def _get_project_role_name(project):
91+
"""Get oss role form project name.
92+
A valid role name contains letters and numbers only.
93+
The prefix 'pai2oss' of the role name denotes PAI access OS
94+
95+
Args:
96+
project: string
97+
project name
98+
99+
Returns:
100+
role name for the project
101+
"""
102+
return "pai2oss" + "".join(x for x in project.lower()
103+
if x in string.ascii_lowercase + string.digits)
104+
105+
106+
def _max_compute_table_url(table):
107+
parts = table.split(".")
108+
if len(parts) != 2:
109+
raise SQLFlowDiagnostic("odps table: %s should be format db.table" %
110+
table)
111+
return "odps://%s/tables/%s" % (parts[0], parts[1])

0 commit comments

Comments
 (0)