Skip to content

Commit a569b3d

Browse files
committed
rename maximizer_kwargs to context
1 parent 62c8cca commit a569b3d

20 files changed

Lines changed: 116 additions & 135 deletions

docs/src/tutorials/warcraft_tutorial.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ x = sample.x
3030
θ_true = sample.θ
3131
# `y` correspond to the optimal shortest path, encoded as a binary matrix:
3232
y_true = sample.y
33-
# `maximizer_kwargs` is not used in this benchmark (no solver kwargs needed), so it is empty:
34-
isempty(sample.maximizer_kwargs)
33+
# `context` is not used in this benchmark (no solver kwargs needed), so it is empty:
34+
isempty(sample.context)
3535

3636
# For some benchmarks, we provide the following plotting method [`plot_solution`](@ref) to visualize the data:
3737
plot_solution(b, sample)

ext/DFLBenchmarksPlotsExt.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ Reconstruct a new sample with `y` overridden and delegate to the 2-arg
2121
function plot_solution(bench::AbstractBenchmark, sample::DataSample, y; kwargs...)
2222
return plot_solution(
2323
bench,
24-
DataSample(;
25-
sample.maximizer_kwargs..., x=sample.x, θ=sample.θ, y=y, extra=sample.extra
26-
);
24+
DataSample(; sample.context..., x=sample.x, θ=sample.θ, y=y, extra=sample.extra);
2725
kwargs...,
2826
)
2927
end

src/ContextualStochasticArgmax/ContextualStochasticArgmax.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ $TYPEDEF
1212
1313
Minimal contextual stochastic argmax benchmark.
1414
15-
Per instance: `c_base ~ U[0,1]^n` (base utility, stored in `extra` of the instance sample).
15+
Per instance: `c_base ~ U[0,1]^n` (base utility, stored in `context` of the instance sample).
1616
Per context draw: `x_raw ~ N(0, I_d)` (observable context). Features: `x = [c_base; x_raw]`.
1717
Per scenario: `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`.
1818
The learner sees `x` and must predict `θ̂` so that `argmax(θ̂)` ≈ `argmax(ξ)`.
@@ -55,34 +55,32 @@ function Utils.objective_value(
5555
bench::ContextualStochasticArgmaxBenchmark, sample::DataSample, y
5656
)
5757
if hasproperty(sample.extra, :scenario)
58-
return Utils.objective_value(bench, sample, y, sample.extra.scenario)
58+
return Utils.objective_value(bench, sample, y, sample.scenario)
5959
elseif hasproperty(sample.extra, :scenarios)
60-
return mean(
61-
Utils.objective_value(bench, sample, y, ξ) for ξ in sample.extra.scenarios
62-
)
60+
return mean(Utils.objective_value(bench, sample, y, ξ) for ξ in sample.scenarios)
6361
end
6462
return error("Sample must have scenario or scenarios")
6563
end
6664

6765
"""
6866
generate_instance(::ContextualStochasticArgmaxBenchmark, rng)
6967
70-
Draw `c_base ~ U[0,1]^n` and store it in `extra`. No solver kwargs are needed
68+
Draw `c_base ~ U[0,1]^n` and store it in `context`. No solver kwargs are needed
7169
(the maximizer is `one_hot_argmax`, which takes no kwargs).
7270
"""
7371
function Utils.generate_instance(
7472
bench::ContextualStochasticArgmaxBenchmark, rng::AbstractRNG; kwargs...
7573
)
7674
c_base = rand(rng, Float32, bench.n)
77-
return DataSample(; extra=(; c_base))
75+
return DataSample(; c_base)
7876
end
7977

