Skip to content

Commit 5ae82c0

Browse files
authored
Add feature derivator (#771)
* add feature derivator (WIP) * wip * wip add test * wip: infer column data details * test done * clean unused code * fix CI
1 parent ce0310a commit 5ae82c0

14 files changed

Lines changed: 538 additions & 68 deletions

sql/codegen.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,12 @@ func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *D
112112
},
113113
}
114114

115-
trainResolved, err := resolveTrainClause(&pr.trainClause)
115+
var err error
116+
r.connectionConfig, err = newConnectionConfig(db)
117+
if err != nil {
118+
return nil, err
119+
}
120+
trainResolved, err := resolveTrainClause(&pr.trainClause, &pr.standardSelect, r.connectionConfig)
116121
if err != nil {
117122
return nil, err
118123
}
@@ -225,7 +230,6 @@ func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *D
225230
}
226231
}
227232

228-
r.connectionConfig, err = newConnectionConfig(db)
229233
if err == nil && r.Driver == "hive" {
230234
// remove the last ';' which leads to a (hive)ParseException
231235
r.TrainingDatasetSQL = strings.TrimSuffix(r.TrainingDatasetSQL, ";")

sql/codegen_alps.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ func modelCreatorCode(resolved *resolvedTrainClause, args []string) (string, str
178178
}
179179

180180
func newALPSTrainFiller(pr *extendedSelect, db *DB, session *pb.Session, ds *trainAndValDataset) (*alpsFiller, error) {
181-
resolved, err := resolveTrainClause(&pr.trainClause)
181+
resolved, err := resolveTrainClause(&pr.trainClause, &pr.standardSelect, nil)
182182
if err != nil {
183183
return nil, err
184184
}
@@ -675,6 +675,7 @@ func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*c
675675
shape,
676676
userSpec.DType,
677677
userSpec.Delimiter,
678+
nil,
678679
*meta.featureMap}
679680
} else {
680681
output[ct.Name()] = &columns.ColumnSpec{
@@ -683,6 +684,7 @@ func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*c
683684
shape,
684685
"float",
685686
",",
687+
nil,
686688
*meta.featureMap}
687689
}
688690
}
@@ -730,7 +732,7 @@ func (meta *metadata) getSparseColumnInfo() (map[string]*columns.ColumnSpec, err
730732
column, present := output[*name]
731733
if !present {
732734
shape := make([]int, 0, 1000)
733-
column := &columns.ColumnSpec{*name, true, shape, "int64", "", *meta.featureMap}
735+
column := &columns.ColumnSpec{*name, true, shape, "int64", "", nil, *meta.featureMap}
734736
column.DType = "int64"
735737
output[*name] = column
736738
}

sql/codegen_elasticdl.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ func newElasticDLDataConversionFiller(pr *extendedSelect, db *DB, recordIODataDi
143143
}
144144

145145
func newElasticDLTrainFiller(pr *extendedSelect, db *DB, session *pb.Session, ds *trainAndValDataset) (*elasticDLFiller, error) {
146-
resolved, err := resolveTrainClause(&pr.trainClause)
146+
resolved, err := resolveTrainClause(&pr.trainClause, &pr.standardSelect, nil)
147147
if err != nil {
148148
return nil, err
149149
}

sql/columns/bucket_column.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ func (bc *BucketColumn) GenerateCode(cs *ColumnSpec) ([]string, error) {
3636
strings.Join(strings.Split(fmt.Sprint(bc.Boundaries), " "), ","))}, nil
3737
}
3838

39+
// GetKey implements the FeatureColumn interface.
40+
func (bc *BucketColumn) GetKey() string {
41+
return bc.SourceColumn.Key
42+
}
43+
3944
// GetDelimiter implements the FeatureColumn interface.
4045
func (bc *BucketColumn) GetDelimiter() string {
4146
return ""
@@ -46,11 +51,6 @@ func (bc *BucketColumn) GetDtype() string {
4651
return ""
4752
}
4853

49-
// GetKey implements the FeatureColumn interface.
50-
func (bc *BucketColumn) GetKey() string {
51-
return bc.SourceColumn.Key
52-
}
53-
5454
// GetInputShape implements the FeatureColumn interface.
5555
func (bc *BucketColumn) GetInputShape() string {
5656
return bc.SourceColumn.GetInputShape()

sql/columns/category_id_column.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ type SequenceCategoryIDColumn struct {
3232
BucketSize int
3333
Delimiter string
3434
Dtype string
35-
IsSequence bool
3635
}
3736

3837
// GenerateCode implements the FeatureColumn interface.
@@ -41,6 +40,11 @@ func (cc *CategoryIDColumn) GenerateCode(cs *ColumnSpec) ([]string, error) {
4140
cc.Key, cc.BucketSize)}, nil
4241
}
4342

