Skip to content

Commit b5e09a8

Browse files
authored
Support custom train loop model prediction and evaluation (#2759)
* support custom train loop model prediction and evaluation * update models repo * update
1 parent 67b5833 commit b5e09a8

8 files changed

Lines changed: 82 additions & 27 deletions

File tree

docker/dev/build.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ cp target/*.jar $SQLFLOW_BIN
6868
echo "Build model zoo ..."
6969
cd $SQLFLOW_BIN
7070
if [[ ! -d models ]]; then
71-
git clone https://github.com/sql-machine-learning/models
71+
git clone https://github.com/sql-machine-learning/models.git
7272
fi
7373
cd models
7474
git fetch origin # The residual local repo might not be on a branch.
75-
git checkout v0.0.5 -b v0.0.5
75+
git checkout v0.0.6 -b v0.0.6
7676
python setup.py bdist_wheel -q --dist-dir $SQLFLOW_BIN > /dev/null
7777

7878
echo "Convert tutorials from Markdown to IPython notebooks ..."

go/attribute/attribute_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ func TestParamsDocs(t *testing.T) {
143143

144144
a.Equal(11, len(PremadeModelParamsDocs))
145145
ExtractSQLFlowModelsSymbolOnce()
146-
a.Equal(20, len(PremadeModelParamsDocs))
146+
a.Equal(21, len(PremadeModelParamsDocs))
147147
a.Equal(len(PremadeModelParamsDocs["DNNClassifier"]), 12)
148148
a.NotContains(PremadeModelParamsDocs["DNNClassifier"], "feature_columns")
149149
a.Contains(PremadeModelParamsDocs["DNNClassifier"], "optimizer")

go/cmd/sqlflow/main_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ func TestComplete(t *testing.T) {
636636

637637
p.InsertText(`RAIN `, false, true)
638638
c = s.completer(*p.Document())
639-
a.Equal(20, len(c))
639+
a.Equal(21, len(c))
640640
a.Equal("BoostedTreesClassifier", c[0].Text)
641641

642642
p.InsertText(`DNN`, false, true)

go/cmd/sqlflowserver/e2e_mysql_test.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,34 @@ import (
2626
server "sqlflow.org/sqlflow/go/sqlflowserver"
2727
)
2828

29+
func caseCustomLoopModel(t *testing.T) {
30+
a := assert.New(t)
31+
trainSQL := fmt.Sprintf(`SELECT * FROM %s
32+
TO TRAIN sqlflow_models.CustomClassifier
33+
LABEL class
34+
INTO sqlflow_models.custom_loop_model;`, caseTrainTable)
35+
_, _, _, err := connectAndRunSQL(trainSQL)
36+
if err != nil {
37+
a.Fail("Run trainSQL error: %v", err)
38+
}
39+
predSQL := fmt.Sprintf(`SELECT * FROM %s
40+
TO PREDICT sqlflow_models.custom_loop_model_pred_result.class
41+
USING sqlflow_models.custom_loop_model;`, caseTrainTable)
42+
_, _, _, err = connectAndRunSQL(predSQL)
43+
if err != nil {
44+
a.Fail("Run trainSQL error: %v", err)
45+
}
46+
evalSQL := fmt.Sprintf(`SELECT * FROM %s
47+
TO EVALUATE sqlflow_models.custom_loop_model
48+
WITH validation.metrics="Accuracy"
49+
LABEL class
50+
INTO sqlflow_models.custom_loop_model_eval_result;`, caseTrainTable)
51+
_, _, _, err = connectAndRunSQL(evalSQL)
52+
if err != nil {
53+
a.Fail("Run trainSQL error: %v", err)
54+
}
55+
}
56+
2957
func TestEnd2EndMySQL(t *testing.T) {
3058
if os.Getenv("SQLFLOW_TEST_DB") != "mysql" {
3159
t.Skip("Skipping mysql tests")
@@ -54,6 +82,7 @@ func TestEnd2EndMySQL(t *testing.T) {
5482
t.Run("CaseCoverage", CaseCoverageMysql)
5583
t.Run("CaseTrainWithCommaSeparatedLabel", CaseTrainWithCommaSeparatedLabel)
5684
t.Run("CaseTrainCustomModelFunctional", CaseTrainCustomModelFunctional)
85+
t.Run("CaseCustomLoopModel", caseCustomLoopModel)
5786
t.Run("CaseSQLByPassLeftJoin", CaseSQLByPassLeftJoin)
5887
t.Run("CaseTrainRegression", caseTrainRegression)
5988

python/runtime/tensorflow/evaluate.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def evaluate(datasource,
6060
result_metrics = estimator_evaluate(estimator, eval_dataset,
6161
validation_metrics)
6262
else:
63-
keras_model = init_model_with_feature_column(estimator, model_params)
63+
keras_model = init_model_with_feature_column(estimator_cls,
64+
model_params)
6465
keras_model_pkg = sys.modules[estimator_cls.__module__]
6566
result_metrics = keras_evaluate(keras_model, eval_dataset, save,
6667
keras_model_pkg, validation_metrics)
@@ -119,10 +120,12 @@ def keras_evaluate(keras_model, eval_dataset_fn, save, keras_model_pkg,
119120
else:
120121
# default
121122
keras_metrics = metrics.get_keras_metrics(["Accuracy"])
123+
has_custom_evaluate_func = hasattr(keras_model, 'sqlflow_evaluate_loop')
122124

123-
# compile the model with default arguments only for evaluation (run forward
124-
# only).
125-
keras_model.compile(loss=keras_model_pkg.loss, metrics=keras_metrics)
125+
if not has_custom_evaluate_func:
126+
# compile the model with default arguments only for evaluation
127+
# (run forward only).
128+
keras_model.compile(loss=keras_model_pkg.loss, metrics=keras_metrics)
126129

127130
eval_dataset = eval_dataset_fn()
128131

@@ -131,12 +134,17 @@ def get_features(sample, label):
131134

132135
eval_dataset_x = eval_dataset.map(get_features)
133136

134-
one_batch = next(iter(eval_dataset_x))
135-
# NOTE: must run predict one batch to initialize parameters
136-
# see: https://www.tensorflow.org/alpha/guide/keras/saving_and_serializing#saving_subclassed_models # noqa: E501
137-
keras_model.predict_on_batch(one_batch)
138-
keras_model.load_weights(save)
139-
result = keras_model.evaluate(eval_dataset)
137+
if has_custom_evaluate_func:
138+
result = keras_model.sqlflow_evaluate_loop(eval_dataset,
139+
validation_metrics)
140+
else:
141+
one_batch = next(iter(eval_dataset_x))
142+
# NOTE: must run predict one batch to initialize parameters
143+
# see: https://www.tensorflow.org/alpha/guide/keras/saving_and_serializing#saving_subclassed_models # noqa: E501
144+
keras_model.predict_on_batch(one_batch)
145+
keras_model.load_weights(save)
146+
result = keras_model.evaluate(eval_dataset)
147+
140148
assert (len(result) == len(validation_metrics) + 1)
141149
result_metrics = dict()
142150
for idx, m in enumerate(["loss"] + validation_metrics):

python/runtime/tensorflow/predict.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,16 @@ def eval_input_fn(batch_size, cache=False):
6666
dataset = dataset.cache()
6767
return dataset
6868

69-
# NOTE: always use batch_size=1 when predicting to get the pairs of
70-
# features and predict results to insert into result table.
71-
pred_dataset = eval_input_fn(1)
72-
one_batch = next(iter(pred_dataset))
73-
# NOTE: must run predict one batch to initialize parameters. See:
74-
# https://www.tensorflow.org/alpha/guide/keras/saving_and_serializing#saving_subclassed_models # noqa: E501
75-
classifier.predict_on_batch(one_batch)
76-
classifier.load_weights(save)
69+
if not hasattr(classifier, 'sqlflow_predict_one'):
70+
# NOTE: load_weights should be called by keras models only.
71+
# NOTE: always use batch_size=1 when predicting to get the pairs of
72+
# features and predict results to insert into result table.
73+
pred_dataset = eval_input_fn(1)
74+
one_batch = next(iter(pred_dataset))
75+
# NOTE: must run predict one batch to initialize parameters. See:
76+
# https://www.tensorflow.org/alpha/guide/keras/saving_and_serializing#saving_subclassed_models # noqa: E501
77+
classifier.predict_on_batch(one_batch)
78+
classifier.load_weights(save)
7779
pred_dataset = eval_input_fn(1, cache=True).make_one_shot_iterator()
7880

7981
column_names = selected_cols[:]
@@ -89,7 +91,10 @@ def eval_input_fn(batch_size, cache=False):
8991
hdfs_namenode_addr, hive_location, hdfs_user,
9092
hdfs_pass) as w:
9193
for features in pred_dataset:
92-
result = classifier.predict_on_batch(features)
94+
if hasattr(classifier, 'sqlflow_predict_one'):
95+
result = classifier.sqlflow_predict_one(features)
96+
else:
97+
result = classifier.predict_on_batch(features)
9398
# FIXME(typhoonzero): determine the predict result is
9499
# classification by adding the prediction result together
95100
# to see if it is close to 1.0.

python/runtime/tensorflow/train_keras.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,19 @@ def keras_train_and_save(estimator, model_params, save, is_pai,
9999
verbose, metric_names, validation_steps,
100100
load_pretrained_model, model_meta):
101101
print("Start training using keras model...")
102-
classifier, has_none_optimizer = keras_compile(estimator, model_params,
103-
save, metric_names)
102+
try:
103+
classifier, has_none_optimizer = keras_compile(estimator, model_params,
104+
save, metric_names)
105+
except Exception as e:
106+
if hasattr(estimator, "sqlflow_train_loop"):
107+
sys.stderr.write(
108+
"compile keras model failed, ignoring this error since the model seems to defined sqlflow_train_loop."
109+
)
110+
classifier = init_model_with_feature_column(
111+
estimator, model_params, has_none_optimizer=True)
112+
has_none_optimizer = True
113+
else:
114+
raise e
104115

105116
train_dataset = train_dataset_fn()
106117
if val_dataset_fn is not None:
@@ -165,9 +176,11 @@ def keras_train_compiled(classifier, save, train_dataset, validate_dataset,
165176
model_meta["evaluation"] = val_metrics
166177

167178
try:
168-
classifier.save_weights(save, save_format="h5")
169179
# write model metadata to model_meta.json
170180
save_model_metadata("model_meta.json", model_meta)
181+
# NOTE: classifier.save_weights may fail if the model has sqlflow_train_loop
182+
# and does not have Keras layers defined. So save metadata before calling save_weights.
183+
classifier.save_weights(save, save_format="h5")
171184
except: # noqa: E722
172185
if has_none_optimizer:
173186
warnings.warn("Saving model with None optimizer fails")

scripts/test/prepare.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ python -m pip install --quiet \
4444

4545
git clone https://github.com/sql-machine-learning/models.git
4646
(cd models && git fetch origin && \
47-
git checkout v0.0.5 -b v0.0.5 && \
47+
git checkout v0.0.6 -b v0.0.6 && \
4848
python setup.py install)
4949

5050
# 3. install java parser

0 commit comments

Comments
 (0)