Skip to content

Commit 77613cb

Browse files
committed
wip
1 parent 2d3aa84 commit 77613cb

File tree

15 files changed

+110
-94
lines changed

15 files changed

+110
-94
lines changed

docs/src/tutorials/tutorial.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,7 @@ metrics = (validation_loss_metric, val_gap_metric, test_gap_metric)
3939
# Train the model
4040
fyl_model = deepcopy(model)
4141
history = train_policy!(
42-
algorithm,
43-
fyl_model,
44-
maximizer,
45-
train_instances,
46-
validation_instances;
47-
epochs=100,
48-
metrics=metrics,
42+
algorithm, fyl_model, maximizer, train_instances; epochs=100, metrics=metrics
4943
)
5044

5145
# Plot validation and test gaps

docs/src/tutorials/tutorial.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@ history = train_policy!(
7171
algorithm,
7272
fyl_model,
7373
maximizer,
74-
train_instances,
75-
validation_instances;
74+
train_instances;
7675
epochs=100,
7776
metrics=metrics,
7877
)

scripts/main.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ metrics = (
4747

4848
model = deepcopy(initial_model)
4949
history = train_policy!(
50-
algorithm, model, maximizer, train_dataset, val_dataset; epochs=50, metrics=metrics
50+
algorithm, model, maximizer, train_dataset; epochs=50, metrics=metrics
5151
)
5252
X_train, Y_train = get(history, :training_loss)
5353
X_val, Y_val = get(history, :validation_loss)

scripts/old/main.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ res = fyl_train_model(StochasticVehicleSchedulingBenchmark(); epochs=100)
1919
plot(res.validation_loss; label="Validation Loss")
2020
plot!(res.training_loss; label="Training Loss")
2121

22-
baty_train_model(DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false))
22+
kleopatra_train_model(DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false))
2323
DAgger_train_model(DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false))
2424

