Skip to content

Commit 327c3c1

Browse files
committed
fix: make sure DataSample field cannot overlap with kwargs and extra
1 parent 386e6a5 commit 327c3c1

2 files changed

Lines changed: 19 additions & 2 deletions

File tree

src/Utils/data_sample.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ Constructor for `DataSample` with keyword arguments.
4242
All keyword arguments beyond `x`, `θ`, `y`, and `extra` are collected into the `maximizer_kwargs`
4343
field (solver kwargs). The `extra` keyword accepts a `NamedTuple` of non-solver data.
4444
45-
Fields in `maximizer_kwargs` and `extra` must be disjoint. An error is thrown if they overlap.
45+
Fields in `maximizer_kwargs` and `extra` must be disjoint. Neither may use a reserved
46+
struct field name (`x`, `θ`, `y`, `maximizer_kwargs`, `extra`). An error is thrown in
47+
both cases.
4648
Both can be accessed directly via property forwarding.
4749
4850
# Examples
@@ -69,6 +71,19 @@ function DataSample(; x=nothing, θ=nothing, y=nothing, extra=NamedTuple(), kwar
6971
"Keys $(collect(overlap)) appear in both maximizer_kwargs and extra of DataSample",
7072
)
7173
end
74+
reserved = (:x, , :y, :maximizer_kwargs, :extra)
75+
shadowed_ctx = intersect(keys(maximizer_kwargs), reserved)
76+
if !isempty(shadowed_ctx)
77+
error(
78+
"Keys $(collect(shadowed_ctx)) in maximizer_kwargs shadow DataSample struct fields",
79+
)
80+
end
81+
shadowed_extra = intersect(keys(extra), reserved)
82+
if !isempty(shadowed_extra)
83+
error(
84+
"Keys $(collect(shadowed_extra)) in extra shadow DataSample struct fields",
85+
)
86+
end
7287
return DataSample(x, θ, y, maximizer_kwargs, extra)
7388
end
7489

src/Utils/interface.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,8 @@ meaning (whether uncertainty is independent of decisions).
547547
548548
# Additional optional methods
549549
- [`generate_environment`](@ref)`(bench, rng)`: initialize a single rollout environment.
550+
Must return an [`AbstractEnvironment`](@ref) (see `environment.jl` for the full protocol:
551+
[`reset!`](@ref), [`observe`](@ref), [`step!`](@ref), [`is_terminated`](@ref)).
550552
Implement this instead of overriding [`generate_environments`](@ref) when environments
551553
can be drawn independently.
552554
- [`generate_baseline_policies`](@ref)`(bench)`: returns named baseline callables of
@@ -568,7 +570,7 @@ const ExogenousDynamicBenchmark = AbstractDynamicBenchmark{true}
568570
const EndogenousDynamicBenchmark = AbstractDynamicBenchmark{false}
569571

570572
"""
571-
generate_environment(::AbstractDynamicBenchmark, rng::AbstractRNG; kwargs...)
573+
generate_environment(::AbstractDynamicBenchmark, rng::AbstractRNG; kwargs...) -> AbstractEnvironment
572574
573575
Initialize a single environment for the given dynamic benchmark.
574576
Primary implementation target for the count-based [`generate_environments`](@ref) default.

0 commit comments

Comments
 (0)