Skip to content

Commit d1598f4

Browse files
committed
Enforce info field to be a NamedTuple
1 parent 2ccbcc7 commit d1598f4

21 files changed

Lines changed: 121 additions & 70 deletions

src/Argmax2D/Argmax2D.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ function Utils.generate_sample(bench::Argmax2DBenchmark, rng::AbstractRNG)
6262
θ_true ./= 2 * norm(θ_true)
6363
instance = build_polytope(rand(rng, polytope_vertex_range); shift=rand(rng))
6464
y_true = maximizer(θ_true; instance)
65-
return DataSample(; x=x, θ=θ_true, y=y_true, info=instance)
65+
return DataSample(; x=x, θ=θ_true, y=y_true, instance=instance)
6666
end
6767

6868
"""
@@ -88,11 +88,11 @@ function Utils.generate_statistical_model(
8888
return model
8989
end
9090

91-
function Utils.plot_data(::Argmax2DBenchmark; info, θ, kwargs...)
91+
function Utils.plot_data(::Argmax2DBenchmark; instance, θ, kwargs...)
9292
pl = init_plot()
93-
plot_polytope!(pl, info)
93+
plot_polytope!(pl, instance)
9494
plot_objective!(pl, θ)
95-
return plot_maximizer!(pl, θ, info, maximizer)
95+
return plot_maximizer!(pl, θ, instance, maximizer)
9696
end
9797

9898
"""
@@ -101,9 +101,13 @@ $TYPEDSIGNATURES
101101
Plot the data sample for the [`Argmax2DBenchmark`](@ref).
102102
"""
103103
function Utils.plot_data(
104-
bench::Argmax2DBenchmark, sample::DataSample; info=sample.info, θ=sample.θ, kwargs...
104+
bench::Argmax2DBenchmark,
105+
sample::DataSample;
106+
instance=sample.instance,
107+
θ=sample.θ,
108+
kwargs...,
105109
)
106-
return Utils.plot_data(bench; info, θ, kwargs...)
110+
return Utils.plot_data(bench; instance, θ, kwargs...)
107111
end
108112

109113
export Argmax2DBenchmark

src/DecisionFocusedLearningBenchmarks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ export generate_sample, generate_dataset, generate_environments, generate_enviro
7070
export generate_scenario
7171
export generate_policies
7272
export generate_statistical_model
73-
export generate_maximizer, maximizer_kwargs
73+
export generate_maximizer
7474
export generate_anticipative_solution
7575
export is_exogenous, is_endogenous
7676

src/DynamicAssortment/DynamicAssortment.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ Outputs a data sample containing an [`Instance`](@ref).
8383
function Utils.generate_sample(
8484
b::DynamicAssortmentBenchmark, rng::AbstractRNG=MersenneTwister(0)
8585
)
86-
return DataSample(; info=Instance(b, rng))
86+
return DataSample(; instance=Instance(b, rng))
8787
end
8888

8989
"""

