-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathFixedSizeShortestPath.jl
More file actions
145 lines (123 loc) · 3.7 KB
/
FixedSizeShortestPath.jl
File metadata and controls
145 lines (123 loc) · 3.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
module FixedSizeShortestPath
using ..Utils
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
using Distributions
using Flux: Chain, Dense
using Graphs
using LinearAlgebra
using Random
using SparseArrays
"""
$TYPEDEF
Benchmark problem for the shortest path problem.
In this benchmark, all graphs are acyclic directed grids, all of the same size `grid_size`.
Features are given at instance level (one dimensional vector of length `p` for each graph).
Data is generated using the process described in: <https://arxiv.org/abs/2307.13565>.
# Fields
$TYPEDFIELDS
"""
struct FixedSizeShortestPathBenchmark <: AbstractBenchmark
"grid graph instance"
graph::SimpleDiGraph{Int64}
"grid size of graphs"
grid_size::Tuple{Int,Int}
"size of feature vectors"
p::Int
"degree of formula between features and true weights"
deg::Int
"multiplicative noise for true weights sampled between [1-ν, 1+ν], should be between 0 and 1"
ν::Float32
end
function Base.show(io::IO, bench::FixedSizeShortestPathBenchmark)
(; grid_size, p, deg, ν) = bench
return print(
io, "FixedSizeShortestPathBenchmark(grid_size=$grid_size, p=$p, deg=$deg, ν=$ν)"
)
end
"""
$TYPEDSIGNATURES
Constructor for [`FixedSizeShortestPathBenchmark`](@ref).
"""
function FixedSizeShortestPathBenchmark(;
grid_size::Tuple{Int,Int}=(5, 5), p::Int=5, deg::Int=1, ν=0.0f0
)
@assert ν >= 0.0 && ν <= 1.0
g = DiGraph(collect(edges(Graphs.grid(grid_size))))
return FixedSizeShortestPathBenchmark(g, grid_size, p, deg, ν)
end
function Utils.objective_value(
::FixedSizeShortestPathBenchmark, θ::AbstractArray, y::AbstractArray
)
return -dot(θ, y)
end
"""
$TYPEDSIGNATURES
Outputs a function that computes the longest path on the grid graph, given edge weights θ as input.
```julia
maximizer = generate_maximizer(bench)
maximizer(θ)
```
"""
function Utils.generate_maximizer(bench::FixedSizeShortestPathBenchmark; use_dijkstra=true)
g = bench.graph
V = Graphs.nv(g)
E = Graphs.ne(g)
I = [src(e) for e in edges(g)]
J = [dst(e) for e in edges(g)]
algo =
use_dijkstra ? Graphs.dijkstra_shortest_paths : Graphs.bellman_ford_shortest_paths
function shortest_path_maximizer(θ; kwargs...)
weights = sparse(I, J, -θ, V, V)
parents = algo(g, 1, weights).parents
y = falses(V, V)
u = V
while u != 1
prev = parents[u]
y[prev, u] = true
u = prev
end
solution = falses(E)
for (i, edge) in enumerate(edges(g))
if y[src(edge), dst(edge)]
solution[i] = true
end
end
return solution
end
return shortest_path_maximizer
end
"""
$TYPEDSIGNATURES
Generate a labeled sample for the fixed size shortest path benchmark.
"""
function Utils.generate_sample(
bench::FixedSizeShortestPathBenchmark, rng::AbstractRNG; type::Type=Float32
)
(; graph, p, deg, ν) = bench
features = randn(rng, Float32, bench.p)
E = Graphs.ne(graph)
# True weights
B = rand(rng, Bernoulli(0.5), E, p)
ξ = if ν == 0.0
ones(type, E)
else
rand(rng, Uniform{type}(1 - ν, 1 + ν), E)
end
θ_true = -(1 .+ (3 .+ B * features ./ type(sqrt(p))) .^ deg) .* ξ
maximizer = Utils.generate_maximizer(bench)
y_true = maximizer(θ_true)
return DataSample(; x=features, θ=θ_true, y=y_true)
end
"""
$TYPEDSIGNATURES
Initialize a linear model for `bench` using `Flux`.
"""
function Utils.generate_statistical_model(
bench::FixedSizeShortestPathBenchmark; seed=nothing
)
Random.seed!(seed)
(; p, graph) = bench
return Chain(Dense(p, ne(graph)))
end
export FixedSizeShortestPathBenchmark
end