Skip to content

Commit 386e6a5

Browse files
committed
refactor: make objective_value explicit and remove implicit fallback
1 parent 9069660 commit 386e6a5

11 files changed

Lines changed: 74 additions & 31 deletions

File tree

src/Argmax/Argmax.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
55
using Flux: Chain, Dense
66
using Random
77

8+
using LinearAlgebra: dot
9+
810
"""
911
$TYPEDEF
1012
@@ -29,6 +31,8 @@ function Base.show(io::IO, bench::ArgmaxBenchmark)
2931
)
3032
end
3133

34+
Utils.objective_value(::ArgmaxBenchmark, sample::DataSample, y) = dot(sample.θ, y)
35+
3236
"""
3337
$TYPEDSIGNATURES
3438

src/Argmax2D/Argmax2D.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ function Base.show(io::IO, bench::Argmax2DBenchmark)
3030
return print(io, "Argmax2DBenchmark(nb_features=$nb_features)")
3131
end
3232

33+
Utils.objective_value(::Argmax2DBenchmark, sample::DataSample, y) = dot(sample.θ, y)
34+
3335
"""
3436
$TYPEDSIGNATURES
3537

src/ContextualStochasticArgmax/ContextualStochasticArgmax.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module ContextualStochasticArgmax
33
using ..Utils
44
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
55
using Flux: Dense
6+
using LinearAlgebra: dot
67
using Random: Random, AbstractRNG, MersenneTwister
78
using Statistics: mean
89

@@ -44,6 +45,25 @@ end
4445
Utils.is_minimization_problem(::ContextualStochasticArgmaxBenchmark) = false
4546
Utils.generate_maximizer(::ContextualStochasticArgmaxBenchmark) = one_hot_argmax
4647

48+
function Utils.objective_value(
49+
::ContextualStochasticArgmaxBenchmark, sample::DataSample, y, scenario
50+
)
51+
return dot(scenario, y)
52+
end
53+
54+
function Utils.objective_value(
55+
bench::ContextualStochasticArgmaxBenchmark, sample::DataSample, y
56+
)
57+
if hasproperty(sample.extra, :scenario)
58+
return Utils.objective_value(bench, sample, y, sample.extra.scenario)
59+
elseif hasproperty(sample.extra, :scenarios)
60+
return mean(
61+
Utils.objective_value(bench, sample, y, ξ) for ξ in sample.extra.scenarios
62+
)
63+
end
64+
return error("Sample must have scenario or scenarios")
65+
end
66+
4767
"""
4868
generate_instance(::ContextualStochasticArgmaxBenchmark, rng)
4969

src/FixedSizeShortestPath/FixedSizeShortestPath.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ function FixedSizeShortestPathBenchmark(;
5555
end
5656

5757
function Utils.objective_value(
58-
::FixedSizeShortestPathBenchmark, θ::AbstractArray, y::AbstractArray
58+
::FixedSizeShortestPathBenchmark, sample::DataSample, y::AbstractArray
5959
)
60-
return -dot(θ, y)
60+
return -dot(sample.θ, y)
6161
end
6262

6363
"""

src/PortfolioOptimization/PortfolioOptimization.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Distributions: Uniform, Bernoulli
66
using Flux: Chain, Dense
77
using Ipopt: Ipopt
88
using JuMP: @variable, @objective, @constraint, optimize!, value, Model, set_silent
9-
using LinearAlgebra: I
9+
using LinearAlgebra: I, dot
1010
using Random: Random, AbstractRNG, MersenneTwister
1111

1212
"""
@@ -38,6 +38,10 @@ struct PortfolioOptimizationBenchmark <: AbstractBenchmark
3838
f::Vector{Float32}
3939
end
4040

41+
function Utils.objective_value(::PortfolioOptimizationBenchmark, sample::DataSample, y)
42+
return dot(sample.θ, y)
43+
end
44+
4145
"""
4246
$TYPEDSIGNATURES
4347

src/Ranking/Ranking.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
55
using Flux: Chain, Dense
66
using Random
77

8+
using LinearAlgebra: dot
9+
810
"""
911
$TYPEDEF
1012
@@ -29,6 +31,8 @@ function Base.show(io::IO, bench::RankingBenchmark)
2931
)
3032
end
3133