src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ function Utils.generate_dataset(b::DynamicVehicleSchedulingBenchmark, dataset_si
7070
dataset_size = min(dataset_size, length(files))
7171
return [
7272
DataSample(;
73-
info=Instance(
73+
instance=Instance(
7474
read_vsp_instance(files[i]);
7575
max_requests_per_epoch,
7676
Δ_dispatch,

src/DynamicVehicleScheduling/anticipative_solver.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ function anticipative_solver(
222222
compute_features(state, env.instance)
223223
end
224224

225-
return DataSample(; info=(; state, reward), y=y_true, x)
225+
return DataSample(; y=y_true, x, state, reward)
226226
end
227227

228228
return obj, dataset

src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ end
6767
function Utils.objective_value(
6868
::StochasticVehicleSchedulingBenchmark, sample::DataSample, y::BitVector
6969
)
70-
return evaluate_solution(y, sample.info)
70+
return evaluate_solution(y, sample.instance)
7171
end
7272

7373
"""
@@ -98,7 +98,7 @@ function Utils.generate_sample(
9898
else
9999
nothing
100100
end
101-
return DataSample(; x, info=instance, y=y_true)
101+
return DataSample(; x, instance, y=y_true)
102102
end
103103

104104
"""
@@ -145,11 +145,12 @@ end
145145
$TYPEDSIGNATURES
146146
"""
147147
function plot_instance(
148-
::StochasticVehicleSchedulingBenchmark, sample::DataSample{<:Instance{City}}; kwargs...
148+
::StochasticVehicleSchedulingBenchmark, sample::DataSample; kwargs...
149149
)
150-
(; tasks, district_width, width) = sample.info.city
150+
@assert hasproperty(sample.instance, :city) "Sample does not contain city information."
151+
(; tasks, district_width, width) = sample.instance.city
151152
ticks = 0:district_width:width
152-
max_time = maximum(t.end_time for t in sample.info.city.tasks[1:(end - 1)])
153+
max_time = maximum(t.end_time for t in sample.instance.city.tasks[1:(end - 1)])
153154
fig = plot(;
154155
xlabel="x",
155156
ylabel="y",
@@ -204,11 +205,12 @@ end
204205
$TYPEDSIGNATURES
205206
"""
206207
function plot_solution(
207-
::StochasticVehicleSchedulingBenchmark, sample::DataSample{<:Instance{City}}; kwargs...
208+
::StochasticVehicleSchedulingBenchmark, sample::DataSample; kwargs...
208209
)
209-
(; tasks, district_width, width) = sample.info.city
210+
@assert hasproperty(sample.instance, :city) "Sample does not contain city information."
211+
(; tasks, district_width, width) = sample.instance.city
210212
ticks = 0:district_width:width
211-
solution = Solution(sample.y, sample.info)
213+
solution = Solution(sample.y, sample.instance)
212214
path_list = compute_path_list(solution)
213215
fig = plot(;
214216
xlabel="x",

src/Utils/Utils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ export generate_policies
3535
export generate_anticipative_solution
3636

3737
export plot_data, compute_gap
38-
export maximizer_kwargs
3938
export grid_graph, get_path, path_to_matrix
4039
export neg_tensor, squeeze_last_dims, average_tensor
4140
export scip_model, highs_model

src/Utils/data_sample.jl

Lines changed: 73 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,99 @@
22
$TYPEDEF
33
44
Data sample data structure.
5+
Its main purpose is to store datasets generated by the benchmarks.
6+
It has 3 main fields: features `x`, cost parameters `θ` and solution `y`.
7+
Additionally, it has an `info` field to store any additional information as a `NamedTuple`, usually the instance, but can be used for anything else.
58
69
# Fields
710
$TYPEDFIELDS
811
"""
9-
@kwdef struct DataSample{
10-
I,
12+
struct DataSample{
13+
I<:NamedTuple,
1114
F<:Union{AbstractArray,Nothing},
1215
S<:Union{AbstractArray,Nothing},
1316
C<:Union{AbstractArray,Nothing},
1417
}
1518
"input features (optional)"
16-
x::F = nothing
19+
x::F
1720
"intermediate cost parameters (optional)"
18-
θ::C = nothing
21+
θ::C
1922
"output solution (optional)"
20-
y::S = nothing
23+
y::S
2124
"additional information, usually the instance (optional)"
22-
info::I = nothing
25+
info::I
2326
end
2427

28+
"""
29+
$TYPEDSIGNATURES
30+
31+
Constructor for `DataSample` with keyword arguments.
32+
33+
Additional keyword arguments beyond `x`, `θ`, and `y` are stored in the `info` field
34+
and can be accessed directly (e.g., `data.instance` instead of `data.info.instance`).
35+
36+
# Examples
37+
```julia
38+
d = DataSample(x=[1,2,3], θ=[4,5,6], y=[7,8,9], instance="my_instance")
39+
d.instance # "my_instance"
40+
```
41+
"""
42+
function DataSample(; x=nothing, θ=nothing, y=nothing, kwargs...)
43+
info = (; kwargs...)
44+
return DataSample(x, θ, y, info)
45+
end
46+
47+
"""
48+
$TYPEDSIGNATURES
49+
50+
Extended property access for `DataSample`.
51+
52+
Allows accessing `info` fields directly as properties (e.g., `d.instance` instead of `d.info.instance`).
53+
"""
54+
function Base.getproperty(d::DataSample, name::Symbol)
55+
if name in (:x, , :y, :info)
56+
return getfield(d, name)
57+
else
58+
return getproperty(getfield(d, :info), name)
59+
end
60+
end
61+
62+
"""
63+
$TYPEDSIGNATURES
64+
65+
Return all property names of a `DataSample`, including both struct fields and `info` fields.
66+
67+
This enables tab completion for all available properties, including those stored in `info`.
68+
"""
69+
function Base.propertynames(d::DataSample, private::Bool=false)
70+
return (fieldnames(DataSample)..., propertynames(getfield(d, :info), private)...)
71+
end
72+
73+
"""
74+
$TYPEDSIGNATURES
75+
76+
Display a `DataSample` with truncated array representations for readability.
77+
78+
Large arrays are automatically truncated with ellipsis (`...`), similar to standard Julia array printing.
79+
"""
2580
function Base.show(io::IO, d::DataSample)
2681
fields = String[]
82+
io_limited = IOContext(io, :limit => true, :compact => true)
2783
if !isnothing(d.x)
28-
push!(fields, "x=$(d.x)")
84+
x_str = sprint(show, d.x; context=io_limited)
85+
push!(fields, "x=$x_str")
2986
end
3087
if !isnothing(d.θ)
31-
push!(fields, "θ_true=$(d.θ)")
88+
θ_str = sprint(show, d.θ; context=io_limited)
89+
push!(fields, "θ_true=$θ_str")
3290
end
3391
if !isnothing(d.y)
34-
push!(fields, "y_true=$(d.y)")
92+
y_str = sprint(show, d.y; context=io_limited)
93+
push!(fields, "y_true=$y_str")
3594
end
36-
if !isnothing(d.info)
37-
push!(fields, "instance=$(d.info)")
95+
for (key, value) in pairs(d.info)
96+
value_str = sprint(show, value; context=io_limited)
97+
push!(fields, "$key=$value_str")
3898
end
3999
return print(io, "DataSample(", join(fields, ", "), ")")
40100
end
@@ -57,7 +117,7 @@ Transform the features in the dataset.
57117
function StatsBase.transform(t, dataset::AbstractVector{<:DataSample})
58118
return map(dataset) do d
59119
(; info, x, θ, y) = d
60-
DataSample(; info, x=StatsBase.transform(t, x), θ, y)
120+
DataSample(StatsBase.transform(t, x), θ, y, info)
61121
end
62122
end
63123

@@ -80,7 +140,7 @@ Reconstruct the features in the dataset.
80140
function StatsBase.reconstruct(t, dataset::AbstractVector{<:DataSample})
81141
return map(dataset) do d
82142
(; info, x, θ, y) = d
83-
DataSample(; info, x=StatsBase.reconstruct(t, x), θ, y)
143+
DataSample(StatsBase.reconstruct(t, x), θ, y, info)
84144
end
85145
end
86146

src/Utils/interface.jl

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -99,26 +99,6 @@ function compute_gap end
9999
"""
100100
$TYPEDSIGNATURES
101101
102-
For simple benchmarks where there is no instance object, maximizer does not need any keyword arguments.
103-
"""
104-
function maximizer_kwargs(
105-
::AbstractBenchmark, sample::DataSample{Nothing,F,S,C}
106-
) where {F,S,C}
107-
return NamedTuple()
108-
end
109-
110-
"""
111-
$TYPEDSIGNATURES
112-
113-
For benchmarks where there is an instance object, maximizer needs the instance object as a keyword argument.
114-
"""
115-
function maximizer_kwargs(::AbstractBenchmark, sample::DataSample)
116-
return (; instance=sample.info)
117-
end
118-
119-
"""
120-
$TYPEDSIGNATURES
121-
122102
Default behaviour of `objective_value`.
123103
"""
124104
function objective_value(::AbstractBenchmark, θ::AbstractArray, y::AbstractArray)
@@ -175,7 +155,7 @@ function compute_gap(
175155
target_obj = objective_value(bench, sample)
176156
x = sample.x
177157
θ = statistical_model(x)
178-
y = maximizer(θ; maximizer_kwargs(bench, sample)...)
158+
y = maximizer(θ; sample.info...)
179159
obj = objective_value(bench, sample, y)
180160
Δ = check ? obj - target_obj : target_obj - obj
181161
return Δ / abs(target_obj)
@@ -234,7 +214,7 @@ Uses the info field of the sample as the instance.
234214
function generate_environment(
235215
bench::AbstractDynamicBenchmark, sample::DataSample, rng::AbstractRNG; kwargs...
236216
)
237-
return generate_environment(bench, sample.info, rng; kwargs...)
217+
return generate_environment(bench, sample.instance, rng; kwargs...)
238218
end
239219

240220
"""
@@ -250,7 +230,7 @@ function generate_environments(
250230
kwargs...,
251231
)
252232
Random.seed!(rng, seed)
253-
return map(dataset) do instance
254-
generate_environment(bench, instance, rng; kwargs...)
233+
return map(dataset) do sample
234+
generate_environment(bench, sample, rng; kwargs...)
255235
end
256236
end

src/Utils/policy.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function evaluate_policy!(
4444
features, state = observe(env)
4545
state_copy = deepcopy(state) # To avoid mutation issues
4646
reward = step!(env, y)
47-
sample = DataSample(; x=features, y=y, info=(; state=state_copy, reward))
47+
sample = DataSample(; x=features, y=y, state=state_copy, reward=reward)
4848
if @isdefined labeled_dataset
4949
push!(labeled_dataset, sample)
5050
else

0 commit comments

Comments
 (0)