@@ -19,16 +19,21 @@ import (
1919 "strings"
2020)
2121
22+ const (
23+ shapSummaryAttributePrefix = "shap_summary"
24+ )
25+
2226type analyzeFiller struct {
2327 * connectionConfig
24- X []* FeatureMeta
25- Label string
26- AnalyzeDatasetSQL string
27- PlotType string
28- ModelFile string // path/to/model_file
28+ X []* FeatureMeta
29+ Label string
30+ AnalyzeDatasetSQL string
31+ PlotType string
32+ ShapSummaryParames map [string ]interface {}
33+ ModelFile string // path/to/model_file
2934}
3035
31- func newAnalyzeFiller (pr * extendedSelect , db * DB , fms []* FeatureMeta , label , modelPath , plotType string ) (* analyzeFiller , error ) {
36+ func newAnalyzeFiller (pr * extendedSelect , db * DB , fms []* FeatureMeta , label , modelPath string , summaryAttrs map [ string ] interface {} ) (* analyzeFiller , error ) {
3237 conn , err := newConnectionConfig (db )
3338 if err != nil {
3439 return nil , err
@@ -39,9 +44,9 @@ func newAnalyzeFiller(pr *extendedSelect, db *DB, fms []*FeatureMeta, label, mod
3944 Label : label ,
4045 // TODO(weiguo): test if it needs TrimSuffix(SQL, ";") on hive,
4146 // or we trim it in pr(*extendedSelect)
42- AnalyzeDatasetSQL : pr .standardSelect .String (),
43- ModelFile : modelPath ,
44- PlotType : plotType ,
47+ AnalyzeDatasetSQL : pr .standardSelect .String (),
48+ ModelFile : modelPath ,
49+ ShapSummaryParames : summaryAttrs ,
4550 }, nil
4651}
4752
@@ -71,13 +76,19 @@ func readXGBFeatures(pr *extendedSelect, db *DB) ([]*FeatureMeta, string, error)
7176 return xs , fr .Y .FeatureName , nil
7277}
7378
74- func readPlotType (pr * extendedSelect ) string {
75- v , ok := pr .analyzeAttrs ["shap.plot_type" ]
76- if ! ok {
77- // using shap default value
78- return `""`
79+ func resolveAnalyzeSummaryParames (atts * attrs ) (map [string ]interface {}, error ) {
80+ parames , err := resolveAttribute (atts )
81+ if err != nil {
82+ return nil , err
7983 }
80- return v .val
84+
85+ summaryAttrs := make (map [string ]interface {})
86+ for _ , v := range parames {
87+ if v .Prefix == shapSummaryAttributePrefix {
88+ summaryAttrs [v .Name ] = v .Value
89+ }
90+ }
91+ return summaryAttrs , nil
8192}
8293
8394func genAnalyzer (pr * extendedSelect , db * DB , cwd , modelDir string ) (* bytes.Buffer , error ) {
@@ -89,13 +100,17 @@ func genAnalyzer(pr *extendedSelect, db *DB, cwd, modelDir string) (*bytes.Buffe
89100 return nil , fmt .Errorf ("analyzer: model[%s] not supported" , pr .estimator )
90101 }
91102 // We untar the XGBoost.{pr.trainedModel}.tar.gz and get three files.
92- plotType := readPlotType (pr )
103+ summaryAttrs , err := resolveAnalyzeSummaryParames (& pr .analyzeAttrs )
104+ if err != nil {
105+ return nil , err
106+ }
107+
93108 xs , label , err := readXGBFeatures (pr , db )
94109 if err != nil {
95110 return nil , err
96111 }
97112
98- fr , err := newAnalyzeFiller (pr , db , xs , label , pr .trainedModel , plotType )
113+ fr , err := newAnalyzeFiller (pr , db , xs , label , pr .trainedModel , summaryAttrs )
99114 if err != nil {
100115 return nil , fmt .Errorf ("create analyze filler failed: %v" , err )
101116 }
0 commit comments