Skip to content

Commit 1f69bf9

Browse files
authored
Merge pull request #67 from JuliaDecisionFocusedLearning/parametric-anticipative-solvers
Implement anticipative solvers for both stochastic benchmarks
2 parents afa05eb + d228d0d commit 1f69bf9

File tree

7 files changed

+178
-22
lines changed

7 files changed

+178
-22
lines changed

src/ContextualStochasticArgmax/ContextualStochasticArgmax.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,26 @@ end
118118

119119
include("policies.jl")
120120

121+
"""
122+
$TYPEDSIGNATURES
123+
124+
Generates the anticipative solver for the benchmark.
125+
"""
126+
function Utils.generate_anticipative_solver(::ContextualStochasticArgmaxBenchmark)
127+
return AnticipativeSolver()
128+
end
129+
130+
"""
131+
$TYPEDSIGNATURES
132+
133+
Generates the parametric anticipative solver for the benchmark.
134+
"""
135+
function Utils.generate_parametric_anticipative_solver(
136+
::ContextualStochasticArgmaxBenchmark
137+
)
138+
return AnticipativeSolver()
139+
end
140+
121141
export ContextualStochasticArgmaxBenchmark
122142

123143
end

src/ContextualStochasticArgmax/policies.jl

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using Statistics: mean
2-
31
"""
42
$TYPEDSIGNATURES
53
@@ -28,3 +26,35 @@ Each policy has signature `(ctx_sample, scenarios) -> Vector{DataSample}`.
2826
function Utils.generate_baseline_policies(::ContextualStochasticArgmaxBenchmark)
2927
return (; saa=Policy("SAA", "argmax of mean scenarios", csa_saa_policy))
3028
end
29+
30+
"""
31+
$TYPEDEF
32+
33+
A policy that acts with perfect information about the future scenario.
34+
"""
35+
struct AnticipativeSolver end
36+
37+
function Base.show(io::IO, ::AnticipativeSolver)
38+
return print(io, "Anticipative solver for ContextualStochasticArgmaxBenchmark")
39+
end
40+
41+
"""
42+
$TYPEDSIGNATURES
43+
44+
Evaluate the anticipative policy for a given `scenario`.
45+
Returns the optimal action `one_hot_argmax(scenario)`.
46+
"""
47+
function (::AnticipativeSolver)(scenario; context...)
48+
return one_hot_argmax(scenario)
49+
end
50+
51+
"""
52+
$TYPEDSIGNATURES
53+
54+
Evaluate the anticipative policy with a parametric prediction `θ` and a `scenario`.
55+
Returns the optimal action for the combined signal `one_hot_argmax(scenario + θ)`.
56+
"""
57+
function (::AnticipativeSolver)(θ, scenario; context...)
58+
ξ = scenario + θ
59+
return one_hot_argmax(ξ)
60+
end

src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ include("solution/algorithms/mip.jl")
4747
include("solution/algorithms/column_generation.jl")
4848
include("solution/algorithms/local_search.jl")
4949
include("solution/algorithms/deterministic_mip.jl")
50+
include("solution/algorithms/anticipative_solver.jl")
5051

5152
include("maximizer.jl")
5253

@@ -113,13 +114,29 @@ end
113114
$TYPEDSIGNATURES
114115
115116
Return the anticipative solver: a callable `(scenario::VSPScenario; instance, kwargs...) -> y`
116-
that solves the 1-scenario stochastic VSP via column generation.
117+
that solves the 1-scenario stochastic VSP.
118+
119+
# Keyword Arguments
120+
- `model_builder`: a function returning an empty `JuMP.Model` with a solver attached (defaults to `scip_model`).
117121
"""
118-
function Utils.generate_anticipative_solver(::StochasticVehicleSchedulingBenchmark)
119-
return (scenario::VSPScenario; instance::Instance, kwargs...) -> begin
120-
stochastic_inst = build_stochastic_instance(instance, [scenario])
121-
return column_generation_algorithm(stochastic_inst)
122-
end
122+
function Utils.generate_anticipative_solver(
123+
::StochasticVehicleSchedulingBenchmark; model_builder=scip_model
124+
)
125+
return AnticipativeSolver(; model_builder=model_builder)
126+
end
127+
128+
"""
129+
$TYPEDSIGNATURES
130+
131+
Return the parametric anticipative solver: a callable `(θ, scenario::VSPScenario; instance, kwargs...) -> y`.
132+
133+
# Keyword Arguments
134+
- `model_builder`: a function returning an empty `JuMP.Model` with a solver attached (defaults to `scip_model`).
135+
"""
136+
function Utils.generate_parametric_anticipative_solver(
137+
::StochasticVehicleSchedulingBenchmark; model_builder=scip_model
138+
)
139+
return AnticipativeSolver(; model_builder=model_builder)
123140
end
124141

