Skip to content

Commit 341ed4a

Browse files
authored
Merge pull request #62 from JuliaDecisionFocusedLearning/cleanup
Cleanup and reinforce the interface
2 parents 4282039 + 2557f07 commit 341ed4a

29 files changed

+315
-200
lines changed

docs/src/tutorials/warcraft_tutorial.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ x = sample.x
3030
θ_true = sample.θ
3131
# `y` correspond to the optimal shortest path, encoded as a binary matrix:
3232
y_true = sample.y
33-
# `maximizer_kwargs` is not used in this benchmark (no solver kwargs needed), so it is empty:
34-
isempty(sample.maximizer_kwargs)
33+
# `context` is not used in this benchmark (no solver kwargs needed), so it is empty:
34+
isempty(sample.context)
3535

3636
# For some benchmarks, we provide the following plotting method [`plot_solution`](@ref) to visualize the data:
3737
plot_solution(b, sample)

ext/DFLBenchmarksPlotsExt.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ Reconstruct a new sample with `y` overridden and delegate to the 2-arg
2121
function plot_solution(bench::AbstractBenchmark, sample::DataSample, y; kwargs...)
2222
return plot_solution(
2323
bench,
24-
DataSample(;
25-
sample.maximizer_kwargs..., x=sample.x, θ=sample.θ, y=y, extra=sample.extra
26-
);
24+
DataSample(; sample.context..., x=sample.x, θ=sample.θ, y=y, extra=sample.extra);
2725
kwargs...,
2826
)
2927
end

src/Argmax/Argmax.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
55
using Flux: Chain, Dense
66
using Random
77

8+
using LinearAlgebra: dot
9+
810
"""
911
$TYPEDEF
1012
@@ -29,6 +31,8 @@ function Base.show(io::IO, bench::ArgmaxBenchmark)
2931
)
3032
end
3133

34+
Utils.objective_value(::ArgmaxBenchmark, sample::DataSample, y) = dot(sample.θ, y)
35+
3236
"""
3337
$TYPEDSIGNATURES
3438

src/Argmax2D/Argmax2D.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ function Base.show(io::IO, bench::Argmax2DBenchmark)
3030
return print(io, "Argmax2DBenchmark(nb_features=$nb_features)")
3131
end
3232

