@@ -3,6 +3,7 @@ module ContextualStochasticArgmax
33using .. Utils
44using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
55using Flux: Dense
6+ using LinearAlgebra: dot
67using Random: Random, AbstractRNG, MersenneTwister
78using Statistics: mean
89
@@ -11,7 +12,7 @@ $TYPEDEF
1112
1213Minimal contextual stochastic argmax benchmark.
1314
14- Per instance: `c_base ~ U[0,1]^n` (base utility, stored in `extra ` of the instance sample).
15+ Per instance: `c_base ~ U[0,1]^n` (base utility, stored in `context ` of the instance sample).
1516Per context draw: `x_raw ~ N(0, I_d)` (observable context). Features: `x = [c_base; x_raw]`.
1617Per scenario: `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`.
1718The learner sees `x` and must predict `θ̂` so that `argmax(θ̂)` ≈ `argmax(ξ)`.
4445Utils. is_minimization_problem (:: ContextualStochasticArgmaxBenchmark ) = false
4546Utils. generate_maximizer (:: ContextualStochasticArgmaxBenchmark ) = one_hot_argmax
4647
48+ function Utils. objective_value (
49+ :: ContextualStochasticArgmaxBenchmark , sample:: DataSample , y, scenario
50+ )
51+ return dot (scenario, y)
52+ end
53+
54+ function Utils. objective_value (
55+ bench:: ContextualStochasticArgmaxBenchmark , sample:: DataSample , y
56+ )
57+ if hasproperty (sample. extra, :scenario )
58+ return Utils. objective_value (bench, sample, y, sample. scenario)
59+ elseif hasproperty (sample. extra, :scenarios )
60+ return mean (Utils. objective_value (bench, sample, y, ξ) for ξ in sample. scenarios)
61+ end
62+ return error (" Sample must have scenario or scenarios" )
63+ end
64+
4765"""
4866 generate_instance(::ContextualStochasticArgmaxBenchmark, rng)
4967
50- Draw `c_base ~ U[0,1]^n` and store it in `extra `. No solver kwargs are needed
68+ Draw `c_base ~ U[0,1]^n` and store it in `context `. No solver kwargs are needed
5169(the maximizer is `one_hot_argmax`, which takes no kwargs).
5270"""
5371function Utils. generate_instance (
5472 bench:: ContextualStochasticArgmaxBenchmark , rng:: AbstractRNG ; kwargs...
5573)
5674 c_base = rand (rng, Float32, bench. n)
57- return DataSample (; extra = (; c_base) )
75+ return DataSample (; c_base)
5876end
5977
6078"""
6179 generate_context(::ContextualStochasticArgmaxBenchmark, rng, instance_sample)
6280
6381Draw `x_raw ~ N(0, I_d)` and return a context sample with:
6482- `x = [c_base; x_raw]`: full feature vector seen by the ML model.
65- - `extra = (; c_base, x_raw)`: latents spread into [`generate_scenario`](@ref).
83+ - `c_base`, ` x_raw` in `context`: spread into [`generate_scenario`](@ref).
6684"""
6785function Utils. generate_context (
6886 bench:: ContextualStochasticArgmaxBenchmark ,
@@ -71,14 +89,14 @@ function Utils.generate_context(
7189)
7290 c_base = instance_sample. c_base
7391 x_raw = randn (rng, Float32, bench. d)
74- return DataSample (; x= vcat (c_base, x_raw), extra = (; x_raw, c_base) )
92+ return DataSample (; x= vcat (c_base, x_raw), c_base, x_raw )
7593end
7694
7795"""
7896 generate_scenario(::ContextualStochasticArgmaxBenchmark, rng; c_base, x_raw, kwargs...)
7997
8098Draw `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`.
81- `c_base` and `x_raw` are spread from `ctx.extra ` by the framework.
99+ `c_base` and `x_raw` are spread from `ctx.context ` by the framework.
82100"""
83101function Utils. generate_scenario (
84102 bench:: ContextualStochasticArgmaxBenchmark ,
0 commit comments