@@ -34,7 +34,7 @@ function train_policy!(
3434 train_environments;
3535 anticipative_policy,
3636 metrics:: Tuple = (),
37- maximizer_kwargs= get_state ,
37+ maximizer_kwargs= sample -> sample . context ,
3838)
3939 (; inner_algorithm, iterations, epochs_per_iteration, α_decay) = algorithm
4040 (; statistical_model, maximizer) = policy
@@ -43,15 +43,14 @@ function train_policy!(
4343
4444 # Initial dataset from expert demonstrations
4545 train_dataset = vcat (map (train_environments) do env
46- v, y = anticipative_policy (env; reset_env= true )
47- return y
46+ return anticipative_policy (env; reset_env= true )
4847 end ... )
4948
5049 dataset = deepcopy (train_dataset)
5150
5251 # Initialize combined history for all DAgger iterations
5352 combined_history = MVHistory ()
54- global_epoch = 0
53+ epoch_offset = 0
5554
5655 for iter in 1 : iterations
5756 println (" DAgger iteration $iter /$iterations (α=$(round (α, digits= 3 )) )" )
@@ -68,50 +67,24 @@ function train_policy!(
6867
6968 # Merge iteration history into combined history
7069 for key in keys (iter_history)
71- epochs, values = get (iter_history, key)
72- for i in eachindex (epochs)
73- # Calculate global epoch number
74- if iter == 1
75- # First iteration: use epochs as-is [0, 1, 2, ...]
76- global_epoch_value = epochs[i]
77- else
78- # Later iterations: skip epoch 0 and renumber starting from global_epoch
79- if epochs[i] == 0
80- continue # Skip epoch 0 for iterations > 1
81- end
82- # Map epoch 1 → global_epoch, epoch 2 → global_epoch+1, etc.
83- global_epoch_value = global_epoch + epochs[i] - 1
84- end
85-
86- # For the epoch key, use global_epoch_value as both time and value
87- # For other keys, use global_epoch_value as time and original value
88- if key == :epoch
89- push! (combined_history, key, global_epoch_value, global_epoch_value)
90- else
91- push! (combined_history, key, global_epoch_value, values[i])
92- end
70+ local_epochs, values = get (iter_history, key)
71+ for i in eachindex (local_epochs)
72+ # Skip epoch 0 for all iterations after the first
73+ local_epochs[i] == 0 && epoch_offset > 0 && continue
74+ global_e = epoch_offset + local_epochs[i]
75+ push! (combined_history, key, global_e, key == :epoch ? global_e : values[i])
9376 end
9477 end
9578
96- # Update global_epoch for next iteration
97- # After each iteration, advance by the number of non-zero epochs processed
98- if iter == 1
99- # First iteration processes all epochs [0, 1, ..., epochs_per_iteration]
100- # Next iteration should start at epochs_per_iteration + 1
101- global_epoch = epochs_per_iteration + 1
102- else
103- # Subsequent iterations skip epoch 0, so they process epochs_per_iteration epochs
104- # Next iteration should start epochs_per_iteration later
105- global_epoch += epochs_per_iteration
106- end
79+ epoch_offset += epochs_per_iteration
10780
10881 # Dataset update - collect new samples using mixed policy
10982 new_samples = eltype (dataset)[]
11083 for env in train_environments
11184 DecisionFocusedLearningBenchmarks. reset! (env; reset_rng= false )
11285 while ! is_terminated (env)
11386 x_before = copy (observe (env)[1 ])
114- _, anticipative_solution = anticipative_policy (env; reset_env= false )
87+ anticipative_solution = anticipative_policy (env; reset_env= false )
11588 p = rand ()
11689 target = anticipative_solution[1 ]
11790 x, state = observe (env)
@@ -149,25 +122,21 @@ This high-level function handles all setup from the benchmark and returns a trai
149122"""
150123function train_policy (
151124 algorithm:: DAgger ,
152- benchmark:: AbstractStochasticBenchmark{true} ;
125+ benchmark:: ExogenousDynamicBenchmark ;
153126 dataset_size= 30 ,
154- split_ratio= (0.3 , 0.3 , 0.4 ),
155127 metrics:: Tuple = (),
156128 seed= 0 ,
157129)
158- # Generate dataset and environments
159- dataset = generate_dataset (benchmark, dataset_size)
160- train_instances, validation_instances, _ = splitobs (dataset; at= split_ratio)
161- train_environments = generate_environments (benchmark, train_instances; seed)
130+ # Generate environments
131+ train_environments = generate_environments (benchmark, dataset_size; seed)
162132
163133 # Initialize model and create policy
164134 model = generate_statistical_model (benchmark)
165135 maximizer = generate_maximizer (benchmark)
166136 policy = DFLPolicy (model, maximizer)
167137
168138 # Define anticipative policy from benchmark
169- anticipative_policy =
170- (env; reset_env) -> generate_anticipative_solution (benchmark, env; reset_env)
139+ anticipative_policy = generate_anticipative_solver (benchmark)
171140
172141 # Train policy
173142 history = train_policy! (
@@ -176,7 +145,7 @@ function train_policy(
176145 train_environments;
177146 anticipative_policy= anticipative_policy,
178147 metrics= metrics,
179- maximizer_kwargs= get_state ,
148+ maximizer_kwargs= sample -> sample . context ,
180149 )
181150
182151 return history, policy
0 commit comments