Skip to content

Commit 0fed793

Browse files
committed
Revamp plotting interface
1 parent 668ecf9 commit 0fed793

23 files changed

Lines changed: 901 additions & 274 deletions

Project.toml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ authors = ["Members of JuliaDecisionFocusedLearning"]
77
projects = ["docs", "test"]
88

99
[deps]
10-
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
1110
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1211
ConstrainedShortestPaths = "b3798467-87dc-4d99-943d-35a1bd39e395"
1312
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
@@ -22,12 +21,11 @@ InferOpt = "4846b161-c94e-4150-8dac-c7ae193c601f"
2221
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
2322
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
2423
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
25-
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
2624
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
25+
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
2726
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2827
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
2928
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
30-
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
3129
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
3230
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3331
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
@@ -37,8 +35,13 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3735
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3836
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3937

38+
[weakdeps]
39+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
40+
41+
[extensions]
42+
DFLBenchmarksPlotsExt = "Plots"
43+
4044
[compat]
41-
Colors = "0.13.1"
4245
Combinatorics = "1.0.3"
4346
ConstrainedShortestPaths = "0.6.0"
4447
DataDeps = "0.7"
@@ -53,8 +56,8 @@ InferOpt = "0.7.0"
5356
Ipopt = "1.6"
5457
IterTools = "1.10.0"
5558
JSON = "1"
56-
JuMP = "1.22"
5759
LaTeXStrings = "1.4.0"
60+
JuMP = "1.22"
5861
LinearAlgebra = "1"
5962
Metalhead = "0.9.4"
6063
NPZ = "0.4"

docs/src/custom_benchmarks.md

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,13 @@ generate_maximizer(bench::MyBenchmark)
6969
### Optional methods
7070

7171
```julia
72+
generate_baseline_policies(bench::MyBenchmark) -> collection of callables
7273
is_minimization_problem(bench::MyBenchmark) -> Bool # default: false (maximization)
7374
objective_value(bench::MyBenchmark, sample::DataSample, y) -> Real
7475
compute_gap(bench::MyBenchmark, dataset, model, maximizer) -> Float64
75-
plot_data(bench::MyBenchmark, sample::DataSample; kwargs...)
76-
plot_instance(bench::MyBenchmark, instance; kwargs...)
77-
plot_solution(bench::MyBenchmark, sample::DataSample, y; kwargs...)
78-
generate_baseline_policies(bench::MyBenchmark) -> collection of callables
76+
has_visualization(bench::MyBenchmark) -> Bool # default: false; return true when plot methods are implemented/available
77+
plot_instance(bench::MyBenchmark, sample::DataSample; kwargs...)
78+
plot_solution(bench::MyBenchmark, sample::DataSample; kwargs...)
7979
```
8080