8078
"""
8179
generate_context(::ContextualStochasticArgmaxBenchmark, rng, instance_sample)
8280
8381
Draw `x_raw ~ N(0, I_d)` and return a context sample with:
8482
- `x = [c_base; x_raw]`: full feature vector seen by the ML model.
85-
- `extra = (; c_base, x_raw)`: latents spread into [`generate_scenario`](@ref).
83+
- `c_base`, `x_raw` in `context`: spread into [`generate_scenario`](@ref).
8684
"""
8785
function Utils.generate_context(
8886
bench::ContextualStochasticArgmaxBenchmark,
@@ -91,14 +89,14 @@ function Utils.generate_context(
9189
)
9290
c_base = instance_sample.c_base
9391
x_raw = randn(rng, Float32, bench.d)
94-
return DataSample(; x=vcat(c_base, x_raw), extra=(; x_raw, c_base))
92+
return DataSample(; x=vcat(c_base, x_raw), c_base, x_raw)
9593
end
9694

9795
"""
9896
generate_scenario(::ContextualStochasticArgmaxBenchmark, rng; c_base, x_raw, kwargs...)
9997
10098
Draw `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`.
101-
`c_base` and `x_raw` are spread from `ctx.extra` by the framework.
99+
`c_base` and `x_raw` are spread from `ctx.context` by the framework.
102100
"""
103101
function Utils.generate_scenario(
104102
bench::ContextualStochasticArgmaxBenchmark,

src/ContextualStochasticArgmax/policies.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function csa_saa_policy(ctx_sample, scenarios)
1111
y = one_hot_argmax(mean(scenarios))
1212
return [
1313
DataSample(;
14-
ctx_sample.maximizer_kwargs...,
14+
ctx_sample.context...,
1515
x=ctx_sample.x,
1616
y=y,
1717
extra=(; ctx_sample.extra..., scenarios),

src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ Returns a [`DataSample`](@ref) with features `x` and `instance` set, but `y=noth
131131
To obtain labeled samples, pass a `target_policy` to [`generate_dataset`](@ref):
132132
133133
```julia
134-
policy = sample -> DataSample(; sample.maximizer_kwargs..., x=sample.x,
134+
policy = sample -> DataSample(; sample.context..., x=sample.x,
135135
y=column_generation_algorithm(sample.instance))
136136
dataset = generate_dataset(benchmark, N; target_policy=policy)
137137
```

src/StochasticVehicleScheduling/policies.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function svs_saa_policy(ctx_sample, scenarios)
1010
y = column_generation_algorithm(stochastic_inst)
1111
return [
1212
DataSample(;
13-
ctx_sample.maximizer_kwargs...,
13+
ctx_sample.context...,
1414
x=ctx_sample.x,
1515
y,
1616
extra=(; ctx_sample.extra..., scenarios),
@@ -28,7 +28,7 @@ function svs_deterministic_policy(ctx_sample, scenarios; model_builder=highs_mod
2828
y = deterministic_mip(ctx_sample.instance; model_builder)
2929
return [
3030
DataSample(;
31-
ctx_sample.maximizer_kwargs...,
31+
ctx_sample.context...,
3232
x=ctx_sample.x,
3333
y,
3434
extra=(; ctx_sample.extra..., scenarios),
@@ -48,7 +48,7 @@ function svs_local_search_policy(ctx_sample, scenarios)
4848
y = local_search(stochastic_inst)
4949
return [
5050
DataSample(;
51-
ctx_sample.maximizer_kwargs...,
51+
ctx_sample.context...,
5252
x=ctx_sample.x,
5353
y,
5454
extra=(; ctx_sample.extra..., scenarios),
@@ -70,7 +70,7 @@ function svs_saa_mip_policy(ctx_sample, scenarios; model_builder=scip_model)
7070
y = compact_linearized_mip(ctx_sample.instance, scenarios; model_builder)
7171
return [
7272
DataSample(;
73-
ctx_sample.maximizer_kwargs...,
73+
ctx_sample.context...,
7474
x=ctx_sample.x,
7575
y,
7676
extra=(; ctx_sample.extra..., scenarios),

src/Utils/data_sample.jl

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ $TYPEDEF
44
Data sample data structure.
55
Its main purpose is to store datasets generated by the benchmarks.
66
It has 3 main (optional) fields: features `x`, cost parameters `θ`, and solution `y`.
7-
Additionally, it has a `maximizer_kwargs` field (solver kwargs, spread into the maximizer as
8-
`maximizer(θ; sample.maximizer_kwargs...)`) and an `extra` field (non-solver data, never passed
7+
Additionally, it has a `context` field (solver and scenario-generation context, spread into the
8+
maximizer as `maximizer(θ; sample.context...)`) and an `extra` field (non-solver data, never passed
99
to the maximizer).
1010
1111
The separation prevents silent breakage from accidentally passing non-solver data
@@ -27,8 +27,8 @@ struct DataSample{
2727
θ::C
2828
"output solution (optional)"
2929
y::S
30-
"solver kwargs, e.g. instance, graph, etc."
31-
maximizer_kwargs::K
30+
"solver and scenario-generation context, e.g. instance, graph, contextual information"
31+
context::K
3232
"additional data, never passed to the maximizer, e.g. scenario, objective value, reward,
3333
step count, etc. Can be used for any purpose by the user, such as plotting utilities."
3434
extra::E
@@ -39,65 +39,61 @@ $TYPEDSIGNATURES
3939
4040
Constructor for `DataSample` with keyword arguments.
4141
42-
All keyword arguments beyond `x`, `θ`, `y`, and `extra` are collected into the `maximizer_kwargs`
43-
field (solver kwargs). The `extra` keyword accepts a `NamedTuple` of non-solver data.
42+
All keyword arguments beyond `x`, `θ`, `y`, and `extra` are collected into the `context`
43+
field (solver and scenario-generation context). The `extra` keyword accepts a `NamedTuple` of non-solver data.
4444
45-
Fields in `maximizer_kwargs` and `extra` must be disjoint. Neither may use a reserved
46-
struct field name (`x`, `θ`, `y`, `maximizer_kwargs`, `extra`). An error is thrown in
45+
Fields in `context` and `extra` must be disjoint. Neither may use a reserved
46+
struct field name (`x`, `θ`, `y`, `context`, `extra`). An error is thrown in
4747
both cases.
4848
Both can be accessed directly via property forwarding.
4949
5050
# Examples
5151
```julia
52-
# Instance goes in maximizer_kwargs
52+
# Instance goes in context
5353
d = DataSample(x=[1,2,3], θ=[4,5,6], y=[7,8,9], instance="my_instance")
54-
d.instance # "my_instance" (from maximizer_kwargs)
54+
d.instance # "my_instance" (from context)
5555
5656
# Scenario goes in extra
5757
d = DataSample(x=x, y=y, instance=inst, extra=(; scenario=ξ))
5858
d.scenario # ξ (from extra)
5959
60-
# State goes in maximizer_kwargs, reward in extra
60+
# State goes in context, reward in extra
6161
d = DataSample(x=x, y=y, instance=state, extra=(; reward=-1.5))
62-
d.instance # state (from maximizer_kwargs)
62+
d.instance # state (from context)
6363
d.reward # -1.5 (from extra)
6464
```
6565
"""
6666
function DataSample(; x=nothing, θ=nothing, y=nothing, extra=NamedTuple(), kwargs...)
67-
maximizer_kwargs = (; kwargs...)
68-
overlap = intersect(keys(maximizer_kwargs), keys(extra))
67+
context = (; kwargs...)
68+
overlap = intersect(keys(context), keys(extra))
6969
if !isempty(overlap)
70-
error(
71-
"Keys $(collect(overlap)) appear in both maximizer_kwargs and extra of DataSample",
72-
)
70+
error("Keys $(collect(overlap)) appear in both context and extra of DataSample")
7371
end
74-
reserved = (:x, , :y, :maximizer_kwargs, :extra)
75-
shadowed_ctx = intersect(keys(maximizer_kwargs), reserved)
72+
reserved = (:x, , :y, :context, :extra)
73+
shadowed_ctx = intersect(keys(context), reserved)
7674
if !isempty(shadowed_ctx)
77-
error(
78-
"Keys $(collect(shadowed_ctx)) in maximizer_kwargs shadow DataSample struct fields",
79-
)
75+
error("Keys $(collect(shadowed_ctx)) in context shadow DataSample struct fields")
8076
end
8177
shadowed_extra = intersect(keys(extra), reserved)
8278
if !isempty(shadowed_extra)
8379
error("Keys $(collect(shadowed_extra)) in extra shadow DataSample struct fields")
8480
end
85-
return DataSample(x, θ, y, maximizer_kwargs, extra)
81+
return DataSample(x, θ, y, context, extra)
8682
end
8783

8884
"""
8985
$TYPEDSIGNATURES
9086
9187
Extended property access for `DataSample`.
9288
93-
Allows accessing `maximizer_kwargs` and `extra` fields directly as properties.
94-
`maximizer_kwargs` is searched first; if the key is not found there, `extra` is searched.
89+
Allows accessing `context` and `extra` fields directly as properties.
90+
`context` is searched first; if the key is not found there, `extra` is searched.
9591
"""
9692
function Base.getproperty(d::DataSample, name::Symbol)
97-
if name in (:x, , :y, :maximizer_kwargs, :extra)
93+
if name in (:x, , :y, :context, :extra)
9894
return getfield(d, name)
9995
else
100-
ctx = getfield(d, :maximizer_kwargs)
96+
ctx = getfield(d, :context)
10197
if haskey(ctx, name)
10298
return getproperty(ctx, name)
10399
end
@@ -109,12 +105,12 @@ end
109105
$TYPEDSIGNATURES
110106
111107
Return all property names of a `DataSample`, including both struct fields and forwarded
112-
fields from `maximizer_kwargs` and `extra`.
108+
fields from `context` and `extra`.
113109
114110
This enables tab completion for all available properties.
115111
"""
116112
function Base.propertynames(d::DataSample, private::Bool=false)
117-
ctx_names = propertynames(getfield(d, :maximizer_kwargs), private)
113+
ctx_names = propertynames(getfield(d, :context), private)
118114
extra_names = propertynames(getfield(d, :extra), private)
119115
return (fieldnames(DataSample)..., ctx_names..., extra_names...)
120116
end
@@ -141,7 +137,7 @@ function Base.show(io::IO, d::DataSample)
141137
y_str = sprint(show, d.y; context=io_limited)
142138
push!(fields, "y_true=$y_str")
143139
end
144-
for (key, value) in pairs(d.maximizer_kwargs)
140+
for (key, value) in pairs(d.context)
145141
value_str = sprint(show, value; context=io_limited)
146142
push!(fields, "$key=$value_str")
147143
end
@@ -169,8 +165,8 @@ Transform the features in the dataset.
169165
"""
170166
function StatsBase.transform(t, dataset::AbstractVector{<:DataSample})
171167
return map(dataset) do d
172-
(; maximizer_kwargs, extra, x, θ, y) = d
173-
DataSample(StatsBase.transform(t, x), θ, y, maximizer_kwargs, extra)
168+
(; context, extra, x, θ, y) = d
169+
DataSample(StatsBase.transform(t, x), θ, y, context, extra)
174170
end
175171
end
176172

@@ -192,8 +188,8 @@ Reconstruct the features in the dataset.
192188
"""
193189
function StatsBase.reconstruct(t, dataset::AbstractVector{<:DataSample})
194190
return map(dataset) do d
195-
(; maximizer_kwargs, extra, x, θ, y) = d
196-
DataSample(StatsBase.reconstruct(t, x), θ, y, maximizer_kwargs, extra)
191+
(; context, extra, x, θ, y) = d
192+
DataSample(StatsBase.reconstruct(t, x), θ, y, context, extra)
197193
end
198194
end
199195

0 commit comments

Comments
 (0)