Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ Modules = [DecisionFocusedLearningBenchmarks.Argmax]
Public = false
```

## Contextual Stochastic Argmax

```@autodocs
Modules = [DecisionFocusedLearningBenchmarks.ContextualStochasticArgmax]
Private = false
```

```@autodocs
Modules = [DecisionFocusedLearningBenchmarks.ContextualStochasticArgmax]
Public = false
```

## Dynamic Vehicle Scheduling

```@autodocs
Expand Down
37 changes: 37 additions & 0 deletions docs/src/benchmarks/contextual_stochastic_argmax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Contextual Stochastic Argmax

[`ContextualStochasticArgmaxBenchmark`](@ref) is a minimalist contextual stochastic optimization benchmark problem.

The decision maker selects one item out of ``n``. Item values are uncertain at decision time: they depend on a base utility plus a context-correlated perturbation revealed only after the decision is made. An observable context vector, correlated with the perturbation via a fixed linear map ``W``, allows the learner to anticipate the perturbation and pick the right item.

## Problem Formulation

**Instance**: ``c_{\text{base}} \sim \mathcal{U}[0,1]^n``, base values for ``n`` items.

**Context**: ``x_{\text{raw}} \sim \mathcal{N}(0, I_d)``, a ``d``-dimensional signal correlated with item values. The feature vector passed to the model is ``x = [c_{\text{base}};\, x_{\text{raw}}] \in \mathbb{R}^{n+d}``.

**Scenario**: the realized item values are
```math
\xi = c_{\text{base}} + W x_{\text{raw}} + \varepsilon, \quad \varepsilon \sim \mathcal{N}(0, \sigma^2 I_n)
```
where ``W \in \mathbb{R}^{n \times d}`` is a fixed matrix unknown to the learner.

**Decision**: ``y \in \{e_1, \ldots, e_n\}`` (one-hot vector selecting one item).

## Policies

### DFL Policy

```math
\xrightarrow[\text{Features}]{x}
\fbox{Neural network $\varphi_w$}
\xrightarrow[\text{Predicted values}]{\hat{\theta}}
\fbox{\texttt{one\_hot\_argmax}}
\xrightarrow[\text{Decision}]{y}
```

The neural network predicts item values ``\hat{\theta} \in \mathbb{R}^n`` from the feature vector ``x \in \mathbb{R}^{n+d}``. The default architecture is `Dense(n+d => n; bias=false)`, which can exactly recover the optimal linear predictor ``[I_n \mid W]``, so a well-trained model should reach near-zero gap.

### SAA Policy

``y_{\text{SAA}} = \operatorname{argmax}\bigl(\frac{1}{S}\sum_s \xi^{(s)}\bigr)`` — the exact SAA-optimal decision for linear argmax, accessible via `generate_baseline_policies(bench).saa`.
4 changes: 2 additions & 2 deletions docs/src/tutorials/warcraft_tutorial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ x = sample.x
θ_true = sample.θ
# `y` correspond to the optimal shortest path, encoded as a binary matrix:
y_true = sample.y
# `context` is not used in this benchmark (no solver kwargs needed), so it is empty:
isempty(sample.context)
# `maximizer_kwargs` is not used in this benchmark (no solver kwargs needed), so it is empty:
isempty(sample.maximizer_kwargs)

# For some benchmarks, we provide the following plotting method [`plot_solution`](@ref) to visualize the data:
plot_solution(b, sample)
Expand Down
4 changes: 3 additions & 1 deletion ext/DFLBenchmarksPlotsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ Reconstruct a new sample with `y` overridden and delegate to the 2-arg
function plot_solution(bench::AbstractBenchmark, sample::DataSample, y; kwargs...)
return plot_solution(
bench,
DataSample(; sample.context..., x=sample.x, θ=sample.θ, y=y, extra=sample.extra);
DataSample(;
sample.maximizer_kwargs..., x=sample.x, θ=sample.θ, y=y, extra=sample.extra
);
kwargs...,
)
end
Expand Down
105 changes: 105 additions & 0 deletions src/ContextualStochasticArgmax/ContextualStochasticArgmax.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
module ContextualStochasticArgmax

