Skip to content

Commit b047bc5

Browse files
typhoonzeroYancey0623
authored andcommitted
Xgboost predict (#789)
* xgboost predict * refine * update * add executor test * fix test by merge
1 parent 3c5c411 commit b047bc5

5 files changed

Lines changed: 144 additions & 34 deletions

File tree

sql/codegen_xgboost.go

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ func resolveParamsCfg(attrs map[string]*attribute) (map[string]interface{}, erro
7979
func resolveObjective(pr *extendedSelect) (string, error) {
8080
estimatorParts := strings.Split(pr.estimator, ".")
8181
if len(estimatorParts) != 3 {
82-
return "", fmt.Errorf("XGBoost Estimator should be xgboost.first_part.second_part")
82+
return "", fmt.Errorf("XGBoost Estimator should be xgboost.first_part.second_part, current: %s", pr.estimator)
8383
}
8484
return strings.Join(estimatorParts[1:], ":"), nil
8585
}
@@ -90,6 +90,7 @@ func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, db *DB) (*xgbFille
9090
return nil, err
9191
}
9292
training, validation := trainingAndValidationDataset(pr, ds)
93+
isTrain := pr.train
9394
r := &xgbFiller{
9495
Estimator: Estimator{
9596
IsTrain: pr.train,
@@ -99,25 +100,34 @@ func newXGBFiller(pr *extendedSelect, ds *trainAndValDataset, db *DB) (*xgbFille
99100
xgbTrainConfig: *resolveTrainCfg(attrs),
100101
Save: pr.save,
101102
}
102-
103-
// resolve the attribute keys without any prefix as the XGBoost Paremeters
104-
params, err := resolveParamsCfg(attrs)
105-
if err != nil {
106-
return nil, err
103+
if !isTrain && !pr.analyze {
104+
r.PredictionDatasetSQL = pr.standardSelect.String()
105+
if r.TableName, _, err = parseTableColumn(pr.into); err != nil {
106+
return nil, err
107+
}
108+
r.Save = pr.model
107109
}
108110

109-
// fill learning target
110-
objective, err := resolveObjective(pr)
111-
if err != nil {
112-
return nil, err
113-
}
114-
params["objective"] = objective
111+
if isTrain {
112+
// resolve the attribute keys without any prefix as the XGBoost Paremeters
113+
params, err := resolveParamsCfg(attrs)
114+
if err != nil {
115+
return nil, err
116+
}
115117

116-
paramsJSON, err := json.Marshal(params)
117-
if err != nil {
118-
return nil, err
118+
// fill learning target
119+
objective, err := resolveObjective(pr)
120+
if err != nil {
121+
return nil, err
122+
}
123+
params["objective"] = objective
124+
125+
paramsJSON, err := json.Marshal(params)
126+
if err != nil {
127+
return nil, err
128+
}
129+
r.ParamsCfgJSON = string(paramsJSON)
119130
}
120-
r.ParamsCfgJSON = string(paramsJSON)
121131

122132
if r.connectionConfig, err = newConnectionConfig(db); err != nil {
123133
return nil, err
@@ -161,7 +171,11 @@ func genXGBoost(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fie
161171
if pr.train {
162172
return xgbTrainTemplate.Execute(w, r)
163173
}
164-
return fmt.Errorf("xgboost prediction codegen has not been implemented")
174+
if e := createPredictionTable(pr, db); e != nil {
175+
return fmt.Errorf("failed to create prediction table: %v", e)
176+
}
177+
return xgbPredictTemplate.Execute(w, r)
165178
}
166179

167180
var xgbTrainTemplate = template.Must(template.New("codegenXGBTrain").Parse(xgbTrainTemplateText))
181+
var xgbPredictTemplate = template.Must(template.New("codegenXGBPredict").Parse(xgbPredictTemplateText))

sql/codegen_xgboost_test.go

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ SELECT *
2525
FROM iris.train
2626
TRAIN xgb.multi.softprob
2727
WITH
28-
train.num_boost_round = 30,
29-
eta = 3.1,
30-
num_class = 3
28+
train.num_boost_round = 30,
29+
eta = 3.1,
30+
num_class = 3
3131
COLUMN sepal_length, sepal_width, petal_length, petal_width
3232
LABEL class
3333
INTO sqlflow_models.my_xgboost_model;
@@ -38,6 +38,13 @@ ANALYZE sqlflow_models.my_xgboost_model
3838
USING TreeExplainer;
3939
`
4040

41+
const testXGBoostPredictIris = `
42+
SELECT *
43+
FROM iris.test
44+
PREDICT iris.predict.class
45+
USING sqlflow_models.my_xgboost_model;
46+
`
47+
4148
func TestXGBFiller(t *testing.T) {
4249
a := assert.New(t)
4350
parser := newParser()
@@ -56,3 +63,17 @@ func TestXGBFiller(t *testing.T) {
5663
a.NoError(err)
5764
a.Equal(filler.ParamsCfgJSON, string(paramsJSON))
5865
}
66+
67+
func TestXGBFillerPredict(t *testing.T) {
68+
a := assert.New(t)
69+
parser := newParser()
70+
r, e := parser.Parse(testXGBoostPredictIris)
71+
a.NoError(e)
72+
filler, e := newXGBFiller(r, nil, testDB)
73+
a.NoError(e)
74+
a.False(filler.IsTrain)
75+
a.Equal(filler.TableName, "iris.predict")
76+
a.Equal(filler.Save, "sqlflow_models.my_xgboost_model")
77+
a.Equal(filler.PredictionDatasetSQL, `SELECT *
78+
FROM iris.test`)
79+
}

sql/executor.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,10 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin
463463
if e := genAntXGBoost(&buf, pr, nil, fts, db); e != nil {
464464
return fmt.Errorf("genAntXGBoost %v", e)
465465
}
466+
} else if strings.HasPrefix(strings.ToUpper(pr.estimator), `XGB.`) {
467+
if e := genXGBoost(&buf, pr, nil, fts, db); e != nil {
468+
return fmt.Errorf("genXGBoost %v", e)
469+
}
466470
} else {
467471
if e := genTF(&buf, pr, nil, fts, db); e != nil {
468472
return fmt.Errorf("genTF %v", e)

sql/executor_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ func TestExecutorXGBoost(t *testing.T) {
112112
a.True(goodStream(stream.ReadAll()))
113113
stream = runExtendedSQL(testAnalyzeTreeModelSelectIris, testDB, modelDir, nil)
114114
a.True(goodStream(stream.ReadAll()))
115+
stream = runExtendedSQL(testXGBoostPredictIris, testDB, modelDir, nil)
116+
a.True(goodStream(stream.ReadAll()))
115117
})
116118
}
117119

sql/template_xgboost.go

Lines changed: 83 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ num_boost_round = {{.NumBoostRound}}
3434
maximize = True if "{{.Maximize}}" == "true" else False
3535
early_stopping_rounds = {{.EarlyStoppingRounds}}
3636
if early_stopping_rounds == -1:
37-
early_stopping_rounds = None
37+
early_stopping_rounds = None
3838
3939
{{if ne .ParamsCfgJSON ""}}
4040
params = {{.ParamsCfgJSON}}
@@ -58,22 +58,20 @@ feature_specs["{{$value.FeatureName}}"] = {
5858
}
5959
{{end}}
6060
61-
62-
6361
conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}")
6462
6563
def xgb_dataset(fn, dataset_sql):
66-
gen = db_generator(driver, conn, session_cfg, dataset_sql, feature_column_names, "{{.Y.FeatureName}}", feature_specs)
67-
with open(fn, 'w') as f:
68-
for item in gen():
69-
features, label = item
70-
row_data = [str(label[0])] + ["%d:%f" % (i, v) for i, v in enumerate(features)]
71-
f.write("\t".join(row_data) + "\n")
72-
# TODO(yancey1989): genearte group and weight text file if necessary
73-
return xgb.DMatrix(fn)
74-
75-
dtrain = xgb_dataset('train.txt', "{{.TrainingDatasetSQL}}")
76-
dtest = xgb_dataset('test.txt', "{{.ValidationDatasetSQL}}")
64+
gen = db_generator(driver, conn, session_cfg, dataset_sql, feature_column_names, "{{.Y.FeatureName}}", feature_specs)
65+
with open(fn, 'w') as f:
66+
for item in gen():
67+
features, label = item
68+
row_data = [str(label[0])] + ["%d:%f" % (i, v) for i, v in enumerate(features)]
69+
f.write("\t".join(row_data) + "\n")
70+
# TODO(yancey1989): genearte group and weight text file if necessary
71+
return xgb.DMatrix(fn)
72+
73+
dtrain = xgb_dataset('train.txt', """{{.TrainingDatasetSQL}}""")
74+
dtest = xgb_dataset('test.txt', """{{.ValidationDatasetSQL}}""")
7775
7876
train_args = {}
7977
train_args["num_boost_round"] = num_boost_round
@@ -84,3 +82,74 @@ train_args["evals"] = [(dtrain, "train"), (dtest, "validation")]
8482
bst = xgb.train(params, dtrain, **train_args)
8583
bst.save_model("{{.Save}}")
8684
`
85+
86+
const xgbPredictTemplateText = `
87+
import xgboost as xgb
88+
import numpy as np
89+
from sqlflow_submitter.db import connect, db_generator, buffered_db_writer
90+
91+
driver="{{.Driver}}"
92+
93+
{{if ne .Database ""}}
94+
database="{{.Database}}"
95+
{{else}}
96+
database=""
97+
{{end}}
98+
99+
session_cfg = {}
100+
{{ range $k, $v := .Session }}
101+
session_cfg["{{$k}}"] = "{{$v}}"
102+
{{end}}
103+
104+
feature_column_names = [{{range .X}}
105+
"{{.FeatureName}}",
106+
{{end}}]
107+
108+
{{/* Convert go side featureSpec to python dict for input_fn */}}
109+
feature_specs = dict()
110+
{{ range $value := .X }}
111+
feature_specs["{{$value.FeatureName}}"] = {
112+
"feature_name": "{{$value.FeatureName}}",
113+
"dtype": "{{$value.Dtype}}",
114+
"delimiter": "{{$value.Delimiter}}",
115+
"shape": {{$value.InputShape}},
116+
"is_sparse": "{{$value.IsSparse}}" == "true"
117+
}
118+
{{end}}
119+
120+
conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}")
121+
122+
def xgb_dataset(fn, dataset_sql):
123+
gen = db_generator(driver, conn, session_cfg, dataset_sql, feature_column_names, "", feature_specs)
124+
with open(fn, 'w') as f:
125+
for item in gen():
126+
features, label = item
127+
row_data = [str(label[0])] + ["%d:%f" % (i, v) for i, v in enumerate(features)]
128+
f.write("\t".join(row_data) + "\n")
129+
# TODO(yancey1989): genearte group and weight text file if necessary
130+
return xgb.DMatrix(fn)
131+
132+
dpred = xgb_dataset('predict.txt', """{{.PredictionDatasetSQL}}""")
133+
134+
bst = xgb.Booster({'nthread': 4}) # init model
135+
bst.load_model("{{.Save}}") # load data
136+
preds = bst.predict(dpred)
137+
# TODO(typhoonzero): regression models may have different behavior
138+
pred_classes = np.argmax(np.array(preds), axis=1)
139+
140+
feature_file_read = open("predict.txt", "r")
141+
142+
result_column_names = feature_column_names
143+
result_column_names.append("{{.Y.FeatureName}}")
144+
145+
line_no = 0
146+
with buffered_db_writer(driver, conn, "{{.TableName}}", result_column_names, 100) as w:
147+
while True:
148+
line = feature_file_read.readline()
149+
if not line:
150+
break
151+
row = [i.split(":")[1] for i in line.replace("\n", "").split("\t")[1:]]
152+
row.append(pred_classes[line_no])
153+
w.write(row)
154+
line_no += 1
155+
`

0 commit comments

Comments
 (0)