Skip to content

Commit 6516cc6

Browse files
committed
fixed buggs
1 parent 64226f5 commit 6516cc6

7 files changed

Lines changed: 400 additions & 3 deletions

File tree

src/DecisionFocusedLearningBenchmarks.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ include("PortfolioOptimization/PortfolioOptimization.jl")
5757
include("StochasticVehicleScheduling/StochasticVehicleScheduling.jl")
5858
include("DynamicVehicleScheduling/DynamicVehicleScheduling.jl")
5959
include("DynamicAssortment/DynamicAssortment.jl")
60+
include("Maintenance/Maintenance.jl")
6061

6162
using .Utils
6263

@@ -89,6 +90,7 @@ using .PortfolioOptimization
8990
using .StochasticVehicleScheduling
9091
using .DynamicVehicleScheduling
9192
using .DynamicAssortment
93+
using .Maintenance
9294

9395
export Argmax2DBenchmark
9496
export ArgmaxBenchmark
@@ -100,5 +102,6 @@ export RankingBenchmark
100102
export StochasticVehicleSchedulingBenchmark
101103
export SubsetSelectionBenchmark
102104
export WarcraftBenchmark
105+
export MaintenanceBenchmark
103106

104107
end # module DecisionFocusedLearningBenchmarks

src/DynamicAssortment/environment.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,16 +197,21 @@ Features observed by the agent at current step, as a concatenation of:
197197
- change in hype and saturation features from the starting state
198198
- normalized current step (divided by max steps and multiplied by 10)
199199
All features are normalized by dividing by 10.
200+
201+
State
200202
"""
201203
function Utils.observe(env::Environment)
202204
delta_features = env.features[2:3, :] .- env.instance.starting_hype_and_saturation
203-
return vcat(
205+
features = vcat(
204206
env.features,
205207
env.d_features,
206208
delta_features,
207209
ones(1, item_count(env)) .* (env.step / max_steps(env) * 10),
208-
) ./ 10,
209-
nothing
210+
) ./ 10
211+
212+
state = (env.features, env.purchase_history)
213+
214+
return features, state
210215
end
211216

212217
"""

