Skip to content

Commit 7c51353

Browse files
authored
Merge pull request #8 from JuliaDecisionFocusedLearning/new-benchmark-interface
Adapt to new benchmark interface
2 parents c084670 + 28d2aa6 commit 7c51353

File tree

16 files changed

+189
-170
lines changed

16 files changed

+189
-170
lines changed

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
1919
ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7"
2020

2121
[compat]
22-
DecisionFocusedLearningBenchmarks = "0.4"
22+
DecisionFocusedLearningBenchmarks = "0.5.0"
2323
DocStringExtensions = "0.9.5"
24-
Flux = "0.16.5"
24+
Flux = "0.16.9"
2525
InferOpt = "0.7.1"
2626
MLUtils = "0.4.8"
2727
ProgressMeter = "1.11.0"
2828
Random = "1.11.0"
2929
Statistics = "1.11.1"
30-
UnicodePlots = "3.8.1"
31-
ValueHistories = "0.5.4"
30+
UnicodePlots = "3.8.2"
31+
ValueHistories = "0.5.6"
3232
julia = "1.11"

src/DecisionFocusedLearningAlgorithms.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ using DecisionFocusedLearningBenchmarks
44
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
55
using Flux: Flux, Adam
66
using InferOpt: InferOpt, FenchelYoungLoss, PerturbedAdditive, PerturbedMultiplicative
7-
using MLUtils: splitobs, DataLoader
7+
using MLUtils: DataLoader
88
using ProgressMeter: @showprogress
9+
using Random: Random, MersenneTwister
910
using Statistics: mean
1011
using UnicodePlots: lineplot
1112
using ValueHistories: MVHistory
1213

13-
include("utils.jl")
1414
include("training_context.jl")
1515

1616
include("metrics/interface.jl")
@@ -39,6 +39,7 @@ export AbstractMetric,
3939
compute!,
4040
evaluate_metrics!
4141

42+
export AbstractAlgorithm, AbstractImitationAlgorithm
4243
export PerturbedFenchelYoungLossImitation,
4344
DAgger, AnticipativeImitation, train_policy!, train_policy
4445
export AbstractPolicy, DFLPolicy

src/algorithms/abstract_algorithm.jl

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,46 @@ $TYPEDEF
1111
An abstract type for imitation learning algorithms.
1212
1313
All subtypes must implement:
14-
- `train_policy!(algorithm::AbstractImitationAlgorithm, model, maximizer, train_data; epochs, metrics)`
14+
- `train_policy!(algorithm::AbstractImitationAlgorithm, policy::DFLPolicy, train_data; epochs, metrics)`
1515
"""
1616
abstract type AbstractImitationAlgorithm <: AbstractAlgorithm end
17+
18+
"""
19+
$TYPEDSIGNATURES
20+
21+
Train a new DFLPolicy on a benchmark using any imitation learning algorithm.
22+
23+
Convenience wrapper that handles dataset generation, model initialization, and policy
24+
creation. Returns the training history and the trained policy.
25+
26+
For dynamic benchmarks, use the algorithm-specific `train_policy` overload that accepts
27+
environments and an anticipative policy.
28+
"""
29+
function train_policy(
30+
algorithm::AbstractImitationAlgorithm,
31+
benchmark::AbstractBenchmark;
32+
target_policy=nothing,
33+
dataset_size=30,
34+
epochs=100,
35+
metrics::Tuple=(),
36+
seed=nothing,
37+
)
38+
dataset = generate_dataset(benchmark, dataset_size; target_policy)
39+
40+
if any(s -> isnothing(s.y), dataset)
41+
error(
42+
"Training dataset contains unlabeled samples (y=nothing). " *
43+
"Provide a `target_policy` kwarg to label samples during dataset generation.",
44+
)
45+
end
46+
47+
model = generate_statistical_model(benchmark; seed)
48+
maximizer = generate_maximizer(benchmark)
49+
policy = DFLPolicy(model, maximizer)
50+
51+
history = train_policy!(
52+
algorithm, policy, dataset; epochs, metrics, maximizer_kwargs=s -> s.context
53+
)
54+
55+
return history, policy
56+
end

src/algorithms/supervised/anticipative_imitation.jl

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,11 @@ function train_policy!(
3131
anticipative_policy,
3232
epochs=10,
3333
metrics::Tuple=(),
34-
maximizer_kwargs=get_state,
34+
maximizer_kwargs=sample -> sample.context,
3535
)
3636
# Generate anticipative solutions as training data
3737
train_dataset = vcat(map(train_environments) do env
38-
v, y = anticipative_policy(env; reset_env=true)
39-
return y
38+
return anticipative_policy(env; reset_env=true)
4039
end...)
4140

4241
# Delegate to inner algorithm
@@ -62,26 +61,22 @@ Uses anticipative solutions as expert demonstrations.
6261
"""
6362
function train_policy(
6463
algorithm::AnticipativeImitation,
65-
benchmark::AbstractStochasticBenchmark{true};
64+
benchmark::ExogenousDynamicBenchmark;
6665
dataset_size=30,
67-
split_ratio=(0.3, 0.3),
6866
epochs=10,
6967
metrics::Tuple=(),
7068
seed=nothing,
7169
)
72-
# Generate instances and environments
73-
dataset = generate_dataset(benchmark, dataset_size)
74-
train_instances, validation_instances, _ = splitobs(dataset; at=split_ratio)
75-
train_environments = generate_environments(benchmark, train_instances)
70+
# Generate environments
71+
train_environments = generate_environments(benchmark, dataset_size; seed)
7672

