Skip to content

Commit 89462c0

Browse files
committed
Support for contextual stochastic optimization problems
1 parent 0f5590e commit 89462c0

22 files changed

Lines changed: 485 additions & 89 deletions

docs/src/tutorials/warcraft_tutorial.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ x = sample.x
3131
# `y` correspond to the optimal shortest path, encoded as a binary matrix:
3232
y_true = sample.y
3333
# `context` is not used in this benchmark (no solver kwargs needed), so it is empty:
34-
isempty(sample.context)
34+
isempty(sample.instance_kwargs)
3535

3636
# For some benchmarks, we provide the following plotting method [`plot_solution`](@ref) to visualize the data:
3737
plot_solution(b, sample)

ext/DFLBenchmarksPlotsExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ Reconstruct a new sample with `y` overridden and delegate to the 2-arg
2121
function plot_solution(bench::AbstractBenchmark, sample::DataSample, y; kwargs...)
2222
return plot_solution(
2323
bench,
24-
DataSample(; sample.context..., x=sample.x, θ=sample.θ, y=y, extra=sample.extra);
24+
DataSample(;
25+
sample.instance_kwargs..., x=sample.x, θ=sample.θ, y=y, extra=sample.extra
26+
);
2527
kwargs...,
2628
)
2729
end
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
module ContextualStochasticArgmax
2+
3+
using ..Utils
4+
using DocStringExtensions: TYPEDEF, TYPEDFIELDS
5+
using Flux: Dense
6+
using Random: Random, AbstractRNG, MersenneTwister
7+
8+
function one_hot_argmax(z::AbstractVector{R}; kwargs...) where {R<:Real}
9+
e = zeros(R, length(z))
10+
e[argmax(z)] = one(R)
11+
return e
12+
end
13+
14+
"""
15+
$TYPEDEF
16+
17+
Minimal contextual stochastic argmax benchmark.
18+
19+
Per instance: `c_base ~ U[0,1]^n` (base utility, part of instance kwargs and base features).
20+
Per context draw: `x_raw ~ N(0, I_d)` (observable context). Features: `x = [c_base; x_raw]`.
21+
Per scenario: `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`.
22+
The learner sees `x` and must predict `θ̂` so that `argmax(θ̂)` ≈ `argmax(ξ)`.
23+
24+
A linear model `Dense(n+d → n; bias=false)` can exactly recover `[I | W]`.
25+
26+
# Fields
27+
$TYPEDFIELDS
28+
"""
29+
struct ContextualStochasticArgmaxBenchmark{M<:AbstractMatrix} <:
30+
AbstractStochasticBenchmark{true}
31+
"number of items (argmax dimension)"
32+
n::Int
33+
"number of context features"
34+
d::Int
35+
"fixed perturbation matrix W ∈ R^{n×d}, unknown to the learner"
36+
W::M
37+
"noise std for scenario draws"
38+
noise_std::Float32
39+
end
40+
41+
function ContextualStochasticArgmaxBenchmark(;
42+
n::Int=10, d::Int=5, noise_std::Float32=0.1f0, seed=nothing
43+
)
44+
rng = MersenneTwister(seed)
45+
W = randn(rng, Float32, n, d)
46+
return ContextualStochasticArgmaxBenchmark(n, d, W, noise_std)
47+
end
48+
49+
Utils.is_minimization_problem(::ContextualStochasticArgmaxBenchmark) = false
50+
Utils.generate_maximizer(::ContextualStochasticArgmaxBenchmark) = one_hot_argmax
51+
52+
# c_base: base features (in x) and solver kwarg (in instance_kwargs for generate_scenario)
53+
function Utils.generate_instance(
54+
bench::ContextualStochasticArgmaxBenchmark, rng::AbstractRNG; kwargs...
55+
)
56+
c_base = rand(rng, Float32, bench.n)
57+
return DataSample(; x=c_base, c_base=c_base)
58+
end
59+
60+
# Enriches instance_sample: x = [c_base; x_raw], x_raw in extra for generate_scenario
61+
function Utils.generate_context(
62+
bench::ContextualStochasticArgmaxBenchmark,
63+
rng::AbstractRNG,
64+
instance_sample::DataSample,
65+
)
66+
x_raw = randn(rng, Float32, bench.d)
67+
return DataSample(;
68+
x=vcat(instance_sample.x, x_raw),
69+
instance_sample.instance_kwargs...,
70+
extra=(; x_raw),
71+
)
72+
end
73+
74+
# ξ = c_base + W * x_raw + noise (c_base from instance_kwargs, x_raw from ctx.extra)
75+
function Utils.generate_scenario(
76+
bench::ContextualStochasticArgmaxBenchmark,
77+
rng::AbstractRNG;
78+
c_base::AbstractVector,
79+
x_raw::AbstractVector,
80+
kwargs...,
81+
)
82+
θ_true = c_base + bench.W * x_raw
83+
return θ_true + bench.noise_std * randn(rng, Float32, bench.n)
84+
end
85+
86+
function Utils.generate_statistical_model(
87+
bench::ContextualStochasticArgmaxBenchmark; seed=nothing
88+
)
89+
Random.seed!(seed)
90+
return Dense(bench.n + bench.d => bench.n; bias=false)
91+
end
92+
93+
export ContextualStochasticArgmaxBenchmark
94+
95+
end

