Skip to content

Commit 277ab0f

Browse files
authored
Add xgboost predict method (#2835)
* add xgboost predict * fix ut
1 parent 1b0233f commit 277ab0f

7 files changed

Lines changed: 355 additions & 58 deletions

File tree

python/runtime/db.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,27 @@ def read_features_from_row(row, select_cols, feature_column_names,
149149
return tuple(features)
150150

151151

152+
def to_db_field_type(driver, dtype):
153+
"""
154+
This method converts the dtype to a field type that the CREATE
155+
TABLE statement accepts.
156+
157+
Args:
158+
driver (str): the DBMS driver type.
159+
dtype (str): the data type.
160+
161+
Returns:
162+
A field type that the CREATE TABLE statement accepts.
163+
"""
164+
if dtype in ["VARCHAR", "CHAR"]:
165+
if driver == "mysql":
166+
return dtype + "(255)"
167+
else:
168+
return "STRING"
169+
else:
170+
return dtype
171+
172+
152173
def db_generator(conn, statement, label_meta=None):
153174
def reader():
154175
rs = conn.query(statement)

python/runtime/feature/field_desc.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,32 @@ class DataType(object):
2727
FLOAT32 = 1
2828
STRING = 2
2929

30+
@staticmethod
31+
def to_db_field_type(driver, dtype):
32+
"""
33+
This method converts the dtype to a field type that the CREATE
34+
TABLE statement accepts.
35+
36+
Args:
37+
driver (str): the DBMS driver type.
38+
dtype (enum): the data type. One of FLOAT32, INT64 and STRING.
39+
40+
Returns:
41+
A field type that the CREATE TABLE statement accepts.
42+
"""
43+
if dtype == DataType.INT64:
44+
return "BIGINT"
45+
46+
if dtype == DataType.FLOAT32:
47+
return "DOUBLE"
48+
49+
if dtype == DataType.STRING:
50+
if driver == "mysql":
51+
return "VARCHAR(255)"
52+
return "STRING"
53+
54+
raise ValueError("unsupported data type {}".format(dtype))
55+
3056

3157
# DataFormat is used in FieldDesc to represent the data format
3258
# of a database field.

python/runtime/local/xgboost/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
from runtime.local.xgboost.predict import pred # noqa: F401
1415
from runtime.local.xgboost.train import train # noqa: F401
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import os
15+
import tempfile
16+
17+
import numpy as np
18+
import runtime.xgboost as xgboost_extended
19+
import xgboost as xgb
20+
from runtime import db
21+
from runtime.feature.compile import compile_ir_feature_columns
22+
from runtime.feature.derivation import get_ordered_field_descs
23+
from runtime.feature.field_desc import DataType
24+
from runtime.model.model import Model
25+
from runtime.xgboost.dataset import xgb_dataset
26+
27+
28+
def pred(datasource, select, result_table, pred_label_name, load):
29+
"""
30+
Do prediction using a trained model.
31+
32+
Args:
33+
datasource (str): the database connection string.
34+
select (str): the input data to predict.
35+
result_table (str): the output data table.
36+
pred_label_name (str): the output label name to predict.
37+
load (str): where the trained model stores.
38+
39+
Returns:
40+
None.
41+
"""
42+
model = Model.load_from_db(datasource, load)
43+
model_params = model.get_meta("attributes")
44+
train_fc_map = model.get_meta("features")
45+
train_label_desc = model.get_meta("label").get_field_desc()[0]
46+
47+
field_descs = get_ordered_field_descs(train_fc_map)
48+
feature_column_names = [fd.name for fd in field_descs]
49+
feature_metas = dict([(fd.name, fd.to_dict()) for fd in field_descs])
50+
51+
# NOTE: in the current implementation, we are generating a transform_fn
52+
# from the COLUMN clause. The transform_fn is executed during the process
53+
# of dumping the original data into DMatrix SVM file.
54+
compiled_fc = compile_ir_feature_columns(train_fc_map, model.get_type())
55+
transform_fn = xgboost_extended.feature_column.ComposedColumnTransformer(
56+
feature_column_names, *compiled_fc["feature_columns"])
57+
58+
bst = xgb.Booster()
59+
bst.load_model("my_model")
60+
61+
conn = db.connect_with_data_source(datasource)
62+
result_column_names, train_label_idx = _create_predict_table(
63+
conn, select, result_table, train_label_desc, pred_label_name)
64+
65+
with tempfile.TemporaryDirectory() as tmp_dir_name:
66+
pred_fn = os.path.join(tmp_dir_name, "predict.txt")
67+
raw_data_dir = os.path.join(tmp_dir_name, "predict_raw_dir")
68+
69+
dpred = xgb_dataset(
70+
datasource=datasource,
71+
fn=pred_fn,
72+
dataset_sql=select,
73+
feature_metas=feature_metas,
74+
feature_column_names=feature_column_names,
75+
label_meta=None,
76+
cache=True,
77+
batch_size=10000,
78+
transform_fn=transform_fn,
79+
raw_data_dir=raw_data_dir) # NOTE: default to use external memory
80+
81+
print("Start predicting XGBoost model...")
82+
for idx, pred_dmatrix in enumerate(dpred):
83+
feature_file_name = os.path.join(
84+
tmp_dir_name, "predict_raw_dir/predict.txt_%d" % idx)
85+
_predict_and_store_result(bst, pred_dmatrix, model_params,
86+
result_table, result_column_names,
87+
train_label_idx, feature_file_name, conn)
88+
print("Done predicting. Predict table : %s" % result_table)
89+
90+
conn.close()
91+
92+
93+
def _predict_and_store_result(bst, dpred, model_params, result_table,
94+
result_column_names, train_label_idx,
95+
feature_file_name, conn):
96+
"""
97+
Do prediction and save the prediction result in the table.
98+
99+
Args:
100+
bst: the XGBoost booster object.
101+
dpred: the XGBoost DMatrix input data to predict.
102+
model_params (dict): the XGBoost model parameters.
103+
result_table (str): the result table name.
104+
result_column_names (list[str]): the result column names.
105+
train_label_idx (int): the index where the trained label is inside
106+
result_column_names.
107+
feature_file_name (str): the file path where the feature dumps.
108+
conn: the database connection object.
109+
110+
Returns:
111+
None.
112+
"""
113+
preds = bst.predict(dpred)
114+
115+
# TODO(yancey1989): should save train_params and model_params
116+
# not only on PAI submitter
117+
# TODO(yancey1989): output the original result for various
118+
# objective function.
119+
objective = model_params.get("objective", "")
120+
if objective.startswith("binary:"):
121+
preds = (preds > 0.5).astype(np.int64)
122+
elif objective.startswith("multi:") and len(preds) == 2:
123+
preds = np.argmax(np.array(preds), axis=1)
124+
125+
with db.buffered_db_writer(conn, result_table, result_column_names,
126+
100) as w:
127+
with open(feature_file_name, "r") as feature_file_read:
128+
line_no = 0
129+
for line in feature_file_read.readlines():
130+
if not line:
131+
break
132+
133+
row = [
134+
item for i, item in enumerate(line.strip().split("/"))
135+
if i != train_label_idx
136+
]
137+
row.append(str(preds[line_no]))
138+
w.write(row)
139+
line_no += 1
140+
141+
142+
def _create_predict_table(conn, select, result_table, train_label_desc,
143+
pred_label_name):
144+
"""
145+
Create the result prediction table.
146+
147+
Args:
148+
conn: the database connection object.
149+
select (str): the input data to predict.
150+
result_table (str): the output data table.
151+
train_label_desc (FieldDesc): the FieldDesc of the trained label.
152+
pred_label_name (str): the output label name to predict.
153+
154+
Returns:
155+
A tuple of (result_column_names, train_label_index).
156+
"""
157+
name_and_types = db.selected_columns_and_types(conn, select)
158+
train_label_index = -1
159+
for i, (name, _) in enumerate(name_and_types):
160+
if name == train_label_desc.name:
161+
train_label_index = i
162+
break
163+
164+
if train_label_index >= 0:
165+
del name_and_types[train_label_index]
166+
167+
column_strs = []
168+
for name, typ in name_and_types:
169+
column_strs.append("%s %s" %
170+
(name, db.to_db_field_type(conn.driver, typ)))
171+
172+
train_label_field_type = DataType.to_db_field_type(conn.driver,
173+
train_label_desc.dtype)
174+
column_strs.append("%s %s" % (pred_label_name, train_label_field_type))
175+
176+
drop_sql = "DROP TABLE IF EXISTS %s;" % result_table
177+
create_sql = "CREATE TABLE %s (%s);" % (result_table,
178+
",".join(column_strs))
179+
conn.execute(drop_sql)
180+
conn.execute(create_sql)
181+
result_column_names = [item[0] for item in name_and_types]
182+
result_column_names.append(pred_label_name)
183+
return result_column_names, train_label_index
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import os
15+
import tempfile
16+
import unittest
17+
18+
import runtime.db as db
19+
import runtime.testing as testing
20+
from runtime.feature.column import NumericColumn
21+
from runtime.feature.field_desc import FieldDesc
22+
from runtime.local.xgboost import pred, train
23+
24+
25+
class TestXGBoostTrain(unittest.TestCase):
26+
def get_table_row_count(self, conn, table):
27+
ret = list(conn.query("SELECT COUNT(*) FROM %s" % table))
28+
self.assertEqual(len(ret), 1)
29+
ret = ret[0]
30+
self.assertEqual(len(ret), 1)
31+
return ret[0]
32+
33+
def get_table_schema(self, conn, table):
34+
name_and_types = conn.get_table_schema(table)
35+
return dict(name_and_types)
36+
37+
@unittest.skipUnless(testing.get_driver() == "mysql",
38+
"skip non mysql tests")
39+
def test_train_and_predict(self):
40+
ds = testing.get_datasource()
41+
original_sql = """SELECT * FROM iris.train
42+
TO TRAIN xgboost.gbtree
43+
WITH
44+
objective="multi:softmax",
45+
num_boost_round=20,
46+
num_class=3,
47+
validation.select="SELECT * FROM iris.test"
48+
INTO iris.xgboost_train_model_test;
49+
"""
50+
51+
select = "SELECT * FROM iris.train"
52+
val_select = "SELECT * FROM iris.test"
53+
train_params = {"num_boost_round": 20}
54+
model_params = {"num_class": 3, "objective": "multi:softmax"}
55+
save_name = "iris.xgboost_train_model_test"
56+
class_name = "class"
57+
58+
old_dir_name = os.getcwd()
59+
with tempfile.TemporaryDirectory() as tmp_dir_name:
60+
os.chdir(tmp_dir_name)
61+
eval_result = train(original_sql=original_sql,
62+
model_image="sqlflow:step",
63+
estimator="xgboost.gbtree",
64+
datasource=ds,
65+
select=select,
66+
validation_select=val_select,
67+
model_params=model_params,
68+
train_params=train_params,
69+
feature_column_map=None,
70+
label_column=NumericColumn(
71+
FieldDesc(name=class_name)),
72+
save=save_name)
73+
self.assertLess(eval_result['train']['merror'][-1], 0.01)
74+
self.assertLess(eval_result['validate']['merror'][-1], 0.01)
75+
76+
conn = db.connect_with_data_source(ds)
77+
78+
pred_select = "SELECT * FROM iris.test"
79+
pred(ds, pred_select, "iris.predict_result_table", class_name,
80+
save_name)
81+
82+
self.assertEqual(
83+
self.get_table_row_count(conn, "iris.test"),
84+
self.get_table_row_count(conn, "iris.predict_result_table"))
85+
86+
schema1 = self.get_table_schema(conn, "iris.test")
87+
schema2 = self.get_table_schema(conn, "iris.predict_result_table")
88+
self.assertEqual(len(schema1), len(schema2))
89+
for name in schema1:
90+
if name == 'class':
91+
self.assertEqual(schema2[name], "BIGINT")
92+
continue
93+
94+
self.assertTrue(name in schema2)
95+
self.assertEqual(schema1[name], schema2[name])
96+
97+
diff_schema = schema2.keys() - schema1.keys()
98+
self.assertEqual(len(diff_schema), 0)
99+
100+
os.chdir(old_dir_name)
101+
102+
103+
if __name__ == '__main__':
104+
unittest.main()

0 commit comments

Comments
 (0)