33# TODO : parallelize loss computation on validation set
44# TODO : have supervised learning training method, where fyl_train calls it, therefore we can easily test new supervised losses if needed
55
6- @kwdef struct PerturbedImitationAlgorithm{O,S}
6+ """
7+ $TYPEDEF
8+
9+ Structured imitation learning with a perturbed Fenchel-Young loss.
10+
11+ # Fields
12+ $TYPEDFIELDS
13+ """
14+ @kwdef struct PerturbedImitationAlgorithm{O,S} <: AbstractImitationAlgorithm
15+ " number of perturbation samples"
716 nb_samples:: Int = 10
17+ " perturbation magnitude"
818 ε:: Float64 = 0.1
19+ " whether to use threading for perturbations"
920 threaded:: Bool = true
21+ " optimizer used for training"
1022 training_optimizer:: O = Adam ()
23+ " random seed for perturbations"
1124 seed:: S = nothing
1225end
1326
14- reset! (algorithm:: PerturbedImitationAlgorithm ) = empty! (algorithm. history)
27+ """
28+ $TYPEDSIGNATURES
1529
30+ Train a model using the Perturbed Imitation Algorithm on the provided training dataset.
31+ """
1632function train_policy! (
1733 algorithm:: PerturbedImitationAlgorithm ,
1834 model,
@@ -21,9 +37,7 @@ function train_policy!(
2137 epochs= 100 ,
2238 maximizer_kwargs= get_info,
2339 metrics:: Tuple = (),
24- reset= false ,
2540)
26- reset && reset! (algorithm)
2741 (; nb_samples, ε, threaded, training_optimizer, seed) = algorithm
2842 perturbed = PerturbedAdditive (maximizer; nb_samples, ε, threaded, seed)
2943 loss = FenchelYoungLoss (perturbed)
@@ -32,23 +46,21 @@ function train_policy!(
3246
3347 history = MVHistory ()
3448
35- train_loss_metric = LossAccumulator ( :training_loss )
49+ train_loss_metric = FYLLossMetric (train_dataset, :training_loss )
3650
37- # Store initial losses (epoch 0)
38- # Epoch 0
39- for sample in train_dataset
40- (; x, y) = sample
41- val = loss (model (x), y; maximizer_kwargs (sample)... )
42- update! (train_loss_metric, val)
43- end
44- push! (history, :training_loss , 0 , compute (train_loss_metric))
45- reset! (train_loss_metric)
46-
47- # Initial metric evaluation
48- context = TrainingContext (; model= model, epoch= 0 , maximizer= maximizer, loss= loss)
49- run_metrics! (history, metrics, context)
51+ # Initial metric evaluation and training loss (epoch 0)
52+ context = TrainingContext (;
53+ model= model,
54+ epoch= 0 ,
55+ maximizer= maximizer,
56+ maximizer_kwargs= maximizer_kwargs,
57+ loss= loss,
58+ )
59+ push! (history, :training_loss , 0 , evaluate! (train_loss_metric, context))
60+ evaluate_metrics! (history, metrics, context)
5061
5162 @showprogress for epoch in 1 : epochs
63+ next_epoch! (context)
5264 # Training step
5365 for sample in train_dataset
5466 (; x, y) = sample
@@ -59,13 +71,9 @@ function train_policy!(
5971 update! (train_loss_metric, val)
6072 end
6173
62- # Store training loss
63- push! (history, :training_loss , epoch, compute (train_loss_metric))
64- reset! (train_loss_metric)
65-
66- # Evaluate all metrics - update epoch in context
67- context. epoch = epoch
68- run_metrics! (history, metrics, context)
74+ # Log metrics
75+ push! (history, :training_loss , epoch, compute! (train_loss_metric))
76+ evaluate_metrics! (history, metrics, context)
6977 end
7078
7179 # Plot training loss (or first metric if available)
0 commit comments