Skip to content

Commit 0cf9022

Browse files
committed
Added maintenance benchmark
1 parent 6516cc6 commit 0cf9022

5 files changed

Lines changed: 255 additions & 6 deletions

File tree

src/Maintenance/Maintenance.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ function MaintenanceBenchmark(;
6464
p=0.2,
6565
c_f=10.0,
6666
c_m=3.0,
67-
max_steps=10,
67+
max_steps=80,
6868
)
6969
return MaintenanceBenchmark(
7070
N, K, n, p, c_f, c_m, max_steps
@@ -142,7 +142,7 @@ function Utils.generate_policies(::MaintenanceBenchmark)
142142
"policy that maintains components when they are in the last state before failure, up to the maintenance capacity",
143143
greedy_policy,
144144
)
145-
return (greedy)
145+
return (greedy,)
146146
end
147147

148148
export MaintenanceBenchmark

src/Maintenance/environment.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,11 @@ Draw random degradations for components that are not maintained.
157157
"""
158158
function Utils.step!(env::Environment, maintenance::BitVector)
159159
@assert !Utils.is_terminated(env) "Environment is terminated, cannot act!"
160-
reward = maintenance_cost(env, maintenance) + degradation_cost(env)
160+
cost = maintenance_cost(env, maintenance) + degradation_cost(env)
161161
degrad!(env)
162162
maintain!(env, maintenance)
163163
env.step += 1
164-
return reward
164+
return cost
165165
end
166166

167167

src/Maintenance/policies.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,23 @@ $TYPEDSIGNATURES
55
Greedy policy that maintains components when they are in the last state before failure, up to the maintenance capacity.
66
"""
77
function greedy_policy(env::Environment)
8-
maximizer = generate_maximizer(env.instance.config)
9-
return maximizer(prices(env)[1:item_count(env)])
8+
state = env.degradation_state
9+
N = component_count(env)
10+
K = maintenance_capacity(env)
11+
res = falses(N)
12+
n = degradation_levels(env)
13+
14+
15+
idx_max = findall(==(n), state)
16+
take = first(idx_max, min(K, length(idx_max)))
17+
res[take] .= true
18+
remaining = K - length(take)
19+
20+
if remaining > 0
21+
idx_second = findall(==(n-1), state)
22+
take2 = first(idx_second, min(remaining, length(idx_second)))
23+
res[take2] .= true
24+
end
25+
26+
return res
1027
end

