@@ -8,7 +8,7 @@ Reference: <https://arxiv.org/abs/2402.04463>
88# Fields
99$TYPEDFIELDS
1010"""
11- @kwdef struct DAgger{A} <: AbstractImitationAlgorithm
11+ @kwdef struct DAgger{A,S } <: AbstractImitationAlgorithm
1212 " inner imitation algorithm for supervised learning"
1313 inner_algorithm:: A = PerturbedFenchelYoungLossImitation ()
1414 " number of DAgger iterations"
@@ -17,6 +17,11 @@ $TYPEDFIELDS
1717 epochs_per_iteration:: Int = 3
1818 " decay factor for mixing expert and learned policy"
1919 α_decay:: Float64 = 0.9
20+ " random seed for the expert/policy mixing coin-flip (nothing = non-reproducible)"
21+ seed:: S = nothing
22+ " maximum dataset size across iterations (nothing keeps all samples,
23+ an integer caps to the most recent N samples via FIFO)"
24+ max_dataset_size:: Union{Int,Nothing} = nothing
2025end
2126
2227"""
@@ -36,9 +41,10 @@ function train_policy!(
3641 metrics:: Tuple = (),
3742 maximizer_kwargs= sample -> sample. context,
3843)
39- (; inner_algorithm, iterations, epochs_per_iteration, α_decay) = algorithm
44+ (; inner_algorithm, iterations, epochs_per_iteration, α_decay, seed ) = algorithm
4045 (; statistical_model, maximizer) = policy
4146
47+ rng = isnothing (seed) ? MersenneTwister () : MersenneTwister (seed)
4248 α = 1.0
4349
4450 # Initial dataset from expert demonstrations
@@ -85,7 +91,7 @@ function train_policy!(
8591 while ! is_terminated (env)
8692 x_before = copy (observe (env)[1 ])
8793 anticipative_solution = anticipative_policy (env; reset_env= false )
88- p = rand ()
94+ p = rand (rng )
8995 target = anticipative_solution[1 ]
9096 x, state = observe (env)
9197 if size (target. x) != size (x)
@@ -104,7 +110,10 @@ function train_policy!(
104110 step! (env, action)
105111 end
106112 end
107- dataset = new_samples # TODO : replay buffer
113+ dataset = vcat (dataset, new_samples)
114+ if ! isnothing (algorithm. max_dataset_size)
115+ dataset = last (dataset, algorithm. max_dataset_size)
116+ end
108117 α *= α_decay # Decay factor for mixing expert and learned policy
109118 end
110119
0 commit comments