43+
// GetKey implements the FeatureColumn interface.
44+
func (cc *CategoryIDColumn) GetKey() string {
45+
return cc.Key
46+
}
47+
4448
// GetDelimiter implements the FeatureColumn interface.
4549
func (cc *CategoryIDColumn) GetDelimiter() string {
4650
return cc.Delimiter
@@ -51,11 +55,6 @@ func (cc *CategoryIDColumn) GetDtype() string {
5155
return cc.Dtype
5256
}
5357

54-
// GetKey implements the FeatureColumn interface.
55-
func (cc *CategoryIDColumn) GetKey() string {
56-
return cc.Key
57-
}
58-
5958
// GetInputShape implements the FeatureColumn interface.
6059
func (cc *CategoryIDColumn) GetInputShape() string {
6160
return fmt.Sprintf("[%d]", cc.BucketSize)

sql/columns/column_spec.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ type ColumnSpec struct {
3232
Shape []int
3333
DType string
3434
Delimiter string
35+
Vocabulary map[string]string // use a map to generate a list without duplication
3536
FeatureMap FeatureMap
3637
}
3738

sql/columns/cross_column.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ func (cc *CrossColumn) GenerateCode(cs *ColumnSpec) ([]string, error) {
5151
strings.Join(keysGenerated, ","), cc.HashBucketSize)}, nil
5252
}
5353

