Skip to content

Commit a1c85a4

Browse files
authored
Add xgboost evaluation local code (#2844)
* add xgboost evaluation local code * replace temp_file apis * polish
1 parent 6443307 commit a1c85a4

3 files changed

Lines changed: 229 additions & 20 deletions

File tree

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import os
15+
16+
import numpy as np
17+
import runtime.temp_file as temp_file
18+
import runtime.xgboost as xgboost_extended
19+
import sklearn.metrics
20+
import xgboost as xgb
21+
from runtime import db
22+
from runtime.feature.compile import compile_ir_feature_columns
23+
from runtime.feature.derivation import get_ordered_field_descs
24+
from runtime.feature.field_desc import DataType
25+
from runtime.local.xgboost_submitter.predict import _calc_predict_result
26+
from runtime.model.model import Model
27+
from runtime.xgboost.dataset import xgb_dataset
28+
29+
SKLEARN_METRICS = [
30+
'accuracy_score',
31+
'average_precision_score',
32+
'balanced_accuracy_score',
33+
'brier_score_loss',
34+
'cohen_kappa_score',
35+
'explained_variance_score',
36+
'f1_score',
37+
'fbeta_score',
38+
'hamming_loss',
39+
'hinge_loss',
40+
'log_loss',
41+
'mean_absolute_error',
42+
'mean_squared_error',
43+
'mean_squared_log_error',
44+
'median_absolute_error',
45+
'precision_score',
46+
'r2_score',
47+
'recall_score',
48+
'roc_auc_score',
49+
'zero_one_loss',
50+
]
51+
52+
53+
def evaluate(datasource,
54+
select,
55+
result_table,
56+
load,
57+
pred_label_name=None,
58+
validation_metrics=["accuracy_score"]):
59+
"""
60+
Do evaluation to a trained XGBoost model.
61+
62+
Args:
63+
datasource (str): the database connection string.
64+
select (str): the input data to predict.
65+
result_table (str): the output data table.
66+
load (str): where the trained model stores.
67+
pred_label_name (str): the label column name.
68+
validation_metrics (list[str]): the evaluation metric names.
69+
70+
Returns:
71+
None.
72+
"""
73+
model = Model.load_from_db(datasource, load)
74+
model_params = model.get_meta("attributes")
75+
train_fc_map = model.get_meta("features")
76+
train_label_desc = model.get_meta("label").get_field_desc()[0]
77+
if pred_label_name:
78+
train_label_desc.name = pred_label_name
79+
80+
field_descs = get_ordered_field_descs(train_fc_map)
81+
feature_column_names = [fd.name for fd in field_descs]
82+
feature_metas = dict([(fd.name, fd.to_dict()) for fd in field_descs])
83+
84+
# NOTE: in the current implementation, we are generating a transform_fn
85+
# from the COLUMN clause. The transform_fn is executed during the process
86+
# of dumping the original data into DMatrix SVM file.
87+
compiled_fc = compile_ir_feature_columns(train_fc_map, model.get_type())
88+
transform_fn = xgboost_extended.feature_column.ComposedColumnTransformer(
89+
feature_column_names, *compiled_fc["feature_columns"])
90+
91+
bst = xgb.Booster()
92+
bst.load_model("my_model")
93+
conn = db.connect_with_data_source(datasource)
94+
95+
result_column_names = _create_evaluate_table(conn, result_table,
96+
validation_metrics)
97+
98+
with temp_file.TemporaryDirectory() as tmp_dir_name:
99+
pred_fn = os.path.join(tmp_dir_name, "predict.txt")
100+
101+
dpred = xgb_dataset(datasource=datasource,
102+
fn=pred_fn,
103+
dataset_sql=select,
104+
feature_metas=feature_metas,
105+
feature_column_names=feature_column_names,
106+
label_meta=train_label_desc.to_dict(),
107+
cache=True,
108+
batch_size=10000,
109+
transform_fn=transform_fn)
110+
111+
for i, pred_dmatrix in enumerate(dpred):
112+
feature_file_name = pred_fn + "_%d" % i
113+
preds = _calc_predict_result(bst, pred_dmatrix, model_params)
114+
_store_evaluate_result(preds, feature_file_name, train_label_desc,
115+
result_table, result_column_names,
116+
validation_metrics, conn)
117+
118+
conn.close()
119+
120+
121+
def _create_evaluate_table(conn, result_table, validation_metrics):
122+
"""
123+
Create the result table to store the evaluation result.
124+
125+
Args:
126+
conn: the database connection object.
127+
result_table (str): the output data table.
128+
validation_metrics (list[str]): the evaluation metric names.
129+
130+
Returns:
131+
The column names of the created table.
132+
"""
133+
result_columns = ['loss'] + validation_metrics
134+
float_field_type = DataType.to_db_field_type(conn.driver, DataType.FLOAT32)
135+
column_strs = [
136+
"%s %s" % (name, float_field_type) for name in result_columns
137+
]
138+
139+
drop_sql = "DROP TABLE IF EXISTS %s;" % result_table
140+
create_sql = "CREATE TABLE %s (%s);" % (result_table,
141+
",".join(column_strs))
142+
conn.execute(drop_sql)
143+
conn.execute(create_sql)
144+
145+
return result_columns
146+
147+
148+
def _store_evaluate_result(preds, feature_file_name, label_desc, result_table,
149+
result_column_names, validation_metrics, conn):
150+
"""
151+
Save the evaluation result in the table.
152+
153+
Args:
154+
preds: the prediction result.
155+
feature_file_name (str): the file path where the feature dumps.
156+
label_desc (FieldDesc): the label FieldDesc object.
157+
result_table (str): the result table name.
158+
result_column_names (list[str]): the result column names.
159+
validation_metrics (list[str]): the evaluation metric names.
160+
conn: the database connection object.
161+
162+
Returns:
163+
None.
164+
"""
165+
y_test = []
166+
with open(feature_file_name, 'r') as f:
167+
for line in f.readlines():
168+
row = [i for i in line.strip().split("\t")]
169+
# DMatrix store label in the first column
170+
if label_desc.dtype == DataType.INT64:
171+
y_test.append(int(row[0]))
172+
elif label_desc.dtype == DataType.FLOAT32:
173+
y_test.append(float(row[0]))
174+
else:
175+
raise TypeError("unsupported data type {}".format(
176+
label_desc.dtype))
177+
178+
y_test = np.array(y_test)
179+
180+
evaluate_results = dict()
181+
for metric_name in validation_metrics:
182+
metric_name = metric_name.strip()
183+
if metric_name not in SKLEARN_METRICS:
184+
raise ValueError("unsupported metrics %s" % metric_name)
185+
metric_func = getattr(sklearn.metrics, metric_name)
186+
metric_value = metric_func(y_test, preds)
187+
evaluate_results[metric_name] = metric_value
188+
189+
# write evaluation result to result table
190+
with db.buffered_db_writer(conn, result_table, result_column_names) as w:
191+
row = ["0.0"]
192+
for mn in validation_metrics:
193+
row.append(str(evaluate_results[mn]))
194+
w.write(row)

python/runtime/local/xgboost_submitter/local_submitter_test.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
import os
1514
import unittest
1615

1716
import runtime.db as db
1817
import runtime.temp_file as temp_file
1918
import runtime.testing as testing
2019
from runtime.feature.column import NumericColumn
2120
from runtime.feature.field_desc import FieldDesc
21+
from runtime.local.xgboost_submitter.evaluate import evaluate
2222
from runtime.local.xgboost_submitter.predict import pred
2323
from runtime.local.xgboost_submitter.train import train
2424

@@ -37,7 +37,7 @@ def get_table_schema(self, conn, table):
3737

3838
@unittest.skipUnless(testing.get_driver() == "mysql",
3939
"skip non mysql tests")
40-
def test_train_and_predict(self):
40+
def test_main(self):
4141
ds = testing.get_datasource()
4242
original_sql = """SELECT * FROM iris.train
4343
TO TRAIN xgboost.gbtree
@@ -56,7 +56,6 @@ def test_train_and_predict(self):
5656
save_name = "iris.xgboost_train_model_test"
5757
class_name = "class"
5858

59-
old_dir_name = os.getcwd()
6059
with temp_file.TemporaryDirectory(as_cwd=True):
6160
eval_result = train(original_sql=original_sql,
6261
model_image="sqlflow:step",
@@ -97,7 +96,12 @@ def test_train_and_predict(self):
9796
diff_schema = schema2.keys() - schema1.keys()
9897
self.assertEqual(len(diff_schema), 0)
9998

100-
os.chdir(old_dir_name)
99+
evaluate(ds, pred_select, "iris.evaluate_result_table", save_name,
100+
'class', ['accuracy_score'])
101+
eval_schema = self.get_table_schema(conn,
102+
"iris.evaluate_result_table")
103+
self.assertEqual(eval_schema.keys(),
104+
set(['loss', 'accuracy_score']))
101105

102106

103107
if __name__ == '__main__':

python/runtime/local/xgboost_submitter/predict.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -88,33 +88,25 @@ def pred(datasource, select, result_table, pred_label_name, model):
8888
for idx, pred_dmatrix in enumerate(dpred):
8989
feature_file_name = os.path.join(
9090
tmp_dir_name, "predict_raw_dir/predict.txt_%d" % idx)
91-
_predict_and_store_result(bst, pred_dmatrix, model_params,
92-
result_table, result_column_names,
93-
train_label_idx, feature_file_name, conn)
91+
preds = _calc_predict_result(bst, pred_dmatrix, model_params)
92+
_store_predict_result(preds, result_table, result_column_names,
93+
train_label_idx, feature_file_name, conn)
9494
print("Done predicting. Predict table : %s" % result_table)
9595

9696
conn.close()
9797

9898

99-
def _predict_and_store_result(bst, dpred, model_params, result_table,
100-
result_column_names, train_label_idx,
101-
feature_file_name, conn):
99+
def _calc_predict_result(bst, dpred, model_params):
102100
"""
103-
Do prediction and save the prediction result in the table.
101+
Calculate the prediction result.
104102
105103
Args:
106104
bst: the XGBoost booster object.
107105
dpred: the XGBoost DMatrix input data to predict.
108106
model_params (dict): the XGBoost model parameters.
109-
result_table (str): the result table name.
110-
result_column_names (list[str]): the result column names.
111-
train_label_idx (int): the index where the trained label is inside
112-
result_column_names.
113-
feature_file_name (str): the file path where the feature dumps.
114-
conn: the database connection object.
115107
116108
Returns:
117-
None.
109+
The prediction result.
118110
"""
119111
preds = bst.predict(dpred)
120112

@@ -128,8 +120,27 @@ def _predict_and_store_result(bst, dpred, model_params, result_table,
128120
elif objective.startswith("multi:") and len(preds) == 2:
129121
preds = np.argmax(np.array(preds), axis=1)
130122

131-
with db.buffered_db_writer(conn, result_table, result_column_names,
132-
100) as w:
123+
return preds
124+
125+
126+
def _store_predict_result(preds, result_table, result_column_names,
127+
train_label_idx, feature_file_name, conn):
128+
"""
129+
Save the prediction result in the table.
130+
131+
Args:
132+
preds: the prediction result to save.
133+
result_table (str): the result table name.
134+
result_column_names (list[str]): the result column names.
135+
train_label_idx (int): the index where the trained label is inside
136+
result_column_names.
137+
feature_file_name (str): the file path where the feature dumps.
138+
conn: the database connection object.
139+
140+
Returns:
141+
None.
142+
"""
143+
with db.buffered_db_writer(conn, result_table, result_column_names) as w:
133144
with open(feature_file_name, "r") as feature_file_read:
134145
line_no = 0
135146
for line in feature_file_read.readlines():

0 commit comments

Comments
 (0)