Skip to content
4 changes: 2 additions & 2 deletions docs/src/tutorials/warcraft_tutorial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ x = sample.x
θ_true = sample.θ
# `y` correspond to the optimal shortest path, encoded as a binary matrix:
y_true = sample.y
# `maximizer_kwargs` is not used in this benchmark (no solver kwargs needed), so it is empty:
isempty(sample.maximizer_kwargs)
# `context` is not used in this benchmark (no solver kwargs needed), so it is empty:
isempty(sample.context)

# For some benchmarks, we provide the following plotting method [`plot_solution`](@ref) to visualize the data:
plot_solution(b, sample)
Expand Down
4 changes: 1 addition & 3 deletions ext/DFLBenchmarksPlotsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ Reconstruct a new sample with `y` overridden and delegate to the 2-arg
function plot_solution(bench::AbstractBenchmark, sample::DataSample, y; kwargs...)
return plot_solution(
bench,
DataSample(;
sample.maximizer_kwargs..., x=sample.x, θ=sample.θ, y=y, extra=sample.extra
);
DataSample(; sample.context..., x=sample.x, θ=sample.θ, y=y, extra=sample.extra);
kwargs...,
)
end
Expand Down
4 changes: 4 additions & 0 deletions src/Argmax/Argmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
using Flux: Chain, Dense
using Random

using LinearAlgebra: dot

"""
$TYPEDEF

Expand All @@ -29,6 +31,8 @@ function Base.show(io::IO, bench::ArgmaxBenchmark)
)
end

Utils.objective_value(::ArgmaxBenchmark, sample::DataSample, y) = dot(sample.θ, y)

"""
$TYPEDSIGNATURES

Expand Down
2 changes: 2 additions & 0 deletions src/Argmax2D/Argmax2D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ function Base.show(io::IO, bench::Argmax2DBenchmark)
return print(io, "Argmax2DBenchmark(nb_features=$nb_features)")
end

Utils.objective_value(::Argmax2DBenchmark, sample::DataSample, y) = dot(sample.θ, y)

"""
$TYPEDSIGNATURES

Expand Down
30 changes: 24 additions & 6 deletions src/ContextualStochasticArgmax/ContextualStochasticArgmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module ContextualStochasticArgmax
using ..Utils
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
using Flux: Dense
using LinearAlgebra: dot
using Random: Random, AbstractRNG, MersenneTwister
using Statistics: mean

Expand All @@ -11,7 +12,7 @@ $TYPEDEF

Minimal contextual stochastic argmax benchmark.

Per instance: `c_base ~ U[0,1]^n` (base utility, stored in `extra` of the instance sample).
Per instance: `c_base ~ U[0,1]^n` (base utility, stored in `context` of the instance sample).
Per context draw: `x_raw ~ N(0, I_d)` (observable context). Features: `x = [c_base; x_raw]`.
Per scenario: `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`.
The learner sees `x` and must predict `θ̂` so that `argmax(θ̂)` ≈ `argmax(ξ)`.
Expand Down Expand Up @@ -44,25 +45,42 @@ end
Utils.is_minimization_problem(::ContextualStochasticArgmaxBenchmark) = false
Utils.generate_maximizer(::ContextualStochasticArgmaxBenchmark) = one_hot_argmax

function Utils.objective_value(
::ContextualStochasticArgmaxBenchmark, sample::DataSample, y, scenario
)
return dot(scenario, y)
end

function Utils.objective_value(
bench::ContextualStochasticArgmaxBenchmark, sample::DataSample, y
)
if hasproperty(sample.extra, :scenario)
return Utils.objective_value(bench, sample, y, sample.scenario)
elseif hasproperty(sample.extra, :scenarios)
return mean(Utils.objective_value(bench, sample, y, ξ) for ξ in sample.scenarios)
end
return error("Sample must have scenario or scenarios")
end

"""
generate_instance(::ContextualStochasticArgmaxBenchmark, rng)

