Skip to content

Commit a8a804a

Browse files
authored
support pai random forests predict and explain (#1706)
1 parent 9b21371 commit a8a804a

6 files changed

Lines changed: 308 additions & 7 deletions

File tree

cmd/sqlflowserver/main_test.go

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
package main
1515

1616
import (
17+
"bytes"
1718
"context"
1819
"fmt"
1920
"io"
@@ -1280,16 +1281,62 @@ USING %s;`, caseTestTable, casePredictTable, caseInto)
12801281

12811282
}
12821283

1284+
func dropPAIModel(dataSource, modelName string) error {
1285+
code := fmt.Sprintf(`import subprocess
1286+
import sqlflow_submitter.db
1287+
1288+
driver, dsn = "%s".split("://")
1289+
assert driver == "maxcompute"
1290+
user, passwd, address, database = sqlflow_submitter.db.parseMaxComputeDSN(dsn)
1291+
1292+
cmd = "drop offlinemodel if exists %s"
1293+
subprocess.run(["odpscmd", "-u", user,
1294+
"-p", passwd,
1295+
"--project", database,
1296+
"--endpoint", address,
1297+
"-e", cmd],
1298+
check=True)
1299+
`, dataSource, modelName)
1300+
cmd := exec.Command("python", "-u")
1301+
cmd.Stdin = bytes.NewBufferString(code)
1302+
if e := cmd.Run(); e != nil {
1303+
return e
1304+
}
1305+
return nil
1306+
}
1307+
12831308
func CaseTrainPAIRandomForests(t *testing.T) {
12841309
a := assert.New(t)
1285-
trainSQL := fmt.Sprintf(`
1286-
SELECT * FROM %s.%s
1310+
err := dropPAIModel(dbConnStr, "my_rf_model")
1311+
a.NoError(err)
1312+
1313+
trainSQL := fmt.Sprintf(`SELECT * FROM %s
12871314
TO TRAIN randomforests
12881315
WITH tree_num = 3
12891316
LABEL class
12901317
INTO my_rf_model;
1291-
`, caseDB, caseTrainTable)
1292-
_, _, err := connectAndRunSQL(trainSQL)
1318+
`, caseTrainTable)
1319+
_, _, err = connectAndRunSQL(trainSQL)
1320+
if err != nil {
1321+
a.Fail("Run trainSQL error: %v", err)
1322+
}
1323+
1324+
predSQL := fmt.Sprintf(`SELECT * FROM %s
1325+
TO PREDICT %s.class
1326+
USING my_rf_model;
1327+
`, caseTestTable, casePredictTable)
1328+
_, _, err = connectAndRunSQL(predSQL)
1329+
if err != nil {
1330+
a.Fail("Run trainSQL error: %v", err)
1331+
}
1332+
1333+
explainSQL := fmt.Sprintf(`SELECT * FROM %s
1334+
TO EXPLAIN my_rf_model
1335+
WITH label_column = class
1336+
USING TreeExplainer
1337+
INTO %s.rf_model_explain;
1338+
`, caseTestTable, caseDB)
1339+
_, _, err = connectAndRunSQL(explainSQL)
12931340
if err != nil {
12941341
a.Fail("Run trainSQL error: %v", err)
12951342
}

pkg/sql/codegen/pai/codegen.go

Lines changed: 155 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,19 @@ import (
1717
"bytes"
1818
"encoding/json"
1919
"fmt"
20+
"log"
2021
"os"
2122
"path/filepath"
2223
"strconv"
2324
"strings"
2425
"text/template"
2526

27+
"github.com/aliyun/aliyun-oss-go-sdk/oss"
28+
"sqlflow.org/sqlflow/pkg/database"
2629
pb "sqlflow.org/sqlflow/pkg/proto"
2730
"sqlflow.org/sqlflow/pkg/sql/codegen/tensorflow"
2831
"sqlflow.org/sqlflow/pkg/sql/ir"
32+
"sqlflow.org/sqlflow/pkg/verifier"
2933
)
3034

3135
const entryFile = "entry.py"
@@ -160,7 +164,6 @@ func trainRandomForests(ir *ir.TrainStmt, session *pb.Session) (string, error) {
160164
}
161165
filler := &randomForestsTrainFiller{
162166
DataSource: session.DbConnStr,
163-
Select: ir.Select,
164167
TmpTrainTable: ir.TmpTrainTable,
165168
FeatureColumns: featureCols,
166169
LabelColumn: ir.Label.GetFieldDesc()[0].Name,
@@ -172,10 +175,44 @@ func trainRandomForests(ir *ir.TrainStmt, session *pb.Session) (string, error) {
172175
if err := tpl.Execute(&rfCode, filler); err != nil {
173176
return "", err
174177
}
175-
fmt.Println(rfCode.String())
176178
return rfCode.String(), nil
177179
}
178180

181+
// getColumnTypes is quiet like verify but accept a SQL string as input, and returns
182+
// an ordered list of the field types.
183+
// FIXME(typhoonzero): copied from executor_ir.go
184+
func getColumnTypes(slct string, db *database.DB) ([]string, []string, error) {
185+
rows, err := db.Query(slct)
186+
if err != nil {
187+
return nil, nil, err
188+
}
189+
defer rows.Close()
190+
191+
if !rows.Next() {
192+
return nil, nil, fmt.Errorf("query %s gives 0 row", slct)
193+
}
194+
195+
if rows.Err() != nil {
196+
return nil, nil, err
197+
}
198+
199+
columnTypes, err := rows.ColumnTypes()
200+
if err != nil {
201+
return nil, nil, err
202+
}
203+
204+
ft := []string{}
205+
flds := []string{}
206+
for _, ct := range columnTypes {
207+
_, fld := verifier.Decomp(ct.Name())
208+
typeName := ct.DatabaseTypeName()
209+
flds = append(flds, fld)
210+
ft = append(ft, typeName)
211+
}
212+
213+
return flds, ft, nil
214+
}
215+
179216
// Train generates a Python program for train a TensorFlow model.
180217
func Train(ir *ir.TrainStmt, session *pb.Session, modelName, cwd string) (string, error) {
181218
if strings.ToLower(ir.Estimator) == "randomforests" {
@@ -214,8 +251,74 @@ func TFTrainAndSave(ir *ir.TrainStmt, session *pb.Session, modelName string) (st
214251
return code + saveCode.String(), nil
215252
}
216253

254+
func ossFileExists(modelName string) (bool, error) {
255+
endpoint := os.Getenv("SQLFLOW_OSS_ENDPOINT")
256+
ak := os.Getenv("SQLFLOW_OSS_AK")
257+
sk := os.Getenv("SQLFLOW_OSS_SK")
258+
// NOTE(typhoonzero): PAI Tensorflow need SQLFLOW_OSS_CHECKPOINT_DIR, get bucket name from it
259+
ossCheckpointDir := os.Getenv("SQLFLOW_OSS_CHECKPOINT_DIR")
260+
ckptParts := strings.Split(ossCheckpointDir, "?")
261+
if len(ckptParts) != 2 {
262+
return false, fmt.Errorf("SQLFLOW_OSS_CHECKPOINT_DIR got wrong format")
263+
}
264+
urlParts := strings.Split(ckptParts[0], "://")
265+
if len(urlParts) != 2 {
266+
return false, fmt.Errorf("SQLFLOW_OSS_CHECKPOINT_DIR got wrong format")
267+
}
268+
bucketName := strings.Split(urlParts[1], "/")[0]
269+
270+
cli, err := oss.New(endpoint, ak, sk)
271+
if err != nil {
272+
return false, err
273+
}
274+
bucket, err := cli.Bucket(bucketName)
275+
if err != nil {
276+
return false, err
277+
}
278+
return bucket.IsObjectExist(modelName + "/sqlflow_model_desc")
279+
}
280+
281+
func predictRandomForests(ir *ir.PredictStmt, session *pb.Session) (string, error) {
282+
// NOTE(typhoonzero): for PAI random forests predicting, we can not load the TrainStmt
283+
// since the model saving is fully done by PAI. We directly use the columns in SELECT
284+
// statement for prediction, error will be reported by PAI job if the columns not match.
285+
db, err := database.OpenAndConnectDB(session.DbConnStr)
286+
if err != nil {
287+
return "", err
288+
}
289+
flds, _, err := getColumnTypes(ir.Select, db)
290+
if err != nil {
291+
return "", err
292+
}
293+
// drop result table if exists
294+
db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", ir.ResultTable))
295+
filler := &randomForestsPredictFiller{
296+
DataSource: session.DbConnStr,
297+
TmpPredictTable: ir.TmpPredictTable,
298+
FeatureColumns: flds,
299+
Save: ir.Using,
300+
ResultTable: ir.ResultTable,
301+
}
302+
var tpl = template.Must(template.New("RandomForestsPredict").Parse(randomForestsPredictTemplate))
303+
var rfCode bytes.Buffer
304+
if err := tpl.Execute(&rfCode, filler); err != nil {
305+
return "", err
306+
}
307+
return rfCode.String(), nil
308+
}
309+
217310
// Predict generates a Python program for train a TensorFlow model.
218311
func Predict(ir *ir.PredictStmt, session *pb.Session, modelName, cwd string) (string, error) {
312+
// FIXME(typhoonzero): if the model not exist on OSS, assume it's a random forest model
313+
// should use a general method to fetch the model and see the model type.
314+
exists, err := ossFileExists(modelName)
315+
if err != nil {
316+
return "", err
317+
}
318+
if !exists {
319+
log.Printf("predicting using pai random forests")
320+
return predictRandomForests(ir, session)
321+
}
219322
cc, err := GetClusterConfig(ir.Attributes)
220323
if err != nil {
221324
return "", err
@@ -253,3 +356,53 @@ func TFLoadAndPredict(ir *ir.PredictStmt, session *pb.Session, modelName string)
253356
}
254357
return code.String(), nil
255358
}
359+
360+
func explainRandomForests(ir *ir.ExplainStmt, session *pb.Session) (string, error) {
361+
// NOTE(typhoonzero): for PAI random forests predicting, we can not load the TrainStmt
362+
// since the model saving is fully done by PAI. We directly use the columns in SELECT
363+
// statement for prediction, error will be reported by PAI job if the columns not match.
364+
db, err := database.OpenAndConnectDB(session.DbConnStr)
365+
if err != nil {
366+
return "", err
367+
}
368+
flds, _, err := getColumnTypes(ir.Select, db)
369+
if err != nil {
370+
return "", err
371+
}
372+
// drop result table if exists
373+
db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", ir.Into))
374+
labelCol, ok := ir.Attributes["label_column"]
375+
if !ok {
376+
return "", fmt.Errorf("must specify WITH label_column when using pai random forest to explain models")
377+
}
378+
featureFileds := []string{}
379+
for _, f := range flds {
380+
if f != labelCol {
381+
featureFileds = append(featureFileds, f)
382+
}
383+
}
384+
385+
filler := &randomForestsExplainFiller{
386+
DataSource: session.DbConnStr,
387+
TmpExplainTable: ir.TmpExplainTable,
388+
FeatureColumns: featureFileds,
389+
LabelColumn: labelCol.(string),
390+
Save: ir.ModelName,
391+
ResultTable: ir.Into,
392+
}
393+
var tpl = template.Must(template.New("RandomForestsExplain").Parse(randomForestsExplainTemplate))
394+
var rfCode bytes.Buffer
395+
if err := tpl.Execute(&rfCode, filler); err != nil {
396+
return "", err
397+
}
398+
return rfCode.String(), nil
399+
}
400+
401+
// Explain generates a Python program for train a TensorFlow model.
402+
func Explain(ir *ir.ExplainStmt, session *pb.Session, modelName, cwd string) (string, error) {
403+
// NOTE(typhoonzero): only support random forests explain.
404+
if ir.Into == "" {
405+
return "", fmt.Errorf("explain PAI random forests model need INTO clause to output the explain result to a table")
406+
}
407+
return explainRandomForests(ir, session)
408+
}

pkg/sql/codegen/pai/template_random_forests.go

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,30 @@ package pai
1515

1616
type randomForestsTrainFiller struct {
1717
DataSource string
18-
Select string
1918
TmpTrainTable string
2019
FeatureColumns []string
2120
LabelColumn string
2221
Save string
2322
TreeNum int
2423
}
2524

25+
type randomForestsPredictFiller struct {
26+
DataSource string
27+
TmpPredictTable string
28+
FeatureColumns []string
29+
Save string
30+
ResultTable string
31+
}
32+
33+
type randomForestsExplainFiller struct {
34+
DataSource string
35+
TmpExplainTable string
36+
FeatureColumns []string
37+
LabelColumn string
38+
Save string
39+
ResultTable string
40+
}
41+
2642
const randomForestsTrainTemplate = `
2743
import os
2844
import subprocess
@@ -48,3 +64,55 @@ subprocess.run(["odpscmd", "-u", user,
4864
"-e", pai_cmd],
4965
check=True)
5066
`
67+
68+
const randomForestsPredictTemplate = `
69+
import os
70+
import subprocess
71+
import sqlflow_submitter.db
72+
73+
driver, dsn = "{{.DataSource}}".split("://")
74+
assert driver == "maxcompute"
75+
user, passwd, address, database = sqlflow_submitter.db.parseMaxComputeDSN(dsn)
76+
77+
column_names = []
78+
{{ range $colname := .FeatureColumns }}
79+
column_names.append("{{$colname}}")
80+
{{ end }}
81+
pai_cmd = 'pai -name prediction -project algo_public -DmodelName="{{.Save}}" -DinputTableName="{{.TmpPredictTable}}" -DoutputTableName="{{.ResultTable}}" -DfeatureColNames="%s" ' % (
82+
",".join(column_names)
83+
)
84+
85+
# Submit the tarball to PAI
86+
subprocess.run(["odpscmd", "-u", user,
87+
"-p", passwd,
88+
"--project", database,
89+
"--endpoint", address,
90+
"-e", pai_cmd],
91+
check=True)
92+
`
93+
94+
const randomForestsExplainTemplate = `
95+
import os
96+
import subprocess
97+
import sqlflow_submitter.db
98+
99+
driver, dsn = "{{.DataSource}}".split("://")
100+
assert driver == "maxcompute"
101+
user, passwd, address, database = sqlflow_submitter.db.parseMaxComputeDSN(dsn)
102+
103+
column_names = []
104+
{{ range $colname := .FeatureColumns }}
105+
column_names.append("{{$colname}}")
106+
{{ end }}
107+
pai_cmd = 'pai -name feature_importance -project algo_public -DmodelName="{{.Save}}" -DinputTableName="{{.TmpExplainTable}}" -DoutputTableName="{{.ResultTable}}" -DlabelColName="{{.LabelColumn}}" -DfeatureColNames="%s" ' % (
108+
",".join(column_names)
109+
)
110+
111+
# Submit the tarball to PAI
112+
subprocess.run(["odpscmd", "-u", user,
113+
"-p", passwd,
114+
"--project", database,
115+
"--endpoint", address,
116+
"-e", pai_cmd],
117+
check=True)
118+
`

pkg/sql/ir/ir.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,13 @@ type ExplainStmt struct {
182182
Attributes map[string]interface{}
183183
// Explainer types. For example TreeExplainer.
184184
Explainer string
185+
// ModelName is the model to be explained, e.g. TO EXPLAIN model_name
186+
ModelName string
185187
// Into stores the model explain result. Note that this field is optional.
186188
Into string
189+
// When SQLFLOW_submitter == "pai", tmp tables will be created for predicting task
190+
// see: pai_submitter.go
191+
TmpExplainTable string
187192
// TrainStmt is the TrainStmt used for generating the training job of the corresponding model
188193
TrainStmt *TrainStmt
189194
}

pkg/sql/ir_generator.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ func generateExplainStmt(slct *parser.SQLFlowSelectStmt, connStr, modelDir strin
305305
Attributes: attrs,
306306
Explainer: slct.Explainer,
307307
TrainStmt: trainStmt,
308+
ModelName: slct.TrainedModel,
308309
Into: slct.ExplainInto,
309310
}
310311

0 commit comments

Comments
 (0)