File tree Expand file tree Collapse file tree 3 files changed +46
-1
lines changed
src/ContextualStochasticArgmax Expand file tree Collapse file tree 3 files changed +46
-1
lines changed Original file line number Diff line number Diff line change 11module ContextualStochasticArgmax
22
33using .. Utils
4- using DocStringExtensions: TYPEDEF, TYPEDFIELDS
4+ using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
55using Flux: Dense
66using Random: Random, AbstractRNG, MersenneTwister
7+ using Statistics: mean
78
89"""
910$TYPEDEF
@@ -97,6 +98,8 @@ function Utils.generate_statistical_model(
9798 return Dense (bench. n + bench. d => bench. n; bias= false )
9899end
99100
101+ include (" policies.jl" )
102+
100103export ContextualStochasticArgmaxBenchmark
101104
102105end
Original file line number Diff line number Diff line change 1+ using Statistics: mean
2+
3+ """
4+ $TYPEDSIGNATURES
5+
6+ SAA baseline policy: returns `argmax(mean(scenarios))`.
7+ For a linear argmax problem this is the exact SAA-optimal decision.
8+ Returns a single labeled [`DataSample`](@ref) with `extra=(; scenarios)`.
9+ """
10+ function csa_saa_policy (ctx_sample, scenarios)
11+ y = one_hot_argmax (mean (scenarios))
12+ return [
13+ DataSample (;
14+ ctx_sample. maximizer_kwargs... ,
15+ x= ctx_sample. x,
16+ y= y,
17+ extra= (; ctx_sample. extra... , scenarios),
18+ ),
19+ ]
20+ end
21+
22+ """
23+ $TYPEDSIGNATURES
24+
25+ Return the named baseline policies for [`ContextualStochasticArgmaxBenchmark`](@ref).
26+ Each policy has signature `(ctx_sample, scenarios) -> Vector{DataSample}`.
27+ """
28+ function Utils. generate_baseline_policies (:: ContextualStochasticArgmaxBenchmark )
29+ return (; saa= Policy (" SAA" , " argmax of mean scenarios" , csa_saa_policy))
30+ end
Original file line number Diff line number Diff line change 3838 @test first (d1). x ≈ first (d2). x
3939end
4040
41+ @testset " csa_saa_policy" begin
42+ using DecisionFocusedLearningBenchmarks
43+
44+ b = ContextualStochasticArgmaxBenchmark (; n= 5 , d= 3 , seed= 0 )
45+ policies = generate_baseline_policies (b)
46+
47+ labeled = generate_dataset (b, 3 ; nb_scenarios= 4 , target_policy= policies. saa)
48+ @test length (labeled) == 3 # one sample per context (SAA aggregates)
49+ @test sum (first (labeled). y) ≈ 1.0 # one-hot label
50+ @test length (first (labeled). extra. scenarios) == 4 # scenarios stored in extra
51+ end
52+
4153@testset " SampleAverageApproximation wrapper on ContextualStochasticArgmax" begin
4254 using DecisionFocusedLearningBenchmarks
4355 using Statistics: mean
You can’t perform that action at this time.
0 commit comments