diff --git a/docs/src/tutorials/warcraft_tutorial.jl b/docs/src/tutorials/warcraft_tutorial.jl index 8b7b8d9..b801d7a 100644 --- a/docs/src/tutorials/warcraft_tutorial.jl +++ b/docs/src/tutorials/warcraft_tutorial.jl @@ -30,8 +30,8 @@ x = sample.x θ_true = sample.θ # `y` correspond to the optimal shortest path, encoded as a binary matrix: y_true = sample.y -# `maximizer_kwargs` is not used in this benchmark (no solver kwargs needed), so it is empty: -isempty(sample.maximizer_kwargs) +# `context` is not used in this benchmark (no solver kwargs needed), so it is empty: +isempty(sample.context) # For some benchmarks, we provide the following plotting method [`plot_solution`](@ref) to visualize the data: plot_solution(b, sample) diff --git a/ext/DFLBenchmarksPlotsExt.jl b/ext/DFLBenchmarksPlotsExt.jl index 23fe0d5..0a5caae 100644 --- a/ext/DFLBenchmarksPlotsExt.jl +++ b/ext/DFLBenchmarksPlotsExt.jl @@ -21,9 +21,7 @@ Reconstruct a new sample with `y` overridden and delegate to the 2-arg function plot_solution(bench::AbstractBenchmark, sample::DataSample, y; kwargs...) return plot_solution( bench, - DataSample(; - sample.maximizer_kwargs..., x=sample.x, θ=sample.θ, y=y, extra=sample.extra - ); + DataSample(; sample.context..., x=sample.x, θ=sample.θ, y=y, extra=sample.extra); kwargs..., ) end diff --git a/src/Argmax/Argmax.jl b/src/Argmax/Argmax.jl index a4faede..6775f9c 100644 --- a/src/Argmax/Argmax.jl +++ b/src/Argmax/Argmax.jl @@ -5,6 +5,8 @@ using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES using Flux: Chain, Dense using Random +using LinearAlgebra: dot + """ $TYPEDEF @@ -29,6 +31,8 @@ function Base.show(io::IO, bench::ArgmaxBenchmark) ) end +Utils.objective_value(::ArgmaxBenchmark, sample::DataSample, y) = dot(sample.θ, y) + """ $TYPEDSIGNATURES diff --git a/src/Argmax2D/Argmax2D.jl b/src/Argmax2D/Argmax2D.jl index 968d63a..2bae9f5 100644 --- a/src/Argmax2D/Argmax2D.jl +++ b/src/Argmax2D/Argmax2D.jl @@ -30,6 +30,8 @@ function Base.show(io::IO, bench::Argmax2DBenchmark) return print(io, "Argmax2DBenchmark(nb_features=$nb_features)") end +Utils.objective_value(::Argmax2DBenchmark, sample::DataSample, y) = dot(sample.θ, y) + """ $TYPEDSIGNATURES diff --git a/src/ContextualStochasticArgmax/ContextualStochasticArgmax.jl b/src/ContextualStochasticArgmax/ContextualStochasticArgmax.jl index fb27286..f8009f9 100644 --- a/src/ContextualStochasticArgmax/ContextualStochasticArgmax.jl +++ b/src/ContextualStochasticArgmax/ContextualStochasticArgmax.jl @@ -3,6 +3,7 @@ module ContextualStochasticArgmax using ..Utils using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES using Flux: Dense +using LinearAlgebra: dot using Random: Random, AbstractRNG, MersenneTwister using Statistics: mean @@ -11,7 +12,7 @@ $TYPEDEF Minimal contextual stochastic argmax benchmark. -Per instance: `c_base ~ U[0,1]^n` (base utility, stored in `extra` of the instance sample). +Per instance: `c_base ~ U[0,1]^n` (base utility, stored in `context` of the instance sample). Per context draw: `x_raw ~ N(0, I_d)` (observable context). Features: `x = [c_base; x_raw]`. Per scenario: `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`. The learner sees `x` and must predict `θ̂` so that `argmax(θ̂)` ≈ `argmax(ξ)`. @@ -44,17 +45,34 @@ end Utils.is_minimization_problem(::ContextualStochasticArgmaxBenchmark) = false Utils.generate_maximizer(::ContextualStochasticArgmaxBenchmark) = one_hot_argmax +function Utils.objective_value( + ::ContextualStochasticArgmaxBenchmark, sample::DataSample, y, scenario +) + return dot(scenario, y) +end + +function Utils.objective_value( + bench::ContextualStochasticArgmaxBenchmark, sample::DataSample, y +) + if hasproperty(sample.extra, :scenario) + return Utils.objective_value(bench, sample, y, sample.scenario) + elseif hasproperty(sample.extra, :scenarios) + return mean(Utils.objective_value(bench, sample, y, ξ) for ξ in sample.scenarios) + end + return error("Sample must have scenario or scenarios") +end + """ generate_instance(::ContextualStochasticArgmaxBenchmark, rng) -Draw `c_base ~ U[0,1]^n` and store it in `extra`. No solver kwargs are needed +Draw `c_base ~ U[0,1]^n` and store it in `context`. No solver kwargs are needed (the maximizer is `one_hot_argmax`, which takes no kwargs). """ function Utils.generate_instance( bench::ContextualStochasticArgmaxBenchmark, rng::AbstractRNG; kwargs... ) c_base = rand(rng, Float32, bench.n) - return DataSample(; extra=(; c_base)) + return DataSample(; c_base) end """ @@ -62,7 +80,7 @@ end Draw `x_raw ~ N(0, I_d)` and return a context sample with: - `x = [c_base; x_raw]`: full feature vector seen by the ML model. -- `extra = (; c_base, x_raw)`: latents spread into [`generate_scenario`](@ref). +- `c_base`, `x_raw` in `context`: spread into [`generate_scenario`](@ref). """ function Utils.generate_context( bench::ContextualStochasticArgmaxBenchmark, @@ -71,14 +89,14 @@ function Utils.generate_context( ) c_base = instance_sample.c_base x_raw = randn(rng, Float32, bench.d) - return DataSample(; x=vcat(c_base, x_raw), extra=(; x_raw, c_base)) + return DataSample(; x=vcat(c_base, x_raw), c_base, x_raw) end """ generate_scenario(::ContextualStochasticArgmaxBenchmark, rng; c_base, x_raw, kwargs...) Draw `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`. -`c_base` and `x_raw` are spread from `ctx.extra` by the framework. +`c_base` and `x_raw` are spread from `ctx.context` by the framework. """ function Utils.generate_scenario( bench::ContextualStochasticArgmaxBenchmark, diff --git a/src/ContextualStochasticArgmax/policies.jl b/src/ContextualStochasticArgmax/policies.jl index 76244ab..0e199b4 100644 --- a/src/ContextualStochasticArgmax/policies.jl +++ b/src/ContextualStochasticArgmax/policies.jl @@ -11,7 +11,7 @@ function csa_saa_policy(ctx_sample, scenarios) y = one_hot_argmax(mean(scenarios)) return [ DataSample(; - ctx_sample.maximizer_kwargs..., + ctx_sample.context..., x=ctx_sample.x, y=y, extra=(; ctx_sample.extra..., scenarios), diff --git a/src/DecisionFocusedLearningBenchmarks.jl b/src/DecisionFocusedLearningBenchmarks.jl index 6dc5db6..3ef3448 100644 --- a/src/DecisionFocusedLearningBenchmarks.jl +++ b/src/DecisionFocusedLearningBenchmarks.jl @@ -79,6 +79,7 @@ export generate_statistical_model export generate_maximizer export generate_anticipative_solver, generate_parametric_anticipative_solver export is_exogenous, is_endogenous +export is_minimization_problem export objective_value export has_visualization, plot_instance, plot_solution, plot_trajectory, animate_trajectory diff --git a/src/FixedSizeShortestPath/FixedSizeShortestPath.jl b/src/FixedSizeShortestPath/FixedSizeShortestPath.jl index e4cc64b..700ff86 100644 --- a/src/FixedSizeShortestPath/FixedSizeShortestPath.jl +++ b/src/FixedSizeShortestPath/FixedSizeShortestPath.jl @@ -55,9 +55,9 @@ function FixedSizeShortestPathBenchmark(; end function Utils.objective_value( - ::FixedSizeShortestPathBenchmark, θ::AbstractArray, y::AbstractArray + ::FixedSizeShortestPathBenchmark, sample::DataSample, y::AbstractArray ) - return -dot(θ, y) + return -dot(sample.θ, y) end """ diff --git a/src/PortfolioOptimization/PortfolioOptimization.jl b/src/PortfolioOptimization/PortfolioOptimization.jl index 9e8c277..4f374ca 100644 --- a/src/PortfolioOptimization/PortfolioOptimization.jl +++ b/src/PortfolioOptimization/PortfolioOptimization.jl @@ -6,7 +6,7 @@ using Distributions: Uniform, Bernoulli using Flux: Chain, Dense using Ipopt: Ipopt using JuMP: @variable, @objective, @constraint, optimize!, value, Model, set_silent -using LinearAlgebra: I +using LinearAlgebra: I, dot using Random: Random, AbstractRNG, MersenneTwister """ @@ -38,6 +38,10 @@ struct PortfolioOptimizationBenchmark <: AbstractBenchmark f::Vector{Float32} end +function Utils.objective_value(::PortfolioOptimizationBenchmark, sample::DataSample, y) + return dot(sample.θ, y) +end + """ $TYPEDSIGNATURES diff --git a/src/Ranking/Ranking.jl b/src/Ranking/Ranking.jl index 269b98c..b826575 100644 --- a/src/Ranking/Ranking.jl +++ b/src/Ranking/Ranking.jl @@ -5,6 +5,8 @@ using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES using Flux: Chain, Dense using Random +using LinearAlgebra: dot + """ $TYPEDEF @@ -29,6 +31,9 @@ function Base.show(io::IO, bench::RankingBenchmark) ) end +Utils.objective_value(::RankingBenchmark, sample::DataSample, y) = dot(sample.θ, y) +Utils.is_minimization_problem(::RankingBenchmark) = false + """ $TYPEDSIGNATURES diff --git a/src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl b/src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl index ca28004..0f1afbb 100644 --- a/src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl +++ b/src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl @@ -68,12 +68,27 @@ end include("policies.jl") function Utils.objective_value( - ::StochasticVehicleSchedulingBenchmark, sample::DataSample, y::BitVector + ::StochasticVehicleSchedulingBenchmark, + sample::DataSample, + y::BitVector, + scenario::VSPScenario, ) - stoch = build_stochastic_instance(sample.instance, sample.extra.scenarios) + stoch = build_stochastic_instance(sample.instance, [scenario]) return evaluate_solution(y, stoch) end +function Utils.objective_value( + bench::StochasticVehicleSchedulingBenchmark, sample::DataSample, y::BitVector +) + if hasproperty(sample.extra, :scenario) + return Utils.objective_value(bench, sample, y, sample.extra.scenario) + elseif hasproperty(sample.extra, :scenarios) + stoch = build_stochastic_instance(sample.instance, sample.extra.scenarios) + return evaluate_solution(y, stoch) + end + return error("Sample must have scenario or scenarios") +end + """ $TYPEDSIGNATURES @@ -116,7 +131,7 @@ Returns a [`DataSample`](@ref) with features `x` and `instance` set, but `y=noth To obtain labeled samples, pass a `target_policy` to [`generate_dataset`](@ref): ```julia -policy = sample -> DataSample(; sample.maximizer_kwargs..., x=sample.x, +policy = sample -> DataSample(; sample.context..., x=sample.x, y=column_generation_algorithm(sample.instance)) dataset = generate_dataset(benchmark, N; target_policy=policy) ``` diff --git a/src/StochasticVehicleScheduling/policies.jl b/src/StochasticVehicleScheduling/policies.jl index 49d6607..69fc2bf 100644 --- a/src/StochasticVehicleScheduling/policies.jl +++ b/src/StochasticVehicleScheduling/policies.jl @@ -10,7 +10,7 @@ function svs_saa_policy(ctx_sample, scenarios) y = column_generation_algorithm(stochastic_inst) return [ DataSample(; - ctx_sample.maximizer_kwargs..., + ctx_sample.context..., x=ctx_sample.x, y, extra=(; ctx_sample.extra..., scenarios), @@ -28,7 +28,7 @@ function svs_deterministic_policy(ctx_sample, scenarios; model_builder=highs_mod y = deterministic_mip(ctx_sample.instance; model_builder) return [ DataSample(; - ctx_sample.maximizer_kwargs..., + ctx_sample.context..., x=ctx_sample.x, y, extra=(; ctx_sample.extra..., scenarios), @@ -48,7 +48,7 @@ function svs_local_search_policy(ctx_sample, scenarios) y = local_search(stochastic_inst) return [ DataSample(; - ctx_sample.maximizer_kwargs..., + ctx_sample.context..., x=ctx_sample.x, y, extra=(; ctx_sample.extra..., scenarios), @@ -70,7 +70,7 @@ function svs_saa_mip_policy(ctx_sample, scenarios; model_builder=scip_model) y = compact_linearized_mip(ctx_sample.instance, scenarios; model_builder) return [ DataSample(; - ctx_sample.maximizer_kwargs..., + ctx_sample.context..., x=ctx_sample.x, y, extra=(; ctx_sample.extra..., scenarios), diff --git a/src/SubsetSelection/SubsetSelection.jl b/src/SubsetSelection/SubsetSelection.jl index a05359d..98fa0ea 100644 --- a/src/SubsetSelection/SubsetSelection.jl +++ b/src/SubsetSelection/SubsetSelection.jl @@ -3,6 +3,7 @@ module SubsetSelection using ..Utils using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES using Flux: Chain, Dense +using LinearAlgebra: dot using Random """ @@ -31,6 +32,9 @@ function Base.show(io::IO, bench::SubsetSelectionBenchmark) return print(io, "SubsetSelectionBenchmark(n=$n, k=$k)") end +Utils.objective_value(::SubsetSelectionBenchmark, sample::DataSample, y) = dot(sample.θ, y) +Utils.is_minimization_problem(::SubsetSelectionBenchmark) = false + function SubsetSelectionBenchmark(; n::Int=25, k::Int=5, identity_mapping::Bool=true) @assert n >= k "number of items n must be greater than k" mapping = if identity_mapping diff --git a/src/Utils/Utils.jl b/src/Utils/Utils.jl index 51e0834..3b6b4fc 100644 --- a/src/Utils/Utils.jl +++ b/src/Utils/Utils.jl @@ -36,6 +36,7 @@ export generate_environment, generate_environments export SampleAverageApproximation export generate_baseline_policies export generate_anticipative_solver, generate_parametric_anticipative_solver +export is_minimization_problem export has_visualization, plot_instance, plot_solution, plot_trajectory, animate_trajectory export compute_gap diff --git a/src/Utils/data_sample.jl b/src/Utils/data_sample.jl index a22963f..1147761 100644 --- a/src/Utils/data_sample.jl +++ b/src/Utils/data_sample.jl @@ -4,8 +4,9 @@ $TYPEDEF Data sample data structure. Its main purpose is to store datasets generated by the benchmarks. It has 3 main (optional) fields: features `x`, cost parameters `θ`, and solution `y`. -Additionally, it has an `maximizer_kwargs` field (solver kwargs, spread into the maximizer as -`maximizer(θ; sample.maximizer_kwargs...)`) and an `extra` field (non-solver data, never passed +Currently, all three are restricted to `AbstractArray` or `nothing`. +Additionally, it has a `context` field (solver and scenario-generation context, spread into the +maximizer as `maximizer(θ; sample.context...)`) and an `extra` field (non-solver data, never passed to the maximizer). The separation prevents silent breakage from accidentally passing non-solver data @@ -27,8 +28,8 @@ struct DataSample{ θ::C "output solution (optional)" y::S - "solver kwargs, e.g. instance, graph, etc." - maximizer_kwargs::K + "solver and scenario-generation context, e.g. instance, graph, contextual information" + context::K "additional data, never passed to the maximizer, e.g. scenario, objective value, reward, step count, etc. Can be used for any purpose by the user, such as plotting utilities." extra::E @@ -39,37 +40,46 @@ $TYPEDSIGNATURES Constructor for `DataSample` with keyword arguments. -All keyword arguments beyond `x`, `θ`, `y`, and `extra` are collected into the `maximizer_kwargs` -field (solver kwargs). The `extra` keyword accepts a `NamedTuple` of non-solver data. +All keyword arguments beyond `x`, `θ`, `y`, and `extra` are collected into the `context` +field (solver and scenario-generation context). The `extra` keyword accepts a `NamedTuple` of non-solver data. -Fields in `maximizer_kwargs` and `extra` must be disjoint. An error is thrown if they overlap. +Fields in `context` and `extra` must be disjoint. Neither may use a reserved +struct field name (`x`, `θ`, `y`, `context`, `extra`). An error is thrown in +both cases. Both can be accessed directly via property forwarding. # Examples ```julia -# Instance goes in maximizer_kwargs +# Instance goes in context d = DataSample(x=[1,2,3], θ=[4,5,6], y=[7,8,9], instance="my_instance") -d.instance # "my_instance" (from maximizer_kwargs) +d.instance # "my_instance" (from context) # Scenario goes in extra d = DataSample(x=x, y=y, instance=inst, extra=(; scenario=ξ)) d.scenario # ξ (from extra) -# State goes in maximizer_kwargs, reward in extra +# State goes in context, reward in extra d = DataSample(x=x, y=y, instance=state, extra=(; reward=-1.5)) -d.instance # state (from maximizer_kwargs) +d.instance # state (from context) d.reward # -1.5 (from extra) ``` """ function DataSample(; x=nothing, θ=nothing, y=nothing, extra=NamedTuple(), kwargs...) - maximizer_kwargs = (; kwargs...) - overlap = intersect(keys(maximizer_kwargs), keys(extra)) + context = (; kwargs...) + overlap = intersect(keys(context), keys(extra)) if !isempty(overlap) - error( - "Keys $(collect(overlap)) appear in both maximizer_kwargs and extra of DataSample", - ) + error("Keys $(collect(overlap)) appear in both context and extra of DataSample") end - return DataSample(x, θ, y, maximizer_kwargs, extra) + reserved = (:x, :θ, :y, :context, :extra) + shadowed_ctx = intersect(keys(context), reserved) + if !isempty(shadowed_ctx) + error("Keys $(collect(shadowed_ctx)) in context shadow DataSample struct fields") + end + shadowed_extra = intersect(keys(extra), reserved) + if !isempty(shadowed_extra) + error("Keys $(collect(shadowed_extra)) in extra shadow DataSample struct fields") + end + return DataSample(x, θ, y, context, extra) end """ @@ -77,14 +87,14 @@ $TYPEDSIGNATURES Extended property access for `DataSample`. -Allows accessing `maximizer_kwargs` and `extra` fields directly as properties. -`maximizer_kwargs` is searched first; if the key is not found there, `extra` is searched. +Allows accessing `context` and `extra` fields directly as properties. +`context` is searched first; if the key is not found there, `extra` is searched. """ function Base.getproperty(d::DataSample, name::Symbol) - if name in (:x, :θ, :y, :maximizer_kwargs, :extra) + if name in (:x, :θ, :y, :context, :extra) return getfield(d, name) else - ctx = getfield(d, :maximizer_kwargs) + ctx = getfield(d, :context) if haskey(ctx, name) return getproperty(ctx, name) end @@ -96,12 +106,12 @@ end $TYPEDSIGNATURES Return all property names of a `DataSample`, including both struct fields and forwarded -fields from `maximizer_kwargs` and `extra`. +fields from `context` and `extra`. This enables tab completion for all available properties. """ function Base.propertynames(d::DataSample, private::Bool=false) - ctx_names = propertynames(getfield(d, :maximizer_kwargs), private) + ctx_names = propertynames(getfield(d, :context), private) extra_names = propertynames(getfield(d, :extra), private) return (fieldnames(DataSample)..., ctx_names..., extra_names...) end @@ -122,13 +132,13 @@ function Base.show(io::IO, d::DataSample) end if !isnothing(d.θ) θ_str = sprint(show, d.θ; context=io_limited) - push!(fields, "θ_true=$θ_str") + push!(fields, "θ=$θ_str") end if !isnothing(d.y) y_str = sprint(show, d.y; context=io_limited) - push!(fields, "y_true=$y_str") + push!(fields, "y=$y_str") end - for (key, value) in pairs(d.maximizer_kwargs) + for (key, value) in pairs(d.context) value_str = sprint(show, value; context=io_limited) push!(fields, "$key=$value_str") end @@ -156,8 +166,8 @@ Transform the features in the dataset. """ function StatsBase.transform(t, dataset::AbstractVector{<:DataSample}) return map(dataset) do d - (; maximizer_kwargs, extra, x, θ, y) = d - DataSample(StatsBase.transform(t, x), θ, y, maximizer_kwargs, extra) + (; context, extra, x, θ, y) = d + DataSample(; x=StatsBase.transform(t, x), θ, y, context..., extra) end end @@ -179,8 +189,8 @@ Reconstruct the features in the dataset. """ function StatsBase.reconstruct(t, dataset::AbstractVector{<:DataSample}) return map(dataset) do d - (; maximizer_kwargs, extra, x, θ, y) = d - DataSample(StatsBase.reconstruct(t, x), θ, y, maximizer_kwargs, extra) + (; context, extra, x, θ, y) = d + DataSample(StatsBase.reconstruct(t, x), θ, y, context, extra) end end diff --git a/src/Utils/interface.jl b/src/Utils/interface.jl index 206ca4f..6af5979 100644 --- a/src/Utils/interface.jl +++ b/src/Utils/interface.jl @@ -8,7 +8,10 @@ Choose one of three primary implementation strategies: - Implement [`generate_instance`](@ref) (returns a [`DataSample`](@ref) with `y=nothing`). The default [`generate_sample`](@ref) forwards the call directly; [`generate_dataset`](@ref) applies `target_policy` afterwards if provided. -- Override [`generate_sample`](@ref) directly when the sample requires custom logic. +- Override [`generate_sample`](@ref) directly when the sample requires custom logic + that cannot be expressed via [`generate_instance`](@ref). Applies to static benchmarks + only, stochastic benchmarks should implement the finer-grained hooks instead + ([`generate_instance`](@ref), [`generate_context`](@ref), [`generate_scenario`](@ref)). [`generate_dataset`](@ref) applies `target_policy` to the result after the call returns. - Override [`generate_dataset`](@ref) directly when samples cannot be drawn independently. @@ -51,6 +54,12 @@ Calls [`generate_instance`](@ref) and returns the result directly. Override this method when sample generation requires custom logic. Labeling via `target_policy` is always applied by [`generate_dataset`](@ref) after this call returns. + +!!! note + This is an internal hook called by [`generate_dataset`](@ref). Prefer calling + [`generate_dataset`](@ref) rather than this method directly. For stochastic + benchmarks, implement [`generate_instance`](@ref), [`generate_context`](@ref), + and [`generate_scenario`](@ref) instead of overriding this method. """ function generate_sample(bench::AbstractBenchmark, rng; kwargs...) return generate_instance(bench, rng; kwargs...) @@ -74,7 +83,6 @@ function generate_dataset( rng=MersenneTwister(seed), kwargs..., ) - Random.seed!(rng, seed) return [ begin sample = generate_sample(bench, rng; kwargs...) @@ -88,7 +96,12 @@ end Returns a callable `f(θ; kwargs...) -> y`, solving a maximization problem. """ -function generate_maximizer end +function generate_maximizer(bench::AbstractBenchmark; kwargs...) + return error( + "`generate_maximizer` is not implemented for $(typeof(bench)). " * + "Implement `generate_maximizer(::$(typeof(bench)); kwargs...) -> f(θ; kwargs...) -> y`.", + ) +end """ generate_statistical_model(::AbstractBenchmark, seed=nothing; kwargs...) @@ -97,7 +110,12 @@ Returns an untrained statistical model (usually a Flux neural network) that maps feature matrix `x` to an output array `θ`. The `seed` parameter controls initialization randomness for reproducibility. """ -function generate_statistical_model end +function generate_statistical_model(bench::AbstractBenchmark, seed=nothing; kwargs...) + return error( + "`generate_statistical_model` is not implemented for $(typeof(bench)). " * + "Implement `generate_statistical_model(::$(typeof(bench)), seed=nothing; kwargs...) -> model`.", + ) +end """ generate_baseline_policies(::AbstractBenchmark) -> NamedTuple or Tuple @@ -131,32 +149,14 @@ Plot the instance with `sample.y` overlaid. Only available when `Plots` is loade """ function plot_solution end -""" - compute_gap(::AbstractBenchmark, dataset::Vector{<:DataSample}, statistical_model, maximizer) -> Float64 - -Compute the average relative optimality gap of the pipeline on the dataset. -""" -function compute_gap end - """ $TYPEDSIGNATURES -Compute `dot(θ, y)`. Override for non-linear objectives. +Compute the objective value of given solution `y` for a specific benchmark. +Must be implemented by each concrete benchmark type. For stochastic benchmarks, +an additional `scenario` argument is required. """ -function objective_value(::AbstractBenchmark, θ::AbstractArray, y::AbstractArray) - return dot(θ, y) -end - -""" -$TYPEDSIGNATURES - -Compute the objective value of given solution `y`. -""" -function objective_value( - bench::AbstractBenchmark, sample::DataSample{CTX,EX,F,S,C}, y::AbstractArray -) where {CTX,EX,F,S,C<:AbstractArray} - return objective_value(bench, sample.θ, y) -end +function objective_value end """ $TYPEDSIGNATURES @@ -173,6 +173,9 @@ end $TYPEDSIGNATURES Check if the benchmark is a minimization problem. + +Defaults to `true`. **Maximization benchmarks must override this method**, forgetting to do +so will cause `compute_gap` to compute the gap with the wrong sign without any error or warning. """ function is_minimization_problem(::AbstractBenchmark) return true @@ -182,11 +185,12 @@ end $TYPEDSIGNATURES Default implementation of [`compute_gap`](@ref): average relative optimality gap over `dataset`. -Requires samples with `x`, `θ`, and `y` fields. Override for custom evaluation logic. +Requires labeled samples (`y ≠ nothing`), `x`, and `context` fields. +Override for custom evaluation logic. """ function compute_gap( bench::AbstractBenchmark, - dataset::AbstractVector{<:DataSample}, + dataset::AbstractVector{<:DataSample{<:Any,<:Any,<:Any,<:AbstractArray}}, statistical_model, maximizer, op=mean, @@ -198,7 +202,7 @@ function compute_gap( target_obj = objective_value(bench, sample) x = sample.x θ = statistical_model(x) - y = maximizer(θ; sample.maximizer_kwargs...) + y = maximizer(θ; sample.context...) obj = objective_value(bench, sample, y) Δ = check ? obj - target_obj : target_obj - obj return Δ / abs(target_obj) @@ -217,11 +221,13 @@ part). Decisions are taken by seeing only the instance. Scenarios are used to ge anticipative targets and compute objective values. # Required methods ([`ExogenousStochasticBenchmark`](@ref) only) -- [`generate_instance`](@ref)`(bench, rng)`: returns a [`DataSample`](@ref) with instance - and features but **no scenario**. Scenarios are added later by [`generate_dataset`](@ref) - via [`generate_scenario`](@ref). +- [`generate_instance`](@ref)`(bench, rng)`: returns a [`DataSample`](@ref) with the + problem instance (solver kwargs) and, if not overriding [`generate_context`](@ref), + the ML features `x`. Scenarios are added later by [`generate_dataset`](@ref) via + [`generate_scenario`](@ref). When [`generate_context`](@ref) is overridden, `x` may + be absent here and constructed there instead. - [`generate_scenario`](@ref)`(bench, rng; kwargs...)`: draws a random scenario. - Solver kwargs are spread from `sample.maximizer_kwargs`; context latents from `ctx.extra`. + Solver kwargs are spread from `ctx.context`. # Optional methods - [`generate_context`](@ref)`(bench, rng, instance_sample)`: enriches the instance with @@ -234,16 +240,16 @@ anticipative targets and compute objective values. # Dataset generation (exogenous only) [`generate_dataset`](@ref) is specialised for [`ExogenousStochasticBenchmark`](@ref) and -supports all three standard structures via `nb_scenarios` and `nb_contexts`: +supports all three standard structures via `nb_scenarios` and `contexts_per_instance`: | Setting | Call | |---------|------| | 1 instance with K scenarios | `generate_dataset(bench, 1; nb_scenarios=K)` | | N instances with 1 scenario | `generate_dataset(bench, N)` (default) | | N instances with K scenarios | `generate_dataset(bench, N; nb_scenarios=K)` | -| N instances with M contexts × K scenarios | `generate_dataset(bench, N; nb_contexts=M, nb_scenarios=K)` | +| N instances with M contexts × K scenarios | `generate_dataset(bench, N; contexts_per_instance=M, nb_scenarios=K)` | -By default (no `target_policy`), each [`DataSample`](@ref) has `maximizer_kwargs` holding +By default (no `target_policy`), each [`DataSample`](@ref) has `context` holding the solver kwargs and `extra=(; scenario)` holding one scenario. Provide a `target_policy(ctx_sample, scenarios) -> Vector{DataSample}` @@ -265,9 +271,9 @@ const EndogenousStochasticBenchmark = AbstractStochasticBenchmark{false} generate_scenario(::ExogenousStochasticBenchmark, rng::AbstractRNG; kwargs...) -> scenario Draw a random scenario. Solver kwargs are passed as keyword arguments spread from -`sample.maximizer_kwargs`, and context latents (if any) are spread from `ctx.extra`: +`sample.context`: - ξ = generate_scenario(bench, rng; ctx.extra..., ctx.maximizer_kwargs...) + ξ = generate_scenario(bench, rng; ctx.context...) """ function generate_scenario end @@ -289,12 +295,12 @@ function generate_context(bench::MyBench, rng, instance_sample::DataSample) x_raw = randn(rng, Float32, bench.d) return DataSample(; x=vcat(instance_sample.x, x_raw), - instance_sample.maximizer_kwargs..., - extra=(; x_raw), + instance_sample.context..., + x_raw, ) end ``` -Fields in `.extra` are spread into [`generate_scenario`](@ref) as kwargs. +Fields in `.context` are spread into [`generate_scenario`](@ref) as kwargs. """ function generate_context(::AbstractStochasticBenchmark, rng, instance_sample::DataSample) return instance_sample @@ -303,13 +309,17 @@ end """ generate_anticipative_solver(::AbstractBenchmark) -> callable -Return a callable that computes the anticipative solution. +Return a callable that computes the anticipative (oracle) solution. +The calling convention differs by benchmark category: + +**Stochastic benchmarks** ([`AbstractStochasticBenchmark`](@ref)): +Returns `(scenario; context...) -> y`. +Called once per scenario to obtain the optimal label. -- For [`AbstractStochasticBenchmark`](@ref): returns `(scenario; context...) -> y`. -- For [`AbstractDynamicBenchmark`](@ref): returns - `(env; reset_env=true, kwargs...) -> Vector{DataSample}`, a full training trajectory. - `reset_env=true` resets the env before solving (initial dataset building); - `reset_env=false` starts from the current env state. +**Dynamic benchmarks** ([`AbstractDynamicBenchmark`](@ref)): +Returns `(env; reset_env=true, kwargs...) -> Vector{DataSample}`, a full trajectory. +`reset_env=true` resets the environment before solving (used for initial dataset building); +`reset_env=false` starts from the current environment state (used inside DAgger rollouts). """ function generate_anticipative_solver end @@ -330,7 +340,7 @@ Default [`generate_sample`](@ref) for exogenous stochastic benchmarks. Calls [`generate_instance`](@ref), then [`generate_context`](@ref) (default: identity), draws scenarios via [`generate_scenario`](@ref), then: -- Without `target_policy`: returns M×K unlabeled samples (`nb_contexts` contexts × +- Without `target_policy`: returns M×K unlabeled samples (`contexts_per_instance` contexts × `nb_scenarios` scenarios each), each with one scenario in `extra=(; scenario=ξ)`. - With `target_policy`: calls `target_policy(ctx_sample, scenarios)` per context and returns the result. @@ -338,41 +348,50 @@ draws scenarios via [`generate_scenario`](@ref), then: `target_policy(ctx_sample, scenarios) -> Vector{DataSample}` enables anticipative labeling (K samples, one per scenario) or SAA (1 sample aggregating all K scenarios). + +!!! note + This is an internal override of [`generate_sample`](@ref) for the stochastic pipeline, + called by [`generate_dataset`](@ref). New stochastic benchmarks should implement + [`generate_instance`](@ref), [`generate_context`](@ref), and [`generate_scenario`](@ref) + rather than overriding this method. Note that the return type is `Vector{DataSample}` + (one per context × scenario combination), unlike the base method which returns a + single `DataSample`. """ function generate_sample( bench::ExogenousStochasticBenchmark, rng; target_policy=nothing, nb_scenarios::Int=1, - nb_contexts::Int=1, + contexts_per_instance::Int=1, kwargs..., ) instance_sample = generate_instance(bench, rng; kwargs...) - result = DataSample[] - for _ in 1:nb_contexts - ctx = generate_context(bench, rng, instance_sample) - if isnothing(target_policy) - for _ in 1:nb_scenarios - ξ = generate_scenario(bench, rng; ctx.extra..., ctx.maximizer_kwargs...) - push!( - result, - DataSample(; - x=ctx.x, - θ=ctx.θ, - ctx.maximizer_kwargs..., - extra=(; ctx.extra..., scenario=ξ), - ), - ) - end - else - scenarios = [ - generate_scenario(bench, rng; ctx.extra..., ctx.maximizer_kwargs...) for - _ in 1:nb_scenarios - ] - append!(result, target_policy(ctx, scenarios)) - end - end - return result + return reduce( + vcat, + ( + let ctx = generate_context(bench, rng, instance_sample) + if isnothing(target_policy) + [ + DataSample(; + x=ctx.x, + θ=ctx.θ, + ctx.context..., + extra=(; + ctx.extra..., + scenario=generate_scenario(bench, rng; ctx.context...), + ), + ) for _ in 1:nb_scenarios + ] + else + scenarios = [ + generate_scenario(bench, rng; ctx.context...) for + _ in 1:nb_scenarios + ] + target_policy(ctx, scenarios) + end + end for _ in 1:contexts_per_instance + ), + ) end """ @@ -380,7 +399,7 @@ $TYPEDSIGNATURES Specialised [`generate_dataset`](@ref) for exogenous stochastic benchmarks. -Generates `nb_instances` problem instances, each with `nb_contexts` context draws +Generates `nb_instances` problem instances, each with `contexts_per_instance` context draws and `nb_scenarios` scenario draws per context. The scenario→sample mapping is controlled by the `target_policy`: - Without `target_policy` (default): M contexts × K scenarios produce M×K unlabeled @@ -391,7 +410,7 @@ by the `target_policy`: # Keyword arguments - `nb_scenarios::Int = 1`: scenarios per context (K). -- `nb_contexts::Int = 1`: context draws per instance (M). +- `contexts_per_instance::Int = 1`: context draws per instance (M). - `target_policy`: when provided, called as `target_policy(ctx_sample, scenarios)` to compute labels. Defaults to `nothing` (unlabeled samples). @@ -404,20 +423,19 @@ function generate_dataset( nb_instances::Int; target_policy=nothing, nb_scenarios::Int=1, - nb_contexts::Int=1, + contexts_per_instance::Int=1, seed=nothing, rng=MersenneTwister(seed), kwargs..., ) - Random.seed!(rng, seed) - samples = DataSample[] - for _ in 1:nb_instances - new_samples = generate_sample( - bench, rng; target_policy, nb_scenarios, nb_contexts, kwargs... - ) - append!(samples, new_samples) - end - return samples + return reduce( + vcat, + ( + generate_sample( + bench, rng; target_policy, nb_scenarios, contexts_per_instance, kwargs... + ) for _ in 1:nb_instances + ), + ) end """ @@ -430,6 +448,12 @@ For each (instance, context) pair, draws `nb_scenarios` fixed scenarios. These a in the sample and used for feature computation, target labeling (via `target_policy`), and gap evaluation. +!!! note + `SampleAverageApproximation <: AbstractBenchmark`, not `AbstractStochasticBenchmark`. + This is intentional: after wrapping, the scenarios are fixed at dataset-generation time + and the benchmark behaves as a static problem. Functions dispatching on + `AbstractStochasticBenchmark` (e.g. `is_exogenous`) will not match SAA instances. + # Fields $TYPEDFIELDS """ @@ -459,15 +483,10 @@ function generate_sample( instance_sample = generate_instance(inner, rng; kwargs...) ctx = generate_context(inner, rng, instance_sample) scenarios = [ - generate_scenario(inner, rng; ctx.extra..., ctx.maximizer_kwargs...) for - _ in 1:(saa.nb_scenarios) + generate_scenario(inner, rng; ctx.context...) for _ in 1:(saa.nb_scenarios) ] if isnothing(target_policy) - return [ - DataSample(; - x=ctx.x, ctx.maximizer_kwargs..., extra=(; ctx.extra..., scenarios) - ), - ] + return [DataSample(; x=ctx.x, ctx.context..., extra=(; ctx.extra..., scenarios))] else return target_policy(ctx, scenarios) end @@ -492,12 +511,9 @@ function generate_dataset( rng=MersenneTwister(seed), kwargs..., ) - Random.seed!(rng, seed) - samples = DataSample[] - for _ in 1:nb_instances - append!(samples, generate_sample(saa, rng; target_policy, kwargs...)) - end - return samples + return reduce( + vcat, (generate_sample(saa, rng; target_policy, kwargs...) for _ in 1:nb_instances) + ) end """ @@ -508,7 +524,9 @@ Evaluate a decision `y` against stored scenarios (average over scenarios). function objective_value( saa::SampleAverageApproximation, sample::DataSample, y::AbstractArray ) - return mean(objective_value(saa.benchmark, ξ, y) for ξ in sample.extra.scenarios) + return mean( + objective_value(saa.benchmark, sample, y, ξ) for ξ in sample.extra.scenarios + ) end """ @@ -537,6 +555,8 @@ meaning (whether uncertainty is independent of decisions). # Additional optional methods - [`generate_environment`](@ref)`(bench, rng)`: initialize a single rollout environment. + Must return an [`AbstractEnvironment`](@ref) (see `environment.jl` for the full protocol: + [`reset!`](@ref), [`observe`](@ref), [`step!`](@ref), [`is_terminated`](@ref)). Implement this instead of overriding [`generate_environments`](@ref) when environments can be drawn independently. - [`generate_baseline_policies`](@ref)`(bench)`: returns named baseline callables of @@ -551,6 +571,20 @@ meaning (whether uncertainty is independent of decisions). """ abstract type AbstractDynamicBenchmark{exogenous} <: AbstractStochasticBenchmark{exogenous} end +""" +$TYPEDSIGNATURES + +Intercepts accidental calls to `generate_sample` on dynamic benchmarks and throws a +descriptive error pointing at the correct entry point. +""" +function generate_sample(bench::AbstractDynamicBenchmark, rng; kwargs...) + return error( + "`generate_sample` is not supported for dynamic benchmarks ($(typeof(bench))). " * + "Use `generate_environments` and " * + "`generate_dataset(bench, environments; target_policy=...)` instead.", + ) +end + "Alias for [`AbstractDynamicBenchmark`](@ref)`{true}`. Uncertainty is independent of decisions." const ExogenousDynamicBenchmark = AbstractDynamicBenchmark{true} @@ -558,7 +592,7 @@ const ExogenousDynamicBenchmark = AbstractDynamicBenchmark{true} const EndogenousDynamicBenchmark = AbstractDynamicBenchmark{false} """ - generate_environment(::AbstractDynamicBenchmark, rng::AbstractRNG; kwargs...) + generate_environment(::AbstractDynamicBenchmark, rng::AbstractRNG; kwargs...) -> AbstractEnvironment Initialize a single environment for the given dynamic benchmark. Primary implementation target for the count-based [`generate_environments`](@ref) default. @@ -581,7 +615,6 @@ function generate_environments( rng=MersenneTwister(seed), kwargs..., ) - Random.seed!(rng, seed) return [generate_environment(bench, rng; kwargs...) for _ in 1:n] end @@ -602,20 +635,9 @@ to obtain standard baseline callables (e.g. the anticipative solver). - `rng`: random number generator. """ function generate_dataset( - bench::ExogenousDynamicBenchmark, - environments::AbstractVector; - target_policy, - seed=nothing, - rng=MersenneTwister(seed), - kwargs..., + bench::ExogenousDynamicBenchmark, environments::AbstractVector; target_policy, kwargs... ) - Random.seed!(rng, seed) - samples = DataSample[] - for env in environments - trajectory = target_policy(env) - append!(samples, trajectory) - end - return samples + return reduce(vcat, (target_policy(env) for env in environments)) end """ diff --git a/src/Warcraft/Warcraft.jl b/src/Warcraft/Warcraft.jl index bd37418..6be07f9 100644 --- a/src/Warcraft/Warcraft.jl +++ b/src/Warcraft/Warcraft.jl @@ -24,8 +24,8 @@ Does not have any field. """ struct WarcraftBenchmark <: AbstractBenchmark end -function Utils.objective_value(::WarcraftBenchmark, θ::AbstractArray, y::AbstractArray) - return -dot(θ, y) +function Utils.objective_value(::WarcraftBenchmark, sample::DataSample, y::AbstractArray) + return -dot(sample.θ, y) end """ diff --git a/test/argmax.jl b/test/argmax.jl index afde005..d772c4d 100644 --- a/test/argmax.jl +++ b/test/argmax.jl @@ -24,7 +24,7 @@ @test size(x) == (nb_features, instance_dim) @test length(θ_true) == instance_dim @test length(y_true) == instance_dim - @test isempty(sample.maximizer_kwargs) + @test isempty(sample.context) @test all(y_true .== maximizer(θ_true)) θ = model(x) diff --git a/test/contextual_stochastic_argmax.jl b/test/contextual_stochastic_argmax.jl index 7fe6061..980816d 100644 --- a/test/contextual_stochastic_argmax.jl +++ b/test/contextual_stochastic_argmax.jl @@ -4,11 +4,11 @@ b = ContextualStochasticArgmaxBenchmark(; n=5, d=3, seed=0) # Unlabeled: N instances × M contexts × K scenarios = N*M*K samples - dataset = generate_dataset(b, 10; nb_contexts=2, nb_scenarios=4) + dataset = generate_dataset(b, 10; contexts_per_instance=2, nb_scenarios=4) @test length(dataset) == 80 sample = first(dataset) @test size(sample.x) == (8,) # n+d - @test sample.x ≈ vcat(sample.c_base, sample.extra.x_raw) # features = [c_base; x_raw] + @test sample.x ≈ vcat(sample.c_base, sample.x_raw) # features = [c_base; x_raw] @test sample.y === nothing @test sample.scenario isa AbstractVector{Float32} && length(sample.scenario) == 5 @@ -21,7 +21,7 @@ policy = (ctx_sample, scenarios) -> [ DataSample(; - ctx_sample.maximizer_kwargs..., + ctx_sample.context..., x=ctx_sample.x, y=maximizer(s), extra=(; ctx_sample.extra..., scenario=s), @@ -69,7 +69,7 @@ end maximizer = generate_maximizer(saa) labeled = map(dataset) do s y_saa = maximizer(mean(s.scenarios)) - DataSample(; s.maximizer_kwargs..., x=s.x, y=y_saa, extra=s.extra) + DataSample(; s.context..., x=s.x, y=y_saa, extra=s.extra) end @test sum(first(labeled).y) ≈ 1.0 @@ -78,3 +78,13 @@ end gap = compute_gap(saa, labeled, model, maximizer) @test isfinite(gap) end + +@testset "csa_objective_value_error" begin + using DecisionFocusedLearningBenchmarks + + b = ContextualStochasticArgmaxBenchmark(; n=5, d=3, seed=0) + maximizer = generate_maximizer(b) + # Sample with neither :scenario nor :scenarios in extra → objective_value should error + s = DataSample(; x=randn(Float32, 8), y=maximizer(randn(Float32, 5))) + @test_throws Exception objective_value(b, s, s.y) +end diff --git a/test/dynamic_assortment.jl b/test/dynamic_assortment.jl index a617170..3189f57 100644 --- a/test/dynamic_assortment.jl +++ b/test/dynamic_assortment.jl @@ -321,14 +321,14 @@ end # vector-of-environments overload dataset = generate_dataset(b, envs; target_policy=target_policy) - @test dataset isa Vector{DataSample} + @test dataset isa Vector{D} where {D<:DataSample} @test !isempty(dataset) @test all(!isnothing(s.x) for s in dataset) @test all(!isnothing(s.y) for s in dataset) # count-based wrapper dataset2 = generate_dataset(b, 3; seed=7, target_policy=target_policy) - @test dataset2 isa Vector{DataSample} + @test dataset2 isa Vector{D} where {D<:DataSample} @test !isempty(dataset2) end @@ -342,7 +342,7 @@ end # Test integration with sample data sample = generate_sample(b, MersenneTwister(42)) - @test hasfield(typeof(sample), :maximizer_kwargs) + @test hasfield(typeof(sample), :context) environments = generate_environments(b, 3; seed=42) diff --git a/test/dynamic_vsp.jl b/test/dynamic_vsp.jl index 9c208a9..ca67de8 100644 --- a/test/dynamic_vsp.jl +++ b/test/dynamic_vsp.jl @@ -63,14 +63,14 @@ end # vector-of-environments overload dataset = generate_dataset(b, envs; target_policy=target_policy) - @test dataset isa Vector{DataSample} + @test dataset isa Vector{D} where {D<:DataSample} @test !isempty(dataset) @test all(!isnothing(s.x) for s in dataset) @test all(!isnothing(s.y) for s in dataset) # count-based wrapper dataset2 = generate_dataset(b, 3; seed=1, target_policy=target_policy) - @test dataset2 isa Vector{DataSample} + @test dataset2 isa Vector{D} where {D<:DataSample} @test !isempty(dataset2) # seed keyword is forwarded: same seed → same dataset diff --git a/test/fixed_size_shortest_path.jl b/test/fixed_size_shortest_path.jl index d8a547f..eacdd64 100644 --- a/test/fixed_size_shortest_path.jl +++ b/test/fixed_size_shortest_path.jl @@ -25,7 +25,7 @@ @test size(x) == (p,) @test length(θ_true) == A @test length(y_true) == A - @test isempty(sample.maximizer_kwargs) + @test isempty(sample.context) @test all(y_true .== maximizer(θ_true)) θ = model(x) @test length(θ) == length(θ_true) diff --git a/test/maintenance.jl b/test/maintenance.jl index 80e648a..a2a9983 100644 --- a/test/maintenance.jl +++ b/test/maintenance.jl @@ -197,7 +197,7 @@ end # Test integration with sample data sample = generate_sample(b, MersenneTwister(42)) - @test hasfield(typeof(sample), :maximizer_kwargs) + @test hasfield(typeof(sample), :context) environments = generate_environments(b, 3; seed=42) diff --git a/test/portfolio_optimization.jl b/test/portfolio_optimization.jl index 023214f..b436c81 100644 --- a/test/portfolio_optimization.jl +++ b/test/portfolio_optimization.jl @@ -16,7 +16,7 @@ @test size(x) == (p,) @test length(θ_true) == d @test length(y_true) == d - @test isempty(sample.maximizer_kwargs) + @test isempty(sample.context) @test all(y_true .== maximizer(θ_true)) θ = model(x) @@ -26,4 +26,7 @@ @test length(y) == d @test sum(y) <= 1 + 1e-6 end + + gap = compute_gap(b, dataset[1:5], model, maximizer) + @test isfinite(gap) end diff --git a/test/ranking.jl b/test/ranking.jl index 8045cc6..b3c2a3b 100644 --- a/test/ranking.jl +++ b/test/ranking.jl @@ -21,7 +21,7 @@ @test size(x) == (nb_features, instance_dim) @test length(θ_true) == instance_dim @test length(y_true) == instance_dim - @test isempty(sample.maximizer_kwargs) + @test isempty(sample.context) @test all(y_true .== maximizer(θ_true)) θ = model(x) @@ -30,4 +30,8 @@ y = maximizer(θ) @test length(y) == instance_dim end + + gap = compute_gap(b, dataset[1:5], model, maximizer) + @test isfinite(gap) + @test gap >= 0 end diff --git a/test/subset_selection.jl b/test/subset_selection.jl index da6e677..90d3150 100644 --- a/test/subset_selection.jl +++ b/test/subset_selection.jl @@ -23,7 +23,7 @@ @test size(x) == (n,) @test length(θ_true) == n @test length(y_true) == n - @test isempty(sample.maximizer_kwargs) + @test isempty(sample.context) @test all(y_true .== maximizer(θ_true)) # Features and true weights should be equal @@ -36,4 +36,8 @@ @test length(y) == n @test sum(y) == k end + + gap = compute_gap(b, dataset[1:5], model, maximizer) + @test isfinite(gap) + @test gap >= 0 end diff --git a/test/utils.jl b/test/utils.jl index 9e92457..393335e 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -54,11 +54,11 @@ end show(io, sample) s = String(take!(io)) @test occursin("DataSample(", s) - @test occursin("θ_true", s) - @test occursin("y_true", s) + @test occursin("θ", s) + @test occursin("y", s) @test occursin("instance=\"this is an instance\"", s) - @test propertynames(sample) == (:x, :θ, :y, :maximizer_kwargs, :extra, :instance) + @test propertynames(sample) == (:x, :θ, :y, :context, :extra, :instance) # Create a dataset for testing N = 5 @@ -81,7 +81,7 @@ end for i in 1:N @test dataset_zt[i].θ == dataset[i].θ @test dataset_zt[i].y == dataset[i].y - @test dataset_zt[i].maximizer_kwargs == dataset[i].maximizer_kwargs + @test dataset_zt[i].context == dataset[i].context end # Check that features are actually transformed @@ -97,7 +97,7 @@ end for i in 1:N @test dataset_copy[i].θ == dataset[i].θ @test dataset_copy[i].y == dataset[i].y - @test dataset_copy[i].maximizer_kwargs == dataset[i].maximizer_kwargs + @test dataset_copy[i].context == dataset[i].context end # Test reconstruct (non-mutating) @@ -109,7 +109,7 @@ end @test dataset_reconstructed[i].x ≈ dataset[i].x atol = 1e-10 @test dataset_reconstructed[i].θ == dataset[i].θ @test dataset_reconstructed[i].y == dataset[i].y - @test dataset_reconstructed[i].maximizer_kwargs == dataset[i].maximizer_kwargs + @test dataset_reconstructed[i].context == dataset[i].context end # Test reconstruct! (mutating) @@ -117,6 +117,16 @@ end for i in 1:N @test dataset_zt[i].x ≈ dataset[i].x atol = 1e-10 end + + # Error: overlap between context and extra + @test_throws Exception DataSample(x=[1], foo=1, extra=(; foo=2)) + # Error: reserved name used as context kwarg + @test_throws Exception DataSample(x=[1], context=[2]) + # Error: reserved name in extra + @test_throws Exception DataSample(x=[1], extra=(; x=[2])) + + # is_minimization_problem: maximization benchmarks override to false + @test is_minimization_problem(ArgmaxBenchmark()) == false end @testset "Maximizers" begin diff --git a/test/vsp.jl b/test/vsp.jl index 1a27d36..fd10c5e 100644 --- a/test/vsp.jl +++ b/test/vsp.jl @@ -91,6 +91,6 @@ anticipative_solver = generate_anticipative_solver(b) sample = unlabeled[1] - y_anticipative = anticipative_solver(sample.scenario; sample.maximizer_kwargs...) + y_anticipative = anticipative_solver(sample.scenario; sample.context...) @test y_anticipative isa BitVector end diff --git a/test/warcraft.jl b/test/warcraft.jl index f361eb3..94a4678 100644 --- a/test/warcraft.jl +++ b/test/warcraft.jl @@ -29,7 +29,7 @@ y_true = sample.y @test size(x) == (96, 96, 3, 1) @test all(θ_true .<= 0) - @test isempty(sample.maximizer_kwargs) + @test isempty(sample.context) θ = model(x) @test size(θ) == size(θ_true) @@ -39,19 +39,19 @@ y_dijkstra = dijkstra_maximizer(θ) @test size(y_bellman) == size(y_true) @test size(y_dijkstra) == size(y_true) - @test objective_value(b, θ_true, y_bellman) == - objective_value(b, θ_true, y_dijkstra) + @test objective_value(b, sample, y_bellman) == + objective_value(b, sample, y_dijkstra) y_bellman_true = bellman_maximizer(θ_true) y_dijkstra_true = dijkstra_maximizer(θ_true) - @test objective_value(b, θ_true, y_true) == - objective_value(b, θ_true, y_dijkstra_true) + @test objective_value(b, sample, y_true) == + objective_value(b, sample, y_dijkstra_true) if i == 32 # TODO: bellman seems to be broken for some edge cases ? - @test_broken objective_value(b, θ_true, y_bellman_true) == - objective_value(b, θ_true, y_dijkstra_true) + @test_broken objective_value(b, sample, y_bellman_true) == + objective_value(b, sample, y_dijkstra_true) else - @test objective_value(b, θ_true, y_bellman_true) == - objective_value(b, θ_true, y_dijkstra_true) + @test objective_value(b, sample, y_bellman_true) == + objective_value(b, sample, y_dijkstra_true) end end end