Skip to content

Commit 15edb07

Browse files
committed
rename instance_kwargs to maximizer_kwargs, + other small fixes
1 parent cb96efe commit 15edb07

21 files changed

Lines changed: 139 additions & 113 deletions

docs/src/tutorials/warcraft_tutorial.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ x = sample.x
3030
θ_true = sample.θ
3131
# `y` correspond to the optimal shortest path, encoded as a binary matrix:
3232
y_true = sample.y
33-
# `context` is not used in this benchmark (no solver kwargs needed), so it is empty:
34-
isempty(sample.instance_kwargs)
33+
# `maximizer_kwargs` is not used in this benchmark (no solver kwargs needed), so it is empty:
34+
isempty(sample.maximizer_kwargs)
3535

3636
# For some benchmarks, we provide the following plotting method [`plot_solution`](@ref) to visualize the data:
3737
plot_solution(b, sample)

ext/DFLBenchmarksPlotsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function plot_solution(bench::AbstractBenchmark, sample::DataSample, y; kwargs..
2222
return plot_solution(
2323
bench,
2424
DataSample(;
25-
sample.instance_kwargs..., x=sample.x, θ=sample.θ, y=y, extra=sample.extra
25+
sample.maximizer_kwargs..., x=sample.x, θ=sample.θ, y=y, extra=sample.extra
2626
);
2727
kwargs...,
2828
)

src/ContextualStochasticArgmax/ContextualStochasticArgmax.jl

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,12 @@ using DocStringExtensions: TYPEDEF, TYPEDFIELDS
55
using Flux: Dense
66
using Random: Random, AbstractRNG, MersenneTwister
77

8-
function one_hot_argmax(z::AbstractVector{R}; kwargs...) where {R<:Real}
9-
e = zeros(R, length(z))
10-
e[argmax(z)] = one(R)
11-
return e
12-
end
13-
148
"""
159
$TYPEDEF
1610
1711
Minimal contextual stochastic argmax benchmark.
1812
19-
Per instance: `c_base ~ U[0,1]^n` (base utility, part of instance kwargs and base features).
13+
Per instance: `c_base ~ U[0,1]^n` (base utility, stored in `extra` of the instance sample).
2014
Per context draw: `x_raw ~ N(0, I_d)` (observable context). Features: `x = [c_base; x_raw]`.
2115
Per scenario: `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`.
2216
The learner sees `x` and must predict `θ̂` so that `argmax(θ̂)` ≈ `argmax(ξ)`.
@@ -49,29 +43,42 @@ end
4943
Utils.is_minimization_problem(::ContextualStochasticArgmaxBenchmark) = false
5044
Utils.generate_maximizer(::ContextualStochasticArgmaxBenchmark) = one_hot_argmax
5145

52-
# c_base: base features (in x) and solver kwarg (in instance_kwargs for generate_scenario)
46+
"""
47+
generate_instance(::ContextualStochasticArgmaxBenchmark, rng)
48+
49+
Draw `c_base ~ U[0,1]^n` and store it in `extra`. No solver kwargs are needed
50+
(the maximizer is `one_hot_argmax`, which takes no kwargs).
51+
"""
5352
function Utils.generate_instance(
5453
bench::ContextualStochasticArgmaxBenchmark, rng::AbstractRNG; kwargs...
5554
)
5655
c_base = rand(rng, Float32, bench.n)
57-
return DataSample(; x=c_base, c_base=c_base)
56+
return DataSample(; extra=(; c_base))
5857
end
5958

60-
# Enriches instance_sample: x = [c_base; x_raw], x_raw in extra for generate_scenario
59+
"""
60+
generate_context(::ContextualStochasticArgmaxBenchmark, rng, instance_sample)
61+
62+
Draw `x_raw ~ N(0, I_d)` and return a context sample with:
63+
- `x = [c_base; x_raw]`: full feature vector seen by the ML model.
64+
- `extra = (; c_base, x_raw)`: latents spread into [`generate_scenario`](@ref).
65+
"""
6166
function Utils.generate_context(
6267
bench::ContextualStochasticArgmaxBenchmark,
6368
rng::AbstractRNG,
6469
instance_sample::DataSample,
6570
)
71+
c_base = instance_sample.c_base
6672
x_raw = randn(rng, Float32, bench.d)
67-
return DataSample(;
68-
x=vcat(instance_sample.x, x_raw),
69-
instance_sample.instance_kwargs...,
70-
extra=(; x_raw),
71-
)
73+
return DataSample(; x=vcat(c_base, x_raw), extra=(; x_raw, c_base))
7274
end
7375

74-
# ξ = c_base + W * x_raw + noise (c_base from instance_kwargs, x_raw from ctx.extra)
76+
"""
77+
generate_scenario(::ContextualStochasticArgmaxBenchmark, rng; c_base, x_raw, kwargs...)
78+
79+
Draw `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`.
80+
`c_base` and `x_raw` are spread from `ctx.extra` by the framework.
81+
"""
7582
function Utils.generate_scenario(
7683
bench::ContextualStochasticArgmaxBenchmark,
7784
rng::AbstractRNG;

src/DecisionFocusedLearningBenchmarks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ export generate_instance,
7474
generate_sample, generate_dataset, generate_environments, generate_environment
7575
export generate_scenario, generate_context
7676
export generate_baseline_policies
77-
export SAA
77+
export SampleAverageApproximation
7878
export generate_statistical_model
7979
export generate_maximizer
8080
export generate_anticipative_solver, generate_parametric_anticipative_solver

src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ Returns a [`DataSample`](@ref) with features `x` and `instance` set, but `y=noth
116116
To obtain labeled samples, pass a `target_policy` to [`generate_dataset`](@ref):
117117
118118
```julia
119-
policy = sample -> DataSample(; sample.instance_kwargs..., x=sample.x,
119+
policy = sample -> DataSample(; sample.maximizer_kwargs..., x=sample.x,
120120
y=column_generation_algorithm(sample.instance))
121121
dataset = generate_dataset(benchmark, N; target_policy=policy)
122122
```

src/StochasticVehicleScheduling/policies.jl

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ SAA baseline policy: builds a stochastic instance from all K scenarios and solve
55
via column generation.
66
Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.
77
"""
8-
function svs_saa_policy(instance_sample, ctx_sample, scenarios)
9-
stochastic_inst = build_stochastic_instance(instance_sample.instance, scenarios)
8+
function svs_saa_policy(ctx_sample, scenarios)
9+
stochastic_inst = build_stochastic_instance(ctx_sample.instance, scenarios)
1010
y = column_generation_algorithm(stochastic_inst)
1111
return [
1212
DataSample(;
13-
instance_sample.instance_kwargs...,
13+
ctx_sample.maximizer_kwargs...,
1414
x=ctx_sample.x,
1515
y,
1616
extra=(; ctx_sample.extra..., scenarios),
@@ -24,13 +24,11 @@ $TYPEDSIGNATURES
2424
Deterministic baseline policy: solves the deterministic MIP (ignores scenario delays).
2525
Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.
2626
"""
27-
function svs_deterministic_policy(
28-
instance_sample, ctx_sample, scenarios; model_builder=highs_model
29-
)
30-
y = deterministic_mip(instance_sample.instance; model_builder)
27+
function svs_deterministic_policy(ctx_sample, scenarios; model_builder=highs_model)
28+
y = deterministic_mip(ctx_sample.instance; model_builder)
3129
return [
3230
DataSample(;
33-
instance_sample.instance_kwargs...,
31+
ctx_sample.maximizer_kwargs...,
3432
x=ctx_sample.x,
3533
y,
3634
extra=(; ctx_sample.extra..., scenarios),
@@ -45,12 +43,12 @@ Local search baseline policy: builds a stochastic instance from all K scenarios
4543
solves via local search heuristic.
4644
Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.
4745
"""
48-
function svs_local_search_policy(instance_sample, ctx_sample, scenarios)
49-
stochastic_inst = build_stochastic_instance(instance_sample.instance, scenarios)
46+
function svs_local_search_policy(ctx_sample, scenarios)
47+
stochastic_inst = build_stochastic_instance(ctx_sample.instance, scenarios)
5048
y = local_search(stochastic_inst)
5149
return [
5250
DataSample(;
53-
instance_sample.instance_kwargs...,
51+
ctx_sample.maximizer_kwargs...,
5452
x=ctx_sample.x,
5553
y,
5654
extra=(; ctx_sample.extra..., scenarios),
@@ -68,13 +66,11 @@ Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.
6866
Prefer this over [`svs_saa_policy`](@ref) when an exact solution is needed; requires
6967
SCIP (default) or Gurobi.
7068
"""
71-
function svs_saa_mip_policy(
72-
instance_sample, ctx_sample, scenarios; model_builder=scip_model
73-
)
74-
y = compact_linearized_mip(instance_sample.instance, scenarios; model_builder)
69+
function svs_saa_mip_policy(ctx_sample, scenarios; model_builder=scip_model)
70+
y = compact_linearized_mip(ctx_sample.instance, scenarios; model_builder)
7571
return [
7672
DataSample(;
77-
instance_sample.instance_kwargs...,
73+
ctx_sample.maximizer_kwargs...,
7874
x=ctx_sample.x,
7975
y,
8076
extra=(; ctx_sample.extra..., scenarios),
@@ -86,7 +82,7 @@ end
8682
$TYPEDSIGNATURES
8783
8884
Return the named baseline policies for [`StochasticVehicleSchedulingBenchmark`](@ref).
89-
Each policy has signature `(instance_sample, ctx_sample, scenarios) -> Vector{DataSample}`.
85+
Each policy has signature `(ctx_sample, scenarios) -> Vector{DataSample}`.
9086
"""
9187
function svs_generate_baseline_policies(::StochasticVehicleSchedulingBenchmark)
9288
return (;

src/Utils/Utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ include("model_builders.jl")
2222

2323
export DataSample, Policy
2424
export evaluate_policy!
25-
export TopKMaximizer
25+
export TopKMaximizer, one_hot_argmax
2626

2727
export AbstractEnvironment, get_seed, is_terminated, observe, reset!, step!
2828

@@ -33,7 +33,7 @@ export generate_instance, generate_sample, generate_dataset
3333
export generate_statistical_model, generate_maximizer
3434
export generate_scenario, generate_context
3535
export generate_environment, generate_environments
36-
export SAA
36+
export SampleAverageApproximation
3737
export generate_baseline_policies
3838
export generate_anticipative_solver, generate_parametric_anticipative_solver
3939

src/Utils/data_sample.jl

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ $TYPEDEF
44
Data sample data structure.
55
Its main purpose is to store datasets generated by the benchmarks.
66
It has 3 main (optional) fields: features `x`, cost parameters `θ`, and solution `y`.
7-
Additionally, it has an `instance_kwargs` field (solver kwargs, spread into the maximizer as
8-
`maximizer(θ; sample.instance_kwargs...)`) and an `extra` field (non-solver data, never passed
7+
Additionally, it has an `maximizer_kwargs` field (solver kwargs, spread into the maximizer as
8+
`maximizer(θ; sample.maximizer_kwargs...)`) and an `extra` field (non-solver data, never passed
99
to the maximizer).
1010
1111
The separation prevents silent breakage from accidentally passing non-solver data
@@ -15,8 +15,8 @@ The separation prevents silent breakage from accidentally passing non-solver dat
1515
$TYPEDFIELDS
1616
"""
1717
struct DataSample{
18-
CTX<:NamedTuple,
19-
EX<:NamedTuple,
18+
K<:NamedTuple,
19+
E<:NamedTuple,
2020
F<:Union{AbstractArray,Nothing},
2121
S<:Union{AbstractArray,Nothing},
2222
C<:Union{AbstractArray,Nothing},
@@ -28,63 +28,63 @@ struct DataSample{
2828
"output solution (optional)"
2929
y::S
3030
"solver kwargs, e.g. instance, graph, etc."
31-
instance_kwargs::CTX
31+
maximizer_kwargs::K
3232
"additional data, never passed to the maximizer, e.g. scenario, objective value, reward,
3333
step count, etc. Can be used for any purpose by the user, such as plotting utilities."
34-
extra::EX
34+
extra::E
3535
end
3636

3737
"""
3838
$TYPEDSIGNATURES
3939
4040
Constructor for `DataSample` with keyword arguments.
4141
42-
All keyword arguments beyond `x`, `θ`, `y`, and `extra` are collected into the `instance_kwargs`
42+
All keyword arguments beyond `x`, `θ`, `y`, and `extra` are collected into the `maximizer_kwargs`
4343
field (solver kwargs). The `extra` keyword accepts a `NamedTuple` of non-solver data.
4444
45-
Fields in `instance_kwargs` and `extra` must be disjoint. An error is thrown if they overlap.
45+
Fields in `maximizer_kwargs` and `extra` must be disjoint. An error is thrown if they overlap.
4646
Both can be accessed directly via property forwarding.
4747
4848
# Examples
4949
```julia
50-
# Instance goes in instance_kwargs
50+
# Instance goes in maximizer_kwargs
5151
d = DataSample(x=[1,2,3], θ=[4,5,6], y=[7,8,9], instance="my_instance")
52-
d.instance # "my_instance" (from instance_kwargs)
52+
d.instance # "my_instance" (from maximizer_kwargs)
5353
5454
# Scenario goes in extra
5555
d = DataSample(x=x, y=y, instance=inst, extra=(; scenario=ξ))
5656
d.scenario # ξ (from extra)
5757
58-
# State goes in instance_kwargs, reward in extra
58+
# State goes in maximizer_kwargs, reward in extra
5959
d = DataSample(x=x, y=y, instance=state, extra=(; reward=-1.5))
60-
d.instance # state (from instance_kwargs)
60+
d.instance # state (from maximizer_kwargs)
6161
d.reward # -1.5 (from extra)
6262
```
6363
"""
6464
function DataSample(; x=nothing, θ=nothing, y=nothing, extra=NamedTuple(), kwargs...)
65-
instance_kwargs = (; kwargs...)
66-
overlap = intersect(keys(instance_kwargs), keys(extra))
65+
maximizer_kwargs = (; kwargs...)
66+
overlap = intersect(keys(maximizer_kwargs), keys(extra))
6767
if !isempty(overlap)
6868
error(
69-
"Keys $(collect(overlap)) appear in both instance_kwargs and extra of DataSample",
69+
"Keys $(collect(overlap)) appear in both maximizer_kwargs and extra of DataSample",
7070
)
7171
end
72-
return DataSample(x, θ, y, instance_kwargs, extra)
72+
return DataSample(x, θ, y, maximizer_kwargs, extra)
7373
end
7474

7575
"""
7676
$TYPEDSIGNATURES
7777
7878
Extended property access for `DataSample`.
7979
80-
Allows accessing `instance_kwargs` and `extra` fields directly as properties.
81-
`instance_kwargs` is searched first; if the key is not found there, `extra` is searched.
80+
Allows accessing `maximizer_kwargs` and `extra` fields directly as properties.
81+
`maximizer_kwargs` is searched first; if the key is not found there, `extra` is searched.
8282
"""
8383
function Base.getproperty(d::DataSample, name::Symbol)
84-
if name in (:x, , :y, :instance_kwargs, :extra)
84+
if name in (:x, , :y, :maximizer_kwargs, :extra)
8585
return getfield(d, name)
8686
else
87-
ctx = getfield(d, :instance_kwargs)
87+
ctx = getfield(d, :maximizer_kwargs)
8888
if haskey(ctx, name)
8989
return getproperty(ctx, name)
9090
end
@@ -96,12 +96,12 @@ end
9696
$TYPEDSIGNATURES
9797
9898
Return all property names of a `DataSample`, including both struct fields and forwarded
99-
fields from `instance_kwargs` and `extra`.
99+
fields from `maximizer_kwargs` and `extra`.
100100
101101
This enables tab completion for all available properties.
102102
"""
103103
function Base.propertynames(d::DataSample, private::Bool=false)
104-
ctx_names = propertynames(getfield(d, :instance_kwargs), private)
104+
ctx_names = propertynames(getfield(d, :maximizer_kwargs), private)
105105
extra_names = propertynames(getfield(d, :extra), private)
106106
return (fieldnames(DataSample)..., ctx_names..., extra_names...)
107107
end
@@ -128,7 +128,7 @@ function Base.show(io::IO, d::DataSample)
128128
y_str = sprint(show, d.y; context=io_limited)
129129
push!(fields, "y_true=$y_str")
130130
end
131-
for (key, value) in pairs(d.instance_kwargs)
131+
for (key, value) in pairs(d.maximizer_kwargs)
132132
value_str = sprint(show, value; context=io_limited)
133133
push!(fields, "$key=$value_str")
134134
end
@@ -156,8 +156,8 @@ Transform the features in the dataset.
156156
"""
157157
function StatsBase.transform(t, dataset::AbstractVector{<:DataSample})
158158
return map(dataset) do d
159-
(; instance_kwargs, extra, x, θ, y) = d
160-
DataSample(StatsBase.transform(t, x), θ, y, instance_kwargs, extra)
159+
(; maximizer_kwargs, extra, x, θ, y) = d
160+
DataSample(StatsBase.transform(t, x), θ, y, maximizer_kwargs, extra)
161161
end
162162
end
163163

@@ -179,8 +179,8 @@ Reconstruct the features in the dataset.
179179
"""
180180
function StatsBase.reconstruct(t, dataset::AbstractVector{<:DataSample})
181181
return map(dataset) do d
182-
(; instance_kwargs, extra, x, θ, y) = d
183-
DataSample(StatsBase.reconstruct(t, x), θ, y, instance_kwargs, extra)
182+
(; maximizer_kwargs, extra, x, θ, y) = d
183+
DataSample(StatsBase.reconstruct(t, x), θ, y, maximizer_kwargs, extra)
184184
end
185185
end
186186

0 commit comments

Comments
 (0)