@@ -64,18 +64,18 @@ By default, the environment is reset before running the policy.
6464function evaluate_policy! (
6565 policy, env:: AbstractEnvironment , episodes:: Int ; seed= get_seed (env), kwargs...
6666)
67- total_reward = 0.0
67+ rewards = zeros (Float64, episodes)
6868 datasets = map (1 : episodes) do _i
6969 if _i == 1
7070 reset! (env; reset_rng= true , seed= seed)
7171 else
7272 reset! (env; reset_rng= false )
7373 end
7474 reward, dataset = evaluate_policy! (policy, env; reset_env= false , kwargs... )
75- total_reward + = reward
75+ rewards[_i] = reward
7676 return dataset
7777 end
78- return total_reward / episodes, vcat ( datasets... )
78+ return rewards, datasets
7979end
8080
8181"""
@@ -90,8 +90,9 @@ function evaluate_policy!(
9090 E = length (envs)
9191 rewards = zeros (Float64, E)
9292 datasets = map (1 : E) do e
93- reward, dataset = evaluate_policy! (policy, envs[e], episodes; kwargs... )
94- rewards[e] = reward
93+ rewards, datasets = evaluate_policy! (policy, envs[e], episodes; kwargs... )
94+ rewards[e] = sum (reward) / episodes
95+ dataset = vcat (datasets... )
9596 return dataset
9697 end
9798 return rewards, vcat (datasets... )
0 commit comments