-
Notifications
You must be signed in to change notification settings - Fork 1
Support for contextual stochastic optimization problems #60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
89462c0
Support for contextual stochastic optimization problems
BatyLeo cb96efe
fix doc
BatyLeo 15edb07
rename instance_kwargs to maximizer_kwargs, + other small fixes
BatyLeo 9443ac7
saa policy for the contextual argmax
BatyLeo 9acdd14
update documentation
BatyLeo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| # Contextual Stochastic Argmax | ||
|
|
||
| [`ContextualStochasticArgmaxBenchmark`](@ref) is a minimalist contextual stochastic optimization benchmark problem. | ||
|
|
||
| The decision maker selects one item out of ``n``. Item values are uncertain at decision time: they depend on a base utility plus a context-correlated perturbation revealed only after the decision is made. An observable context vector, correlated with the perturbation via a fixed linear map ``W``, allows the learner to anticipate the perturbation and pick the right item. | ||
|
|
||
| ## Problem Formulation | ||
|
|
||
| **Instance**: ``c_{\text{base}} \sim \mathcal{U}[0,1]^n``, base values for ``n`` items. | ||
|
|
||
| **Context**: ``x_{\text{raw}} \sim \mathcal{N}(0, I_d)``, a ``d``-dimensional signal correlated with item values. The feature vector passed to the model is ``x = [c_{\text{base}};\, x_{\text{raw}}] \in \mathbb{R}^{n+d}``. | ||
|
|
||
| **Scenario**: the realized item values are | ||
| ```math | ||
| \xi = c_{\text{base}} + W x_{\text{raw}} + \varepsilon, \quad \varepsilon \sim \mathcal{N}(0, \sigma^2 I_n) | ||
| ``` | ||
| where ``W \in \mathbb{R}^{n \times d}`` is a fixed matrix unknown to the learner. | ||
|
|
||
| **Decision**: ``y \in \{e_1, \ldots, e_n\}`` (one-hot vector selecting one item). | ||
|
|
||
| ## Policies | ||
|
|
||
| ### DFL Policy | ||
|
|
||
| ```math | ||
| \xrightarrow[\text{Features}]{x} | ||
| \fbox{Neural network $\varphi_w$} | ||
| \xrightarrow[\text{Predicted values}]{\hat{\theta}} | ||
| \fbox{\texttt{one\_hot\_argmax}} | ||
| \xrightarrow[\text{Decision}]{y} | ||
| ``` | ||
|
|
||
| The neural network predicts item values ``\hat{\theta} \in \mathbb{R}^n`` from the feature vector ``x \in \mathbb{R}^{n+d}``. The default architecture is `Dense(n+d => n; bias=false)`, which can exactly recover the optimal linear predictor ``[I_n \mid W]``, so a well-trained model should reach near-zero gap. | ||
|
|
||
| ### SAA Policy | ||
|
|
||
| ``y_{\text{SAA}} = \operatorname{argmax}\bigl(\frac{1}{S}\sum_s \xi^{(s)}\bigr)`` — the exact SAA-optimal decision for linear argmax, accessible via `generate_baseline_policies(bench).saa`. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
105 changes: 105 additions & 0 deletions
105
src/ContextualStochasticArgmax/ContextualStochasticArgmax.jl
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| module ContextualStochasticArgmax | ||
|
|
||
| using ..Utils | ||
| using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES | ||
| using Flux: Dense | ||
| using Random: Random, AbstractRNG, MersenneTwister | ||
| using Statistics: mean | ||
|
|
||
| """ | ||
| $TYPEDEF | ||
|
|
||
| Minimal contextual stochastic argmax benchmark. | ||
|
|
||
| Per instance: `c_base ~ U[0,1]^n` (base utility, stored in `extra` of the instance sample). | ||
| Per context draw: `x_raw ~ N(0, I_d)` (observable context). Features: `x = [c_base; x_raw]`. | ||
| Per scenario: `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`. | ||
| The learner sees `x` and must predict `θ̂` so that `argmax(θ̂)` ≈ `argmax(ξ)`. | ||
|
|
||
| A linear model `Dense(n+d → n; bias=false)` can exactly recover `[I | W]`. | ||
|
|
||
| # Fields | ||
| $TYPEDFIELDS | ||
| """ | ||
| struct ContextualStochasticArgmaxBenchmark{M<:AbstractMatrix} <: | ||
| AbstractStochasticBenchmark{true} | ||
| "number of items (argmax dimension)" | ||
| n::Int | ||
| "number of context features" | ||
| d::Int | ||
| "fixed perturbation matrix W ∈ R^{n×d}, unknown to the learner" | ||
| W::M | ||
| "noise std for scenario draws" | ||
| noise_std::Float32 | ||
| end | ||
|
|
||
| function ContextualStochasticArgmaxBenchmark(; | ||
| n::Int=10, d::Int=5, noise_std::Float32=0.1f0, seed=nothing | ||
| ) | ||
| rng = MersenneTwister(seed) | ||
| W = randn(rng, Float32, n, d) | ||
| return ContextualStochasticArgmaxBenchmark(n, d, W, noise_std) | ||
| end | ||
|
|
||
| Utils.is_minimization_problem(::ContextualStochasticArgmaxBenchmark) = false | ||
| Utils.generate_maximizer(::ContextualStochasticArgmaxBenchmark) = one_hot_argmax | ||
|
|
||
| """ | ||
| generate_instance(::ContextualStochasticArgmaxBenchmark, rng) | ||
|
|
||
| Draw `c_base ~ U[0,1]^n` and store it in `extra`. No solver kwargs are needed | ||
| (the maximizer is `one_hot_argmax`, which takes no kwargs). | ||
| """ | ||
| function Utils.generate_instance( | ||
| bench::ContextualStochasticArgmaxBenchmark, rng::AbstractRNG; kwargs... | ||
| ) | ||
| c_base = rand(rng, Float32, bench.n) | ||
| return DataSample(; extra=(; c_base)) | ||
| end | ||
|
|
||
| """ | ||
| generate_context(::ContextualStochasticArgmaxBenchmark, rng, instance_sample) | ||
|
|
||
| Draw `x_raw ~ N(0, I_d)` and return a context sample with: | ||
| - `x = [c_base; x_raw]`: full feature vector seen by the ML model. | ||
| - `extra = (; c_base, x_raw)`: latents spread into [`generate_scenario`](@ref). | ||
| """ | ||
| function Utils.generate_context( | ||
| bench::ContextualStochasticArgmaxBenchmark, | ||
| rng::AbstractRNG, | ||
| instance_sample::DataSample, | ||
| ) | ||
| c_base = instance_sample.c_base | ||
| x_raw = randn(rng, Float32, bench.d) | ||
| return DataSample(; x=vcat(c_base, x_raw), extra=(; x_raw, c_base)) | ||
| end | ||
|
|
||
| """ | ||
| generate_scenario(::ContextualStochasticArgmaxBenchmark, rng; c_base, x_raw, kwargs...) | ||
|
|
||
| Draw `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`. | ||
| `c_base` and `x_raw` are spread from `ctx.extra` by the framework. | ||
| """ | ||
| function Utils.generate_scenario( | ||
| bench::ContextualStochasticArgmaxBenchmark, | ||
| rng::AbstractRNG; | ||
| c_base::AbstractVector, | ||
| x_raw::AbstractVector, | ||
| kwargs..., | ||
| ) | ||
| θ_true = c_base + bench.W * x_raw | ||
| return θ_true + bench.noise_std * randn(rng, Float32, bench.n) | ||
| end | ||
|
|
||
| function Utils.generate_statistical_model( | ||
| bench::ContextualStochasticArgmaxBenchmark; seed=nothing | ||
| ) | ||
| Random.seed!(seed) | ||
| return Dense(bench.n + bench.d => bench.n; bias=false) | ||
| end | ||
|
|
||
| include("policies.jl") | ||
|
|
||
| export ContextualStochasticArgmaxBenchmark | ||
|
|
||
| end | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| using Statistics: mean | ||
|
|
||
| """ | ||
| $TYPEDSIGNATURES | ||
| SAA baseline policy: returns `argmax(mean(scenarios))`. | ||
| For a linear argmax problem this is the exact SAA-optimal decision. | ||
| Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`. | ||
| """ | ||
| function csa_saa_policy(ctx_sample, scenarios) | ||
| y = one_hot_argmax(mean(scenarios)) | ||
| return [ | ||
| DataSample(; | ||
| ctx_sample.maximizer_kwargs..., | ||
| x=ctx_sample.x, | ||
| y=y, | ||
| extra=(; ctx_sample.extra..., scenarios), | ||
| ), | ||
| ] | ||
| end | ||
|
|
||
| """ | ||
| $TYPEDSIGNATURES | ||
| Return the named baseline policies for [`ContextualStochasticArgmaxBenchmark`](@ref). | ||
| Each policy has signature `(ctx_sample, scenarios) -> Vector{DataSample}`. | ||
| """ | ||
| function Utils.generate_baseline_policies(::ContextualStochasticArgmaxBenchmark) | ||
| return (; saa=Policy("SAA", "argmax of mean scenarios", csa_saa_policy)) | ||
| end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.