8181
---
@@ -148,6 +148,13 @@ generate_baseline_policies(bench::MyDynamicBenchmark)
148148
# Each callable performs a full episode rollout and returns the trajectory.
149149
```
150150

151+
### Optional visualization methods
152+
153+
```julia
154+
plot_trajectory(bench::MyDynamicBenchmark, traj::Vector{DataSample}; kwargs...)
155+
animate_trajectory(bench::MyDynamicBenchmark, traj::Vector{DataSample}; kwargs...)
156+
```
157+
151158
`generate_dataset` for dynamic benchmarks **requires** a `target_policy` kwarg,
152159
there is no default. The `target_policy` must be a callable `(env) -> Vector{DataSample}`.
153160

docs/src/tutorials/warcraft_tutorial.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ y_true = sample.y
3232
# `context` is not used in this benchmark (no solver kwargs needed), so it is empty:
3333
isempty(sample.context)
3434

35-
# For some benchmarks, we provide the following plotting method [`plot_data`](@ref) to visualize the data:
36-
plot_data(b, sample)
35+
# For some benchmarks, we provide the following plotting method [`plot_solution`](@ref) to visualize the data:
36+
plot_solution(b, sample)
3737
# We can see here the terrain image, the true terrain weights, and the true shortest path avoiding the high cost cells.
3838

3939
# ## Building a pipeline
@@ -50,7 +50,7 @@ maximizer = generate_maximizer(b; dijkstra=true)
5050
# In the case o fthe Warcraft benchmark, the method has an additional keyword argument to chose the algorithm to use: Dijkstra's algorithm or Bellman-Ford algorithm.
5151
y = maximizer(θ)
5252
# As we can see, currently the pipeline predicts random noise as cell weights, and therefore the maximizer returns a straight line path.
53-
plot_data(b, DataSample(; x, θ, y))
53+
plot_solution(b, DataSample(; x, θ, y))
5454
# We can evaluate the current pipeline performance using the optimality gap metric:
5555
starting_gap = compute_gap(b, test_dataset, model, maximizer)
5656

@@ -85,7 +85,7 @@ final_gap = compute_gap(b, test_dataset, model, maximizer)
8585
#
8686
θ = model(x)
8787
y = maximizer(θ)
88-
plot_data(b, DataSample(; x, θ, y))
88+
plot_solution(b, DataSample(; x, θ, y))
8989

9090
using Test #src
9191
@test final_gap < starting_gap #src

docs/src/using_benchmarks.md

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ A benchmark bundles a problem family (an instance generator, a combinatorial sol
1010
Three abstract types cover the main settings:
1111
- **`AbstractBenchmark`**: static problems (one instance, one decision)
1212
- **`AbstractStochasticBenchmark{exogenous}`**: stochastic problems (type parameter indicates whether uncertainty is exogenous)
13-
- **`AbstractDynamicBenchmark`**: sequential / multi-stage problems
13+
- **`AbstractDynamicBenchmark{exogenous}`**: sequential / multi-stage problems
1414

1515
The sections below explain what changes between these settings. For most purposes, start with a static benchmark to understand the core workflow.
1616

@@ -180,10 +180,29 @@ rewards, samples = evaluate_policy!(pol, envs, n_episodes)
180180

181181
## Visualization
182182

183-
Where implemented, benchmarks provide benchmark-specific plotting helpers:
184-
183+
Plots is an **optional** dependency, load it with `using Plots` to unlock the plot functions. Not all benchmarks support visualization, call `has_visualization(bench)` to check.
185184
```julia
186-
plot_data(bench, sample) # overview of a data sample
187-
plot_instance(bench, instance) # raw problem instance
188-
plot_solution(bench, sample, y) # overlay solution on instance
185+
using Plots
186+
187+
bench = Argmax2DBenchmark()
188+
dataset = generate_dataset(bench, 10)
189+
sample = dataset[1]
190+
191+
has_visualization(bench) # true
192+
plot_instance(bench, sample) # problem geometry only
193+
plot_solution(bench, sample) # sample.y overlaid on the instance
194+
plot_solution(bench, sample, y) # convenience 3-arg form: override y before plotting
195+
196+
# Dynamic benchmarks only
197+
traj = generate_anticipative_solver(bench)(env)
198+
plot_trajectory(bench, traj) # grid of epoch subplots
199+
anim = animate_trajectory(bench, traj; fps=2)
200+
gif(anim, "episode.gif")
189201
```
202+
203+
- `has_visualization(bench)`: returns `true` for benchmarks that implement plot support (if Plots is loaded).
204+
- `plot_instance(bench, sample; kwargs...)`: renders the problem geometry without any solution.
205+
- `plot_solution(bench, sample; kwargs...)`: renders `sample.y` overlaid on the instance.
206+
- `plot_solution(bench, sample, y; kwargs...)`: 3-arg convenience form that overrides `y` before plotting.
207+
- `plot_trajectory(bench, traj; kwargs...)`: dynamic benchmarks only; produces a grid of per-epoch subplots.
208+
- `animate_trajectory(bench, traj; kwargs...)`: dynamic benchmarks only, returns a `Plots.Animation` that can be saved with `gif(anim, "file.gif")`.

ext/DFLBenchmarksPlotsExt.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
module DFLBenchmarksPlotsExt
2+
3+
using DecisionFocusedLearningBenchmarks
4+
using DocStringExtensions: TYPEDSIGNATURES
5+
using LaTeXStrings: @L_str
6+
using Plots
7+
import DecisionFocusedLearningBenchmarks:
8+
has_visualization, plot_instance, plot_solution, plot_trajectory, animate_trajectory
9+
10+
include("plots/argmax2d_plots.jl")
11+
include("plots/warcraft_plots.jl")
12+
include("plots/svs_plots.jl")
13+
include("plots/dvs_plots.jl")
14+
15+
"""
16+
plot_solution(bench::AbstractBenchmark, sample::DataSample, y; kwargs...)
17+
18+
Reconstruct a new sample with `y` overridden and delegate to the 2-arg
19+
[`plot_solution`](@ref). Only available when `Plots` is loaded.
20+
"""
21+
function plot_solution(bench::AbstractBenchmark, sample::DataSample, y; kwargs...)
22+
return plot_solution(
23+
bench,
24+
DataSample(; sample.context..., x=sample.x, θ=sample.θ, y=y, extra=sample.extra);
25+
kwargs...,
26+
)
27+
end
28+
29+
end

ext/plots/argmax2d_plots.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
function _init_plot(title="")
2+
pl = Plots.plot(;
3+
aspect_ratio=:equal,
4+
legend=:outerleft,
5+
xlim=(-1.1, 1.1),
6+
ylim=(-1.1, 1.1),
7+
title=title,
8+
)
9+
return pl
10+
end
11+
12+
function _plot_polytope!(pl, vertices)
13+
return Plots.plot!(
14+
pl,
15+
vcat(map(first, vertices), first(vertices[1])),
16+
vcat(map(last, vertices), last(vertices[1]));
17+
fillrange=0,
18+
fillcolor=:gray,
19+
fillalpha=0.2,
20+
linecolor=:black,
21+
label=L"\mathrm{conv}(\mathcal{Y}(x))",
22+
)
23+
end
24+
25+
function _plot_objective!(pl, θ)
26+
Plots.plot!(
27+
pl, [0.0, θ[1]], [0.0, θ[2]]; color="#9558B2", arrow=true, lw=2, label=nothing
28+
)
29+
Plots.annotate!(pl, [-0.2 * θ[1]], [-0.2 * θ[2]], [L"\theta"])
30+
return pl
31+
end
32+
33+
function _plot_y!(pl, y)
34+
return Plots.scatter!(
35+
pl,
36+
[y[1]],
37+
[y[2]];
38+
color="#CB3C33",
39+
markersize=9,
40+
markershape=:square,
41+
label=L"f(\theta)",
42+
)
43+
end
44+
45+
function _plot_maximizer!(pl, θ, instance, maximizer)
46+
ŷ = maximizer(θ; instance)
47+
return _plot_y!(pl, ŷ)
48+
end
49+
50+
has_visualization(::Argmax2DBenchmark) = true
51+
52+
function plot_instance(::Argmax2DBenchmark, sample::DataSample)
53+
pl = _init_plot()
54+
_plot_polytope!(pl, sample.instance)
55+
return pl
56+
end
57+
58+
function plot_solution(::Argmax2DBenchmark, sample::DataSample)
59+
pl = _init_plot()
60+
_plot_polytope!(pl, sample.instance)
61+
_plot_objective!(pl, sample.θ)
62+
return _plot_y!(pl, sample.y)
63+
end
64+
65+
function plot_solution(::Argmax2DBenchmark, sample::DataSample, y; θ=sample.θ)
66+
pl = _init_plot()
67+
_plot_polytope!(pl, sample.instance)
68+
_plot_objective!(pl, θ)
69+
return _plot_y!(pl, y)
70+
end

0 commit comments

Comments
 (0)