Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/src/tutorials/warcraft_tutorial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ x = sample.x
θ_true = sample.θ
# `y` correspond to the optimal shortest path, encoded as a binary matrix:
y_true = sample.y
# `info` is not used in this benchmark, therefore set to nothing:
isnothing(sample.info)
# `context` is not used in this benchmark (no solver kwargs needed), so it is empty:
isempty(sample.context)

# For some benchmarks, we provide the following plotting method [`plot_data`](@ref) to visualize the data:
plot_data(b, sample)
Expand Down
2 changes: 1 addition & 1 deletion src/DynamicVehicleScheduling/anticipative_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ function anticipative_solver(
compute_features(state, env.instance)
end

return DataSample(; y=y_true, x, state, reward)
return DataSample(; y=y_true, x, instance=state, extra=(; reward))
end

return obj, dataset
Expand Down
12 changes: 6 additions & 6 deletions src/DynamicVehicleScheduling/plot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ The returned dictionary contains:
This lets plotting code build figures without depending on plotting internals.
"""
function build_plot_data(data_samples::Vector{<:DataSample})
state_data = [build_state_data(sample.info.state) for sample in data_samples]
rewards = [sample.info.reward for sample in data_samples]
state_data = [build_state_data(sample.instance) for sample in data_samples]
rewards = [sample.reward for sample in data_samples]
routess = [sample.y for sample in data_samples]
return [
(; state..., reward, routes) for
Expand Down Expand Up @@ -273,8 +273,8 @@ function plot_epochs(
# Create subplots
plots = map(1:n_epochs) do i
sample = data_samples[i]
state = sample.info.state
reward = sample.info.reward
state = sample.instance
reward = sample.reward

common_kwargs = Dict(
:xlims => xlims,
Expand Down Expand Up @@ -351,7 +351,7 @@ function animate_epochs(
kwargs...,
)
pd = build_plot_data(data_samples)
epoch_costs = [-sample.info.reward for sample in data_samples]
epoch_costs = [-sample.reward for sample in data_samples]

# Calculate global xlims and ylims from all states
x_min = minimum(min(data.x_depot, minimum(data.x_customers)) for data in pd)
Expand Down Expand Up @@ -393,7 +393,7 @@ function animate_epochs(
anim = @animate for frame_idx in 1:total_frames
epoch_idx, frame_type = frame_plan[frame_idx]
sample = data_samples[epoch_idx]
state = sample.info.state
state = sample.instance

if frame_type == :routes
fig = plot_routes(
Expand Down
82 changes: 60 additions & 22 deletions src/Utils/data_sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,20 @@ $TYPEDEF

Data sample data structure.
Its main purpose is to store datasets generated by the benchmarks.
It has 3 main fields: features `x`, cost parameters `θ` and solution `y`.
Additionally, it has an `info` field to store any additional information as a `NamedTuple`, usually the instance, but can be used for anything else.
It has 3 main (optional) fields: features `x`, cost parameters `θ`, and solution `y`.
Additionally, it has a `context` field (solver kwargs, spread into the maximizer as
`maximizer(θ; sample.context...)`) and an `extra` field (non-solver data, never passed
to the maximizer).

The separation prevents silent breakage from accidentally passing non-solver data
(e.g. a scenario, reward, or step counter) as unexpected kwargs to the maximizer.

# Fields
$TYPEDFIELDS
"""
struct DataSample{
I<:NamedTuple,
CTX<:NamedTuple,
EX<:NamedTuple,
F<:Union{AbstractArray,Nothing},
S<:Union{AbstractArray,Nothing},
C<:Union{AbstractArray,Nothing},
Expand All @@ -21,53 +27,81 @@ struct DataSample{
θ::C
"output solution (optional)"
y::S
"additional information, usually the instance (optional)"
info::I
"context information as solver kwargs, e.g. instance, graph, etc."
context::CTX
"additional data, never passed to the maximizer, e.g. scenario, objective value, reward,
step count, etc. Can be used for any purpose by the user, such as plotting utilities."
extra::EX
end

"""
$TYPEDSIGNATURES

Constructor for `DataSample` with keyword arguments.

Additional keyword arguments beyond `x`, `θ`, and `y` are stored in the `info` field
and can be accessed directly (e.g., `data.instance` instead of `data.info.instance`).
All keyword arguments beyond `x`, `θ`, `y`, and `extra` are collected into the `context`
field (solver kwargs). The `extra` keyword accepts a `NamedTuple` of non-solver data.

Fields in `context` and `extra` must be disjoint. An error is thrown if they overlap.
Both can be accessed directly via property forwarding.

# Examples
```julia
# Instance goes in context
d = DataSample(x=[1,2,3], θ=[4,5,6], y=[7,8,9], instance="my_instance")
d.instance # "my_instance"
d.instance # "my_instance" (from context)

# Scenario goes in extra
d = DataSample(x=x, y=y, instance=inst, extra=(; scenario=ξ))
d.scenario # ξ (from extra)

# State goes in context, reward in extra
d = DataSample(x=x, y=y, instance=state, extra=(; reward=-1.5))
d.instance # state (from context)
d.reward # -1.5 (from extra)
```
"""
function DataSample(; x=nothing, θ=nothing, y=nothing, kwargs...)
info = (; kwargs...)
return DataSample(x, θ, y, info)
function DataSample(; x=nothing, θ=nothing, y=nothing, extra=NamedTuple(), kwargs...)
context = (; kwargs...)
overlap = intersect(keys(context), keys(extra))
if !isempty(overlap)
error("Keys $(collect(overlap)) appear in both context and extra of DataSample")
end
return DataSample(x, θ, y, context, extra)
end

"""
$TYPEDSIGNATURES

Extended property access for `DataSample`.

Allows accessing `info` fields directly as properties (e.g., `d.instance` instead of `d.info.instance`).
Allows accessing `context` and `extra` fields directly as properties.
`context` is searched first; if the key is not found there, `extra` is searched.
"""
function Base.getproperty(d::DataSample, name::Symbol)
if name in (:x, :θ, :y, :info)
if name in (:x, :θ, :y, :context, :extra)
return getfield(d, name)
else
return getproperty(getfield(d, :info), name)
ctx = getfield(d, :context)
if haskey(ctx, name)
return getproperty(ctx, name)
end
return getproperty(getfield(d, :extra), name)
end
end

"""
$TYPEDSIGNATURES

Return all property names of a `DataSample`, including both struct fields and `info` fields.
Return all property names of a `DataSample`, including both struct fields and forwarded
fields from `context` and `extra`.

This enables tab completion for all available properties, including those stored in `info`.
This enables tab completion for all available properties.
"""
function Base.propertynames(d::DataSample, private::Bool=false)
return (fieldnames(DataSample)..., propertynames(getfield(d, :info), private)...)
ctx_names = propertynames(getfield(d, :context), private)
extra_names = propertynames(getfield(d, :extra), private)
return (fieldnames(DataSample)..., ctx_names..., extra_names...)
end

"""
Expand All @@ -92,7 +126,11 @@ function Base.show(io::IO, d::DataSample)
y_str = sprint(show, d.y; context=io_limited)
push!(fields, "y_true=$y_str")
end
for (key, value) in pairs(d.info)
for (key, value) in pairs(d.context)
value_str = sprint(show, value; context=io_limited)
push!(fields, "$key=$value_str")
end
for (key, value) in pairs(d.extra)
value_str = sprint(show, value; context=io_limited)
push!(fields, "$key=$value_str")
end
Expand All @@ -116,8 +154,8 @@ Transform the features in the dataset.
"""
function StatsBase.transform(t, dataset::AbstractVector{<:DataSample})
return map(dataset) do d
(; info, x, θ, y) = d
DataSample(StatsBase.transform(t, x), θ, y, info)
(; context, extra, x, θ, y) = d
DataSample(StatsBase.transform(t, x), θ, y, context, extra)
end
end

Expand All @@ -139,8 +177,8 @@ Reconstruct the features in the dataset.
"""
function StatsBase.reconstruct(t, dataset::AbstractVector{<:DataSample})
return map(dataset) do d
(; info, x, θ, y) = d
DataSample(StatsBase.reconstruct(t, x), θ, y, info)
(; context, extra, x, θ, y) = d
DataSample(StatsBase.reconstruct(t, x), θ, y, context, extra)
end
end

Expand Down
10 changes: 5 additions & 5 deletions src/Utils/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ $TYPEDSIGNATURES
Compute the objective value of given solution `y`.
"""
function objective_value(
bench::AbstractBenchmark, sample::DataSample{I,F,S,C}, y::AbstractArray
) where {I,F,S,C<:AbstractArray}
bench::AbstractBenchmark, sample::DataSample{CTX,EX,F,S,C}, y::AbstractArray
) where {CTX,EX,F,S,C<:AbstractArray}
return objective_value(bench, sample.θ, y)
end

Expand All @@ -122,8 +122,8 @@ $TYPEDSIGNATURES
Compute the objective value of the target in the sample (needs to exist).
"""
function objective_value(
bench::AbstractBenchmark, sample::DataSample{I,F,S,C}
) where {I,F,S<:AbstractArray,C}
bench::AbstractBenchmark, sample::DataSample{CTX,EX,F,S,C}
) where {CTX,EX,F,S<:AbstractArray,C}
return objective_value(bench, sample, sample.y)
end

Expand Down Expand Up @@ -155,7 +155,7 @@ function compute_gap(
target_obj = objective_value(bench, sample)
x = sample.x
θ = statistical_model(x)
y = maximizer(θ; sample.info...)
y = maximizer(θ; sample.context...)
obj = objective_value(bench, sample, y)
Δ = check ? obj - target_obj : target_obj - obj
return Δ / abs(target_obj)
Expand Down
2 changes: 1 addition & 1 deletion src/Utils/policy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ function evaluate_policy!(
features, state = observe(env)
state_copy = deepcopy(state) # To avoid mutation issues
reward = step!(env, y)
sample = DataSample(; x=features, y=y, state=state_copy, reward=reward)
sample = DataSample(; x=features, y=y, instance=state_copy, extra=(; reward))
if isempty(labeled_dataset)
labeled_dataset = typeof(sample)[sample]
else
Expand Down
2 changes: 1 addition & 1 deletion test/argmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
@test size(x) == (nb_features, instance_dim)
@test length(θ_true) == instance_dim
@test length(y_true) == instance_dim
@test length(sample.info) == 0
@test isempty(sample.context)
@test all(y_true .== maximizer(θ_true))

θ = model(x)
Expand Down
2 changes: 1 addition & 1 deletion test/dynamic_assortment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ end

# Test integration with sample data
sample = generate_sample(b, MersenneTwister(42))
@test hasfield(typeof(sample), :info)
@test hasfield(typeof(sample), :context)

dataset = generate_dataset(b, 3; seed=42)
environments = generate_environments(b, dataset)
Expand Down
2 changes: 1 addition & 1 deletion test/dynamic_vsp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
anticipative_value, solution = generate_anticipative_solution(b, env; reset_env=true)
reset!(env; reset_rng=true)
cost = sum(step!(env, sample.y) for sample in solution)
cost2 = sum(sample.info.reward for sample in solution)
cost2 = sum(sample.reward for sample in solution)
@test isapprox(cost, anticipative_value; atol=1e-5)
@test isapprox(cost, cost2; atol=1e-5)
end
2 changes: 1 addition & 1 deletion test/dynamic_vsp_plots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
policies = generate_policies(b)
lazy = policies[1]
_, d = evaluate_policy!(lazy, env)
fig3 = DVSP.plot_routes(d[1].info.state, d[1].y)
fig3 = DVSP.plot_routes(d[1].instance, d[1].y)
@test fig3 isa Plots.Plot

# Test animation
Expand Down
2 changes: 1 addition & 1 deletion test/fixed_size_shortest_path.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
@test size(x) == (p,)
@test length(θ_true) == A
@test length(y_true) == A
@test length(sample.info) == 0
@test isempty(sample.context)
@test all(y_true .== maximizer(θ_true))
θ = model(x)
@test length(θ) == length(θ_true)
Expand Down
2 changes: 1 addition & 1 deletion test/maintenance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ end

# Test integration with sample data
sample = generate_sample(b, MersenneTwister(42))
@test hasfield(typeof(sample), :info)
@test hasfield(typeof(sample), :context)

dataset = generate_dataset(b, 3; seed=42)
environments = generate_environments(b, dataset)
Expand Down
2 changes: 1 addition & 1 deletion test/portfolio_optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
@test size(x) == (p,)
@test length(θ_true) == d
@test length(y_true) == d
@test length(sample.info) == 0
@test isempty(sample.context)
@test all(y_true .== maximizer(θ_true))

θ = model(x)
Expand Down
2 changes: 1 addition & 1 deletion test/ranking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
@test size(x) == (nb_features, instance_dim)
@test length(θ_true) == instance_dim
@test length(y_true) == instance_dim
@test length(sample.info) == 0
@test isempty(sample.context)
@test all(y_true .== maximizer(θ_true))

θ = model(x)
Expand Down
2 changes: 1 addition & 1 deletion test/subset_selection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
@test size(x) == (n,)
@test length(θ_true) == n
@test length(y_true) == n
@test length(sample.info) == 0
@test isempty(sample.context)
@test all(y_true .== maximizer(θ_true))

# Features and true weights should be equal
Expand Down
8 changes: 4 additions & 4 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ end
@test occursin("y_true", s)
@test occursin("instance=\"this is an instance\"", s)

@test propertynames(sample) == (:x, :θ, :y, :info, :instance)
@test propertynames(sample) == (:x, :θ, :y, :context, :extra, :instance)

# Create a dataset for testing
N = 5
Expand All @@ -81,7 +81,7 @@ end
for i in 1:N
@test dataset_zt[i].θ == dataset[i].θ
@test dataset_zt[i].y == dataset[i].y
@test dataset_zt[i].info == dataset[i].info
@test dataset_zt[i].context == dataset[i].context
end

# Check that features are actually transformed
Expand All @@ -97,7 +97,7 @@ end
for i in 1:N
@test dataset_copy[i].θ == dataset[i].θ
@test dataset_copy[i].y == dataset[i].y
@test dataset_copy[i].info == dataset[i].info
@test dataset_copy[i].context == dataset[i].context
end

# Test reconstruct (non-mutating)
Expand All @@ -109,7 +109,7 @@ end
@test dataset_reconstructed[i].x ≈ dataset[i].x atol = 1e-10
@test dataset_reconstructed[i].θ == dataset[i].θ
@test dataset_reconstructed[i].y == dataset[i].y
@test dataset_reconstructed[i].info == dataset[i].info
@test dataset_reconstructed[i].context == dataset[i].context
end

# Test reconstruct! (mutating)
Expand Down
2 changes: 1 addition & 1 deletion test/warcraft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
y_true = sample.y
@test size(x) == (96, 96, 3, 1)
@test all(θ_true .<= 0)
@test length(sample.info) == 0
@test isempty(sample.context)

θ = model(x)
@test size(θ) == size(θ_true)
Expand Down