Skip to content

Commit 579d5e2

Browse files
committed
Complete revamp of the interface
1 parent 95ecfcc commit 579d5e2

14 files changed

Lines changed: 176 additions & 193 deletions

File tree

src/DecisionFocusedLearningBenchmarks.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,10 @@ export AbstractEnvironment, get_seed, is_terminated, observe, reset!, step!
6767

6868
export Policy, evaluate_policy!
6969

70-
export generate_sample, generate_dataset, generate_environments, generate_environment
70+
export generate_instance,
71+
generate_sample, generate_dataset, generate_environments, generate_environment
7172
export generate_scenario
72-
export generate_policies
73+
export generate_baseline_policies
7374
export generate_statistical_model
7475
export generate_maximizer
7576
export generate_anticipative_solution

src/DynamicAssortment/DynamicAssortment.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ Returns two policies for the dynamic assortment benchmark:
128128
- `Greedy`: selects the assortment containing items with the highest prices
129129
- `Expert`: selects the assortment with the highest expected revenue (through brute-force enumeration)
130130
"""
131-
function Utils.generate_policies(::DynamicAssortmentBenchmark)
131+
function Utils.generate_baseline_policies(::DynamicAssortmentBenchmark)
132132
greedy = Policy(
133133
"Greedy",
134134
"policy that selects the assortment with items with the highest prices",

src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ Returns a tuple containing:
149149
- `lazy`: A policy that dispatches vehicles only when they are ready
150150
- `greedy`: A policy that dispatches vehicles to the nearest customer
151151
"""
152-
function Utils.generate_policies(b::DynamicVehicleSchedulingBenchmark)
152+
function Utils.generate_baseline_policies(::DynamicVehicleSchedulingBenchmark)
153153
lazy = Policy(
154154
"Lazy",
155155
"Lazy policy that dispatches vehicles only when they are ready.",

src/FixedSizeShortestPath/FixedSizeShortestPath.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,10 @@ $TYPEDSIGNATURES
133133
134134
Initialize a linear model for `bench` using `Flux`.
135135
"""
136-
function Utils.generate_statistical_model(bench::FixedSizeShortestPathBenchmark)
136+
function Utils.generate_statistical_model(
137+
bench::FixedSizeShortestPathBenchmark; seed=nothing
138+
)
139+
Random.seed!(seed)
137140
(; p, graph) = bench
138141
return Chain(Dense(p, ne(graph)))
139142
end

src/Maintenance/Maintenance.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ $TYPEDSIGNATURES
129129
Returns two policies for the dynamic assortment benchmark:
130130
- `Greedy`: maintains components when they are in the last state before failure, up to the maintenance capacity
131131
"""
132-
function Utils.generate_policies(::MaintenanceBenchmark)
132+
function Utils.generate_baseline_policies(::MaintenanceBenchmark)
133133
greedy = Policy(
134134
"Greedy",
135135
"policy that maintains components when they are in the last state before failure, up to the maintenance capacity",

src/PortfolioOptimization/PortfolioOptimization.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using Flux: Chain, Dense
77
using Ipopt: Ipopt
88
using JuMP: @variable, @objective, @constraint, optimize!, value, Model, set_silent
99
using LinearAlgebra: I
10-
using Random: AbstractRNG, MersenneTwister
10+
using Random: Random, AbstractRNG, MersenneTwister
1111

1212
"""
1313
$TYPEDEF
@@ -107,7 +107,10 @@ $TYPEDSIGNATURES
107107
108108
Initialize a linear model for `bench` using `Flux`.
109109
"""
110-
function Utils.generate_statistical_model(bench::PortfolioOptimizationBenchmark)
110+
function Utils.generate_statistical_model(
111+
bench::PortfolioOptimizationBenchmark; seed=nothing
112+
)
113+
Random.seed!(seed)
111114
(; p, d) = bench
112115
return Dense(p, d)
113116
end

src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -73,32 +73,29 @@ end
7373
"""
7474
$TYPEDSIGNATURES
7575
76-
Generate a sample for the given `StochasticVehicleSchedulingBenchmark`.
77-
If you want to not add label solutions in the sample, set `compute_solutions=false`.
78-
By default, they will be computed using column generation.
79-
Note that computing solutions can be time-consuming, especially for large instances.
80-
You can also use instead `compact_mip` or `compact_linearized_mip` as the algorithm to compute solutions.
81-
If you want to provide a custom algorithm to compute solutions, you can pass it as the `algorithm` keyword argument.
82-
If `algorithm` takes keyword arguments, you can pass them as well directly in `kwargs...`.
83-
If `store_city=false`, the coordinates and unnecessary information about instances will not be stored in the sample.
84-
"""
85-
function Utils.generate_sample(
76+
Generate an unlabeled instance for the given `StochasticVehicleSchedulingBenchmark`.
77+
Returns a [`DataSample`](@ref) with features `x` and `instance` set, but `y=nothing`.
78+
79+
To obtain labeled samples, pass a `target_policy` to [`generate_dataset`](@ref):
80+
81+
```julia
82+
policy = sample -> DataSample(; sample.context..., x=sample.x,
83+
y=column_generation_algorithm(sample.instance))
84+
dataset = generate_dataset(benchmark, N; target_policy=policy)
85+
```
86+
87+
If `store_city=false`, coordinates and city information are not stored in the instance.
88+
"""
89+
function Utils.generate_instance(
8690
benchmark::StochasticVehicleSchedulingBenchmark,
8791
rng::AbstractRNG;
8892
store_city=true,
89-
compute_solutions=true,
90-
algorithm=column_generation_algorithm,
9193
kwargs...,
9294
)
9395
(; nb_tasks, nb_scenarios) = benchmark
9496
instance = Instance(; nb_tasks, nb_scenarios, rng, store_city)
9597
x = get_features(instance)
96-
y_true = if compute_solutions
97-
algorithm(instance; kwargs...)
98-
else
99-
nothing
100-
end
101-
return DataSample(; x, instance, y=y_true)
98+
return DataSample(; x, instance)
10299
end
103100

104101
"""

src/Utils/Utils.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,12 @@ export TopKMaximizer
2727
export AbstractEnvironment, get_seed, is_terminated, observe, reset!, step!
2828

2929
export AbstractBenchmark, AbstractStochasticBenchmark, AbstractDynamicBenchmark
30-
export generate_sample, generate_dataset
30+
export generate_instance, generate_sample, generate_dataset
3131
export generate_statistical_model, generate_maximizer
3232
export generate_scenario
3333
export generate_environment, generate_environments
34-
export generate_policies
34+
export generate_baseline_policies
3535
export generate_anticipative_solution
36-
export generate_instance_samples, generate_environment_samples
3736

3837
export plot_data, compute_gap
3938
export grid_graph, get_path, path_to_matrix

0 commit comments

Comments
 (0)