Skip to content

Commit 0f5590e

Browse files
authored
Merge pull request #58 from JuliaDecisionFocusedLearning/revamp-plot-interface
Refactor visualization interface: optional Plots, uniform API
2 parents 668ecf9 + 64f52d7 commit 0f5590e

23 files changed

+900
-278
lines changed

Project.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ authors = ["Members of JuliaDecisionFocusedLearning"]
77
projects = ["docs", "test"]
88

99
[deps]
10-
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
1110
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1211
ConstrainedShortestPaths = "b3798467-87dc-4d99-943d-35a1bd39e395"
1312
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
@@ -27,7 +26,6 @@ LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
2726
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2827
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
2928
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
30-
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
3129
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
3230
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3331
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
@@ -37,8 +35,13 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3735
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3836
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3937

38+
[weakdeps]
39+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
40+
41+
[extensions]
42+
DFLBenchmarksPlotsExt = "Plots"
43+
4044
[compat]
41-
Colors = "0.13.1"
4245
Combinatorics = "1.0.3"
4346
ConstrainedShortestPaths = "0.6.0"
4447
DataDeps = "0.7"

docs/src/benchmarks/maintenance.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@ The Maintenance problem with resource constraint is a sequential decision-making
77

88
### Overview
99

10-
In this benchmark, a system consists of $N$ identical components, each of which can degrade over $n$ discrete states. State $1$ means that the component is new, state $n$ means that the component is failed. At each time step, the agent can maintain up to $K$ components.
10+
In this benchmark, a system consists of ``N`` identical components, each of which can degrade over ``n`` discrete states. State ``1`` means that the component is new, state $n$ means that the component is failed. At each time step, the agent can maintain up to $K$ components.
1111

1212
This forms an endogenous multistage stochastic optimization problem, where the agent must plan maintenance actions over the horizon.
1313

1414
### Mathematical Formulation
1515

1616
The maintenance problem can be formulated as a finite-horizon Markov Decision Process (MDP) with the following components:
1717

18-
**State Space** $\mathcal{S}$: At time step $t$, the state $s_t \in [1:n]^N$ is the degradation state for each component.
18+
**State Space** ``\mathcal{S}``: At time step ``t``, the state ``s_t \in [1:n]^N`` is the degradation state for each component.
1919

20-
**Action Space** $\mathcal{A}$: The action at time $t$ is the set of components that are maintained at time $t$:
20+
**Action Space** ``\mathcal{A}``: The action at time ``t`` is the set of components that are maintained at time ``t``:
2121
```math
2222
a_t \subseteq \{1, 2, \ldots, N\} \text{ such that } |a_t| \leq K
2323
```
@@ -51,9 +51,9 @@ Here, \(p\) is the degradation probability, \(s_t^i\) is the current state of co
5151

5252
The immediate cost at time \(t\) is:
5353

54-
$$
54+
```math
5555
c(s_t, a_t) = \Big( c_m \cdot |a_t| + c_f \cdot \#\{ i : s_t^i = n \} \Big)
56-
$$
56+
```
5757

5858
Where:
5959

docs/src/custom_benchmarks.md

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,13 @@ generate_maximizer(bench::MyBenchmark)
6969
### Optional methods
7070

7171
```julia
72+
generate_baseline_policies(bench::MyBenchmark) -> collection of callables
7273
is_minimization_problem(bench::MyBenchmark) -> Bool # default: false (maximization)
7374
objective_value(bench::MyBenchmark, sample::DataSample, y) -> Real
7475
compute_gap(bench::MyBenchmark, dataset, model, maximizer) -> Float64
75-
plot_data(bench::MyBenchmark, sample::DataSample; kwargs...)
76-
plot_instance(bench::MyBenchmark, instance; kwargs...)
77-
plot_solution(bench::MyBenchmark, sample::DataSample, y; kwargs...)
78-
generate_baseline_policies(bench::MyBenchmark) -> collection of callables
76+
has_visualization(bench::MyBenchmark) -> Bool # default: false; return true when plot methods are implemented/available
77+
plot_instance(bench::MyBenchmark, sample::DataSample; kwargs...)
78+
plot_solution(bench::MyBenchmark, sample::DataSample; kwargs...)
7979
```
8080

