@@ -34,6 +34,8 @@ func generateStepCodeAndImage(sqlStmt ir.SQLFlowStmt, stepIndex int, session *pb
3434 return generateTrainCodeAndImage (stmt , stepIndex , session )
3535 case * ir.PredictStmt :
3636 return generatePredictCodeAndImage (stmt , stepIndex , session , sqlStmts )
37+ case * ir.EvaluateStmt :
38+ return generateEvaluationCodeAndImage (stmt , stepIndex , session , sqlStmts )
3739 case * ir.NormalStmt :
3840 code , err := generateNormalStmtStep (string (* stmt ), stepIndex , session )
3941 return code , "" , err
@@ -55,9 +57,9 @@ func generateTrainCodeAndImage(trainStmt *ir.TrainStmt, stepIndex int, session *
5557}
5658
5759func generatePredictCodeAndImage (predStmt * ir.PredictStmt , stepIndex int , session * pb.Session , sqlStmts []ir.SQLFlowStmt ) (string , string , error ) {
58- trainStmt := findModelGenerationTrainStmt (predStmt .Using , stepIndex , sqlStmts )
5960 image := ""
6061 isXGBoost := false
62+ trainStmt := findModelGenerationTrainStmt (predStmt .Using , stepIndex , sqlStmts )
6163 if trainStmt != nil {
6264 image = trainStmt .ModelImage
6365 isXGBoost = isXGBoostEstimator (trainStmt .Estimator )
@@ -80,6 +82,32 @@ func generatePredictCodeAndImage(predStmt *ir.PredictStmt, stepIndex int, sessio
8082 return "" , "" , fmt .Errorf ("not implemented model type" )
8183}
8284
85+ func generateEvaluationCodeAndImage (evalStmt * ir.EvaluateStmt , stepIndex int , session * pb.Session , sqlStmts []ir.SQLFlowStmt ) (string , string , error ) {
86+ image := ""
87+ isXGBoost := false
88+ trainStmt := findModelGenerationTrainStmt (evalStmt .ModelName , stepIndex , sqlStmts )
89+ if trainStmt != nil {
90+ image = trainStmt .ModelImage
91+ isXGBoost = isXGBoostEstimator (trainStmt .Estimator )
92+ } else {
93+ meta , err := getModelMetadata (session , evalStmt .ModelName )
94+ if err != nil {
95+ return "" , "" , err
96+ }
97+ image = meta .imageName ()
98+ isXGBoost = meta .isXGBoostModel ()
99+ }
100+
101+ if isXGBoost {
102+ code , err := XGBoostGenerateEvaluation (evalStmt , stepIndex , session )
103+ if err != nil {
104+ return "" , "" , err
105+ }
106+ return code , image , nil
107+ }
108+ return "" , "" , fmt .Errorf ("not implemented model type" )
109+ }
110+
83111// findModelGenerationTrainStmt finds the *ir.TrainStmt that generates the model named `modelName`.
84112// TODO(sneaxiy): find a better way to do this when we have a well designed dependency analysis.
85113func findModelGenerationTrainStmt (modelName string , idx int , sqlStmts []ir.SQLFlowStmt ) * ir.TrainStmt {
@@ -144,7 +172,7 @@ func getModelMetadataFromDB(dbConnStr, table string) (*metadata, error) {
144172 if err != nil {
145173 return nil , err
146174 }
147- if readCnt != int ( length ) {
175+ if uint64 ( readCnt ) != length {
148176 return nil , fmt .Errorf ("invalid model metadata" )
149177 }
150178 json , err := simplejson .NewJson (jsonBytes )
0 commit comments