diff --git a/Project.toml b/Project.toml index 150c205..7ad625a 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ authors = ["Members of JuliaDecisionFocusedLearning"] projects = ["docs", "test"] [deps] -Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" ConstrainedShortestPaths = "b3798467-87dc-4d99-943d-35a1bd39e395" DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe" @@ -27,7 +26,6 @@ LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" @@ -37,8 +35,13 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +[weakdeps] +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" + +[extensions] +DFLBenchmarksPlotsExt = "Plots" + [compat] -Colors = "0.13.1" Combinatorics = "1.0.3" ConstrainedShortestPaths = "0.6.0" DataDeps = "0.7" diff --git a/docs/src/benchmarks/maintenance.md b/docs/src/benchmarks/maintenance.md index 236501c..060099d 100644 --- a/docs/src/benchmarks/maintenance.md +++ b/docs/src/benchmarks/maintenance.md @@ -7,7 +7,7 @@ The Maintenance problem with resource constraint is a sequential decision-making ### Overview -In this benchmark, a system consists of $N$ identical components, each of which can degrade over $n$ discrete states. State $1$ means that the component is new, state $n$ means that the component is failed. At each time step, the agent can maintain up to $K$ components. +In this benchmark, a system consists of ``N`` identical components, each of which can degrade over ``n`` discrete states. State ``1`` means that the component is new, state $n$ means that the component is failed. At each time step, the agent can maintain up to $K$ components. This forms an endogenous multistage stochastic optimization problem, where the agent must plan maintenance actions over the horizon. @@ -15,9 +15,9 @@ This forms an endogenous multistage stochastic optimization problem, where the a The maintenance problem can be formulated as a finite-horizon Markov Decision Process (MDP) with the following components: -**State Space** $\mathcal{S}$: At time step $t$, the state $s_t \in [1:n]^N$ is the degradation state for each component. +**State Space** ``\mathcal{S}``: At time step ``t``, the state ``s_t \in [1:n]^N`` is the degradation state for each component. -**Action Space** $\mathcal{A}$: The action at time $t$ is the set of components that are maintained at time $t$: +**Action Space** ``\mathcal{A}``: The action at time ``t`` is the set of components that are maintained at time ``t``: ```math a_t \subseteq \{1, 2, \ldots, N\} \text{ such that } |a_t| \leq K ``` @@ -51,9 +51,9 @@ Here, \(p\) is the degradation probability, \(s_t^i\) is the current state of co The immediate cost at time \(t\) is: -$$ +```math c(s_t, a_t) = \Big( c_m \cdot |a_t| + c_f \cdot \#\{ i : s_t^i = n \} \Big) -$$ +``` Where: diff --git a/docs/src/custom_benchmarks.md b/docs/src/custom_benchmarks.md index 289152d..f4a4e7f 100644 --- a/docs/src/custom_benchmarks.md +++ b/docs/src/custom_benchmarks.md @@ -69,13 +69,13 @@ generate_maximizer(bench::MyBenchmark) ### Optional methods ```julia +generate_baseline_policies(bench::MyBenchmark) -> collection of callables is_minimization_problem(bench::MyBenchmark) -> Bool # default: false (maximization) objective_value(bench::MyBenchmark, sample::DataSample, y) -> Real compute_gap(bench::MyBenchmark, dataset, model, maximizer) -> Float64 -plot_data(bench::MyBenchmark, sample::DataSample; kwargs...) -plot_instance(bench::MyBenchmark, instance; kwargs...) -plot_solution(bench::MyBenchmark, sample::DataSample, y; kwargs...) -generate_baseline_policies(bench::MyBenchmark) -> collection of callables +has_visualization(bench::MyBenchmark) -> Bool # default: false; return true when plot methods are implemented/available +plot_instance(bench::MyBenchmark, sample::DataSample; kwargs...) +plot_solution(bench::MyBenchmark, sample::DataSample; kwargs...) ``` --- @@ -148,6 +148,13 @@ generate_baseline_policies(bench::MyDynamicBenchmark) # Each callable performs a full episode rollout and returns the trajectory. ``` +### Optional visualization methods + +```julia +plot_trajectory(bench::MyDynamicBenchmark, traj::Vector{DataSample}; kwargs...) +animate_trajectory(bench::MyDynamicBenchmark, traj::Vector{DataSample}; kwargs...) +``` + `generate_dataset` for dynamic benchmarks **requires** a `target_policy` kwarg, there is no default. The `target_policy` must be a callable `(env) -> Vector{DataSample}`. diff --git a/docs/src/tutorials/warcraft_tutorial.jl b/docs/src/tutorials/warcraft_tutorial.jl index 8f78db8..b801d7a 100644 --- a/docs/src/tutorials/warcraft_tutorial.jl +++ b/docs/src/tutorials/warcraft_tutorial.jl @@ -8,6 +8,7 @@ The map is represented as a 2D image representing a 12x12 grid, each cell having # First, let's load the package and create a benchmark object as follows: using DecisionFocusedLearningBenchmarks +using Plots b = WarcraftBenchmark() # ## Dataset generation @@ -32,8 +33,8 @@ y_true = sample.y # `context` is not used in this benchmark (no solver kwargs needed), so it is empty: isempty(sample.context) -# For some benchmarks, we provide the following plotting method [`plot_data`](@ref) to visualize the data: -plot_data(b, sample) +# For some benchmarks, we provide the following plotting method [`plot_solution`](@ref) to visualize the data: +plot_solution(b, sample) # We can see here the terrain image, the true terrain weights, and the true shortest path avoiding the high cost cells. # ## Building a pipeline @@ -50,7 +51,7 @@ maximizer = generate_maximizer(b; dijkstra=true) # In the case o fthe Warcraft benchmark, the method has an additional keyword argument to chose the algorithm to use: Dijkstra's algorithm or Bellman-Ford algorithm. y = maximizer(θ) # As we can see, currently the pipeline predicts random noise as cell weights, and therefore the maximizer returns a straight line path. -plot_data(b, DataSample(; x, θ, y)) +plot_solution(b, DataSample(; x, θ, y)) # We can evaluate the current pipeline performance using the optimality gap metric: starting_gap = compute_gap(b, test_dataset, model, maximizer) @@ -59,7 +60,6 @@ starting_gap = compute_gap(b, test_dataset, model, maximizer) # We can now train the model using the InferOpt.jl package: using InferOpt using Flux -using Plots perturbed_maximizer = PerturbedMultiplicative(maximizer; ε=0.2, nb_samples=100) loss = FenchelYoungLoss(perturbed_maximizer) @@ -85,7 +85,7 @@ final_gap = compute_gap(b, test_dataset, model, maximizer) # θ = model(x) y = maximizer(θ) -plot_data(b, DataSample(; x, θ, y)) +plot_solution(b, DataSample(; x, θ, y)) using Test #src @test final_gap < starting_gap #src diff --git a/docs/src/using_benchmarks.md b/docs/src/using_benchmarks.md index b136424..bd64437 100644 --- a/docs/src/using_benchmarks.md +++ b/docs/src/using_benchmarks.md @@ -10,7 +10,7 @@ A benchmark bundles a problem family (an instance generator, a combinatorial sol Three abstract types cover the main settings: - **`AbstractBenchmark`**: static problems (one instance, one decision) - **`AbstractStochasticBenchmark{exogenous}`**: stochastic problems (type parameter indicates whether uncertainty is exogenous) -- **`AbstractDynamicBenchmark`**: sequential / multi-stage problems +- **`AbstractDynamicBenchmark{exogenous}`**: sequential / multi-stage problems The sections below explain what changes between these settings. For most purposes, start with a static benchmark to understand the core workflow. @@ -180,10 +180,29 @@ rewards, samples = evaluate_policy!(pol, envs, n_episodes) ## Visualization -Where implemented, benchmarks provide benchmark-specific plotting helpers: - +Plots is an **optional** dependency, load it with `using Plots` to unlock the plot functions. Not all benchmarks support visualization, call `has_visualization(bench)` to check. ```julia -plot_data(bench, sample) # overview of a data sample -plot_instance(bench, instance) # raw problem instance -plot_solution(bench, sample, y) # overlay solution on instance +using Plots + +bench = Argmax2DBenchmark() +dataset = generate_dataset(bench, 10) +sample = dataset[1] + +has_visualization(bench) # true +plot_instance(bench, sample) # problem geometry only +plot_solution(bench, sample) # sample.y overlaid on the instance +plot_solution(bench, sample, y) # convenience 3-arg form: override y before plotting + +# Dynamic benchmarks only +traj = generate_anticipative_solver(bench)(env) +plot_trajectory(bench, traj) # grid of epoch subplots +anim = animate_trajectory(bench, traj; fps=2) +gif(anim, "episode.gif") ``` + +- `has_visualization(bench)`: returns `true` for benchmarks that implement plot support (if Plots is loaded). +- `plot_instance(bench, sample; kwargs...)`: renders the problem geometry without any solution. +- `plot_solution(bench, sample; kwargs...)`: renders `sample.y` overlaid on the instance. +- `plot_solution(bench, sample, y; kwargs...)`: 3-arg convenience form that overrides `y` before plotting. +- `plot_trajectory(bench, traj; kwargs...)`: dynamic benchmarks only; produces a grid of per-epoch subplots. +- `animate_trajectory(bench, traj; kwargs...)`: dynamic benchmarks only, returns a `Plots.Animation` that can be saved with `gif(anim, "file.gif")`. diff --git a/ext/DFLBenchmarksPlotsExt.jl b/ext/DFLBenchmarksPlotsExt.jl new file mode 100644 index 0000000..0a5caae --- /dev/null +++ b/ext/DFLBenchmarksPlotsExt.jl @@ -0,0 +1,29 @@ +module DFLBenchmarksPlotsExt + +using DecisionFocusedLearningBenchmarks +using DocStringExtensions: TYPEDSIGNATURES +using LaTeXStrings: @L_str +using Plots +import DecisionFocusedLearningBenchmarks: + has_visualization, plot_instance, plot_solution, plot_trajectory, animate_trajectory + +include("plots/argmax2d_plots.jl") +include("plots/warcraft_plots.jl") +include("plots/svs_plots.jl") +include("plots/dvs_plots.jl") + +""" + plot_solution(bench::AbstractBenchmark, sample::DataSample, y; kwargs...) + +Reconstruct a new sample with `y` overridden and delegate to the 2-arg +[`plot_solution`](@ref). Only available when `Plots` is loaded. +""" +function plot_solution(bench::AbstractBenchmark, sample::DataSample, y; kwargs...) + return plot_solution( + bench, + DataSample(; sample.context..., x=sample.x, θ=sample.θ, y=y, extra=sample.extra); + kwargs..., + ) +end + +end diff --git a/ext/plots/argmax2d_plots.jl b/ext/plots/argmax2d_plots.jl new file mode 100644 index 0000000..cdb9800 --- /dev/null +++ b/ext/plots/argmax2d_plots.jl @@ -0,0 +1,66 @@ +function _init_plot(title=""; kwargs...) + pl = Plots.plot(; + aspect_ratio=:equal, + legend=:outerleft, + xlim=(-1.1, 1.1), + ylim=(-1.1, 1.1), + title=title, + kwargs..., + ) + return pl +end + +function _plot_polytope!(pl, vertices) + return Plots.plot!( + pl, + vcat(map(first, vertices), first(vertices[1])), + vcat(map(last, vertices), last(vertices[1])); + fillrange=0, + fillcolor=:gray, + fillalpha=0.2, + linecolor=:black, + label=L"\mathrm{conv}(\mathcal{Y}(x))", + ) +end + +function _plot_objective!(pl, θ) + Plots.plot!( + pl, [0.0, θ[1]], [0.0, θ[2]]; color="#9558B2", arrow=true, lw=2, label=nothing + ) + Plots.annotate!(pl, [-0.2 * θ[1]], [-0.2 * θ[2]], [L"\theta"]) + return pl +end + +function _plot_y!(pl, y) + return Plots.scatter!( + pl, + [y[1]], + [y[2]]; + color="#CB3C33", + markersize=9, + markershape=:square, + label=L"f(\theta)", + ) +end + +has_visualization(::Argmax2DBenchmark) = true + +function plot_instance(::Argmax2DBenchmark, sample::DataSample; kwargs...) + pl = _init_plot(; kwargs...) + _plot_polytope!(pl, sample.instance) + return pl +end + +function plot_solution(::Argmax2DBenchmark, sample::DataSample; kwargs...) + pl = _init_plot(; kwargs...) + _plot_polytope!(pl, sample.instance) + _plot_objective!(pl, sample.θ) + return _plot_y!(pl, sample.y) +end + +function plot_solution(::Argmax2DBenchmark, sample::DataSample, y; θ=sample.θ, kwargs...) + pl = _init_plot(; kwargs...) + _plot_polytope!(pl, sample.instance) + _plot_objective!(pl, θ) + return _plot_y!(pl, y) +end diff --git a/ext/plots/dvs_plots.jl b/ext/plots/dvs_plots.jl new file mode 100644 index 0000000..0b61a5e --- /dev/null +++ b/ext/plots/dvs_plots.jl @@ -0,0 +1,491 @@ +import DecisionFocusedLearningBenchmarks.DynamicVehicleScheduling as DVS +using Printf: @sprintf + +has_visualization(::DynamicVehicleSchedulingBenchmark) = true + +# ── helpers (moved from static_vsp/plot.jl) ───────────────────────────────── + +function _plot_static_instance( + x_depot, + y_depot, + x_customers, + y_customers; + customer_markersize=4, + depot_markersize=7, + alpha_depot=0.8, + customer_color=:lightblue, + depot_color=:lightgreen, + kwargs..., +) + fig = Plots.plot(; + legend=:topleft, xlabel="x coordinate", ylabel="y coordinate", kwargs... + ) + Plots.scatter!( + fig, + x_customers, + y_customers; + label="Customers", + markercolor=customer_color, + marker=:circle, + markersize=customer_markersize, + ) + Plots.scatter!( + fig, + [x_depot], + [y_depot]; + label="Depot", + markercolor=depot_color, + marker=:rect, + markersize=depot_markersize, + alpha=alpha_depot, + ) + return fig +end + +# ── plot_state ─────────────────────────────────────────────────────────────── + +""" +$TYPEDSIGNATURES + +Plot a given DVSPState showing depot, must-dispatch customers, and postponable customers. +""" +function plot_state( + state::DVS.DVSPState; + customer_markersize=6, + depot_markersize=8, + alpha_depot=0.8, + depot_color=:lightgreen, + depot_marker=:rect, + must_dispatch_color=:red, + postponable_color=:lightblue, + must_dispatch_marker=:star5, + postponable_marker=:utriangle, + show_axis_labels=true, + markerstrokewidth=0.5, + show_colorbar=true, + kwargs..., +) + (; x_depot, y_depot, x_customers, y_customers, is_must_dispatch, start_times) = DVS.build_state_data( + state + ) + + plot_args = Dict( + :legend => :topleft, :title => "DVSP State - Epoch $(state.current_epoch)" + ) + + if show_axis_labels + plot_args[:xlabel] = "x coordinate" + plot_args[:ylabel] = "y coordinate" + end + + for (k, v) in kwargs + plot_args[k] = v + end + + fig = Plots.plot(; plot_args...) + + Plots.scatter!( + fig, + [x_depot], + [y_depot]; + label="Depot", + markercolor=depot_color, + marker=depot_marker, + markersize=depot_markersize, + alpha=alpha_depot, + markerstrokewidth=markerstrokewidth, + ) + + scatter_must_dispatch_args = Dict( + :label => "Must-dispatch customers", + :markercolor => must_dispatch_color, + :marker => must_dispatch_marker, + :markersize => customer_markersize, + :markerstrokewidth => markerstrokewidth, + ) + + scatter_postponable_args = Dict( + :label => "Postponable customers", + :markercolor => postponable_color, + :marker => postponable_marker, + :markersize => customer_markersize, + :markerstrokewidth => markerstrokewidth, + ) + if show_colorbar + scatter_must_dispatch_args[:marker_z] = start_times[is_must_dispatch] + scatter_postponable_args[:marker_z] = start_times[.!is_must_dispatch] + scatter_postponable_args[:colormap] = :plasma + scatter_must_dispatch_args[:colormap] = :plasma + scatter_postponable_args[:colorbar] = :right + scatter_must_dispatch_args[:colorbar] = :right + end + + if length(x_customers[is_must_dispatch]) > 0 + Plots.scatter!( + fig, + x_customers[is_must_dispatch], + y_customers[is_must_dispatch]; + scatter_must_dispatch_args..., + ) + end + + if length(x_customers[.!is_must_dispatch]) > 0 + Plots.scatter!( + fig, + x_customers[.!is_must_dispatch], + y_customers[.!is_must_dispatch]; + scatter_postponable_args..., + ) + end + + return fig +end + +# ── plot_routes ────────────────────────────────────────────────────────────── + +function plot_routes( + state::DVS.DVSPState, + routes::Vector{Vector{Int}}; + reward=nothing, + route_color=nothing, + route_linewidth=2, + route_alpha=0.8, + kwargs..., +) + cost_text = if !isnothing(reward) + " (" * @sprintf("%.2f", -reward) * ")" + else + "" + end + fig = plot_state( + state; + kwargs..., + title="DVSP State with Routes - Epoch $(state.current_epoch)$cost_text", + ) + + (; x_depot, y_depot, x_customers, y_customers) = DVS.build_state_data(state) + + x = vcat(x_depot, x_customers) + y = vcat(y_depot, y_customers) + + plot_args = Dict( + :linewidth => route_linewidth, :alpha => route_alpha, :z_order => :back; + ) + + if !isnothing(route_color) + plot_args[:color] = route_color + end + + for route in routes + if !isempty(route) + route_x = vcat(x_depot, x[route], x_depot) + route_y = vcat(y_depot, y[route], y_depot) + Plots.plot!(fig, route_x, route_y; label=false, plot_args...) + end + end + + return fig +end + +function plot_routes(state::DVS.DVSPState, routes::BitMatrix; kwargs...) + route_vectors = DVS.decode_bitmatrix_to_routes(routes) + return plot_routes(state, route_vectors; kwargs...) +end + +# ── interface methods ──────────────────────────────────────────────────────── + +function plot_instance( + bench::DynamicVehicleSchedulingBenchmark, sample::DataSample; kwargs... +) + return plot_state(sample.instance; kwargs...) +end + +function plot_solution( + bench::DynamicVehicleSchedulingBenchmark, sample::DataSample; kwargs... +) + return plot_routes(sample.instance, sample.y; reward=sample.reward, kwargs...) +end + +function plot_trajectory( + bench::DynamicVehicleSchedulingBenchmark, + traj::Vector{<:DataSample}; + plot_routes_flag=true, + cols=nothing, + figsize=nothing, + margin=0.05, + legend_margin_factor=0.15, + titlefontsize=14, + guidefontsize=12, + legendfontsize=11, + tickfontsize=10, + show_axis_labels=false, + show_colorbar=true, + kwargs..., +) + if length(traj) == 0 + error("No data samples provided") + end + + pd = DVS.build_plot_data(traj) + n_epochs = length(pd) + + if isnothing(cols) + cols = min(n_epochs, 3) + end + rows = ceil(Int, n_epochs / cols) + + x_min = minimum(min(data.x_depot, minimum(data.x_customers)) for data in pd) + x_max = maximum(max(data.x_depot, maximum(data.x_customers)) for data in pd) + y_min = minimum(min(data.y_depot, minimum(data.y_customers)) for data in pd) + y_max = maximum(max(data.y_depot, maximum(data.y_customers)) for data in pd) + + xlims = (x_min - margin, x_max + margin) + y_range = y_max - y_min + 2 * margin + legend_margin = y_range * legend_margin_factor + ylims = (y_min - margin, y_max + margin + legend_margin) + + min_start_time = minimum(minimum(data.start_times) for data in pd) + max_start_time = maximum(maximum(data.start_times) for data in pd) + clims = (min_start_time, max_start_time) + + plots = map(1:n_epochs) do i + sample = traj[i] + state = sample.instance + reward = sample.reward + + common_kwargs = Dict( + :xlims => xlims, + :ylims => ylims, + :clims => clims, + :show_colorbar => show_colorbar, + :titlefontsize => titlefontsize, + :guidefontsize => guidefontsize, + :legendfontsize => legendfontsize, + :tickfontsize => tickfontsize, + :show_axis_labels => show_axis_labels, + :markerstrokewidth => 0.5, + ) + + if plot_routes_flag + fig = plot_routes( + state, + sample.y; + reward=reward, + show_route_labels=false, + common_kwargs..., + kwargs..., + ) + else + fig = plot_state(state; common_kwargs..., kwargs...) + end + + return fig + end + + if isnothing(figsize) + plot_width = 600 * cols + plot_height = 500 * rows + figsize = (plot_width, plot_height) + end + + combined_plot = Plots.plot( + plots...; layout=(rows, cols), size=figsize, link=:both, clims=clims + ) + + return combined_plot +end + +function animate_trajectory( + bench::DynamicVehicleSchedulingBenchmark, + traj::Vector{<:DataSample}; + figsize=(800, 600), + margin=0.1, + legend_margin_factor=0.2, + titlefontsize=16, + guidefontsize=14, + legendfontsize=12, + tickfontsize=11, + show_axis_labels=true, + show_cost_bar=true, + show_colorbar=false, + cost_bar_width=0.05, + cost_bar_margin=0.02, + cost_bar_color_palette=:turbo, + kwargs..., +) + pd = DVS.build_plot_data(traj) + epoch_costs = [-sample.reward for sample in traj] + + x_min = minimum(min(data.x_depot, minimum(data.x_customers)) for data in pd) + x_max = maximum(max(data.x_depot, maximum(data.x_customers)) for data in pd) + y_min = minimum(min(data.y_depot, minimum(data.y_customers)) for data in pd) + y_max = maximum(max(data.y_depot, maximum(data.y_customers)) for data in pd) + + xlims = (x_min - margin, x_max + margin) + y_range = y_max - y_min + 2 * margin + legend_margin = y_range * legend_margin_factor + ylims = (y_min - margin, y_max + margin + legend_margin) + + min_start_time = minimum(minimum(data.start_times) for data in pd) + max_start_time = maximum(maximum(data.start_times) for data in pd) + clims = (min_start_time, max_start_time) + + if show_cost_bar + x_min, x_max = xlims + x_range = x_max - x_min + cost_bar_space = x_range * (cost_bar_width + cost_bar_margin) + xlims = (x_min, x_max + cost_bar_space) + end + + frame_plan = [] + for (epoch_idx, _) in enumerate(traj) + push!(frame_plan, (epoch_idx, :state)) + push!(frame_plan, (epoch_idx, :routes)) + end + + total_frames = length(frame_plan) + + anim = @animate for frame_idx in 1:total_frames + epoch_idx, frame_type = frame_plan[frame_idx] + sample = traj[epoch_idx] + state = sample.instance + + if frame_type == :routes + fig = plot_routes( + state, + sample.y; + xlims=xlims, + ylims=ylims, + clims=clims, + title="Epoch $(state.current_epoch) - Routes Dispatched", + titlefontsize=titlefontsize, + guidefontsize=guidefontsize, + legendfontsize=legendfontsize, + tickfontsize=tickfontsize, + show_axis_labels=show_axis_labels, + markerstrokewidth=0.5, + show_route_labels=false, + show_colorbar=show_colorbar, + size=figsize, + kwargs..., + ) + else + fig = plot_state( + state; + xlims=xlims, + ylims=ylims, + clims=clims, + title="Epoch $(state.current_epoch) - Available Customers", + titlefontsize=titlefontsize, + guidefontsize=guidefontsize, + legendfontsize=legendfontsize, + tickfontsize=tickfontsize, + show_axis_labels=show_axis_labels, + markerstrokewidth=0.5, + show_colorbar=show_colorbar, + size=figsize, + kwargs..., + ) + end + + if show_cost_bar + x_min, x_max = xlims + x_range = x_max - x_min + bar_x_start = x_max - cost_bar_width * x_range + bar_x_end = x_max - cost_bar_margin * x_range + + y_min, y_max = ylims + y_range = y_max - y_min + bar_y_start = y_min + 0.1 * y_range + bar_y_end = y_max - 0.1 * y_range + bar_height = bar_y_end - bar_y_start + + current_cost = 0.0 + for frame_i in 1:frame_idx + frame_epoch, frame_frame_type = frame_plan[frame_i] + if frame_frame_type == :routes && frame_epoch <= length(epoch_costs) + current_cost += epoch_costs[frame_epoch] + end + end + + max_cost = sum(epoch_costs) + if max_cost > 0 + filled_height = (current_cost / max_cost) * bar_height + else + filled_height = 0.0 + end + + Plots.plot!( + fig, + [bar_x_start, bar_x_end, bar_x_end, bar_x_start, bar_x_start], + [bar_y_start, bar_y_start, bar_y_end, bar_y_end, bar_y_start]; + seriestype=:shape, + color=:white, + alpha=0.8, + linecolor=:black, + linewidth=2, + label="", + ) + + cmap = Plots.cgrad(cost_bar_color_palette) + if filled_height > 0 + ratio = current_cost / max_cost + color_at_val = Plots.get(cmap, ratio) + Plots.plot!( + fig, + [bar_x_start, bar_x_end, bar_x_end, bar_x_start, bar_x_start], + [ + bar_y_start, + bar_y_start, + bar_y_start + filled_height, + bar_y_start + filled_height, + bar_y_start, + ]; + seriestype=:shape, + color=color_at_val, + alpha=0.7, + linecolor=:darkred, + linewidth=1, + label="", + ) + end + + cost_text_y = bar_y_start + filled_height + 0.02 * y_range + if cost_text_y > bar_y_end + cost_text_y = bar_y_end + end + + Plots.plot!( + fig, + [bar_x_start + (bar_x_end - bar_x_start) / 2], + [cost_text_y]; + seriestype=:scatter, + markersize=0, + label="", + annotations=( + bar_x_start - 0.04 * x_range, + cost_text_y, + (@sprintf("%.1f", current_cost), :center, guidefontsize), + ), + ) + + Plots.plot!( + fig, + [(bar_x_start + bar_x_end) / 2], + [bar_y_end + 0.05 * y_range]; + seriestype=:scatter, + markersize=0, + label="", + annotations=( + (bar_x_start + bar_x_end) / 2, + bar_y_end + 0.05 * y_range, + ("Cost", :center, guidefontsize), + ), + ) + end + + fig + end + + return anim +end diff --git a/ext/plots/svs_plots.jl b/ext/plots/svs_plots.jl new file mode 100644 index 0000000..9a6161e --- /dev/null +++ b/ext/plots/svs_plots.jl @@ -0,0 +1,104 @@ +import DecisionFocusedLearningBenchmarks.StochasticVehicleScheduling: + Solution, compute_path_list + +has_visualization(::StochasticVehicleSchedulingBenchmark) = true + +function plot_instance( + ::StochasticVehicleSchedulingBenchmark, sample::DataSample; kwargs... +) + @assert hasproperty(sample.instance, :city) "Sample does not contain city information." + (; tasks, district_width, width) = sample.instance.city + ticks = 0:district_width:width + max_time = maximum(t.end_time for t in sample.instance.city.tasks[1:(end - 1)]) + fig = Plots.plot(; + xlabel="x", + ylabel="y", + gridlinewidth=3, + aspect_ratio=:equal, + size=(500, 500), + xticks=ticks, + yticks=ticks, + xlims=(-1, width + 1), + ylims=(-1, width + 1), + clim=(0.0, max_time), + label=nothing, + colorbar_title="Time", + kwargs..., + ) + Plots.scatter!( + fig, + [tasks[1].start_point.x], + [tasks[1].start_point.y]; + label=nothing, + marker=:rect, + markersize=10, + ) + Plots.annotate!( + fig, (tasks[1].start_point.x, tasks[1].start_point.y, Plots.text("0", 10)) + ) + for (i_task, task) in enumerate(tasks[2:(end - 1)]) + (; start_point, end_point) = task + points = [(start_point.x, start_point.y), (end_point.x, end_point.y)] + Plots.plot!(fig, points; color=:black, label=nothing) + Plots.scatter!( + fig, + points[1]; + markersize=10, + marker=:rect, + marker_z=task.start_time, + colormap=:turbo, + label=nothing, + ) + Plots.scatter!( + fig, + points[2]; + markersize=10, + marker=:rect, + marker_z=task.end_time, + colormap=:turbo, + label=nothing, + ) + Plots.annotate!(fig, (points[1]..., Plots.text("$(i_task)", 10))) + end + return fig +end + +function plot_solution( + ::StochasticVehicleSchedulingBenchmark, sample::DataSample; kwargs... +) + @assert hasproperty(sample.instance, :city) "Sample does not contain city information." + (; tasks, district_width, width) = sample.instance.city + ticks = 0:district_width:width + solution = Solution(sample.y, sample.instance) + path_list = compute_path_list(solution) + fig = Plots.plot(; + xlabel="x", + ylabel="y", + legend=false, + gridlinewidth=3, + aspect_ratio=:equal, + size=(500, 500), + xticks=ticks, + yticks=ticks, + xlims=(-1, width + 1), + ylims=(-1, width + 1), + kwargs..., + ) + for path in path_list + X = Float64[] + Y = Float64[] + (; start_point, end_point) = tasks[path[1]] + (; x, y) = end_point + push!(X, x) + push!(Y, y) + for task in path[2:end] + (; start_point, end_point) = tasks[task] + push!(X, start_point.x) + push!(Y, start_point.y) + push!(X, end_point.x) + push!(Y, end_point.y) + end + Plots.plot!(fig, X, Y; marker=:circle) + end + return fig +end diff --git a/ext/plots/warcraft_plots.jl b/ext/plots/warcraft_plots.jl new file mode 100644 index 0000000..2029225 --- /dev/null +++ b/ext/plots/warcraft_plots.jl @@ -0,0 +1,49 @@ +import DecisionFocusedLearningBenchmarks.Warcraft as W +using Images: Gray + +has_visualization(::WarcraftBenchmark) = true + +function plot_instance(::WarcraftBenchmark, sample::DataSample; kwargs...) + im = dropdims(sample.x; dims=4) + img = W.convert_image_for_plot(im) + return Plots.plot( + img; aspect_ratio=:equal, framestyle=:none, title="Terrain image", kwargs... + ) +end + +function plot_solution( + ::WarcraftBenchmark, + sample::DataSample; + θ_true=sample.θ, + θ_title="Weights", + y_title="Path", + kwargs..., +) + x = sample.x + y = sample.y + θ = sample.θ + im = dropdims(x; dims=4) + img = W.convert_image_for_plot(im) + p1 = Plots.plot( + img; aspect_ratio=:equal, framestyle=:none, size=(300, 300), title="Terrain image" + ) + p2 = Plots.heatmap( + -θ; + yflip=true, + aspect_ratio=:equal, + framestyle=:none, + padding=(0.0, 0.0), + size=(300, 300), + legend=false, + title=θ_title, + clim=(minimum(-θ_true), maximum(-θ_true)), + ) + p3 = Plots.plot( + Gray.(y .* 0.7); + aspect_ratio=:equal, + framestyle=:none, + size=(300, 300), + title=y_title, + ) + return Plots.plot(p1, p2, p3; layout=(1, 3), size=(900, 300), kwargs...) +end diff --git a/src/Argmax2D/Argmax2D.jl b/src/Argmax2D/Argmax2D.jl index 512d60b..968d63a 100644 --- a/src/Argmax2D/Argmax2D.jl +++ b/src/Argmax2D/Argmax2D.jl @@ -1,12 +1,9 @@ module Argmax2D using ..Utils -using Colors: Colors using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES using Flux: Chain, Dense -using LaTeXStrings: @L_str using LinearAlgebra: dot, norm -using Plots: Plots using Random: Random, MersenneTwister, AbstractRNG include("polytope.jl") @@ -88,28 +85,6 @@ function Utils.generate_statistical_model( return model end -function Utils.plot_data(::Argmax2DBenchmark; instance, θ, kwargs...) - pl = init_plot() - plot_polytope!(pl, instance) - plot_objective!(pl, θ) - return plot_maximizer!(pl, θ, instance, maximizer) -end - -""" -$TYPEDSIGNATURES - -Plot the data sample for the [`Argmax2DBenchmark`](@ref). -""" -function Utils.plot_data( - bench::Argmax2DBenchmark, - sample::DataSample; - instance=sample.instance, - θ=sample.θ, - kwargs..., -) - return Utils.plot_data(bench; instance, θ, kwargs...) -end - export Argmax2DBenchmark end diff --git a/src/Argmax2D/polytope.jl b/src/Argmax2D/polytope.jl index dd8057a..a694779 100644 --- a/src/Argmax2D/polytope.jl +++ b/src/Argmax2D/polytope.jl @@ -1,53 +1,3 @@ function build_polytope(N; shift=0.0) return [[cospi(2k / N + shift), sinpi(2k / N + shift)] for k in 0:(N - 1)] end - -function init_plot(title="") - pl = Plots.plot(; - aspect_ratio=:equal, - legend=:outerleft, - xlim=(-1.1, 1.1), - ylim=(-1.1, 1.1), - title=title, - ) - return pl -end; - -function plot_polytope!(pl, vertices) - return Plots.plot!( - vcat(map(first, vertices), first(vertices[1])), - vcat(map(last, vertices), last(vertices[1])); - fillrange=0, - fillcolor=:gray, - fillalpha=0.2, - linecolor=:black, - label=L"\mathrm{conv}(\mathcal{Y}(x))", - ) -end; - -function plot_objective!(pl, θ) - Plots.plot!( - pl, - [0.0, θ[1]], - [0.0, θ[2]]; - color=Colors.JULIA_LOGO_COLORS.purple, - arrow=true, - lw=2, - label=nothing, - ) - Plots.annotate!(pl, [-0.2 * θ[1]], [-0.2 * θ[2]], [L"\theta"]) - return pl -end; - -function plot_maximizer!(pl, θ, instance, maximizer) - ŷ = maximizer(θ; instance) - return Plots.scatter!( - pl, - [ŷ[1]], - [ŷ[2]]; - color=Colors.JULIA_LOGO_COLORS.red, - markersize=9, - markershape=:square, - label=L"f(\theta)", - ) -end; diff --git a/src/DecisionFocusedLearningBenchmarks.jl b/src/DecisionFocusedLearningBenchmarks.jl index 428346c..f0b4f5b 100644 --- a/src/DecisionFocusedLearningBenchmarks.jl +++ b/src/DecisionFocusedLearningBenchmarks.jl @@ -79,7 +79,7 @@ export generate_anticipative_solver, generate_parametric_anticipative_solver export is_exogenous, is_endogenous export objective_value -export plot_data, plot_instance, plot_solution +export has_visualization, plot_instance, plot_solution, plot_trajectory, animate_trajectory export compute_gap # Export all benchmarks diff --git a/src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl b/src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl index 92ff5d5..1923bbb 100644 --- a/src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl +++ b/src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl @@ -12,7 +12,6 @@ using InferOpt: LinearMaximizer using IterTools: partition using JSON using JuMP -using Plots: plot, plot!, scatter!, @animate, Plots, gif using Printf: @printf, @sprintf using Random: Random, AbstractRNG, MersenneTwister, seed!, randperm using Requires: @require @@ -43,13 +42,12 @@ include("utils.jl") include("static_vsp/instance.jl") include("static_vsp/parsing.jl") include("static_vsp/solution.jl") -include("static_vsp/plot.jl") include("instance.jl") include("state.jl") include("scenario.jl") include("environment.jl") -include("plot.jl") +include("plot_data.jl") include("maximizer.jl") include("anticipative_solver.jl") diff --git a/src/DynamicVehicleScheduling/plot_data.jl b/src/DynamicVehicleScheduling/plot_data.jl new file mode 100644 index 0000000..757645b --- /dev/null +++ b/src/DynamicVehicleScheduling/plot_data.jl @@ -0,0 +1,46 @@ +function build_state_data(state::DVSPState) + coords = coordinate(state) + x = [p.x for p in coords] + y = [p.y for p in coords] + x_depot = x[1] + y_depot = y[1] + x_customers = x[2:end] + y_customers = y[2:end] + start_times_customers = start_time(state)[2:end] + service_times_customers = service_time(state)[2:end] + must_customers = state.is_must_dispatch[2:end] + + return (; + x_depot=x_depot, + y_depot=y_depot, + x_customers=x_customers, + y_customers=y_customers, + is_must_dispatch=must_customers, + start_times=start_times_customers, + service_times=service_times_customers, + ) +end + +""" +Return a Dict with plot-ready information extracted from a vector of DataSample objects. + + +The returned dictionary contains: +- :n_epochs => Int +- :coordinates => Vector{Vector{Tuple{Float64,Float64}}} (per-epoch list of (x,y) tuples, empty if instance missing) +- :start_times => Vector{Vector{Float64}} (per-epoch start times, empty if instance missing) +- :node_types => Vector{Vector{Symbol}} (per-epoch node-type labels aligned with coordinates) +- :routes => Vector{Vector{Vector{Int}}} (per-epoch normalized routes; empty vector when no routes) +- :epoch_costs => Vector{Float64} (per-epoch cost; NaN if not computable) + +This lets plotting code build figures without depending on plotting internals. +""" +function build_plot_data(data_samples::Vector{<:DataSample}) + state_data = [build_state_data(sample.instance) for sample in data_samples] + rewards = [sample.reward for sample in data_samples] + routess = [sample.y for sample in data_samples] + return [ + (; state..., reward, routes) for + (state, reward, routes) in zip(state_data, rewards, routess) + ] +end diff --git a/src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl b/src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl index 534cc76..2b31618 100644 --- a/src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl +++ b/src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl @@ -2,7 +2,6 @@ module StochasticVehicleScheduling export StochasticVehicleSchedulingBenchmark export generate_dataset, generate_maximizer, generate_statistical_model -export plot_instance, plot_solution export compact_linearized_mip, compact_mip, column_generation_algorithm, local_search, deterministic_mip export evaluate_solution, is_feasible @@ -28,7 +27,6 @@ using Graphs: outneighbors using JuMP: JuMP, Model, @variable, @objective, @constraint, optimize!, value, set_silent, dual -using Plots: Plots, plot, plot!, scatter!, annotate!, text using Printf: @printf using Random: Random, AbstractRNG, MersenneTwister using SparseArrays: sparse, SparseMatrixCSC @@ -179,106 +177,4 @@ function Utils.generate_statistical_model( return Chain(Dense(20 => 1; bias=false), vec) end -""" -$TYPEDSIGNATURES -""" -function plot_instance( - ::StochasticVehicleSchedulingBenchmark, sample::DataSample; kwargs... -) - @assert hasproperty(sample.instance, :city) "Sample does not contain city information." - (; tasks, district_width, width) = sample.instance.city - ticks = 0:district_width:width - max_time = maximum(t.end_time for t in sample.instance.city.tasks[1:(end - 1)]) - fig = plot(; - xlabel="x", - ylabel="y", - gridlinewidth=3, - aspect_ratio=:equal, - size=(500, 500), - xticks=ticks, - yticks=ticks, - xlims=(-1, width + 1), - ylims=(-1, width + 1), - clim=(0.0, max_time), - label=nothing, - colorbar_title="Time", - ) - scatter!( - fig, - [tasks[1].start_point.x], - [tasks[1].start_point.y]; - label=nothing, - marker=:rect, - markersize=10, - ) - annotate!(fig, (tasks[1].start_point.x, tasks[1].start_point.y, text("0", 10))) - for (i_task, task) in enumerate(tasks[2:(end - 1)]) - (; start_point, end_point) = task - points = [(start_point.x, start_point.y), (end_point.x, end_point.y)] - plot!(fig, points; color=:black, label=nothing) - scatter!( - fig, - points[1]; - markersize=10, - marker=:rect, - marker_z=task.start_time, - colormap=:turbo, - label=nothing, - ) - scatter!( - fig, - points[2]; - markersize=10, - marker=:rect, - marker_z=task.end_time, - colormap=:turbo, - label=nothing, - ) - annotate!(fig, (points[1]..., text("$(i_task)", 10))) - end - return fig -end - -""" -$TYPEDSIGNATURES -""" -function plot_solution( - ::StochasticVehicleSchedulingBenchmark, sample::DataSample; kwargs... -) - @assert hasproperty(sample.instance, :city) "Sample does not contain city information." - (; tasks, district_width, width) = sample.instance.city - ticks = 0:district_width:width - solution = Solution(sample.y, sample.instance) - path_list = compute_path_list(solution) - fig = plot(; - xlabel="x", - ylabel="y", - legend=false, - gridlinewidth=3, - aspect_ratio=:equal, - size=(500, 500), - xticks=ticks, - yticks=ticks, - xlims=(-1, width + 1), - ylims=(-1, width + 1), - ) - for path in path_list - X = Float64[] - Y = Float64[] - (; start_point, end_point) = tasks[path[1]] - (; x, y) = end_point - push!(X, x) - push!(Y, y) - for task in path[2:end] - (; start_point, end_point) = tasks[task] - push!(X, start_point.x) - push!(Y, start_point.y) - push!(X, end_point.x) - push!(Y, end_point.y) - end - plot!(fig, X, Y; marker=:circle) - end - return fig -end - end diff --git a/src/Utils/Utils.jl b/src/Utils/Utils.jl index 76d66a0..3711916 100644 --- a/src/Utils/Utils.jl +++ b/src/Utils/Utils.jl @@ -36,7 +36,8 @@ export generate_environment, generate_environments export generate_baseline_policies export generate_anticipative_solver, generate_parametric_anticipative_solver -export plot_data, compute_gap +export has_visualization, plot_instance, plot_solution, plot_trajectory, animate_trajectory +export compute_gap export grid_graph, get_path, path_to_matrix export neg_tensor, squeeze_last_dims, average_tensor export scip_model, highs_model diff --git a/src/Utils/interface.jl b/src/Utils/interface.jl index 5f42e98..efc9c26 100644 --- a/src/Utils/interface.jl +++ b/src/Utils/interface.jl @@ -20,9 +20,10 @@ Also implement: - [`is_minimization_problem`](@ref): defaults to `true` - [`objective_value`](@ref): defaults to `dot(θ, y)` - [`compute_gap`](@ref): default implementation provided; override for custom evaluation +- [`has_visualization`](@ref): defaults to `false` -# Optional methods (no default) -- [`plot_data`](@ref), [`plot_instance`](@ref), [`plot_solution`](@ref) +# Optional methods (no default, require `Plots` to be loaded) +- [`plot_instance`](@ref), [`plot_solution`](@ref) - [`generate_baseline_policies`](@ref) """ abstract type AbstractBenchmark end @@ -59,7 +60,7 @@ end generate_dataset(::AbstractBenchmark, dataset_size::Int; target_policy=nothing, kwargs...) -> Vector{<:DataSample} Generate a `Vector` of [`DataSample`](@ref) of length `dataset_size` for given benchmark. -Content of the dataset can be visualized using [`plot_data`](@ref), when it applies. +Content of the dataset can be visualized using [`plot_solution`](@ref), when it applies. By default, it uses [`generate_sample`](@ref) to create each sample in the dataset, and passes any keyword arguments to it. `target_policy` is applied if provided, it is called on each sample @@ -109,24 +110,24 @@ Return named baseline policies for the benchmark. Each policy is a callable. function generate_baseline_policies end """ - plot_data(::AbstractBenchmark, ::DataSample; kwargs...) + has_visualization(::AbstractBenchmark) -> Bool -Plot a data sample from the dataset created by [`generate_dataset`](@ref). -Check the specific benchmark documentation of `plot_data` for more details on the arguments. +Return `true` if `plot_instance` and `plot_solution` are implemented for this benchmark +(requires `Plots` to be loaded). Default is `false`. """ -function plot_data end +has_visualization(::AbstractBenchmark) = false """ - plot_instance(::AbstractBenchmark, instance; kwargs...) + plot_instance(bench::AbstractBenchmark, sample::DataSample; kwargs...) -Plot the instance object of the sample. +Plot the problem instance (no solution). Only available when `Plots` is loaded. """ function plot_instance end """ - plot_solution(::AbstractBenchmark, sample::DataSample, [solution]; kwargs...) + plot_solution(bench::AbstractBenchmark, sample::DataSample; kwargs...) -Plot `solution` if given, else plot the target solution in the sample. +Plot the instance with `sample.y` overlaid. Only available when `Plots` is loaded. """ function plot_solution end @@ -382,6 +383,10 @@ meaning (whether uncertainty is independent of decisions). - [`generate_dataset`](@ref)`(bench, environments; target_policy, ...)`: generates training-ready [`DataSample`](@ref)s by calling `target_policy(env)` for each environment. Requires `target_policy` as a mandatory keyword argument. + +# Optional visualization methods (require `Plots` to be loaded) +- [`plot_trajectory`](@ref)`(bench, traj)`: plot a full episode as a grid of subplots. +- [`animate_trajectory`](@ref)`(bench, traj)`: animate a full episode. """ abstract type AbstractDynamicBenchmark{exogenous} <: AbstractStochasticBenchmark{exogenous} end @@ -467,3 +472,19 @@ function generate_dataset( environments = generate_environments(bench, n; seed) return generate_dataset(bench, environments; target_policy, seed, kwargs...) end + +""" + plot_trajectory(bench::AbstractDynamicBenchmark, trajectory::Vector{<:DataSample}; kwargs...) + +Plot a full dynamic episode as a grid of state/decision subplots. +Only available when `Plots` is loaded. +""" +function plot_trajectory end + +""" + animate_trajectory(bench::AbstractDynamicBenchmark, trajectory::Vector{<:DataSample}; kwargs...) + +Animate a full dynamic episode. Returns a `Plots.Animation` object +(save with `gif(result, filename)`). Only available when `Plots` is loaded. +""" +function animate_trajectory end diff --git a/src/Warcraft/Warcraft.jl b/src/Warcraft/Warcraft.jl index a169fc5..bd37418 100644 --- a/src/Warcraft/Warcraft.jl +++ b/src/Warcraft/Warcraft.jl @@ -10,7 +10,6 @@ using Images using LinearAlgebra using Metalhead using NPZ -using Plots using Random using SimpleWeightedGraphs using SparseArrays @@ -80,51 +79,6 @@ function Utils.generate_statistical_model(::WarcraftBenchmark; seed=nothing) return model_embedding end -""" -$TYPEDSIGNATURES - -Plot the content of input `DataSample` as images. -`x` as the initial image, `θ` as the weights, and `y` as the path. - -The keyword argument `θ_true` is used to set the color range of the weights plot. -""" -function Utils.plot_data( - ::WarcraftBenchmark, - sample::DataSample; - θ_true=sample.θ, - θ_title="Weights", - y_title="Path", - kwargs..., -) - x = sample.x - y = sample.y - θ = sample.θ - im = dropdims(x; dims=4) - img = convert_image_for_plot(im) - p1 = Plots.plot( - img; aspect_ratio=:equal, framestyle=:none, size=(300, 300), title="Terrain image" - ) - p2 = Plots.heatmap( - -θ; - yflip=true, - aspect_ratio=:equal, - framestyle=:none, - padding=(0.0, 0.0), - size=(300, 300), - legend=false, - title=θ_title, - clim=(minimum(-θ_true), maximum(-θ_true)), - ) - p3 = Plots.plot( - Gray.(y .* 0.7); - aspect_ratio=:equal, - framestyle=:none, - size=(300, 300), - title=y_title, - ) - return plot(p1, p2, p3; layout=(1, 3), size=(900, 300)) -end - export WarcraftBenchmark end diff --git a/test/argmax_2d.jl b/test/argmax_2d.jl index 089e013..e3bd6ff 100644 --- a/test/argmax_2d.jl +++ b/test/argmax_2d.jl @@ -16,9 +16,13 @@ gap = compute_gap(b, dataset, model, maximizer) @test gap >= 0 - # Test plot_data - figure = plot_data(b, dataset[1]) + @test has_visualization(b) + figure = plot_solution(b, dataset[1]) @test figure isa Plots.Plot + figure2 = plot_instance(b, dataset[1]) + @test figure2 isa Plots.Plot + figure3 = plot_solution(b, dataset[1], dataset[2].y) + @test figure3 isa Plots.Plot for (i, sample) in enumerate(dataset) x = sample.x diff --git a/test/code.jl b/test/code.jl index af617fc..c9752d8 100644 --- a/test/code.jl +++ b/test/code.jl @@ -4,6 +4,7 @@ DecisionFocusedLearningBenchmarks; ambiguities=false, deps_compat=(check_extras = false), + stale_deps=(ignore=[:LaTeXStrings],), # used only inside DFLBenchmarksPlotsExt ) end diff --git a/test/dynamic_vsp_plots.jl b/test/dynamic_vsp_plots.jl index f643551..32cbc4a 100644 --- a/test/dynamic_vsp_plots.jl +++ b/test/dynamic_vsp_plots.jl @@ -1,32 +1,35 @@ @testset "Dynamic VSP Plots" begin - import DecisionFocusedLearningBenchmarks.DynamicVehicleScheduling as DVSP using Plots - # Create test benchmark and data (similar to scripts/a.jl) + # Create test benchmark and environments b = DynamicVehicleSchedulingBenchmark(; two_dimensional_features=true) environments = generate_environments(b, 3; seed=0) env = environments[1] - # Test basic plotting functions - fig1 = DVSP.plot_instance(env) - @test fig1 isa Plots.Plot - + # Get a trajectory via the anticipative solver y = generate_anticipative_solver(b)(env; nb_epochs=3) - fig2 = DVSP.plot_epochs(y) + # Test plot_instance (shows first epoch state) + fig1 = plot_instance(b, y[1]) + @test fig1 isa Plots.Plot + + # Test plot_trajectory (grid of epoch subplots) + fig2 = plot_trajectory(b, y) @test fig2 isa Plots.Plot + # Test plot_solution via baseline policy policies = generate_baseline_policies(b) lazy = policies[1] _, d = evaluate_policy!(lazy, env) - fig3 = DVSP.plot_routes(d[1].instance, d[1].y) + fig3 = plot_solution(b, d[1]) @test fig3 isa Plots.Plot - # Test animation + # Test animate_trajectory — returns Animation, save separately with gif() temp_filename = tempname() * ".gif" try - anim = DVSP.animate_epochs(y; filename=temp_filename, fps=1) - @test anim isa Plots.AnimatedGif || anim isa Plots.Animation + anim = animate_trajectory(b, y; fps=1) + @test anim isa Plots.Animation + gif(anim, temp_filename; fps=1) @test isfile(temp_filename) finally if isfile(temp_filename) diff --git a/test/warcraft.jl b/test/warcraft.jl index 6a8ecd9..4d5f67d 100644 --- a/test/warcraft.jl +++ b/test/warcraft.jl @@ -13,8 +13,13 @@ bellman_maximizer = generate_maximizer(b; dijkstra=false) dijkstra_maximizer = generate_maximizer(b; dijkstra=true) - figure = plot_data(b, dataset[1]) + @test has_visualization(b) + figure = plot_solution(b, dataset[1]) @test figure isa Plots.Plot + figure2 = plot_instance(b, dataset[1]) + @test figure2 isa Plots.Plot + figure3 = plot_solution(b, dataset[1], dataset[2].y) + @test figure3 isa Plots.Plot gap = compute_gap(b, dataset, model, dijkstra_maximizer) @test gap >= 0