Skip to content

Commit c65ac81

Browse files
authored
Add unit test in python for dnn explanation (#1714)
1 parent a8a804a commit c65ac81

1 file changed

Lines changed: 36 additions & 0 deletions

File tree

python/sqlflow_submitter/tensorflow/explain_example.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sqlflow_submitter.tensorflow.train import train
2222

2323
if __name__ == "__main__":
24+
# Train and explain BoostedTreesClassifier
2425
train(datasource=datasource,
2526
estimator=tf.estimator.BoostedTreesClassifier,
2627
select="SELECT * FROM iris.train where class!=2",
@@ -39,6 +40,7 @@
3940
batch_size=100,
4041
epochs=20,
4142
verbose=0)
43+
4244
explain(datasource=datasource,
4345
estimator_cls=tf.estimator.BoostedTreesClassifier,
4446
select="SELECT * FROM iris.test where class!=2",
@@ -56,3 +58,37 @@
5658
is_pai=False,
5759
plot_type='bar',
5860
result_table="iris.explain_result")
61+
62+
# Train and explain DNNClassifier
63+
train(datasource=datasource,
64+
estimator=tf.estimator.DNNClassifier,
65+
select="SELECT * FROM iris.train",
66+
validate_select="SELECT * FROM iris.test",
67+
feature_columns=feature_columns,
68+
feature_column_names=feature_column_names,
69+
feature_metas=feature_metas,
70+
label_meta=label_meta,
71+
model_params={
72+
"n_classes": 3,
73+
"hidden_units": [100, 100],
74+
},
75+
save="dnnmodel",
76+
batch_size=100,
77+
epochs=20,
78+
verbose=0)
79+
80+
explain(datasource=datasource,
81+
estimator_cls=tf.estimator.DNNClassifier,
82+
select="SELECT * FROM iris.test LIMIT 10",
83+
feature_columns=feature_columns,
84+
feature_column_names=feature_column_names,
85+
feature_metas=feature_metas,
86+
label_meta=label_meta,
87+
model_params={
88+
"n_classes": 3,
89+
"hidden_units": [100, 100],
90+
},
91+
save="dnnmodel",
92+
is_pai=False,
93+
plot_type='bar',
94+
result_table="iris.explain_result")

0 commit comments

Comments
 (0)