test/maintenance.jl

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
const maintenance = DecisionFocusedLearningBenchmarks.Maintenance
2+
3+
@testset "Maintenance - Benchmark Construction" begin
4+
# Test default constructor
5+
b = MaintenanceBenchmark()
6+
@test b.N == 2
7+
@test b.K == 1
8+
@test b.n == 3
9+
@test b.p == 0.2
10+
@test b.c_f == 10.0
11+
@test b.c_m == 3.0
12+
@test b.max_steps == 80
13+
@test is_exogenous(b)
14+
@test !is_endogenous(b)
15+
16+
# Test custom constructor
17+
b_custom = MaintenanceBenchmark(; N=10, K=3, n=5, p=0.3, c_f=5.0, c_m=3.0, max_steps=50)
18+
@test b_custom.N == 10
19+
@test b_custom.K == 3
20+
@test b_custom.n == 5
21+
@test b_custom.p == 0.3
22+
@test b_custom.c_f == 5.0
23+
@test b_custom.c_m == 3.0
24+
@test b_custom.max_steps == 50
25+
26+
# Test accessor functions
27+
@test maintenance.component_count(b) == 2
28+
@test maintenance.maintenance_capacity(b) == 1
29+
@test maintenance.degradation_levels(b) == 3
30+
@test maintenance.degradation_probability(b) == 0.2
31+
@test maintenance.failure_cost(b) == 10.0
32+
@test maintenance.maintenance_cost(b) == 3.0
33+
@test maintenance.max_steps(b) == 80
34+
end
35+
36+
@testset "Maintenance - Instance Generation" begin
37+
b = MaintenanceBenchmark(; N=10, K=3, n=5, p=0.3, c_f=5.0, c_m=3.0, max_steps=50)
38+
rng = MersenneTwister(42)
39+
40+
instance = maintenance.Instance(b, rng)
41+
42+
# test state is randomly initialized
43+
state1 = maintenance.starting_state(instance)
44+
rng2 = MersenneTwister(43)
45+
instance2 = maintenance.Instance(b, rng2)
46+
state2 = maintenance.starting_state(instance2)
47+
@test state1 != state2
48+
49+
# Test instance structure
50+
@test length(instance.starting_state) == 10
51+
@test all(1.0 s 5 for s in instance.starting_state)
52+
53+
# Test accessor functions
54+
@test maintenance.component_count(instance) == 10
55+
@test maintenance.maintenance_capacity(instance) == 3
56+
@test maintenance.degradation_levels(instance) == 5
57+
@test maintenance.degradation_probability(instance) == 0.3
58+
@test maintenance.failure_cost(instance) == 5.0
59+
@test maintenance.maintenance_cost(instance) == 3.0
60+
@test maintenance.max_steps(instance) == 50
61+
end
62+
63+
@testset "Maintenance - Environment Initialization" begin
64+
b = MaintenanceBenchmark()
65+
instance = maintenance.Instance(b, MersenneTwister(42))
66+
67+
env = maintenance.Environment(instance; seed=123)
68+
69+
# Test initial state
70+
@test env.step == 1
71+
@test env.seed == 123
72+
@test !is_terminated(env)
73+
74+
# Test accessor functions
75+
@test maintenance.component_count(env) == 2
76+
@test maintenance.maintenance_capacity(env) == 1
77+
@test maintenance.degradation_levels(env) == 3
78+
@test maintenance.degradation_probability(env) == 0.2
79+
@test maintenance.failure_cost(env) == 10.0
80+
@test maintenance.maintenance_cost(env) == 3.0
81+
@test maintenance.max_steps(env) == 80
82+
end
83+
84+
@testset "Maintenance - Environment Reset" begin
85+
b = MaintenanceBenchmark()
86+
instance = maintenance.Instance(b, MersenneTwister(42))
87+
env = maintenance.Environment(instance; seed=123)
88+
89+
# Modify environment state
90+
env.step = 3
91+
92+
# Reset environment
93+
reset!(env)
94+
95+
# Check reset state
96+
@test env.step == 1
97+
end
98+
99+
@testset "Maintenance - Cost" begin
100+
b = MaintenanceBenchmark()
101+
instance = maintenance.Instance(b, MersenneTwister(42))
102+
env = maintenance.Environment(instance; seed=123)
103+
104+
env.degradation_state = [1,1]
105+
@test maintenance.maintenance_cost(env, BitVector([false, false])) == 0.0
106+
@test maintenance.maintenance_cost(env, BitVector([false, true])) == 3.0
107+
@test maintenance.maintenance_cost(env, BitVector([true, true])) == 6.0
108+
109+
@test maintenance.degradation_cost(env) == 0.0
110+
env.degradation_state = [2,2]
111+
@test maintenance.degradation_cost(env) == 0.0
112+
env.degradation_state = [3,2]
113+
@test maintenance.degradation_cost(env) == 10.0
114+
env.degradation_state = [3,3]
115+
@test maintenance.degradation_cost(env) == 20.0
116+
end
117+
118+
@testset "Maintenance - Environment Step" begin
119+
b = MaintenanceBenchmark()
120+
instance = maintenance.Instance(b, MersenneTwister(42))
121+
env = maintenance.Environment(instance; seed=123)
122+
123+
maintenance_vect = BitVector([false, false])
124+
125+
initial_step = env.step
126+
# Take a step
127+
reward = step!(env, maintenance_vect)
128+
129+
# Check step progression
130+
@test env.step == initial_step + 1
131+
@test reward 0.0 # Reward should be non-negative
132+
133+
# Test termination
134+
for _ in 1:(maintenance.max_steps(env) - 1)
135+
if !is_terminated(env)
136+
step!(env, maintenance_vect)
137+
end
138+
end
139+
@test is_terminated(env)
140+
141+
# Test error on terminated environment
142+
@test_throws AssertionError step!(env, maintenance_vect)
143+
end
144+
145+
@testset "Maintenance - Observation" begin
146+
b = MaintenanceBenchmark()
147+
instance = maintenance.Instance(b, MersenneTwister(42))
148+
env = maintenance.Environment(instance; seed=123)
149+
env.degradation_state = [1,1]
150+
151+
state, features = observe(env)
152+
153+
@test state == [1,1]
154+
@test features === state
155+
156+
env.degradation_state = [2,3]
157+
state2, _ = observe(env)
158+
159+
@test state != state2 # Observations should differ after purchase
160+
end
161+
162+
163+
@testset "Maintenance - Policies" begin
164+
using Statistics: mean
165+
166+
b = MaintenanceBenchmark()
167+
168+
# Generate test data
169+
dataset = generate_dataset(b, 10; seed=0)
170+
environments = generate_environments(b, dataset)
171+
172+
# Get policies
173+
policies = generate_policies(b)
174+
greedy = policies[1]
175+
176+
@test greedy.name == "Greedy"
177+
178+
# Test policy evaluation
179+
r_greedy, _ = evaluate_policy!(greedy, environments, 10)
180+
181+
@test length(r_greedy) == length(environments)
182+
@test all(r_greedy .≥ 0.0)
183+
184+
# Test policy output format
185+
env = environments[1]
186+
reset!(env)
187+
188+
greedy_action = greedy(env)
189+
@test greedy_action isa BitVector && length(greedy_action) == 2
190+
end
191+
192+
193+
@testset "Maintenance - Model and Maximizer Integration" begin
194+
b = MaintenanceBenchmark()
195+
196+
# Test statistical model generation
197+
model = generate_statistical_model(b; seed=42)
198+
# Test maximizer generation
199+
maximizer = generate_maximizer(b)
200+
201+
# Test integration with sample data
202+
sample = generate_sample(b, MersenneTwister(42))
203+
@test hasfield(typeof(sample), :info)
204+
205+
dataset = generate_dataset(b, 3; seed=42)
206+
environments = generate_environments(b, dataset)
207+
208+
# Evaluate policy to get data samples
209+
policies = generate_policies(b)
210+
_, data_samples = evaluate_policy!(policies[1], environments)
211+
212+
# Test model-maximizer pipeline
213+
sample = data_samples[1]
214+
x = sample.x
215+
θ = model(x)
216+
y = maximizer(θ)
217+
218+
@test length(θ) == 2
219+
220+
θ = [1,2]
221+
@test maximizer(θ) == BitVector([false, true])
222+
223+
b = MaintenanceBenchmark(; N=10, K=3, n=5, p=0.3, c_f=5.0, c_m=3.0, max_steps=50)
224+
θ = [i for i in 1:10]
225+
maximizer = generate_maximizer(b)
226+
@test maximizer(θ) == BitVector([false, false, false, false, false, false, false, true, true, true])
227+
228+
229+
230+
#test maximizer output
231+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ using Random
1414
include("ranking.jl")
1515
include("subset_selection.jl")
1616
include("fixed_size_shortest_path.jl")
17+
include("maintenance.jl")
1718
include("warcraft.jl")
1819
include("vsp.jl")
1920
include("portfolio_optimization.jl")

0 commit comments

Comments
 (0)