using ..Utils
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
using Flux: Dense
using Random: Random, AbstractRNG, MersenneTwister
using Statistics: mean

"""
$TYPEDEF

Minimal contextual stochastic argmax benchmark.

Per instance: `c_base ~ U[0,1]^n` (base utility, stored in `extra` of the instance sample).
Per context draw: `x_raw ~ N(0, I_d)` (observable context). Features: `x = [c_base; x_raw]`.
Per scenario: `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`.
The learner sees `x` and must predict `θ̂` so that `argmax(θ̂)` ≈ `argmax(ξ)`.

A linear model `Dense(n+d → n; bias=false)` can exactly recover `[I | W]`.

# Fields
$TYPEDFIELDS
"""
struct ContextualStochasticArgmaxBenchmark{M<:AbstractMatrix} <:
AbstractStochasticBenchmark{true}
"number of items (argmax dimension)"
n::Int
"number of context features"
d::Int
"fixed perturbation matrix W ∈ R^{n×d}, unknown to the learner"
W::M
"noise std for scenario draws"
noise_std::Float32
end

function ContextualStochasticArgmaxBenchmark(;
n::Int=10, d::Int=5, noise_std::Float32=0.1f0, seed=nothing
)
rng = MersenneTwister(seed)
W = randn(rng, Float32, n, d)
return ContextualStochasticArgmaxBenchmark(n, d, W, noise_std)
end

Utils.is_minimization_problem(::ContextualStochasticArgmaxBenchmark) = false
Comment thread
BatyLeo marked this conversation as resolved.
Utils.generate_maximizer(::ContextualStochasticArgmaxBenchmark) = one_hot_argmax

"""
generate_instance(::ContextualStochasticArgmaxBenchmark, rng)

Draw `c_base ~ U[0,1]^n` and store it in `extra`. No solver kwargs are needed
(the maximizer is `one_hot_argmax`, which takes no kwargs).
"""
function Utils.generate_instance(
bench::ContextualStochasticArgmaxBenchmark, rng::AbstractRNG; kwargs...
)
c_base = rand(rng, Float32, bench.n)
return DataSample(; extra=(; c_base))
end

"""
generate_context(::ContextualStochasticArgmaxBenchmark, rng, instance_sample)

Draw `x_raw ~ N(0, I_d)` and return a context sample with:
- `x = [c_base; x_raw]`: full feature vector seen by the ML model.
- `extra = (; c_base, x_raw)`: latents spread into [`generate_scenario`](@ref).
"""
function Utils.generate_context(
bench::ContextualStochasticArgmaxBenchmark,
rng::AbstractRNG,
instance_sample::DataSample,
)
c_base = instance_sample.c_base
x_raw = randn(rng, Float32, bench.d)
return DataSample(; x=vcat(c_base, x_raw), extra=(; x_raw, c_base))
end

"""
generate_scenario(::ContextualStochasticArgmaxBenchmark, rng; c_base, x_raw, kwargs...)

Draw `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`.
`c_base` and `x_raw` are spread from `ctx.extra` by the framework.
"""
function Utils.generate_scenario(
bench::ContextualStochasticArgmaxBenchmark,
rng::AbstractRNG;
c_base::AbstractVector,
x_raw::AbstractVector,
kwargs...,
)
θ_true = c_base + bench.W * x_raw
return θ_true + bench.noise_std * randn(rng, Float32, bench.n)
end

function Utils.generate_statistical_model(
bench::ContextualStochasticArgmaxBenchmark; seed=nothing
)
Random.seed!(seed)
return Dense(bench.n + bench.d => bench.n; bias=false)
end

include("policies.jl")

export ContextualStochasticArgmaxBenchmark

end
30 changes: 30 additions & 0 deletions src/ContextualStochasticArgmax/policies.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using Statistics: mean

