@@ -17,15 +17,19 @@ import (
1717 "bytes"
1818 "encoding/json"
1919 "fmt"
20+ "log"
2021 "os"
2122 "path/filepath"
2223 "strconv"
2324 "strings"
2425 "text/template"
2526
27+ "github.com/aliyun/aliyun-oss-go-sdk/oss"
28+ "sqlflow.org/sqlflow/pkg/database"
2629 pb "sqlflow.org/sqlflow/pkg/proto"
2730 "sqlflow.org/sqlflow/pkg/sql/codegen/tensorflow"
2831 "sqlflow.org/sqlflow/pkg/sql/ir"
32+ "sqlflow.org/sqlflow/pkg/verifier"
2933)
3034
3135const entryFile = "entry.py"
@@ -160,7 +164,6 @@ func trainRandomForests(ir *ir.TrainStmt, session *pb.Session) (string, error) {
160164 }
161165 filler := & randomForestsTrainFiller {
162166 DataSource : session .DbConnStr ,
163- Select : ir .Select ,
164167 TmpTrainTable : ir .TmpTrainTable ,
165168 FeatureColumns : featureCols ,
166169 LabelColumn : ir .Label .GetFieldDesc ()[0 ].Name ,
@@ -172,10 +175,44 @@ func trainRandomForests(ir *ir.TrainStmt, session *pb.Session) (string, error) {
172175 if err := tpl .Execute (& rfCode , filler ); err != nil {
173176 return "" , err
174177 }
175- fmt .Println (rfCode .String ())
176178 return rfCode .String (), nil
177179}
178180
181+ // getColumnTypes is quiet like verify but accept a SQL string as input, and returns
182+ // an ordered list of the field types.
183+ // FIXME(typhoonzero): copied from executor_ir.go
184+ func getColumnTypes (slct string , db * database.DB ) ([]string , []string , error ) {
185+ rows , err := db .Query (slct )
186+ if err != nil {
187+ return nil , nil , err
188+ }
189+ defer rows .Close ()
190+
191+ if ! rows .Next () {
192+ return nil , nil , fmt .Errorf ("query %s gives 0 row" , slct )
193+ }
194+
195+ if rows .Err () != nil {
196+ return nil , nil , err
197+ }
198+
199+ columnTypes , err := rows .ColumnTypes ()
200+ if err != nil {
201+ return nil , nil , err
202+ }
203+
204+ ft := []string {}
205+ flds := []string {}
206+ for _ , ct := range columnTypes {
207+ _ , fld := verifier .Decomp (ct .Name ())
208+ typeName := ct .DatabaseTypeName ()
209+ flds = append (flds , fld )
210+ ft = append (ft , typeName )
211+ }
212+
213+ return flds , ft , nil
214+ }
215+
179216// Train generates a Python program for train a TensorFlow model.
180217func Train (ir * ir.TrainStmt , session * pb.Session , modelName , cwd string ) (string , error ) {
181218 if strings .ToLower (ir .Estimator ) == "randomforests" {
@@ -214,8 +251,74 @@ func TFTrainAndSave(ir *ir.TrainStmt, session *pb.Session, modelName string) (st
214251 return code + saveCode .String (), nil
215252}
216253
254+ func ossFileExists (modelName string ) (bool , error ) {
255+ endpoint := os .Getenv ("SQLFLOW_OSS_ENDPOINT" )
256+ ak := os .Getenv ("SQLFLOW_OSS_AK" )
257+ sk := os .Getenv ("SQLFLOW_OSS_SK" )
258+ // NOTE(typhoonzero): PAI Tensorflow need SQLFLOW_OSS_CHECKPOINT_DIR, get bucket name from it
259+ ossCheckpointDir := os .Getenv ("SQLFLOW_OSS_CHECKPOINT_DIR" )
260+ ckptParts := strings .Split (ossCheckpointDir , "?" )
261+ if len (ckptParts ) != 2 {
262+ return false , fmt .Errorf ("SQLFLOW_OSS_CHECKPOINT_DIR got wrong format" )
263+ }
264+ urlParts := strings .Split (ckptParts [0 ], "://" )
265+ if len (urlParts ) != 2 {
266+ return false , fmt .Errorf ("SQLFLOW_OSS_CHECKPOINT_DIR got wrong format" )
267+ }
268+ bucketName := strings .Split (urlParts [1 ], "/" )[0 ]
269+
270+ cli , err := oss .New (endpoint , ak , sk )
271+ if err != nil {
272+ return false , err
273+ }
274+ bucket , err := cli .Bucket (bucketName )
275+ if err != nil {
276+ return false , err
277+ }
278+ return bucket .IsObjectExist (modelName + "/sqlflow_model_desc" )
279+ }
280+
281+ func predictRandomForests (ir * ir.PredictStmt , session * pb.Session ) (string , error ) {
282+ // NOTE(typhoonzero): for PAI random forests predicting, we can not load the TrainStmt
283+ // since the model saving is fully done by PAI. We directly use the columns in SELECT
284+ // statement for prediction, error will be reported by PAI job if the columns not match.
285+ db , err := database .OpenAndConnectDB (session .DbConnStr )
286+ if err != nil {
287+ return "" , err
288+ }
289+ flds , _ , err := getColumnTypes (ir .Select , db )
290+ if err != nil {
291+ return "" , err
292+ }
293+ // drop result table if exists
294+ db .Exec (fmt .Sprintf ("DROP TABLE IF EXISTS %s;" , ir .ResultTable ))
295+ filler := & randomForestsPredictFiller {
296+ DataSource : session .DbConnStr ,
297+ TmpPredictTable : ir .TmpPredictTable ,
298+ FeatureColumns : flds ,
299+ Save : ir .Using ,
300+ ResultTable : ir .ResultTable ,
301+ }
302+ var tpl = template .Must (template .New ("RandomForestsPredict" ).Parse (randomForestsPredictTemplate ))
303+ var rfCode bytes.Buffer
304+ if err := tpl .Execute (& rfCode , filler ); err != nil {
305+ return "" , err
306+ }
307+ return rfCode .String (), nil
308+ }
309+
217310// Predict generates a Python program for train a TensorFlow model.
218311func Predict (ir * ir.PredictStmt , session * pb.Session , modelName , cwd string ) (string , error ) {
312+ // FIXME(typhoonzero): if the model not exist on OSS, assume it's a random forest model
313+ // should use a general method to fetch the model and see the model type.
314+ exists , err := ossFileExists (modelName )
315+ if err != nil {
316+ return "" , err
317+ }
318+ if ! exists {
319+ log .Printf ("predicting using pai random forests" )
320+ return predictRandomForests (ir , session )
321+ }
219322 cc , err := GetClusterConfig (ir .Attributes )
220323 if err != nil {
221324 return "" , err
@@ -253,3 +356,53 @@ func TFLoadAndPredict(ir *ir.PredictStmt, session *pb.Session, modelName string)
253356 }
254357 return code .String (), nil
255358}
359+
360+ func explainRandomForests (ir * ir.ExplainStmt , session * pb.Session ) (string , error ) {
361+ // NOTE(typhoonzero): for PAI random forests predicting, we can not load the TrainStmt
362+ // since the model saving is fully done by PAI. We directly use the columns in SELECT
363+ // statement for prediction, error will be reported by PAI job if the columns not match.
364+ db , err := database .OpenAndConnectDB (session .DbConnStr )
365+ if err != nil {
366+ return "" , err
367+ }
368+ flds , _ , err := getColumnTypes (ir .Select , db )
369+ if err != nil {
370+ return "" , err
371+ }
372+ // drop result table if exists
373+ db .Exec (fmt .Sprintf ("DROP TABLE IF EXISTS %s;" , ir .Into ))
374+ labelCol , ok := ir .Attributes ["label_column" ]
375+ if ! ok {
376+ return "" , fmt .Errorf ("must specify WITH label_column when using pai random forest to explain models" )
377+ }
378+ featureFileds := []string {}
379+ for _ , f := range flds {
380+ if f != labelCol {
381+ featureFileds = append (featureFileds , f )
382+ }
383+ }
384+
385+ filler := & randomForestsExplainFiller {
386+ DataSource : session .DbConnStr ,
387+ TmpExplainTable : ir .TmpExplainTable ,
388+ FeatureColumns : featureFileds ,
389+ LabelColumn : labelCol .(string ),
390+ Save : ir .ModelName ,
391+ ResultTable : ir .Into ,
392+ }
393+ var tpl = template .Must (template .New ("RandomForestsExplain" ).Parse (randomForestsExplainTemplate ))
394+ var rfCode bytes.Buffer
395+ if err := tpl .Execute (& rfCode , filler ); err != nil {
396+ return "" , err
397+ }
398+ return rfCode .String (), nil
399+ }
400+
401+ // Explain generates a Python program for train a TensorFlow model.
402+ func Explain (ir * ir.ExplainStmt , session * pb.Session , modelName , cwd string ) (string , error ) {
403+ // NOTE(typhoonzero): only support random forests explain.
404+ if ir .Into == "" {
405+ return "" , fmt .Errorf ("explain PAI random forests model need INTO clause to output the explain result to a table" )
406+ }
407+ return explainRandomForests (ir , session )
408+ }
0 commit comments