-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpolicy.jl
More file actions
99 lines (90 loc) · 2.72 KB
/
policy.jl
File metadata and controls
99 lines (90 loc) · 2.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
$TYPEDEF
Policy type for decision-focused learning benchmarks.
"""
struct Policy{P}
"policy name"
name::String
"policy description"
description::String
"policy run function"
policy::P
end
function Base.show(io::IO, p::Policy)
println(io, "$(p.name): $(p.description)")
return nothing
end
"""
$TYPEDSIGNATURES
Run the policy and get the next decision on the given environment/instance.
"""
function (p::Policy)(args...; kwargs...)
return p.policy(args...; kwargs...)
end
"""
$TYPEDSIGNATURES
Run the policy on the environment and return the total reward and a dataset of observations.
By default, the environment is reset before running the policy.
"""
function evaluate_policy!(
policy, env::AbstractEnvironment; reset_env=true, seed=get_seed(env), kwargs...
)
if reset_env
reset!(env; reset_rng=true, seed=seed)
end
total_reward = 0.0
labeled_dataset = DataSample[]
while !is_terminated(env)
y = policy(env; kwargs...)
features, state = observe(env)
state_copy = deepcopy(state) # To avoid mutation issues
reward = step!(env, y)
sample = DataSample(; x=features, y=y, instance=state_copy, extra=(; reward))
if isempty(labeled_dataset)
labeled_dataset = typeof(sample)[sample]
else
push!(labeled_dataset, sample)
end
total_reward += reward
end
return total_reward, labeled_dataset
end
"""
$TYPEDSIGNATURES
Evaluate the policy on the environment and return the total reward and a dataset of observations.
By default, the environment is reset before running the policy.
"""
function evaluate_policy!(
policy, env::AbstractEnvironment, episodes::Int; seed=get_seed(env), kwargs...
)
rewards = zeros(Float64, episodes)
datasets = map(1:episodes) do _i
if _i == 1
reset!(env; reset_rng=true, seed=seed)
else
reset!(env; reset_rng=false)
end
reward, dataset = evaluate_policy!(policy, env; reset_env=false, kwargs...)
rewards[_i] = reward
return dataset
end
return rewards, datasets
end
"""
$TYPEDSIGNATURES
Run the policy on the environments and return the total rewards and a dataset of observations.
By default, the environments are reset before running the policy.
"""
function evaluate_policy!(
policy, envs::Vector{<:AbstractEnvironment}, episodes::Int=1; kwargs...
)
E = length(envs)
avg_rewards = zeros(Float64, E)
datasets = map(1:E) do e
rewards, datasets = evaluate_policy!(policy, envs[e], episodes; kwargs...)
avg_rewards[e] = sum(rewards) / episodes
dataset = vcat(datasets...)
return dataset
end
return avg_rewards, vcat(datasets...)
end