Skip to content

Commit ce0310a

Browse files
authored
support summary_plot(args) (#798)
1 parent ced5708 commit ce0310a

2 files changed

Lines changed: 38 additions & 21 deletions

File tree

sql/codegen_analyze.go

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,21 @@ import (
1919
"strings"
2020
)
2121

22+
const (
23+
shapSummaryAttributePrefix = "shap_summary"
24+
)
25+
2226
type analyzeFiller struct {
2327
*connectionConfig
24-
X []*FeatureMeta
25-
Label string
26-
AnalyzeDatasetSQL string
27-
PlotType string
28-
ModelFile string // path/to/model_file
28+
X []*FeatureMeta
29+
Label string
30+
AnalyzeDatasetSQL string
31+
PlotType string
32+
ShapSummaryParames map[string]interface{}
33+
ModelFile string // path/to/model_file
2934
}
3035

31-
func newAnalyzeFiller(pr *extendedSelect, db *DB, fms []*FeatureMeta, label, modelPath, plotType string) (*analyzeFiller, error) {
36+
func newAnalyzeFiller(pr *extendedSelect, db *DB, fms []*FeatureMeta, label, modelPath string, summaryAttrs map[string]interface{}) (*analyzeFiller, error) {
3237
conn, err := newConnectionConfig(db)
3338
if err != nil {
3439
return nil, err
@@ -39,9 +44,9 @@ func newAnalyzeFiller(pr *extendedSelect, db *DB, fms []*FeatureMeta, label, mod
3944
Label: label,
4045
// TODO(weiguo): test if it needs TrimSuffix(SQL, ";") on hive,
4146
// or we trim it in pr(*extendedSelect)
42-
AnalyzeDatasetSQL: pr.standardSelect.String(),
43-
ModelFile: modelPath,
44-
PlotType: plotType,
47+
AnalyzeDatasetSQL: pr.standardSelect.String(),
48+
ModelFile: modelPath,
49+
ShapSummaryParames: summaryAttrs,
4550
}, nil
4651
}
4752

@@ -71,13 +76,19 @@ func readXGBFeatures(pr *extendedSelect, db *DB) ([]*FeatureMeta, string, error)
7176
return xs, fr.Y.FeatureName, nil
7277
}
7378

74-
func readPlotType(pr *extendedSelect) string {
75-
v, ok := pr.analyzeAttrs["shap.plot_type"]
76-
if !ok {
77-
// using shap default value
78-
return `""`
79+
func resolveAnalyzeSummaryParames(atts *attrs) (map[string]interface{}, error) {
80+
parames, err := resolveAttribute(atts)
81+
if err != nil {
82+
return nil, err
7983
}
80-
return v.val
84+
85+
summaryAttrs := make(map[string]interface{})
86+
for _, v := range parames {
87+
if v.Prefix == shapSummaryAttributePrefix {
88+
summaryAttrs[v.Name] = v.Value
89+
}
90+
}
91+
return summaryAttrs, nil
8192
}
8293

8394
func genAnalyzer(pr *extendedSelect, db *DB, cwd, modelDir string) (*bytes.Buffer, error) {
@@ -89,13 +100,17 @@ func genAnalyzer(pr *extendedSelect, db *DB, cwd, modelDir string) (*bytes.Buffe
89100
return nil, fmt.Errorf("analyzer: model[%s] not supported", pr.estimator)
90101
}
91102
// We untar the XGBoost.{pr.trainedModel}.tar.gz and get three files.
92-
plotType := readPlotType(pr)
103+
summaryAttrs, err := resolveAnalyzeSummaryParames(&pr.analyzeAttrs)
104+
if err != nil {
105+
return nil, err
106+
}
107+
93108
xs, label, err := readXGBFeatures(pr, db)
94109
if err != nil {
95110
return nil, err
96111
}
97112

98-
fr, err := newAnalyzeFiller(pr, db, xs, label, pr.trainedModel, plotType)
113+
fr, err := newAnalyzeFiller(pr, db, xs, label, pr.trainedModel, summaryAttrs)
99114
if err != nil {
100115
return nil, fmt.Errorf("create analyze filler failed: %v", err)
101116
}

sql/template_analyze.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,11 @@ def analyzer_dataset():
6666
6767
# 2. load the model
6868
model_path = "{{.ModelFile}}"
69-
ptype = {{.PlotType}}
70-
if len(ptype) == 0:
71-
ptype = None
69+
70+
summaryAttrs = {}
71+
{{ range $k, $v := .ShapSummaryParames }}
72+
summaryAttrs["{{$k}}"] = {{$v}}
73+
{{end}}
7274
7375
X,y = analyzer_dataset()
7476
@@ -77,7 +79,7 @@ bst.load_model(fname=model_path)
7779
explainer = shap.TreeExplainer(bst)
7880
shap_values = explainer.shap_values(X)
7981
80-
shap.summary_plot(shap_values, X, plot_type=ptype)
82+
shap.summary_plot(shap_values, X, **summaryAttrs)
8183
plt.savefig('summary')
8284
`
8385

0 commit comments

Comments
 (0)