Skip to content

Commit f0f1d30

Browse files
authored
Add feature column compilation (#2772)
* add feature_column compile * follow comments * update
1 parent 89e1f6e commit f0f1d30

4 files changed

Lines changed: 299 additions & 4 deletions

File tree

python/runtime/feature/column.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def num_class(self):
128128
return self.bucket_size
129129

130130

131-
class CategoryHashColumn(CategoryIDColumn):
131+
class CategoryHashColumn(CategoryColumn):
132132
"""
133133
CategoryHashColumn represents a categorical hash feature column.
134134
@@ -151,7 +151,7 @@ def num_class(self):
151151
return self.bucket_size
152152

153153

154-
class SeqCategoryIDColumn(CategoryIDColumn):
154+
class SeqCategoryIDColumn(CategoryColumn):
155155
"""
156156
SeqCategoryIDColumn represents a sequential categorical id feature column.
157157

python/runtime/feature/compile.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import numpy as np
15+
import six
16+
from runtime.feature.column import (BucketColumn, CategoryHashColumn,
17+
CategoryIDColumn, CrossColumn,
18+
EmbeddingColumn, IndicatorColumn,
19+
NumericColumn, SeqCategoryIDColumn)
20+
from runtime.feature.field_desc import DataType
21+
from runtime.model import EstimatorType
22+
23+
__all__ = [
24+
'compile_ir_feature_columns',
25+
]
26+
27+
28+
def to_package_dtype(dtype, package):
29+
"""
30+
Convert dtype to the data type accepted by the feature column
31+
implementation packages including TensorFlow and XGBoost.
32+
33+
Args:
34+
dtype (DataType): one of INT, FLOAT and STRING.
35+
package (module): the Python package, including TensorFlow
36+
and XGBoost feature column packages.
37+
38+
Returns:
39+
The data type accepted by the feature column implementation
40+
packages including TensorFlow and XGBoost.
41+
"""
42+
if dtype == DataType.INT:
43+
return package.dtypes.int64
44+
45+
if dtype == DataType.FLOAT:
46+
return package.dtypes.float32
47+
48+
if dtype == DataType.STRING:
49+
return package.dtypes.string
50+
51+
raise ValueError("unsupported data type {}".format(dtype))
52+
53+
54+
def compile_feature_column(ir_fc, model_type, package):
55+
"""
56+
Compile an IR FeatureColumn object to a runtime feature column object.
57+
58+
Args:
59+
ir_fc (FeatureColumn): the IR FeatureColumn object.
60+
model_type (EstimatorType): one of TENSORFLOW and XGBOOST.
61+
package (module): the Python package corresponding to the model_type.
62+
63+
Returns:
64+
A runtime feature column object.
65+
"""
66+
fc_package = package.feature_column
67+
68+
if isinstance(ir_fc, NumericColumn):
69+
fd = ir_fc.get_field_desc()[0]
70+
return fc_package.numeric_column(fd.name, shape=fd.shape)
71+
72+
if isinstance(ir_fc, BucketColumn):
73+
source_fc = compile_feature_column(ir_fc.source_column, model_type,
74+
package)
75+
return fc_package.bucketized_column(source_fc,
76+
boundaries=ir_fc.boundaries)
77+
78+
if isinstance(ir_fc, CategoryIDColumn):
79+
fd = ir_fc.get_field_desc()[0]
80+
if fd.vocabulary:
81+
return fc_package.categorical_column_with_vocabulary_list(
82+
key=fd.name, vocabulary_list=list(fd.vocabulary))
83+
else:
84+
return fc_package.categorical_column_with_identity(
85+
key=fd.name, num_buckets=ir_fc.bucket_size)
86+
87+
if isinstance(ir_fc, SeqCategoryIDColumn):
88+
assert model_type != EstimatorType.XGBOOST, \
89+
"SEQ_CATEGORY_ID is not supported in XGBoost models"
90+
fd = ir_fc.get_field_desc()[0]
91+
return fc_package.sequence_categorical_column_with_identity(
92+
key=fd.name, num_buckets=ir_fc.bucket_size)
93+
94+
if isinstance(ir_fc, CategoryHashColumn):
95+
fd = ir_fc.get_field_desc()[0]
96+
dtype = to_package_dtype(fd.dtype, package)
97+
return fc_package.categorical_column_with_hash_bucket(
98+
key=fd.name, hash_bucket_size=ir_fc.bucket_size, dtype=dtype)
99+
100+
if isinstance(ir_fc, CrossColumn):
101+
assert model_type != EstimatorType.XGBOOST, \
102+
"CROSS is not supported in XGBoost models"
103+
key_strs = []
104+
for key in ir_fc.keys:
105+
if isinstance(key, six.string_types):
106+
key_strs.append(key)
107+
elif isinstance(key, NumericColumn):
108+
fd = key.get_field_desc()[0]
109+
size = np.prod(fd.shape) if fd.shape else 1
110+
assert size == 1, "CROSS does not support shape not equal to 1"
111+
key_strs.append(fd.name)
112+
else:
113+
raise ValueError(
114+
"field in CROSS must be of FeatureColumn or string type")
115+
116+
return fc_package.crossed_column(
117+
key_strs, hash_bucket_size=ir_fc.hash_bucket_size)
118+
119+
if isinstance(ir_fc, EmbeddingColumn):
120+
assert model_type != EstimatorType.XGBOOST, \
121+
"EMBEDDING is not supported in XGBoost models"
122+
category_column = compile_feature_column(ir_fc.category_column,
123+
model_type, package)
124+
return fc_package.embedding_column(category_column,
125+
dimension=ir_fc.dimension,
126+
combiner=ir_fc.combiner)
127+
128+
if isinstance(ir_fc, IndicatorColumn):
129+
category_column = compile_feature_column(ir_fc.category_column,
130+
model_type, package)
131+
return fc_package.indicator_column(category_column)
132+
133+
raise ValueError("unsupport FeatureColumn %s" % type(ir_fc))
134+
135+
136+
def compile_ir_feature_columns(ir_features, model_type):
137+
"""
138+
Compile an IR FeatureColumn map to a runtime feature column map.
139+
140+
Args:
141+
ir_features (dict[str -> list[FeatureColumn]]): the IR FeatureColumn
142+
map, where the key is the target name, e.g. "feature_columns",
143+
and the element inside the list is the IR FeatureColumn object.
144+
model_type (EstimatorType): one of TENSORFLOW and XGBOOST.
145+
146+
Returns:
147+
A runtime feature column map, whose type is
148+
dict[str -> list[RuntimeFeatureColumn]].
149+
"""
150+
if model_type == EstimatorType.TENSORFLOW:
151+
import tensorflow
152+
package = tensorflow
153+
elif model_type == EstimatorType.XGBOOST:
154+
import runtime.xgboost
155+
package = runtime.xgboost
156+
assert len(ir_features) == 1 and "feature_columns" in ir_features, \
157+
"XGBoost only supports 'feature_columns' as the feature target"
158+
else:
159+
raise ValueError("only support TensorFlow and XGBoost model")
160+
161+
all_fcs = dict()
162+
for target, fc_list in ir_features.items():
163+
fcs = [
164+
compile_feature_column(fc, model_type, package) for fc in fc_list
165+
]
166+
all_fcs[target] = fcs
167+
168+
return all_fcs
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import unittest
15+
16+
from runtime.feature.column import (BucketColumn, CategoryHashColumn,
17+
CategoryIDColumn, CrossColumn,
18+
EmbeddingColumn, IndicatorColumn,
19+
NumericColumn, SeqCategoryIDColumn)
20+
from runtime.feature.compile import compile_ir_feature_columns
21+
from runtime.feature.field_desc import DataType, FieldDesc
22+
from runtime.model import EstimatorType
23+
24+
TENSORFLOW = EstimatorType.TENSORFLOW
25+
XGBOOST = EstimatorType.XGBOOST
26+
27+
28+
class TestFeatureColumnCompilation(unittest.TestCase):
29+
def compile_fc(self, fc, model_type):
30+
fc_dict = {"feature_columns": [fc]}
31+
rt_fc_dict = compile_ir_feature_columns(fc_dict, model_type)
32+
self.assertEqual(len(rt_fc_dict), 1)
33+
self.assertTrue("feature_columns" in rt_fc_dict)
34+
fc_list = rt_fc_dict.get("feature_columns")
35+
self.assertEqual(len(fc_list), 1)
36+
return fc_list[0]
37+
38+
def test_numeric_column(self):
39+
nc = NumericColumn(FieldDesc(name='c1', shape=(2, 3)))
40+
41+
for model_type in [TENSORFLOW, XGBOOST]:
42+
compiled_nc = self.compile_fc(nc, model_type)
43+
self.assertEqual(compiled_nc.key, 'c1')
44+
self.assertEqual(compiled_nc.shape, (2, 3))
45+
46+
def test_bucket_column(self):
47+
nc = NumericColumn(FieldDesc(name='c1', shape=(1, )))
48+
bc = BucketColumn(nc, (-10, -5, 3, 7))
49+
50+
for model_type in [TENSORFLOW, XGBOOST]:
51+
compiled_bc = self.compile_fc(bc, model_type)
52+
self.assertEqual(compiled_bc.source_column.key, 'c1')
53+
self.assertEqual(compiled_bc.boundaries, (-10, -5, 3, 7))
54+
55+
def test_category_id_column(self):
56+
cc = CategoryIDColumn(FieldDesc(name='c1'), 128)
57+
58+
for model_type in [TENSORFLOW, XGBOOST]:
59+
compiled_cc = self.compile_fc(cc, model_type)
60+
self.assertEqual(compiled_cc.key, 'c1')
61+
self.assertEqual(compiled_cc.num_buckets, 128)
62+
63+
cc = CategoryIDColumn(FieldDesc(name='c1', vocabulary=set(['a', 'b'])),
64+
128)
65+
for model_type in [TENSORFLOW, XGBOOST]:
66+
compiled_cc = self.compile_fc(cc, model_type)
67+
vocab = sorted(compiled_cc.vocabulary_list)
68+
self.assertEqual(vocab, ['a', 'b'])
69+
70+
def test_seq_category_id_column(self):
71+
scc = SeqCategoryIDColumn(FieldDesc(name='c1'), 64)
72+
compiled_scc = self.compile_fc(scc, TENSORFLOW)
73+
# NOTE: TensorFlow SeqCategoryIDColumn does not have key
74+
# attribute
75+
# self.assertEqual(compiled_scc.key, 'c1')
76+
self.assertEqual(compiled_scc.num_buckets, 64)
77+
78+
with self.assertRaises(AssertionError):
79+
self.compile_fc(scc, XGBOOST)
80+
81+
def test_category_hash_column(self):
82+
chc = CategoryHashColumn(FieldDesc(name='c1', dtype=DataType.STRING),
83+
32)
84+
for model_type in [TENSORFLOW, XGBOOST]:
85+
compiled_chc = self.compile_fc(chc, model_type)
86+
self.assertEqual(compiled_chc.key, 'c1')
87+
self.assertEqual(compiled_chc.hash_bucket_size, 32)
88+
89+
def test_cross_column(self):
90+
cc = CrossColumn(['c1', NumericColumn(FieldDesc(name='c2'))], 4096)
91+
compiled_cc = self.compile_fc(cc, TENSORFLOW)
92+
self.assertEqual(list(compiled_cc.keys), ['c1', 'c2'])
93+
self.assertEqual(compiled_cc.hash_bucket_size, 4096)
94+
95+
with self.assertRaises(AssertionError):
96+
self.compile_fc(cc, XGBOOST)
97+
98+
def test_embedding_column(self):
99+
chc = CategoryHashColumn(FieldDesc(name='c1', dtype=DataType.STRING),
100+
32)
101+
ec = EmbeddingColumn(category_column=chc, combiner='sum', dimension=23)
102+
103+
compiled_ec = self.compile_fc(ec, TENSORFLOW)
104+
self.assertEqual(compiled_ec.combiner, 'sum')
105+
self.assertEqual(compiled_ec.dimension, 23)
106+
107+
compiled_chc = compiled_ec.categorical_column
108+
self.assertEqual(compiled_chc.key, 'c1')
109+
self.assertEqual(compiled_chc.hash_bucket_size, 32)
110+
111+
with self.assertRaises(AssertionError):
112+
self.compile_fc(ec, XGBOOST)
113+
114+
def test_indicator_column(self):
115+
cc = CategoryIDColumn(FieldDesc(name='c1'), 128)
116+
ic = IndicatorColumn(category_column=cc)
117+
118+
for model_type in [TENSORFLOW, XGBOOST]:
119+
compiled_chc = self.compile_fc(ic, model_type)
120+
compiled_cc = compiled_chc.categorical_column
121+
self.assertEqual(compiled_cc.key, 'c1')
122+
self.assertEqual(compiled_cc.num_buckets, 128)
123+
124+
125+
if __name__ == '__main__':
126+
unittest.main()

python/runtime/xgboost/feature_column.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,9 @@ def numeric_column(key, shape=(1, )):
8989

9090
class BucketizedColumnTransformer(CategoricalColumnTransformer):
9191
def __init__(self, source_column, boundaries):
92-
assert boundaries == sorted(
93-
boundaries), "Boundaries must be sorted in ascending order"
92+
for i in six.moves.range(len(boundaries) - 1):
93+
assert boundaries[i] < boundaries[i+1], \
94+
"Boundaries must be sorted in ascending order"
9495
self.source_column = source_column
9596
self.boundaries = boundaries
9697

0 commit comments

Comments
 (0)