Skip to content

Commit 1762656

Browse files
authored
refine codes in python/runtime/xgboost by flake8 (#2768)
1 parent b5e09a8 commit 1762656

8 files changed

Lines changed: 70 additions & 45 deletions

File tree

python/runtime/xgboost/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
from runtime.xgboost import feature_column
14+
from runtime.xgboost import feature_column # noqa: F401
1515

1616

1717
class DataTypeCollection(object):

python/runtime/xgboost/dataset.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ def dump_dmatrix(filename,
165165

166166
f.write("\t".join(row_data) + "\n")
167167
row_id += 1
168-
# batch_size == None meas use all data in generator
169-
if batch_size == None:
168+
# batch_size == None means use all data in generator
169+
if batch_size is None:
170170
continue
171171
if row_id >= batch_size:
172172
break
@@ -224,8 +224,9 @@ def get_pai_table_slice_count(table, nworkers, batch_size):
224224

225225
row_cnt = db.get_pai_table_row_num(table)
226226

227-
assert row_cnt >= nworkers, "Data number {} should not less than worker number {}".format(
228-
row_cnt, nworkers)
227+
assert row_cnt >= nworkers, "Data number {} should not " \
228+
"less than worker number {}"\
229+
.format(row_cnt, nworkers)
229230

230231
slice_num_per_worker = max(int(row_cnt / (nworkers * batch_size)), 1)
231232
slice_count = slice_num_per_worker * nworkers
@@ -279,7 +280,8 @@ def thread_worker(slice_id):
279280
raw_data_dir
280281
]))
281282

282-
assert p.returncode == 0, "The subprocess raises error when reading data"
283+
assert p.returncode == 0, \
284+
"The subprocess raises error when reading data"
283285
complete_queue.put(slice_id)
284286

285287
slice_id = rank

python/runtime/xgboost/evaluate.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,33 @@
1212
# limitations under the License.
1313

1414
import numpy as np
15-
import sklearn
15+
import sklearn.metrics
1616
import xgboost as xgb
1717
from runtime import db
1818
from runtime.xgboost.dataset import xgb_dataset
19-
# yapf: disable
20-
from sklearn.metrics import (accuracy_score, average_precision_score,
21-
balanced_accuracy_score, brier_score_loss,
22-
cohen_kappa_score, explained_variance_score,
23-
f1_score, fbeta_score, hamming_loss, hinge_loss,
24-
log_loss, mean_absolute_error, mean_squared_error,
25-
mean_squared_log_error, median_absolute_error,
26-
precision_score, r2_score, recall_score,
27-
roc_auc_score, zero_one_loss)
2819

29-
# yapf: enable
20+
SKLEARN_METRICS = [
21+
'accuracy_score',
22+
'average_precision_score',
23+
'balanced_accuracy_score',
24+
'brier_score_loss',
25+
'cohen_kappa_score',
26+
'explained_variance_score',
27+
'f1_score',
28+
'fbeta_score',
29+
'hamming_loss',
30+
'hinge_loss',
31+
'log_loss',
32+
'mean_absolute_error',
33+
'mean_squared_error',
34+
'mean_squared_log_error',
35+
'median_absolute_error',
36+
'precision_score',
37+
'r2_score',
38+
'recall_score',
39+
'roc_auc_score',
40+
'zero_one_loss',
41+
]
3042

3143
DEFAULT_PREDICT_BATCH_SIZE = 10000
3244

