@@ -7,7 +7,7 @@ using Flux: Chain, Dense
77using LaTeXStrings: @L_str
88using LinearAlgebra: dot, norm
99using Plots: Plots
10- using Random: Random, MersenneTwister
10+ using Random: Random, MersenneTwister, AbstractRNG
1111
1212include (" polytope.jl" )
1313
@@ -53,20 +53,16 @@ maximizer(θ; instance, kwargs...) = instance[argmax(dot(θ, v) for v in instanc
5353"""
5454$TYPEDSIGNATURES
5555
56- Generate a dataset for the [`Argmax2DBenchmark`](@ref).
56+ Generate a sample for the [`Argmax2DBenchmark`](@ref).
5757"""
58- function Utils. generate_dataset (
59- bench:: Argmax2DBenchmark , dataset_size= 10 ; seed= nothing , rng= MersenneTwister (seed)
60- )
58+ function Utils. generate_sample (bench:: Argmax2DBenchmark , rng:: AbstractRNG )
6159 (; nb_features, encoder, polytope_vertex_range) = bench
62- return map (1 : dataset_size) do _
63- x = randn (rng, Float32, nb_features)
64- θ_true = encoder (x)
65- θ_true ./= 2 * norm (θ_true)
66- instance = build_polytope (rand (rng, polytope_vertex_range); shift= rand (rng))
67- y_true = maximizer (θ_true; instance)
68- return DataSample (; x= x, θ_true= θ_true, y_true= y_true, instance= instance)
69- end
60+ x = randn (rng, Float32, nb_features)
61+ θ_true = encoder (x)
62+ θ_true ./= 2 * norm (θ_true)
63+ instance = build_polytope (rand (rng, polytope_vertex_range); shift= rand (rng))
64+ y_true = maximizer (θ_true; instance)
65+ return DataSample (; x= x, θ_true= θ_true, y_true= y_true, instance= instance)
7066end
7167
7268"""
0 commit comments