34+
Utils.objective_value(::RankingBenchmark, sample::DataSample, y) = dot(sample.θ, y)
35+
3236
"""
3337
$TYPEDSIGNATURES
3438

src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,27 @@ end
6868
include("policies.jl")
6969

7070
function Utils.objective_value(
71-
::StochasticVehicleSchedulingBenchmark, sample::DataSample, y::BitVector
71+
::StochasticVehicleSchedulingBenchmark,
72+
sample::DataSample,
73+
y::BitVector,
74+
scenario::VSPScenario,
7275
)
73-
stoch = build_stochastic_instance(sample.instance, sample.extra.scenarios)
76+
stoch = build_stochastic_instance(sample.instance, [scenario])
7477
return evaluate_solution(y, stoch)
7578
end
7679

80+
function Utils.objective_value(
81+
bench::StochasticVehicleSchedulingBenchmark, sample::DataSample, y::BitVector
82+
)
83+
if hasproperty(sample.extra, :scenario)
84+
return Utils.objective_value(bench, sample, y, sample.extra.scenario)
85+
elseif hasproperty(sample.extra, :scenarios)
86+
stoch = build_stochastic_instance(sample.instance, sample.extra.scenarios)
87+
return evaluate_solution(y, stoch)
88+
end
89+
return error("Sample must have scenario or scenarios")
90+
end
91+
7792
"""
7893
$TYPEDSIGNATURES
7994

src/SubsetSelection/SubsetSelection.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module SubsetSelection
33
using ..Utils
44
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
55
using Flux: Chain, Dense
6+
using LinearAlgebra: dot
67
using Random
78

89
"""
@@ -31,6 +32,8 @@ function Base.show(io::IO, bench::SubsetSelectionBenchmark)
3132
return print(io, "SubsetSelectionBenchmark(n=$n, k=$k)")
3233
end
3334

35+
Utils.objective_value(::SubsetSelectionBenchmark, sample::DataSample, y) = dot(sample.θ, y)
36+
3437
function SubsetSelectionBenchmark(; n::Int=25, k::Int=5, identity_mapping::Bool=true)
3538
@assert n >= k "number of items n must be greater than k"
3639
mapping = if identity_mapping

src/Utils/interface.jl

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -150,22 +150,11 @@ function compute_gap end
150150
"""
151151
$TYPEDSIGNATURES
152152
153-
Compute `dot(θ, y)`. Override for non-linear objectives.
153+
Compute the objective value of given solution `y` for a specific benchmark.
154+
Must be implemented by each concrete benchmark type. For stochastic benchmarks,
155+
an additional `scenario` argument is required.
154156
"""
155-
function objective_value(::AbstractBenchmark, θ::AbstractArray, y::AbstractArray)
156-
return dot(θ, y)
157-
end
158-
159-
"""
160-
$TYPEDSIGNATURES
161-
162-
Compute the objective value of given solution `y`.
163-
"""
164-
function objective_value(
165-
bench::AbstractBenchmark, sample::DataSample{CTX,EX,F,S,C}, y::AbstractArray
166-
) where {CTX,EX,F,S,C<:AbstractArray}
167-
return objective_value(bench, sample.θ, y)
168-
end
157+
function objective_value end
169158

170159
"""
171160
$TYPEDSIGNATURES
@@ -527,7 +516,9 @@ Evaluate a decision `y` against stored scenarios (average over scenarios).
527516
function objective_value(
528517
saa::SampleAverageApproximation, sample::DataSample, y::AbstractArray
529518
)
530-
return mean(objective_value(saa.benchmark, ξ, y) for ξ in sample.extra.scenarios)
519+
return mean(
520+
objective_value(saa.benchmark, sample, y, ξ) for ξ in sample.extra.scenarios
521+
)
531522
end
532523

533524
"""

src/Warcraft/Warcraft.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ Does not have any field.
2424
"""
2525
struct WarcraftBenchmark <: AbstractBenchmark end
2626

27-
function Utils.objective_value(::WarcraftBenchmark, θ::AbstractArray, y::AbstractArray)
28-
return -dot(θ, y)
27+
function Utils.objective_value(::WarcraftBenchmark, sample::DataSample, y::AbstractArray)
28+
return -dot(sample.θ, y)
2929
end
3030

3131
"""

0 commit comments

Comments
 (0)