Skip to content

Commit 76d3fe3

Browse files
committed
Add tests
1 parent 9a305bf commit 76d3fe3

File tree

4 files changed

+63
-7
lines changed

4 files changed

+63
-7
lines changed

src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,10 @@ $TYPEDSIGNATURES
116116
Return the anticipative solver: a callable `(scenario::VSPScenario; instance, kwargs...) -> y`
117117
that solves the 1-scenario stochastic VSP.
118118
"""
119-
function Utils.generate_anticipative_solver(::StochasticVehicleSchedulingBenchmark)
120-
return AnticipativeSolver()
119+
function Utils.generate_anticipative_solver(
120+
::StochasticVehicleSchedulingBenchmark; model_builder=scip_model
121+
)
122+
return AnticipativeSolver(; model_builder=model_builder)
121123
end
122124

123125
"""
@@ -126,9 +128,9 @@ $TYPEDSIGNATURES
126128
Return the parametric anticipative solver: a callable `(θ, scenario::VSPScenario; instance, kwargs...) -> y`.
127129
"""
128130
function Utils.generate_parametric_anticipative_solver(
129-
::StochasticVehicleSchedulingBenchmark
131+
::StochasticVehicleSchedulingBenchmark; model_builder=scip_model
130132
)
131-
return AnticipativeSolver()
133+
return AnticipativeSolver(; model_builder=model_builder)
132134
end
133135

134136
"""
Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
@kwdef struct AnticipativeSolver{A}
1+
@kwdef struct AnticipativeSolver{M,A}
2+
model_builder::M = scip_model
23
single_scenario_algorithm::A = compact_mip
34
end
45

@@ -8,10 +9,14 @@ end
89

910
function (solver::AnticipativeSolver)(scenario; instance::Instance, kwargs...)
1011
stochastic_inst = build_stochastic_instance(instance, [scenario])
11-
return solver.single_scenario_algorithm(stochastic_inst)
12+
return solver.single_scenario_algorithm(
13+
stochastic_inst; model_builder=solver.model_builder, kwargs...
14+
)
1215
end
1316

1417
function (solver::AnticipativeSolver)(θ, scenario; instance::Instance, kwargs...)
1518
stochastic_inst = build_stochastic_instance(instance, [scenario])
16-
return solver.single_scenario_algorithm(stochastic_inst, θ)
19+
return solver.single_scenario_algorithm(
20+
stochastic_inst, θ; model_builder=solver.model_builder, kwargs...
21+
)
1722
end

test/contextual_stochastic_argmax.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,31 @@
3838
@test first(d1).x first(d2).x
3939
end
4040

41+
@testset "Parametric Anticipative Solver - ContextualStochasticArgmax" begin
42+
using DecisionFocusedLearningBenchmarks
43+
44+
b = ContextualStochasticArgmaxBenchmark(; n=5, d=3, seed=0)
45+
dataset = generate_dataset(b, 2; contexts_per_instance=1, nb_scenarios=1)
46+
sample = first(dataset)
47+
scenario = generate_scenario(b, StableRNG(0); sample.context...)
48+
49+
solver = generate_anticipative_solver(b)
50+
parametric_solver = generate_parametric_anticipative_solver(b)
51+
52+
# 1. Zero perturbation equivalence
53+
θ_zero = zeros(eltype(scenario), size(scenario))
54+
@test parametric_solver(θ_zero, scenario; sample.context...) ==
55+
solver(scenario; sample.context...)
56+
57+
# 2. Extreme perturbation
58+
θ_extreme = zeros(eltype(scenario), size(scenario))
59+
θ_extreme[1] = 1000.0 # Force dimension 1
60+
y_extreme = parametric_solver(θ_extreme, scenario; sample.context...)
61+
62+
@test y_extreme[1] == 1.0 # Only dimension 1 should be active
63+
@test sum(y_extreme) 1.0 # One-hot preserved
64+
end
65+
4166
@testset "csa_saa_policy" begin
4267
using DecisionFocusedLearningBenchmarks
4368

test/vsp.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,28 @@
9595
sample = unlabeled[1]
9696
y_anticipative = anticipative_solver(sample.scenario; sample.context...)
9797
@test y_anticipative isa BitVector
98+
99+
# Extract necessary dependencies
100+
parametric_solver = generate_parametric_anticipative_solver(b)
101+
nb_edges = ne(sample.instance.graph)
102+
103+
# 1. Zero perturbation equivalence
104+
θ_zero = zeros(nb_edges)
105+
y_zero = parametric_solver(θ_zero, sample.scenario; sample.context...)
106+
107+
@test y_zero == y_anticipative
108+
@test y_zero isa BitVector
109+
110+
# 2. Perturbation execution
111+
θ_random = randn(nb_edges)
112+
y_rand = parametric_solver(θ_random, sample.scenario; sample.context...)
113+
114+
@test length(y_rand) == nb_edges
115+
@test y_rand isa BitVector
116+
117+
# 3. High negative perturbation on edge forces activation (it's minimization)
118+
θ_extreme = zeros(nb_edges)
119+
θ_extreme[1] = -100000.0 # large negative pull for edge 1
120+
y_extreme = parametric_solver(θ_extreme, sample.scenario; sample.context...)
121+
@test y_extreme[1] == 1.0 # BitVector
98122
end

0 commit comments

Comments
 (0)