"""
$TYPEDSIGNATURES
SAA baseline policy: returns `argmax(mean(scenarios))`.
For a linear argmax problem this is the exact SAA-optimal decision.
Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.
"""
function csa_saa_policy(ctx_sample, scenarios)
y = one_hot_argmax(mean(scenarios))
return [
DataSample(;
ctx_sample.maximizer_kwargs...,
x=ctx_sample.x,
y=y,
extra=(; ctx_sample.extra..., scenarios),
),
]
end

"""
$TYPEDSIGNATURES
Return the named baseline policies for [`ContextualStochasticArgmaxBenchmark`](@ref).
Each policy has signature `(ctx_sample, scenarios) -> Vector{DataSample}`.
"""
function Utils.generate_baseline_policies(::ContextualStochasticArgmaxBenchmark)
return (; saa=Policy("SAA", "argmax of mean scenarios", csa_saa_policy))
end
6 changes: 5 additions & 1 deletion src/DecisionFocusedLearningBenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ include("Warcraft/Warcraft.jl")
include("FixedSizeShortestPath/FixedSizeShortestPath.jl")
include("PortfolioOptimization/PortfolioOptimization.jl")
include("StochasticVehicleScheduling/StochasticVehicleScheduling.jl")
include("ContextualStochasticArgmax/ContextualStochasticArgmax.jl")
include("DynamicVehicleScheduling/DynamicVehicleScheduling.jl")
include("DynamicAssortment/DynamicAssortment.jl")
include("Maintenance/Maintenance.jl")
Expand All @@ -71,8 +72,9 @@ export Policy, evaluate_policy!

export generate_instance,
generate_sample, generate_dataset, generate_environments, generate_environment
export generate_scenario
export generate_scenario, generate_context
export generate_baseline_policies
export SampleAverageApproximation
export generate_statistical_model
export generate_maximizer
export generate_anticipative_solver, generate_parametric_anticipative_solver
Expand All @@ -91,6 +93,7 @@ using .Warcraft
using .FixedSizeShortestPath
using .PortfolioOptimization
using .StochasticVehicleScheduling
using .ContextualStochasticArgmax
using .DynamicVehicleScheduling
using .DynamicAssortment
using .Maintenance
Expand All @@ -106,5 +109,6 @@ export StochasticVehicleSchedulingBenchmark
export SubsetSelectionBenchmark
export WarcraftBenchmark
export MaintenanceBenchmark
export ContextualStochasticArgmaxBenchmark

