From 3828e5f953799f58063c64678b2eac49fd038f9c Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Tue, 30 Sep 2025 10:27:17 +0200 Subject: [PATCH 1/4] Rename DataSample fields to x/theta/y/info --- src/Argmax/Argmax.jl | 6 +- src/Argmax2D/Argmax2D.jl | 16 +-- src/DynamicAssortment/DynamicAssortment.jl | 2 +- .../DynamicVehicleScheduling.jl | 2 +- .../anticipative_solver.jl | 25 ++-- .../FixedSizeShortestPath.jl | 6 +- .../PortfolioOptimization.jl | 6 +- src/Ranking/Ranking.jl | 6 +- src/ShortestPath/ShortestPath.jl | 16 --- src/ShortestPath/shortest_paths.jl | 130 ------------------ .../StochasticVehicleScheduling.jl | 12 +- src/SubsetSelection/SubsetSelection.jl | 6 +- src/Utils/Utils.jl | 2 +- src/Utils/data_sample.jl | 38 ++--- src/Utils/interface.jl | 24 +++- 15 files changed, 80 insertions(+), 217 deletions(-) delete mode 100644 src/ShortestPath/ShortestPath.jl delete mode 100644 src/ShortestPath/shortest_paths.jl diff --git a/src/Argmax/Argmax.jl b/src/Argmax/Argmax.jl index fa6ddda..60f37c5 100644 --- a/src/Argmax/Argmax.jl +++ b/src/Argmax/Argmax.jl @@ -76,9 +76,9 @@ function Utils.generate_sample( ) (; instance_dim, nb_features, encoder) = bench features = randn(rng, Float32, nb_features, instance_dim) - costs = encoder(features) - noisy_solution = one_hot_argmax(costs + noise_std * randn(rng, Float32, instance_dim)) - return DataSample(; x=features, θ_true=costs, y_true=noisy_solution) + θ_true = encoder(features) + noisy_y_true = one_hot_argmax(θ_true + noise_std * randn(rng, Float32, instance_dim)) + return DataSample(; x=features, θ=θ_true, y=noisy_y_true) end """ diff --git a/src/Argmax2D/Argmax2D.jl b/src/Argmax2D/Argmax2D.jl index 169c403..5f67c8e 100644 --- a/src/Argmax2D/Argmax2D.jl +++ b/src/Argmax2D/Argmax2D.jl @@ -62,7 +62,7 @@ function Utils.generate_sample(bench::Argmax2DBenchmark, rng::AbstractRNG) θ_true ./= 2 * norm(θ_true) instance = build_polytope(rand(rng, polytope_vertex_range); shift=rand(rng)) y_true = maximizer(θ_true; instance) - return DataSample(; x=x, θ_true=θ_true, y_true=y_true, instance=instance) + return DataSample(; x=x, θ=θ_true, y=y_true, info=instance) end """ @@ -88,11 +88,11 @@ function Utils.generate_statistical_model( return model end -function Utils.plot_data(::Argmax2DBenchmark; instance, θ, kwargs...) +function Utils.plot_data(::Argmax2DBenchmark; info, θ, kwargs...) pl = init_plot() - plot_polytope!(pl, instance) + plot_polytope!(pl, info) plot_objective!(pl, θ) - return plot_maximizer!(pl, θ, instance, maximizer) + return plot_maximizer!(pl, θ, info, maximizer) end """ @@ -101,13 +101,9 @@ $TYPEDSIGNATURES Plot the data sample for the [`Argmax2DBenchmark`](@ref). """ function Utils.plot_data( - bench::Argmax2DBenchmark, - sample::DataSample; - instance=sample.instance, - θ=sample.θ_true, - kwargs..., + bench::Argmax2DBenchmark, sample::DataSample; info=sample.info, θ=sample.θ, kwargs... ) - return Utils.plot_data(bench; instance, θ, kwargs...) + return Utils.plot_data(bench; info, θ, kwargs...) end export Argmax2DBenchmark diff --git a/src/DynamicAssortment/DynamicAssortment.jl b/src/DynamicAssortment/DynamicAssortment.jl index c943dba..14ef63c 100644 --- a/src/DynamicAssortment/DynamicAssortment.jl +++ b/src/DynamicAssortment/DynamicAssortment.jl @@ -83,7 +83,7 @@ Outputs a data sample containing an [`Instance`](@ref). function Utils.generate_sample( b::DynamicAssortmentBenchmark, rng::AbstractRNG=MersenneTwister(0) ) - return DataSample(; instance=Instance(b, rng)) + return DataSample(; info=Instance(b, rng)) end """ diff --git a/src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl b/src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl index b940743..7e8c23e 100644 --- a/src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl +++ b/src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl @@ -63,7 +63,7 @@ function Utils.generate_dataset(b::DynamicVehicleSchedulingBenchmark, dataset_si dataset_size = min(dataset_size, length(files)) return [ DataSample(; - instance=Instance( + info=Instance( read_vsp_instance(files[i]); max_requests_per_epoch, Δ_dispatch, diff --git a/src/DynamicVehicleScheduling/anticipative_solver.jl b/src/DynamicVehicleScheduling/anticipative_solver.jl index 7178532..4c48eca 100644 --- a/src/DynamicVehicleScheduling/anticipative_solver.jl +++ b/src/DynamicVehicleScheduling/anticipative_solver.jl @@ -92,14 +92,14 @@ function anticipative_solver( job_indices = 2:nb_nodes epoch_indices = T - @variable(model, y[i = 1:nb_nodes, j = 1:nb_nodes, t = epoch_indices]; binary=true) + @variable(model, y[i=1:nb_nodes, j=1:nb_nodes, t=epoch_indices]; binary=true) @objective( model, Max, sum( - -duration[i, j] * y[i, j, t] for - i in 1:nb_nodes, j in 1:nb_nodes, t in epoch_indices + -duration[i, j] * y[i, j, t] for i in 1:nb_nodes, j in 1:nb_nodes, + t in epoch_indices ) ) @@ -171,12 +171,14 @@ function anticipative_solver( routes = epoch_routes[i] epoch_customers = epoch_indices[i] - y_true = VSPSolution( - Vector{Int}[ - map(idx -> findfirst(==(idx), epoch_customers), route) for route in routes - ]; - max_index=length(epoch_customers), - ).edge_matrix + y_true = + VSPSolution( + Vector{Int}[ + map(idx -> findfirst(==(idx), epoch_customers), route) for + route in routes + ]; + max_index=length(epoch_customers), + ).edge_matrix location_indices = customer_index[epoch_customers] new_coordinates = env.instance.static_instance.coordinate[location_indices] @@ -200,8 +202,7 @@ function anticipative_solver( is_must_dispatch[2:end] .= true else is_must_dispatch[2:end] .= - planning_start_time .+ epoch_duration .+ @view(new_duration[1, 2:end]) .> - new_start_time[2:end] + planning_start_time .+ epoch_duration .+ @view(new_duration[1, 2:end]) .> new_start_time[2:end] end is_postponable[2:end] .= .!is_must_dispatch[2:end] # TODO: avoid code duplication with add_new_customers! @@ -222,7 +223,7 @@ function anticipative_solver( compute_features(state, env.instance) end - return DataSample(; instance=(; state, reward), y_true, x) + return DataSample(; info=(; state, reward), y, x) end return obj, dataset diff --git a/src/FixedSizeShortestPath/FixedSizeShortestPath.jl b/src/FixedSizeShortestPath/FixedSizeShortestPath.jl index 46a22fe..3a350e5 100644 --- a/src/FixedSizeShortestPath/FixedSizeShortestPath.jl +++ b/src/FixedSizeShortestPath/FixedSizeShortestPath.jl @@ -121,11 +121,11 @@ function Utils.generate_sample( else rand(rng, Uniform{type}(1 - ν, 1 + ν), E) end - costs = -(1 .+ (3 .+ B * features ./ type(sqrt(p))) .^ deg) .* ξ + θ_true = -(1 .+ (3 .+ B * features ./ type(sqrt(p))) .^ deg) .* ξ maximizer = Utils.generate_maximizer(bench) - solution = maximizer(costs) - return DataSample(; x=features, θ_true=costs, y_true=solution) + y_true = maximizer(θ_true) + return DataSample(; x=features, θ=θ_true, y=y_true) end """ diff --git a/src/PortfolioOptimization/PortfolioOptimization.jl b/src/PortfolioOptimization/PortfolioOptimization.jl index 5fc7e8f..f79f488 100644 --- a/src/PortfolioOptimization/PortfolioOptimization.jl +++ b/src/PortfolioOptimization/PortfolioOptimization.jl @@ -94,12 +94,12 @@ function Utils.generate_sample( features = randn(rng, type, p) B = rand(rng, Bernoulli(0.5), d, p) c̄ = (0.05 / type(sqrt(p)) .* B * features .+ 0.1^(1 / deg)) .^ deg - costs = c̄ .+ L * f .+ 0.01 * ν * randn(rng, type, d) + θ_true = c̄ .+ L * f .+ 0.01 * ν * randn(rng, type, d) maximizer = Utils.generate_maximizer(bench) - solution = maximizer(costs) + y_true = maximizer(θ_true) - return DataSample(; x=features, θ_true=costs, y_true=solution) + return DataSample(; x=features, θ=θ_true, y=y_true) end """ diff --git a/src/Ranking/Ranking.jl b/src/Ranking/Ranking.jl index c6ec398..269b98c 100644 --- a/src/Ranking/Ranking.jl +++ b/src/Ranking/Ranking.jl @@ -68,9 +68,9 @@ function Utils.generate_sample( ) (; instance_dim, nb_features, encoder) = bench features = randn(rng, Float32, nb_features, instance_dim) - costs = encoder(features) - noisy_solution = ranking(costs .+ noise_std * randn(rng, Float32, instance_dim)) - return DataSample(; x=features, θ_true=costs, y_true=noisy_solution) + θ_true = encoder(features) + noisy_y_true = ranking(θ_true .+ noise_std * randn(rng, Float32, instance_dim)) + return DataSample(; x=features, θ=θ_true, y=noisy_y_true) end """ diff --git a/src/ShortestPath/ShortestPath.jl b/src/ShortestPath/ShortestPath.jl deleted file mode 100644 index 2cc8812..0000000 --- a/src/ShortestPath/ShortestPath.jl +++ /dev/null @@ -1,16 +0,0 @@ -module ShortestPath - -using ..Utils -using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES -using Distributions -using Flux: Chain, Dense -using Graphs -using LinearAlgebra -using Random -using SparseArrays - -include("shortest_paths.jl") - -export ShortestPathBenchmark - -end diff --git a/src/ShortestPath/shortest_paths.jl b/src/ShortestPath/shortest_paths.jl deleted file mode 100644 index bb422de..0000000 --- a/src/ShortestPath/shortest_paths.jl +++ /dev/null @@ -1,130 +0,0 @@ -""" -$TYPEDEF - -Benchmark problem for the shortest path problem. -In this benchmark, all graphs are acyclic directed grids, all of the same size `grid_size`. -Features are given at instance level (one dimensional vector of length `p` for each graph). - -Data is generated using the process described in: . - -# Fields -$TYPEDFIELDS -""" -struct ShortestPathBenchmark <: AbstractBenchmark - "grid graph instance" - graph::SimpleDiGraph{Int64} - "grid size of graphs" - grid_size::Tuple{Int,Int} - "size of feature vectors" - p::Int - "degree of formula between features and true weights" - deg::Int - "multiplicative noise for true weights sampled between [1-ν, 1+ν], should be between 0 and 1" - ν::Float32 -end - -function Base.show(io::IO, bench::ShortestPathBenchmark) - (; grid_size, p, deg, ν) = bench - return print(io, "ShortestPathBenchmark(grid_size=$grid_size, p=$p, deg=$deg, ν=$ν)") -end - -""" -$TYPEDSIGNATURES - -Constructor for [`ShortestPathBenchmark`](@ref). -""" -function ShortestPathBenchmark(; - grid_size::Tuple{Int,Int}=(5, 5), p::Int=5, deg::Int=1, ν=0.0f0 -) - @assert ν >= 0.0 && ν <= 1.0 - g = DiGraph(collect(edges(Graphs.grid(grid_size)))) - return ShortestPathBenchmark(g, grid_size, p, deg, ν) -end - -""" -$TYPEDSIGNATURES - -Outputs a function that computes the longest path on the grid graph, given edge weights θ as input. - -```julia -maximizer = generate_maximizer(bench) -maximizer(θ) -``` -""" -function Utils.generate_maximizer(bench::ShortestPathBenchmark; use_dijkstra=true) - g = bench.graph - V = Graphs.nv(g) - E = Graphs.ne(g) - - I = [src(e) for e in edges(g)] - J = [dst(e) for e in edges(g)] - algo = - use_dijkstra ? Graphs.dijkstra_shortest_paths : Graphs.bellman_ford_shortest_paths - - function shortest_path_maximizer(θ; kwargs...) - weights = sparse(I, J, -θ, V, V) - parents = algo(g, 1, weights).parents - y = falses(V, V) - u = V - while u != 1 - prev = parents[u] - y[prev, u] = true - u = prev - end - - solution = falses(E) - for (i, edge) in enumerate(edges(g)) - if y[src(edge), dst(edge)] - solution[i] = true - end - end - return solution - end - - return shortest_path_maximizer -end - -""" -$TYPEDSIGNATURES - -Generate dataset for the shortest path problem. -""" -function Utils.generate_dataset( - bench::ShortestPathBenchmark, dataset_size::Int=10; seed::Int=0, type::Type=Float32 -) - # Set seed - rng = MersenneTwister(seed) - (; graph, p, deg, ν) = bench - - E = Graphs.ne(graph) - - # Features - features = [randn(rng, type, p) for _ in 1:dataset_size] - - # True weights - B = rand(rng, Bernoulli(0.5), E, p) - ξ = if ν == 0.0 - [ones(type, E) for _ in 1:dataset_size] - else - [rand(rng, Uniform{type}(1 - ν, 1 + ν), E) for _ in 1:dataset_size] - end - costs = [ - (1 .+ (3 .+ B * zᵢ ./ type(sqrt(p))) .^ deg) .* ξᵢ for (ξᵢ, zᵢ) in zip(ξ, features) - ] - - shortest_path_maximizer = Utils.generate_maximizer(bench) - - # Label solutions - solutions = shortest_path_maximizer.(.-costs) - return [DataSample(; x=x, θ=θ, y=y) for (x, θ, y) in zip(features, costs, solutions)] -end - -""" -$TYPEDSIGNATURES - -Initialize a linear model for `bench` using `Flux`. -""" -function Utils.generate_statistical_model(bench::ShortestPathBenchmark) - (; p, graph) = bench - return Chain(Dense(p, ne(graph))) -end diff --git a/src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl b/src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl index 41801c5..e1b3fa1 100644 --- a/src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl +++ b/src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl @@ -67,7 +67,7 @@ end function Utils.objective_value( ::StochasticVehicleSchedulingBenchmark, sample::DataSample, y::BitVector ) - return evaluate_solution(y, sample.instance) + return evaluate_solution(y, sample.info) end """ @@ -98,7 +98,7 @@ function Utils.generate_sample( else nothing end - return DataSample(; x, instance, y_true) + return DataSample(; x, info=instance, y=y_true) end """ @@ -147,9 +147,9 @@ $TYPEDSIGNATURES function plot_instance( ::StochasticVehicleSchedulingBenchmark, sample::DataSample{<:Instance{City}}; kwargs... ) - (; tasks, district_width, width) = sample.instance.city + (; tasks, district_width, width) = sample.info.city ticks = 0:district_width:width - max_time = maximum(t.end_time for t in sample.instance.city.tasks[1:(end - 1)]) + max_time = maximum(t.end_time for t in sample.info.city.tasks[1:(end - 1)]) fig = plot(; xlabel="x", ylabel="y", @@ -206,9 +206,9 @@ $TYPEDSIGNATURES function plot_solution( ::StochasticVehicleSchedulingBenchmark, sample::DataSample{<:Instance{City}}; kwargs... ) - (; tasks, district_width, width) = sample.instance.city + (; tasks, district_width, width) = sample.info.city ticks = 0:district_width:width - solution = Solution(sample.y_true, sample.instance) + solution = Solution(sample.y_true, sample.info) path_list = compute_path_list(solution) fig = plot(; xlabel="x", diff --git a/src/SubsetSelection/SubsetSelection.jl b/src/SubsetSelection/SubsetSelection.jl index 085324d..416745f 100644 --- a/src/SubsetSelection/SubsetSelection.jl +++ b/src/SubsetSelection/SubsetSelection.jl @@ -66,9 +66,9 @@ Generate a labeled instance for the subset selection problem. function Utils.generate_sample(bench::SubsetSelectionBenchmark, rng::AbstractRNG) (; n, k, mapping) = bench features = randn(rng, Float32, n) - costs = mapping(features) - solution = top_k(costs, k) - return DataSample(; x=features, θ_true=costs, y_true=solution) + θ_true = mapping(features) + y_true = top_k(θ_true, k) + return DataSample(; x=features, θ=θ_true, y=y_true) end """ diff --git a/src/Utils/Utils.jl b/src/Utils/Utils.jl index d738e31..1b4767f 100644 --- a/src/Utils/Utils.jl +++ b/src/Utils/Utils.jl @@ -5,7 +5,7 @@ using Flux: softplus using HiGHS: HiGHS using JuMP: Model using LinearAlgebra: dot -using Random: Random, MersenneTwister +using Random: Random, MersenneTwister, AbstractRNG using SCIP: SCIP using SimpleWeightedGraphs: SimpleWeightedDiGraph using StatsBase: StatsBase diff --git a/src/Utils/data_sample.jl b/src/Utils/data_sample.jl index d0cccc6..f445219 100644 --- a/src/Utils/data_sample.jl +++ b/src/Utils/data_sample.jl @@ -12,14 +12,14 @@ $TYPEDFIELDS S<:Union{AbstractArray,Nothing}, C<:Union{AbstractArray,Nothing}, } - "features" + "input features (optional)" x::F = nothing - "target cost parameters (optional)" - θ_true::C = nothing - "target solution (optional)" - y_true::S = nothing - "instance object (optional)" - instance::I = nothing + "intermediate cost parameters (optional)" + θ::C = nothing + "output solution (optional)" + y::S = nothing + "additional information, usually the instance (optional)" + info::I = nothing end function Base.show(io::IO, d::DataSample) @@ -27,14 +27,14 @@ function Base.show(io::IO, d::DataSample) if !isnothing(d.x) push!(fields, "x=$(d.x)") end - if !isnothing(d.θ_true) - push!(fields, "θ_true=$(d.θ_true)") + if !isnothing(d.θ) + push!(fields, "θ_true=$(d.θ)") end - if !isnothing(d.y_true) - push!(fields, "y_true=$(d.y_true)") + if !isnothing(d.y) + push!(fields, "y_true=$(d.y)") end - if !isnothing(d.instance) - push!(fields, "instance=$(d.instance)") + if !isnothing(d.info) + push!(fields, "instance=$(d.info)") end return print(io, "DataSample(", join(fields, ", "), ")") end @@ -56,15 +56,15 @@ Transform the features in the dataset. """ function StatsBase.transform(t, dataset::AbstractVector{<:DataSample}) return map(dataset) do d - (; instance, x, θ_true, y_true) = d - DataSample(; instance, x=StatsBase.transform(t, x), θ_true, y_true) + (; info, x, θ, y) = d + DataSample(; info, x=StatsBase.transform(t, x), θ, y) end end """ $TYPEDSIGNATURES -Transform the features in the dataset in place. +Transform the features in the dataset, in place. """ function StatsBase.transform!(t, dataset::AbstractVector{<:DataSample}) for d in dataset @@ -79,15 +79,15 @@ Reconstruct the features in the dataset. """ function StatsBase.reconstruct(t, dataset::AbstractVector{<:DataSample}) return map(dataset) do d - (; instance, x, θ_true, y_true) = d - DataSample(; instance, x=StatsBase.reconstruct(t, x), θ_true, y_true) + (; info, x, θ, y) = d + DataSample(; info, x=StatsBase.reconstruct(t, x), θ, y) end end """ $TYPEDSIGNATURES -Reconstruct the features in the dataset in place. +Reconstruct the features in the dataset, in place. """ function StatsBase.reconstruct!(t, dataset::AbstractVector{<:DataSample}) for d in dataset diff --git a/src/Utils/interface.jl b/src/Utils/interface.jl index 1aabda6..19689e5 100644 --- a/src/Utils/interface.jl +++ b/src/Utils/interface.jl @@ -113,7 +113,7 @@ $TYPEDSIGNATURES For benchmarks where there is an instance object, maximizer needs the instance object as a keyword argument. """ function maximizer_kwargs(::AbstractBenchmark, sample::DataSample) - return (; instance=sample.instance) + return (; instance=sample.info) end """ @@ -133,7 +133,7 @@ Compute the objective value of given solution `y`. function objective_value( bench::AbstractBenchmark, sample::DataSample{I,F,S,C}, y::AbstractArray ) where {I,F,S,C<:AbstractArray} - return objective_value(bench, sample.θ_true, y) + return objective_value(bench, sample.θ, y) end """ @@ -144,7 +144,7 @@ Compute the objective value of the target in the sample (needs to exist). function objective_value( bench::AbstractBenchmark, sample::DataSample{I,F,S,C} ) where {I,F,S<:AbstractArray,C} - return objective_value(bench, sample, sample.y_true) + return objective_value(bench, sample, sample.y) end """ @@ -228,17 +228,29 @@ function generate_environment end """ $TYPEDSIGNATURES +Default behaviour of `generate_environment` applied to a data sample. +Uses the info field of the sample as the instance. +""" +function generate_environment( + bench::AbstractDynamicBenchmark, sample::DataSample, rng::AbstractRNG; kwargs... +) + return generate_environment(bench, sample.info, rng; kwargs...) +end + +""" +$TYPEDSIGNATURES + Generate a vector of environments for the given dynamic benchmark and dataset. """ function generate_environments( bench::AbstractDynamicBenchmark, - dataset::AbstractArray{<:DataSample}; + dataset::AbstractArray; seed=nothing, rng=MersenneTwister(seed), kwargs..., ) Random.seed!(rng, seed) - return map(dataset) do sample - generate_environment(bench, sample.instance, rng; kwargs...) + return map(dataset) do instance + generate_environment(bench, instance, rng; kwargs...) end end From 0168942d7fec45b563fc3fac52941ed0c859c2e0 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Tue, 30 Sep 2025 11:27:04 +0200 Subject: [PATCH 2/4] Adjust tests, and fix errors --- .../anticipative_solver.jl | 25 +++++---- src/DynamicVehicleScheduling/plot.jl | 18 +++---- .../StochasticVehicleScheduling.jl | 2 +- src/Utils/policy.jl | 2 +- src/Warcraft/Warcraft.jl | 6 +-- src/Warcraft/utils.jl | 2 +- test/argmax.jl | 6 ++- test/argmax_2d.jl | 7 ++- test/dynamic_assortment.jl | 52 +++++++++---------- test/dynamic_vsp.jl | 6 ++- test/dynamic_vsp_plots.jl | 4 +- test/fixed_size_shortest_path.jl | 6 ++- test/portfolio_optimization.jl | 6 ++- test/ranking.jl | 6 ++- test/subset_selection.jl | 6 ++- test/utils.jl | 25 ++++----- test/vsp.jl | 2 +- test/warcraft.jl | 6 ++- 18 files changed, 100 insertions(+), 87 deletions(-) diff --git a/src/DynamicVehicleScheduling/anticipative_solver.jl b/src/DynamicVehicleScheduling/anticipative_solver.jl index 4c48eca..c2eae89 100644 --- a/src/DynamicVehicleScheduling/anticipative_solver.jl +++ b/src/DynamicVehicleScheduling/anticipative_solver.jl @@ -92,14 +92,14 @@ function anticipative_solver( job_indices = 2:nb_nodes epoch_indices = T - @variable(model, y[i=1:nb_nodes, j=1:nb_nodes, t=epoch_indices]; binary=true) + @variable(model, y[i = 1:nb_nodes, j = 1:nb_nodes, t = epoch_indices]; binary=true) @objective( model, Max, sum( - -duration[i, j] * y[i, j, t] for i in 1:nb_nodes, j in 1:nb_nodes, - t in epoch_indices + -duration[i, j] * y[i, j, t] for + i in 1:nb_nodes, j in 1:nb_nodes, t in epoch_indices ) ) @@ -171,14 +171,12 @@ function anticipative_solver( routes = epoch_routes[i] epoch_customers = epoch_indices[i] - y_true = - VSPSolution( - Vector{Int}[ - map(idx -> findfirst(==(idx), epoch_customers), route) for - route in routes - ]; - max_index=length(epoch_customers), - ).edge_matrix + y_true = VSPSolution( + Vector{Int}[ + map(idx -> findfirst(==(idx), epoch_customers), route) for route in routes + ]; + max_index=length(epoch_customers), + ).edge_matrix location_indices = customer_index[epoch_customers] new_coordinates = env.instance.static_instance.coordinate[location_indices] @@ -202,7 +200,8 @@ function anticipative_solver( is_must_dispatch[2:end] .= true else is_must_dispatch[2:end] .= - planning_start_time .+ epoch_duration .+ @view(new_duration[1, 2:end]) .> new_start_time[2:end] + planning_start_time .+ epoch_duration .+ @view(new_duration[1, 2:end]) .> + new_start_time[2:end] end is_postponable[2:end] .= .!is_must_dispatch[2:end] # TODO: avoid code duplication with add_new_customers! @@ -223,7 +222,7 @@ function anticipative_solver( compute_features(state, env.instance) end - return DataSample(; info=(; state, reward), y, x) + return DataSample(; info=(; state, reward), y=y_true, x) end return obj, dataset diff --git a/src/DynamicVehicleScheduling/plot.jl b/src/DynamicVehicleScheduling/plot.jl index 33dd940..00a76a4 100644 --- a/src/DynamicVehicleScheduling/plot.jl +++ b/src/DynamicVehicleScheduling/plot.jl @@ -208,9 +208,9 @@ The returned dictionary contains: This lets plotting code build figures without depending on plotting internals. """ function build_plot_data(data_samples::Vector{<:DataSample}) - state_data = [build_state_data(sample.instance.state) for sample in data_samples] - rewards = [sample.instance.reward for sample in data_samples] - routess = [sample.y_true for sample in data_samples] + state_data = [build_state_data(sample.info.state) for sample in data_samples] + rewards = [sample.info.reward for sample in data_samples] + routess = [sample.y for sample in data_samples] return [ (; state..., reward, routes) for (state, reward, routes) in zip(state_data, rewards, routess) @@ -273,8 +273,8 @@ function plot_epochs( # Create subplots plots = map(1:n_epochs) do i sample = data_samples[i] - state = sample.instance.state - reward = sample.instance.reward + state = sample.info.state + reward = sample.info.reward common_kwargs = Dict( :xlims => xlims, @@ -292,7 +292,7 @@ function plot_epochs( if plot_routes_flag fig = plot_routes( state, - sample.y_true; + sample.y; reward=reward, show_route_labels=false, common_kwargs..., @@ -351,7 +351,7 @@ function animate_epochs( kwargs..., ) pd = build_plot_data(data_samples) - epoch_costs = [-sample.instance.reward for sample in data_samples] + epoch_costs = [-sample.info.reward for sample in data_samples] # Calculate global xlims and ylims from all states x_min = minimum(min(data.x_depot, minimum(data.x_customers)) for data in pd) @@ -393,12 +393,12 @@ function animate_epochs( anim = @animate for frame_idx in 1:total_frames epoch_idx, frame_type = frame_plan[frame_idx] sample = data_samples[epoch_idx] - state = sample.instance.state + state = sample.info.state if frame_type == :routes fig = plot_routes( state, - sample.y_true; + sample.y; xlims=xlims, ylims=ylims, clims=clims, diff --git a/src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl b/src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl index e1b3fa1..a2b2374 100644 --- a/src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl +++ b/src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl @@ -208,7 +208,7 @@ function plot_solution( ) (; tasks, district_width, width) = sample.info.city ticks = 0:district_width:width - solution = Solution(sample.y_true, sample.info) + solution = Solution(sample.y, sample.info) path_list = compute_path_list(solution) fig = plot(; xlabel="x", diff --git a/src/Utils/policy.jl b/src/Utils/policy.jl index 7fdf582..81d79f3 100644 --- a/src/Utils/policy.jl +++ b/src/Utils/policy.jl @@ -44,7 +44,7 @@ function evaluate_policy!( features, state = observe(env) state_copy = deepcopy(state) # To avoid mutation issues reward = step!(env, y) - sample = DataSample(; x=features, y_true=y, instance=(; state=state_copy, reward)) + sample = DataSample(; x=features, y=y, info=(; state=state_copy, reward)) if @isdefined labeled_dataset push!(labeled_dataset, sample) else diff --git a/src/Warcraft/Warcraft.jl b/src/Warcraft/Warcraft.jl index c4dcbae..6452d33 100644 --- a/src/Warcraft/Warcraft.jl +++ b/src/Warcraft/Warcraft.jl @@ -91,14 +91,14 @@ The keyword argument `θ_true` is used to set the color range of the weights plo function Utils.plot_data( ::WarcraftBenchmark, sample::DataSample; - θ_true=sample.θ_true, + θ_true=sample.θ, θ_title="Weights", y_title="Path", kwargs..., ) x = sample.x - y = sample.y_true - θ = sample.θ_true + y = sample.y + θ = sample.θ im = dropdims(x; dims=4) img = convert_image_for_plot(im) p1 = Plots.plot( diff --git a/src/Warcraft/utils.jl b/src/Warcraft/utils.jl index b6c040a..1244c9f 100644 --- a/src/Warcraft/utils.jl +++ b/src/Warcraft/utils.jl @@ -40,7 +40,7 @@ function create_dataset(decompressed_path::String, nb_samples::Int) ] Y = [BitMatrix(terrain_labels[:, :, i]) for i in 1:N] WG = [-terrain_weights[:, :, i] for i in 1:N] - return [DataSample(; x, y_true, θ_true) for (x, y_true, θ_true) in zip(X, Y, WG)] + return [DataSample(; x, y=y_true, θ=θ_true) for (x, y_true, θ_true) in zip(X, Y, WG)] end """ diff --git a/test/argmax.jl b/test/argmax.jl index 44e8c3a..cce896c 100644 --- a/test/argmax.jl +++ b/test/argmax.jl @@ -18,11 +18,13 @@ @test gap >= 0 for (i, sample) in enumerate(dataset) - (; x, θ_true, y_true) = sample + x = sample.x + θ_true = sample.θ + y_true = sample.y @test size(x) == (nb_features, instance_dim) @test length(θ_true) == instance_dim @test length(y_true) == instance_dim - @test isnothing(sample.instance) + @test isnothing(sample.info) @test all(y_true .== maximizer(θ_true)) θ = model(x) diff --git a/test/argmax_2d.jl b/test/argmax_2d.jl index 06ab2b5..75351b8 100644 --- a/test/argmax_2d.jl +++ b/test/argmax_2d.jl @@ -21,11 +21,14 @@ @test figure isa Plots.Plot for (i, sample) in enumerate(dataset) - (; x, θ_true, y_true, instance) = sample + x = sample.x + θ_true = sample.θ + y_true = sample.y + instance = sample.info @test length(x) == nb_features @test length(θ_true) == 2 @test length(y_true) == 2 - @test !isnothing(sample.instance) + @test !isnothing(instance) @test instance isa Vector{Vector{Float64}} @test all(length(vertex) == 2 for vertex in instance) @test y_true in instance diff --git a/test/dynamic_assortment.jl b/test/dynamic_assortment.jl index b057c44..de8bf41 100644 --- a/test/dynamic_assortment.jl +++ b/test/dynamic_assortment.jl @@ -2,7 +2,7 @@ const DAP = DecisionFocusedLearningBenchmarks.DynamicAssortment end -@testitem "DynamicAssortment - Benchmark Construction" setup=[Imports, DAPSetup] begin +@testitem "DynamicAssortment - Benchmark Construction" setup = [Imports, DAPSetup] begin # Test default constructor b = DynamicAssortmentBenchmark() @test b.N == 20 @@ -13,7 +13,7 @@ end @test !is_exogenous(b) # Test custom constructor - b_custom = DynamicAssortmentBenchmark(N=10, d=3, K=2, max_steps=50, exogenous=true) + b_custom = DynamicAssortmentBenchmark(; N=10, d=3, K=2, max_steps=50, exogenous=true) @test b_custom.N == 10 @test b_custom.d == 3 @test b_custom.K == 2 @@ -28,8 +28,8 @@ end @test DAP.max_steps(b) == 80 end -@testitem "DynamicAssortment - Instance Generation" setup=[Imports, DAPSetup] begin - b = DynamicAssortmentBenchmark(N=5, d=3, K=2) +@testitem "DynamicAssortment - Instance Generation" setup = [Imports, DAPSetup] begin + b = DynamicAssortmentBenchmark(; N=5, d=3, K=2) rng = MersenneTwister(42) instance = DAP.Instance(b, rng) @@ -53,8 +53,8 @@ end @test DAP.prices(instance) == instance.prices end -@testitem "DynamicAssortment - Environment Initialization" setup=[Imports, DAPSetup] begin - b = DynamicAssortmentBenchmark(N=5, d=2, K=2, max_steps=10) +@testitem "DynamicAssortment - Environment Initialization" setup = [Imports, DAPSetup] begin + b = DynamicAssortmentBenchmark(; N=5, d=2, K=2, max_steps=10) instance = DAP.Instance(b, MersenneTwister(42)) env = DAP.Environment(instance; seed=123) @@ -80,8 +80,8 @@ end @test DAP.prices(env) == instance.prices end -@testitem "DynamicAssortment - Environment Reset" setup=[Imports, DAPSetup] begin - b = DynamicAssortmentBenchmark(N=3, d=1, K=2, max_steps=5) +@testitem "DynamicAssortment - Environment Reset" setup = [Imports, DAPSetup] begin + b = DynamicAssortmentBenchmark(; N=3, d=1, K=2, max_steps=5) instance = DAP.Instance(b, MersenneTwister(42)) env = DAP.Environment(instance; seed=123) @@ -107,8 +107,8 @@ end @test env.features ≈ expected_features end -@testitem "DynamicAssortment - Hype Update Logic" setup=[Imports, DAPSetup] begin - b = DynamicAssortmentBenchmark(N=5, d=1, K=2) +@testitem "DynamicAssortment - Hype Update Logic" setup = [Imports, DAPSetup] begin + b = DynamicAssortmentBenchmark(; N=5, d=1, K=2) instance = DAP.Instance(b, MersenneTwister(42)) env = DAP.Environment(instance; seed=123) @@ -135,8 +135,8 @@ end @test all(hype .== 1.0) # Should not affect any item hype end -@testitem "DynamicAssortment - Choice Probabilities" setup=[Imports, DAPSetup] begin - b = DynamicAssortmentBenchmark(N=3, d=1, K=2) +@testitem "DynamicAssortment - Choice Probabilities" setup = [Imports, DAPSetup] begin + b = DynamicAssortmentBenchmark(; N=3, d=1, K=2) instance = DAP.Instance(b, MersenneTwister(42)) env = DAP.Environment(instance; seed=123) @@ -167,8 +167,8 @@ end @test probs[4] ≈ 1.0 # Only no-purchase available end -@testitem "DynamicAssortment - Expected Revenue" setup=[Imports, DAPSetup] begin - b = DynamicAssortmentBenchmark(N=3, d=1, K=2) +@testitem "DynamicAssortment - Expected Revenue" setup = [Imports, DAPSetup] begin + b = DynamicAssortmentBenchmark(; N=3, d=1, K=2) instance = DAP.Instance(b, MersenneTwister(42)) env = DAP.Environment(instance; seed=123) @@ -183,8 +183,8 @@ end @test revenue == 0.0 # Only no-purchase available with price 0 end -@testitem "DynamicAssortment - Environment Step" setup=[Imports, DAPSetup] begin - b = DynamicAssortmentBenchmark(N=3, d=1, K=2, max_steps=5) +@testitem "DynamicAssortment - Environment Step" setup = [Imports, DAPSetup] begin + b = DynamicAssortmentBenchmark(; N=3, d=1, K=2, max_steps=5) instance = DAP.Instance(b, MersenneTwister(42)) env = DAP.Environment(instance; seed=123) @@ -219,9 +219,9 @@ end @test_throws AssertionError step!(env, assortment) end -@testitem "DynamicAssortment - Endogenous vs Exogenous" setup=[Imports, DAPSetup] begin +@testitem "DynamicAssortment - Endogenous vs Exogenous" setup = [Imports, DAPSetup] begin # Test endogenous environment (features change with purchases) - b_endo = DynamicAssortmentBenchmark(N=3, d=1, K=2, exogenous=false) + b_endo = DynamicAssortmentBenchmark(; N=3, d=1, K=2, exogenous=false) instance_endo = DAP.Instance(b_endo, MersenneTwister(42)) env_endo = DAP.Environment(instance_endo; seed=123) @@ -232,7 +232,7 @@ end @test any(env_endo.d_features .!= 0.0) # Delta features should be non-zero # Test exogenous environment (features don't change with purchases) - b_exo = DynamicAssortmentBenchmark(N=3, d=1, K=2, exogenous=true) + b_exo = DynamicAssortmentBenchmark(; N=3, d=1, K=2, exogenous=true) instance_exo = DAP.Instance(b_exo, MersenneTwister(42)) env_exo = DAP.Environment(instance_exo; seed=123) @@ -243,8 +243,8 @@ end @test all(env_exo.d_features .== 0.0) # Delta features should remain zero end -@testitem "DynamicAssortment - Observation" setup=[Imports, DAPSetup] begin - b = DynamicAssortmentBenchmark(N=3, d=2, max_steps=10) +@testitem "DynamicAssortment - Observation" setup = [Imports, DAPSetup] begin + b = DynamicAssortmentBenchmark(; N=3, d=2, max_steps=10) instance = DAP.Instance(b, MersenneTwister(42)) env = DAP.Environment(instance; seed=123) @@ -266,10 +266,10 @@ end @test obs1 != obs2 # Observations should differ after purchase end -@testitem "DynamicAssortment - Policies" setup=[Imports, DAPSetup] begin +@testitem "DynamicAssortment - Policies" setup = [Imports, DAPSetup] begin using Statistics: mean - b = DynamicAssortmentBenchmark(N=5, d=2, K=3, max_steps=20) + b = DynamicAssortmentBenchmark(; N=5, d=2, K=3, max_steps=20) # Generate test data dataset = generate_dataset(b, 10; seed=0) @@ -307,8 +307,8 @@ end @test sum(greedy_action) == DAP.assortment_size(env) end -@testitem "DynamicAssortment - Model and Maximizer Integration" setup=[Imports, DAPSetup] begin - b = DynamicAssortmentBenchmark(N=4, d=3, K=2) +@testitem "DynamicAssortment - Model and Maximizer Integration" setup = [Imports, DAPSetup] begin + b = DynamicAssortmentBenchmark(; N=4, d=3, K=2) # Test statistical model generation model = generate_statistical_model(b; seed=42) @@ -317,7 +317,7 @@ end # Test integration with sample data sample = generate_sample(b, MersenneTwister(42)) - @test hasfield(typeof(sample), :instance) + @test hasfield(typeof(sample), :info) dataset = generate_dataset(b, 3; seed=42) environments = generate_environments(b, dataset) diff --git a/test/dynamic_vsp.jl b/test/dynamic_vsp.jl index c2ea1f3..131ca89 100644 --- a/test/dynamic_vsp.jl +++ b/test/dynamic_vsp.jl @@ -26,7 +26,7 @@ @test mean(r_lazy) <= mean(r_greedy) env = environments[1] - instance = dataset[1].instance + instance = dataset[1].info scenario = generate_scenario(b, instance) v, y = generate_anticipative_solution(b, env, scenario; nb_epochs=2, reset_env=true) @@ -49,6 +49,8 @@ anticipative_value, solution = generate_anticipative_solution(b, env; reset_env=true) reset!(env; reset_rng=true) - cost = sum(step!(env, sample.y_true) for sample in solution) + cost = sum(step!(env, sample.y) for sample in solution) + cost2 = sum(sample.info.reward for sample in solution) @test isapprox(cost, anticipative_value; atol=1e-5) + @test isapprox(cost, cost2; atol=1e-5) end diff --git a/test/dynamic_vsp_plots.jl b/test/dynamic_vsp_plots.jl index 185c91a..5bba734 100644 --- a/test/dynamic_vsp_plots.jl +++ b/test/dynamic_vsp_plots.jl @@ -13,7 +13,7 @@ fig1 = DVSP.plot_instance(env) @test fig1 isa Plots.Plot - instance = dataset[1].instance + instance = dataset[1].info scenario = generate_scenario(b, instance; seed=0) v, y = generate_anticipative_solution(b, env, scenario; nb_epochs=3, reset_env=true) @@ -23,7 +23,7 @@ policies = generate_policies(b) lazy = policies[1] _, d = evaluate_policy!(lazy, env) - fig3 = DVSP.plot_routes(d[1].instance.state, d[1].y_true) + fig3 = DVSP.plot_routes(d[1].info.state, d[1].y) @test fig3 isa Plots.Plot # Test animation diff --git a/test/fixed_size_shortest_path.jl b/test/fixed_size_shortest_path.jl index 257e7b2..c8f659a 100644 --- a/test/fixed_size_shortest_path.jl +++ b/test/fixed_size_shortest_path.jl @@ -18,12 +18,14 @@ @test gap >= 0 for sample in dataset - (; x, θ_true, y_true) = sample + x = sample.x + θ_true = sample.θ + y_true = sample.y @test all(θ_true .< 0) @test size(x) == (p,) @test length(θ_true) == A @test length(y_true) == A - @test isnothing(sample.instance) + @test isnothing(sample.info) @test all(y_true .== maximizer(θ_true)) θ = model(x) @test length(θ) == length(θ_true) diff --git a/test/portfolio_optimization.jl b/test/portfolio_optimization.jl index 7c983d7..6d0f0a2 100644 --- a/test/portfolio_optimization.jl +++ b/test/portfolio_optimization.jl @@ -10,11 +10,13 @@ maximizer = generate_maximizer(b) for sample in dataset - (; x, θ_true, y_true) = sample + x = sample.x + θ_true = sample.θ + y_true = sample.y @test size(x) == (p,) @test length(θ_true) == d @test length(y_true) == d - @test isnothing(sample.instance) + @test isnothing(sample.info) @test all(y_true .== maximizer(θ_true)) θ = model(x) diff --git a/test/ranking.jl b/test/ranking.jl index 1756733..8c8cf10 100644 --- a/test/ranking.jl +++ b/test/ranking.jl @@ -15,11 +15,13 @@ maximizer = generate_maximizer(b) for (i, sample) in enumerate(dataset) - (; x, θ_true, y_true) = sample + x = sample.x + θ_true = sample.θ + y_true = sample.y @test size(x) == (nb_features, instance_dim) @test length(θ_true) == instance_dim @test length(y_true) == instance_dim - @test isnothing(sample.instance) + @test isnothing(sample.info) @test all(y_true .== maximizer(θ_true)) θ = model(x) diff --git a/test/subset_selection.jl b/test/subset_selection.jl index d59ae54..6cb6ea9 100644 --- a/test/subset_selection.jl +++ b/test/subset_selection.jl @@ -17,11 +17,13 @@ maximizer = generate_maximizer(b) for (i, sample) in enumerate(dataset) - (; x, θ_true, y_true) = sample + x = sample.x + θ_true = sample.θ + y_true = sample.y @test size(x) == (n,) @test length(θ_true) == n @test length(y_true) == n - @test isnothing(sample.instance) + @test isnothing(sample.info) @test all(y_true .== maximizer(θ_true)) # Features and true weights should be equal diff --git a/test/utils.jl b/test/utils.jl index 8591750..21d0820 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -32,10 +32,7 @@ end function random_sample() return DataSample(; - x=randn(rng, 10, 5), - θ_true=rand(rng, 5), - y_true=rand(rng, 10), - instance="this is an instance", + x=randn(rng, 10, 5), θ=rand(rng, 5), y=rand(rng, 10), info="this is an instance" ) end @@ -45,7 +42,7 @@ end io = IOBuffer() show(io, sample) @test String(take!(io)) == - "DataSample(x=$(sample.x), θ_true=$(sample.θ_true), y_true=$(sample.y_true), instance=$(sample.instance))" + "DataSample(x=$(sample.x), θ_true=$(sample.θ), y_true=$(sample.y), instance=$(sample.info))" # Test StatsBase methods using StatsBase: @@ -76,9 +73,9 @@ end # Check that other fields are preserved for i in 1:N - @test dataset_zt[i].θ_true == dataset[i].θ_true - @test dataset_zt[i].y_true == dataset[i].y_true - @test dataset_zt[i].instance == dataset[i].instance + @test dataset_zt[i].θ == dataset[i].θ + @test dataset_zt[i].y == dataset[i].y + @test dataset_zt[i].info == dataset[i].info end # Check that features are actually transformed @@ -92,9 +89,9 @@ end # Check that other fields remain unchanged after transform! for i in 1:N - @test dataset_copy[i].θ_true == dataset[i].θ_true - @test dataset_copy[i].y_true == dataset[i].y_true - @test dataset_copy[i].instance == dataset[i].instance + @test dataset_copy[i].θ == dataset[i].θ + @test dataset_copy[i].y == dataset[i].y + @test dataset_copy[i].info == dataset[i].info end # Test reconstruct (non-mutating) @@ -104,9 +101,9 @@ end # Test round-trip consistency (should be close to original) for i in 1:N @test dataset_reconstructed[i].x ≈ dataset[i].x atol = 1e-10 - @test dataset_reconstructed[i].θ_true == dataset[i].θ_true - @test dataset_reconstructed[i].y_true == dataset[i].y_true - @test dataset_reconstructed[i].instance == dataset[i].instance + @test dataset_reconstructed[i].θ == dataset[i].θ + @test dataset_reconstructed[i].y == dataset[i].y + @test dataset_reconstructed[i].info == dataset[i].info end # Test reconstruct! (mutating) diff --git a/test/vsp.jl b/test/vsp.jl index 1c3b5fd..eb054da 100644 --- a/test/vsp.jl +++ b/test/vsp.jl @@ -42,7 +42,7 @@ for sample in dataset x = sample.x - instance = sample.instance + instance = sample.info E = ne(instance.graph) @test size(x) == (20, E) θ = model(x) diff --git a/test/warcraft.jl b/test/warcraft.jl index 5d52cdd..7653609 100644 --- a/test/warcraft.jl +++ b/test/warcraft.jl @@ -19,10 +19,12 @@ @test gap >= 0 for (i, sample) in enumerate(dataset) - (; x, θ_true, y_true) = sample + x = sample.x + θ_true = sample.θ + y_true = sample.y @test size(x) == (96, 96, 3, 1) @test all(θ_true .<= 0) - @test isnothing(sample.instance) + @test isnothing(sample.info) θ = model(x) @test size(θ) == size(θ_true) From abc217aa3f1501671c67fc617b764d391e8610bf Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Tue, 30 Sep 2025 11:37:48 +0200 Subject: [PATCH 3/4] Adjust doc, and fix tutorial --- docs/src/benchmark_interfaces.md | 8 ++++---- docs/src/tutorials/warcraft_tutorial.jl | 20 ++++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/src/benchmark_interfaces.md b/docs/src/benchmark_interfaces.md index faa6fb7..01f4953 100644 --- a/docs/src/benchmark_interfaces.md +++ b/docs/src/benchmark_interfaces.md @@ -11,10 +11,10 @@ All benchmarks work with [`DataSample`](@ref) objects that encapsulate the data ```julia @kwdef struct DataSample{I,F,S,C} - x::F = nothing # Input features - θ_true::C = nothing # True cost/utility parameters - y_true::S = nothing # True optimal solution - instance::I = nothing # Problem instance object/additional data + x::F = nothing # Input features of the policy + θ::C = nothing # Intermediate cost/utility parameters + y::S = nothing # Output solution + info::I = nothing # Additional data information (e.g., problem instance) end ``` diff --git a/docs/src/tutorials/warcraft_tutorial.jl b/docs/src/tutorials/warcraft_tutorial.jl index 2d41563..9039d09 100644 --- a/docs/src/tutorials/warcraft_tutorial.jl +++ b/docs/src/tutorials/warcraft_tutorial.jl @@ -21,16 +21,16 @@ dataset = generate_dataset(b, 50); # Subdatasets can be created through regular slicing: train_dataset, test_dataset = dataset[1:45], dataset[46:50] -# And getting an individual sample will return a [`DataSample`](@ref) with four fields: `x`, `instance`, `θ`, and `y`: +# And getting an individual sample will return a [`DataSample`](@ref) with four fields: `x`, `info`, `θ`, and `y`: sample = test_dataset[1] # `x` correspond to the input features, i.e. the input image (3D array) in the Warcraft benchmark case: x = sample.x -# `θ_true` correspond to the true unknown terrain weights. We use the opposite of the true weights in order to formulate the optimization problem as a maximization problem: -θ_true = sample.θ_true -# `y_true` correspond to the optimal shortest path, encoded as a binary matrix: -y_true = sample.y_true -# `instance` is not used in this benchmark, therefore set to nothing: -isnothing(sample.instance) +# `θ` correspond to the true unknown terrain weights. We use the opposite of the true weights in order to formulate the optimization problem as a maximization problem: +θ_true = sample.θ +# `y` correspond to the optimal shortest path, encoded as a binary matrix: +y_true = sample.y +# `info` is not used in this benchmark, therefore set to nothing: +isnothing(sample.info) # For some benchmarks, we provide the following plotting method [`plot_data`](@ref) to visualize the data: plot_data(b, sample) @@ -50,7 +50,7 @@ maximizer = generate_maximizer(b; dijkstra=true) # In the case o fthe Warcraft benchmark, the method has an additional keyword argument to chose the algorithm to use: Dijkstra's algorithm or Bellman-Ford algorithm. y = maximizer(θ) # As we can see, currently the pipeline predicts random noise as cell weights, and therefore the maximizer returns a straight line path. -plot_data(b, DataSample(; x, θ_true=θ, y_true=y)) +plot_data(b, DataSample(; x, θ, y)) # We can evaluate the current pipeline performance using the optimality gap metric: starting_gap = compute_gap(b, test_dataset, model, maximizer) @@ -70,7 +70,7 @@ opt_state = Flux.setup(Adam(1e-3), model) loss_history = Float64[] for epoch in 1:50 val, grads = Flux.withgradient(model) do m - sum(loss(m(x), y_true) for (; x, y_true) in train_dataset) / length(train_dataset) + sum(loss(m(x), y) for (; x, y) in train_dataset) / length(train_dataset) end Flux.update!(opt_state, model, grads[1]) push!(loss_history, val) @@ -85,7 +85,7 @@ final_gap = compute_gap(b, test_dataset, model, maximizer) # θ = model(x) y = maximizer(θ) -plot_data(b, DataSample(; x, θ_true=θ, y_true=y)) +plot_data(b, DataSample(; x, θ, y)) using Test #src @test final_gap < starting_gap #src From fb7977a3bfb2e499098547f941a2103b41ff5e0c Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Tue, 30 Sep 2025 11:41:51 +0200 Subject: [PATCH 4/4] small fix --- docs/src/benchmark_interfaces.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/benchmark_interfaces.md b/docs/src/benchmark_interfaces.md index 01f4953..7543836 100644 --- a/docs/src/benchmark_interfaces.md +++ b/docs/src/benchmark_interfaces.md @@ -18,7 +18,7 @@ All benchmarks work with [`DataSample`](@ref) objects that encapsulate the data end ``` -The `DataSample` provides flexibility - not all fields need to be populated depending on the benchmark type and use case. +The `DataSample` provides flexibility, not all fields need to be populated depending on the benchmark type and use. ### Benchmark Type Hierarchy