-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathArgmax2D.jl
More file actions
115 lines (94 loc) · 2.68 KB
/
Argmax2D.jl
File metadata and controls
115 lines (94 loc) · 2.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
module Argmax2D
using ..Utils
using Colors: Colors
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
using Flux: Chain, Dense
using LaTeXStrings: @L_str
using LinearAlgebra: dot, norm
using Plots: Plots
using Random: Random, MersenneTwister, AbstractRNG
include("polytope.jl")
"""
$TYPEDEF
Argmax becnhmark on a 2d polytope.
# Fields
$TYPEDFIELDS
"""
struct Argmax2DBenchmark{E,R} <: AbstractBenchmark
"number of features"
nb_features::Int
"true mapping between features and costs"
encoder::E
""
polytope_vertex_range::R
end
function Base.show(io::IO, bench::Argmax2DBenchmark)
(; nb_features) = bench
return print(io, "Argmax2DBenchmark(nb_features=$nb_features)")
end
"""
$TYPEDSIGNATURES
Custom constructor for [`Argmax2DBenchmark`](@ref).
"""
function Argmax2DBenchmark(; nb_features::Int=5, seed=nothing, polytope_vertex_range=[6])
Random.seed!(seed)
model = Dense(nb_features => 2; bias=false)
return Argmax2DBenchmark(nb_features, model, polytope_vertex_range)
end
function Utils.is_minimization_problem(::Argmax2DBenchmark)
return false
end
maximizer(θ; instance, kwargs...) = instance[argmax(dot(θ, v) for v in instance)]
"""
$TYPEDSIGNATURES
Generate a sample for the [`Argmax2DBenchmark`](@ref).
"""
function Utils.generate_sample(bench::Argmax2DBenchmark, rng::AbstractRNG)
(; nb_features, encoder, polytope_vertex_range) = bench
x = randn(rng, Float32, nb_features)
θ_true = encoder(x)
θ_true ./= 2 * norm(θ_true)
instance = build_polytope(rand(rng, polytope_vertex_range); shift=rand(rng))
y_true = maximizer(θ_true; instance)
return DataSample(; x=x, θ=θ_true, y=y_true, instance=instance)
end
"""
$TYPEDSIGNATURES
Maximizer for the [`Argmax2DBenchmark`](@ref).
"""
function Utils.generate_maximizer(::Argmax2DBenchmark)
return maximizer
end
"""
$TYPEDSIGNATURES
Generate a statistical model for the [`Argmax2DBenchmark`](@ref).
"""
function Utils.generate_statistical_model(
bench::Argmax2DBenchmark; seed=nothing, rng=MersenneTwister(seed)
)
Random.seed!(rng, seed)
(; nb_features) = bench
model = Dense(nb_features => 2; bias=false)
return model
end
function Utils.plot_data(::Argmax2DBenchmark; instance, θ, kwargs...)
pl = init_plot()
plot_polytope!(pl, instance)
plot_objective!(pl, θ)
return plot_maximizer!(pl, θ, instance, maximizer)
end
"""
$TYPEDSIGNATURES
Plot the data sample for the [`Argmax2DBenchmark`](@ref).
"""
function Utils.plot_data(
bench::Argmax2DBenchmark,
sample::DataSample;
instance=sample.instance,
θ=sample.θ,
kwargs...,
)
return Utils.plot_data(bench; instance, θ, kwargs...)
end
export Argmax2DBenchmark
end