2525
struct KleopatraPolicy{M}
@@ -79,7 +79,6 @@ dagger_history = DAgger_train_model!(
7979
dagger_model,
8080
maximizer,
8181
train_environments,
82-
validation_environments,
8382
anticipative_policy;
8483
iterations=10,
8584
fyl_epochs=10,

src/DecisionFocusedLearningAlgorithms.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ include("metrics/accumulators.jl")
1818
include("metrics/function_metric.jl")
1919
include("metrics/periodic.jl")
2020

21+
include("algorithms/abstract_algorithm.jl")
2122
include("algorithms/supervised/fyl.jl")
2223
include("algorithms/supervised/kleopatra.jl")
2324
include("algorithms/supervised/dagger.jl")
@@ -32,10 +33,10 @@ export AbstractMetric,
3233
reset!,
3334
update!,
3435
evaluate!,
35-
compute,
36-
run_metrics!
36+
compute!,
37+
evaluate_metrics!
3738

38-
export fyl_train_model, baty_train_model, DAgger_train_model!, DAgger_train_model
39+
export fyl_train_model, kleopatra_train_model, DAgger_train_model!, DAgger_train_model
3940
export PerturbedImitationAlgorithm, train_policy!
4041

4142
end
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""
2+
$TYPEDEF
3+
4+
An abstract type for decision-focused learning algorithms.
5+
"""
6+
abstract type AbstractAlgorithm end
7+
8+
"""
9+
$TYPEDEF
10+
11+
An abstract type for imitation learning algorithms.
12+
13+
All subtypes must implement:
14+
- `train_policy!(algorithm::AbstractImitationAlgorithm, model, maximizer, train_data; epochs, metrics)`
15+
"""
16+
abstract type AbstractImitationAlgorithm <: AbstractAlgorithm end

src/algorithms/supervised/dagger.jl

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ function DAgger_train_model!(
33
model,
44
maximizer,
55
train_environments,
6-
validation_environments,
76
anticipative_policy;
87
iterations=5,
98
fyl_epochs=3,
@@ -16,10 +15,6 @@ function DAgger_train_model!(
1615
v, y = anticipative_policy(env; reset_env=true)
1716
return y
1817
end...)
19-
val_dataset = vcat(map(validation_environments) do env
20-
v, y = anticipative_policy(env; reset_env=true)
21-
return y
22-
end...)
2318

2419
dataset = deepcopy(train_dataset)
2520

@@ -117,18 +112,12 @@ function DAgger_train_model(b::AbstractStochasticBenchmark{true}; kwargs...)
117112
dataset = generate_dataset(b, 30)
118113
train_instances, validation_instances, _ = splitobs(dataset; at=(0.3, 0.3, 0.4))
119114
train_environments = generate_environments(b, train_instances; seed=0)
120-
validation_environments = generate_environments(b, validation_instances)
121115
model = generate_statistical_model(b)
122116
maximizer = generate_maximizer(b)
123117
anticipative_policy =
124118
(env; reset_env) -> generate_anticipative_solution(b, env; reset_env)
125119
history = DAgger_train_model!(
126-
model,
127-
maximizer,
128-
train_environments,
129-
validation_environments,
130-
anticipative_policy;
131-
kwargs...,
120+
model, maximizer, train_environments, anticipative_policy; kwargs...
132121
)
133122
return history, model
134123
end

src/algorithms/supervised/fyl.jl

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,32 @@
33
# TODO: parallelize loss computation on validation set
44
# TODO: have supervised learning training method, where fyl_train calls it, therefore we can easily test new supervised losses if needed
55

6-
@kwdef struct PerturbedImitationAlgorithm{O,S}
6+
"""
7+
$TYPEDEF
8+
9+
Structured imitation learning with a perturbed Fenchel-Young loss.
10+
11+
# Fields
12+
$TYPEDFIELDS
13+
"""
14+
@kwdef struct PerturbedImitationAlgorithm{O,S} <: AbstractImitationAlgorithm
15+
"number of perturbation samples"
716
nb_samples::Int = 10
17+
"perturbation magnitude"
818
ε::Float64 = 0.1
19+
"whether to use threading for perturbations"
920
threaded::Bool = true
21+
"optimizer used for training"
1022
training_optimizer::O = Adam()
23+
"random seed for perturbations"
1124
seed::S = nothing
1225
end
1326

14-
reset!(algorithm::PerturbedImitationAlgorithm) = empty!(algorithm.history)
27+
"""
28+
$TYPEDSIGNATURES
1529
30+
Train a model using the Perturbed Imitation Algorithm on the provided training dataset.
31+
"""
1632
function train_policy!(
1733
algorithm::PerturbedImitationAlgorithm,
1834
model,
@@ -21,9 +37,7 @@ function train_policy!(
2137
epochs=100,
2238
maximizer_kwargs=get_info,
2339
metrics::Tuple=(),
24-
reset=false,
2540
)
26-
reset && reset!(algorithm)
2741
(; nb_samples, ε, threaded, training_optimizer, seed) = algorithm
2842
perturbed = PerturbedAdditive(maximizer; nb_samples, ε, threaded, seed)
2943
loss = FenchelYoungLoss(perturbed)
@@ -32,23 +46,21 @@ function train_policy!(
3246

3347
history = MVHistory()
3448

35-
train_loss_metric = LossAccumulator(:training_loss)
49+
train_loss_metric = FYLLossMetric(train_dataset, :training_loss)
3650

37-
# Store initial losses (epoch 0)
38-
# Epoch 0
39-
for sample in train_dataset
40-
(; x, y) = sample
41-
val = loss(model(x), y; maximizer_kwargs(sample)...)
42-
update!(train_loss_metric, val)
43-
end
44-
push!(history, :training_loss, 0, compute(train_loss_metric))
45-
reset!(train_loss_metric)
46-
47-
# Initial metric evaluation
48-
context = TrainingContext(; model=model, epoch=0, maximizer=maximizer, loss=loss)
49-
run_metrics!(history, metrics, context)
51+
# Initial metric evaluation and training loss (epoch 0)
52+
context = TrainingContext(;
53+
model=model,
54+
epoch=0,
55+
maximizer=maximizer,
56+
maximizer_kwargs=maximizer_kwargs,
57+
loss=loss,
58+
)
59+
push!(history, :training_loss, 0, evaluate!(train_loss_metric, context))
60+
evaluate_metrics!(history, metrics, context)
5061

5162
@showprogress for epoch in 1:epochs
63+
next_epoch!(context)
5264
# Training step
5365
for sample in train_dataset
5466
(; x, y) = sample
@@ -59,13 +71,9 @@ function train_policy!(
5971
update!(train_loss_metric, val)
6072
end
6173

62-
# Store training loss
63-
push!(history, :training_loss, epoch, compute(train_loss_metric))
64-
reset!(train_loss_metric)
65-
66-
# Evaluate all metrics - update epoch in context
67-
context.epoch = epoch
68-
run_metrics!(history, metrics, context)
74+
# Log metrics
75+
push!(history, :training_loss, epoch, compute!(train_loss_metric))
76+
evaluate_metrics!(history, metrics, context)
6977
end
7078

7179
# Plot training loss (or first metric if available)
Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function baty_train_model(
1+
function kleopatra_train_model(
22
b::AbstractStochasticBenchmark{true};
33
epochs=10,
44
metrics::Tuple=(),
@@ -8,7 +8,6 @@ function baty_train_model(
88
dataset = generate_dataset(b, 30)
99
train_instances, validation_instances, _ = splitobs(dataset; at=(0.3, 0.3))
1010
train_environments = generate_environments(b, train_instances)
11-
validation_environments = generate_environments(b, validation_instances)
1211

1312
# Generate anticipative solutions
1413
train_dataset = vcat(
@@ -18,11 +17,6 @@ function baty_train_model(
1817
end...
1918
)
2019

21-
val_dataset = vcat(map(validation_environments) do env
22-
v, y = generate_anticipative_solution(b, env; reset_env=true)
23-
return y
24-
end...)
25-
2620
# Initialize model and maximizer
2721
model = generate_statistical_model(b)
2822
maximizer = generate_maximizer(b)
@@ -32,12 +26,11 @@ function baty_train_model(
3226
algorithm,
3327
model,
3428
maximizer,
35-
train_dataset,
36-
val_dataset;
29+
train_dataset;
3730
epochs=epochs,
3831
metrics=metrics,
3932
maximizer_kwargs=get_state,
4033
)
4134

4235
return history, model
43-
end
36+
end

src/metrics/accumulators.jl

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ for sample in dataset
2020
end
2121
2222
# Get average and reset
23-
avg_loss = compute(metric) # Automatically resets
23+
avg_loss = compute!(metric) # Automatically resets
2424
```
2525
2626
# See also
@@ -76,7 +76,7 @@ Add a loss value to the accumulator.
7676
metric = LossAccumulator()
7777
update!(metric, 1.5)
7878
update!(metric, 2.0)
79-
compute(metric) # Returns 1.75
79+
compute!(metric) # Returns 1.75
8080
```
8181
"""
8282
function update!(metric::LossAccumulator, loss_value::Float64)
@@ -101,10 +101,10 @@ Compute the average loss from accumulated values.
101101
metric = LossAccumulator()
102102
update!(metric, 1.5)
103103
update!(metric, 2.5)
104-
avg = compute(metric) # Returns 2.0, then resets
104+
avg = compute!(metric) # Returns 2.0, then resets
105105
```
106106
"""
107-
function compute(metric::LossAccumulator; reset::Bool=true)
107+
function compute!(metric::LossAccumulator; reset::Bool=true)
108108
value = metric.count == 0 ? 0.0 : metric.total_loss / metric.count
109109
reset && reset!(metric)
110110
return value
@@ -130,7 +130,7 @@ Can also be used in the algorithms to accumulate loss over training data.
130130
# Create metric with validation dataset
131131
val_metric = FYLLossMetric(val_dataset, :validation_loss)
132132
133-
# Evaluate during training (called by run_metrics!)
133+
# Evaluate during training (called by evaluate_metrics!)
134134
context = TrainingContext(model=model, epoch=5, maximizer=maximizer, loss=loss)
135135
avg_loss = evaluate!(val_metric, context)
136136
```
@@ -228,19 +228,33 @@ function evaluate!(metric::FYLLossMetric, context::TrainingContext)
228228
for sample in metric.dataset
229229
θ = context.model(sample.x)
230230
y_target = sample.y
231-
update!(metric, context.loss, θ, y_target; sample.info...)
231+
update!(metric, context.loss, θ, y_target; context.maximizer_kwargs(sample)...)
232232
end
233-
return compute(metric)
233+
return compute!(metric)
234234
end
235235

236236
"""
237-
compute(metric::FYLLossMetric)
237+
$TYPEDSIGNATURES
238+
239+
Update the metric with an already-computed loss value. This avoids re-evaluating
240+
the loss inside the metric when the loss was computed during training.
241+
242+
# Returns
243+
- `Float64` - The provided loss value
244+
"""
245+
function update!(metric::FYLLossMetric, loss_value::Float64)
246+
update!(metric.accumulator, loss_value)
247+
return loss_value
248+
end
249+
250+
"""
251+
compute!(metric::FYLLossMetric)
238252
239253
Compute the average loss from accumulated values.
240254
241255
# Returns
242256
- `Float64` - Average loss (or 0.0 if no values accumulated)
243257
"""
244-
function compute(metric::FYLLossMetric)
245-
return compute(metric.accumulator)
258+
function compute!(metric::FYLLossMetric)
259+
return compute!(metric.accumulator)
246260
end

0 commit comments

Comments
 (0)