7773
# Initialize model and create policy
7874
model = generate_statistical_model(benchmark; seed)
7975
maximizer = generate_maximizer(benchmark)
8076
policy = DFLPolicy(model, maximizer)
8177

8278
# Define anticipative policy from benchmark
83-
anticipative_policy =
84-
(env; reset_env) -> generate_anticipative_solution(benchmark, env; reset_env)
79+
anticipative_policy = generate_anticipative_solver(benchmark)
8580

8681
# Train policy
8782
history = train_policy!(

src/algorithms/supervised/dagger.jl

Lines changed: 32 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Reference: <https://arxiv.org/abs/2402.04463>
88
# Fields
99
$TYPEDFIELDS
1010
"""
11-
@kwdef struct DAgger{A} <: AbstractImitationAlgorithm
11+
@kwdef struct DAgger{A,S} <: AbstractImitationAlgorithm
1212
"inner imitation algorithm for supervised learning"
1313
inner_algorithm::A = PerturbedFenchelYoungLossImitation()
1414
"number of DAgger iterations"
@@ -17,6 +17,11 @@ $TYPEDFIELDS
1717
epochs_per_iteration::Int = 3
1818
"decay factor for mixing expert and learned policy"
1919
α_decay::Float64 = 0.9
20+
"random seed for the expert/policy mixing coin-flip (nothing = non-reproducible)"
21+
seed::S = nothing
22+
"maximum dataset size across iterations (nothing keeps all samples,
23+
an integer caps to the most recent N samples via FIFO)"
24+
max_dataset_size::Union{Int,Nothing} = nothing
2025
end
2126

2227
"""
@@ -34,24 +39,24 @@ function train_policy!(
3439
train_environments;
3540
anticipative_policy,
3641
metrics::Tuple=(),
37-
maximizer_kwargs=get_state,
42+
maximizer_kwargs=sample -> sample.context,
3843
)
39-
(; inner_algorithm, iterations, epochs_per_iteration, α_decay) = algorithm
44+
(; inner_algorithm, iterations, epochs_per_iteration, α_decay, seed) = algorithm
4045
(; statistical_model, maximizer) = policy
4146

47+
rng = isnothing(seed) ? MersenneTwister() : MersenneTwister(seed)
4248
α = 1.0
4349

4450
# Initial dataset from expert demonstrations
4551
train_dataset = vcat(map(train_environments) do env
46-
v, y = anticipative_policy(env; reset_env=true)
47-
return y
52+
return anticipative_policy(env; reset_env=true)
4853
end...)
4954

5055
dataset = deepcopy(train_dataset)
5156

5257
# Initialize combined history for all DAgger iterations
5358
combined_history = MVHistory()
54-
global_epoch = 0
59+
epoch_offset = 0
5560

5661
for iter in 1:iterations
5762
println("DAgger iteration $iter/$iterations (α=$(round(α, digits=3)))")
@@ -68,53 +73,26 @@ function train_policy!(
6873

6974
# Merge iteration history into combined history
7075
for key in keys(iter_history)
71-
epochs, values = get(iter_history, key)
72-
for i in eachindex(epochs)
73-
# Calculate global epoch number
74-
if iter == 1
75-
# First iteration: use epochs as-is [0, 1, 2, ...]
76-
global_epoch_value = epochs[i]
77-
else
78-
# Later iterations: skip epoch 0 and renumber starting from global_epoch
79-
if epochs[i] == 0
80-
continue # Skip epoch 0 for iterations > 1
81-
end
82-
# Map epoch 1 → global_epoch, epoch 2 → global_epoch+1, etc.
83-
global_epoch_value = global_epoch + epochs[i] - 1
84-
end
85-
86-
# For the epoch key, use global_epoch_value as both time and value
87-
# For other keys, use global_epoch_value as time and original value
88-
if key == :epoch
89-
push!(combined_history, key, global_epoch_value, global_epoch_value)
90-
else
91-
push!(combined_history, key, global_epoch_value, values[i])
92-
end
76+
local_epochs, values = get(iter_history, key)
77+
for i in eachindex(local_epochs)
78+
# Skip epoch 0 for all iterations after the first
79+
local_epochs[i] == 0 && epoch_offset > 0 && continue
80+
global_e = epoch_offset + local_epochs[i]
81+
push!(combined_history, key, global_e, key == :epoch ? global_e : values[i])
9382
end
9483
end
9584

96-
# Update global_epoch for next iteration
97-
# After each iteration, advance by the number of non-zero epochs processed
98-
if iter == 1
99-
# First iteration processes all epochs [0, 1, ..., epochs_per_iteration]
100-
# Next iteration should start at epochs_per_iteration + 1
101-
global_epoch = epochs_per_iteration + 1
102-
else
103-
# Subsequent iterations skip epoch 0, so they process epochs_per_iteration epochs
104-
# Next iteration should start epochs_per_iteration later
105-
global_epoch += epochs_per_iteration
106-
end
85+
epoch_offset += epochs_per_iteration
10786

10887
# Dataset update - collect new samples using mixed policy
10988
new_samples = eltype(dataset)[]
11089
for env in train_environments
11190
DecisionFocusedLearningBenchmarks.reset!(env; reset_rng=false)
11291
while !is_terminated(env)
113-
x_before = copy(observe(env)[1])
114-
_, anticipative_solution = anticipative_policy(env; reset_env=false)
115-
p = rand()
92+
anticipative_solution = anticipative_policy(env; reset_env=false)
93+
p = rand(rng)
11694
target = anticipative_solution[1]
117-
x, state = observe(env)
95+
x, _ = observe(env)
11896
if size(target.x) != size(x)
11997
@error "Mismatch between expert and observed state" size(target.x) size(
12098
x
@@ -124,14 +102,16 @@ function train_policy!(
124102
if p < α
125103
action = target.y
126104
else
127-
x, state = observe(env)
128105
θ = statistical_model(x)
129106
action = maximizer(θ; maximizer_kwargs(target)...)
130107
end
131108
step!(env, action)
132109
end
133110
end
134-
dataset = new_samples # TODO: replay buffer
111+
dataset = vcat(dataset, new_samples)
112+
if !isnothing(algorithm.max_dataset_size)
113+
dataset = last(dataset, algorithm.max_dataset_size)
114+
end
135115
α *= α_decay # Decay factor for mixing expert and learned policy
136116
end
137117

@@ -149,25 +129,21 @@ This high-level function handles all setup from the benchmark and returns a trai
149129
"""
150130
function train_policy(
151131
algorithm::DAgger,
152-
benchmark::AbstractStochasticBenchmark{true};
132+
benchmark::ExogenousDynamicBenchmark;
153133
dataset_size=30,
154-
split_ratio=(0.3, 0.3, 0.4),
155134
metrics::Tuple=(),
156-
seed=0,
135+
seed=nothing,
157136
)
158-
# Generate dataset and environments
159-
dataset = generate_dataset(benchmark, dataset_size)
160-
train_instances, validation_instances, _ = splitobs(dataset; at=split_ratio)
161-
train_environments = generate_environments(benchmark, train_instances; seed)
137+
# Generate environments
138+
train_environments = generate_environments(benchmark, dataset_size; seed)
162139

163140
# Initialize model and create policy
164-
model = generate_statistical_model(benchmark)
141+
model = generate_statistical_model(benchmark; seed)
165142
maximizer = generate_maximizer(benchmark)
166143
policy = DFLPolicy(model, maximizer)
167144

168145
# Define anticipative policy from benchmark
169-
anticipative_policy =
170-
(env; reset_env) -> generate_anticipative_solution(benchmark, env; reset_env)
146+
anticipative_policy = generate_anticipative_solver(benchmark)
171147

172148
# Train policy
173149
history = train_policy!(
@@ -176,7 +152,7 @@ function train_policy(
176152
train_environments;
177153
anticipative_policy=anticipative_policy,
178154
metrics=metrics,
179-
maximizer_kwargs=get_state,
155+
maximizer_kwargs=sample -> sample.context,
180156
)
181157

182158
return history, policy

src/algorithms/supervised/fyl.jl

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function train_policy!(
4545
train_dataset::DataLoader;
4646
epochs=100,
4747
metrics::Tuple=(),
48-
maximizer_kwargs=get_info,
48+
maximizer_kwargs=sample -> sample.context,
4949
)
5050
(; nb_samples, ε, threaded, training_optimizer, seed) = algorithm
5151
(; statistical_model, maximizer) = policy
@@ -106,7 +106,7 @@ function train_policy!(
106106
train_dataset::AbstractArray{<:DataSample};
107107
epochs=100,
108108
metrics::Tuple=(),
109-
maximizer_kwargs=get_info,
109+
maximizer_kwargs=sample -> sample.context,
110110
)
111111
data_loader = DataLoader(train_dataset; batchsize=1, shuffle=false)
112112
return train_policy!(
@@ -118,38 +118,3 @@ function train_policy!(
118118
maximizer_kwargs=maximizer_kwargs,
119119
)
120120
end
121-
122-
"""
123-
$TYPEDSIGNATURES
124-
125-
Train a DFLPolicy using the Perturbed Fenchel-Young Loss Imitation Algorithm on a benchmark.
126-
127-
# Benchmark convenience wrapper
128-
129-
This high-level function handles all setup from the benchmark and returns a trained policy.
130-
"""
131-
function train_policy(
132-
algorithm::PerturbedFenchelYoungLossImitation,
133-
benchmark::AbstractBenchmark;
134-
dataset_size=30,
135-
split_ratio=(0.3, 0.3),
136-
epochs=100,
137-
metrics::Tuple=(),
138-
seed=nothing,
139-
)
140-
# Generate dataset and split
141-
dataset = generate_dataset(benchmark, dataset_size)
142-
train_instances, _, _ = splitobs(dataset; at=split_ratio)
143-
144-
# Initialize model and create policy
145-
model = generate_statistical_model(benchmark; seed)
146-
maximizer = generate_maximizer(benchmark)
147-
policy = DFLPolicy(model, maximizer)
148-
149-
# Train policy
150-
history = train_policy!(
151-
algorithm, policy, train_instances; epochs, metrics, maximizer_kwargs=get_info
152-
)
153-
154-
return history, policy
155-
end

0 commit comments

Comments
 (0)