Draw `c_base ~ U[0,1]^n` and store it in `extra`. No solver kwargs are needed
Draw `c_base ~ U[0,1]^n` and store it in `context`. No solver kwargs are needed
(the maximizer is `one_hot_argmax`, which takes no kwargs).
"""
function Utils.generate_instance(
bench::ContextualStochasticArgmaxBenchmark, rng::AbstractRNG; kwargs...
)
c_base = rand(rng, Float32, bench.n)
return DataSample(; extra=(; c_base))
return DataSample(; c_base)
end

"""
generate_context(::ContextualStochasticArgmaxBenchmark, rng, instance_sample)

Draw `x_raw ~ N(0, I_d)` and return a context sample with:
- `x = [c_base; x_raw]`: full feature vector seen by the ML model.
- `extra = (; c_base, x_raw)`: latents spread into [`generate_scenario`](@ref).
- `c_base`, `x_raw` in `context`: spread into [`generate_scenario`](@ref).
"""
function Utils.generate_context(
bench::ContextualStochasticArgmaxBenchmark,
Expand All @@ -71,14 +89,14 @@ function Utils.generate_context(
)
c_base = instance_sample.c_base
x_raw = randn(rng, Float32, bench.d)
return DataSample(; x=vcat(c_base, x_raw), extra=(; x_raw, c_base))
return DataSample(; x=vcat(c_base, x_raw), c_base, x_raw)
end

"""
generate_scenario(::ContextualStochasticArgmaxBenchmark, rng; c_base, x_raw, kwargs...)

Draw `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`.
`c_base` and `x_raw` are spread from `ctx.extra` by the framework.
`c_base` and `x_raw` are spread from `ctx.context` by the framework.
"""
function Utils.generate_scenario(
bench::ContextualStochasticArgmaxBenchmark,
Expand Down
2 changes: 1 addition & 1 deletion src/ContextualStochasticArgmax/policies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function csa_saa_policy(ctx_sample, scenarios)
y = one_hot_argmax(mean(scenarios))
return [
DataSample(;
ctx_sample.maximizer_kwargs...,
ctx_sample.context...,
x=ctx_sample.x,
y=y,
extra=(; ctx_sample.extra..., scenarios),
Expand Down
1 change: 1 addition & 0 deletions src/DecisionFocusedLearningBenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ export generate_statistical_model
export generate_maximizer
export generate_anticipative_solver, generate_parametric_anticipative_solver
export is_exogenous, is_endogenous
export is_minimization_problem

export objective_value
export has_visualization, plot_instance, plot_solution, plot_trajectory, animate_trajectory
Expand Down
4 changes: 2 additions & 2 deletions src/FixedSizeShortestPath/FixedSizeShortestPath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ function FixedSizeShortestPathBenchmark(;
end

function Utils.objective_value(
::FixedSizeShortestPathBenchmark, θ::AbstractArray, y::AbstractArray
::FixedSizeShortestPathBenchmark, sample::DataSample, y::AbstractArray
)
return -dot(θ, y)
return -dot(sample.θ, y)
end

"""
Expand Down
6 changes: 5 additions & 1 deletion src/PortfolioOptimization/PortfolioOptimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Distributions: Uniform, Bernoulli
using Flux: Chain, Dense
using Ipopt: Ipopt
using JuMP: @variable, @objective, @constraint, optimize!, value, Model, set_silent
using LinearAlgebra: I
using LinearAlgebra: I, dot
using Random: Random, AbstractRNG, MersenneTwister

"""
Expand Down Expand Up @@ -38,6 +38,10 @@ struct PortfolioOptimizationBenchmark <: AbstractBenchmark
f::Vector{Float32}
end

function Utils.objective_value(::PortfolioOptimizationBenchmark, sample::DataSample, y)
return dot(sample.θ, y)
end

"""
$TYPEDSIGNATURES

Expand Down
5 changes: 5 additions & 0 deletions src/Ranking/Ranking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
using Flux: Chain, Dense
using Random

using LinearAlgebra: dot

"""
$TYPEDEF

Expand All @@ -29,6 +31,9 @@ function Base.show(io::IO, bench::RankingBenchmark)
)
end

Utils.objective_value(::RankingBenchmark, sample::DataSample, y) = dot(sample.θ, y)
Utils.is_minimization_problem(::RankingBenchmark) = false

"""
$TYPEDSIGNATURES

Expand Down
21 changes: 18 additions & 3 deletions src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,27 @@ end
include("policies.jl")

