Skip to content

Commit 4282039

Browse files
authored
Merge pull request #60 from JuliaDecisionFocusedLearning/contextual
Support for contextual stochastic optimization problems
2 parents 0f5590e + 9acdd14 commit 4282039

26 files changed

+610
-94
lines changed

docs/src/api.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ Modules = [DecisionFocusedLearningBenchmarks.Argmax]
3636
Public = false
3737
```
3838

39+
## Contextual Stochastic Argmax
40+
41+
```@autodocs
42+
Modules = [DecisionFocusedLearningBenchmarks.ContextualStochasticArgmax]
43+
Private = false
44+
```
45+
46+
```@autodocs
47+
Modules = [DecisionFocusedLearningBenchmarks.ContextualStochasticArgmax]
48+
Public = false
49+
```
50+
3951
## Dynamic Vehicle Scheduling
4052

4153
```@autodocs
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Contextual Stochastic Argmax
2+
3+
[`ContextualStochasticArgmaxBenchmark`](@ref) is a minimalist contextual stochastic optimization benchmark problem.
4+
5+
The decision maker selects one item out of ``n``. Item values are uncertain at decision time: they depend on a base utility plus a context-correlated perturbation revealed only after the decision is made. An observable context vector, correlated with the perturbation via a fixed linear map ``W``, allows the learner to anticipate the perturbation and pick the right item.
6+
7+
## Problem Formulation
8+
9+
**Instance**: ``c_{\text{base}} \sim \mathcal{U}[0,1]^n``, base values for ``n`` items.
10+
11+
**Context**: ``x_{\text{raw}} \sim \mathcal{N}(0, I_d)``, a ``d``-dimensional signal correlated with item values. The feature vector passed to the model is ``x = [c_{\text{base}};\, x_{\text{raw}}] \in \mathbb{R}^{n+d}``.
12+
13+
**Scenario**: the realized item values are
14+
```math
15+
\xi = c_{\text{base}} + W x_{\text{raw}} + \varepsilon, \quad \varepsilon \sim \mathcal{N}(0, \sigma^2 I_n)
16+
```
17+
where ``W \in \mathbb{R}^{n \times d}`` is a fixed matrix unknown to the learner.
18+
19+
**Decision**: ``y \in \{e_1, \ldots, e_n\}`` (one-hot vector selecting one item).
20+
21+
## Policies
22+
23+
### DFL Policy
24+
25+
```math
26+
\xrightarrow[\text{Features}]{x}
27+
\fbox{Neural network $\varphi_w$}
28+
\xrightarrow[\text{Predicted values}]{\hat{\theta}}
29+
\fbox{\texttt{one\_hot\_argmax}}
30+
\xrightarrow[\text{Decision}]{y}
31+
```
32+
33+
The neural network predicts item values ``\hat{\theta} \in \mathbb{R}^n`` from the feature vector ``x \in \mathbb{R}^{n+d}``. The default architecture is `Dense(n+d => n; bias=false)`, which can exactly recover the optimal linear predictor ``[I_n \mid W]``, so a well-trained model should reach near-zero gap.
34+
35+
### SAA Policy
36+
37+
``y_{\text{SAA}} = \operatorname{argmax}\bigl(\frac{1}{S}\sum_s \xi^{(s)}\bigr)`` — the exact SAA-optimal decision for linear argmax, accessible via `generate_baseline_policies(bench).saa`.

docs/src/tutorials/warcraft_tutorial.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ x = sample.x
3030
θ_true = sample.θ
3131
# `y` correspond to the optimal shortest path, encoded as a binary matrix:
3232
y_true = sample.y
33-
# `context` is not used in this benchmark (no solver kwargs needed), so it is empty:
34-
isempty(sample.context)
33+
# `maximizer_kwargs` is not used in this benchmark (no solver kwargs needed), so it is empty:
34+
isempty(sample.maximizer_kwargs)
3535

3636
# For some benchmarks, we provide the following plotting method [`plot_solution`](@ref) to visualize the data:
3737
plot_solution(b, sample)

ext/DFLBenchmarksPlotsExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ Reconstruct a new sample with `y` overridden and delegate to the 2-arg
2121
function plot_solution(bench::AbstractBenchmark, sample::DataSample, y; kwargs...)
2222
return plot_solution(
2323
bench,
24-
DataSample(; sample.context..., x=sample.x, θ=sample.θ, y=y, extra=sample.extra);
24+
DataSample(;
25+
sample.maximizer_kwargs..., x=sample.x, θ=sample.θ, y=y, extra=sample.extra
26+
);
2527
kwargs...,
2628
)
2729
end
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
module ContextualStochasticArgmax
2+
3+
using ..Utils
4+
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
5+
using Flux: Dense
6+
using Random: Random, AbstractRNG, MersenneTwister
7+
using Statistics: mean
8+
9+
"""
10+
$TYPEDEF
11+
12+
Minimal contextual stochastic argmax benchmark.
13+
14+
Per instance: `c_base ~ U[0,1]^n` (base utility, stored in `extra` of the instance sample).
15+
Per context draw: `x_raw ~ N(0, I_d)` (observable context). Features: `x = [c_base; x_raw]`.
16+
Per scenario: `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`.
17+
The learner sees `x` and must predict `θ̂` so that `argmax(θ̂)` ≈ `argmax(ξ)`.
18+
19+
A linear model `Dense(n+d → n; bias=false)` can exactly recover `[I | W]`.
20+
21+
# Fields
22+
$TYPEDFIELDS
23+
"""
24+
struct ContextualStochasticArgmaxBenchmark{M<:AbstractMatrix} <:
25+
AbstractStochasticBenchmark{true}
26+
"number of items (argmax dimension)"
27+
n::Int
28+
"number of context features"
29+
d::Int
30+
"fixed perturbation matrix W ∈ R^{n×d}, unknown to the learner"
31+
W::M
32+
"noise std for scenario draws"
33+
noise_std::Float32
34+
end
35+
36+
function ContextualStochasticArgmaxBenchmark(;
37+
n::Int=10, d::Int=5, noise_std::Float32=0.1f0, seed=nothing
38+
)
39+
rng = MersenneTwister(seed)
40+
W = randn(rng, Float32, n, d)
41+
return ContextualStochasticArgmaxBenchmark(n, d, W, noise_std)
42+
end
43+
44+
Utils.is_minimization_problem(::ContextualStochasticArgmaxBenchmark) = false
45+
Utils.generate_maximizer(::ContextualStochasticArgmaxBenchmark) = one_hot_argmax
46+
47+
"""
48+
generate_instance(::ContextualStochasticArgmaxBenchmark, rng)
49+
50+
Draw `c_base ~ U[0,1]^n` and store it in `extra`. No solver kwargs are needed
51+
(the maximizer is `one_hot_argmax`, which takes no kwargs).
52+
"""
53+
function Utils.generate_instance(
54+
bench::ContextualStochasticArgmaxBenchmark, rng::AbstractRNG; kwargs...
55+
)
56+
c_base = rand(rng, Float32, bench.n)
57+
return DataSample(; extra=(; c_base))
58+
end
59+
60+
"""
61+
generate_context(::ContextualStochasticArgmaxBenchmark, rng, instance_sample)
62+
63+
Draw `x_raw ~ N(0, I_d)` and return a context sample with:
64+
- `x = [c_base; x_raw]`: full feature vector seen by the ML model.
65+
- `extra = (; c_base, x_raw)`: latents spread into [`generate_scenario`](@ref).
66+
"""
67+
function Utils.generate_context(
68+
bench::ContextualStochasticArgmaxBenchmark,
69+
rng::AbstractRNG,
70+
instance_sample::DataSample,
71+
)
72+
c_base = instance_sample.c_base
73+
x_raw = randn(rng, Float32, bench.d)
74+
return DataSample(; x=vcat(c_base, x_raw), extra=(; x_raw, c_base))
75+
end
76+
77+
"""
78+
generate_scenario(::ContextualStochasticArgmaxBenchmark, rng; c_base, x_raw, kwargs...)
79+
80+
Draw `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`.
81+
`c_base` and `x_raw` are spread from `ctx.extra` by the framework.
82+
"""
83+
function Utils.generate_scenario(
84+
bench::ContextualStochasticArgmaxBenchmark,
85+
rng::AbstractRNG;
86+
c_base::AbstractVector,
87+
x_raw::AbstractVector,
88+
kwargs...,
89+
)
90+
θ_true = c_base + bench.W * x_raw
91+
return θ_true + bench.noise_std * randn(rng, Float32, bench.n)
92+
end
93+
94+
function Utils.generate_statistical_model(
95+
bench::ContextualStochasticArgmaxBenchmark; seed=nothing
96+
)
97+
Random.seed!(seed)
98+
return Dense(bench.n + bench.d => bench.n; bias=false)
99+
end
100+
101+
include("policies.jl")
102+
103+
export ContextualStochasticArgmaxBenchmark
104+
105+
end
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
using Statistics: mean
2+
3+
"""
4+
$TYPEDSIGNATURES
5+
6+
SAA baseline policy: returns `argmax(mean(scenarios))`.
7+
For a linear argmax problem this is the exact SAA-optimal decision.
8+
Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.
9+
"""
10+
function csa_saa_policy(ctx_sample, scenarios)
11+
y = one_hot_argmax(mean(scenarios))
12+
return [
13+
DataSample(;
14+
ctx_sample.maximizer_kwargs...,
15+
x=ctx_sample.x,
16+
y=y,
17+
extra=(; ctx_sample.extra..., scenarios),
18+
),
19+
]
20+
end
21+
22+
"""
23+
$TYPEDSIGNATURES
24+
25+
Return the named baseline policies for [`ContextualStochasticArgmaxBenchmark`](@ref).
26+
Each policy has signature `(ctx_sample, scenarios) -> Vector{DataSample}`.
27+
"""
28+
function Utils.generate_baseline_policies(::ContextualStochasticArgmaxBenchmark)
29+
return (; saa=Policy("SAA", "argmax of mean scenarios", csa_saa_policy))
30+
end