src/DecisionFocusedLearningBenchmarks.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ include("Warcraft/Warcraft.jl")
5555
include("FixedSizeShortestPath/FixedSizeShortestPath.jl")
5656
include("PortfolioOptimization/PortfolioOptimization.jl")
5757
include("StochasticVehicleScheduling/StochasticVehicleScheduling.jl")
58+
include("ContextualStochasticArgmax/ContextualStochasticArgmax.jl")
5859
include("DynamicVehicleScheduling/DynamicVehicleScheduling.jl")
5960
include("DynamicAssortment/DynamicAssortment.jl")
6061
include("Maintenance/Maintenance.jl")
@@ -71,8 +72,9 @@ export Policy, evaluate_policy!
7172

7273
export generate_instance,
7374
generate_sample, generate_dataset, generate_environments, generate_environment
74-
export generate_scenario
75+
export generate_scenario, generate_context
7576
export generate_baseline_policies
77+
export SAA
7678
export generate_statistical_model
7779
export generate_maximizer
7880
export generate_anticipative_solver, generate_parametric_anticipative_solver
@@ -91,6 +93,7 @@ using .Warcraft
9193
using .FixedSizeShortestPath
9294
using .PortfolioOptimization
9395
using .StochasticVehicleScheduling
96+
using .ContextualStochasticArgmax
9497
using .DynamicVehicleScheduling
9598
using .DynamicAssortment
9699
using .Maintenance
@@ -106,5 +109,6 @@ export StochasticVehicleSchedulingBenchmark
106109
export SubsetSelectionBenchmark
107110
export WarcraftBenchmark
108111
export MaintenanceBenchmark
112+
export ContextualStochasticArgmaxBenchmark
109113

110114
end # module DecisionFocusedLearningBenchmarks

