Skip to content

Commit 1970b41

Browse files
authored
refine python/runtime/pai by flake8 (#2767)
1 parent 8112a1f commit 1970b41

17 files changed

Lines changed: 147 additions & 181 deletions

python/runtime/pai/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
from runtime.pai.submitter import submit_pai_evaluate as evaluate
15-
from runtime.pai.submitter import submit_pai_explain as explain
16-
from runtime.pai.submitter import submit_pai_predict as predict
17-
from runtime.pai.submitter import submit_pai_train as train
14+
from runtime.pai.submitter import submit_pai_evaluate as evaluate # noqa: F401
15+
from runtime.pai.submitter import submit_pai_explain as explain # noqa: F401
16+
from runtime.pai.submitter import submit_pai_predict as predict # noqa: F401
17+
from runtime.pai.submitter import submit_pai_train as train # noqa: F401

python/runtime/pai/cluster_conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,6 @@ def get_cluster_config(attrs):
6464
else:
6565
raise SQLFlowDiagnostic("train.num_evaluator should only be 1 or 0")
6666
conf = {"ps": ps, "worker": worker}
67-
if evaluator != None:
67+
if evaluator is not None:
6868
conf["evaluator"] = evaluator
6969
return conf

python/runtime/pai/entry.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,34 +11,30 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
import os
1514
import pickle
16-
import types
1715
from inspect import getargspec
1816

19-
from runtime import oss
2017
from runtime.diagnostics import SQLFlowDiagnostic
2118
from runtime.pai.pai_distributed import define_tf_flags, set_oss_environs
2219
from runtime.pai.tensorflow.evaluate import evaluate as evaluate_tf
2320
from runtime.pai.tensorflow.explain import explain as explain_tf
2421
from runtime.pai.tensorflow.predict import predict as predict_tf
2522
from runtime.pai.tensorflow.train import train as train_tf
26-
from runtime.tensorflow import is_tf_estimator
2723

2824
try:
29-
#(TODO: lhw) split entry.py into multiple files,
25+
# (TODO: lhw) split entry.py into multiple files,
3026
# so, we can only import needed packages
3127
from runtime.pai.xgboost.predict import predict as predict_xgb
3228
from runtime.pai.xgboost.train import train as train_xgb
3329
from runtime.pai.xgboost.explain import explain as explain_xgb
3430
from runtime.pai.xgboost.evaluate import evaluate as evaluate_xgb
35-
except:
31+
except: # noqa: E722
3632
pass
3733

3834

3935
def call_fun(func, params):
4036
"""Call a function with given params, entries in params will be treated
41-
as func' param if the key matches some argument name. Do not support
37+
as func' param if the key matches some argument name. Do not support
4238
var-args in func.
4339
4440
Arags:

python/runtime/pai/pai_distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
try:
2222
import tensorflow.compat.v1 as tf
23-
except:
23+
except: # noqa: E722
2424
import tensorflow as tf
2525

2626
# This module contain utilities for PAI distributed training.
@@ -104,7 +104,7 @@ def make_estimator_distributed_runconfig(FLAGS,
104104
is_distributed,
105105
save_checkpoints_steps=100):
106106
if is_distributed:
107-
cluster, task_type, task_index = make_distributed_info_without_evaluator(
107+
cluster, task_type, task_index = make_distributed_info_without_evaluator( # noqa: E501
108108
FLAGS)
109109
dump_into_tf_config(cluster, task_type, task_index)
110110
device_filters = None

python/runtime/pai/random_forest.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,17 @@ def get_explain_random_forest_pai_cmd(datasource, model_name, data_table,
4848
datasource: current datasoruce
4949
model_name: model name on PAI
5050
data_table: input data table name
51-
result_table: name of the result table, PAI will automatically create this table
51+
result_table: name of the result table, PAI will automatically
52+
create this table
5253
label_column: name of the label column
53-
54+
5455
Returns:
5556
A string which is a PAI cmd
5657
"""
57-
# NOTE(typhoonzero): for PAI random forests predicting, we can not load the TrainStmt
58-
# since the model saving is fully done by PAI. We directly use the columns in SELECT
59-
# statement for prediction, error will be reported by PAI job if the columns not match.
58+
# NOTE(typhoonzero): for PAI random forests predicting, we can not load
59+
# the TrainStmt since the model saving is fully done by PAI. We directly
60+
# use the columns in SELECT statement for prediction, error will be
61+
# reported by PAI job if the columns not match.
6062
if not label_column:
6163
return ("must specify WITH label_column when using "
6264
"pai random forest to explain models")

python/runtime/pai/submitter.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def drop_tables(tables, datasource):
106106
if table != "":
107107
drop_sql = "DROP TABLE IF EXISTS %s" % table
108108
db.execute(conn, drop_sql)
109-
except:
109+
except: # noqa: E722
110110
# odps will clear table itself, so even fail here, we do
111111
# not need to raise error
112112
print("Encounter error on drop tmp table")
@@ -311,7 +311,8 @@ def get_pai_tf_cmd(cluster_config, tarball, params_file, entry_file,
311311
job_name = "_".join(["sqlflow", model_name]).replace(".", "_")
312312
cf_quote = json.dumps(cluster_config).replace("\"", "\\\"")
313313

314-
# submit table should format as: odps://<project>/tables/<table >,odps://<project>/tables/<table > ...
314+
# submit table should format as: odps://<project>/tables/<table >,
315+
# odps://<project>/tables/<table > ...
315316
submit_tables = max_compute_table_url(train_table)
316317
if train_table != val_table and val_table:
317318
val_table = max_compute_table_url(val_table)
@@ -321,7 +322,8 @@ def get_pai_tf_cmd(cluster_config, tarball, params_file, entry_file,
321322
table = max_compute_table_url(res_table)
322323
output_tables = "-Doutputs=%s" % table
323324

324-
# NOTE(typhoonzero): use - DhyperParameters to define flags passing OSS credentials.
325+
# NOTE(typhoonzero): use - DhyperParameters to define flags passing
326+
# OSS credentials.
325327
# TODO(typhoonzero): need to find a more secure way to pass credentials.
326328
cmd = ("pai -name tensorflow1150 -project algo_public_dev "
327329
"-DmaxHungTimeBeforeGCInSeconds=0 -DjobName=%s -Dtags=dnn "
@@ -333,8 +335,8 @@ def get_pai_tf_cmd(cluster_config, tarball, params_file, entry_file,
333335
oss_checkpoint_configs = os.getenv("SQLFLOW_OSS_CHECKPOINT_CONFIG")
334336
if not oss_checkpoint_configs:
335337
raise SQLFlowDiagnostic(
336-
"need to configure SQLFLOW_OSS_CHECKPOINT_CONFIG when submitting to PAI"
337-
)
338+
"need to configure SQLFLOW_OSS_CHECKPOINT_CONFIG when "
339+
"submitting to PAI")
338340
ckpt_conf = json.loads(oss_checkpoint_configs)
339341
model_url = get_oss_model_url(oss_model_path)
340342
role_name = get_project_role_name(project)
@@ -406,7 +408,8 @@ def submit_pai_train(datasource, estimator_string, select, validation_select,
406408
407409
Args:
408410
datasource: string
409-
Like: odps://access_id:access_key@service.com/api?curr_project=test_ci&scheme=http
411+
Like: odps://access_id:access_key@service.com/api?
412+
curr_project=test_ci&scheme=http
410413
estimator_string: string
411414
Tensorflow estimator name, Keras class name, or XGBoost
412415
select: string
@@ -489,8 +492,9 @@ def get_oss_saved_model_type_and_estimator(model_name, project):
489492
If model is TensorFlow model, return type and estimator name
490493
If model is XGBoost, or other PAI model, just return model type
491494
"""
492-
# FIXME(typhoonzero): if the model not exist on OSS, assume it's a random forest model
493-
# should use a general method to fetch the model and see the model type.
495+
# FIXME(typhoonzero): if the model not exist on OSS, assume it's a random
496+
# forest model should use a general method to fetch the model and see the
497+
# model type.
494498
bucket = oss.get_models_bucket()
495499
tf = bucket.object_exists(model_name + "/tensorflow_model_desc")
496500
if tf:
@@ -529,9 +533,10 @@ def get_pai_predict_cmd(cluster_conf, datasource, project, oss_model_path,
529533
Returns:
530534
The command to submit PAI prediction task
531535
"""
532-
# NOTE(typhoonzero): for PAI machine learning toolkit predicting, we can not load the TrainStmt
533-
# since the model saving is fully done by PAI. We directly use the columns in SELECT
534-
# statement for prediction, error will be reported by PAI job if the columns not match.
536+
# NOTE(typhoonzero): for PAI machine learning toolkit predicting, we can
537+
# not load the TrainStmt since the model saving is fully done by PAI.
538+
# We directly use the columns in SELECT statement for prediction, error
539+
# will be reported by PAI job if the columns not match.
535540
conn = db.connect_with_data_source(datasource)
536541
if model_type == EstimatorType.PAIML:
537542
schema = db.get_table_schema(conn, predict_table)
@@ -621,13 +626,13 @@ def submit_pai_predict(datasource, select, result_table, label_column,
621626
params = dict(locals())
622627

623628
cwd = tempfile.mkdtemp(prefix="sqlflow", dir="/tmp")
624-
# TODO(typhoonzero): Do **NOT** create tmp table when the select statement is like:
625-
# "SELECT fields,... FROM table"
629+
# TODO(typhoonzero): Do **NOT** create tmp table when the select statement
630+
# is like: "SELECT fields,... FROM table"
626631
data_table = create_tmp_table_from_select(select, datasource)
627632
params["data_table"] = data_table
628633

629-
# format resultTable name to "db.table" to let the codegen form a submitting
630-
# argument of format "odps://project/tables/table_name"
634+
# format resultTable name to "db.table" to let the codegen form a
635+
# submitting argument of format "odps://project/tables/table_name"
631636
project = get_project(datasource)
632637
if result_table.count(".") == 0:
633638
result_table = "%s.%s" % (project, result_table)
@@ -740,9 +745,10 @@ def get_explain_random_forests_cmd(datasource, model_name, data_table,
740745
Returns:
741746
a PAI cmd to explain the data using given model
742747
"""
743-
# NOTE(typhoonzero): for PAI random forests predicting, we can not load the TrainStmt
744-
# since the model saving is fully done by PAI. We directly use the columns in SELECT
745-
# statement for prediction, error will be reported by PAI job if the columns not match.
748+
# NOTE(typhoonzero): for PAI random forests predicting, we can not load
749+
# the TrainStmt since the model saving is fully done by PAI. We directly
750+
# use the columns in SELECT statement for prediction, error will be
751+
# reported by PAI job if the columns not match.
746752
if not label_column:
747753
raise SQLFlowDiagnostic("must specify WITH label_column when using "
748754
"pai random forest to explain models")
@@ -752,11 +758,12 @@ def get_explain_random_forests_cmd(datasource, model_name, data_table,
752758
db.execute(conn, "DROP TABLE IF EXISTS %s;" % result_table)
753759
schema = db.get_table_schema(conn, data_table)
754760
fields = [f[0] for f in schema if f[0] != label_column]
755-
return (
756-
'''pai -name feature_importance -project algo_public '''
757-
'''-DmodelName="%s" -DinputTableName="%s" '''
758-
'''-DoutputTableName="%s" -DlabelColName="%s" -DfeatureColNames="%s" '''
759-
) % (model_name, data_table, result_table, label_column, ",".join(fields))
761+
return ('''pai -name feature_importance -project algo_public '''
762+
'''-DmodelName="%s" -DinputTableName="%s" '''
763+
'''-DoutputTableName="%s" -DlabelColName="%s" '''
764+
'''-DfeatureColNames="%s" ''') % (model_name, data_table,
765+
result_table, label_column,
766+
",".join(fields))
760767

761768

762769
def submit_pai_explain(datasource, select, result_table, model_name,
@@ -774,13 +781,13 @@ def submit_pai_explain(datasource, select, result_table, model_name,
774781
params = dict(locals())
775782

776783
cwd = tempfile.mkdtemp(prefix="sqlflow", dir="/tmp")
777-
# TODO(typhoonzero): Do **NOT** create tmp table when the select statement is like:
778-
# "SELECT fields,... FROM table"
784+
# TODO(typhoonzero): Do **NOT** create tmp table when the select statement
785+
# is like: "SELECT fields,... FROM table"
779786
data_table = create_tmp_table_from_select(select, datasource)
780787
params["data_table"] = data_table
781788

782-
# format resultTable name to "db.table" to let the codegen form a submitting
783-
# argument of format "odps://project/tables/table_name"
789+
# format resultTable name to "db.table" to let the codegen form a
790+
# submitting argument of format "odps://project/tables/table_name"
784791
project = get_project(datasource)
785792
if result_table.count(".") == 0:
786793
result_table = "%s.%s" % (project, result_table)

python/runtime/pai/submitter_test.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
from unittest import TestCase
1717

1818
import runtime.testing as testing
19-
import runtime.xgboost as xgboost_extended
20-
import tensorflow as tf
19+
import runtime.xgboost as xgboost_extended # noqa: F401
2120
from runtime.pai import submitter
2221
from runtime.pai.cluster_conf import get_cluster_config
2322

@@ -28,8 +27,10 @@ def test_get_oss_model_url(self):
2827
self.assertEqual("oss://sqlflow-models/user_a/model", url)
2928

3029
def test_get_datasource_dsn(self):
31-
ds = "odps://access_id:access_key@service.com/api?curr_project=test_ci&scheme=http"
32-
expected_dsn = "access_id:access_key@service.com/api?curr_project=test_ci&scheme=http"
30+
ds = "odps://access_id:access_key@service.com/api?" \
31+
"curr_project=test_ci&scheme=http"
32+
expected_dsn = "access_id:access_key@service.com/api?" \
33+
"curr_project=test_ci&scheme=http"
3334
dsn = submitter.get_datasource_dsn(ds)
3435
self.assertEqual(expected_dsn, dsn)
3536
project = "test_ci"
@@ -38,17 +39,22 @@ def test_get_datasource_dsn(self):
3839
def test_get_pai_tf_cmd(self):
3940
conf = get_cluster_config({})
4041
os.environ[
41-
"SQLFLOW_OSS_CHECKPOINT_CONFIG"] = '''{"arn":"arn", "host":"host"}'''
42+
"SQLFLOW_OSS_CHECKPOINT_CONFIG"] = '{"arn":"arn", "host":"host"}'
4243
cmd = submitter.get_pai_tf_cmd(
4344
conf, "job.tar.gz", "params.txt", "entry.py", "my_dnn_model",
4445
"user1/my_dnn_model", "test_project.input_table",
4546
"test_project.val_table", "test_project.res_table", "test_project")
4647
expected = (
47-
"pai -name tensorflow1150 -project algo_public_dev -DmaxHungTimeBeforeGCInSeconds=0 "
48-
"-DjobName=sqlflow_my_dnn_model -Dtags=dnn -Dscript=job.tar.gz -DentryFile=entry.py "
49-
"-Dtables=odps://test_project/tables/input_table,odps://test_project/tables/val_table "
50-
"-Doutputs=odps://test_project/tables/res_table -DhyperParameters='params.txt' "
51-
"-DcheckpointDir='oss://sqlflow-models/user1/my_dnn_model/?role_arn=arn/pai2osstestproject&host=host' "
48+
"pai -name tensorflow1150 -project algo_public_dev "
49+
"-DmaxHungTimeBeforeGCInSeconds=0 "
50+
"-DjobName=sqlflow_my_dnn_model -Dtags=dnn -Dscript=job.tar.gz "
51+
"-DentryFile=entry.py "
52+
"-Dtables=odps://test_project/tables/input_table,"
53+
"odps://test_project/tables/val_table "
54+
"-Doutputs=odps://test_project/tables/res_table "
55+
"-DhyperParameters='params.txt' "
56+
"-DcheckpointDir='oss://sqlflow-models/user1/my_dnn_model/?"
57+
"role_arn=arn/pai2osstestproject&host=host' "
5258
"-DgpuRequired='0'")
5359
self.assertEqual(expected, cmd)
5460

@@ -58,13 +64,18 @@ def test_get_pai_tf_cmd(self):
5864
"user1/my_dnn_model", "test_project.input_table",
5965
"test_project.val_table", "test_project.res_table", "test_project")
6066
expected = (
61-
"pai -name tensorflow1150 -project algo_public_dev -DmaxHungTimeBeforeGCInSeconds=0 "
62-
"-DjobName=sqlflow_my_dnn_model -Dtags=dnn -Dscript=job.tar.gz -DentryFile=entry.py "
63-
"-Dtables=odps://test_project/tables/input_table,odps://test_project/tables/val_table "
64-
"-Doutputs=odps://test_project/tables/res_table -DhyperParameters='params.txt' "
65-
"-DcheckpointDir='oss://sqlflow-models/user1/my_dnn_model/?role_arn=arn/pai2osstestproject&host=host' "
66-
r'''-Dcluster="{\"ps\": {\"count\": 1, \"cpu\": 200, \"gpu\": 0}, \"worker\": {\"count\": 5, \"cpu\": 400, \"gpu\": 0}}"'''
67-
)
67+
"pai -name tensorflow1150 -project algo_public_dev "
68+
"-DmaxHungTimeBeforeGCInSeconds=0 "
69+
"-DjobName=sqlflow_my_dnn_model -Dtags=dnn -Dscript=job.tar.gz "
70+
"-DentryFile=entry.py "
71+
"-Dtables=odps://test_project/tables/input_table,"
72+
"odps://test_project/tables/val_table "
73+
"-Doutputs=odps://test_project/tables/res_table "
74+
"-DhyperParameters='params.txt' "
75+
"-DcheckpointDir='oss://sqlflow-models/user1/my_dnn_model/?"
76+
"role_arn=arn/pai2osstestproject&host=host' "
77+
r'''-Dcluster="{\"ps\": {\"count\": 1, \"cpu\": 200, \"gpu\": 0}'''
78+
r''', \"worker\": {\"count\": 5, \"cpu\": 400, \"gpu\": 0}}"''')
6879
self.assertEqual(expected, cmd)
6980
del os.environ["SQLFLOW_OSS_CHECKPOINT_CONFIG"]
7081

@@ -136,12 +147,14 @@ def test_submit_pai_train_task(self):
136147
model_params["hidden_units"] = [10, 20]
137148
model_params["n_classes"] = 3
138149

139-
# feature_columns_code will be used to save the training informations together
140-
# with the saved model.
141-
feature_columns_code = """{"feature_columns": [tf.feature_column.numeric_column("sepal_length", shape=[1]),
142-
tf.feature_column.numeric_column("sepal_width", shape=[1]),
143-
tf.feature_column.numeric_column("petal_length", shape=[1]),
144-
tf.feature_column.numeric_column("petal_width", shape=[1])]}"""
150+
# feature_columns_code will be used to save the training information
151+
# together with the saved model.
152+
feature_columns_code = """{"feature_columns": [
153+
tf.feature_column.numeric_column("sepal_length", shape=[1]),
154+
tf.feature_column.numeric_column("sepal_width", shape=[1]),
155+
tf.feature_column.numeric_column("petal_length", shape=[1]),
156+
tf.feature_column.numeric_column("petal_width", shape=[1]),
157+
]}"""
145158
feature_columns = eval(feature_columns_code)
146159

147160
submitter.submit_pai_train(
@@ -172,12 +185,12 @@ def test_submit_pai_train_task(self):
172185
is_pai=True,
173186
feature_columns_code=feature_columns_code,
174187
model_repo_image="",
175-
original_sql=
176-
'''SELECT * FROM alifin_jtest_dev.sqlflow_test_iris_train
177-
TO TRAIN DNNClassifier
178-
WITH model.n_classes = 3, model.hidden_units = [10, 20]
179-
LABEL class
180-
INTO e2etest_pai_dnn;''')
188+
original_sql='''
189+
SELECT * FROM alifin_jtest_dev.sqlflow_test_iris_train
190+
TO TRAIN DNNClassifier
191+
WITH model.n_classes = 3, model.hidden_units = [10, 20]
192+
LABEL class
193+
INTO e2etest_pai_dnn;''')
181194

182195
def test_submit_pai_predict_task(self):
183196
submitter.submit_pai_predict(

0 commit comments

Comments
 (0)