function Utils.objective_value(
::StochasticVehicleSchedulingBenchmark, sample::DataSample, y::BitVector
::StochasticVehicleSchedulingBenchmark,
sample::DataSample,
y::BitVector,
scenario::VSPScenario,
)
stoch = build_stochastic_instance(sample.instance, sample.extra.scenarios)
stoch = build_stochastic_instance(sample.instance, [scenario])
return evaluate_solution(y, stoch)
end

function Utils.objective_value(
bench::StochasticVehicleSchedulingBenchmark, sample::DataSample, y::BitVector
)
if hasproperty(sample.extra, :scenario)
return Utils.objective_value(bench, sample, y, sample.extra.scenario)
elseif hasproperty(sample.extra, :scenarios)
stoch = build_stochastic_instance(sample.instance, sample.extra.scenarios)
return evaluate_solution(y, stoch)
end
return error("Sample must have scenario or scenarios")
end

"""
$TYPEDSIGNATURES

Expand Down Expand Up @@ -116,7 +131,7 @@ Returns a [`DataSample`](@ref) with features `x` and `instance` set, but `y=noth
To obtain labeled samples, pass a `target_policy` to [`generate_dataset`](@ref):

```julia
policy = sample -> DataSample(; sample.maximizer_kwargs..., x=sample.x,
policy = sample -> DataSample(; sample.context..., x=sample.x,
y=column_generation_algorithm(sample.instance))
dataset = generate_dataset(benchmark, N; target_policy=policy)
```
Expand Down
8 changes: 4 additions & 4 deletions src/StochasticVehicleScheduling/policies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function svs_saa_policy(ctx_sample, scenarios)
y = column_generation_algorithm(stochastic_inst)
return [
DataSample(;
ctx_sample.maximizer_kwargs...,
ctx_sample.context...,
x=ctx_sample.x,
y,
extra=(; ctx_sample.extra..., scenarios),
Expand All @@ -28,7 +28,7 @@ function svs_deterministic_policy(ctx_sample, scenarios; model_builder=highs_mod
y = deterministic_mip(ctx_sample.instance; model_builder)
return [
DataSample(;
ctx_sample.maximizer_kwargs...,
ctx_sample.context...,
x=ctx_sample.x,
y,
extra=(; ctx_sample.extra..., scenarios),
Expand All @@ -48,7 +48,7 @@ function svs_local_search_policy(ctx_sample, scenarios)
y = local_search(stochastic_inst)
return [
DataSample(;
ctx_sample.maximizer_kwargs...,
ctx_sample.context...,
x=ctx_sample.x,
y,
extra=(; ctx_sample.extra..., scenarios),
Expand All @@ -70,7 +70,7 @@ function svs_saa_mip_policy(ctx_sample, scenarios; model_builder=scip_model)
y = compact_linearized_mip(ctx_sample.instance, scenarios; model_builder)
return [
DataSample(;
ctx_sample.maximizer_kwargs...,
ctx_sample.context...,
x=ctx_sample.x,
y,
extra=(; ctx_sample.extra..., scenarios),
Expand Down
4 changes: 4 additions & 0 deletions src/SubsetSelection/SubsetSelection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module SubsetSelection
using ..Utils
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
using Flux: Chain, Dense
using LinearAlgebra: dot
using Random

"""
Expand Down Expand Up @@ -31,6 +32,9 @@ function Base.show(io::IO, bench::SubsetSelectionBenchmark)
return print(io, "SubsetSelectionBenchmark(n=$n, k=$k)")
end

Utils.objective_value(::SubsetSelectionBenchmark, sample::DataSample, y) = dot(sample.θ, y)
Utils.is_minimization_problem(::SubsetSelectionBenchmark) = false

function SubsetSelectionBenchmark(; n::Int=25, k::Int=5, identity_mapping::Bool=true)
@assert n >= k "number of items n must be greater than k"
mapping = if identity_mapping
Expand Down
1 change: 1 addition & 0 deletions src/Utils/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export generate_environment, generate_environments
export SampleAverageApproximation
export generate_baseline_policies
export generate_anticipative_solver, generate_parametric_anticipative_solver
export is_minimization_problem

export has_visualization, plot_instance, plot_solution, plot_trajectory, animate_trajectory
export compute_gap
Expand Down
Loading
Loading