Skip to content

Commit 9443ac7

Browse files
committed
saa policy for the contextual argmax
1 parent 15edb07 commit 9443ac7

File tree

3 files changed

+46
-1
lines changed

3 files changed

+46
-1
lines changed

src/ContextualStochasticArgmax/ContextualStochasticArgmax.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
module ContextualStochasticArgmax
22

33
using ..Utils
4-
using DocStringExtensions: TYPEDEF, TYPEDFIELDS
4+
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
55
using Flux: Dense
66
using Random: Random, AbstractRNG, MersenneTwister
7+
using Statistics: mean
78

89
"""
910
$TYPEDEF
@@ -97,6 +98,8 @@ function Utils.generate_statistical_model(
9798
return Dense(bench.n + bench.d => bench.n; bias=false)
9899
end
99100

101+
include("policies.jl")
102+
100103
export ContextualStochasticArgmaxBenchmark
101104

102105
end
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
using Statistics: mean
2+
3+
"""
4+
$TYPEDSIGNATURES
5+
6+
SAA baseline policy: returns `argmax(mean(scenarios))`.
7+
For a linear argmax problem this is the exact SAA-optimal decision.
8+
Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.
9+
"""
10+
function csa_saa_policy(ctx_sample, scenarios)
11+
y = one_hot_argmax(mean(scenarios))
12+
return [
13+
DataSample(;
14+
ctx_sample.maximizer_kwargs...,
15+
x=ctx_sample.x,
16+
y=y,
17+
extra=(; ctx_sample.extra..., scenarios),
18+
),
19+
]
20+
end
21+
22+
"""
23+
$TYPEDSIGNATURES
24+
25+
Return the named baseline policies for [`ContextualStochasticArgmaxBenchmark`](@ref).
26+
Each policy has signature `(ctx_sample, scenarios) -> Vector{DataSample}`.
27+
"""
28+
function Utils.generate_baseline_policies(::ContextualStochasticArgmaxBenchmark)
29+
return (; saa=Policy("SAA", "argmax of mean scenarios", csa_saa_policy))
30+
end

test/contextual_stochastic_argmax.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,18 @@
3838
@test first(d1).x first(d2).x
3939
end
4040

41+
@testset "csa_saa_policy" begin
42+
using DecisionFocusedLearningBenchmarks
43+
44+
b = ContextualStochasticArgmaxBenchmark(; n=5, d=3, seed=0)
45+
policies = generate_baseline_policies(b)
46+
47+
labeled = generate_dataset(b, 3; nb_scenarios=4, target_policy=policies.saa)
48+
@test length(labeled) == 3 # one sample per context (SAA aggregates)
49+
@test sum(first(labeled).y) 1.0 # one-hot label
50+
@test length(first(labeled).extra.scenarios) == 4 # scenarios stored in extra
51+
end
52+
4153
@testset "SampleAverageApproximation wrapper on ContextualStochasticArgmax" begin
4254
using DecisionFocusedLearningBenchmarks
4355
using Statistics: mean

0 commit comments

Comments
 (0)