src/Maintenance/Maintenance.jl

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
module Maintenance
2+
3+
using ..Utils
4+
5+
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES, SIGNATURES
6+
using Distributions: Uniform, Categorical
7+
using Flux: Chain, Dense
8+
using LinearAlgebra: dot
9+
using Random: Random, AbstractRNG, MersenneTwister
10+
using Statistics: mean
11+
12+
using Combinatorics: combinations
13+
14+
"""
15+
$TYPEDEF
16+
17+
Benchmark for a standard maintenance problem with resource constraints.
18+
Components are identical and degrade idependently over time.
19+
A high cost is incurred for each component that reaches the final degradation level.
20+
A cost is also incurred for maintaining a component.
21+
The number of simultaneous maintenance operations is limited by a maintenance capacity constraint.
22+
23+
# Fields
24+
$TYPEDFIELDS
25+
26+
"""
27+
struct MaintenanceBenchmark <: AbstractDynamicBenchmark{true}
28+
"number of components"
29+
N::Int
30+
"maximum number of components that can be maintained simultaneously"
31+
K::Int
32+
"number of degradation states per component"
33+
n::Int
34+
"degradation probability"
35+
p::Float64
36+
"failure cost"
37+
c_f::Float64
38+
"maintenance cost"
39+
c_m::Float64
40+
"number of steps per episode"
41+
max_steps::Int
42+
end
43+
44+
"""
45+
MaintenanceBenchmark(;
46+
N=2,
47+
K=1,
48+
n=3,
49+
p=0.2
50+
c_f=10.0,
51+
c_m=3.0,
52+
max_steps=10,
53+
)
54+
55+
Constructor for [`MaintenanceBenchmark`](@ref).
56+
By default, the benchmark has 2 components, maintenance capacity 1, number of degradation levels 3,
57+
degradation probability 0.2, failure cost 10.0, maintenance cost 3.0, 10 steps per episode, and is exogenous.
58+
"""
59+
60+
function MaintenanceBenchmark(;
61+
N=2,
62+
K=1,
63+
n=3,
64+
p=0.2,
65+
c_f=10.0,
66+
c_m=3.0,
67+
max_steps=10,
68+
)
69+
return MaintenanceBenchmark(
70+
N, K, n, p, c_f, c_m, max_steps
71+
)
72+
end
73+
74+
# Accessor functions
75+
component_count(b::MaintenanceBenchmark) = b.N
76+
maintenance_capacity(b::MaintenanceBenchmark) = b.K
77+
degradation_levels(b::MaintenanceBenchmark) = b.n
78+
degradation_probability(b::MaintenanceBenchmark) = b.p
79+
failure_cost(b::MaintenanceBenchmark) = b.c_f
80+
maintenance_cost(b::MaintenanceBenchmark) = b.c_m
81+
max_steps(b::MaintenanceBenchmark) = b.max_steps
82+
83+
include("instance.jl")
84+
include("environment.jl")
85+
include("policies.jl")
86+
include("maximizer.jl")
87+
88+
"""
89+
$TYPEDSIGNATURES
90+
91+
Outputs a data sample containing an [`Instance`](@ref).
92+
"""
93+
function Utils.generate_sample(
94+
b::MaintenanceBenchmark, rng::AbstractRNG=MersenneTwister(0)
95+
)
96+
return DataSample(; instance=Instance(b, rng))
97+
end
98+
99+
"""
100+
$TYPEDSIGNATURES
101+
102+
Generates a statistical model for the maintenance benchmark.
103+
The model is a small neural network with one hidden layer no activation function.
104+
"""
105+
function Utils.generate_statistical_model(b::MaintenanceBenchmark; seed=nothing)
106+
Random.seed!(seed)
107+
N = component_count(b)
108+
return Chain(Dense(N => N), Dense(N => N), vec)
109+
end
110+
111+
"""
112+
$TYPEDSIGNATURES
113+
114+
Outputs a top k maximizer, with k being the maintenance capacity of the benchmark.
115+
"""
116+
function Utils.generate_maximizer(b::MaintenanceBenchmark)
117+
return TopKPositiveMaximizer(maintenance_capacity(b))
118+
end
119+
120+
"""
121+
$TYPEDSIGNATURES
122+
123+
Creates an [`Environment`](@ref) from an [`Instance`](@ref) of the maintenance benchmark.
124+
The seed of the environment is randomly generated using the provided random number generator.
125+
"""
126+
function Utils.generate_environment(
127+
::MaintenanceBenchmark, instance::Instance, rng::AbstractRNG; kwargs...
128+
)
129+
seed = rand(rng, 1:typemax(Int))
130+
return Environment(instance; seed)
131+
end
132+
133+
"""
134+
$TYPEDSIGNATURES
135+
136+
Returns two policies for the dynamic assortment benchmark:
137+
- `Greedy`: maintains components when they are in the last state before failure, up to the maintenance capacity
138+
"""
139+
function Utils.generate_policies(::MaintenanceBenchmark)
140+
greedy = Policy(
141+
"Greedy",
142+
"policy that maintains components when they are in the last state before failure, up to the maintenance capacity",
143+
greedy_policy,
144+
)
145+
return (greedy)
146+
end
147+
148+
export MaintenanceBenchmark
149+
150+
end