54+
// GetKey implements the FeatureColumn interface.
55+
func (cc *CrossColumn) GetKey() string {
56+
// NOTE: cross column is a feature on multiple column keys.
57+
return ""
58+
}
59+
5460
// GetDelimiter implements the FeatureColumn interface.
5561
func (cc *CrossColumn) GetDelimiter() string {
5662
return ""
@@ -61,12 +67,6 @@ func (cc *CrossColumn) GetDtype() string {
6167
return ""
6268
}
6369

64-
// GetKey implements the FeatureColumn interface.
65-
func (cc *CrossColumn) GetKey() string {
66-
// NOTE: cross column is a feature on multiple column keys.
67-
return ""
68-
}
69-
7070
// GetInputShape implements the FeatureColumn interface.
7171
func (cc *CrossColumn) GetInputShape() string {
7272
// NOTE: return empty since crossed column input shape is already determined

sql/columns/embedding_column.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,21 @@ import (
1919

2020
// EmbeddingColumn is the wrapper of `tf.feature_column.embedding_column`
2121
type EmbeddingColumn struct {
22+
Key string // only used when CategoryColumn = nil, feature derivation will fill up the details
2223
CategoryColumn interface{}
2324
Dimension int
2425
Combiner string
2526
Initializer string
2627
}
2728

29+
// GetKey implements the FeatureColumn interface.
30+
func (ec *EmbeddingColumn) GetKey() string {
31+
if ec.CategoryColumn != nil {
32+
return ec.CategoryColumn.(FeatureColumn).GetKey()
33+
}
34+
return ec.Key
35+
}
36+
2837
// GetDelimiter implements the FeatureColumn interface.
2938
func (ec *EmbeddingColumn) GetDelimiter() string {
3039
return ec.CategoryColumn.(FeatureColumn).GetDelimiter()
@@ -35,11 +44,6 @@ func (ec *EmbeddingColumn) GetDtype() string {
3544
return ec.CategoryColumn.(FeatureColumn).GetDtype()
3645
}
3746

38-
// GetKey implements the FeatureColumn interface.
39-
func (ec *EmbeddingColumn) GetKey() string {
40-
return ec.CategoryColumn.(FeatureColumn).GetKey()
41-
}
42-
4347
// GetInputShape implements the FeatureColumn interface.
4448
func (ec *EmbeddingColumn) GetInputShape() string {
4549
return ec.CategoryColumn.(FeatureColumn).GetInputShape()

sql/columns/feature_column.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,13 @@ type FeatureColumn interface {
3535
// feature_column code. And we maybe use one compound column's data to generate
3636
// multiple feature columns, so return a list of strings.
3737
GenerateCode(cs *ColumnSpec) ([]string, error)
38+
GetKey() string
39+
3840
// FIXME(typhoonzero): remove delimiter, dtype shape from feature column
3941
// get these from column spec claused or by feature derivation.
4042
GetDelimiter() string
4143
GetDtype() string
42-
GetKey() string
43-
// return input shape json string, like "[2,3]"
4444
GetInputShape() string
45+
4546
GetColumnType() int
4647
}

sql/expression_resolver.go

Lines changed: 56 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ type resolvedTrainClause struct {
6767
ColumnSpecs map[string][]*columns.ColumnSpec
6868
EngineParams engineSpec
6969
CustomModule *gitLabModule
70+
FeatureColumnInfered FeatureColumnMap
71+
ColumnSpecInfered ColumnSpecMap
7072
}
7173

7274
type resolvedPredictClause struct {
@@ -133,7 +135,7 @@ func getStringsAttr(attrs map[string]*attribute, key string, defaultValue []stri
133135
return defaultValue
134136
}
135137

136-
func resolveTrainClause(tc *trainClause) (*resolvedTrainClause, error) {
138+
func resolveTrainClause(tc *trainClause, slct *standardSelect, connConfig *connectionConfig) (*resolvedTrainClause, error) {
137139
modelName := tc.estimator
138140
preMadeModel := !strings.ContainsAny(modelName, ".")
139141
attrs, err := resolveAttribute(&tc.trainAttrs)
@@ -219,6 +221,11 @@ func resolveTrainClause(tc *trainClause) (*resolvedTrainClause, error) {
219221
fcMap[target] = fcs
220222
csMap[target] = css
221223
}
224+
// TODO(typhoonzero): use the derivated maps for codegen, skip checking error
225+
// since it's not used by codegen yet.
226+
// also, need to clean up what is inside "resolvedTrainClause", keep only
227+
// fcInfered, csInfered
228+
fcInfered, csInfered, err := InferFeatureColumns(slct, fcMap, csMap, connConfig)
222229

223230
return &resolvedTrainClause{
224231
IsPreMadeModel: preMadeModel,
@@ -246,7 +253,10 @@ func resolveTrainClause(tc *trainClause) (*resolvedTrainClause, error) {
246253
FeatureColumns: fcMap,
247254
ColumnSpecs: csMap,
248255
EngineParams: getEngineSpec(engineParams),
249-
CustomModule: customModel}, nil
256+
CustomModule: customModel,
257+
FeatureColumnInfered: fcInfered,
258+
ColumnSpecInfered: csInfered,
259+
}, nil
250260
}
251261

252262
func resolvePredictClause(pc *predictClause) (*resolvedPredictClause, error) {
@@ -284,7 +294,6 @@ func resolveTrainColumns(columnExprs *exprlist) ([]columns.FeatureColumn, []*col
284294
for _, expr := range *columnExprs {
285295
if expr.typ != 0 {
286296
// Column identifier like "COLUMN a1,b1"
287-
// FIXME(typhoonzero): infer the column spec here.
288297
c := &columns.NumericColumn{
289298
Key: expr.val,
290299
Shape: []int{1},
@@ -380,18 +389,6 @@ func expression2string(e interface{}) (string, error) {
380389
return "", fmt.Errorf("expression expected to be string, actual: %s", resolved)
381390
}
382391

383-
// func generateFeatureColumnCode(fcs []columns.FeatureColumn) (string, error) {
384-
// var codes = make([]string, 0, len(fcs))
385-
// for _, fc := range fcs {
386-
// code, err := fc.GenerateCode()
387-
// if err != nil {
388-
// return "", nil
389-
// }
390-
// codes = append(codes, code)
391-
// }
392-
// return fmt.Sprintf("[%s]", strings.Join(codes, ",")), nil
393-
// }
394-
395392
func resolveDelimiter(delimiter string) (string, error) {
396393
if strings.EqualFold(delimiter, comma) {
397394
return ",", nil
@@ -476,8 +473,7 @@ func resolveSeqCategoryIDColumn(el *exprlist) (*columns.SequenceCategoryIDColumn
476473
BucketSize: bucketSize,
477474
Delimiter: delimiter,
478475
// TODO(typhoonzero): support config dtype
479-
Dtype: "int64",
480-
IsSequence: true}, cs, nil
476+
Dtype: "int64"}, cs, nil
481477
}
482478

483479
func resolveCategoryIDColumn(el *exprlist) (*columns.CategoryIDColumn, *columns.ColumnSpec, error) {
@@ -551,50 +547,73 @@ func resolveCrossColumn(el *exprlist) (*columns.CrossColumn, error) {
551547
HashBucketSize: bucketSize}, nil
552548
}
553549

554-
func resolveEmbeddingColumn(el *exprlist) (*columns.EmbeddingColumn, error) {
550+
func resolveEmbeddingColumn(el *exprlist) (*columns.EmbeddingColumn, *columns.ColumnSpec, error) {
555551
if len(*el) != 4 && len(*el) != 5 {
556-
return nil, fmt.Errorf("bad EMBEDDING expression format: %s", *el)
552+
return nil, nil, fmt.Errorf("bad EMBEDDING expression format: %s", *el)
557553
}
554+
558555
sourceExprList := (*el)[1]
559556
var source columns.FeatureColumn
557+
var cs *columns.ColumnSpec
560558
var err error
559+
var innerCategoryColumnKey string
560+
561+
var catColumnResult interface{}
561562
if sourceExprList.typ == 0 {
562-
source, _, err = resolveColumn(&sourceExprList.sexp)
563+
source, cs, err = resolveColumn(&sourceExprList.sexp)
563564
if err != nil {
564-
return nil, err
565+
return nil, nil, err
565566
}
566-
} else {
567-
return nil, fmt.Errorf("key of EMBEDDING must be categorical column")
568-
}
569-
// TODO(uuleon) support other kinds of categorical column in the future
570-
var catColumn interface{}
571-
catColumn, ok := source.(*columns.CategoryIDColumn)
572-
if !ok {
573-
catColumn, ok = source.(*columns.SequenceCategoryIDColumn)
574-
if !ok {
575-
return nil, fmt.Errorf("key of EMBEDDING must be categorical column")
567+
// user may write EMBEDDING(SPARSE(...)) or EMBEDDING(DENSE(...))
568+
if cs != nil {
569+
innerCategoryColumnKey = cs.ColumnName
570+
catColumnResult = &columns.CategoryIDColumn{
571+
Key: cs.ColumnName,
572+
BucketSize: cs.Shape[0],
573+
Delimiter: cs.Delimiter,
574+
Dtype: cs.DType,
575+
}
576+
} else {
577+
// TODO(uuleon) support other kinds of categorical column in the future
578+
var catColumn interface{}
579+
catColumn, ok := source.(*columns.CategoryIDColumn)
580+
if !ok {
581+
catColumn, ok = source.(*columns.SequenceCategoryIDColumn)
582+
if !ok {
583+
return nil, nil, fmt.Errorf("key of EMBEDDING must be categorical column")
584+
}
585+
}
586+
// NOTE: to avoid golang multiple assignment compiler restrictions
587+
catColumnResult = catColumn
588+
innerCategoryColumnKey = source.GetKey()
576589
}
590+
} else {
591+
// generate a default CategoryIDColumn for later feature derivation.
592+
catColumnResult = nil
593+
innerCategoryColumnKey = sourceExprList.val
577594
}
595+
578596
dimension, err := strconv.Atoi((*el)[2].val)
579597
if err != nil {
580-
return nil, fmt.Errorf("bad EMBEDDING dimension: %s, err: %s", (*el)[2].val, err)
598+
return nil, nil, fmt.Errorf("bad EMBEDDING dimension: %s, err: %s", (*el)[2].val, err)
581599
}
582600
combiner, err := expression2string((*el)[3])
583601
if err != nil {
584-
return nil, fmt.Errorf("bad EMBEDDING combiner: %s, err: %s", (*el)[3], err)
602+
return nil, nil, fmt.Errorf("bad EMBEDDING combiner: %s, err: %s", (*el)[3], err)
585603
}
586604
initializer := ""
587605
if len(*el) == 5 {
588606
initializer, err = expression2string((*el)[4])
589607
if err != nil {
590-
return nil, fmt.Errorf("bad EMBEDDING initializer: %s, err: %s", (*el)[4], err)
608+
return nil, nil, fmt.Errorf("bad EMBEDDING initializer: %s, err: %s", (*el)[4], err)
591609
}
592610
}
593611
return &columns.EmbeddingColumn{
594-
CategoryColumn: catColumn,
612+
Key: innerCategoryColumnKey,
613+
CategoryColumn: catColumnResult,
595614
Dimension: dimension,
596615
Combiner: combiner,
597-
Initializer: initializer}, nil
616+
Initializer: initializer}, cs, nil
598617
}
599618

600619
func resolveNumericColumn(el *exprlist) (*columns.NumericColumn, error) {
@@ -710,8 +729,7 @@ func resolveColumn(el *exprlist) (columns.FeatureColumn, *columns.ColumnSpec, er
710729
case seqCategoryID:
711730
return resolveSeqCategoryIDColumn(el)
712731
case embedding:
713-
fc, err := resolveEmbeddingColumn(el)
714-
return fc, nil, err
732+
return resolveEmbeddingColumn(el)
715733
default:
716734
return nil, nil, fmt.Errorf("not supported expr: %s", head)
717735
}

0 commit comments

Comments
 (0)