@@ -67,6 +67,8 @@ type resolvedTrainClause struct {
6767 ColumnSpecs map [string ][]* columns.ColumnSpec
6868 EngineParams engineSpec
6969 CustomModule * gitLabModule
70+ FeatureColumnInfered FeatureColumnMap
71+ ColumnSpecInfered ColumnSpecMap
7072}
7173
7274type resolvedPredictClause struct {
@@ -133,7 +135,7 @@ func getStringsAttr(attrs map[string]*attribute, key string, defaultValue []stri
133135 return defaultValue
134136}
135137
136- func resolveTrainClause (tc * trainClause ) (* resolvedTrainClause , error ) {
138+ func resolveTrainClause (tc * trainClause , slct * standardSelect , connConfig * connectionConfig ) (* resolvedTrainClause , error ) {
137139 modelName := tc .estimator
138140 preMadeModel := ! strings .ContainsAny (modelName , "." )
139141 attrs , err := resolveAttribute (& tc .trainAttrs )
@@ -219,6 +221,11 @@ func resolveTrainClause(tc *trainClause) (*resolvedTrainClause, error) {
219221 fcMap [target ] = fcs
220222 csMap [target ] = css
221223 }
224+ // TODO(typhoonzero): use the derivated maps for codegen, skip checking error
225+ // since it's not used by codegen yet.
226+ // also, need to clean up what is inside "resolvedTrainClause", keep only
227+ // fcInfered, csInfered
228+ fcInfered , csInfered , err := InferFeatureColumns (slct , fcMap , csMap , connConfig )
222229
223230 return & resolvedTrainClause {
224231 IsPreMadeModel : preMadeModel ,
@@ -246,7 +253,10 @@ func resolveTrainClause(tc *trainClause) (*resolvedTrainClause, error) {
246253 FeatureColumns : fcMap ,
247254 ColumnSpecs : csMap ,
248255 EngineParams : getEngineSpec (engineParams ),
249- CustomModule : customModel }, nil
256+ CustomModule : customModel ,
257+ FeatureColumnInfered : fcInfered ,
258+ ColumnSpecInfered : csInfered ,
259+ }, nil
250260}
251261
252262func resolvePredictClause (pc * predictClause ) (* resolvedPredictClause , error ) {
@@ -284,7 +294,6 @@ func resolveTrainColumns(columnExprs *exprlist) ([]columns.FeatureColumn, []*col
284294 for _ , expr := range * columnExprs {
285295 if expr .typ != 0 {
286296 // Column identifier like "COLUMN a1,b1"
287- // FIXME(typhoonzero): infer the column spec here.
288297 c := & columns.NumericColumn {
289298 Key : expr .val ,
290299 Shape : []int {1 },
@@ -380,18 +389,6 @@ func expression2string(e interface{}) (string, error) {
380389 return "" , fmt .Errorf ("expression expected to be string, actual: %s" , resolved )
381390}
382391
383- // func generateFeatureColumnCode(fcs []columns.FeatureColumn) (string, error) {
384- // var codes = make([]string, 0, len(fcs))
385- // for _, fc := range fcs {
386- // code, err := fc.GenerateCode()
387- // if err != nil {
388- // return "", nil
389- // }
390- // codes = append(codes, code)
391- // }
392- // return fmt.Sprintf("[%s]", strings.Join(codes, ",")), nil
393- // }
394-
395392func resolveDelimiter (delimiter string ) (string , error ) {
396393 if strings .EqualFold (delimiter , comma ) {
397394 return "," , nil
@@ -476,8 +473,7 @@ func resolveSeqCategoryIDColumn(el *exprlist) (*columns.SequenceCategoryIDColumn
476473 BucketSize : bucketSize ,
477474 Delimiter : delimiter ,
478475 // TODO(typhoonzero): support config dtype
479- Dtype : "int64" ,
480- IsSequence : true }, cs , nil
476+ Dtype : "int64" }, cs , nil
481477}
482478
483479func resolveCategoryIDColumn (el * exprlist ) (* columns.CategoryIDColumn , * columns.ColumnSpec , error ) {
@@ -551,50 +547,73 @@ func resolveCrossColumn(el *exprlist) (*columns.CrossColumn, error) {
551547 HashBucketSize : bucketSize }, nil
552548}
553549
554- func resolveEmbeddingColumn (el * exprlist ) (* columns.EmbeddingColumn , error ) {
550+ func resolveEmbeddingColumn (el * exprlist ) (* columns.EmbeddingColumn , * columns. ColumnSpec , error ) {
555551 if len (* el ) != 4 && len (* el ) != 5 {
556- return nil , fmt .Errorf ("bad EMBEDDING expression format: %s" , * el )
552+ return nil , nil , fmt .Errorf ("bad EMBEDDING expression format: %s" , * el )
557553 }
554+
558555 sourceExprList := (* el )[1 ]
559556 var source columns.FeatureColumn
557+ var cs * columns.ColumnSpec
560558 var err error
559+ var innerCategoryColumnKey string
560+
561+ var catColumnResult interface {}
561562 if sourceExprList .typ == 0 {
562- source , _ , err = resolveColumn (& sourceExprList .sexp )
563+ source , cs , err = resolveColumn (& sourceExprList .sexp )
563564 if err != nil {
564- return nil , err
565+ return nil , nil , err
565566 }
566- } else {
567- return nil , fmt .Errorf ("key of EMBEDDING must be categorical column" )
568- }
569- // TODO(uuleon) support other kinds of categorical column in the future
570- var catColumn interface {}
571- catColumn , ok := source .(* columns.CategoryIDColumn )
572- if ! ok {
573- catColumn , ok = source .(* columns.SequenceCategoryIDColumn )
574- if ! ok {
575- return nil , fmt .Errorf ("key of EMBEDDING must be categorical column" )
567+ // user may write EMBEDDING(SPARSE(...)) or EMBEDDING(DENSE(...))
568+ if cs != nil {
569+ innerCategoryColumnKey = cs .ColumnName
570+ catColumnResult = & columns.CategoryIDColumn {
571+ Key : cs .ColumnName ,
572+ BucketSize : cs .Shape [0 ],
573+ Delimiter : cs .Delimiter ,
574+ Dtype : cs .DType ,
575+ }
576+ } else {
577+ // TODO(uuleon) support other kinds of categorical column in the future
578+ var catColumn interface {}
579+ catColumn , ok := source .(* columns.CategoryIDColumn )
580+ if ! ok {
581+ catColumn , ok = source .(* columns.SequenceCategoryIDColumn )
582+ if ! ok {
583+ return nil , nil , fmt .Errorf ("key of EMBEDDING must be categorical column" )
584+ }
585+ }
586+ // NOTE: to avoid golang multiple assignment compiler restrictions
587+ catColumnResult = catColumn
588+ innerCategoryColumnKey = source .GetKey ()
576589 }
590+ } else {
591+ // generate a default CategoryIDColumn for later feature derivation.
592+ catColumnResult = nil
593+ innerCategoryColumnKey = sourceExprList .val
577594 }
595+
578596 dimension , err := strconv .Atoi ((* el )[2 ].val )
579597 if err != nil {
580- return nil , fmt .Errorf ("bad EMBEDDING dimension: %s, err: %s" , (* el )[2 ].val , err )
598+ return nil , nil , fmt .Errorf ("bad EMBEDDING dimension: %s, err: %s" , (* el )[2 ].val , err )
581599 }
582600 combiner , err := expression2string ((* el )[3 ])
583601 if err != nil {
584- return nil , fmt .Errorf ("bad EMBEDDING combiner: %s, err: %s" , (* el )[3 ], err )
602+ return nil , nil , fmt .Errorf ("bad EMBEDDING combiner: %s, err: %s" , (* el )[3 ], err )
585603 }
586604 initializer := ""
587605 if len (* el ) == 5 {
588606 initializer , err = expression2string ((* el )[4 ])
589607 if err != nil {
590- return nil , fmt .Errorf ("bad EMBEDDING initializer: %s, err: %s" , (* el )[4 ], err )
608+ return nil , nil , fmt .Errorf ("bad EMBEDDING initializer: %s, err: %s" , (* el )[4 ], err )
591609 }
592610 }
593611 return & columns.EmbeddingColumn {
594- CategoryColumn : catColumn ,
612+ Key : innerCategoryColumnKey ,
613+ CategoryColumn : catColumnResult ,
595614 Dimension : dimension ,
596615 Combiner : combiner ,
597- Initializer : initializer }, nil
616+ Initializer : initializer }, cs , nil
598617}
599618
600619func resolveNumericColumn (el * exprlist ) (* columns.NumericColumn , error ) {
@@ -710,8 +729,7 @@ func resolveColumn(el *exprlist) (columns.FeatureColumn, *columns.ColumnSpec, er
710729 case seqCategoryID :
711730 return resolveSeqCategoryIDColumn (el )
712731 case embedding :
713- fc , err := resolveEmbeddingColumn (el )
714- return fc , nil , err
732+ return resolveEmbeddingColumn (el )
715733 default :
716734 return nil , nil , fmt .Errorf ("not supported expr: %s" , head )
717735 }
0 commit comments