Skip to content

Commit f33fd39

Browse files
committed
change evaluate_policy behaviour
1 parent baefa9a commit f33fd39

1 file changed

Lines changed: 6 additions & 5 deletions

File tree

src/Utils/policy.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,18 @@ By default, the environment is reset before running the policy.
6464
function evaluate_policy!(
6565
policy, env::AbstractEnvironment, episodes::Int; seed=get_seed(env), kwargs...
6666
)
67-
total_reward = 0.0
67+
rewards = zeros(Float64, episodes)
6868
datasets = map(1:episodes) do _i
6969
if _i == 1
7070
reset!(env; reset_rng=true, seed=seed)
7171
else
7272
reset!(env; reset_rng=false)
7373
end
7474
reward, dataset = evaluate_policy!(policy, env; reset_env=false, kwargs...)
75-
total_reward += reward
75+
rewards[_i] = reward
7676
return dataset
7777
end
78-
return total_reward / episodes, vcat(datasets...)
78+
return rewards, datasets
7979
end
8080

8181
"""
@@ -90,8 +90,9 @@ function evaluate_policy!(
9090
E = length(envs)
9191
rewards = zeros(Float64, E)
9292
datasets = map(1:E) do e
93-
reward, dataset = evaluate_policy!(policy, envs[e], episodes; kwargs...)
94-
rewards[e] = reward
93+
rewards, datasets = evaluate_policy!(policy, envs[e], episodes; kwargs...)
94+
rewards[e] = sum(reward) / episodes
95+
dataset = vcat(datasets...)
9596
return dataset
9697
end
9798
return rewards, vcat(datasets...)

0 commit comments

Comments
 (0)