125142
"""
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
@kwdef struct AnticipativeSolver{M,A}
2+
model_builder::M = scip_model
3+
single_scenario_algorithm::A = compact_mip
4+
end
5+
6+
function Base.show(io::IO, ::AnticipativeSolver)
7+
return print(io, "Anticipative solver for StochasticVehicleSchedulingBenchmark")
8+
end
9+
10+
function (solver::AnticipativeSolver)(scenario; instance::Instance, kwargs...)
11+
stochastic_inst = build_stochastic_instance(instance, [scenario])
12+
return solver.single_scenario_algorithm(
13+
stochastic_inst; model_builder=solver.model_builder, kwargs...
14+
)
15+
end
16+
17+
function (solver::AnticipativeSolver)(θ, scenario; instance::Instance, kwargs...)
18+
stochastic_inst = build_stochastic_instance(instance, [scenario])
19+
return solver.single_scenario_algorithm(
20+
stochastic_inst, θ; model_builder=solver.model_builder, kwargs...
21+
)
22+
end

src/StochasticVehicleScheduling/solution/algorithms/mip.jl

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ Quadratic constraints are linearized using Mc Cormick linearization.
66
Note: If you have Gurobi, use `grb_model` as `model_builder` instead of `highs_model`.
77
"""
88
function compact_linearized_mip(
9-
instance::Instance; scenario_range=nothing, model_builder=scip_model, silent=true
9+
instance::Instance,
10+
θ=nothing;
11+
scenario_range=nothing,
12+
model_builder=scip_model,
13+
silent=true,
1014
)
1115
(; graph, slacks, intrinsic_delays, vehicle_cost, delay_cost) = instance
1216
nb_nodes = nv(graph)
@@ -28,13 +32,17 @@ function compact_linearized_mip(
2832
@variable(model, R[v in nodes, ω in Ω] >= 0) # propagated delay of job v
2933
@variable(model, yR[u in nodes, v in nodes, ω in Ω; has_edge(graph, u, v)] >= 0) # yR[u, v] = y[u, v] * R[u, ω]
3034

31-
@objective(
32-
model,
33-
Min,
35+
obj = (
3436
delay_cost * sum(sum(R[v, ω] for v in job_indices) for ω in Ω) / nb_scenarios # average total delay
35-
+
36-
vehicle_cost * sum(y[1, v] for v in job_indices) # nb_vehicles
37+
+
38+
vehicle_cost * sum(y[1, v] for v in job_indices) # nb_vehicles
3739
)
40+
if !isnothing(θ)
41+
@assert length(θ) == ne(graph)
42+
obj += sum(θ[a] * y[src(edge), dst(edge)] for (a, edge) in enumerate(edges(graph)))
43+
end
44+
45+
@objective(model, Min, obj)
3846

3947
# Flow contraints
4048
@constraint(
@@ -103,7 +111,11 @@ Note: If you have Gurobi, use `grb_model` as `model_builder` instead of `highs_m
103111
You need to use a solver that supports quadratic constraints to use this method.
104112
"""
105113
function compact_mip(
106-
instance::Instance; scenario_range=nothing, model_builder=scip_model, silent=true
114+
instance::Instance,
115+
θ=nothing;
116+
scenario_range=nothing,
117+
model_builder=scip_model,
118+
silent=true,
107119
)
108120
(; graph, slacks, intrinsic_delays, vehicle_cost, delay_cost) = instance
109121
nb_nodes = nv(graph)
@@ -124,13 +136,17 @@ function compact_mip(
124136
@variable(model, R[v in nodes, ω in Ω] >= 0) # propagated delay of job v
125137
@variable(model, yR[u in nodes, v in nodes, ω in Ω; has_edge(graph, u, v)] >= 0) # yR[u, v] = y[u, v] * R[u, ω]
126138

127-
@objective(
128-
model,
129-
Min,
139+
obj = (
130140
delay_cost * sum(sum(R[v, ω] for v in job_indices) for ω in Ω) / nb_scenarios # average total delay
131-
+
132-
vehicle_cost * sum(y[1, v] for v in job_indices) # nb_vehicles
141+
+
142+
vehicle_cost * sum(y[1, v] for v in job_indices) # nb_vehicles
133143
)
144+
if !isnothing(θ)
145+
@assert length(θ) == ne(graph)
146+
obj += sum(θ[a] * y[src(edge), dst(edge)] for (a, edge) in enumerate(edges(graph)))
147+
end
148+
149+
@objective(model, Min, obj)
134150

135151
# Flow contraints
136152
@constraint(
@@ -170,6 +186,8 @@ $TYPEDSIGNATURES
170186
SAA variant: build stochastic instance from `scenarios` then solve via
171187
[`compact_mip`](@ref).
172188
"""
173-
function compact_mip(instance::Instance, scenarios::Vector{VSPScenario}; kwargs...)
174-
return compact_mip(build_stochastic_instance(instance, scenarios); kwargs...)
189+
function compact_mip(
190+
instance::Instance, scenarios::Vector{VSPScenario}, θ=nothing; kwargs...
191+
)
192+
return compact_mip(build_stochastic_instance(instance, scenarios), θ; kwargs...)
175193
end

test/contextual_stochastic_argmax.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,31 @@
3838
@test first(d1).x first(d2).x
3939
end
4040

41+
@testset "Parametric Anticipative Solver - ContextualStochasticArgmax" begin
42+
using DecisionFocusedLearningBenchmarks
43+
44+
b = ContextualStochasticArgmaxBenchmark(; n=5, d=3, seed=0)
45+
dataset = generate_dataset(b, 2; contexts_per_instance=1, nb_scenarios=1)
46+
sample = first(dataset)
47+
scenario = generate_scenario(b, StableRNG(0); sample.context...)
48+
49+
solver = generate_anticipative_solver(b)
50+
parametric_solver = generate_parametric_anticipative_solver(b)
51+
52+
# 1. Zero perturbation equivalence
53+
θ_zero = zeros(eltype(scenario), size(scenario))
54+
@test parametric_solver(θ_zero, scenario; sample.context...) ==
55+
solver(scenario; sample.context...)
56+
57+
# 2. Extreme perturbation
58+
θ_extreme = zeros(eltype(scenario), size(scenario))
59+
θ_extreme[1] = 1000.0 # Force dimension 1
60+
y_extreme = parametric_solver(θ_extreme, scenario; sample.context...)
61+
62+
@test y_extreme[1] == 1.0 # Only dimension 1 should be active
63+
@test sum(y_extreme) 1.0 # One-hot preserved
64+
end
65+
4166
@testset "csa_saa_policy" begin
4267
using DecisionFocusedLearningBenchmarks
4368

test/vsp.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,28 @@
9595
sample = unlabeled[1]
9696
y_anticipative = anticipative_solver(sample.scenario; sample.context...)
9797
@test y_anticipative isa BitVector
98+
99+
# Extract necessary dependencies
100+
parametric_solver = generate_parametric_anticipative_solver(b)
101+
nb_edges = ne(sample.instance.graph)
102+
103+
# 1. Zero perturbation equivalence
104+
θ_zero = zeros(nb_edges)
105+
y_zero = parametric_solver(θ_zero, sample.scenario; sample.context...)
106+
107+
@test y_zero == y_anticipative
108+
@test y_zero isa BitVector
109+
110+
# 2. Perturbation execution
111+
θ_random = randn(nb_edges)
112+
y_rand = parametric_solver(θ_random, sample.scenario; sample.context...)
113+
114+
@test length(y_rand) == nb_edges
115+
@test y_rand isa BitVector
116+
117+
# 3. High negative perturbation on edge forces activation (it's minimization)
118+
θ_extreme = zeros(nb_edges)
119+
θ_extreme[1] = -100000.0 # large negative pull for edge 1
120+
y_extreme = parametric_solver(θ_extreme, sample.scenario; sample.context...)
121+
@test y_extreme[1] == 1.0 # BitVector
98122
end

0 commit comments

Comments
 (0)