end # module DecisionFocusedLearningBenchmarks
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ Returns a [`DataSample`](@ref) with features `x` and `instance` set, but `y=noth
To obtain labeled samples, pass a `target_policy` to [`generate_dataset`](@ref):
```julia
policy = sample -> DataSample(; sample.context..., x=sample.x,
policy = sample -> DataSample(; sample.maximizer_kwargs..., x=sample.x,
y=column_generation_algorithm(sample.instance))
dataset = generate_dataset(benchmark, N; target_policy=policy)
```
Expand Down
68 changes: 58 additions & 10 deletions src/StochasticVehicleScheduling/policies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@ SAA baseline policy: builds a stochastic instance from all K scenarios and solve
via column generation.
Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.
"""
function svs_saa_policy(sample, scenarios)
stochastic_inst = build_stochastic_instance(sample.instance, scenarios)
function svs_saa_policy(ctx_sample, scenarios)
stochastic_inst = build_stochastic_instance(ctx_sample.instance, scenarios)
y = column_generation_algorithm(stochastic_inst)
return [DataSample(; sample.context..., x=sample.x, y, extra=(; scenarios))]
return [
DataSample(;
ctx_sample.maximizer_kwargs...,
x=ctx_sample.x,
y,
extra=(; ctx_sample.extra..., scenarios),
),
]
end

"""
Expand All @@ -17,9 +24,16 @@ $TYPEDSIGNATURES
Deterministic baseline policy: solves the deterministic MIP (ignores scenario delays).
Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.
"""
function svs_deterministic_policy(sample, scenarios; model_builder=highs_model)
y = deterministic_mip(sample.instance; model_builder)
return [DataSample(; sample.context..., x=sample.x, y, extra=(; scenarios))]
function svs_deterministic_policy(ctx_sample, scenarios; model_builder=highs_model)
y = deterministic_mip(ctx_sample.instance; model_builder)
return [
DataSample(;
ctx_sample.maximizer_kwargs...,
x=ctx_sample.x,
y,
extra=(; ctx_sample.extra..., scenarios),
),
]
end

"""
Expand All @@ -29,24 +43,58 @@ Local search baseline policy: builds a stochastic instance from all K scenarios
solves via local search heuristic.
Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.
"""
function svs_local_search_policy(sample, scenarios)
stochastic_inst = build_stochastic_instance(sample.instance, scenarios)
function svs_local_search_policy(ctx_sample, scenarios)
stochastic_inst = build_stochastic_instance(ctx_sample.instance, scenarios)
y = local_search(stochastic_inst)
return [DataSample(; sample.context..., x=sample.x, y, extra=(; scenarios))]
return [
DataSample(;
ctx_sample.maximizer_kwargs...,
x=ctx_sample.x,
y,
extra=(; ctx_sample.extra..., scenarios),
),
]
end

"""
$TYPEDSIGNATURES

Exact SAA MIP policy (linearized): solves the stochastic VSP exactly for the given
scenarios via [`compact_linearized_mip`](@ref).
Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.

Prefer this over [`svs_saa_policy`](@ref) when an exact solution is needed; requires
SCIP (default) or Gurobi.
"""
function svs_saa_mip_policy(ctx_sample, scenarios; model_builder=scip_model)
y = compact_linearized_mip(ctx_sample.instance, scenarios; model_builder)
return [
DataSample(;
ctx_sample.maximizer_kwargs...,
x=ctx_sample.x,
y,
extra=(; ctx_sample.extra..., scenarios),
),
]
end

"""
$TYPEDSIGNATURES

Return the named baseline policies for [`StochasticVehicleSchedulingBenchmark`](@ref).
Each policy has signature `(sample, scenarios) -> Vector{DataSample}`.
Each policy has signature `(ctx_sample, scenarios) -> Vector{DataSample}`.
"""
function svs_generate_baseline_policies(::StochasticVehicleSchedulingBenchmark)
return (;
deterministic=Policy(
"Deterministic MIP", "Ignores delays", svs_deterministic_policy
),
saa=Policy("SAA (col gen)", "Stochastic MIP over K scenarios", svs_saa_policy),
saa_mip=Policy(
"SAA (exact MIP)",
"Exact stochastic MIP over K scenarios via compact linearized formulation",
svs_saa_mip_policy,
),
local_search=Policy(
"Local search", "Heuristic with K scenarios", svs_local_search_policy
),
Expand Down
22 changes: 22 additions & 0 deletions src/StochasticVehicleScheduling/solution/algorithms/mip.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,18 @@ end
"""
$TYPEDSIGNATURES

SAA variant: build stochastic instance from `scenarios` then solve via
[`compact_linearized_mip`](@ref).
"""
function compact_linearized_mip(
instance::Instance, scenarios::Vector{VSPScenario}; kwargs...
)
return compact_linearized_mip(build_stochastic_instance(instance, scenarios); kwargs...)
end

"""
$TYPEDSIGNATURES

Returns the optimal solution of the Stochastic VSP instance, by solving the associated compact quadratic MIP.
Note: If you have Gurobi, use `grb_model` as `model_builder` instead of `highs_model`.

Expand Down Expand Up @@ -151,3 +163,13 @@ function compact_mip(
sol = solution_from_JuMP_array(solution, graph)
return sol.value
end

"""
$TYPEDSIGNATURES

SAA variant: build stochastic instance from `scenarios` then solve via
[`compact_mip`](@ref).
"""
function compact_mip(instance::Instance, scenarios::Vector{VSPScenario}; kwargs...)
return compact_mip(build_stochastic_instance(instance, scenarios); kwargs...)
end
Loading
Loading