@@ -95,8 +107,9 @@ def evaluate_and_store_result(bst, dpred, feature_file_id, validation_metrics,
95107
# using the original prediction result of predict API by default
96108
pass
97109
else:
98-
# prediction output with multi-class job has two dimensions, this is a temporary
99-
# way, can remove this else branch when we can load the model meta not only on PAI submitter.
110+
# prediction output with multi-class job has two dimensions, this
111+
# is a temporary way, can remove this else branch when we can load
112+
# the model meta not only on PAI submitter.
100113
if len(preds.shape) == 2:
101114
preds = np.argmax(np.array(preds), axis=1)
102115

@@ -121,7 +134,9 @@ def evaluate_and_store_result(bst, dpred, feature_file_id, validation_metrics,
121134

122135
evaluate_results = dict()
123136
for metric_name in validation_metrics:
124-
metric_func = eval(metric_name)
137+
if metric_name not in SKLEARN_METRICS:
138+
raise ValueError("unsupported metric: %s" % metric_name)
139+
metric_func = getattr(sklearn.metrics, metric_name)
125140
metric_value = metric_func(y_test, preds)
126141
evaluate_results[metric_name] = metric_value
127142

python/runtime/xgboost/explain.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,10 @@ def explain(datasource,
184184

185185
if result_table != "":
186186
if is_pai:
187-
# TODO(typhoonzero): the shape of shap_values is (3, num_samples, num_features)
188-
# use the first dimension here, should find out how to use the other two.
187+
# TODO(typhoonzero): the shape of shap_values is
188+
# (3, num_samples, num_features), use the first
189+
# dimension here, should find out how to use
190+
# the other two.
189191
write_shap_values(shap_values[0], "pai_maxcompute", None,
190192
result_table, feature_column_names,
191193
hdfs_namenode_addr, hive_location, hdfs_user,

python/runtime/xgboost/feature_column.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
if six.PY2:
3333

3434
def hashing(x):
35-
return long(hashlib.sha1(x).hexdigest(), 16)
35+
return long(hashlib.sha1(x).hexdigest(), 16) # noqa: F821
3636
else:
3737

3838
def hashing(x):
@@ -139,9 +139,9 @@ def elementwise_transform_fn(x):
139139
if self.default_value is not None:
140140
return self.default_value
141141
else:
142-
raise ValueError(
143-
'The categorical value of column {} out of range [0, {})'
144-
.format(self.key, self.num_buckets))
142+
raise ValueError('The categorical value of column {} '
143+
'out of range [0, {})'.format(
144+
self.key, self.num_buckets))
145145

146146
if isinstance(slot_value, np.ndarray):
147147
output = elementwise_transform(
@@ -174,7 +174,7 @@ def num_classes(self):
174174
return len(self.vocabulary_list)
175175

176176
def __call__(self, inputs):
177-
fn = lambda x: self.vocabulary_list.index(x)
177+
fn = lambda x: self.vocabulary_list.index(x) # noqa: E731
178178

179179
def transform_fn(slot_value):
180180
if isinstance(slot_value, np.ndarray):
@@ -208,7 +208,7 @@ def num_classes(self):
208208
return self.hash_bucket_size
209209

210210
def __call__(self, inputs):
211-
fn = lambda x: hashing(x) % self.hash_bucket_size
211+
fn = lambda x: hashing(x) % self.hash_bucket_size # noqa: E731
212212

213213
def transform_fn(slot_value):
214214
if isinstance(slot_value, np.ndarray):
@@ -230,7 +230,9 @@ def categorical_column_with_hash_bucket(key, hash_bucket_size, dtype='string'):
230230
class IndicatorColumnTransformer(BaseColumnTransformer):
231231
def __init__(self, categorical_column):
232232
assert isinstance(categorical_column, CategoricalColumnTransformer), \
233-
"categorical_column must be type of CategoricalColumnTransformer but got {}".format(type(categorical_column))
233+
"categorical_column must be type of " \
234+
"CategoricalColumnTransformer but got {}".format(
235+
type(categorical_column))
234236
self.categorical_column = categorical_column
235237

236238
def _set_feature_column_names(self, names):

python/runtime/xgboost/predict.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,10 @@ def predict_and_store_result(bst, dpred, feature_file_id, model_params,
8686
hdfs_user, hdfs_pass):
8787
preds = bst.predict(dpred)
8888

89-
#TODO(yancey1989): should save train_params and model_params not only on PAI submitter
90-
#TODO(yancey1989): output the original result for various objective function.
89+
# TODO(yancey1989): should save train_params and model_params
90+
# not only on PAI submitter
91+
# TODO(yancey1989): output the original result for various
92+
# objective function.
9193
if model_params:
9294
obj = model_params["objective"]
9395
if obj.startswith("binary:"):
@@ -98,8 +100,9 @@ def predict_and_store_result(bst, dpred, feature_file_id, model_params,
98100
# using the original prediction result of predict API by default
99101
pass
100102
else:
101-
# prediction output with multi-class job has two dimensions, this is a temporary
102-
# way, can remove this else branch when we can load the model meta not only on PAI submitter.
103+
# prediction output with multi-class job has two dimensions, this
104+
# is a temporary way, can remove this else branch when we can load
105+
# the model meta not only on PAI submitter.
103106
if len(preds.shape) == 2:
104107
preds = np.argmax(np.array(preds), axis=1)
105108

@@ -134,12 +137,12 @@ def predict_and_store_result(bst, dpred, feature_file_id, model_params,
134137
hive_location=hive_location,
135138
hdfs_user=hdfs_user,
136139
hdfs_pass=hdfs_pass) as w:
137-
import sys
138140
while True:
139141
line = feature_file_read.readline()
140142
if not line:
141143
break
142-
# FIXME(typhoonzero): how to output columns that are not used as features, like ids?
144+
# FIXME(typhoonzero): how to output columns that are not used
145+
# as features, like ids?
143146
row = [
144147
item for i, item in enumerate(line.strip().split("/"))
145148
if i != train_label_index

python/runtime/xgboost/tracker.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
- help nodes to establish links with each other
2525
Tianqi Chen
2626
"""
27-
# pylint: disable=invalid-name, missing-docstring, too-many-arguments, too-many-locals
27+
# pylint: disable=invalid-name, missing-docstring
28+
# pylint: disable=too-many-arguments, too-many-locals
2829
# pylint: disable=too-many-branches, too-many-statements
2930
from __future__ import absolute_import
3031

@@ -436,9 +437,8 @@ def get_host_ip(hostIP=None):
436437
try:
437438
hostIP = socket.gethostbyname(socket.getfqdn())
438439
except gaierror:
439-
logging.warn(
440-
'gethostbyname(socket.getfqdn()) failed... trying on hostname()'
441-
)
440+
logging.warn('gethostbyname(socket.getfqdn()) failed... '
441+
'trying on hostname()')
442442
hostIP = socket.gethostbyname(socket.gethostname())
443443
if hostIP.startswith("127."):
444444
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)

python/runtime/xgboost/train.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# limitations under the License.
1313

1414
import json
15-
import os
1615
import sys
1716

1817
import runtime.pai.pai_distributed as pai_dist
@@ -200,8 +199,9 @@ def save_model_to_local_file(booster, model_params, meta, filename):
200199
from sklearn2pmml import PMMLPipeline, sklearn2pmml
201200
try:
202201
from xgboost.compat import XGBoostLabelEncoder
203-
except:
204-
# xgboost==0.82.0 does not have XGBoostLabelEncoder in xgboost.compat.py
202+
except: # noqa: E722
203+
# xgboost==0.82.0 does not have XGBoostLabelEncoder
204+
# in xgboost.compat.py
205205
from xgboost.sklearn import XGBLabelEncoder as XGBoostLabelEncoder
206206

207207
objective = model_params.get("objective")
@@ -212,10 +212,11 @@ def save_model_to_local_file(booster, model_params, meta, filename):
212212
num_class = 2
213213
else:
214214
num_class = model_params.get("num_class")
215-
assert num_class is not None and num_class > 0, "num_class should not be None"
215+
assert num_class is not None and num_class > 0, \
216+
"num_class should not be None"
216217

217-
# To fake a trained XGBClassifier, there must be "_le", "classes_", inside
218-
# XGBClassifier. See here:
218+
# To fake a trained XGBClassifier, there must be "_le", "classes_",
219+
# inside XGBClassifier. See here:
219220
# https://github.com/dmlc/xgboost/blob/d19cec70f1b40ea1e1a35101ca22e46dd4e4eecd/python-package/xgboost/sklearn.py#L356
220221
model = xgb.XGBClassifier()
221222
label_encoder = XGBoostLabelEncoder()

0 commit comments

Comments
 (0)