Skip to content

Commit a17aeab

Browse files
authored
Fix explain into when shap values is not list (#2853)
* fix explain into when shap values is not list * update * add e2e test * update * update * clean
1 parent f828911 commit a17aeab

7 files changed

Lines changed: 57 additions & 43 deletions

File tree

go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -378,16 +378,16 @@ func CasePAIMaxComputeTrainXGBoost(t *testing.T) {
378378
a := assert.New(t)
379379

380380
trainSQL := fmt.Sprintf(`SELECT * FROM %s
381-
TO TRAIN xgboost.gbtree
382-
WITH
383-
objective="multi:softprob",
384-
train.num_boost_round = 30,
385-
eta = 0.4,
386-
num_class = 3,
387-
train.batch_size=10,
388-
validation.select="select * from %s"
389-
LABEL class
390-
INTO e2etest_xgb_classi_model;`, caseTrainTable, caseTrainTable)
381+
TO TRAIN xgboost.gbtree
382+
WITH
383+
objective="multi:softprob",
384+
train.num_boost_round = 30,
385+
eta = 0.4,
386+
num_class = 3,
387+
train.batch_size=10,
388+
validation.select="select * from %s"
389+
LABEL class
390+
INTO e2etest_xgb_classi_model;`, caseTrainTable, caseTrainTable)
391391
_, _, _, err := connectAndRunSQL(trainSQL)
392392
a.NoError(err, "Run trainSQL error.")
393393

@@ -405,13 +405,20 @@ INTO %s.e2etest_xgb_evaluate_result;`, caseTestTable, caseDB)
405405
_, _, _, err = connectAndRunSQL(evalSQL)
406406
a.NoError(err, "Run evalSQL error.")
407407

408-
explainSQL := fmt.Sprintf(`SELECT * FROM %s
409-
TO EXPLAIN e2etest_xgb_classi_model
410-
WITH label_col=class
411-
USING TreeExplainer
412-
INTO %s.e2etest_xgb_explain_result;`, caseTrainTable, caseDB)
413-
_, _, _, err = connectAndRunSQL(explainSQL)
414-
a.NoError(err, "Run explainSQL error.")
408+
titanicTrain := fmt.Sprintf(`SELECT * FROM %s.sqlflow_titanic_train
409+
TO TRAIN xgboost.gbtree
410+
WITH objective="binary:logistic"
411+
LABEL survived
412+
INTO e2etest_xgb_titanic;`, caseDB)
413+
_, _, _, err = connectAndRunSQL(titanicTrain)
414+
a.NoError(err, "Run titanicTrain error.")
415+
416+
titanicExplain := fmt.Sprintf(`SELECT * FROM %s.sqlflow_titanic_train
417+
TO EXPLAIN e2etest_xgb_titanic
418+
WITH label_col=survived
419+
INTO %s.e2etest_titanic_explain_result;`, caseDB, caseDB)
420+
_, _, _, err = connectAndRunSQL(titanicExplain)
421+
a.NoError(err, "Run titanicExplain error.")
415422
}
416423

