From e2f8ced6730a82afdfa4df9040c3861f97d60a6a Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Tue, 3 Mar 2026 11:07:52 +0100 Subject: [PATCH] Rename 'info' field to 'context', and new 'extra' field --- docs/src/tutorials/warcraft_tutorial.jl | 4 +- .../anticipative_solver.jl | 2 +- src/DynamicVehicleScheduling/plot.jl | 12 +-- src/Utils/data_sample.jl | 82 ++++++++++++++----- src/Utils/interface.jl | 10 +-- src/Utils/policy.jl | 2 +- test/argmax.jl | 2 +- test/dynamic_assortment.jl | 2 +- test/dynamic_vsp.jl | 2 +- test/dynamic_vsp_plots.jl | 2 +- test/fixed_size_shortest_path.jl | 2 +- test/maintenance.jl | 2 +- test/portfolio_optimization.jl | 2 +- test/ranking.jl | 2 +- test/subset_selection.jl | 2 +- test/utils.jl | 8 +- test/warcraft.jl | 2 +- 17 files changed, 89 insertions(+), 51 deletions(-) diff --git a/docs/src/tutorials/warcraft_tutorial.jl b/docs/src/tutorials/warcraft_tutorial.jl index 9039d09..8f78db8 100644 --- a/docs/src/tutorials/warcraft_tutorial.jl +++ b/docs/src/tutorials/warcraft_tutorial.jl @@ -29,8 +29,8 @@ x = sample.x θ_true = sample.θ # `y` correspond to the optimal shortest path, encoded as a binary matrix: y_true = sample.y -# `info` is not used in this benchmark, therefore set to nothing: -isnothing(sample.info) +# `context` is not used in this benchmark (no solver kwargs needed), so it is empty: +isempty(sample.context) # For some benchmarks, we provide the following plotting method [`plot_data`](@ref) to visualize the data: plot_data(b, sample) diff --git a/src/DynamicVehicleScheduling/anticipative_solver.jl b/src/DynamicVehicleScheduling/anticipative_solver.jl index 42293c5..b44f0e8 100644 --- a/src/DynamicVehicleScheduling/anticipative_solver.jl +++ b/src/DynamicVehicleScheduling/anticipative_solver.jl @@ -224,7 +224,7 @@ function anticipative_solver( compute_features(state, env.instance) end - return DataSample(; y=y_true, x, state, reward) + return DataSample(; y=y_true, x, instance=state, extra=(; reward)) end return obj, dataset diff --git a/src/DynamicVehicleScheduling/plot.jl b/src/DynamicVehicleScheduling/plot.jl index 00a76a4..f6e39ad 100644 --- a/src/DynamicVehicleScheduling/plot.jl +++ b/src/DynamicVehicleScheduling/plot.jl @@ -208,8 +208,8 @@ The returned dictionary contains: This lets plotting code build figures without depending on plotting internals. """ function build_plot_data(data_samples::Vector{<:DataSample}) - state_data = [build_state_data(sample.info.state) for sample in data_samples] - rewards = [sample.info.reward for sample in data_samples] + state_data = [build_state_data(sample.instance) for sample in data_samples] + rewards = [sample.reward for sample in data_samples] routess = [sample.y for sample in data_samples] return [ (; state..., reward, routes) for @@ -273,8 +273,8 @@ function plot_epochs( # Create subplots plots = map(1:n_epochs) do i sample = data_samples[i] - state = sample.info.state - reward = sample.info.reward + state = sample.instance + reward = sample.reward common_kwargs = Dict( :xlims => xlims, @@ -351,7 +351,7 @@ function animate_epochs( kwargs..., ) pd = build_plot_data(data_samples) - epoch_costs = [-sample.info.reward for sample in data_samples] + epoch_costs = [-sample.reward for sample in data_samples] # Calculate global xlims and ylims from all states x_min = minimum(min(data.x_depot, minimum(data.x_customers)) for data in pd) @@ -393,7 +393,7 @@ function animate_epochs( anim = @animate for frame_idx in 1:total_frames epoch_idx, frame_type = frame_plan[frame_idx] sample = data_samples[epoch_idx] - state = sample.info.state + state = sample.instance if frame_type == :routes fig = plot_routes( diff --git a/src/Utils/data_sample.jl b/src/Utils/data_sample.jl index b043b48..ccdf2c0 100644 --- a/src/Utils/data_sample.jl +++ b/src/Utils/data_sample.jl @@ -3,14 +3,20 @@ $TYPEDEF Data sample data structure. Its main purpose is to store datasets generated by the benchmarks. -It has 3 main fields: features `x`, cost parameters `θ` and solution `y`. -Additionally, it has an `info` field to store any additional information as a `NamedTuple`, usually the instance, but can be used for anything else. +It has 3 main (optional) fields: features `x`, cost parameters `θ`, and solution `y`. +Additionally, it has a `context` field (solver kwargs, spread into the maximizer as +`maximizer(θ; sample.context...)`) and an `extra` field (non-solver data, never passed +to the maximizer). + +The separation prevents silent breakage from accidentally passing non-solver data +(e.g. a scenario, reward, or step counter) as unexpected kwargs to the maximizer. # Fields $TYPEDFIELDS """ struct DataSample{ - I<:NamedTuple, + CTX<:NamedTuple, + EX<:NamedTuple, F<:Union{AbstractArray,Nothing}, S<:Union{AbstractArray,Nothing}, C<:Union{AbstractArray,Nothing}, @@ -21,8 +27,11 @@ struct DataSample{ θ::C "output solution (optional)" y::S - "additional information, usually the instance (optional)" - info::I + "context information as solver kwargs, e.g. instance, graph, etc." + context::CTX + "additional data, never passed to the maximizer, e.g. scenario, objective value, reward, + step count, etc. Can be used for any purpose by the user, such as plotting utilities." + extra::EX end """ @@ -30,18 +39,35 @@ $TYPEDSIGNATURES Constructor for `DataSample` with keyword arguments. -Additional keyword arguments beyond `x`, `θ`, and `y` are stored in the `info` field -and can be accessed directly (e.g., `data.instance` instead of `data.info.instance`). +All keyword arguments beyond `x`, `θ`, `y`, and `extra` are collected into the `context` +field (solver kwargs). The `extra` keyword accepts a `NamedTuple` of non-solver data. + +Fields in `context` and `extra` must be disjoint. An error is thrown if they overlap. +Both can be accessed directly via property forwarding. # Examples ```julia +# Instance goes in context d = DataSample(x=[1,2,3], θ=[4,5,6], y=[7,8,9], instance="my_instance") -d.instance # "my_instance" +d.instance # "my_instance" (from context) + +# Scenario goes in extra +d = DataSample(x=x, y=y, instance=inst, extra=(; scenario=ξ)) +d.scenario # ξ (from extra) + +# State goes in context, reward in extra +d = DataSample(x=x, y=y, instance=state, extra=(; reward=-1.5)) +d.instance # state (from context) +d.reward # -1.5 (from extra) ``` """ -function DataSample(; x=nothing, θ=nothing, y=nothing, kwargs...) - info = (; kwargs...) - return DataSample(x, θ, y, info) +function DataSample(; x=nothing, θ=nothing, y=nothing, extra=NamedTuple(), kwargs...) + context = (; kwargs...) + overlap = intersect(keys(context), keys(extra)) + if !isempty(overlap) + error("Keys $(collect(overlap)) appear in both context and extra of DataSample") + end + return DataSample(x, θ, y, context, extra) end """ @@ -49,25 +75,33 @@ $TYPEDSIGNATURES Extended property access for `DataSample`. -Allows accessing `info` fields directly as properties (e.g., `d.instance` instead of `d.info.instance`). +Allows accessing `context` and `extra` fields directly as properties. +`context` is searched first; if the key is not found there, `extra` is searched. """ function Base.getproperty(d::DataSample, name::Symbol) - if name in (:x, :θ, :y, :info) + if name in (:x, :θ, :y, :context, :extra) return getfield(d, name) else - return getproperty(getfield(d, :info), name) + ctx = getfield(d, :context) + if haskey(ctx, name) + return getproperty(ctx, name) + end + return getproperty(getfield(d, :extra), name) end end """ $TYPEDSIGNATURES -Return all property names of a `DataSample`, including both struct fields and `info` fields. +Return all property names of a `DataSample`, including both struct fields and forwarded +fields from `context` and `extra`. -This enables tab completion for all available properties, including those stored in `info`. +This enables tab completion for all available properties. """ function Base.propertynames(d::DataSample, private::Bool=false) - return (fieldnames(DataSample)..., propertynames(getfield(d, :info), private)...) + ctx_names = propertynames(getfield(d, :context), private) + extra_names = propertynames(getfield(d, :extra), private) + return (fieldnames(DataSample)..., ctx_names..., extra_names...) end """ @@ -92,7 +126,11 @@ function Base.show(io::IO, d::DataSample) y_str = sprint(show, d.y; context=io_limited) push!(fields, "y_true=$y_str") end - for (key, value) in pairs(d.info) + for (key, value) in pairs(d.context) + value_str = sprint(show, value; context=io_limited) + push!(fields, "$key=$value_str") + end + for (key, value) in pairs(d.extra) value_str = sprint(show, value; context=io_limited) push!(fields, "$key=$value_str") end @@ -116,8 +154,8 @@ Transform the features in the dataset. """ function StatsBase.transform(t, dataset::AbstractVector{<:DataSample}) return map(dataset) do d - (; info, x, θ, y) = d - DataSample(StatsBase.transform(t, x), θ, y, info) + (; context, extra, x, θ, y) = d + DataSample(StatsBase.transform(t, x), θ, y, context, extra) end end @@ -139,8 +177,8 @@ Reconstruct the features in the dataset. """ function StatsBase.reconstruct(t, dataset::AbstractVector{<:DataSample}) return map(dataset) do d - (; info, x, θ, y) = d - DataSample(StatsBase.reconstruct(t, x), θ, y, info) + (; context, extra, x, θ, y) = d + DataSample(StatsBase.reconstruct(t, x), θ, y, context, extra) end end diff --git a/src/Utils/interface.jl b/src/Utils/interface.jl index fbe61a9..6d23f12 100644 --- a/src/Utils/interface.jl +++ b/src/Utils/interface.jl @@ -111,8 +111,8 @@ $TYPEDSIGNATURES Compute the objective value of given solution `y`. """ function objective_value( - bench::AbstractBenchmark, sample::DataSample{I,F,S,C}, y::AbstractArray -) where {I,F,S,C<:AbstractArray} + bench::AbstractBenchmark, sample::DataSample{CTX,EX,F,S,C}, y::AbstractArray +) where {CTX,EX,F,S,C<:AbstractArray} return objective_value(bench, sample.θ, y) end @@ -122,8 +122,8 @@ $TYPEDSIGNATURES Compute the objective value of the target in the sample (needs to exist). """ function objective_value( - bench::AbstractBenchmark, sample::DataSample{I,F,S,C} -) where {I,F,S<:AbstractArray,C} + bench::AbstractBenchmark, sample::DataSample{CTX,EX,F,S,C} +) where {CTX,EX,F,S<:AbstractArray,C} return objective_value(bench, sample, sample.y) end @@ -155,7 +155,7 @@ function compute_gap( target_obj = objective_value(bench, sample) x = sample.x θ = statistical_model(x) - y = maximizer(θ; sample.info...) + y = maximizer(θ; sample.context...) obj = objective_value(bench, sample, y) Δ = check ? obj - target_obj : target_obj - obj return Δ / abs(target_obj) diff --git a/src/Utils/policy.jl b/src/Utils/policy.jl index 796cbf0..5eb0c6d 100644 --- a/src/Utils/policy.jl +++ b/src/Utils/policy.jl @@ -44,7 +44,7 @@ function evaluate_policy!( features, state = observe(env) state_copy = deepcopy(state) # To avoid mutation issues reward = step!(env, y) - sample = DataSample(; x=features, y=y, state=state_copy, reward=reward) + sample = DataSample(; x=features, y=y, instance=state_copy, extra=(; reward)) if isempty(labeled_dataset) labeled_dataset = typeof(sample)[sample] else diff --git a/test/argmax.jl b/test/argmax.jl index ba5723b..d772c4d 100644 --- a/test/argmax.jl +++ b/test/argmax.jl @@ -24,7 +24,7 @@ @test size(x) == (nb_features, instance_dim) @test length(θ_true) == instance_dim @test length(y_true) == instance_dim - @test length(sample.info) == 0 + @test isempty(sample.context) @test all(y_true .== maximizer(θ_true)) θ = model(x) diff --git a/test/dynamic_assortment.jl b/test/dynamic_assortment.jl index 97674bf..1504421 100644 --- a/test/dynamic_assortment.jl +++ b/test/dynamic_assortment.jl @@ -321,7 +321,7 @@ end # Test integration with sample data sample = generate_sample(b, MersenneTwister(42)) - @test hasfield(typeof(sample), :info) + @test hasfield(typeof(sample), :context) dataset = generate_dataset(b, 3; seed=42) environments = generate_environments(b, dataset) diff --git a/test/dynamic_vsp.jl b/test/dynamic_vsp.jl index 825b319..39a00aa 100644 --- a/test/dynamic_vsp.jl +++ b/test/dynamic_vsp.jl @@ -50,7 +50,7 @@ anticipative_value, solution = generate_anticipative_solution(b, env; reset_env=true) reset!(env; reset_rng=true) cost = sum(step!(env, sample.y) for sample in solution) - cost2 = sum(sample.info.reward for sample in solution) + cost2 = sum(sample.reward for sample in solution) @test isapprox(cost, anticipative_value; atol=1e-5) @test isapprox(cost, cost2; atol=1e-5) end diff --git a/test/dynamic_vsp_plots.jl b/test/dynamic_vsp_plots.jl index e079086..1fc822b 100644 --- a/test/dynamic_vsp_plots.jl +++ b/test/dynamic_vsp_plots.jl @@ -22,7 +22,7 @@ policies = generate_policies(b) lazy = policies[1] _, d = evaluate_policy!(lazy, env) - fig3 = DVSP.plot_routes(d[1].info.state, d[1].y) + fig3 = DVSP.plot_routes(d[1].instance, d[1].y) @test fig3 isa Plots.Plot # Test animation diff --git a/test/fixed_size_shortest_path.jl b/test/fixed_size_shortest_path.jl index e1360e1..eacdd64 100644 --- a/test/fixed_size_shortest_path.jl +++ b/test/fixed_size_shortest_path.jl @@ -25,7 +25,7 @@ @test size(x) == (p,) @test length(θ_true) == A @test length(y_true) == A - @test length(sample.info) == 0 + @test isempty(sample.context) @test all(y_true .== maximizer(θ_true)) θ = model(x) @test length(θ) == length(θ_true) diff --git a/test/maintenance.jl b/test/maintenance.jl index d7183f5..32f90ed 100644 --- a/test/maintenance.jl +++ b/test/maintenance.jl @@ -198,7 +198,7 @@ end # Test integration with sample data sample = generate_sample(b, MersenneTwister(42)) - @test hasfield(typeof(sample), :info) + @test hasfield(typeof(sample), :context) dataset = generate_dataset(b, 3; seed=42) environments = generate_environments(b, dataset) diff --git a/test/portfolio_optimization.jl b/test/portfolio_optimization.jl index 96895b8..a722c5b 100644 --- a/test/portfolio_optimization.jl +++ b/test/portfolio_optimization.jl @@ -16,7 +16,7 @@ @test size(x) == (p,) @test length(θ_true) == d @test length(y_true) == d - @test length(sample.info) == 0 + @test isempty(sample.context) @test all(y_true .== maximizer(θ_true)) θ = model(x) diff --git a/test/ranking.jl b/test/ranking.jl index 4f9d437..e8e5939 100644 --- a/test/ranking.jl +++ b/test/ranking.jl @@ -21,7 +21,7 @@ @test size(x) == (nb_features, instance_dim) @test length(θ_true) == instance_dim @test length(y_true) == instance_dim - @test length(sample.info) == 0 + @test isempty(sample.context) @test all(y_true .== maximizer(θ_true)) θ = model(x) diff --git a/test/subset_selection.jl b/test/subset_selection.jl index 952420d..ba0ff10 100644 --- a/test/subset_selection.jl +++ b/test/subset_selection.jl @@ -23,7 +23,7 @@ @test size(x) == (n,) @test length(θ_true) == n @test length(y_true) == n - @test length(sample.info) == 0 + @test isempty(sample.context) @test all(y_true .== maximizer(θ_true)) # Features and true weights should be equal diff --git a/test/utils.jl b/test/utils.jl index f469db1..17d3373 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -58,7 +58,7 @@ end @test occursin("y_true", s) @test occursin("instance=\"this is an instance\"", s) - @test propertynames(sample) == (:x, :θ, :y, :info, :instance) + @test propertynames(sample) == (:x, :θ, :y, :context, :extra, :instance) # Create a dataset for testing N = 5 @@ -81,7 +81,7 @@ end for i in 1:N @test dataset_zt[i].θ == dataset[i].θ @test dataset_zt[i].y == dataset[i].y - @test dataset_zt[i].info == dataset[i].info + @test dataset_zt[i].context == dataset[i].context end # Check that features are actually transformed @@ -97,7 +97,7 @@ end for i in 1:N @test dataset_copy[i].θ == dataset[i].θ @test dataset_copy[i].y == dataset[i].y - @test dataset_copy[i].info == dataset[i].info + @test dataset_copy[i].context == dataset[i].context end # Test reconstruct (non-mutating) @@ -109,7 +109,7 @@ end @test dataset_reconstructed[i].x ≈ dataset[i].x atol = 1e-10 @test dataset_reconstructed[i].θ == dataset[i].θ @test dataset_reconstructed[i].y == dataset[i].y - @test dataset_reconstructed[i].info == dataset[i].info + @test dataset_reconstructed[i].context == dataset[i].context end # Test reconstruct! (mutating) diff --git a/test/warcraft.jl b/test/warcraft.jl index 5a7a564..6a8ecd9 100644 --- a/test/warcraft.jl +++ b/test/warcraft.jl @@ -24,7 +24,7 @@ y_true = sample.y @test size(x) == (96, 96, 3, 1) @test all(θ_true .<= 0) - @test length(sample.info) == 0 + @test isempty(sample.context) θ = model(x) @test size(θ) == size(θ_true)