|
1 | | -# Tutorial |
| 1 | +# # Basic Tutorial: Training with FYL on Argmax Benchmark |
| 2 | +# |
| 3 | +# This tutorial demonstrates the basic workflow for training a policy |
| 4 | +# using the Perturbed Fenchel-Young Loss algorithm. |
| 5 | + |
| 6 | +# ## Setup |
2 | 7 | using DecisionFocusedLearningAlgorithms |
3 | 8 | using DecisionFocusedLearningBenchmarks |
4 | 9 | using MLUtils: splitobs |
5 | 10 | using Plots |
6 | 11 |
|
| 12 | +# ## Create Benchmark and Data |
7 | 13 | b = ArgmaxBenchmark() |
8 | 14 | dataset = generate_dataset(b, 100) |
9 | | -train_instances, validation_instances, test_instances = splitobs( |
10 | | - dataset; at=(0.3, 0.3, 0.4) |
11 | | -) |
| 15 | +train_data, val_data, test_data = splitobs(dataset; at=(0.3, 0.3, 0.4)) |
12 | 16 |
|
| 17 | +# ## Create Policy |
13 | 18 | model = generate_statistical_model(b; seed=0) |
14 | 19 | maximizer = generate_maximizer(b) |
| 20 | +policy = DFLPolicy(model, maximizer) |
15 | 21 |
|
16 | | -# Compute initial gap |
17 | | -initial_gap = compute_gap(b, test_instances, model, maximizer) |
18 | | -println("Initial test gap: $initial_gap") |
19 | | - |
20 | | -# Configure the training algorithm |
21 | | -algorithm = PerturbedImitationAlgorithm(; nb_samples=10, ε=0.1, threaded=true, seed=0) |
| 22 | +# ## Configure Algorithm |
| 23 | +algorithm = PerturbedFenchelYoungLossImitation(; |
| 24 | + nb_samples=10, ε=0.1, threaded=true, seed=0 |
| 25 | +) |
22 | 26 |
|
23 | | -# Define metrics to track during training |
24 | | -validation_loss_metric = FYLLossMetric(validation_instances, :validation_loss) |
| 27 | +# ## Define Metrics to track during training |
| 28 | +validation_loss_metric = FYLLossMetric(val_data, :validation_loss) |
25 | 29 |
|
26 | | -# Validation gap metric |
27 | | -val_gap_metric = FunctionMetric(:val_gap, validation_instances) do ctx, data |
28 | | - compute_gap(b, data, ctx.model, ctx.maximizer) |
| 30 | +val_gap_metric = FunctionMetric(:val_gap, val_data) do ctx, data |
| 31 | + compute_gap(b, data, ctx.policy.statistical_model, ctx.policy.maximizer) |
29 | 32 | end |
30 | 33 |
|
31 | | -# Test gap metric |
32 | | -test_gap_metric = FunctionMetric(:test_gap, test_instances) do ctx, data |
33 | | - compute_gap(b, data, ctx.model, ctx.maximizer) |
| 34 | +test_gap_metric = FunctionMetric(:test_gap, test_data) do ctx, data |
| 35 | + compute_gap(b, data, ctx.policy.statistical_model, ctx.policy.maximizer) |
34 | 36 | end |
35 | 37 |
|
36 | | -# Combine metrics |
37 | 38 | metrics = (validation_loss_metric, val_gap_metric, test_gap_metric) |
38 | 39 |
|
39 | | -# Train the model |
40 | | -fyl_model = deepcopy(model) |
41 | | -history = train_policy!( |
42 | | - algorithm, fyl_model, maximizer, train_instances; epochs=100, metrics=metrics |
43 | | -) |
| 40 | +# ## Train the Policy |
| 41 | +history = train_policy!(algorithm, policy, train_data; epochs=100, metrics=metrics) |
44 | 42 |
|
45 | | -# Plot validation and test gaps |
| 43 | +# ## Plot Results |
46 | 44 | val_gap_epochs, val_gap_values = get(history, :val_gap) |
47 | 45 | test_gap_epochs, test_gap_values = get(history, :test_gap) |
48 | 46 |
|
|
55 | 53 | title="Gap Evolution During Training", |
56 | 54 | ) |
57 | 55 |
|
58 | | -# Plot validation loss |
| 56 | +# Plot loss evolution |
59 | 57 | train_loss_epochs, train_loss_values = get(history, :training_loss) |
60 | 58 | val_loss_epochs, val_loss_values = get(history, :validation_loss) |
61 | 59 |
|
|
0 commit comments