src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ Returns a [`DataSample`](@ref) with features `x` and `instance` set, but `y=noth
116116
To obtain labeled samples, pass a `target_policy` to [`generate_dataset`](@ref):
117117
118118
```julia
119-
policy = sample -> DataSample(; sample.context..., x=sample.x,
119+
policy = sample -> DataSample(; sample.instance_kwargs..., x=sample.x,
120120
y=column_generation_algorithm(sample.instance))
121121
dataset = generate_dataset(benchmark, N; target_policy=policy)
122122
```

src/StochasticVehicleScheduling/policies.jl

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,17 @@ SAA baseline policy: builds a stochastic instance from all K scenarios and solve
55
via column generation.
66
Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.
77
"""
8-
function svs_saa_policy(sample, scenarios)
9-
stochastic_inst = build_stochastic_instance(sample.instance, scenarios)
8+
function svs_saa_policy(instance_sample, ctx_sample, scenarios)
9+
stochastic_inst = build_stochastic_instance(instance_sample.instance, scenarios)
1010
y = column_generation_algorithm(stochastic_inst)
11-
return [DataSample(; sample.context..., x=sample.x, y, extra=(; scenarios))]
11+
return [
12+
DataSample(;
13+
instance_sample.instance_kwargs...,
14+
x=ctx_sample.x,
15+
y,
16+
extra=(; ctx_sample.extra..., scenarios),
17+
),
18+
]
1219
end
1320

1421
"""
@@ -17,9 +24,18 @@ $TYPEDSIGNATURES
1724
Deterministic baseline policy: solves the deterministic MIP (ignores scenario delays).
1825
Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.
1926
"""
20-
function svs_deterministic_policy(sample, scenarios; model_builder=highs_model)
21-
y = deterministic_mip(sample.instance; model_builder)
22-
return [DataSample(; sample.context..., x=sample.x, y, extra=(; scenarios))]
27+
function svs_deterministic_policy(
28+
instance_sample, ctx_sample, scenarios; model_builder=highs_model
29+
)
30+
y = deterministic_mip(instance_sample.instance; model_builder)
31+
return [
32+
DataSample(;
33+
instance_sample.instance_kwargs...,
34+
x=ctx_sample.x,
35+
y,
36+
extra=(; ctx_sample.extra..., scenarios),
37+
),
38+
]
2339
end
2440

2541
"""
@@ -29,24 +45,60 @@ Local search baseline policy: builds a stochastic instance from all K scenarios
2945
solves via local search heuristic.
3046
Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.
3147
"""
32-
function svs_local_search_policy(sample, scenarios)
33-
stochastic_inst = build_stochastic_instance(sample.instance, scenarios)
48+
function svs_local_search_policy(instance_sample, ctx_sample, scenarios)
49+
stochastic_inst = build_stochastic_instance(instance_sample.instance, scenarios)
3450
y = local_search(stochastic_inst)
35-
return [DataSample(; sample.context..., x=sample.x, y, extra=(; scenarios))]
51+
return [
52+
DataSample(;
53+
instance_sample.instance_kwargs...,
54+
x=ctx_sample.x,
55+
y,
56+
extra=(; ctx_sample.extra..., scenarios),
57+
),
58+
]
59+
end
60+
61+
"""
62+
$TYPEDSIGNATURES
63+
64+
Exact SAA MIP policy (linearized): solves the stochastic VSP exactly for the given
65+
scenarios via [`compact_linearized_mip`](@ref).
66+
Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.
67+
68+
Prefer this over [`svs_saa_policy`](@ref) when an exact solution is needed; requires
69+
SCIP (default) or Gurobi.
70+
"""
71+
function svs_saa_mip_policy(
72+
instance_sample, ctx_sample, scenarios; model_builder=scip_model
73+
)
74+
y = compact_linearized_mip(instance_sample.instance, scenarios; model_builder)
75+
return [
76+
DataSample(;
77+
instance_sample.instance_kwargs...,
78+
x=ctx_sample.x,
79+
y,
80+
extra=(; ctx_sample.extra..., scenarios),
81+
),
82+
]
3683
end
3784

3885
"""
3986
$TYPEDSIGNATURES
4087
4188
Return the named baseline policies for [`StochasticVehicleSchedulingBenchmark`](@ref).
42-
Each policy has signature `(sample, scenarios) -> Vector{DataSample}`.
89+
Each policy has signature `(instance_sample, ctx_sample, scenarios) -> Vector{DataSample}`.
4390
"""
4491
function svs_generate_baseline_policies(::StochasticVehicleSchedulingBenchmark)
4592
return (;
4693
deterministic=Policy(
4794
"Deterministic MIP", "Ignores delays", svs_deterministic_policy
4895
),
4996
saa=Policy("SAA (col gen)", "Stochastic MIP over K scenarios", svs_saa_policy),
97+
saa_mip=Policy(
98+
"SAA (exact MIP)",
99+
"Exact stochastic MIP over K scenarios via compact linearized formulation",
100+
svs_saa_mip_policy,
101+
),
50102
local_search=Policy(
51103
"Local search", "Heuristic with K scenarios", svs_local_search_policy
52104
),

src/StochasticVehicleScheduling/solution/algorithms/mip.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,18 @@ end
8484
"""
8585
$TYPEDSIGNATURES
8686
87+
SAA variant: build stochastic instance from `scenarios` then solve via
88+
[`compact_linearized_mip`](@ref).
89+
"""
90+
function compact_linearized_mip(
91+
instance::Instance, scenarios::Vector{VSPScenario}; kwargs...
92+
)
93+
return compact_linearized_mip(build_stochastic_instance(instance, scenarios); kwargs...)
94+
end
95+
96+
"""
97+
$TYPEDSIGNATURES
98+
8799
Returns the optimal solution of the Stochastic VSP instance, by solving the associated compact quadratic MIP.
88100
Note: If you have Gurobi, use `grb_model` as `model_builder` instead of `highs_model`.
89101
@@ -151,3 +163,13 @@ function compact_mip(
151163
sol = solution_from_JuMP_array(solution, graph)
152164
return sol.value
153165
end
166+
167+
"""
168+
$TYPEDSIGNATURES
169+
170+
SAA variant: build stochastic instance from `scenarios` then solve via
171+
[`compact_mip`](@ref).
172+
"""
173+
function compact_mip(instance::Instance, scenarios::Vector{VSPScenario}; kwargs...)
174+
return compact_mip(build_stochastic_instance(instance, scenarios); kwargs...)
175+
end

src/Utils/Utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ export ExogenousStochasticBenchmark,
3131
EndogenousStochasticBenchmark, ExogenousDynamicBenchmark, EndogenousDynamicBenchmark
3232
export generate_instance, generate_sample, generate_dataset
3333
export generate_statistical_model, generate_maximizer
34-
export generate_scenario
34+
export generate_scenario, generate_context
3535
export generate_environment, generate_environments
36+
export SAA
3637
export generate_baseline_policies
3738
export generate_anticipative_solver, generate_parametric_anticipative_solver
3839

0 commit comments

Comments
 (0)