|
| 1 | +# # Basic Tutorial: Training with FYL on Argmax Benchmark |
| 2 | +# |
| 3 | +# This tutorial demonstrates the basic workflow for training a policy |
| 4 | +# using the Perturbed Fenchel-Young Loss algorithm. |
| 5 | + |
| 6 | +# ## Setup |
| 7 | +using DecisionFocusedLearningAlgorithms |
| 8 | +using DecisionFocusedLearningBenchmarks |
| 9 | +using MLUtils: splitobs |
| 10 | +using Plots |
| 11 | + |
| 12 | +# ## Create Benchmark and Data |
| 13 | +b = ArgmaxBenchmark() |
| 14 | +dataset = generate_dataset(b, 100) |
| 15 | +train_data, val_data, test_data = splitobs(dataset; at=(0.3, 0.3, 0.4)) |
| 16 | + |
| 17 | +# ## Create Policy |
| 18 | +model = generate_statistical_model(b; seed=0) |
| 19 | +maximizer = generate_maximizer(b) |
| 20 | +policy = DFLPolicy(model, maximizer) |
| 21 | + |
| 22 | +# ## Configure Algorithm |
| 23 | +algorithm = PerturbedFenchelYoungLossImitation(; |
| 24 | + nb_samples=10, ε=0.1, threaded=true, seed=0 |
| 25 | +) |
| 26 | + |
| 27 | +# ## Define Metrics to track during training |
| 28 | +validation_loss_metric = FYLLossMetric(val_data, :validation_loss) |
| 29 | + |
| 30 | +val_gap_metric = FunctionMetric(:val_gap, val_data) do ctx, data |
| 31 | + compute_gap(b, data, ctx.policy.statistical_model, ctx.policy.maximizer) |
| 32 | +end |
| 33 | + |
| 34 | +test_gap_metric = FunctionMetric(:test_gap, test_data) do ctx, data |
| 35 | + compute_gap(b, data, ctx.policy.statistical_model, ctx.policy.maximizer) |
| 36 | +end |
| 37 | + |
| 38 | +metrics = (validation_loss_metric, val_gap_metric, test_gap_metric) |
| 39 | + |
| 40 | +# ## Train the Policy |
| 41 | +history = train_policy!(algorithm, policy, train_data; epochs=100, metrics=metrics) |
| 42 | + |
| 43 | +# ## Plot Results |
| 44 | +val_gap_epochs, val_gap_values = get(history, :val_gap) |
| 45 | +test_gap_epochs, test_gap_values = get(history, :test_gap) |
| 46 | + |
| 47 | +plot( |
| 48 | + [val_gap_epochs, test_gap_epochs], |
| 49 | + [val_gap_values, test_gap_values]; |
| 50 | + labels=["Val Gap" "Test Gap"], |
| 51 | + xlabel="Epoch", |
| 52 | + ylabel="Gap", |
| 53 | + title="Gap Evolution During Training", |
| 54 | +) |
| 55 | + |
| 56 | +# Plot loss evolution |
| 57 | +train_loss_epochs, train_loss_values = get(history, :training_loss) |
| 58 | +val_loss_epochs, val_loss_values = get(history, :validation_loss) |
| 59 | + |
| 60 | +plot( |
| 61 | + [train_loss_epochs, val_loss_epochs], |
| 62 | + [train_loss_values, val_loss_values]; |
| 63 | + labels=["Training Loss" "Validation Loss"], |
| 64 | + xlabel="Epoch", |
| 65 | + ylabel="Loss", |
| 66 | + title="Loss Evolution During Training", |
| 67 | +) |
0 commit comments