diff --git a/Project.toml b/Project.toml index 09e60e4..88905d2 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ ConstrainedShortestPaths = "b3798467-87dc-4d99-943d-35a1bd39e395" DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b" @@ -40,6 +41,7 @@ ConstrainedShortestPaths = "0.6.0" DataDeps = "0.7" Distributions = "0.25" DocStringExtensions = "0.9" +FileIO = "1.17.0" Flux = "0.14, 0.15, 0.16" Graphs = "1.11" HiGHS = "1.9" diff --git a/README.md b/README.md index 585501c..d241833 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,10 @@ [![Coverage](https://codecov.io/gh/JuliaDecisionFocusedLearning/DecisionFocusedLearningBenchmarks.jl/branch/main/graph/badge.svg)](https://app.codecov.io/gh/JuliaDecisionFocusedLearning/DecisionFocusedLearningBenchmarks.jl) [![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/JuliaDiff/BlueStyle) +!!! warning + This package is currently under active development. The API may change in future releases. + Please refer to the [documentation](https://JuliaDecisionFocusedLearning.github.io/DecisionFocusedLearningBenchmarks.jl/stable/) for the latest updates. + ## What is Decision-Focused Learning? Decision-focused learning (DFL) is a paradigm that integrates machine learning prediction with combinatorial optimization to make better decisions under uncertainty. Unlike traditional "predict-then-optimize" approaches that optimize prediction accuracy independently of downstream decision quality, DFL directly optimizes end-to-end decision performance. diff --git a/docs/src/benchmark_interfaces.md b/docs/src/benchmark_interfaces.md index 19fc231..faa6fb7 100644 --- a/docs/src/benchmark_interfaces.md +++ b/docs/src/benchmark_interfaces.md @@ -26,13 +26,13 @@ The package defines a hierarchy of three abstract types: ``` AbstractBenchmark -├── AbstractStochasticBenchmark{exogenous} +└── AbstractStochasticBenchmark{exogenous} └── AbstractDynamicBenchmark{exogenous} ``` - **`AbstractBenchmark`**: static, single-stage optimization problems - **`AbstractStochasticBenchmark{exogenous}`**: stochastic, single stage optimization problems - **`AbstractDynamicBenchmark{exogenous}`**: multi-stage sequential decision problems +- **`AbstractDynamicBenchmark{exogenous}`**: multi-stage sequential decision-making problems The `{exogenous}` type parameter indicates whether uncertainty distribution comes from external sources (`true`) or is influenced by decisions (`false`), which affects available methods. diff --git a/src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl b/src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl index 38bd82a..b940743 100644 --- a/src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl +++ b/src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl @@ -13,7 +13,7 @@ using IterTools: partition using JSON using JuMP using Plots: plot, plot!, scatter!, @animate, Plots, gif -using Printf: @printf +using Printf: @printf, @sprintf using Random: Random, AbstractRNG, MersenneTwister, seed!, randperm using Requires: @require using Statistics: mean, quantile diff --git a/src/DynamicVehicleScheduling/anticipative_solver.jl b/src/DynamicVehicleScheduling/anticipative_solver.jl index b2e2452..7178532 100644 --- a/src/DynamicVehicleScheduling/anticipative_solver.jl +++ b/src/DynamicVehicleScheduling/anticipative_solver.jl @@ -54,6 +54,7 @@ function anticipative_solver( ) if reset_env reset!(env; reset_rng=true, seed) + scenario = env.scenario end @assert !is_terminated(env) @@ -213,13 +214,15 @@ function anticipative_solver( current_epoch=epoch, ) + reward = -cost(state, decode_bitmatrix_to_routes(y_true)) + x = if two_dimensional_features compute_2D_features(state, env.instance) else compute_features(state, env.instance) end - return DataSample(; instance=state, y_true, x) + return DataSample(; instance=(; state, reward), y_true, x) end return obj, dataset diff --git a/src/DynamicVehicleScheduling/plot.jl b/src/DynamicVehicleScheduling/plot.jl index 93343b6..33dd940 100644 --- a/src/DynamicVehicleScheduling/plot.jl +++ b/src/DynamicVehicleScheduling/plot.jl @@ -1,11 +1,34 @@ -function plot_instancee(env::DVSPEnv; kwargs...) +function plot_instance(env::DVSPEnv; kwargs...) return plot_instance(env.instance.static_instance; kwargs...) end +function build_state_data(state::DVSPState) + coords = coordinate(state) + x = [p.x for p in coords] + y = [p.y for p in coords] + x_depot = x[1] + y_depot = y[1] + x_customers = x[2:end] + y_customers = y[2:end] + start_times_customers = start_time(state)[2:end] + service_times_customers = service_time(state)[2:end] + must_customers = state.is_must_dispatch[2:end] + + return (; + x_depot=x_depot, + y_depot=y_depot, + x_customers=x_customers, + y_customers=y_customers, + is_must_dispatch=must_customers, + start_times=start_times_customers, + service_times=service_times_customers, + ) +end + """ $TYPEDSIGNATURES -Plot a given DVSPState showing depot, must-dispatch requests, and postponable requests. +Plot a given DVSPState showing depot, must-dispatch customers, and postponable customers. """ function plot_state( state::DVSPState; @@ -13,21 +36,20 @@ function plot_state( depot_markersize=8, alpha_depot=0.8, depot_color=:lightgreen, + depot_marker=:rect, must_dispatch_color=:red, postponable_color=:lightblue, + must_dispatch_marker=:star5, + postponable_marker=:utriangle, show_axis_labels=true, markerstrokewidth=0.5, + show_colorbar=true, kwargs..., ) - # Get coordinates from the state instance - coordinates = coordinate(state) - start_times = start_time(state) - - # Extract x and y coordinates - x = [p.x for p in coordinates] - y = [p.y for p in coordinates] + (; x_depot, y_depot, x_customers, y_customers, is_must_dispatch, start_times) = build_state_data( + state + ) - # Create the plot plot_args = Dict( :legend => :topleft, :title => "DVSP State - Epoch $(state.current_epoch)" ) @@ -37,57 +59,68 @@ function plot_state( plot_args[:ylabel] = "y coordinate" end - # Merge with kwargs + # Merge with kwargs (possibly overriding defaults) for (k, v) in kwargs plot_args[k] = v end fig = plot(; plot_args...) - # Plot depot (always the first coordinate) + # Display depot scatter!( fig, - [x[1]], - [y[1]]; + [x_depot], + [y_depot]; label="Depot", markercolor=depot_color, - marker=:rect, + marker=depot_marker, markersize=depot_markersize, alpha=alpha_depot, markerstrokewidth=markerstrokewidth, ) - # Plot must-dispatch customers - if sum(state.is_must_dispatch) > 0 - must_dispatch_indices = findall(state.is_must_dispatch) + scatter_must_dispatch_args = Dict( + :label => "Must-dispatch customers", + :markercolor => must_dispatch_color, + :marker => must_dispatch_marker, + :markersize => customer_markersize, + :markerstrokewidth => markerstrokewidth, + ) + + scatter_postponable_args = Dict( + :label => "Postponable customers", + :markercolor => postponable_color, + :marker => postponable_marker, + :markersize => customer_markersize, + :markerstrokewidth => markerstrokewidth, + ) + if show_colorbar + scatter_must_dispatch_args[:marker_z] = start_times[is_must_dispatch] + scatter_postponable_args[:marker_z] = start_times[.!is_must_dispatch] + # scatter_postponable_args[:label] = "Postponable customers (start time)" + scatter_postponable_args[:colormap] = :plasma + scatter_must_dispatch_args[:colormap] = :plasma + scatter_postponable_args[:colorbar] = :right + scatter_must_dispatch_args[:colorbar] = :right + Plots.gr_cbar_width[] = 0.01 + end + + # Display customers, separating must-dispatch and postponable + if length(x_customers[is_must_dispatch]) > 0 scatter!( fig, - x[must_dispatch_indices], - y[must_dispatch_indices]; - label="Must-dispatch requests", - markercolor=must_dispatch_color, - marker=:star5, - markersize=customer_markersize, - marker_z=start_times[must_dispatch_indices], - colormap=:plasma, - markerstrokewidth=markerstrokewidth, + x_customers[is_must_dispatch], + y_customers[is_must_dispatch]; + scatter_must_dispatch_args..., ) end - # Plot postponable customers - if sum(state.is_postponable) > 0 - postponable_indices = findall(state.is_postponable) + if length(x_customers[.!is_must_dispatch]) > 0 scatter!( fig, - x[postponable_indices], - y[postponable_indices]; - label="Postponable requests", - markercolor=postponable_color, - marker=:utriangle, - markersize=customer_markersize, - marker_z=start_times[postponable_indices], - colormap=:viridis, - markerstrokewidth=markerstrokewidth, + x_customers[.!is_must_dispatch], + y_customers[.!is_must_dispatch]; + scatter_postponable_args..., ) end @@ -97,58 +130,52 @@ end """ $TYPEDSIGNATURES -Plot a given DVSPState with routes overlaid, showing depot, requests, and vehicle routes. +Plot a given DVSPState with routes overlaid, showing depot, customers, and vehicle routes. Routes should be provided as a vector of vectors, where each inner vector contains the indices of locations visited by that route (excluding the depot). """ function plot_routes( state::DVSPState, routes::Vector{Vector{Int}}; - route_colors=nothing, - route_linewidth=3, # Increased from 2 to 3 - route_alpha=0.7, - show_route_labels=true, + reward=nothing, + route_color=nothing, + route_linewidth=2, + route_alpha=0.8, kwargs..., ) + cost_text = if !isnothing(reward) + " (" * @sprintf("%.2f", -reward) * ")" + else + "" + end # Start with the basic state plot - fig = plot_state(state; kwargs...) + fig = plot_state( + state; + kwargs..., + title="DVSP State with Routes - Epoch $(state.current_epoch)$cost_text", + ) - # Get coordinates for route plotting - coordinates = coordinate(state) - x = [p.x for p in coordinates] - y = [p.y for p in coordinates] + (; x_depot, y_depot, x_customers, y_customers) = build_state_data(state) - # Depot coordinates (always first) - x_depot = x[1] - y_depot = y[1] + x = vcat(x_depot, x_customers) + y = vcat(y_depot, y_customers) - # Default route colors if not provided - if isnothing(route_colors) - route_colors = [:blue, :purple, :orange, :brown, :pink, :gray, :olive, :cyan] + plot_args = Dict( + :linewidth => route_linewidth, :alpha => route_alpha, :z_order => :back; + ) + + if !isnothing(route_color) + plot_args[:color] = route_color end # Plot each route - for (route_idx, route) in enumerate(routes) + for route in routes if !isempty(route) # Create route path: depot -> customers -> depot route_x = vcat(x_depot, x[route], x_depot) route_y = vcat(y_depot, y[route], y_depot) - # Select color for this route - color = route_colors[(route_idx - 1) % length(route_colors) + 1] - - # Plot the route with more visible styling - label = show_route_labels ? "Route $route_idx" : nothing - plot!( - fig, - route_x, - route_y; - # color=color, - linewidth=route_linewidth, - alpha=1.0, # Make routes fully opaque - label=label, - linestyle=:solid, - ) + plot!(fig, route_x, route_y; label=false, plot_args...) end end @@ -158,42 +185,36 @@ end """ $TYPEDSIGNATURES -Plot a given DVSPState with routes overlaid. This version accepts routes as a single -vector where routes are separated by depot visits (index 1). +Plot a given DVSPState with routes overlaid. This version accepts routes as a BitMatrix +where entry (i,j) = true indicates an edge from location i to location j. """ -function plot_routes(state::DVSPState, routes::Vector{Int}; kwargs...) - # Convert single route vector to vector of route vectors - route_vectors = Vector{Int}[] - current_route = Int[] - - for location in routes - if location == 1 # Depot visit indicates end of route - if !isempty(current_route) - push!(route_vectors, copy(current_route)) - empty!(current_route) - end - else - push!(current_route, location) - end - end - - # Add the last route if it doesn't end with depot - if !isempty(current_route) - push!(route_vectors, current_route) - end - +function plot_routes(state::DVSPState, routes::BitMatrix; kwargs...) + route_vectors = decode_bitmatrix_to_routes(routes) return plot_routes(state, route_vectors; kwargs...) end """ -$TYPEDSIGNATURES +Return a Dict with plot-ready information extracted from a vector of DataSample objects. -Plot a given DVSPState with routes overlaid. This version accepts routes as a BitMatrix -where entry (i,j) = true indicates an edge from location i to location j. + +The returned dictionary contains: +- :n_epochs => Int +- :coordinates => Vector{Vector{Tuple{Float64,Float64}}} (per-epoch list of (x,y) tuples, empty if instance missing) +- :start_times => Vector{Vector{Float64}} (per-epoch start times, empty if instance missing) +- :node_types => Vector{Vector{Symbol}} (per-epoch node-type labels aligned with coordinates) +- :routes => Vector{Vector{Vector{Int}}} (per-epoch normalized routes; empty vector when no routes) +- :epoch_costs => Vector{Float64} (per-epoch cost; NaN if not computable) + +This lets plotting code build figures without depending on plotting internals. """ -function plot_routes(state::DVSPState, routes::BitMatrix; kwargs...) - route_vectors = decode_bitmatrix_to_routes(routes) - return plot_routes(state, route_vectors; kwargs...) +function build_plot_data(data_samples::Vector{<:DataSample}) + state_data = [build_state_data(sample.instance.state) for sample in data_samples] + rewards = [sample.instance.reward for sample in data_samples] + routess = [sample.y_true for sample in data_samples] + return [ + (; state..., reward, routes) for + (state, reward, routes) in zip(state_data, rewards, routess) + ] end """ @@ -207,7 +228,7 @@ function plot_epochs( data_samples::Vector{<:DataSample}; plot_routes_flag=true, cols=nothing, - figsize=(1800, 600), + figsize=nothing, margin=0.05, legend_margin_factor=0.15, titlefontsize=14, @@ -215,15 +236,17 @@ function plot_epochs( legendfontsize=11, tickfontsize=10, show_axis_labels=false, - show_colorbar=false, + show_colorbar=true, kwargs..., ) - n_epochs = length(data_samples) - - if n_epochs == 0 + if length(data_samples) == 0 error("No data samples provided") end + # Build centralized plot data + pd = build_plot_data(data_samples) + n_epochs = length(pd) + # Determine grid layout if isnothing(cols) cols = min(n_epochs, 3) # Default to max 3 columns @@ -231,108 +254,59 @@ function plot_epochs( rows = ceil(Int, n_epochs / cols) # Calculate global xlims and ylims from all states - all_coordinates = [] - for sample in data_samples - if !isnothing(sample.instance) - coords = coordinate(sample.instance) - append!(all_coordinates, coords) - end - end - - if isempty(all_coordinates) - error("No valid coordinates found in data samples") - end - - xlims = ( - minimum(p.x for p in all_coordinates) - margin, - maximum(p.x for p in all_coordinates) + margin, - ) + x_min = minimum(min(data.x_depot, minimum(data.x_customers)) for data in pd) + x_max = maximum(max(data.x_depot, maximum(data.x_customers)) for data in pd) + y_min = minimum(min(data.y_depot, minimum(data.y_customers)) for data in pd) + y_max = maximum(max(data.y_depot, maximum(data.y_customers)) for data in pd) + xlims = (x_min - margin, x_max + margin) # Add extra margin at the top for legend space - y_min = minimum(p.y for p in all_coordinates) - margin - y_max = maximum(p.y for p in all_coordinates) + margin - y_range = y_max - y_min + y_range = y_max - y_min + 2 * margin legend_margin = y_range * legend_margin_factor - - ylims = (y_min, y_max + legend_margin) + ylims = (y_min - margin, y_max + margin + legend_margin) # Calculate global color limits for consistent scaling across subplots - all_start_times = [] - for sample in data_samples - if !isnothing(sample.instance) - times = start_time(sample.instance) - append!(all_start_times, times) - end - end - - clims = if !isempty(all_start_times) - (minimum(all_start_times), maximum(all_start_times)) - else - (0.0, 1.0) # Default range - end + min_start_time = minimum(minimum(data.start_times) for data in pd) + max_start_time = maximum(maximum(data.start_times) for data in pd) + clims = (min_start_time, max_start_time) # Create subplots - plots = [] - - for (i, sample) in enumerate(data_samples) - state = sample.instance + plots = map(1:n_epochs) do i + sample = data_samples[i] + state = sample.instance.state + reward = sample.instance.reward + + common_kwargs = Dict( + :xlims => xlims, + :ylims => ylims, + :clims => clims, + :show_colorbar => show_colorbar, + :titlefontsize => titlefontsize, + :guidefontsize => guidefontsize, + :legendfontsize => legendfontsize, + :tickfontsize => tickfontsize, + :show_axis_labels => show_axis_labels, + :markerstrokewidth => 0.5, + ) - if isnothing(state) - # Create empty plot if no state - fig = plot(; - xlims=xlims, - ylims=ylims, - title="Epoch $i (No Data)", - titlefontsize=titlefontsize, - guidefontsize=guidefontsize, - tickfontsize=tickfontsize, - legend=false, + if plot_routes_flag + fig = plot_routes( + state, + sample.y_true; + reward=reward, + show_route_labels=false, + common_kwargs..., kwargs..., ) else - # Plot with or without routes - if plot_routes_flag && !isnothing(sample.y_true) - fig = plot_routes( - state, - sample.y_true; - xlims=xlims, - ylims=ylims, - clims=clims, - colorbar=false, - title="Epoch $(state.current_epoch)", - titlefontsize=titlefontsize, - guidefontsize=guidefontsize, - legendfontsize=legendfontsize, - tickfontsize=tickfontsize, - show_axis_labels=show_axis_labels, - markerstrokewidth=0.5, - show_route_labels=false, - kwargs..., - ) - else - fig = plot_state( - state; - xlims=xlims, - ylims=ylims, - clims=clims, - colorbar=false, - title="Epoch $(state.current_epoch)", - titlefontsize=titlefontsize, - guidefontsize=guidefontsize, - legendfontsize=legendfontsize, - tickfontsize=tickfontsize, - show_axis_labels=show_axis_labels, - markerstrokewidth=0.5, - kwargs..., - ) - end + fig = plot_state(state; common_kwargs..., kwargs...) end - push!(plots, fig) + return fig end # Calculate dynamic figure size if not specified - if figsize == (1800, 600) # Using default size + if isnothing(figsize) plot_width = 600 * cols plot_height = 500 * rows figsize = (plot_width, plot_height) @@ -341,12 +315,7 @@ function plot_epochs( # Combine plots in a grid layout with optional shared colorbar if show_colorbar combined_plot = plot( - plots...; - layout=(rows, cols), - size=figsize, - link=:both, - colorbar=:right, - clims=clims, + plots...; layout=(rows, cols), size=figsize, link=:both, clims=clims ) else combined_plot = plot( @@ -356,19 +325,6 @@ function plot_epochs( return combined_plot end - -""" -$TYPEDSIGNATURES - -Plot multiple epochs side by side, optionally filtering to specific epoch indices. -""" -function plot_epochs( - data_samples::Vector{<:DataSample}, epoch_indices::Vector{Int}; kwargs... -) - filtered_samples = data_samples[epoch_indices] - return plot_epochs(filtered_samples; kwargs...) -end - """ $TYPEDSIGNATURES @@ -387,79 +343,48 @@ function animate_epochs( legendfontsize=12, tickfontsize=11, show_axis_labels=true, + show_cost_bar=true, + show_colorbar=false, + cost_bar_width=0.05, + cost_bar_margin=0.02, + cost_bar_color_palette=:turbo, kwargs..., ) - n_epochs = length(data_samples) - - if n_epochs == 0 - error("No data samples provided") - end + pd = build_plot_data(data_samples) + epoch_costs = [-sample.instance.reward for sample in data_samples] - # Calculate global limits for consistent scaling - all_coordinates = [] - for sample in data_samples - if !isnothing(sample.instance) - coords = coordinate(sample.instance) - append!(all_coordinates, coords) - end - end - - if isempty(all_coordinates) - error("No valid coordinates found in data samples") - end - - xlims = ( - minimum(p.x for p in all_coordinates) - margin, - maximum(p.x for p in all_coordinates) + margin, - ) + # Calculate global xlims and ylims from all states + x_min = minimum(min(data.x_depot, minimum(data.x_customers)) for data in pd) + x_max = maximum(max(data.x_depot, maximum(data.x_customers)) for data in pd) + y_min = minimum(min(data.y_depot, minimum(data.y_customers)) for data in pd) + y_max = maximum(max(data.y_depot, maximum(data.y_customers)) for data in pd) + xlims = (x_min - margin, x_max + margin) # Add extra margin at the top for legend space - y_min = minimum(p.y for p in all_coordinates) - margin - y_max = maximum(p.y for p in all_coordinates) + margin - y_range = y_max - y_min + y_range = y_max - y_min + 2 * margin legend_margin = y_range * legend_margin_factor - ylims = (y_min, y_max + legend_margin) - - # Calculate global color limits - all_start_times = [] - for sample in data_samples - if !isnothing(sample.instance) - times = start_time(sample.instance) - append!(all_start_times, times) - end - end + ylims = (y_min - margin, y_max + margin + legend_margin) - clims = if !isempty(all_start_times) - (minimum(all_start_times), maximum(all_start_times)) - else - (0.0, 1.0) - end - - # Helper function to check if routes exist and are non-empty - function has_routes(routes) - if isnothing(routes) - return false - elseif routes isa Vector{Vector{Int}} - return any(!isempty(route) for route in routes) - elseif routes isa Vector{Int} - return !isempty(routes) - elseif routes isa BitMatrix - return any(routes) - else - return false - end + # Calculate global color limits for consistent scaling across subplots + min_start_time = minimum(minimum(data.start_times) for data in pd) + max_start_time = maximum(maximum(data.start_times) for data in pd) + clims = (min_start_time, max_start_time) + + # Adjust x-axis if showing cost bar + if show_cost_bar + x_min, x_max = xlims + x_range = x_max - x_min + cost_bar_space = x_range * (cost_bar_width + cost_bar_margin) + xlims = (x_min, x_max + cost_bar_space) end - # Create frame plan: determine which epochs have routes + # Create interleaved frame plan: always include a state frame and a routes frame + # for every epoch. The routes-frame will render a 'no routes' message when + # no routes are present, which keeps timing consistent and the code simpler. frame_plan = [] - for (epoch_idx, sample) in enumerate(data_samples) - # Always add state frame + for (epoch_idx, _) in enumerate(data_samples) push!(frame_plan, (epoch_idx, :state)) - - # Add routes frame only if routes exist - if has_routes(sample.y_true) - push!(frame_plan, (epoch_idx, :routes)) - end + push!(frame_plan, (epoch_idx, :routes)) end total_frames = length(frame_plan) @@ -468,60 +393,158 @@ function animate_epochs( anim = @animate for frame_idx in 1:total_frames epoch_idx, frame_type = frame_plan[frame_idx] sample = data_samples[epoch_idx] - state = sample.instance + state = sample.instance.state - if isnothing(state) - # Empty frame for missing data - plot(; + if frame_type == :routes + fig = plot_routes( + state, + sample.y_true; xlims=xlims, ylims=ylims, - title="Epoch $epoch_idx (No Data)", + clims=clims, + title="Epoch $(state.current_epoch) - Routes Dispatched", titlefontsize=titlefontsize, guidefontsize=guidefontsize, + legendfontsize=legendfontsize, tickfontsize=tickfontsize, - legend=false, + show_axis_labels=show_axis_labels, + markerstrokewidth=0.5, + show_route_labels=false, + show_colorbar=show_colorbar, size=figsize, kwargs..., ) - else - if frame_type == :routes - # Show state with routes - plot_routes( - state, - sample.y_true; - xlims=xlims, - ylims=ylims, - clims=clims, - title="Epoch $(state.current_epoch) - Routes Dispatched", - titlefontsize=titlefontsize, - guidefontsize=guidefontsize, - legendfontsize=legendfontsize, - tickfontsize=tickfontsize, - show_axis_labels=show_axis_labels, - markerstrokewidth=0.5, - show_route_labels=false, - size=figsize, - kwargs..., - ) - else # frame_type == :state - # Show state only - plot_state( - state; - xlims=xlims, - ylims=ylims, - clims=clims, - title="Epoch $(state.current_epoch) - Available Requests", - titlefontsize=titlefontsize, - guidefontsize=guidefontsize, - legendfontsize=legendfontsize, - tickfontsize=tickfontsize, - show_axis_labels=show_axis_labels, - markerstrokewidth=0.5, - size=figsize, - kwargs..., + else # frame_type == :state + # Show state only + fig = plot_state( + state; + xlims=xlims, + ylims=ylims, + clims=clims, + title="Epoch $(state.current_epoch) - Available Customers", + titlefontsize=titlefontsize, + guidefontsize=guidefontsize, + legendfontsize=legendfontsize, + tickfontsize=tickfontsize, + show_axis_labels=show_axis_labels, + markerstrokewidth=0.5, + show_colorbar=show_colorbar, + size=figsize, + kwargs..., + ) + end + + # Add cost bar if requested + if show_cost_bar + # Calculate cost bar position on the right side of the plot + x_min, x_max = xlims + x_range = x_max - x_min + bar_x_start = x_max - cost_bar_width * x_range + bar_x_end = x_max - cost_bar_margin * x_range + + y_min, y_max = ylims + y_range = y_max - y_min + bar_y_start = y_min + 0.1 * y_range + bar_y_end = y_max - 0.1 * y_range + bar_height = bar_y_end - bar_y_start + + # Calculate current cumulative cost based on frame type + # Cost increases only when routes are displayed (dispatched) + current_cost = 0.0 + + # Go through all frames up to the current one to see which epochs have had routes dispatched + for frame_i in 1:frame_idx + frame_epoch, frame_frame_type = frame_plan[frame_i] + + # Add cost only when we encounter a routes frame + if frame_frame_type == :routes && frame_epoch <= length(epoch_costs) + current_cost += epoch_costs[frame_epoch] + end + end + + # Calculate filled height + max_cost = sum(epoch_costs) + if max_cost > 0 + filled_height = (current_cost / max_cost) * bar_height + else + filled_height = 0.0 + end + + # Draw the cost bar background (empty bar) + plot!( + fig, + [bar_x_start, bar_x_end, bar_x_end, bar_x_start, bar_x_start], + [bar_y_start, bar_y_start, bar_y_end, bar_y_end, bar_y_start]; + seriestype=:shape, + color=:white, + alpha=0.8, + linecolor=:black, + linewidth=2, + label="", + ) + + cmap = Plots.cgrad(cost_bar_color_palette) + # Draw the filled portion with solid color + if filled_height > 0 + # Get a color at a value between 0 and 1 + ratio = current_cost / max_cost + color_at_val = Plots.get(cmap, ratio) + plot!( + fig, + [bar_x_start, bar_x_end, bar_x_end, bar_x_start, bar_x_start], + [ + bar_y_start, + bar_y_start, + bar_y_start + filled_height, + bar_y_start + filled_height, + bar_y_start, + ]; + seriestype=:shape, + color=color_at_val, + alpha=0.7, + linecolor=:darkred, + linewidth=1, + label="", ) end + + # Add current cost value + cost_text_y = bar_y_start + filled_height + 0.02 * y_range + if cost_text_y > bar_y_end + cost_text_y = bar_y_end + end + + plot!( + fig, + [bar_x_start + (bar_x_end - bar_x_start) / 2], + [cost_text_y]; + seriestype=:scatter, + markersize=0, + label="", + annotations=( + bar_x_start - 0.04 * x_range, + cost_text_y, + (@sprintf("%.1f", current_cost), :center, guidefontsize), + ), + ) + + # Add cost bar title + plot!( + fig, + [(bar_x_start + bar_x_end) / 2], + [bar_y_end + 0.05 * y_range]; + seriestype=:scatter, + markersize=0, + label="", + annotations=( + (bar_x_start + bar_x_end) / 2, + bar_y_end + 0.05 * y_range, + ("Cost", :center, guidefontsize), + ), + ) end + + fig end # Save as GIF @@ -529,138 +552,3 @@ function animate_epochs( return anim end - -# """ -# $TYPEDSIGNATURES - -# Plot the environment of a DVSPEnv, restricted to the given `epoch_indices` (all epoch if not given). -# """ -# function plot_environment( -# env::DVSPEnv; -# customer_markersize=4, -# depot_markersize=7, -# alpha_depot=0.8, -# depot_color=:lightgreen, -# epoch_indices=nothing, -# kwargs..., -# ) -# draw_all_epochs!(env) - -# epoch_appearance = env.request_epoch -# coordinates = coordinate(get_state(env)) - -# epoch_indices = isnothing(epoch_indices) ? get_epoch_indices(env) : epoch_indices - -# xlims = (minimum(c.x for c in coordinates), maximum(c.x for c in coordinates)) -# ylims = (minimum(c.y for c in coordinates), maximum(c.y for c in coordinates)) - -# fig = plot(; -# legend=:topleft, -# xlabel="x coordinate", -# ylabel="y coordinate", -# xlims, -# ylims, -# kwargs..., -# ) - -# for epoch in epoch_indices -# requests = findall(epoch_appearance .== epoch) -# x = [coordinates[request].x for request in requests] -# y = [coordinates[request].y for request in requests] -# scatter!( -# fig, x, y; label="Epoch $epoch", marker=:circle, markersize=customer_markersize -# ) -# end -# scatter!( -# fig, -# [coordinates[1].x], -# [coordinates[1].y]; -# label="Depot", -# markercolor=depot_color, -# marker=:rect, -# markersize=depot_markersize, -# alpha=alpha_depot, -# ) - -# return fig -# end - -# """ -# $TYPEDSIGNATURES - -# Plot the given `routes`` for a VSP `state`. -# """ -# function plot_epoch(state::DVSPState, routes; kwargs...) -# (; coordinate, start_time) = state.instance -# x_depot = coordinate[1].x -# y_depot = coordinate[1].y -# X = [p.x for p in coordinate] -# Y = [p.y for p in coordinate] -# markersize = 5 -# fig = plot(; -# legend=:topleft, xlabel="x", ylabel="y", clim=(0.0, maximum(start_time)), kwargs... -# ) -# for route in routes -# x_points = vcat(x_depot, X[route], x_depot) -# y_points = vcat(y_depot, Y[route], y_depot) -# plot!(fig, x_points, y_points; label=nothing) -# end -# scatter!( -# fig, -# [x_depot], -# [y_depot]; -# label="depot", -# markercolor=:lightgreen, -# markersize, -# marker=:rect, -# ) -# if sum(state.is_postponable) > 0 -# scatter!( -# fig, -# X[state.is_postponable], -# Y[state.is_postponable]; -# label="Postponable customers", -# marker_z=start_time[state.is_postponable], -# markersize, -# colormap=:turbo, -# marker=:utriangle, -# ) -# end -# if sum(state.is_must_dispatch) > 0 -# scatter!( -# fig, -# X[state.is_must_dispatch], -# Y[state.is_must_dispatch]; -# label="Must-dispatch customers", -# marker_z=start_time[state.is_must_dispatch], -# markersize, -# colormap=:turbo, -# marker=:star5, -# ) -# end -# return fig -# end - -# """ -# $TYPEDSIGNATURES - -# Create a plot of routes for each epoch. -# """ -# function plot_routes(env::DVSPEnv, routes; epoch_indices=nothing, kwargs...) -# reset!(env) -# epoch_indices = isnothing(epoch_indices) ? get_epoch_indices(env) : epoch_indices - -# coordinates = env.config.static_instance.coordinate -# xlims = (minimum(c.x for c in coordinates), maximum(c.x for c in coordinates)) -# ylims = (minimum(c.y for c in coordinates), maximum(c.y for c in coordinates)) - -# figs = map(epoch_indices) do epoch -# s = next_epoch!(env) -# fig = plot_epoch( -# s, state_route_from_env_routes(env, routes[epoch]); xlims, ylims, kwargs... -# ) -# apply_decision!(env, routes[epoch]) -# return fig -# end -# return figs -# end diff --git a/src/DynamicVehicleScheduling/static_vsp/plot.jl b/src/DynamicVehicleScheduling/static_vsp/plot.jl index 515ab3d..fd7f62f 100644 --- a/src/DynamicVehicleScheduling/static_vsp/plot.jl +++ b/src/DynamicVehicleScheduling/static_vsp/plot.jl @@ -1,10 +1,8 @@ -""" -$TYPEDSIGNATURES - -Plot the given static VSP `instance`. -""" function plot_instance( - instance::StaticInstance; + x_depot, + y_depot, + x_customers, + y_customers; customer_markersize=4, depot_markersize=7, alpha_depot=0.8, @@ -12,14 +10,11 @@ function plot_instance( depot_color=:lightgreen, kwargs..., ) - x = [p.x for p in instance.coordinate] - y = [p.y for p in instance.coordinate] - fig = plot(; legend=:topleft, xlabel="x coordinate", ylabel="y coordinate", kwargs...) scatter!( fig, - x[2:end], - y[2:end]; + x_customers, + y_customers; label="Customers", markercolor=customer_color, marker=:circle, @@ -27,8 +22,8 @@ function plot_instance( ) scatter!( fig, - [x[1]], - [y[1]]; + [x_depot], + [y_depot]; label="Depot", markercolor=depot_color, marker=:rect, @@ -37,3 +32,20 @@ function plot_instance( ) return fig end + +function build_instance_data(instance::StaticInstance) + x = [p.x for p in instance.coordinate] + y = [p.y for p in instance.coordinate] + return (x_depot=x[1], y_depot=y[1], x_customers=x[2:end], y_customers=y[2:end]) +end + +""" +$TYPEDSIGNATURES + +Plot the given static VSP `instance`. +""" +function plot_instance(instance::StaticInstance; kwargs...) + x_depot, y_depot, x, y = build_instance_data(instance) + + return plot_instance(x_depot, y_depot, x, y; kwargs...) +end diff --git a/src/Utils/policy.jl b/src/Utils/policy.jl index 2b3c8e5..7fdf582 100644 --- a/src/Utils/policy.jl +++ b/src/Utils/policy.jl @@ -31,38 +31,30 @@ $TYPEDSIGNATURES Run the policy on the environment and return the total reward and a dataset of observations. By default, the environment is reset before running the policy. """ -function evaluate_policy!(policy, env::AbstractEnvironment; kwargs...) +function evaluate_policy!( + policy, env::AbstractEnvironment; reset_env=true, seed=get_seed(env), kwargs... +) + if reset_env + reset!(env; reset_rng=true, seed=seed) + end total_reward = 0.0 - reset!(env; reset_rng=false) local labeled_dataset while !is_terminated(env) y = policy(env; kwargs...) features, state = observe(env) + state_copy = deepcopy(state) # To avoid mutation issues + reward = step!(env, y) + sample = DataSample(; x=features, y_true=y, instance=(; state=state_copy, reward)) if @isdefined labeled_dataset - push!( - labeled_dataset, - DataSample(; x=features, y_true=y, instance=deepcopy(state)), - ) + push!(labeled_dataset, sample) else - labeled_dataset = [DataSample(; x=features, y_true=y, instance=deepcopy(state))] + labeled_dataset = [sample] end - reward = step!(env, y) total_reward += reward end return total_reward, labeled_dataset end -# function evaluate_policy!(policy, envs::Vector{<:AbstractEnvironment}; kwargs...) -# E = length(envs) -# rewards = zeros(Float64, E) -# datasets = map(1:E) do e -# reward, dataset = evaluate_policy!(policy, envs[e]; kwargs...) -# rewards[e] = reward -# return dataset -# end -# return rewards, vcat(datasets...) -# end - """ $TYPEDSIGNATURES @@ -72,10 +64,14 @@ By default, the environment is reset before running the policy. function evaluate_policy!( policy, env::AbstractEnvironment, episodes::Int; seed=get_seed(env), kwargs... ) - reset!(env; reset_rng=true, seed) total_reward = 0.0 datasets = map(1:episodes) do _i - reward, dataset = evaluate_policy!(policy, env; kwargs...) + if _i == 1 + reset!(env; reset_rng=true, seed=seed) + else + reset!(env; reset_rng=false) + end + reward, dataset = evaluate_policy!(policy, env; reset_env=false, kwargs...) total_reward += reward return dataset end diff --git a/test/dynamic_vsp_plots.jl b/test/dynamic_vsp_plots.jl index cc7c962..185c91a 100644 --- a/test/dynamic_vsp_plots.jl +++ b/test/dynamic_vsp_plots.jl @@ -10,7 +10,7 @@ env = environments[1] # Test basic plotting functions - fig1 = DVSP.plot_instancee(env) + fig1 = DVSP.plot_instance(env) @test fig1 isa Plots.Plot instance = dataset[1].instance @@ -23,7 +23,7 @@ policies = generate_policies(b) lazy = policies[1] _, d = evaluate_policy!(lazy, env) - fig3 = DVSP.plot_routes(d[1].instance, d[1].y_true) + fig3 = DVSP.plot_routes(d[1].instance.state, d[1].y_true) @test fig3 isa Plots.Plot # Test animation diff --git a/test/utils.jl b/test/utils.jl index 4fd4b4f..8591750 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -62,11 +62,11 @@ end dataset = [random_sample() for _ in 1:N] # Test fit with ZScoreTransform - zt = fit(ZScoreTransform, dataset) + zt = fit(ZScoreTransform, dataset; dims=2) @test zt isa ZScoreTransform # Test fit with UnitRangeTransform - ut = fit(UnitRangeTransform, dataset) + ut = fit(UnitRangeTransform, dataset; dims=2) @test ut isa UnitRangeTransform # Test transform (non-mutating)