417424
func CasePAIMaxComputeTrainCustomModel(t *testing.T) {

go/codegen/xgboost/codegen_explain.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ func Explain(explainStmt *ir.ExplainStmt, session *pb.Session) (string, error) {
5757
FeatureColumnNames: fs,
5858
FeatureColumnCode: featureColumnCode,
5959
LabelJSON: string(l),
60+
ResultTable: explainStmt.Into,
6061
IsPAI: tf.IsPAI(),
6162
PAIExplainTable: explainStmt.TmpExplainTable,
6263
}

go/codegen/xgboost/template_explain.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ type explainFiller struct {
2525
FeatureColumnNames []string
2626
FeatureColumnCode string
2727
LabelJSON string
28+
ResultTable string
2829
IsPAI bool
2930
PAIExplainTable string
3031
}
@@ -48,10 +49,11 @@ transform_fn = xgboost_extended.feature_column.ComposedColumnTransformer(feature
4849
explain(
4950
datasource='''{{.DataSource}}''',
5051
select='''{{.DatasetSQL}}''',
51-
feature_field_meta=feature_field_meta,
52-
feature_column_names=feature_column_names,
52+
feature_field_meta=feature_field_meta,
53+
feature_column_names=feature_column_names,
5354
label_meta=label_meta,
5455
summary_params=summary_params,
56+
result_table="{{.ResultTable}}",
5557
is_pai="{{.IsPAI}}" == "true",
5658
pai_explain_table="{{.PAIExplainTable}}",
5759
transform_fn=transform_fn,

go/executor/executor.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -272,16 +272,20 @@ func (s *pythonExecutor) ExecuteExplain(cl *ir.ExplainStmt) error {
272272
return err
273273
}
274274
defer db.Close()
275+
276+
var modelType int
275277
if cl.TrainStmt.GetModelKind() == ir.XGBoost {
276278
code, err = xgboost.Explain(cl, s.Session)
277-
// TODO(typhoonzero): deal with XGBoost model explain result table creation.
279+
modelType = pai.ModelTypeXGBoost
278280
} else {
279281
code, err = tensorflow.Explain(cl, s.Session)
280-
if cl.Into != "" {
281-
err := createExplainResultTable(db, cl, cl.Into, pai.ModelTypeTF, cl.TrainStmt.Estimator)
282-
if err != nil {
283-
return err
284-
}
282+
modelType = pai.ModelTypeTF
283+
}
284+
285+
if cl.Into != "" {
286+
err := createExplainResultTable(db, cl, cl.Into, modelType, cl.TrainStmt.Estimator)
287+
if err != nil {
288+
return err
285289
}
286290
}
287291

go/executor/pai.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ import (
2323
"path"
2424
"path/filepath"
2525
"regexp"
26-
"sqlflow.org/sqlflow/go/verifier"
2726
"strings"
2827

28+
"sqlflow.org/sqlflow/go/verifier"
29+
2930
"sqlflow.org/sqlflow/go/codegen/optimize"
3031

3132
"github.com/aliyun/aliyun-oss-go-sdk/oss"
@@ -636,9 +637,13 @@ func getCreateShapResultSQL(db *database.DB, tableName string, selectStmt string
636637
return "", err
637638
}
638639
columnDefList := []string{}
640+
columnType := "STRING"
641+
if db.DriverName == "mysql" {
642+
columnType = "VARCHAR(255)"
643+
}
639644
for _, fieldName := range flds {
640645
if fieldName != labelCol {
641-
columnDefList = append(columnDefList, fmt.Sprintf("%s STRING", fieldName))
646+
columnDefList = append(columnDefList, fmt.Sprintf("%s %s", fieldName, columnType))
642647
}
643648
}
644649
createStmt := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (%s);`, tableName, strings.Join(columnDefList, ","))

python/runtime/xgboost/dataset.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -252,20 +252,15 @@ def pai_dataset(filename,
252252
from subprocess import Popen, PIPE
253253
from multiprocessing.dummy import Pool # ThreadPool
254254
import queue
255-
256255
dname = filename
257256
if single_file:
258257
dname = filename + '.dir'
259-
260258
if os.path.exists(dname):
261259
shutil.rmtree(dname, ignore_errors=True)
262260

263261
os.mkdir(dname)
264-
265262
slice_count = get_pai_table_slice_count(pai_table, nworkers, batch_size)
266-
267263
thread_num = min(int(slice_count / nworkers), 128)
268-
269264
pool = Pool(thread_num)
270265
complete_queue = queue.Queue()
271266

@@ -337,7 +332,7 @@ def pai_download_table_data_worker(dname, feature_metas, feature_column_names,
337332
feature_column_names, *feature_column_transformers)
338333

339334
conn = PaiIOConnection.from_table(pai_table, slice_id, slice_count)
340-
gen = db.db_generator(conn, None)()
335+
gen = db.db_generator(conn, None, label_meta=label_meta)()
341336
selected_cols = db.selected_cols(conn, None)
342337
filename = "{}/{}.txt".format(dname, slice_id)
343338
dump_dmatrix(filename,

python/runtime/xgboost/explain.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -173,23 +173,23 @@ def explain(datasource,
173173
pai_explain_table,
174174
transform_fn=transform_fn,
175175
feature_column_code=feature_column_code)
176-
177176
shap_values, shap_interaction_values, expected_value = xgb_shap_values(x)
178-
179177
if result_table != "":
180178
if is_pai:
181179
from runtime.dbapi.paiio import PaiIOConnection
182180
conn = PaiIOConnection.from_table(result_table)
183-
# TODO(typhoonzero): the shape of shap_values is
184-
# (3, num_samples, num_features), use the first
185-
# dimension here, should find out how to use
186-
# the other two.
187181
else:
188182
conn = db.connect_with_data_source(datasource)
189-
190-
write_shap_values(shap_values[0], conn, result_table,
191-
feature_column_names)
192-
return
183+
# TODO(typhoonzero): the shap_values is may be a
184+
# list of shape [3, num_samples, num_features],
185+
# use the first dimension here, should find out
186+
# when to use the other two. When shap_values is
187+
# not a list it can be directly used.
188+
if isinstance(shap_values, list):
189+
to_write = shap_values[0]
190+
else:
191+
to_write = shap_values
192+
write_shap_values(to_write, conn, result_table, feature_column_names)
193193

194194
if summary_params.get("plot_type") == "decision":
195195
explainer.plot_and_save(

0 commit comments

Comments
 (0)