Skip to content

Commit 52e80c7

Browse files
committed
add replay buffer option to dagger + bugfix
1 parent 37b816a commit 52e80c7

File tree

6 files changed

+63
-15
lines changed

6 files changed

+63
-15
lines changed

src/DecisionFocusedLearningAlgorithms.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Flux: Flux, Adam
66
using InferOpt: InferOpt, FenchelYoungLoss, PerturbedAdditive, PerturbedMultiplicative
77
using MLUtils: splitobs, DataLoader
88
using ProgressMeter: @showprogress
9+
using Random: Random, MersenneTwister
910
using Statistics: mean
1011
using UnicodePlots: lineplot
1112
using ValueHistories: MVHistory

src/algorithms/supervised/dagger.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Reference: <https://arxiv.org/abs/2402.04463>
88
# Fields
99
$TYPEDFIELDS
1010
"""
11-
@kwdef struct DAgger{A} <: AbstractImitationAlgorithm
11+
@kwdef struct DAgger{A,S} <: AbstractImitationAlgorithm
1212
"inner imitation algorithm for supervised learning"
1313
inner_algorithm::A = PerturbedFenchelYoungLossImitation()
1414
"number of DAgger iterations"
@@ -17,6 +17,11 @@ $TYPEDFIELDS
1717
epochs_per_iteration::Int = 3
1818
"decay factor for mixing expert and learned policy"
1919
α_decay::Float64 = 0.9
20+
"random seed for the expert/policy mixing coin-flip (nothing = non-reproducible)"
21+
seed::S = nothing
22+
"maximum dataset size across iterations (nothing keeps all samples,
23+
an integer caps to the most recent N samples via FIFO)"
24+
max_dataset_size::Union{Int,Nothing} = nothing
2025
end
2126

2227
"""
@@ -36,9 +41,10 @@ function train_policy!(
3641
metrics::Tuple=(),
3742
maximizer_kwargs=sample -> sample.context,
3843
)
39-
(; inner_algorithm, iterations, epochs_per_iteration, α_decay) = algorithm
44+
(; inner_algorithm, iterations, epochs_per_iteration, α_decay, seed) = algorithm
4045
(; statistical_model, maximizer) = policy
4146

47+
rng = isnothing(seed) ? MersenneTwister() : MersenneTwister(seed)
4248
α = 1.0
4349

4450
# Initial dataset from expert demonstrations
@@ -85,7 +91,7 @@ function train_policy!(
8591
while !is_terminated(env)
8692
x_before = copy(observe(env)[1])
8793
anticipative_solution = anticipative_policy(env; reset_env=false)
88-
p = rand()
94+
p = rand(rng)
8995
target = anticipative_solution[1]
9096
x, state = observe(env)
9197
if size(target.x) != size(x)
@@ -104,7 +110,10 @@ function train_policy!(
104110
step!(env, action)
105111
end
106112
end
107-
dataset = new_samples # TODO: replay buffer
113+
dataset = vcat(dataset, new_samples)
114+
if !isnothing(algorithm.max_dataset_size)
115+
dataset = last(dataset, algorithm.max_dataset_size)
116+
end
108117
α *= α_decay # Decay factor for mixing expert and learned policy
109118
end
110119

src/metrics/interface.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,7 @@ Internal helper to store a single metric value in the history.
4343
function _store_metric_value!(
4444
history::MVHistory, metric_name::Symbol, epoch::Int, value::Number
4545
)
46-
try
47-
push!(history, metric_name, epoch, value)
48-
catch e
49-
throw(
50-
ErrorException(
51-
"Failed to store metric '$metric_name' at epoch $epoch: $(e.msg)"
52-
),
53-
)
54-
end
46+
push!(history, metric_name, epoch, value)
5547
return nothing
5648
end
5749

@@ -81,6 +73,19 @@ end
8173
"""
8274
$TYPEDSIGNATURES
8375
76+
Fallback that throws a descriptive error for unsupported return types.
77+
Metrics must return a `Number`, a `NamedTuple`, or `nothing`.
78+
"""
79+
function _store_metric_value!(::MVHistory, metric_name::Symbol, ::Int, value)
80+
return error(
81+
"Metric `$metric_name` returned a value of type $(typeof(value)), which cannot " *
82+
"be stored in history. Metrics must return a Number, a NamedTuple, or nothing."
83+
)
84+
end
85+
86+
"""
87+
$TYPEDSIGNATURES
88+
8489
Evaluate all metrics and store their results in the history.
8590
8691
This function handles three types of metric returns through multiple dispatch:

src/metrics/periodic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ This is useful for expensive metrics that don't need to be computed every epoch
1010
$TYPEDFIELDS
1111
1212
# Behavior
13-
The metric is evaluated when `(epoch - offset) % frequency == 0`.
13+
The metric is evaluated when `epoch >= offset` and `(epoch - offset) % frequency == 0`.
1414
On other epochs, `evaluate!` returns `nothing` (which is skipped by `evaluate_metrics!`).
1515
1616
# See also
@@ -82,7 +82,7 @@ Evaluate the wrapped metric only if the current epoch matches the frequency patt
8282
- `nothing` otherwise (which is skipped by `evaluate_metrics!`)
8383
"""
8484
function evaluate!(pm::PeriodicMetric, context)
85-
if (context.epoch - pm.offset) % pm.frequency == 0
85+
if context.epoch >= pm.offset && (context.epoch - pm.offset) % pm.frequency == 0
8686
return evaluate!(pm.metric, context)
8787
else
8888
return nothing # Skip evaluation on this epoch

test/dagger.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,20 @@ using ValueHistories
6969
@test policy.statistical_model !== nothing
7070
@test haskey(history, :training_loss)
7171
end
72+
73+
@testset "DAgger - max_dataset_size cap" begin
74+
algorithm = DAgger(; iterations=2, epochs_per_iteration=1, max_dataset_size=10)
75+
model = generate_statistical_model(benchmark)
76+
maximizer = generate_maximizer(benchmark)
77+
policy = DFLPolicy(model, maximizer)
78+
anticipative_policy = generate_anticipative_solver(benchmark)
79+
80+
history = train_policy!(
81+
algorithm, policy, train_envs; anticipative_policy=anticipative_policy
82+
)
83+
@test history isa MVHistory
84+
@test haskey(history, :training_loss)
85+
end
7286
end
7387

7488
@testset "Integration Tests" begin

test/fyl.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,23 @@ using ValueHistories
138138
_, epoch_sq_values = get(history, :epoch_squared)
139139
@test epoch_sq_values == [0.0, 1.0, 4.0, 9.0]
140140
end
141+
142+
@testset "PeriodicMetric offset guard" begin
143+
model = generate_statistical_model(benchmark)
144+
maximizer = generate_maximizer(benchmark)
145+
policy = DFLPolicy(model, maximizer)
146+
algorithm = PerturbedFenchelYoungLossImitation()
147+
148+
fired_at = Int[]
149+
probe = FunctionMetric(ctx -> (push!(fired_at, ctx.epoch); nothing), :probe)
150+
# offset=5: should fire at epochs 5, 10, ... but NOT at epoch 0
151+
periodic = PeriodicMetric(probe, 5; offset=5)
152+
153+
train_policy!(algorithm, policy, train_data; epochs=10, metrics=(periodic,))
154+
155+
@test 0 fired_at # must not fire before offset
156+
@test 5 fired_at # must fire at offset
157+
@test 10 fired_at # must fire at offset + frequency
158+
@test 3 fired_at # must not fire between
159+
end
141160
end

0 commit comments

Comments
 (0)