Skip to content

Commit ecd25bb

Browse files
committed
better seed handling
1 parent a569b3d commit ecd25bb

5 files changed

Lines changed: 25 additions & 20 deletions

File tree

src/Utils/data_sample.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ $TYPEDEF
44
Data sample data structure.
55
Its main purpose is to store datasets generated by the benchmarks.
66
It has 3 main (optional) fields: features `x`, cost parameters `θ`, and solution `y`.
7+
Currently, all three are restricted to `AbstractArray` or `nothing`.
78
Additionally, it has a `context` field (solver and scenario-generation context, spread into the
89
maximizer as `maximizer(θ; sample.context...)`) and an `extra` field (non-solver data, never passed
910
to the maximizer).
@@ -131,11 +132,11 @@ function Base.show(io::IO, d::DataSample)
131132
end
132133
if !isnothing(d.θ)
133134
θ_str = sprint(show, d.θ; context=io_limited)
134-
push!(fields, "θ_true=$θ_str")
135+
push!(fields, "θ=$θ_str")
135136
end
136137
if !isnothing(d.y)
137138
y_str = sprint(show, d.y; context=io_limited)
138-
push!(fields, "y_true=$y_str")
139+
push!(fields, "y=$y_str")
139140
end
140141
for (key, value) in pairs(d.context)
141142
value_str = sprint(show, value; context=io_limited)
@@ -166,7 +167,7 @@ Transform the features in the dataset.
166167
function StatsBase.transform(t, dataset::AbstractVector{<:DataSample})
167168
return map(dataset) do d
168169
(; context, extra, x, θ, y) = d
169-
DataSample(StatsBase.transform(t, x), θ, y, context, extra)
170+
DataSample(; x=StatsBase.transform(t, x), θ, y, context..., extra)
170171
end
171172
end
172173

src/Utils/interface.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ function generate_dataset(
8383
rng=MersenneTwister(seed),
8484
kwargs...,
8585
)
86-
Random.seed!(rng, seed)
8786
return [
8887
begin
8988
sample = generate_sample(bench, rng; kwargs...)
@@ -422,7 +421,6 @@ function generate_dataset(
422421
rng=MersenneTwister(seed),
423422
kwargs...,
424423
)
425-
Random.seed!(rng, seed)
426424
return reduce(
427425
vcat,
428426
(
@@ -500,7 +498,6 @@ function generate_dataset(
500498
rng=MersenneTwister(seed),
501499
kwargs...,
502500
)
503-
Random.seed!(rng, seed)
504501
return reduce(
505502
vcat, (generate_sample(saa, rng; target_policy, kwargs...) for _ in 1:nb_instances)
506503
)
@@ -561,6 +558,20 @@ meaning (whether uncertainty is independent of decisions).
561558
"""
562559
abstract type AbstractDynamicBenchmark{exogenous} <: AbstractStochasticBenchmark{exogenous} end
563560

561+
"""
562+
$TYPEDSIGNATURES
563+
564+
Intercepts accidental calls to `generate_sample` on dynamic benchmarks and throws a
565+
descriptive error pointing at the correct entry point.
566+
"""
567+
function generate_sample(bench::AbstractDynamicBenchmark, rng; kwargs...)
568+
return error(
569+
"`generate_sample` is not supported for dynamic benchmarks ($(typeof(bench))). " *
570+
"Use `generate_environments` and " *
571+
"`generate_dataset(bench, environments; target_policy=...)` instead.",
572+
)
573+
end
574+
564575
"Alias for [`AbstractDynamicBenchmark`](@ref)`{true}`. Uncertainty is independent of decisions."
565576
const ExogenousDynamicBenchmark = AbstractDynamicBenchmark{true}
566577

@@ -591,7 +602,6 @@ function generate_environments(
591602
rng=MersenneTwister(seed),
592603
kwargs...,
593604
)
594-
Random.seed!(rng, seed)
595605
return [generate_environment(bench, rng; kwargs...) for _ in 1:n]
596606
end
597607

@@ -612,14 +622,8 @@ to obtain standard baseline callables (e.g. the anticipative solver).
612622
- `rng`: random number generator.
613623
"""
614624
function generate_dataset(
615-
bench::ExogenousDynamicBenchmark,
616-
environments::AbstractVector;
617-
target_policy,
618-
seed=nothing,
619-
rng=MersenneTwister(seed),
620-
kwargs...,
625+
bench::ExogenousDynamicBenchmark, environments::AbstractVector; target_policy, kwargs...
621626
)
622-
Random.seed!(rng, seed)
623627
return reduce(vcat, (target_policy(env) for env in environments))
624628
end
625629

test/dynamic_assortment.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,14 +321,14 @@ end
321321

322322
# vector-of-environments overload
323323
dataset = generate_dataset(b, envs; target_policy=target_policy)
324-
@test dataset isa Vector{D} where D <: DataSample
324+
@test dataset isa Vector{D} where {D<:DataSample}
325325
@test !isempty(dataset)
326326
@test all(!isnothing(s.x) for s in dataset)
327327
@test all(!isnothing(s.y) for s in dataset)
328328

329329
# count-based wrapper
330330
dataset2 = generate_dataset(b, 3; seed=7, target_policy=target_policy)
331-
@test dataset2 isa Vector{D} where D <: DataSample
331+
@test dataset2 isa Vector{D} where {D<:DataSample}
332332
@test !isempty(dataset2)
333333
end
334334

test/dynamic_vsp.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,14 @@ end
6363

6464
# vector-of-environments overload
6565
dataset = generate_dataset(b, envs; target_policy=target_policy)
66-
@test dataset isa Vector{D} where D <: DataSample
66+
@test dataset isa Vector{D} where {D<:DataSample}
6767
@test !isempty(dataset)
6868
@test all(!isnothing(s.x) for s in dataset)
6969
@test all(!isnothing(s.y) for s in dataset)
7070

7171
# count-based wrapper
7272
dataset2 = generate_dataset(b, 3; seed=1, target_policy=target_policy)
73-
@test dataset2 isa Vector{D} where D <: DataSample
73+
@test dataset2 isa Vector{D} where {D<:DataSample}
7474
@test !isempty(dataset2)
7575

7676
# seed keyword is forwarded: same seed → same dataset

test/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ end
5454
show(io, sample)
5555
s = String(take!(io))
5656
@test occursin("DataSample(", s)
57-
@test occursin("θ_true", s)
58-
@test occursin("y_true", s)
57+
@test occursin("θ", s)
58+
@test occursin("y", s)
5959
@test occursin("instance=\"this is an instance\"", s)
6060

6161
@test propertynames(sample) == (:x, , :y, :context, :extra, :instance)

0 commit comments

Comments
 (0)