1414package experimental
1515
1616import (
17+ "encoding/binary"
1718 "fmt"
19+ "github.com/bitly/go-simplejson"
1820 "net/url"
21+ "sqlflow.org/sqlflow/go/model"
22+ "sqlflow.org/sqlflow/go/sqlfs"
1923 "strings"
2024
2125 "sqlflow.org/sqlflow/go/database"
@@ -24,37 +28,130 @@ import (
2428 pb "sqlflow.org/sqlflow/go/proto"
2529)
2630
27- // TODO(sneaxiy): implement this method to distinguish whether
28- // a model is a XGBoost model.
29- func isTrainedXBoostModel (modelName string ) bool {
30- return true
31- }
32-
33- func generateStepCode (sqlStmt ir.SQLFlowStmt , stepIndex int , session * pb.Session ) (string , error ) {
31+ func generateStepCodeAndImage (sqlStmt ir.SQLFlowStmt , stepIndex int , session * pb.Session , sqlStmts []ir.SQLFlowStmt ) (string , string , error ) {
3432 switch stmt := sqlStmt .(type ) {
3533 case * ir.TrainStmt :
36- return generateTrainCode (stmt , stepIndex , session )
34+ return generateTrainCodeAndImage (stmt , stepIndex , session )
3735 case * ir.PredictStmt :
38- return generatePredictCode (stmt , stepIndex , session )
36+ return generatePredictCodeAndImage (stmt , stepIndex , session , sqlStmts )
3937 case * ir.NormalStmt :
40- return GenerateNormalStmtStep (string (* stmt ), session , stepIndex )
38+ code , err := generateNormalStmtStep (string (* stmt ), stepIndex , session )
39+ return code , "" , err
4140 default :
42- return "" , fmt .Errorf ("not implemented stmt execution type %v" , stmt )
41+ return "" , "" , fmt .Errorf ("not implemented stmt execution type %v" , stmt )
42+ }
43+ }
44+
45+ func generateTrainCodeAndImage (trainStmt * ir.TrainStmt , stepIndex int , session * pb.Session ) (string , string , error ) {
46+ isXGBoost := isXGBoostEstimator (trainStmt .Estimator )
47+ if isXGBoost {
48+ code , err := XGBoostGenerateTrain (trainStmt , stepIndex , session )
49+ if err != nil {
50+ return "" , "" , err
51+ }
52+ return code , trainStmt .ModelImage , nil
53+ }
54+ return "" , "" , fmt .Errorf ("not implemented estimator type %s" , trainStmt .Estimator )
55+ }
56+
57+ func generatePredictCodeAndImage (predStmt * ir.PredictStmt , stepIndex int , session * pb.Session , sqlStmts []ir.SQLFlowStmt ) (string , string , error ) {
58+ trainStmt := findModelGenerationTrainStmt (predStmt .Using , stepIndex , sqlStmts )
59+ image := ""
60+ isXGBoost := false
61+ if trainStmt != nil {
62+ image = trainStmt .ModelImage
63+ isXGBoost = isXGBoostEstimator (trainStmt .Estimator )
64+ } else {
65+ meta , err := getModelMetadata (session , predStmt .Using )
66+ if err != nil {
67+ return "" , "" , err
68+ }
69+ image = meta .imageName ()
70+ isXGBoost = meta .isXGBoostModel ()
4371 }
72+
73+ if isXGBoost {
74+ code , err := XGBoostGeneratePredict (predStmt , stepIndex , session )
75+ if err != nil {
76+ return "" , "" , err
77+ }
78+ return code , image , nil
79+ }
80+ return "" , "" , fmt .Errorf ("not implemented model type" )
4481}
4582
46- func generateTrainCode (trainStmt * ir.TrainStmt , stepIndex int , session * pb.Session ) (string , error ) {
47- if strings .HasPrefix (strings .ToUpper (trainStmt .Estimator ), "XGBOOST." ) {
48- return XGBoostGenerateTrain (trainStmt , stepIndex , session )
83+ // findModelGenerationTrainStmt finds the *ir.TrainStmt that generates the model named `modelName`.
84+ // TODO(sneaxiy): find a better way to do this when we have a well designed dependency analysis.
85+ func findModelGenerationTrainStmt (modelName string , idx int , sqlStmts []ir.SQLFlowStmt ) * ir.TrainStmt {
86+ idx --
87+ for idx >= 0 {
88+ trainStmt , ok := sqlStmts [idx ].(* ir.TrainStmt )
89+ if ok && trainStmt .Into == modelName {
90+ return trainStmt
91+ }
92+ idx --
4993 }
50- return "" , fmt .Errorf ("not implemented estimator type %s" , trainStmt .Estimator )
94+ return nil
95+ }
96+
97+ func isXGBoostEstimator (estimator string ) bool {
98+ return strings .HasPrefix (strings .ToUpper (estimator ), "XGBOOST." )
99+ }
100+
101+ type metadata simplejson.Json
102+
103+ func (m * metadata ) imageName () string {
104+ return (* simplejson .Json )(m ).Get ("model_repo_image" ).MustString ()
51105}
52106
53- func generatePredictCode (predStmt * ir.PredictStmt , stepIndex int , session * pb.Session ) (string , error ) {
54- if isTrainedXBoostModel (predStmt .Using ) {
55- return XGBoostGeneratePredict (predStmt , stepIndex , session )
107+ func (m * metadata ) isXGBoostModel () bool {
108+ return (* simplejson .Json )(m ).Get ("model_type" ).MustInt () == model .XGBOOST
109+ }
110+
111+ func getModelMetadata (session * pb.Session , table string ) (* metadata , error ) {
112+ submitter := getSubmitter (session )
113+ if submitter == "local" {
114+ return getModelMetadataFromDB (session .DbConnStr , table )
115+ }
116+ return nil , fmt .Errorf ("not supported submitter %s" , submitter )
117+ }
118+
119+ func getModelMetadataFromDB (dbConnStr , table string ) (* metadata , error ) {
120+ db , err := database .OpenAndConnectDB (dbConnStr )
121+ if err != nil {
122+ return nil , err
123+ }
124+ defer db .Close ()
125+
126+ fs , err := sqlfs .Open (db .DB , table )
127+ if err != nil {
128+ return nil , err
129+ }
130+ defer fs .Close ()
131+
132+ lengthBytes := make ([]byte , 8 )
133+ readCnt , err := fs .Read (lengthBytes )
134+ if err != nil {
135+ return nil , err
136+ }
137+ if readCnt != 8 {
138+ return nil , fmt .Errorf ("invalid model table" )
139+ }
140+
141+ length := binary .LittleEndian .Uint64 (lengthBytes )
142+ jsonBytes := make ([]byte , length )
143+ readCnt , err = fs .Read (jsonBytes )
144+ if err != nil {
145+ return nil , err
146+ }
147+ if readCnt != int (length ) {
148+ return nil , fmt .Errorf ("invalid model metadata" )
149+ }
150+ json , err := simplejson .NewJson (jsonBytes )
151+ if err != nil {
152+ return nil , err
56153 }
57- return "" , fmt . Errorf ( "not implemented model type" )
154+ return ( * metadata )( json ), nil
58155}
59156
60157func initializeAndCheckAttributes (stmt ir.SQLFlowStmt ) error {
0 commit comments