diff --git a/docs/src/custom_benchmarks.md b/docs/src/custom_benchmarks.md index 9c95b8b..289152d 100644 --- a/docs/src/custom_benchmarks.md +++ b/docs/src/custom_benchmarks.md @@ -18,7 +18,7 @@ AbstractBenchmark |------|----------| | `AbstractBenchmark` | Static, single-stage optimization (e.g. shortest path, portfolio) | | `AbstractStochasticBenchmark{true}` | Single-stage with exogenous uncertainty (scenarios drawn independently of decisions) | -| `AbstractStochasticBenchmark{false}` | Single-stage with endogenous uncertainty (not yet used) | +| `AbstractStochasticBenchmark{false}` | Single-stage with endogenous uncertainty | | `AbstractDynamicBenchmark{true}` | Multi-stage sequential decisions with exogenous uncertainty | | `AbstractDynamicBenchmark{false}` | Multi-stage sequential decisions with endogenous uncertainty | @@ -90,14 +90,9 @@ generate_instance(bench::MyStochasticBenchmark, rng::AbstractRNG; kwargs...) -> # Draw one scenario given the instance encoded in context generate_scenario(bench::MyStochasticBenchmark, rng::AbstractRNG; context...) -> scenario -# Note: sample.context is spread as kwargs when called by the framework +# Note: sample.context is spread as kwargs when called ``` -The framework `generate_sample` calls `generate_instance`, draws `nb_scenarios` -scenarios via `generate_scenario`, then: -- If `target_policy` is provided: calls `target_policy(sample, scenarios) -> Vector{DataSample}`. -- Otherwise: returns unlabeled samples with `extra=(; scenario=ξ)` for each scenario. - #### Anticipative solver (optional) ```julia @@ -189,89 +184,3 @@ DataSample(; x=feat, y=nothing, instance=inst, extra=(; scenario=ξ)) ``` Keys must not appear in both `context` and `extra`, the constructor raises an error. - ---- - -## Small examples - -### Static benchmark - -```julia -using DecisionFocusedLearningBenchmarks -const DFLBenchmarks = DecisionFocusedLearningBenchmarks - -struct MyStaticBenchmark <: AbstractBenchmark end - -function DFLBenchmarks.generate_instance(bench::MyStaticBenchmark, rng::AbstractRNG; kwargs...) - instance = build_my_instance(rng) - x = compute_features(instance) - return DataSample(; x=x, instance=instance) # y = nothing -end - - -DFLBenchmarks.generate_statistical_model(bench::MyStaticBenchmark; seed=nothing) = - Chain(Dense(10 => 32, relu), Dense(32 => 5)) - -DFLBenchmarks.generate_maximizer(bench::MyStaticBenchmark) = - (θ; instance, kwargs...) -> solve_my_problem(θ, instance) -``` - -### Stochastic benchmark - -```julia - -struct MyStochasticBenchmark <: AbstractStochasticBenchmark{true} end - -function DFLBenchmarks.generate_instance(bench::MyStochasticBenchmark, rng::AbstractRNG; kwargs...) - instance = build_my_instance(rng) - x = compute_features(instance) - return DataSample(; x=x, instance=instance) -end - -function DFLBenchmarks.generate_scenario(bench::MyStochasticBenchmark, rng::AbstractRNG; instance, kwargs...) - return sample_scenario(instance, rng) -end - -DFLBenchmarks.generate_anticipative_solver(bench::MyStochasticBenchmark) = - (scenario; instance, kwargs...) -> solve_with_scenario(instance, scenario) -``` - -### Dynamic benchmark - -```julia -struct MyDynamicBenchmark <: AbstractDynamicBenchmark{true} end - -mutable struct MyEnv <: AbstractEnvironment - const instance::MyInstance - const seed::Int - state::MyState -end - -DFLBenchmarks.get_seed(env::MyEnv) = env.seed -DFLBenchmarks.reset!(env::MyEnv; reset_rng=true, seed=env.seed) = (env.state = initial_state(env.instance)) -DFLBenchmarks.observe(env::MyEnv) = (env.state, nothing) -DFLBenchmarks.step!(env::MyEnv, action) = apply_action!(env.state, action) -DFLBenchmarks.is_terminated(env::MyEnv) = env.state.done - -function DFLBenchmarks.generate_environment(bench::MyDynamicBenchmark, rng::AbstractRNG; kwargs...) - inst = build_my_instance(rng) - seed = rand(rng, Int) - return MyEnv(inst, seed, initial_state(inst)) -end - -function DFLBenchmarks.generate_baseline_policies(bench::MyDynamicBenchmark) - greedy = function(env) - samples = DataSample[] - reset!(env) - while !is_terminated(env) - obs, _ = observe(env) - x = compute_features(obs) - y = greedy_action(obs) - r = step!(env, y) - push!(samples, DataSample(; x=x, y=y, instance=obs, extra=(; reward=r))) - end - return samples - end - return (; greedy) -end -``` diff --git a/docs/src/index.md b/docs/src/index.md index 363abee..ac486e9 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -24,20 +24,20 @@ x \;\longrightarrow\; \boxed{\,\text{Statistical model } \varphi_w\,} ``` Where: -- **Statistical model** $\varphi_w$: machine learning predictor (e.g., neural network) -- **CO algorithm** $f$: combinatorial optimization solver - **Instance** $x$: input data (e.g., features, context) +- **Statistical model** $\varphi_w$: machine learning predictor (e.g., neural network) - **Parameters** $\theta$: predicted parameters for the optimization problem solved by `f` +- **CO algorithm** $f$: combinatorial optimization solver - **Solution** $y$: output decision/solution ## Package Overview -**DecisionFocusedLearningBenchmarks.jl** provides a comprehensive collection of benchmark problems for evaluating decision-focused learning algorithms. The package offers: +**DecisionFocusedLearningBenchmarks.jl** provides a collection of benchmark problems for evaluating decision-focused learning algorithms. The package offers: -- **Standardized benchmark problems** spanning diverse application domains -- **Common interfaces** for creating datasets, statistical models, and optimization algorithms -- **Ready-to-use DFL policies** compatible with [InferOpt.jl](https://github.com/JuliaDecisionFocusedLearning/InferOpt.jl) and the whole [JuliaDecisionFocusedLearning](https://github.com/JuliaDecisionFocusedLearning) ecosystem -- **Evaluation tools** for comparing algorithm performance +- **Collection of benchmark problems** spanning diverse applications +- **Common tools** for creating datasets, statistical models, and optimization algorithms +- **Generic interface** for building custom benchmarks +- Compatibility with [InferOpt.jl](https://github.com/JuliaDecisionFocusedLearning/InferOpt.jl) and the whole [JuliaDecisionFocusedLearning](https://github.com/JuliaDecisionFocusedLearning) ecosystem ## Benchmark Categories diff --git a/docs/src/using_benchmarks.md b/docs/src/using_benchmarks.md index a1e95dd..b136424 100644 --- a/docs/src/using_benchmarks.md +++ b/docs/src/using_benchmarks.md @@ -1,8 +1,42 @@ # Using Benchmarks -This guide covers everything you need to work with existing benchmarks in -DecisionFocusedLearningBenchmarks.jl: generating datasets, assembling DFL pipeline -components, and evaluating results. +This guide covers everything you need to work with existing benchmarks in DecisionFocusedLearningBenchmarks.jl: generating datasets, assembling DFL pipeline components, applying algorithms, and evaluating results. + +--- + +## What is a benchmark? + +A benchmark bundles a problem family (an instance generator, a combinatorial solver, and a statistical model architecture) into a single object. It provides everything needed to run a Decision-Focused Learning experiment out of the box, without having to create each component from scratch. +Three abstract types cover the main settings: +- **`AbstractBenchmark`**: static problems (one instance, one decision) +- **`AbstractStochasticBenchmark{exogenous}`**: stochastic problems (type parameter indicates whether uncertainty is exogenous) +- **`AbstractDynamicBenchmark`**: sequential / multi-stage problems + +The sections below explain what changes between these settings. For most purposes, start with a static benchmark to understand the core workflow. + +--- + +## Core workflow + +Every benchmark exposes three key methods. For any static benchmark: + +```julia +bench = ArgmaxBenchmark() +model = generate_statistical_model(bench; seed=0) # Flux model +maximizer = generate_maximizer(bench) # combinatorial oracle +dataset = generate_dataset(bench, 100; seed=0) # Vector{DataSample} +``` + +- **`generate_statistical_model`**: returns an untrained neural network that maps input features `x` to cost parameters `θ`. +- **`generate_maximizer`**: returns a callable `(θ; context...) -> y` that solves the combinatorial problem given cost parameters. +- **`generate_dataset`**: returns labeled training data as a `Vector{DataSample}`. + +At inference time these two pieces compose naturally as an end-to-end policy: + +```julia +θ = model(sample.x) # predict cost parameters +y = maximizer(θ; sample.context...) # solve the optimization problem +``` --- @@ -15,11 +49,10 @@ All data in the package is represented as [`DataSample`](@ref) objects. | `x` | any | Input features (fed to the statistical model) | | `θ` | any | Intermediate cost parameters | | `y` | any | Output decision / solution | -| `context` | `NamedTuple` | Solver kwargs — spread into `maximizer(θ; sample.context...)` | -| `extra` | `NamedTuple` | Non-solver data (scenario, reward, step, …) — never passed to the solver | +| `context` | `NamedTuple` | Solver kwargs spread into `maximizer(θ; sample.context...)` | +| `extra` | `NamedTuple` | Non-solver data (scenario, reward, step, …), never passed to the solver | -Not all fields are populated in every sample. For convenience, named entries inside -`context` and `extra` can be accessed directly on the sample via property forwarding: +Not all fields are populated in every sample, depending on the setting. For convenience, named entries inside `context` and `extra` can be accessed directly on the sample via property forwarding: ```julia sample.instance # looks up :instance in context first, then in extra @@ -28,12 +61,11 @@ sample.scenario # looks up :scenario in context first, then in extra --- -## Generating datasets for training +## Benchmark type specifics ### Static benchmarks -For static benchmarks (`<:AbstractBenchmark`) the framework already computes the -ground-truth label `y`: +For static benchmarks (`<:AbstractBenchmark`), `generate_dataset` may compute a default ground-truth label `y` if the benchmark implements it: ```julia bench = ArgmaxBenchmark() @@ -43,15 +75,13 @@ dataset = generate_dataset(bench, 100; seed=0) # Vector{DataSample} with x, y, You can override the labels by providing a `target_policy`: ```julia -my_policy = sample -> DataSample(; sample.context..., x=sample.x, - y=my_algorithm(sample.instance)) +my_policy = sample -> DataSample(; sample.context..., x=sample.x, y=my_algorithm(sample.instance)) dataset = generate_dataset(bench, 100; seed=0, target_policy=my_policy) ``` ### Stochastic benchmarks (exogenous) -For `AbstractStochasticBenchmark{true}` benchmarks the default call returns -*unlabeled* samples, each sample carries one scenario in `sample.extra.scenario`: +For `AbstractStochasticBenchmark{true}` benchmarks the default call returns *unlabeled* samples, each sample carries one scenario in `sample.extra.scenario`: ```julia bench = StochasticVehicleSchedulingBenchmark() @@ -85,20 +115,22 @@ Dynamic benchmarks use a two-step workflow: ```julia bench = DynamicVehicleSchedulingBenchmark() -# Step 1 — create environments (reusable across experiments) +# Step 1: create environments (reusable across experiments) envs = generate_environments(bench, 10; seed=0) -# Step 2 — roll out a policy to collect training trajectories +# Step 2: roll out a policy to collect training trajectories policy = generate_baseline_policies(bench)[1] # e.g. lazy policy dataset = generate_dataset(bench, envs; target_policy=policy) # dataset is a flat Vector{DataSample} of all steps across all trajectories ``` -`target_policy` is **required** for dynamic benchmarks (there is no default label). +`target_policy` is **required** to create datasets for dynamic benchmarks (there is no default label). It must be a callable `(env) -> Vector{DataSample}` that performs a full episode rollout and returns the resulting trajectory. -### Seed / RNG control +--- + +## Seed / RNG control All `generate_dataset` and `generate_environments` calls accept either `seed` (creates an internal `MersenneTwister`) or `rng` for full control: @@ -111,22 +143,6 @@ dataset = generate_dataset(bench, 50; rng=rng) --- -## DFL pipeline components - -```julia -model = generate_statistical_model(bench; seed=0) # untrained Flux model -maximizer = generate_maximizer(bench) # combinatorial oracle -``` - -These two pieces compose naturally: - -```julia -θ = model(sample.x) # predict cost parameters -y = maximizer(θ; sample.context...) # solve the optimization problem -``` - ---- - ## Evaluation ```julia diff --git a/src/Argmax/Argmax.jl b/src/Argmax/Argmax.jl index 60f37c5..a4faede 100644 --- a/src/Argmax/Argmax.jl +++ b/src/Argmax/Argmax.jl @@ -60,7 +60,7 @@ $TYPEDSIGNATURES Return an argmax maximizer. """ -function Utils.generate_maximizer(bench::ArgmaxBenchmark) +function Utils.generate_maximizer(::ArgmaxBenchmark) return one_hot_argmax end diff --git a/src/DecisionFocusedLearningBenchmarks.jl b/src/DecisionFocusedLearningBenchmarks.jl index 6561b7a..dd6cb94 100644 --- a/src/DecisionFocusedLearningBenchmarks.jl +++ b/src/DecisionFocusedLearningBenchmarks.jl @@ -73,7 +73,7 @@ export generate_scenario export generate_baseline_policies export generate_statistical_model export generate_maximizer -export generate_anticipative_solution +export generate_anticipative_solver, generate_parametric_anticipative_solver export is_exogenous, is_endogenous export objective_value diff --git a/src/DynamicAssortment/DynamicAssortment.jl b/src/DynamicAssortment/DynamicAssortment.jl index df0e64e..8c372ca 100644 --- a/src/DynamicAssortment/DynamicAssortment.jl +++ b/src/DynamicAssortment/DynamicAssortment.jl @@ -139,7 +139,7 @@ function Utils.generate_baseline_policies(::DynamicAssortmentBenchmark) "policy that selects the assortment with the highest expected revenue", expert_policy, ) - return (expert, greedy) + return (; expert, greedy) end export DynamicAssortmentBenchmark diff --git a/src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl b/src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl index 1eba500..92ff5d5 100644 --- a/src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl +++ b/src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl @@ -115,28 +115,14 @@ end """ $TYPEDSIGNATURES -Generate an anticipative solution for the dynamic vehicle scheduling benchmark. -The solution is computed using the anticipative solver with the benchmark's feature configuration. -""" -function Utils.generate_anticipative_solution( - b::DynamicVehicleSchedulingBenchmark, args...; kwargs... -) - return anticipative_solver( - args...; kwargs..., two_dimensional_features=b.two_dimensional_features - ) -end - -""" -$TYPEDSIGNATURES - Return the anticipative solver for the dynamic vehicle scheduling benchmark. -The callable takes a scenario and solver kwargs (including `instance`) and returns a -training trajectory as a `Vector{DataSample}`. +The callable takes an environment and solver kwargs and returns a training trajectory +as a `Vector{DataSample}`. Set `reset_env=true` (default) to reset the environment +before solving, or `reset_env=false` to plan from the current state. """ function Utils.generate_anticipative_solver(::DynamicVehicleSchedulingBenchmark) - return (scenario; instance, kwargs...) -> begin - env = DVSPEnv(instance, scenario) - _, trajectory = anticipative_solver(env; reset_env=false, kwargs...) + return (env; reset_env=true, kwargs...) -> begin + _, trajectory = anticipative_solver(env; reset_env, kwargs...) return trajectory end end @@ -160,7 +146,7 @@ function Utils.generate_baseline_policies(::DynamicVehicleSchedulingBenchmark) "Greedy policy that dispatches vehicles to the nearest customer.", greedy_policy, ) - return (lazy, greedy) + return (; lazy, greedy) end """ diff --git a/src/FixedSizeShortestPath/FixedSizeShortestPath.jl b/src/FixedSizeShortestPath/FixedSizeShortestPath.jl index ee0586a..e4cc64b 100644 --- a/src/FixedSizeShortestPath/FixedSizeShortestPath.jl +++ b/src/FixedSizeShortestPath/FixedSizeShortestPath.jl @@ -142,6 +142,4 @@ function Utils.generate_statistical_model( end export FixedSizeShortestPathBenchmark -export generate_dataset, generate_maximizer, generate_statistical_model - end diff --git a/src/Maintenance/Maintenance.jl b/src/Maintenance/Maintenance.jl index 64e5ec5..d9e831b 100644 --- a/src/Maintenance/Maintenance.jl +++ b/src/Maintenance/Maintenance.jl @@ -22,7 +22,6 @@ The number of simultaneous maintenance operations is limited by a maintenance ca # Fields $TYPEDFIELDS - """ struct MaintenanceBenchmark <: AbstractDynamicBenchmark{true} "number of components" @@ -126,7 +125,7 @@ end """ $TYPEDSIGNATURES -Returns two policies for the dynamic assortment benchmark: +Returns a policy for the maintenance benchmark: - `Greedy`: maintains components when they are in the last state before failure, up to the maintenance capacity """ function Utils.generate_baseline_policies(::MaintenanceBenchmark) @@ -135,7 +134,7 @@ function Utils.generate_baseline_policies(::MaintenanceBenchmark) "policy that maintains components when they are in the last state before failure, up to the maintenance capacity", greedy_policy, ) - return (greedy,) + return (; greedy) end export MaintenanceBenchmark diff --git a/src/PortfolioOptimization/PortfolioOptimization.jl b/src/PortfolioOptimization/PortfolioOptimization.jl index 37631eb..9e8c277 100644 --- a/src/PortfolioOptimization/PortfolioOptimization.jl +++ b/src/PortfolioOptimization/PortfolioOptimization.jl @@ -116,6 +116,4 @@ function Utils.generate_statistical_model( end export PortfolioOptimizationBenchmark -export generate_dataset, generate_maximizer, generate_statistical_model - end diff --git a/src/SubsetSelection/SubsetSelection.jl b/src/SubsetSelection/SubsetSelection.jl index 416745f..a05359d 100644 --- a/src/SubsetSelection/SubsetSelection.jl +++ b/src/SubsetSelection/SubsetSelection.jl @@ -76,13 +76,11 @@ $TYPEDSIGNATURES Initialize a linear model for `bench` using `Flux`. """ -function Utils.generate_statistical_model(bench::SubsetSelectionBenchmark; seed=0) +function Utils.generate_statistical_model(bench::SubsetSelectionBenchmark; seed=nothing) Random.seed!(seed) (; n) = bench return Dense(n => n; bias=false) end export SubsetSelectionBenchmark -export generate_dataset, generate_maximizer, generate_statistical_model - end diff --git a/src/Utils/Utils.jl b/src/Utils/Utils.jl index 89a6c67..ae766b1 100644 --- a/src/Utils/Utils.jl +++ b/src/Utils/Utils.jl @@ -32,7 +32,7 @@ export generate_statistical_model, generate_maximizer export generate_scenario export generate_environment, generate_environments export generate_baseline_policies -export generate_anticipative_solution +export generate_anticipative_solver, generate_parametric_anticipative_solver export plot_data, compute_gap export grid_graph, get_path, path_to_matrix diff --git a/src/Utils/interface.jl b/src/Utils/interface.jl index 23d6eb5..b97cc2e 100644 --- a/src/Utils/interface.jl +++ b/src/Utils/interface.jl @@ -6,8 +6,9 @@ Abstract type interface for benchmark problems. # Mandatory methods to implement for any benchmark: Choose one of three primary implementation strategies: - Implement [`generate_instance`](@ref) (returns a [`DataSample`](@ref) with `y=nothing`). - The default [`generate_sample`](@ref) then applies `target_policy` if provided. -- Override [`generate_sample`](@ref) directly when the sample requires custom logic. In this case, + The default [`generate_sample`](@ref) forwards the call directly; [`generate_dataset`](@ref) + applies `target_policy` afterwards if provided. +- Override [`generate_sample`](@ref) directly when the sample requires custom logic. [`generate_dataset`](@ref) applies `target_policy` to the result after the call returns. - Override [`generate_dataset`](@ref) directly when samples cannot be drawn independently. @@ -40,20 +41,18 @@ function generate_instance(bench::AbstractBenchmark, rng::AbstractRNG; kwargs... end """ - generate_sample(::AbstractBenchmark, rng::AbstractRNG; target_policy=nothing, kwargs...) -> DataSample + generate_sample(::AbstractBenchmark, rng::AbstractRNG; kwargs...) -> DataSample Generate a single [`DataSample`](@ref) for the benchmark. -**Framework default** (when [`generate_instance`](@ref) is implemented): -Calls [`generate_instance`](@ref), then applies `target_policy(sample)` if provided. +**Default** (when [`generate_instance`](@ref) is implemented): +Calls [`generate_instance`](@ref) and returns the result directly. -Override directly (instead of implementing [`generate_instance`](@ref)) when the sample -requires custom logic. In this case, [`generate_dataset`](@ref) applies `target_policy` -after the call returns. +Override this method when sample generation requires custom logic. Labeling via +`target_policy` is always applied by [`generate_dataset`](@ref) after this call returns. """ -function generate_sample(bench::AbstractBenchmark, rng; target_policy=nothing, kwargs...) - sample = generate_instance(bench, rng; kwargs...) - return isnothing(target_policy) ? sample : target_policy(sample) +function generate_sample(bench::AbstractBenchmark, rng; kwargs...) + return generate_instance(bench, rng; kwargs...) end """ @@ -63,8 +62,8 @@ Generate a `Vector` of [`DataSample`](@ref) of length `dataset_size` for given b Content of the dataset can be visualized using [`plot_data`](@ref), when it applies. By default, it uses [`generate_sample`](@ref) to create each sample in the dataset, and passes any -keyword arguments to it. If `target_policy` is provided, it is applied to each sample after -[`generate_sample`](@ref) returns. +keyword arguments to it. `target_policy is applied if provided, it is called on each sample +after [`generate_sample`](@ref) returns. """ function generate_dataset( bench::AbstractBenchmark, @@ -263,18 +262,15 @@ spread from `sample.context`: function generate_scenario end """ - generate_anticipative_solver(::AbstractStochasticBenchmark{true}) -> callable + generate_anticipative_solver(::AbstractBenchmark) -> callable -Return a callable that computes the anticipative solution for a given scenario. -The instance and other solver-relevant fields are spread from the sample context. +Return a callable that computes the anticipative solution. - For [`AbstractStochasticBenchmark`](@ref): returns `(scenario; context...) -> y`. - For [`AbstractDynamicBenchmark`](@ref): returns - `(scenario; context...) -> Vector{DataSample}` — a full training trajectory. - - solver = generate_anticipative_solver(bench) - y = solver(scenario; sample.context...) # stochastic - trajectory = solver(scenario; sample.context...) # dynamic + `(env; reset_env=true, kwargs...) -> Vector{DataSample}`, a full training trajectory. + `reset_env=true` resets the env before solving (initial dataset building); + `reset_env=false` starts from the current env state. """ function generate_anticipative_solver end @@ -288,16 +284,6 @@ parametric anticipative subproblem: """ function generate_parametric_anticipative_solver end -""" - generate_anticipative_solution(::AbstractStochasticBenchmark, instance, scenario; kwargs...) - -!!! warning "Deprecated" - Use [`generate_anticipative_solver`](@ref) instead, which returns a callable - `(scenario; kwargs...) -> y` consistent with the [`generate_maximizer`](@ref) - convention. -""" -function generate_anticipative_solution end - """ $TYPEDSIGNATURES diff --git a/src/Warcraft/Warcraft.jl b/src/Warcraft/Warcraft.jl index 6452d33..a169fc5 100644 --- a/src/Warcraft/Warcraft.jl +++ b/src/Warcraft/Warcraft.jl @@ -64,7 +64,7 @@ The embedding is made as follows: 4) The element-wize `neg_tensor` function to get cell weights of proper sign to apply shortest path algorithms. 5) A squeeze function to forget the two last dimensions. """ -function Utils.generate_statistical_model(::WarcraftBenchmark; seed=0) +function Utils.generate_statistical_model(::WarcraftBenchmark; seed=nothing) Random.seed!(seed) resnet18 = ResNet(18; pretrain=false, nclasses=1) model_embedding = Chain( diff --git a/test/dynamic_vsp.jl b/test/dynamic_vsp.jl index 8564a2e..9c208a9 100644 --- a/test/dynamic_vsp.jl +++ b/test/dynamic_vsp.jl @@ -25,8 +25,7 @@ @test mean(r_lazy) <= mean(r_greedy) env = environments[1] - scenario = env.scenario - v, y = generate_anticipative_solution(b, env, scenario; nb_epochs=2, reset_env=true) + y = generate_anticipative_solver(b)(env; nb_epochs=2) maximizer = generate_maximizer(b) @@ -44,11 +43,10 @@ @test size(x, 1) == 2 @test size(x2, 1) == 27 - anticipative_value, solution = generate_anticipative_solution(b, env; reset_env=true) + solution = generate_anticipative_solver(b)(env) reset!(env; reset_rng=true) cost = sum(step!(env, sample.y) for sample in solution) cost2 = sum(sample.reward for sample in solution) - @test isapprox(cost, anticipative_value; atol=1e-5) @test isapprox(cost, cost2; atol=1e-5) end diff --git a/test/dynamic_vsp_plots.jl b/test/dynamic_vsp_plots.jl index 345e823..721d9a0 100644 --- a/test/dynamic_vsp_plots.jl +++ b/test/dynamic_vsp_plots.jl @@ -12,7 +12,7 @@ @test fig1 isa Plots.Plot scenario = env.scenario - v, y = generate_anticipative_solution(b, env, scenario; nb_epochs=3, reset_env=true) + y = generate_anticipative_solver(b)(env; nb_epochs=3) fig2 = DVSP.plot_epochs(y) @test fig2 isa Plots.Plot