-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathContextualStochasticArgmax.jl
More file actions
143 lines (119 loc) · 4.18 KB
/
ContextualStochasticArgmax.jl
File metadata and controls
143 lines (119 loc) · 4.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
module ContextualStochasticArgmax
using ..Utils
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
using Flux: Dense
using LinearAlgebra: dot
using Random: Random, AbstractRNG, MersenneTwister
using Statistics: mean
"""
$TYPEDEF
Minimal contextual stochastic argmax benchmark.
Per instance: `c_base ~ U[0,1]^n` (base utility, stored in `context` of the instance sample).
Per context draw: `x_raw ~ N(0, I_d)` (observable context). Features: `x = [c_base; x_raw]`.
Per scenario: `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`.
The learner sees `x` and must predict `θ̂` so that `argmax(θ̂)` ≈ `argmax(ξ)`.
A linear model `Dense(n+d → n; bias=false)` can exactly recover `[I | W]`.
# Fields
$TYPEDFIELDS
"""
struct ContextualStochasticArgmaxBenchmark{M<:AbstractMatrix} <:
AbstractStochasticBenchmark{true}
"number of items (argmax dimension)"
n::Int
"number of context features"
d::Int
"fixed perturbation matrix W ∈ R^{n×d}, unknown to the learner"
W::M
"noise std for scenario draws"
noise_std::Float32
end
function ContextualStochasticArgmaxBenchmark(;
n::Int=10, d::Int=5, noise_std::Float32=0.1f0, seed=nothing
)
rng = MersenneTwister(seed)
W = randn(rng, Float32, n, d)
return ContextualStochasticArgmaxBenchmark(n, d, W, noise_std)
end
Utils.is_minimization_problem(::ContextualStochasticArgmaxBenchmark) = false
Utils.generate_maximizer(::ContextualStochasticArgmaxBenchmark) = one_hot_argmax
function Utils.objective_value(
::ContextualStochasticArgmaxBenchmark, sample::DataSample, y, scenario
)
return dot(scenario, y)
end
function Utils.objective_value(
bench::ContextualStochasticArgmaxBenchmark, sample::DataSample, y
)
if hasproperty(sample.extra, :scenario)
return Utils.objective_value(bench, sample, y, sample.scenario)
elseif hasproperty(sample.extra, :scenarios)
return mean(Utils.objective_value(bench, sample, y, ξ) for ξ in sample.scenarios)
end
return error("Sample must have scenario or scenarios")
end
"""
generate_instance(::ContextualStochasticArgmaxBenchmark, rng)
Draw `c_base ~ U[0,1]^n` and store it in `context`. No solver kwargs are needed
(the maximizer is `one_hot_argmax`, which takes no kwargs).
"""
function Utils.generate_instance(
bench::ContextualStochasticArgmaxBenchmark, rng::AbstractRNG; kwargs...
)
c_base = rand(rng, Float32, bench.n)
return DataSample(; c_base)
end
"""
generate_context(::ContextualStochasticArgmaxBenchmark, rng, instance_sample)
Draw `x_raw ~ N(0, I_d)` and return a context sample with:
- `x = [c_base; x_raw]`: full feature vector seen by the ML model.
- `c_base`, `x_raw` in `context`: spread into [`generate_scenario`](@ref).
"""
function Utils.generate_context(
bench::ContextualStochasticArgmaxBenchmark,
rng::AbstractRNG,
instance_sample::DataSample,
)
c_base = instance_sample.c_base
x_raw = randn(rng, Float32, bench.d)
return DataSample(; x=vcat(c_base, x_raw), c_base, x_raw)
end
"""
generate_scenario(::ContextualStochasticArgmaxBenchmark, rng; c_base, x_raw, kwargs...)
Draw `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`.
`c_base` and `x_raw` are spread from `ctx.context` by the framework.
"""
function Utils.generate_scenario(
bench::ContextualStochasticArgmaxBenchmark,
rng::AbstractRNG;
c_base::AbstractVector,
x_raw::AbstractVector,
kwargs...,
)
θ_true = c_base + bench.W * x_raw
return θ_true + bench.noise_std * randn(rng, Float32, bench.n)
end
function Utils.generate_statistical_model(
bench::ContextualStochasticArgmaxBenchmark; seed=nothing
)
Random.seed!(seed)
return Dense(bench.n + bench.d => bench.n; bias=false)
end
include("policies.jl")
"""
$TYPEDSIGNATURES
Generates the anticipative solver for the benchmark.
"""
function Utils.generate_anticipative_solver(::ContextualStochasticArgmaxBenchmark)
return AnticipativeSolver()
end
"""
$TYPEDSIGNATURES
Generates the parametric anticipative solver for the benchmark.
"""
function Utils.generate_parametric_anticipative_solver(
::ContextualStochasticArgmaxBenchmark
)
return AnticipativeSolver()
end
export ContextualStochasticArgmaxBenchmark
end