Skip to content

Commit 6cd7b24

Browse files
authored
Add XGBoost evaluation codegen (#2867)
* add evaluate codegen * update * update __init__.py fix flake8 * polish py db code
1 parent aac2130 commit 6cd7b24

File tree

9 files changed

+161
-29
lines changed

9 files changed

+161
-29
lines changed

go/cmd/sqlflowserver/e2e_workflow_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,15 @@ TO PREDICT iris.test_result_table.class
409409
USING sqlflow_models.xgb_classification;
410410
411411
SELECT * FROM iris.test_result_table;
412+
413+
SELECT * FROM iris.test
414+
TO EVALUATE sqlflow_models.xgb_classification
415+
WITH
416+
validation.metrics="accuracy_score"
417+
LABEL class
418+
INTO iris.evaluate_result_table;
419+
420+
SELECT * FROM iris.evaluate_result_table;
412421
`
413422
testMain(extraTrainSQLProgram + sqlProgram)
414423
testMain(sqlProgram)

go/codegen/experimental/codegen_couler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func GenerateCodeCouler(sqlProgram string, session *pb.Session) (string, error)
5252
if err != nil {
5353
return "", err
5454
}
55-
stepList := make([]*stepContext, 0)
55+
var stepList []*stepContext
5656
for idx, stmt := range stmts {
5757
stepCode, image, err := generateStepCodeAndImage(stmt, idx, session, stmts)
5858
if err != nil {

go/codegen/experimental/codegen_step.go

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ func generateStepCodeAndImage(sqlStmt ir.SQLFlowStmt, stepIndex int, session *pb
3434
return generateTrainCodeAndImage(stmt, stepIndex, session)
3535
case *ir.PredictStmt:
3636
return generatePredictCodeAndImage(stmt, stepIndex, session, sqlStmts)
37+
case *ir.EvaluateStmt:
38+
return generateEvaluationCodeAndImage(stmt, stepIndex, session, sqlStmts)
3739
case *ir.NormalStmt:
3840
code, err := generateNormalStmtStep(string(*stmt), stepIndex, session)
3941
return code, "", err
@@ -55,9 +57,9 @@ func generateTrainCodeAndImage(trainStmt *ir.TrainStmt, stepIndex int, session *
5557
}
5658

5759
func generatePredictCodeAndImage(predStmt *ir.PredictStmt, stepIndex int, session *pb.Session, sqlStmts []ir.SQLFlowStmt) (string, string, error) {
58-
trainStmt := findModelGenerationTrainStmt(predStmt.Using, stepIndex, sqlStmts)
5960
image := ""
6061
isXGBoost := false
62+
trainStmt := findModelGenerationTrainStmt(predStmt.Using, stepIndex, sqlStmts)
6163
if trainStmt != nil {
6264
image = trainStmt.ModelImage
6365
isXGBoost = isXGBoostEstimator(trainStmt.Estimator)
@@ -80,6 +82,32 @@ func generatePredictCodeAndImage(predStmt *ir.PredictStmt, stepIndex int, sessio
8082
return "", "", fmt.Errorf("not implemented model type")
8183
}
8284

85+
func generateEvaluationCodeAndImage(evalStmt *ir.EvaluateStmt, stepIndex int, session *pb.Session, sqlStmts []ir.SQLFlowStmt) (string, string, error) {
86+
image := ""
87+
isXGBoost := false
88+
trainStmt := findModelGenerationTrainStmt(evalStmt.ModelName, stepIndex, sqlStmts)
89+
if trainStmt != nil {
90+
image = trainStmt.ModelImage
91+
isXGBoost = isXGBoostEstimator(trainStmt.Estimator)
92+
} else {
93+
meta, err := getModelMetadata(session, evalStmt.ModelName)
94+
if err != nil {
95+
return "", "", err
96+
}
97+
image = meta.imageName()
98+
isXGBoost = meta.isXGBoostModel()
99+
}
100+
101+
if isXGBoost {
102+
code, err := XGBoostGenerateEvaluation(evalStmt, stepIndex, session)
103+
if err != nil {
104+
return "", "", err
105+
}
106+
return code, image, nil
107+
}
108+
return "", "", fmt.Errorf("not implemented model type")
109+
}
110+
83111
// findModelGenerationTrainStmt finds the *ir.TrainStmt that generates the model named `modelName`.
84112
// TODO(sneaxiy): find a better way to do this when we have a well designed dependency analysis.
85113
func findModelGenerationTrainStmt(modelName string, idx int, sqlStmts []ir.SQLFlowStmt) *ir.TrainStmt {
@@ -144,7 +172,7 @@ func getModelMetadataFromDB(dbConnStr, table string) (*metadata, error) {
144172
if err != nil {
145173
return nil, err
146174
}
147-
if readCnt != int(length) {
175+
if uint64(readCnt) != length {
148176
return nil, fmt.Errorf("invalid model metadata")
149177
}
150178
json, err := simplejson.NewJson(jsonBytes)

go/codegen/experimental/xgboost.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,77 @@ def step_entry_{{.StepIndex}}():
208208
load='''{{.Load}}''')
209209
`
210210

211+
type xgbEvaluateFiller struct {
212+
StepIndex int
213+
DataSource string
214+
Select string
215+
ResultTable string
216+
PredLabelName string
217+
Load string
218+
ValidationMetrics string
219+
Submitter string
220+
}
221+
222+
// XGBoostGenerateEvaluation generates the XGBoost evaluation code
223+
func XGBoostGenerateEvaluation(evalStmt *ir.EvaluateStmt, stepIndex int, session *pb.Session) (string, error) {
224+
ds, err := GeneratePyDbConnStr(session)
225+
if err != nil {
226+
return "", err
227+
}
228+
229+
labelName := ""
230+
if nc, ok := evalStmt.Label.(*ir.NumericColumn); ok {
231+
labelName = nc.FieldDesc.Name
232+
} else {
233+
return "", fmt.Errorf("unsupported label type %T", evalStmt.Label)
234+
}
235+
236+
metricList := []string{"accuracy_score"}
237+
if m, ok := evalStmt.Attributes["validation.metrics"]; ok {
238+
if metricStr, ok := m.(string); ok {
239+
metricList = []string{}
240+
for _, s := range strings.Split(metricStr, ",") {
241+
metricList = append(metricList, strings.TrimSpace(s))
242+
}
243+
} else {
244+
return "", fmt.Errorf("validation.metrics must be of type string")
245+
}
246+
}
247+
metricPyStr := ir.AttrToPythonValue(metricList)
248+
249+
filler := &xgbEvaluateFiller{
250+
StepIndex: stepIndex,
251+
DataSource: ds,
252+
Select: replaceNewLineRuneAndTrimSpace(evalStmt.Select),
253+
ResultTable: evalStmt.Into,
254+
PredLabelName: labelName,
255+
Load: evalStmt.ModelName,
256+
ValidationMetrics: metricPyStr,
257+
Submitter: getSubmitter(session),
258+
}
259+
260+
var program bytes.Buffer
261+
tpl := template.Must(template.New("Evaluate").Parse(xgbEvaluateTemplate))
262+
if err := tpl.Execute(&program, filler); err != nil {
263+
return "", err
264+
}
265+
return program.String(), nil
266+
}
267+
268+
const xgbEvaluateTemplate = `
269+
def step_entry_{{.StepIndex}}():
270+
import runtime.temp_file as temp_file
271+
from runtime.{{.Submitter}} import evaluate
272+
273+
with temp_file.TemporaryDirectory(as_cwd=True):
274+
evaluate(datasource='''{{.DataSource}}''',
275+
select='''{{.Select}}''',
276+
result_table='''{{.ResultTable}}''',
277+
pred_label_name='''{{.PredLabelName}}''',
278+
load='''{{.Load}}''',
279+
validation_metrics={{.ValidationMetrics}})
280+
`
281+
211282
func getSubmitter(session *pb.Session) string {
212283
if session.Submitter != "" {
213284
return session.Submitter

python/runtime/local/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
from runtime.local.submitter import submit_local_evaluate as evaluate # noqa: F401, E501
1415
from runtime.local.submitter import submit_local_pred as pred # noqa: F401
1516
from runtime.local.submitter import submit_local_train as train # noqa: F401

python/runtime/local/submitter.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
from runtime.local.xgboost_submitter.evaluate import \
15+
evaluate as xgboost_evaluate
1416
from runtime.local.xgboost_submitter.predict import pred as xgboost_pred
1517
from runtime.local.xgboost_submitter.train import train as xgboost_train
1618
from runtime.model.model import EstimatorType, Model
@@ -74,5 +76,16 @@ def submit_local_pred(datasource, select, result_table, pred_label_name, load):
7476
if model.get_type() == EstimatorType.XGBOOST:
7577
xgboost_pred(datasource, select, result_table, pred_label_name, model)
7678
else:
77-
raise NotImplementedError("not implemented model type: %s" %
78-
model.get_type())
79+
raise NotImplementedError("not implemented model type: {}".format(
80+
model.get_type()))
81+
82+
83+
def submit_local_evaluate(datasource, select, result_table, pred_label_name,
84+
load, validation_metrics):
85+
model = Model.load_from_db(datasource, load)
86+
if model.get_type() == EstimatorType.XGBOOST:
87+
xgboost_evaluate(datasource, select, result_table, model,
88+
pred_label_name, validation_metrics)
89+
else:
90+
raise NotImplementedError("not implemented model type: {}".format(
91+
model.get_type()))

python/runtime/local/xgboost_submitter/evaluate.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717
import runtime.temp_file as temp_file
1818
import runtime.xgboost as xgboost_extended
19+
import six
1920
import sklearn.metrics
2021
import xgboost as xgb
2122
from runtime import db
@@ -53,7 +54,7 @@
5354
def evaluate(datasource,
5455
select,
5556
result_table,
56-
load,
57+
model,
5758
pred_label_name=None,
5859
validation_metrics=["accuracy_score"]):
5960
"""
@@ -63,14 +64,19 @@ def evaluate(datasource,
6364
datasource (str): the database connection string.
6465
select (str): the input data to predict.
6566
result_table (str): the output data table.
66-
load (str): where the trained model stores.
67+
model (Model|str): the model object or where to load the model.
6768
pred_label_name (str): the label column name.
6869
validation_metrics (list[str]): the evaluation metric names.
6970
7071
Returns:
7172
None.
7273
"""
73-
model = Model.load_from_db(datasource, load)
74+
if isinstance(model, six.string_types):
75+
model = Model.load_from_db(datasource, model)
76+
else:
77+
assert isinstance(model,
78+
Model), "not supported model type %s" % type(model)
79+
7480
model_params = model.get_meta("attributes")
7581
train_fc_map = model.get_meta("features")
7682
train_label_desc = model.get_meta("label").get_field_desc()[0]

python/runtime/model/db.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,9 @@ def _read_metadata(reader):
128128
return json.loads(metadata_json, cls=JSONDecoderWithFeatureColumn)
129129

130130

131-
def write_with_generator(datasource, table, gen, metadata):
131+
def write_with_generator_and_metadata(datasource, table, gen, metadata):
132132
"""Write data into a table, the written data
133-
comes from the input generator.
133+
comes from the input generator and metadata.
134134
135135
Args:
136136
datasource: string
@@ -176,9 +176,9 @@ def read_metadata_from_db(datasource, table):
176176
return metadata
177177

178178

179-
def read_with_generator(datasource, table, buff_size=256):
179+
def read_with_generator_and_metadata(datasource, table, buff_size=256):
180180
"""Read data from a table, this function returns
181-
a generator to yield the data.
181+
a generator to yield the data, and the metadata dict.
182182
183183
Args:
184184
datasource: string
@@ -188,20 +188,23 @@ def read_with_generator(datasource, table, buff_size=256):
188188
buff_size: int
189189
The buffer size to read data.
190190
191-
Returns: Generator
192-
the generator yield row data of the table.
191+
Returns: tuple(Generator, dict)
192+
the generator yield row data of the table,
193+
and the model metadata dict.
193194
"""
195+
conn = connect_with_data_source(datasource)
196+
r = SQLFSReader(conn, table)
197+
metadata = _read_metadata(r)
198+
194199
def reader():
195-
conn = connect_with_data_source(datasource)
196-
with SQLFSReader(conn, table) as r:
197-
_read_metadata(r)
198-
while True:
199-
buffer = r.read(buff_size)
200-
if not buffer:
201-
break
200+
while True:
201+
buffer = r.read(buff_size)
202+
if not buffer:
203+
break
202204

203-
yield buffer
205+
yield buffer
204206

207+
r.close()
205208
conn.close()
206209

207-
return reader
210+
return reader, metadata

python/runtime/model/model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from runtime.feature.column import (JSONDecoderWithFeatureColumn,
2020
JSONEncoderWithFeatureColumn)
2121
from runtime.model import oss
22-
from runtime.model.db import (read_metadata_from_db, read_with_generator,
23-
write_with_generator)
22+
from runtime.model.db import (read_with_generator_and_metadata,
23+
write_with_generator_and_metadata)
2424
from runtime.model.tar import unzip_dir, zip_dir
2525

2626
# archive the current work director into a tarball
@@ -177,8 +177,9 @@ def _gen():
177177

178178
return _gen
179179

180-
write_with_generator(datasource, table, _bytes_reader(tarball),
181-
self._to_dict())
180+
write_with_generator_and_metadata(datasource, table,
181+
_bytes_reader(tarball),
182+
self._to_dict())
182183

183184
@staticmethod
184185
def load_from_db(datasource, table, local_dir=None):
@@ -199,14 +200,14 @@ def load_from_db(datasource, table, local_dir=None):
199200

200201
with temp_file.TemporaryDirectory() as tmp_dir:
201202
tarball = os.path.join(tmp_dir, TARBALL_NAME)
202-
gen = read_with_generator(datasource, table)
203+
gen, metadata = read_with_generator_and_metadata(datasource, table)
203204
with open(tarball, "wb") as f:
204205
for data in gen():
205206
f.write(bytes(data))
206207

207208
Model._unzip(local_dir, tarball, load_from_db=True)
208209

209-
return Model._from_dict(read_metadata_from_db(datasource, table))
210+
return Model._from_dict(metadata)
210211

211212
def save_to_oss(self, oss_model_dir, local_dir=None):
212213
"""

0 commit comments

Comments
 (0)