src/DecisionFocusedLearningBenchmarks.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ include("Warcraft/Warcraft.jl")
5555
include("FixedSizeShortestPath/FixedSizeShortestPath.jl")
5656
include("PortfolioOptimization/PortfolioOptimization.jl")
5757
include("StochasticVehicleScheduling/StochasticVehicleScheduling.jl")
58+
include("ContextualStochasticArgmax/ContextualStochasticArgmax.jl")
5859
include("DynamicVehicleScheduling/DynamicVehicleScheduling.jl")
5960
include("DynamicAssortment/DynamicAssortment.jl")
6061
include("Maintenance/Maintenance.jl")
@@ -71,8 +72,9 @@ export Policy, evaluate_policy!
7172

7273
export generate_instance,
7374
generate_sample, generate_dataset, generate_environments, generate_environment
74-
export generate_scenario
75+
export generate_scenario, generate_context
7576
export generate_baseline_policies
77+
export SampleAverageApproximation
7678
export generate_statistical_model
7779
export generate_maximizer
7880
export generate_anticipative_solver, generate_parametric_anticipative_solver
@@ -91,6 +93,7 @@ using .Warcraft
9193
using .FixedSizeShortestPath
9294
using .PortfolioOptimization
9395
using .StochasticVehicleScheduling
96+
using .ContextualStochasticArgmax
9497
using .DynamicVehicleScheduling
9598
using .DynamicAssortment
9699
using .Maintenance
@@ -106,5 +109,6 @@ export StochasticVehicleSchedulingBenchmark
106109
export SubsetSelectionBenchmark
107110
export WarcraftBenchmark
108111
export MaintenanceBenchmark
112+
export ContextualStochasticArgmaxBenchmark
109113

110114
end # module DecisionFocusedLearningBenchmarks

