|
21 | 21 | from sqlflow_submitter.tensorflow.train import train |
22 | 22 |
|
23 | 23 | if __name__ == "__main__": |
| 24 | + # Train and explain BoostedTreesClassifier |
24 | 25 | train(datasource=datasource, |
25 | 26 | estimator=tf.estimator.BoostedTreesClassifier, |
26 | 27 | select="SELECT * FROM iris.train where class!=2", |
|
39 | 40 | batch_size=100, |
40 | 41 | epochs=20, |
41 | 42 | verbose=0) |
| 43 | + |
42 | 44 | explain(datasource=datasource, |
43 | 45 | estimator_cls=tf.estimator.BoostedTreesClassifier, |
44 | 46 | select="SELECT * FROM iris.test where class!=2", |
|
56 | 58 | is_pai=False, |
57 | 59 | plot_type='bar', |
58 | 60 | result_table="iris.explain_result") |
| 61 | + |
| 62 | + # Train and explain DNNClassifier |
| 63 | + train(datasource=datasource, |
| 64 | + estimator=tf.estimator.DNNClassifier, |
| 65 | + select="SELECT * FROM iris.train", |
| 66 | + validate_select="SELECT * FROM iris.test", |
| 67 | + feature_columns=feature_columns, |
| 68 | + feature_column_names=feature_column_names, |
| 69 | + feature_metas=feature_metas, |
| 70 | + label_meta=label_meta, |
| 71 | + model_params={ |
| 72 | + "n_classes": 3, |
| 73 | + "hidden_units": [100, 100], |
| 74 | + }, |
| 75 | + save="dnnmodel", |
| 76 | + batch_size=100, |
| 77 | + epochs=20, |
| 78 | + verbose=0) |
| 79 | + |
| 80 | + explain(datasource=datasource, |
| 81 | + estimator_cls=tf.estimator.DNNClassifier, |
| 82 | + select="SELECT * FROM iris.test LIMIT 10", |
| 83 | + feature_columns=feature_columns, |
| 84 | + feature_column_names=feature_column_names, |
| 85 | + feature_metas=feature_metas, |
| 86 | + label_meta=label_meta, |
| 87 | + model_params={ |
| 88 | + "n_classes": 3, |
| 89 | + "hidden_units": [100, 100], |
| 90 | + }, |
| 91 | + save="dnnmodel", |
| 92 | + is_pai=False, |
| 93 | + plot_type='bar', |
| 94 | + result_table="iris.explain_result") |
0 commit comments