Skip to content

Commit 4552193

Browse files
committed
rename nb_contexts to contexts_per_instance
1 parent ecd25bb commit 4552193

File tree

2 files changed

+29
-16
lines changed

2 files changed

+29
-16
lines changed

src/Utils/interface.jl

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@ end
173173
$TYPEDSIGNATURES
174174
175175
Check if the benchmark is a minimization problem.
176+
177+
Defaults to `true`. **Maximization benchmarks must override this method**, forgetting to do
178+
so will cause `compute_gap` to compute the gap with the wrong sign without any error or warning.
176179
"""
177180
function is_minimization_problem(::AbstractBenchmark)
178181
return true
@@ -237,14 +240,14 @@ anticipative targets and compute objective values.
237240
238241
# Dataset generation (exogenous only)
239242
[`generate_dataset`](@ref) is specialised for [`ExogenousStochasticBenchmark`](@ref) and
240-
supports all three standard structures via `nb_scenarios` and `nb_contexts`:
243+
supports all three standard structures via `nb_scenarios` and `contexts_per_instance`:
241244
242245
| Setting | Call |
243246
|---------|------|
244247
| 1 instance with K scenarios | `generate_dataset(bench, 1; nb_scenarios=K)` |
245248
| N instances with 1 scenario | `generate_dataset(bench, N)` (default) |
246249
| N instances with K scenarios | `generate_dataset(bench, N; nb_scenarios=K)` |
247-
| N instances with M contexts × K scenarios | `generate_dataset(bench, N; nb_contexts=M, nb_scenarios=K)` |
250+
| N instances with M contexts × K scenarios | `generate_dataset(bench, N; contexts_per_instance=M, nb_scenarios=K)` |
248251
249252
By default (no `target_policy`), each [`DataSample`](@ref) has `context` holding
250253
the solver kwargs and `extra=(; scenario)` holding one scenario.
@@ -306,13 +309,17 @@ end
306309
"""
307310
generate_anticipative_solver(::AbstractBenchmark) -> callable
308311
309-
Return a callable that computes the anticipative solution.
312+
Return a callable that computes the anticipative (oracle) solution.
313+
The calling convention differs by benchmark category:
314+
315+
**Stochastic benchmarks** ([`AbstractStochasticBenchmark`](@ref)):
316+
Returns `(scenario; context...) -> y`.
317+
Called once per scenario to obtain the optimal label.
310318
311-
- For [`AbstractStochasticBenchmark`](@ref): returns `(scenario; context...) -> y`.
312-
- For [`AbstractDynamicBenchmark`](@ref): returns
313-
`(env; reset_env=true, kwargs...) -> Vector{DataSample}`, a full training trajectory.
314-
`reset_env=true` resets the env before solving (initial dataset building);
315-
`reset_env=false` starts from the current env state.
319+
**Dynamic benchmarks** ([`AbstractDynamicBenchmark`](@ref)):
320+
Returns `(env; reset_env=true, kwargs...) -> Vector{DataSample}`, a full trajectory.
321+
`reset_env=true` resets the environment before solving (used for initial dataset building);
322+
`reset_env=false` starts from the current environment state (used inside DAgger rollouts).
316323
"""
317324
function generate_anticipative_solver end
318325

@@ -333,7 +340,7 @@ Default [`generate_sample`](@ref) for exogenous stochastic benchmarks.
333340
334341
Calls [`generate_instance`](@ref), then [`generate_context`](@ref) (default: identity),
335342
draws scenarios via [`generate_scenario`](@ref), then:
336-
- Without `target_policy`: returns M×K unlabeled samples (`nb_contexts` contexts ×
343+
- Without `target_policy`: returns M×K unlabeled samples (`contexts_per_instance` contexts ×
337344
`nb_scenarios` scenarios each), each with one scenario in `extra=(; scenario=ξ)`.
338345
- With `target_policy`: calls `target_policy(ctx_sample, scenarios)`
339346
per context and returns the result.
@@ -355,7 +362,7 @@ function generate_sample(
355362
rng;
356363
target_policy=nothing,
357364
nb_scenarios::Int=1,
358-
nb_contexts::Int=1,
365+
contexts_per_instance::Int=1,
359366
kwargs...,
360367
)
361368
instance_sample = generate_instance(bench, rng; kwargs...)
@@ -382,7 +389,7 @@ function generate_sample(
382389
]
383390
target_policy(ctx, scenarios)
384391
end
385-
end for _ in 1:nb_contexts
392+
end for _ in 1:contexts_per_instance
386393
),
387394
)
388395
end
@@ -392,7 +399,7 @@ $TYPEDSIGNATURES
392399
393400
Specialised [`generate_dataset`](@ref) for exogenous stochastic benchmarks.
394401
395-
Generates `nb_instances` problem instances, each with `nb_contexts` context draws
402+
Generates `nb_instances` problem instances, each with `contexts_per_instance` context draws
396403
and `nb_scenarios` scenario draws per context. The scenario→sample mapping is controlled
397404
by the `target_policy`:
398405
- Without `target_policy` (default): M contexts × K scenarios produce M×K unlabeled
@@ -403,7 +410,7 @@ by the `target_policy`:
403410
404411
# Keyword arguments
405412
- `nb_scenarios::Int = 1`: scenarios per context (K).
406-
- `nb_contexts::Int = 1`: context draws per instance (M).
413+
- `contexts_per_instance::Int = 1`: context draws per instance (M).
407414
- `target_policy`: when provided, called as
408415
`target_policy(ctx_sample, scenarios)` to compute labels.
409416
Defaults to `nothing` (unlabeled samples).
@@ -416,7 +423,7 @@ function generate_dataset(
416423
nb_instances::Int;
417424
target_policy=nothing,
418425
nb_scenarios::Int=1,
419-
nb_contexts::Int=1,
426+
contexts_per_instance::Int=1,
420427
seed=nothing,
421428
rng=MersenneTwister(seed),
422429
kwargs...,
@@ -425,7 +432,7 @@ function generate_dataset(
425432
vcat,
426433
(
427434
generate_sample(
428-
bench, rng; target_policy, nb_scenarios, nb_contexts, kwargs...
435+
bench, rng; target_policy, nb_scenarios, contexts_per_instance, kwargs...
429436
) for _ in 1:nb_instances
430437
),
431438
)
@@ -441,6 +448,12 @@ For each (instance, context) pair, draws `nb_scenarios` fixed scenarios. These a
441448
in the sample and used for feature computation, target labeling (via `target_policy`),
442449
and gap evaluation.
443450
451+
!!! note
452+
`SampleAverageApproximation <: AbstractBenchmark`, not `AbstractStochasticBenchmark`.
453+
This is intentional: after wrapping, the scenarios are fixed at dataset-generation time
454+
and the benchmark behaves as a static problem. Functions dispatching on
455+
`AbstractStochasticBenchmark` (e.g. `is_exogenous`) will not match SAA instances.
456+
444457
# Fields
445458
$TYPEDFIELDS
446459
"""

test/contextual_stochastic_argmax.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
b = ContextualStochasticArgmaxBenchmark(; n=5, d=3, seed=0)
55

66
# Unlabeled: N instances × M contexts × K scenarios = N*M*K samples
7-
dataset = generate_dataset(b, 10; nb_contexts=2, nb_scenarios=4)
7+
dataset = generate_dataset(b, 10; contexts_per_instance=2, nb_scenarios=4)
88
@test length(dataset) == 80
99
sample = first(dataset)
1010
@test size(sample.x) == (8,) # n+d

0 commit comments

Comments
 (0)