1818import numpy as np
1919import pandas as pd
2020import seaborn as sns
21+ import shap
2122import tensorflow as tf
2223from sqlflow_submitter import explainer
2324from sqlflow_submitter .db import buffered_db_writer , connect_with_data_source
@@ -83,8 +84,28 @@ def _input_fn():
8384
8485 model_params .update (feature_columns )
8586 estimator = estimator_cls (** model_params )
86- result = estimator .experimental_predict_with_explanations (
87- lambda : _input_fn ())
87+ if estimator_cls in (tf .estimator .BoostedTreesClassifier ,
88+ tf .estimator .BoostedTreesRegressor ):
89+ explain_boosted_trees (datasource , estimator , _input_fn , plot_type ,
90+ result_table , feature_column_names ,
91+ hdfs_namenode_addr , hive_location , hdfs_user ,
92+ hdfs_pass )
93+ else :
94+ shap_dataset = pd .DataFrame (columns = feature_column_names )
95+ for i , (features , label ) in enumerate (_input_fn ()):
96+ shap_dataset .loc [i ] = [
97+ item .numpy ()[0 ][0 ] for item in features .values ()
98+ ]
99+ explain_dnns (datasource , estimator , shap_dataset , plot_type ,
100+ result_table , feature_column_names , hdfs_namenode_addr ,
101+ hive_location , hdfs_user , hdfs_pass )
102+
103+
104+ def explain_boosted_trees (datasource , estimator , input_fn , plot_type ,
105+ result_table , feature_column_names ,
106+ hdfs_namenode_addr , hive_location , hdfs_user ,
107+ hdfs_pass ):
108+ result = estimator .experimental_predict_with_explanations (input_fn )
88109 pred_dicts = list (result )
89110 df_dfc = pd .DataFrame ([pred ['dfc' ] for pred in pred_dicts ])
90111 dfc_mean = df_dfc .abs ().mean ()
@@ -98,6 +119,23 @@ def _input_fn():
98119 explainer .plot_and_save (lambda : eval (plot_type )(df_dfc ))
99120
100121
122+ def explain_dnns (datasource , estimator , shap_dataset , plot_type , result_table ,
123+ feature_column_names , hdfs_namenode_addr , hive_location ,
124+ hdfs_user , hdfs_pass ):
125+ def predict (d ):
126+ def input_fn ():
127+ return tf .data .Dataset .from_tensor_slices (
128+ dict (pd .DataFrame (d , columns = shap_dataset .columns ))).batch (1 )
129+
130+ return np .array (
131+ [p ['probabilities' ][0 ] for p in estimator .predict (input_fn )])
132+
133+ shap_values = shap .KernelExplainer (predict ,
134+ shap_dataset ).shap_values (shap_dataset )
135+ explainer .plot_and_save (lambda : shap .summary_plot (
136+ shap_values , shap_dataset , show = False , plot_type = plot_type ))
137+
138+
101139def create_explain_result_table (conn , result_table ):
102140 column_clause = ""
103141 if conn .driver == "mysql" :
0 commit comments