src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ Returns a [`DataSample`](@ref) with features `x` and `instance` set, but `y=noth
116116
To obtain labeled samples, pass a `target_policy` to [`generate_dataset`](@ref):
117117
118118
```julia
119-
policy = sample -> DataSample(; sample.context..., x=sample.x,
119+
policy = sample -> DataSample(; sample.maximizer_kwargs..., x=sample.x,
120120
y=column_generation_algorithm(sample.instance))
121121
dataset = generate_dataset(benchmark, N; target_policy=policy)
122122
```

src/StochasticVehicleScheduling/policies.jl

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,17 @@ SAA baseline policy: builds a stochastic instance from all K scenarios and solve
55
via column generation.
66
Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.
77
"""
8-
function svs_saa_policy(sample, scenarios)
9-
stochastic_inst = build_stochastic_instance(sample.instance, scenarios)
8+
function svs_saa_policy(ctx_sample, scenarios)
9+
stochastic_inst = build_stochastic_instance(ctx_sample.instance, scenarios)
1010
y = column_generation_algorithm(stochastic_inst)
11-
return [DataSample(; sample.context..., x=sample.x, y, extra=(; scenarios))]
11+
return [
12+
DataSample(;
13+
ctx_sample.maximizer_kwargs...,
14+
x=ctx_sample.x,
15+
y,
16+
extra=(; ctx_sample.extra..., scenarios),
17+
),
18+
]
1219
end
1320

1421
"""
@@ -17,9 +24,16 @@ $TYPEDSIGNATURES
1724
Deterministic baseline policy: solves the deterministic MIP (ignores scenario delays).
1825
Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.
1926
"""
20-
function svs_deterministic_policy(sample, scenarios; model_builder=highs_model)
21-
y = deterministic_mip(sample.instance; model_builder)
22-
return [DataSample(; sample.context..., x=sample.x, y, extra=(; scenarios))]
27+
function svs_deterministic_policy(ctx_sample, scenarios; model_builder=highs_model)
28+
y = deterministic_mip(ctx_sample.instance; model_builder)
29+
return [
30+
DataSample(;
31+
ctx_sample.maximizer_kwargs...,
32+
x=ctx_sample.x,
33+
y,
34+
extra=(; ctx_sample.extra..., scenarios),
35+
),
36+
]
2337
end
2438

2539
"""
@@ -29,24 +43,58 @@ Local search baseline policy: builds a stochastic instance from all K scenarios
2943
solves via local search heuristic.
3044
Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.
3145
"""
32-
function svs_local_search_policy(sample, scenarios)
33-
stochastic_inst = build_stochastic_instance(sample.instance, scenarios)
46+
function svs_local_search_policy(ctx_sample, scenarios)
47+
stochastic_inst = build_stochastic_instance(ctx_sample.instance, scenarios)
3448
y = local_search(stochastic_inst)
35-
return [DataSample(; sample.context..., x=sample.x, y, extra=(; scenarios))]
49+
return [
50+
DataSample(;
51+
ctx_sample.maximizer_kwargs...,
52+
x=ctx_sample.x,
53+
y,
54+
extra=(; ctx_sample.extra..., scenarios),
55+
),
56+
]
57+
end
58+
59+
"""
60+
$TYPEDSIGNATURES
61+
62+
Exact SAA MIP policy (linearized): solves the stochastic VSP exactly for the given
63+
scenarios via [`compact_linearized_mip`](@ref).
64+
Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.
65+
66+
Prefer this over [`svs_saa_policy`](@ref) when an exact solution is needed; requires
67+
SCIP (default) or Gurobi.
68+
"""
69+
function svs_saa_mip_policy(ctx_sample, scenarios; model_builder=scip_model)
70+
y = compact_linearized_mip(ctx_sample.instance, scenarios; model_builder)
71+
return [
72+
DataSample(;
73+
ctx_sample.maximizer_kwargs...,
74+
x=ctx_sample.x,
75+
y,
76+
extra=(; ctx_sample.extra..., scenarios),
77+
),
78+
]
3679
end
3780

3881
"""
3982
$TYPEDSIGNATURES
4083
4184
Return the named baseline policies for [`StochasticVehicleSchedulingBenchmark`](@ref).
42-
Each policy has signature `(sample, scenarios) -> Vector{DataSample}`.
85+
Each policy has signature `(ctx_sample, scenarios) -> Vector{DataSample}`.
4386
"""
4487
function svs_generate_baseline_policies(::StochasticVehicleSchedulingBenchmark)
4588
return (;
4689
deterministic=Policy(
4790
"Deterministic MIP", "Ignores delays", svs_deterministic_policy
4891
),
4992
saa=Policy("SAA (col gen)", "Stochastic MIP over K scenarios", svs_saa_policy),
93+
saa_mip=Policy(
94+
"SAA (exact MIP)",
95+
"Exact stochastic MIP over K scenarios via compact linearized formulation",
96+
svs_saa_mip_policy,
97+
),
5098
local_search=Policy(
5199
"Local search", "Heuristic with K scenarios", svs_local_search_policy
52100
),

src/StochasticVehicleScheduling/solution/algorithms/mip.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,18 @@ end
8484
"""
8585
$TYPEDSIGNATURES
8686
87+
SAA variant: build stochastic instance from `scenarios` then solve via
88+
[`compact_linearized_mip`](@ref).
89+
"""
90+
function compact_linearized_mip(
91+
instance::Instance, scenarios::Vector{VSPScenario}; kwargs...
92+
)
93+
return compact_linearized_mip(build_stochastic_instance(instance, scenarios); kwargs...)
94+
end
95+
96+
"""
97+
$TYPEDSIGNATURES
98+
8799
Returns the optimal solution of the Stochastic VSP instance, by solving the associated compact quadratic MIP.
88100
Note: If you have Gurobi, use `grb_model` as `model_builder` instead of `highs_model`.
89101
@@ -151,3 +163,13 @@ function compact_mip(
151163
sol = solution_from_JuMP_array(solution, graph)
152164
return sol.value
153165
end
166+
167+
"""
168+
$TYPEDSIGNATURES
169+
170+
SAA variant: build stochastic instance from `scenarios` then solve via
171+
[`compact_mip`](@ref).
172+
"""
173+
function compact_mip(instance::Instance, scenarios::Vector{VSPScenario}; kwargs...)
174+
return compact_mip(build_stochastic_instance(instance, scenarios); kwargs...)
175+
end

0 commit comments

Comments
 (0)