1212# limitations under the License.
1313
1414import numpy as np
15- import sklearn
15+ import sklearn . metrics
1616import xgboost as xgb
1717from runtime import db
1818from runtime .xgboost .dataset import xgb_dataset
19- # yapf: disable
20- from sklearn .metrics import (accuracy_score , average_precision_score ,
21- balanced_accuracy_score , brier_score_loss ,
22- cohen_kappa_score , explained_variance_score ,
23- f1_score , fbeta_score , hamming_loss , hinge_loss ,
24- log_loss , mean_absolute_error , mean_squared_error ,
25- mean_squared_log_error , median_absolute_error ,
26- precision_score , r2_score , recall_score ,
27- roc_auc_score , zero_one_loss )
2819
29- # yapf: enable
20+ SKLEARN_METRICS = [
21+ 'accuracy_score' ,
22+ 'average_precision_score' ,
23+ 'balanced_accuracy_score' ,
24+ 'brier_score_loss' ,
25+ 'cohen_kappa_score' ,
26+ 'explained_variance_score' ,
27+ 'f1_score' ,
28+ 'fbeta_score' ,
29+ 'hamming_loss' ,
30+ 'hinge_loss' ,
31+ 'log_loss' ,
32+ 'mean_absolute_error' ,
33+ 'mean_squared_error' ,
34+ 'mean_squared_log_error' ,
35+ 'median_absolute_error' ,
36+ 'precision_score' ,
37+ 'r2_score' ,
38+ 'recall_score' ,
39+ 'roc_auc_score' ,
40+ 'zero_one_loss' ,
41+ ]
3042
3143DEFAULT_PREDICT_BATCH_SIZE = 10000
3244
@@ -95,8 +107,9 @@ def evaluate_and_store_result(bst, dpred, feature_file_id, validation_metrics,
95107 # using the original prediction result of predict API by default
96108 pass
97109 else :
98- # prediction output with multi-class job has two dimensions, this is a temporary
99- # way, can remove this else branch when we can load the model meta not only on PAI submitter.
110+ # prediction output with multi-class job has two dimensions, this
111+ # is a temporary way, can remove this else branch when we can load
112+ # the model meta not only on PAI submitter.
100113 if len (preds .shape ) == 2 :
101114 preds = np .argmax (np .array (preds ), axis = 1 )
102115
@@ -121,7 +134,9 @@ def evaluate_and_store_result(bst, dpred, feature_file_id, validation_metrics,
121134
122135 evaluate_results = dict ()
123136 for metric_name in validation_metrics :
124- metric_func = eval (metric_name )
137+ if metric_name not in SKLEARN_METRICS :
138+ raise ValueError ("unsupported metric: %s" % metric_name )
139+ metric_func = getattr (sklearn .metrics , metric_name )
125140 metric_value = metric_func (y_test , preds )
126141 evaluate_results [metric_name ] = metric_value
127142
0 commit comments