src/Maintenance/environment.jl

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
"""
2+
$TYPEDEF
3+
4+
Environment for the maintenance problem.
5+
6+
# Fields
7+
$TYPEDFIELDS
8+
"""
9+
@kwdef mutable struct Environment{I<:Instance,R<:AbstractRNG,S<:Union{Nothing,Int}} <:
10+
Utils.AbstractEnvironment
11+
"associated instance"
12+
instance::I
13+
"current step"
14+
step::Int
15+
"degradation state"
16+
degradation_state::Vector{Int}
17+
"rng"
18+
rng::R
19+
"seed for RNG"
20+
seed::S
21+
end
22+
23+
"""
24+
$TYPEDSIGNATURES
25+
26+
Creates an [`Environment`](@ref) from an [`Instance`](@ref) of the maintenance benchmark.
27+
"""
28+
function Environment(instance::Instance; seed=0, rng::AbstractRNG=MersenneTwister(seed))
29+
degradation_state = starting_state(instance)
30+
env = Environment(;
31+
instance,
32+
step=1,
33+
degradation_state,
34+
rng=rng,
35+
seed=seed,
36+
)
37+
Utils.reset!(env; reset_rng=true)
38+
return env
39+
end
40+
41+
component_count(env::Environment) = component_count(env.instance)
42+
maintenance_capacity(env::Environment) = maintenance_capacity(env.instance)
43+
degradation_levels(env::Environment) = degradation_levels(env.instance)
44+
degradation_probability(env::Environment) = degradation_probability(env.instance)
45+
failure_cost(env::Environment) = failure_cost(env.instance)
46+
maintenance_cost(env::Environment) = maintenance_cost(env.instance)
47+
max_steps(env::Environment) = max_steps(env.instance)
48+
starting_state(env::Environment) = starting_state(env.instance)
49+
50+
51+
"""
52+
$TYPEDSIGNATURES
53+
Draw random degradations for all components.
54+
"""
55+
56+
function degrad!(env::Environment)
57+
N = component_count(env)
58+
n = degradation_levels(env)
59+
p = degradation_probability(env)
60+
61+
for i in 1:N
62+
if env.degradation_state[i] < n && rand() < p
63+
env.degradation_state[i] += 1
64+
end
65+
end
66+
67+
return env.degradation_state
68+
end
69+
70+
"""
71+
$TYPEDSIGNATURES
72+
Maintain components.
73+
"""
74+
75+
function maintain!(env::Environment, maintenance::BitVector)
76+
N = component_count(env)
77+
78+
for i in 1:N
79+
if maintenance[i]
80+
env.degradation_state[i] = 1
81+
end
82+
end
83+
84+
return env.degradation_state
85+
end
86+
87+
"""
88+
$TYPEDSIGNATURES
89+
90+
Compute maintenance cost.
91+
"""
92+
function maintenance_cost(env::Environment, maintenance::BitVector)
93+
return maintenance_cost(env) * sum(maintenance)
94+
end
95+
96+
"""
97+
$TYPEDSIGNATURES
98+
99+
Compute degradation cost.
100+
"""
101+
function degradation_cost(env::Environment)
102+
N = component_count(env)
103+
n = degradation_levels(env)
104+
return failure_cost(env) * count(==(n), env.degradation_state)
105+
end
106+
107+
108+
"""
109+
$TYPEDSIGNATURES
110+
111+
Outputs the seed of the environment.
112+
"""
113+
Utils.get_seed(env::Environment) = env.seed
114+
115+
"""
116+
$TYPEDSIGNATURES
117+
118+
Resets the environment to the initial state:
119+
- reset the rng if `reset_rng` is true
120+
- reset the step to 1
121+
- reset the degradation state to the starting state
122+
"""
123+
function Utils.reset!(env::Environment; reset_rng=false, seed=env.seed)
124+
reset_rng && Random.seed!(env.rng, seed)
125+
env.step = 1
126+
env.degradation_state .= starting_state(env)
127+
return nothing
128+
end
129+
130+
"""
131+
$TYPEDSIGNATURES
132+
133+
Checks if the environment has reached the maximum number of steps.
134+
"""
135+
function Utils.is_terminated(env::Environment)
136+
return env.step > max_steps(env)
137+
end
138+
139+
"""
140+
$TYPEDSIGNATURES
141+
142+
Returns features, state tuple.
143+
The features observed by the agent at current step are the degradation states of all components.
144+
It is also the internal state, so we return the same thing twice.
145+
146+
"""
147+
function Utils.observe(env::Environment)
148+
state = env.degradation_state
149+
return state, state
150+
end
151+
152+
"""
153+
$TYPEDSIGNATURES
154+
155+
Performs one step in the environment given a maintenance.
156+
Draw random degradations for components that are not maintained.
157+
"""
158+
function Utils.step!(env::Environment, maintenance::BitVector)
159+
@assert !Utils.is_terminated(env) "Environment is terminated, cannot act!"
160+
reward = maintenance_cost(env, maintenance) + degradation_cost(env)
161+
degrad!(env)
162+
maintain!(env, maintenance)
163+
env.step += 1
164+
return reward
165+
end
166+
167+

0 commit comments

Comments
 (0)