Skip to content

Commit 9132741

Browse files
authored
Support explaining DNNs using SHAP (#1694)
* Support explaining DNNs using SHAP * Reformatting to satisfy isort and yapf * Fix CI * Fix CI * Add unit test * Remove unnecessary Sprintf
1 parent 13989b5 commit 9132741

4 files changed

Lines changed: 47 additions & 17 deletions

File tree

cmd/sqlflowserver/main_test.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ func TestEnd2EndMySQL(t *testing.T) {
293293
t.Run("CasePredictXGBoostRegression", CasePredictXGBoostRegression)
294294
t.Run("CaseTrainDeepWideModel", CaseTrainDeepWideModel)
295295
t.Run("CaseTrainDeepWideModelOptimizer", CaseTrainDeepWideModelOptimizer)
296+
t.Run("CaseTrainAdaNetAndExplain", CaseTrainAdaNetAndExplain)
296297

297298
// Cases using feature derivation
298299
t.Run("CaseTrainTextClassificationIR", CaseTrainTextClassificationIR)
@@ -378,7 +379,7 @@ func TestEnd2EndHive(t *testing.T) {
378379
t.Run("CaseTrainSQLWithMetrics", CaseTrainSQLWithMetrics)
379380
t.Run("CaseTrainRegression", CaseTrainRegression)
380381
t.Run("CaseTrainCustomModel", CaseTrainCustomModel)
381-
t.Run("CaseTrainAdaNet", CaseTrainAdaNet)
382+
t.Run("CaseTrainAdaNetAndExplain", CaseTrainAdaNetAndExplain)
382383
t.Run("CaseTrainOptimizer", CaseTrainOptimizer)
383384
t.Run("CaseTrainDeepWideModel", CaseTrainDeepWideModel)
384385
t.Run("CaseTrainDeepWideModelOptimizer", CaseTrainDeepWideModelOptimizer)
@@ -939,16 +940,17 @@ INTO sqlflow_models.my_dnn_linear_model;`
939940
}
940941
}
941942

942-
func CaseTrainAdaNet(t *testing.T) {
943+
func CaseTrainAdaNetAndExplain(t *testing.T) {
943944
a := assert.New(t)
944945
trainSQL := `SELECT * FROM iris.train
945-
TO TRAIN sqlflow_models.AutoClassifier WITH model.n_classes = 3
946-
LABEL class
947-
INTO sqlflow_models.my_adanet_model;`
946+
TO TRAIN sqlflow_models.AutoClassifier WITH model.n_classes = 3 LABEL class INTO sqlflow_models.my_adanet_model;`
948947
_, _, err := connectAndRunSQL(trainSQL)
949948
if err != nil {
950949
a.Fail("run trainSQL error: %v", err)
951950
}
951+
explainSQL := `SELECT * FROM iris.test LIMIT 10 TO EXPLAIN sqlflow_models.my_adanet_model;`
952+
_, _, err = connectAndRunSQL(explainSQL)
953+
a.NoError(err)
952954
}
953955

954956
func CaseTrainDeepWideModelOptimizer(t *testing.T) {

pkg/sql/codegen/tensorflow/codegen.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -432,10 +432,6 @@ func Pred(predStmt *ir.PredictStmt, session *pb.Session) (string, error) {
432432

433433
// Explain generates a Python program to explain a trained model.
434434
func Explain(stmt *ir.ExplainStmt, session *pb.Session) (string, error) {
435-
if !strings.HasPrefix(stmt.TrainStmt.Estimator, "BoostedTrees") {
436-
return "", fmt.Errorf("unsupported model %s", stmt.TrainStmt.Estimator)
437-
}
438-
439435
modelParams, featureColumnsCode, fieldDescs, err := restoreModel(stmt.TrainStmt)
440436
if err != nil {
441437
return "", err

python/sqlflow_submitter/tensorflow/explain.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
import pandas as pd
2020
import seaborn as sns
21+
import shap
2122
import tensorflow as tf
2223
from sqlflow_submitter import explainer
2324
from sqlflow_submitter.db import buffered_db_writer, connect_with_data_source
@@ -83,8 +84,28 @@ def _input_fn():
8384

8485
model_params.update(feature_columns)
8586
estimator = estimator_cls(**model_params)
86-
result = estimator.experimental_predict_with_explanations(
87-
lambda: _input_fn())
87+
if estimator_cls in (tf.estimator.BoostedTreesClassifier,
88+
tf.estimator.BoostedTreesRegressor):
89+
explain_boosted_trees(datasource, estimator, _input_fn, plot_type,
90+
result_table, feature_column_names,
91+
hdfs_namenode_addr, hive_location, hdfs_user,
92+
hdfs_pass)
93+
else:
94+
shap_dataset = pd.DataFrame(columns=feature_column_names)
95+
for i, (features, label) in enumerate(_input_fn()):
96+
shap_dataset.loc[i] = [
97+
item.numpy()[0][0] for item in features.values()
98+
]
99+
explain_dnns(datasource, estimator, shap_dataset, plot_type,
100+
result_table, feature_column_names, hdfs_namenode_addr,
101+
hive_location, hdfs_user, hdfs_pass)
102+
103+
104+
def explain_boosted_trees(datasource, estimator, input_fn, plot_type,
105+
result_table, feature_column_names,
106+
hdfs_namenode_addr, hive_location, hdfs_user,
107+
hdfs_pass):
108+
result = estimator.experimental_predict_with_explanations(input_fn)
88109
pred_dicts = list(result)
89110
df_dfc = pd.DataFrame([pred['dfc'] for pred in pred_dicts])
90111
dfc_mean = df_dfc.abs().mean()
@@ -98,6 +119,23 @@ def _input_fn():
98119
explainer.plot_and_save(lambda: eval(plot_type)(df_dfc))
99120

100121

122+
def explain_dnns(datasource, estimator, shap_dataset, plot_type, result_table,
123+
feature_column_names, hdfs_namenode_addr, hive_location,
124+
hdfs_user, hdfs_pass):
125+
def predict(d):
126+
def input_fn():
127+
return tf.data.Dataset.from_tensor_slices(
128+
dict(pd.DataFrame(d, columns=shap_dataset.columns))).batch(1)
129+
130+
return np.array(
131+
[p['probabilities'][0] for p in estimator.predict(input_fn)])
132+
133+
shap_values = shap.KernelExplainer(predict,
134+
shap_dataset).shap_values(shap_dataset)
135+
explainer.plot_and_save(lambda: shap.summary_plot(
136+
shap_values, shap_dataset, show=False, plot_type=plot_type))
137+
138+
101139
def create_explain_result_table(conn, result_table):
102140
column_clause = ""
103141
if conn.driver == "mysql":

python/sqlflow_submitter/xgboost/explain.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,6 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
import json
15-
import sys
16-
17-
import matplotlib
18-
import matplotlib.pyplot as plt
19-
import numpy as np
2014
import pandas as pd
2115
import shap
2216
import xgboost as xgb

0 commit comments

Comments
 (0)