@@ -60,7 +60,8 @@ def evaluate(datasource,
6060 result_metrics = estimator_evaluate (estimator , eval_dataset ,
6161 validation_metrics )
6262 else :
63- keras_model = init_model_with_feature_column (estimator , model_params )
63+ keras_model = init_model_with_feature_column (estimator_cls ,
64+ model_params )
6465 keras_model_pkg = sys .modules [estimator_cls .__module__ ]
6566 result_metrics = keras_evaluate (keras_model , eval_dataset , save ,
6667 keras_model_pkg , validation_metrics )
@@ -119,10 +120,12 @@ def keras_evaluate(keras_model, eval_dataset_fn, save, keras_model_pkg,
119120 else :
120121 # default
121122 keras_metrics = metrics .get_keras_metrics (["Accuracy" ])
123+ has_custom_evaluate_func = hasattr (keras_model , 'sqlflow_evaluate_loop' )
122124
123- # compile the model with default arguments only for evaluation (run forward
124- # only).
125- keras_model .compile (loss = keras_model_pkg .loss , metrics = keras_metrics )
125+ if not has_custom_evaluate_func :
126+ # compile the model with default arguments only for evaluation
127+ # (run forward only).
128+ keras_model .compile (loss = keras_model_pkg .loss , metrics = keras_metrics )
126129
127130 eval_dataset = eval_dataset_fn ()
128131
@@ -131,12 +134,17 @@ def get_features(sample, label):
131134
132135 eval_dataset_x = eval_dataset .map (get_features )
133136
134- one_batch = next (iter (eval_dataset_x ))
135- # NOTE: must run predict one batch to initialize parameters
136- # see: https://www.tensorflow.org/alpha/guide/keras/saving_and_serializing#saving_subclassed_models # noqa: E501
137- keras_model .predict_on_batch (one_batch )
138- keras_model .load_weights (save )
139- result = keras_model .evaluate (eval_dataset )
137+ if has_custom_evaluate_func :
138+ result = keras_model .sqlflow_evaluate_loop (eval_dataset ,
139+ validation_metrics )
140+ else :
141+ one_batch = next (iter (eval_dataset_x ))
142+ # NOTE: must run predict one batch to initialize parameters
143+ # see: https://www.tensorflow.org/alpha/guide/keras/saving_and_serializing#saving_subclassed_models # noqa: E501
144+ keras_model .predict_on_batch (one_batch )
145+ keras_model .load_weights (save )
146+ result = keras_model .evaluate (eval_dataset )
147+
140148 assert (len (result ) == len (validation_metrics ) + 1 )
141149 result_metrics = dict ()
142150 for idx , m in enumerate (["loss" ] + validation_metrics ):
0 commit comments