8181
---
@@ -148,6 +148,13 @@ generate_baseline_policies(bench::MyDynamicBenchmark)
148148
# Each callable performs a full episode rollout and returns the trajectory.
149149
```
150150

151+
### Optional visualization methods
152+
153+
```julia
154+
plot_trajectory(bench::MyDynamicBenchmark, traj::Vector{DataSample}; kwargs...)
155+
animate_trajectory(bench::MyDynamicBenchmark, traj::Vector{DataSample}; kwargs...)
156+
```
157+
151158
`generate_dataset` for dynamic benchmarks **requires** a `target_policy` kwarg,
152159
there is no default. The `target_policy` must be a callable `(env) -> Vector{DataSample}`.
153160

docs/src/tutorials/warcraft_tutorial.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ The map is represented as a 2D image representing a 12x12 grid, each cell having
88

99
# First, let's load the package and create a benchmark object as follows:
1010
using DecisionFocusedLearningBenchmarks
11+
using Plots
1112
b = WarcraftBenchmark()
1213

1314
# ## Dataset generation
@@ -32,8 +33,8 @@ y_true = sample.y
3233
# `context` is not used in this benchmark (no solver kwargs needed), so it is empty:
3334
isempty(sample.context)
3435

35-
# For some benchmarks, we provide the following plotting method [`plot_data`](@ref) to visualize the data:
36-
plot_data(b, sample)
36+
# For some benchmarks, we provide the following plotting method [`plot_solution`](@ref) to visualize the data:
37+
plot_solution(b, sample)
3738
# We can see here the terrain image, the true terrain weights, and the true shortest path avoiding the high cost cells.
3839

3940
# ## Building a pipeline
@@ -50,7 +51,7 @@ maximizer = generate_maximizer(b; dijkstra=true)
5051
# In the case o fthe Warcraft benchmark, the method has an additional keyword argument to chose the algorithm to use: Dijkstra's algorithm or Bellman-Ford algorithm.
5152
y = maximizer(θ)
5253
# As we can see, currently the pipeline predicts random noise as cell weights, and therefore the maximizer returns a straight line path.
53-
plot_data(b, DataSample(; x, θ, y))
54+
plot_solution(b, DataSample(; x, θ, y))
5455
# We can evaluate the current pipeline performance using the optimality gap metric:
5556
starting_gap = compute_gap(b, test_dataset, model, maximizer)
5657

@@ -59,7 +60,6 @@ starting_gap = compute_gap(b, test_dataset, model, maximizer)
5960
# We can now train the model using the InferOpt.jl package:
6061
using InferOpt
6162
using Flux
62-
using Plots
6363

6464
perturbed_maximizer = PerturbedMultiplicative(maximizer; ε=0.2, nb_samples=100)
6565
loss = FenchelYoungLoss(perturbed_maximizer)
@@ -85,7 +85,7 @@ final_gap = compute_gap(b, test_dataset, model, maximizer)
8585
#
8686
θ = model(x)
8787
y = maximizer(θ)
88-
plot_data(b, DataSample(; x, θ, y))
88+
plot_solution(b, DataSample(; x, θ, y))
8989

9090
using Test #src
9191
@test final_gap < starting_gap #src

docs/src/using_benchmarks.md

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ A benchmark bundles a problem family (an instance generator, a combinatorial sol
1010
Three abstract types cover the main settings:
1111
- **`AbstractBenchmark`**: static problems (one instance, one decision)
1212
- **`AbstractStochasticBenchmark{exogenous}`**: stochastic problems (type parameter indicates whether uncertainty is exogenous)
13-
- **`AbstractDynamicBenchmark`**: sequential / multi-stage problems
13+
- **`AbstractDynamicBenchmark{exogenous}`**: sequential / multi-stage problems
1414

1515
The sections below explain what changes between these settings. For most purposes, start with a static benchmark to understand the core workflow.
1616

@@ -180,10 +180,29 @@ rewards, samples = evaluate_policy!(pol, envs, n_episodes)
180180

181181
## Visualization
182182

183-
Where implemented, benchmarks provide benchmark-specific plotting helpers:
184-
183+
Plots is an **optional** dependency, load it with `using Plots` to unlock the plot functions. Not all benchmarks support visualization, call `has_visualization(bench)` to check.
185184
```julia
186-
plot_data(bench, sample) # overview of a data sample
187-
plot_instance(bench, instance) # raw problem instance
188-
plot_solution(bench, sample, y) # overlay solution on instance
185+
using Plots
186+
187+
bench = Argmax2DBenchmark()
188+
dataset = generate_dataset(bench, 10)
189+
sample = dataset[1]
190+
191+
has_visualization(bench) # true
192+
plot_instance(bench, sample) # problem geometry only
193+
plot_solution(bench, sample) # sample.y overlaid on the instance
194+
plot_solution(bench, sample, y) # convenience 3-arg form: override y before plotting
195+
196+
# Dynamic benchmarks only
197+
traj = generate_anticipative_solver(bench)(env)
198+
plot_trajectory(bench, traj) # grid of epoch subplots
199+
anim = animate_trajectory(bench, traj; fps=2)
200+
gif(anim, "episode.gif")
189201
```
202+
203+
- `has_visualization(bench)`: returns `true` for benchmarks that implement plot support (if Plots is loaded).
204+
- `plot_instance(bench, sample; kwargs...)`: renders the problem geometry without any solution.
205+
- `plot_solution(bench, sample; kwargs...)`: renders `sample.y` overlaid on the instance.
206+
- `plot_solution(bench, sample, y; kwargs...)`: 3-arg convenience form that overrides `y` before plotting.
207+
- `plot_trajectory(bench, traj; kwargs...)`: dynamic benchmarks only; produces a grid of per-epoch subplots.
208+
- `animate_trajectory(bench, traj; kwargs...)`: dynamic benchmarks only, returns a `Plots.Animation` that can be saved with `gif(anim, "file.gif")`.

ext/DFLBenchmarksPlotsExt.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
module DFLBenchmarksPlotsExt
2+
3+
using DecisionFocusedLearningBenchmarks
4+
using DocStringExtensions: TYPEDSIGNATURES
5+
using LaTeXStrings: @L_str
6+
using Plots
7+
import DecisionFocusedLearningBenchmarks:
8+
has_visualization, plot_instance, plot_solution, plot_trajectory, animate_trajectory
9+
10+
include("plots/argmax2d_plots.jl")
11+
include("plots/warcraft_plots.jl")
12+
include("plots/svs_plots.jl")
13+
include("plots/dvs_plots.jl")
14+
15+
"""
16+
plot_solution(bench::AbstractBenchmark, sample::DataSample, y; kwargs...)
17+
18+
Reconstruct a new sample with `y` overridden and delegate to the 2-arg
19+
[`plot_solution`](@ref). Only available when `Plots` is loaded.
20+
"""
21+
function plot_solution(bench::AbstractBenchmark, sample::DataSample, y; kwargs...)
22+
return plot_solution(
23+
bench,
24+
DataSample(; sample.context..., x=sample.x, θ=sample.θ, y=y, extra=sample.extra);
25+
kwargs...,
26+
)
27+
end
28+
29+
end

ext/plots/argmax2d_plots.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
function _init_plot(title=""; kwargs...)
2+
pl = Plots.plot(;
3+
aspect_ratio=:equal,
4+
legend=:outerleft,
5+
xlim=(-1.1, 1.1),
6+
ylim=(-1.1, 1.1),
7+
title=title,
8+
kwargs...,
9+
)
10+
return pl
11+
end
12+
13+
function _plot_polytope!(pl, vertices)
14+
return Plots.plot!(
15+
pl,
16+
vcat(map(first, vertices), first(vertices[1])),
17+
vcat(map(last, vertices), last(vertices[1]));
18+
fillrange=0,
19+
fillcolor=:gray,
20+
fillalpha=0.2,
21+
linecolor=:black,
22+
label=L"\mathrm{conv}(\mathcal{Y}(x))",
23+
)
24+
end
25+
26+
function _plot_objective!(pl, θ)
27+
Plots.plot!(
28+
pl, [0.0, θ[1]], [0.0, θ[2]]; color="#9558B2", arrow=true, lw=2, label=nothing
29+
)
30+
Plots.annotate!(pl, [-0.2 * θ[1]], [-0.2 * θ[2]], [L"\theta"])
31+
return pl
32+
end
33+
34+
function _plot_y!(pl, y)
35+
return Plots.scatter!(
36+
pl,
37+
[y[1]],
38+
[y[2]];
39+
color="#CB3C33",
40+
markersize=9,
41+
markershape=:square,
42+
label=L"f(\theta)",
43+
)
44+
end
45+
46+
has_visualization(::Argmax2DBenchmark) = true
47+
48+
function plot_instance(::Argmax2DBenchmark, sample::DataSample; kwargs...)
49+
pl = _init_plot(; kwargs...)
50+
_plot_polytope!(pl, sample.instance)
51+
return pl
52+
end
53+
54+
function plot_solution(::Argmax2DBenchmark, sample::DataSample; kwargs...)
55+
pl = _init_plot(; kwargs...)
56+
_plot_polytope!(pl, sample.instance)
57+
_plot_objective!(pl, sample.θ)
58+
return _plot_y!(pl, sample.y)
59+
end
60+
61+
function plot_solution(::Argmax2DBenchmark, sample::DataSample, y; θ=sample.θ, kwargs...)
62+
pl = _init_plot(; kwargs...)
63+
_plot_polytope!(pl, sample.instance)
64+
_plot_objective!(pl, θ)
65+
return _plot_y!(pl, y)
66+
end

0 commit comments

Comments
 (0)