Skip to content

Commit 89a581e

Browse files
authored
Merge pull request #52 from JuliaDecisionFocusedLearning/datasample-redesign
DataSample interface redesign
2 parents 19d15e4 + e2f8ced commit 89a581e

17 files changed

Lines changed: 89 additions & 51 deletions

docs/src/tutorials/warcraft_tutorial.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ x = sample.x
2929
θ_true = sample.θ
3030
# `y` correspond to the optimal shortest path, encoded as a binary matrix:
3131
y_true = sample.y
32-
# `info` is not used in this benchmark, therefore set to nothing:
33-
isnothing(sample.info)
32+
# `context` is not used in this benchmark (no solver kwargs needed), so it is empty:
33+
isempty(sample.context)
3434

3535
# For some benchmarks, we provide the following plotting method [`plot_data`](@ref) to visualize the data:
3636
plot_data(b, sample)

src/DynamicVehicleScheduling/anticipative_solver.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ function anticipative_solver(
224224
compute_features(state, env.instance)
225225
end
226226

227-
return DataSample(; y=y_true, x, state, reward)
227+
return DataSample(; y=y_true, x, instance=state, extra=(; reward))
228228
end
229229

230230
return obj, dataset

src/DynamicVehicleScheduling/plot.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,8 @@ The returned dictionary contains:
208208
This lets plotting code build figures without depending on plotting internals.
209209
"""
210210
function build_plot_data(data_samples::Vector{<:DataSample})
211-
state_data = [build_state_data(sample.info.state) for sample in data_samples]
212-
rewards = [sample.info.reward for sample in data_samples]
211+
state_data = [build_state_data(sample.instance) for sample in data_samples]
212+
rewards = [sample.reward for sample in data_samples]
213213
routess = [sample.y for sample in data_samples]
214214
return [
215215
(; state..., reward, routes) for
@@ -273,8 +273,8 @@ function plot_epochs(
273273
# Create subplots
274274
plots = map(1:n_epochs) do i
275275
sample = data_samples[i]
276-
state = sample.info.state
277-
reward = sample.info.reward
276+
state = sample.instance
277+
reward = sample.reward
278278

279279
common_kwargs = Dict(
280280
:xlims => xlims,
@@ -351,7 +351,7 @@ function animate_epochs(
351351
kwargs...,
352352
)
353353
pd = build_plot_data(data_samples)
354-
epoch_costs = [-sample.info.reward for sample in data_samples]
354+
epoch_costs = [-sample.reward for sample in data_samples]
355355

356356
# Calculate global xlims and ylims from all states
357357
x_min = minimum(min(data.x_depot, minimum(data.x_customers)) for data in pd)
@@ -393,7 +393,7 @@ function animate_epochs(
393393
anim = @animate for frame_idx in 1:total_frames
394394
epoch_idx, frame_type = frame_plan[frame_idx]
395395
sample = data_samples[epoch_idx]
396-
state = sample.info.state
396+
state = sample.instance
397397

398398
if frame_type == :routes
399399
fig = plot_routes(

src/Utils/data_sample.jl

Lines changed: 60 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,20 @@ $TYPEDEF
33
44
Data sample data structure.
55
Its main purpose is to store datasets generated by the benchmarks.
6-
It has 3 main fields: features `x`, cost parameters `θ` and solution `y`.
7-
Additionally, it has an `info` field to store any additional information as a `NamedTuple`, usually the instance, but can be used for anything else.
6+
It has 3 main (optional) fields: features `x`, cost parameters `θ`, and solution `y`.
7+
Additionally, it has a `context` field (solver kwargs, spread into the maximizer as
8+
`maximizer(θ; sample.context...)`) and an `extra` field (non-solver data, never passed
9+
to the maximizer).
10+
11+
The separation prevents silent breakage from accidentally passing non-solver data
12+
(e.g. a scenario, reward, or step counter) as unexpected kwargs to the maximizer.
813
914
# Fields
1015
$TYPEDFIELDS
1116
"""
1217
struct DataSample{
13-
I<:NamedTuple,
18+
CTX<:NamedTuple,
19+
EX<:NamedTuple,
1420
F<:Union{AbstractArray,Nothing},
1521
S<:Union{AbstractArray,Nothing},
1622
C<:Union{AbstractArray,Nothing},
@@ -21,53 +27,81 @@ struct DataSample{
2127
θ::C
2228
"output solution (optional)"
2329
y::S
24-
"additional information, usually the instance (optional)"
25-
info::I
30+
"context information as solver kwargs, e.g. instance, graph, etc."
31+
context::CTX
32+
"additional data, never passed to the maximizer, e.g. scenario, objective value, reward,
33+
step count, etc. Can be used for any purpose by the user, such as plotting utilities."
34+
extra::EX
2635
end
2736

2837
"""
2938
$TYPEDSIGNATURES
3039
3140
Constructor for `DataSample` with keyword arguments.
3241
33-
Additional keyword arguments beyond `x`, `θ`, and `y` are stored in the `info` field
34-
and can be accessed directly (e.g., `data.instance` instead of `data.info.instance`).
42+
All keyword arguments beyond `x`, `θ`, `y`, and `extra` are collected into the `context`
43+
field (solver kwargs). The `extra` keyword accepts a `NamedTuple` of non-solver data.
44+
45+
Fields in `context` and `extra` must be disjoint. An error is thrown if they overlap.
46+
Both can be accessed directly via property forwarding.
3547
3648
# Examples
3749
```julia
50+
# Instance goes in context
3851
d = DataSample(x=[1,2,3], θ=[4,5,6], y=[7,8,9], instance="my_instance")
39-
d.instance # "my_instance"
52+
d.instance # "my_instance" (from context)
53+
54+
# Scenario goes in extra
55+
d = DataSample(x=x, y=y, instance=inst, extra=(; scenario=ξ))
56+
d.scenario # ξ (from extra)
57+
58+
# State goes in context, reward in extra
59+
d = DataSample(x=x, y=y, instance=state, extra=(; reward=-1.5))
60+
d.instance # state (from context)
61+
d.reward # -1.5 (from extra)
4062
```
4163
"""
42-
function DataSample(; x=nothing, θ=nothing, y=nothing, kwargs...)
43-
info = (; kwargs...)
44-
return DataSample(x, θ, y, info)
64+
function DataSample(; x=nothing, θ=nothing, y=nothing, extra=NamedTuple(), kwargs...)
65+
context = (; kwargs...)
66+
overlap = intersect(keys(context), keys(extra))
67+
if !isempty(overlap)
68+
error("Keys $(collect(overlap)) appear in both context and extra of DataSample")
69+
end
70+
return DataSample(x, θ, y, context, extra)
4571
end
4672

4773
"""
4874
$TYPEDSIGNATURES
4975
5076
Extended property access for `DataSample`.
5177
52-
Allows accessing `info` fields directly as properties (e.g., `d.instance` instead of `d.info.instance`).
78+
Allows accessing `context` and `extra` fields directly as properties.
79+
`context` is searched first; if the key is not found there, `extra` is searched.
5380
"""
5481
function Base.getproperty(d::DataSample, name::Symbol)
55-
if name in (:x, , :y, :info)
82+
if name in (:x, , :y, :context, :extra)
5683
return getfield(d, name)
5784
else
58-
return getproperty(getfield(d, :info), name)
85+
ctx = getfield(d, :context)
86+
if haskey(ctx, name)
87+
return getproperty(ctx, name)
88+
end
89+
return getproperty(getfield(d, :extra), name)
5990
end
6091
end
6192

6293
"""
6394
$TYPEDSIGNATURES
6495
65-
Return all property names of a `DataSample`, including both struct fields and `info` fields.
96+
Return all property names of a `DataSample`, including both struct fields and forwarded
97+
fields from `context` and `extra`.
6698
67-
This enables tab completion for all available properties, including those stored in `info`.
99+
This enables tab completion for all available properties.
68100
"""
69101
function Base.propertynames(d::DataSample, private::Bool=false)
70-
return (fieldnames(DataSample)..., propertynames(getfield(d, :info), private)...)
102+
ctx_names = propertynames(getfield(d, :context), private)
103+
extra_names = propertynames(getfield(d, :extra), private)
104+
return (fieldnames(DataSample)..., ctx_names..., extra_names...)
71105
end
72106

73107
"""
@@ -92,7 +126,11 @@ function Base.show(io::IO, d::DataSample)
92126
y_str = sprint(show, d.y; context=io_limited)
93127
push!(fields, "y_true=$y_str")
94128
end
95-
for (key, value) in pairs(d.info)
129+
for (key, value) in pairs(d.context)
130+
value_str = sprint(show, value; context=io_limited)
131+
push!(fields, "$key=$value_str")
132+
end
133+
for (key, value) in pairs(d.extra)
96134
value_str = sprint(show, value; context=io_limited)
97135
push!(fields, "$key=$value_str")
98136
end
@@ -116,8 +154,8 @@ Transform the features in the dataset.
116154
"""
117155
function StatsBase.transform(t, dataset::AbstractVector{<:DataSample})
118156
return map(dataset) do d
119-
(; info, x, θ, y) = d
120-
DataSample(StatsBase.transform(t, x), θ, y, info)
157+
(; context, extra, x, θ, y) = d
158+
DataSample(StatsBase.transform(t, x), θ, y, context, extra)
121159
end
122160
end
123161

@@ -139,8 +177,8 @@ Reconstruct the features in the dataset.
139177
"""
140178
function StatsBase.reconstruct(t, dataset::AbstractVector{<:DataSample})
141179
return map(dataset) do d
142-
(; info, x, θ, y) = d
143-
DataSample(StatsBase.reconstruct(t, x), θ, y, info)
180+
(; context, extra, x, θ, y) = d
181+
DataSample(StatsBase.reconstruct(t, x), θ, y, context, extra)
144182
end
145183
end
146184

src/Utils/interface.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ $TYPEDSIGNATURES
111111
Compute the objective value of given solution `y`.
112112
"""
113113
function objective_value(
114-
bench::AbstractBenchmark, sample::DataSample{I,F,S,C}, y::AbstractArray
115-
) where {I,F,S,C<:AbstractArray}
114+
bench::AbstractBenchmark, sample::DataSample{CTX,EX,F,S,C}, y::AbstractArray
115+
) where {CTX,EX,F,S,C<:AbstractArray}
116116
return objective_value(bench, sample.θ, y)
117117
end
118118

@@ -122,8 +122,8 @@ $TYPEDSIGNATURES
122122
Compute the objective value of the target in the sample (needs to exist).
123123
"""
124124
function objective_value(
125-
bench::AbstractBenchmark, sample::DataSample{I,F,S,C}
126-
) where {I,F,S<:AbstractArray,C}
125+
bench::AbstractBenchmark, sample::DataSample{CTX,EX,F,S,C}
126+
) where {CTX,EX,F,S<:AbstractArray,C}
127127
return objective_value(bench, sample, sample.y)
128128
end
129129

@@ -155,7 +155,7 @@ function compute_gap(
155155
target_obj = objective_value(bench, sample)
156156
x = sample.x
157157
θ = statistical_model(x)
158-
y = maximizer(θ; sample.info...)
158+
y = maximizer(θ; sample.context...)
159159
obj = objective_value(bench, sample, y)
160160
Δ = check ? obj - target_obj : target_obj - obj
161161
return Δ / abs(target_obj)

src/Utils/policy.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function evaluate_policy!(
4444
features, state = observe(env)
4545
state_copy = deepcopy(state) # To avoid mutation issues
4646
reward = step!(env, y)
47-
sample = DataSample(; x=features, y=y, state=state_copy, reward=reward)
47+
sample = DataSample(; x=features, y=y, instance=state_copy, extra=(; reward))
4848
if isempty(labeled_dataset)
4949
labeled_dataset = typeof(sample)[sample]
5050
else

test/argmax.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
@test size(x) == (nb_features, instance_dim)
2525
@test length(θ_true) == instance_dim
2626
@test length(y_true) == instance_dim
27-
@test length(sample.info) == 0
27+
@test isempty(sample.context)
2828
@test all(y_true .== maximizer(θ_true))
2929

3030
θ = model(x)

test/dynamic_assortment.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ end
321321

322322
# Test integration with sample data
323323
sample = generate_sample(b, MersenneTwister(42))
324-
@test hasfield(typeof(sample), :info)
324+
@test hasfield(typeof(sample), :context)
325325

326326
dataset = generate_dataset(b, 3; seed=42)
327327
environments = generate_environments(b, dataset)

test/dynamic_vsp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
anticipative_value, solution = generate_anticipative_solution(b, env; reset_env=true)
5151
reset!(env; reset_rng=true)
5252
cost = sum(step!(env, sample.y) for sample in solution)
53-
cost2 = sum(sample.info.reward for sample in solution)
53+
cost2 = sum(sample.reward for sample in solution)
5454
@test isapprox(cost, anticipative_value; atol=1e-5)
5555
@test isapprox(cost, cost2; atol=1e-5)
5656
end

test/dynamic_vsp_plots.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
policies = generate_policies(b)
2323
lazy = policies[1]
2424
_, d = evaluate_policy!(lazy, env)
25-
fig3 = DVSP.plot_routes(d[1].info.state, d[1].y)
25+
fig3 = DVSP.plot_routes(d[1].instance, d[1].y)
2626
@test fig3 isa Plots.Plot
2727

2828
# Test animation

0 commit comments

Comments
 (0)