33+
Utils.objective_value(::Argmax2DBenchmark, sample::DataSample, y) = dot(sample.θ, y)
34+
3335
"""
3436
$TYPEDSIGNATURES
3537

src/ContextualStochasticArgmax/ContextualStochasticArgmax.jl

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module ContextualStochasticArgmax
33
using ..Utils
44
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
55
using Flux: Dense
6+
using LinearAlgebra: dot
67
using Random: Random, AbstractRNG, MersenneTwister
78
using Statistics: mean
89

@@ -11,7 +12,7 @@ $TYPEDEF
1112
1213
Minimal contextual stochastic argmax benchmark.
1314
14-
Per instance: `c_base ~ U[0,1]^n` (base utility, stored in `extra` of the instance sample).
15+
Per instance: `c_base ~ U[0,1]^n` (base utility, stored in `context` of the instance sample).
1516
Per context draw: `x_raw ~ N(0, I_d)` (observable context). Features: `x = [c_base; x_raw]`.
1617
Per scenario: `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`.
1718
The learner sees `x` and must predict `θ̂` so that `argmax(θ̂)` ≈ `argmax(ξ)`.
@@ -44,25 +45,42 @@ end
4445
Utils.is_minimization_problem(::ContextualStochasticArgmaxBenchmark) = false
4546
Utils.generate_maximizer(::ContextualStochasticArgmaxBenchmark) = one_hot_argmax
4647

48+
function Utils.objective_value(
49+
::ContextualStochasticArgmaxBenchmark, sample::DataSample, y, scenario
50+
)
51+
return dot(scenario, y)
52+
end
53+
54+
function Utils.objective_value(
55+
bench::ContextualStochasticArgmaxBenchmark, sample::DataSample, y
56+
)
57+
if hasproperty(sample.extra, :scenario)
58+
return Utils.objective_value(bench, sample, y, sample.scenario)
59+
elseif hasproperty(sample.extra, :scenarios)
60+
return mean(Utils.objective_value(bench, sample, y, ξ) for ξ in sample.scenarios)
61+
end
62+
return error("Sample must have scenario or scenarios")
63+
end
64+
4765
"""
4866
generate_instance(::ContextualStochasticArgmaxBenchmark, rng)
4967
50-
Draw `c_base ~ U[0,1]^n` and store it in `extra`. No solver kwargs are needed
68+
Draw `c_base ~ U[0,1]^n` and store it in `context`. No solver kwargs are needed
5169
(the maximizer is `one_hot_argmax`, which takes no kwargs).
5270
"""
5371
function Utils.generate_instance(
5472
bench::ContextualStochasticArgmaxBenchmark, rng::AbstractRNG; kwargs...
5573
)
5674
c_base = rand(rng, Float32, bench.n)
57-
return DataSample(; extra=(; c_base))
75+
return DataSample(; c_base)
5876
end
5977

6078
"""
6179
generate_context(::ContextualStochasticArgmaxBenchmark, rng, instance_sample)
6280
6381
Draw `x_raw ~ N(0, I_d)` and return a context sample with:
6482
- `x = [c_base; x_raw]`: full feature vector seen by the ML model.
65-
- `extra = (; c_base, x_raw)`: latents spread into [`generate_scenario`](@ref).
83+
- `c_base`, `x_raw` in `context`: spread into [`generate_scenario`](@ref).
6684
"""
6785
function Utils.generate_context(
6886
bench::ContextualStochasticArgmaxBenchmark,
@@ -71,14 +89,14 @@ function Utils.generate_context(
7189
)
7290
c_base = instance_sample.c_base
7391
x_raw = randn(rng, Float32, bench.d)
74-
return DataSample(; x=vcat(c_base, x_raw), extra=(; x_raw, c_base))
92+
return DataSample(; x=vcat(c_base, x_raw), c_base, x_raw)
7593
end
7694

7795
"""
7896
generate_scenario(::ContextualStochasticArgmaxBenchmark, rng; c_base, x_raw, kwargs...)
7997
8098
Draw `ξ = c_base + W * x_raw + noise`, `noise ~ N(0, noise_std² I)`.
81-
`c_base` and `x_raw` are spread from `ctx.extra` by the framework.
99+
`c_base` and `x_raw` are spread from `ctx.context` by the framework.
82100
"""
83101
function Utils.generate_scenario(
84102
bench::ContextualStochasticArgmaxBenchmark,

src/ContextualStochasticArgmax/policies.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function csa_saa_policy(ctx_sample, scenarios)
1111
y = one_hot_argmax(mean(scenarios))
1212
return [
1313
DataSample(;
14-
ctx_sample.maximizer_kwargs...,
14+
ctx_sample.context...,
1515
x=ctx_sample.x,
1616
y=y,
1717
extra=(; ctx_sample.extra..., scenarios),

src/DecisionFocusedLearningBenchmarks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ export generate_statistical_model
7979
export generate_maximizer
8080
export generate_anticipative_solver, generate_parametric_anticipative_solver
8181
export is_exogenous, is_endogenous
82+
export is_minimization_problem
8283

8384
export objective_value
8485
export has_visualization, plot_instance, plot_solution, plot_trajectory, animate_trajectory

src/FixedSizeShortestPath/FixedSizeShortestPath.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ function FixedSizeShortestPathBenchmark(;
5555
end
5656

5757
function Utils.objective_value(
58-
::FixedSizeShortestPathBenchmark, θ::AbstractArray, y::AbstractArray
58+
::FixedSizeShortestPathBenchmark, sample::DataSample, y::AbstractArray
5959
)
60-
return -dot(θ, y)
60+
return -dot(sample.θ, y)
6161
end
6262

6363
"""

src/PortfolioOptimization/PortfolioOptimization.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Distributions: Uniform, Bernoulli
66
using Flux: Chain, Dense
77
using Ipopt: Ipopt
88
using JuMP: @variable, @objective, @constraint, optimize!, value, Model, set_silent
9-
using LinearAlgebra: I
9+
using LinearAlgebra: I, dot
1010
using Random: Random, AbstractRNG, MersenneTwister
1111

1212
"""
@@ -38,6 +38,10 @@ struct PortfolioOptimizationBenchmark <: AbstractBenchmark
3838
f::Vector{Float32}
3939
end
4040

41+
function Utils.objective_value(::PortfolioOptimizationBenchmark, sample::DataSample, y)
42+
return dot(sample.θ, y)
43+
end
44+
4145
"""
4246
$TYPEDSIGNATURES
4347

src/Ranking/Ranking.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
55
using Flux: Chain, Dense
66
using Random
77

8+
using LinearAlgebra: dot
9+
810
"""
911
$TYPEDEF
1012
@@ -29,6 +31,9 @@ function Base.show(io::IO, bench::RankingBenchmark)
2931
)
3032
end
3133

34+
Utils.objective_value(::RankingBenchmark, sample::DataSample, y) = dot(sample.θ, y)
35+
Utils.is_minimization_problem(::RankingBenchmark) = false
36+
3237
"""
3338
$TYPEDSIGNATURES
3439

0 commit comments

Comments
 (0)