@@ -8,7 +8,7 @@ Reference: <https://arxiv.org/abs/2402.04463>
88# Fields
99$TYPEDFIELDS
1010"""
11- @kwdef struct DAgger{A} <: AbstractImitationAlgorithm
11+ @kwdef struct DAgger{A,S } <: AbstractImitationAlgorithm
1212 " inner imitation algorithm for supervised learning"
1313 inner_algorithm:: A = PerturbedFenchelYoungLossImitation ()
1414 " number of DAgger iterations"
@@ -17,6 +17,11 @@ $TYPEDFIELDS
1717 epochs_per_iteration:: Int = 3
1818 " decay factor for mixing expert and learned policy"
1919 α_decay:: Float64 = 0.9
20+ " random seed for the expert/policy mixing coin-flip (nothing = non-reproducible)"
21+ seed:: S = nothing
22+ " maximum dataset size across iterations (nothing keeps all samples,
23+ an integer caps to the most recent N samples via FIFO)"
24+ max_dataset_size:: Union{Int,Nothing} = nothing
2025end
2126
2227"""
@@ -34,24 +39,24 @@ function train_policy!(
3439 train_environments;
3540 anticipative_policy,
3641 metrics:: Tuple = (),
37- maximizer_kwargs= get_state ,
42+ maximizer_kwargs= sample -> sample . context ,
3843)
39- (; inner_algorithm, iterations, epochs_per_iteration, α_decay) = algorithm
44+ (; inner_algorithm, iterations, epochs_per_iteration, α_decay, seed ) = algorithm
4045 (; statistical_model, maximizer) = policy
4146
47+ rng = isnothing (seed) ? MersenneTwister () : MersenneTwister (seed)
4248 α = 1.0
4349
4450 # Initial dataset from expert demonstrations
4551 train_dataset = vcat (map (train_environments) do env
46- v, y = anticipative_policy (env; reset_env= true )
47- return y
52+ return anticipative_policy (env; reset_env= true )
4853 end ... )
4954
5055 dataset = deepcopy (train_dataset)
5156
5257 # Initialize combined history for all DAgger iterations
5358 combined_history = MVHistory ()
54- global_epoch = 0
59+ epoch_offset = 0
5560
5661 for iter in 1 : iterations
5762 println (" DAgger iteration $iter /$iterations (α=$(round (α, digits= 3 )) )" )
@@ -68,53 +73,26 @@ function train_policy!(
6873
6974 # Merge iteration history into combined history
7075 for key in keys (iter_history)
71- epochs, values = get (iter_history, key)
72- for i in eachindex (epochs)
73- # Calculate global epoch number
74- if iter == 1
75- # First iteration: use epochs as-is [0, 1, 2, ...]
76- global_epoch_value = epochs[i]
77- else
78- # Later iterations: skip epoch 0 and renumber starting from global_epoch
79- if epochs[i] == 0
80- continue # Skip epoch 0 for iterations > 1
81- end
82- # Map epoch 1 → global_epoch, epoch 2 → global_epoch+1, etc.
83- global_epoch_value = global_epoch + epochs[i] - 1
84- end
85-
86- # For the epoch key, use global_epoch_value as both time and value
87- # For other keys, use global_epoch_value as time and original value
88- if key == :epoch
89- push! (combined_history, key, global_epoch_value, global_epoch_value)
90- else
91- push! (combined_history, key, global_epoch_value, values[i])
92- end
76+ local_epochs, values = get (iter_history, key)
77+ for i in eachindex (local_epochs)
78+ # Skip epoch 0 for all iterations after the first
79+ local_epochs[i] == 0 && epoch_offset > 0 && continue
80+ global_e = epoch_offset + local_epochs[i]
81+ push! (combined_history, key, global_e, key == :epoch ? global_e : values[i])
9382 end
9483 end
9584
96- # Update global_epoch for next iteration
97- # After each iteration, advance by the number of non-zero epochs processed
98- if iter == 1
99- # First iteration processes all epochs [0, 1, ..., epochs_per_iteration]
100- # Next iteration should start at epochs_per_iteration + 1
101- global_epoch = epochs_per_iteration + 1
102- else
103- # Subsequent iterations skip epoch 0, so they process epochs_per_iteration epochs
104- # Next iteration should start epochs_per_iteration later
105- global_epoch += epochs_per_iteration
106- end
85+ epoch_offset += epochs_per_iteration
10786
10887 # Dataset update - collect new samples using mixed policy
10988 new_samples = eltype (dataset)[]
11089 for env in train_environments
11190 DecisionFocusedLearningBenchmarks. reset! (env; reset_rng= false )
11291 while ! is_terminated (env)
113- x_before = copy (observe (env)[1 ])
114- _, anticipative_solution = anticipative_policy (env; reset_env= false )
115- p = rand ()
92+ anticipative_solution = anticipative_policy (env; reset_env= false )
93+ p = rand (rng)
11694 target = anticipative_solution[1 ]
117- x, state = observe (env)
95+ x, _ = observe (env)
11896 if size (target. x) != size (x)
11997 @error " Mismatch between expert and observed state" size (target. x) size (
12098 x
@@ -124,14 +102,16 @@ function train_policy!(
124102 if p < α
125103 action = target. y
126104 else
127- x, state = observe (env)
128105 θ = statistical_model (x)
129106 action = maximizer (θ; maximizer_kwargs (target)... )
130107 end
131108 step! (env, action)
132109 end
133110 end
134- dataset = new_samples # TODO : replay buffer
111+ dataset = vcat (dataset, new_samples)
112+ if ! isnothing (algorithm. max_dataset_size)
113+ dataset = last (dataset, algorithm. max_dataset_size)
114+ end
135115 α *= α_decay # Decay factor for mixing expert and learned policy
136116 end
137117
@@ -149,25 +129,21 @@ This high-level function handles all setup from the benchmark and returns a trai
149129"""
150130function train_policy (
151131 algorithm:: DAgger ,
152- benchmark:: AbstractStochasticBenchmark{true} ;
132+ benchmark:: ExogenousDynamicBenchmark ;
153133 dataset_size= 30 ,
154- split_ratio= (0.3 , 0.3 , 0.4 ),
155134 metrics:: Tuple = (),
156- seed= 0 ,
135+ seed= nothing ,
157136)
158- # Generate dataset and environments
159- dataset = generate_dataset (benchmark, dataset_size)
160- train_instances, validation_instances, _ = splitobs (dataset; at= split_ratio)
161- train_environments = generate_environments (benchmark, train_instances; seed)
137+ # Generate environments
138+ train_environments = generate_environments (benchmark, dataset_size; seed)
162139
163140 # Initialize model and create policy
164- model = generate_statistical_model (benchmark)
141+ model = generate_statistical_model (benchmark; seed )
165142 maximizer = generate_maximizer (benchmark)
166143 policy = DFLPolicy (model, maximizer)
167144
168145 # Define anticipative policy from benchmark
169- anticipative_policy =
170- (env; reset_env) -> generate_anticipative_solution (benchmark, env; reset_env)
146+ anticipative_policy = generate_anticipative_solver (benchmark)
171147
172148 # Train policy
173149 history = train_policy! (
@@ -176,7 +152,7 @@ function train_policy(
176152 train_environments;
177153 anticipative_policy= anticipative_policy,
178154 metrics= metrics,
179- maximizer_kwargs= get_state ,
155+ maximizer_kwargs= sample -> sample . context ,
180156 )
181157
182158 return history, policy
0 commit comments