Skip to content

Commit 2d3aa84

Browse files
committed
cleanup
1 parent 541e8be commit 2d3aa84

File tree

10 files changed

+151
-257
lines changed

10 files changed

+151
-257
lines changed

src/DecisionFocusedLearningAlgorithms.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ using ValueHistories: MVHistory
1313
include("utils.jl")
1414
include("training_context.jl")
1515

16-
# Metrics subsystem
1716
include("metrics/interface.jl")
1817
include("metrics/accumulators.jl")
1918
include("metrics/function_metric.jl")
2019
include("metrics/periodic.jl")
2120

22-
include("algorithms/fyl.jl")
23-
include("algorithms/dagger.jl")
21+
include("algorithms/supervised/fyl.jl")
22+
include("algorithms/supervised/kleopatra.jl")
23+
include("algorithms/supervised/dagger.jl")
2424

2525
export TrainingContext
2626

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ function DAgger_train_model!(
3535
algorithm,
3636
model,
3737
maximizer,
38-
dataset,
39-
val_dataset;
38+
dataset;
4039
epochs=fyl_epochs,
4140
metrics=metrics,
4241
maximizer_kwargs=maximizer_kwargs,
@@ -45,7 +44,7 @@ function DAgger_train_model!(
4544
# Merge iteration history into combined history
4645
for key in keys(iter_history)
4746
epochs, values = get(iter_history, key)
48-
for i in 1:length(epochs)
47+
for i in eachindex(epochs)
4948
# Calculate global epoch number
5049
if iter == 1
5150
# First iteration: use epochs as-is [0, 1, 2, ...]
Lines changed: 3 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
# TODO: every N epochs
21
# TODO: best_model saving method, using default metric validation loss, overwritten in dagger
3-
# TODO: Implement validation loss as a metric callback
42
# TODO: batch training option
53
# TODO: parallelize loss computation on validation set
64
# TODO: have supervised learning training method, where fyl_train calls it, therefore we can easily test new supervised losses if needed
@@ -19,8 +17,7 @@ function train_policy!(
1917
algorithm::PerturbedImitationAlgorithm,
2018
model,
2119
maximizer,
22-
train_dataset::AbstractArray{<:DataSample},
23-
validation_dataset;
20+
train_dataset::AbstractArray{<:DataSample};
2421
epochs=100,
2522
maximizer_kwargs=get_info,
2623
metrics::Tuple=(),
@@ -85,58 +82,11 @@ end
8582
function fyl_train_model(
8683
initial_model,
8784
maximizer,
88-
train_dataset,
89-
validation_dataset;
85+
train_dataset;
9086
algorithm=PerturbedImitationAlgorithm(),
9187
kwargs...,
9288
)
9389
model = deepcopy(initial_model)
94-
history = train_policy!(
95-
algorithm, model, maximizer, train_dataset, validation_dataset; kwargs...
96-
)
97-
return history, model
98-
end
99-
100-
function baty_train_model(
101-
b::AbstractStochasticBenchmark{true};
102-
epochs=10,
103-
metrics::Tuple=(),
104-
algorithm::PerturbedImitationAlgorithm=PerturbedImitationAlgorithm(),
105-
)
106-
# Generate instances and environments
107-
dataset = generate_dataset(b, 30)
108-
train_instances, validation_instances, _ = splitobs(dataset; at=(0.3, 0.3))
109-
train_environments = generate_environments(b, train_instances)
110-
validation_environments = generate_environments(b, validation_instances)
111-
112-
# Generate anticipative solutions
113-
train_dataset = vcat(
114-
map(train_environments) do env
115-
v, y = generate_anticipative_solution(b, env; reset_env=true)
116-
return y
117-
end...
118-
)
119-
120-
val_dataset = vcat(map(validation_environments) do env
121-
v, y = generate_anticipative_solution(b, env; reset_env=true)
122-
return y
123-
end...)
124-
125-
# Initialize model and maximizer
126-
model = generate_statistical_model(b)
127-
maximizer = generate_maximizer(b)
128-
129-
# Train with algorithm
130-
history = train_policy!(
131-
algorithm,
132-
model,
133-
maximizer,
134-
train_dataset,
135-
val_dataset;
136-
epochs=epochs,
137-
metrics=metrics,
138-
maximizer_kwargs=get_state,
139-
)
140-
90+
history = train_policy!(algorithm, model, maximizer, train_dataset; kwargs...)
14191
return history, model
14292
end
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
function baty_train_model(
2+
b::AbstractStochasticBenchmark{true};
3+
epochs=10,
4+
metrics::Tuple=(),
5+
algorithm::PerturbedImitationAlgorithm=PerturbedImitationAlgorithm(),
6+
)
7+
# Generate instances and environments
8+
dataset = generate_dataset(b, 30)
9+
train_instances, validation_instances, _ = splitobs(dataset; at=(0.3, 0.3))
10+
train_environments = generate_environments(b, train_instances)
11+
validation_environments = generate_environments(b, validation_instances)
12+
13+
# Generate anticipative solutions
14+
train_dataset = vcat(
15+
map(train_environments) do env
16+
v, y = generate_anticipative_solution(b, env; reset_env=true)
17+
return y
18+
end...
19+
)
20+
21+
val_dataset = vcat(map(validation_environments) do env
22+
v, y = generate_anticipative_solution(b, env; reset_env=true)
23+
return y
24+
end...)
25+
26+
# Initialize model and maximizer
27+
model = generate_statistical_model(b)
28+
maximizer = generate_maximizer(b)
29+
30+
# Train with algorithm
31+
history = train_policy!(
32+
algorithm,
33+
model,
34+
maximizer,
35+
train_dataset,
36+
val_dataset;
37+
epochs=epochs,
38+
metrics=metrics,
39+
maximizer_kwargs=get_state,
40+
)
41+
42+
return history, model
43+
end

src/metrics/accumulators.jl

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
"""
2-
LossAccumulator <: AbstractMetric
2+
$TYPEDEF
33
44
Accumulates loss values during training and computes their average.
55
66
This metric is used internally by training loops to track training loss.
77
It accumulates loss values via `update!` calls and computes the average via `compute`.
88
99
# Fields
10-
- `name::Symbol` - Identifier for this metric (e.g., `:training_loss`)
11-
- `total_loss::Float64` - Running sum of loss values
12-
- `count::Int` - Number of samples accumulated
10+
$TYPEDFIELDS
1311
1412
# Examples
1513
```julia
@@ -31,32 +29,27 @@ avg_loss = compute(metric) # Automatically resets
3129
- [`update!`](@ref)
3230
- [`compute`](@ref)
3331
"""
34-
mutable struct LossAccumulator <: AbstractMetric
32+
mutable struct LossAccumulator
33+
"Identifier for this metric (e.g., `:training_loss`)"
3534
const name::Symbol
35+
"Running sum of loss values"
3636
total_loss::Float64
37+
"Number of samples accumulated"
3738
count::Int
3839
end
3940

4041
"""
41-
LossAccumulator(name::Symbol=:training_loss)
42+
$TYPEDSIGNATURES
4243
4344
Construct a LossAccumulator with the given name.
44-
45-
# Arguments
46-
- `name::Symbol` - Identifier for the metric (default: `:training_loss`)
47-
48-
# Examples
49-
```julia
50-
train_metric = LossAccumulator(:training_loss)
51-
val_metric = LossAccumulator(:validation_loss)
52-
```
45+
Initializes total loss and count to zero.
5346
"""
5447
function LossAccumulator(name::Symbol=:training_loss)
5548
return LossAccumulator(name, 0.0, 0)
5649
end
5750

5851
"""
59-
reset!(metric::LossAccumulator)
52+
$TYPEDSIGNATURES
6053
6154
Reset the accumulator to its initial state (zero total loss and count).
6255
@@ -74,14 +67,10 @@ function reset!(metric::LossAccumulator)
7467
end
7568

7669
"""
77-
update!(metric::LossAccumulator, loss_value::Float64)
70+
$TYPEDSIGNATURES
7871
7972
Add a loss value to the accumulator.
8073
81-
# Arguments
82-
- `metric::LossAccumulator` - The accumulator to update
83-
- `loss_value::Float64` - Loss value to add
84-
8574
# Examples
8675
```julia
8776
metric = LossAccumulator()
@@ -96,7 +85,7 @@ function update!(metric::LossAccumulator, loss_value::Float64)
9685
end
9786

9887
"""
99-
compute(metric::LossAccumulator; reset::Bool=true)
88+
$TYPEDSIGNATURES
10089
10190
Compute the average loss from accumulated values.
10291
@@ -130,12 +119,11 @@ Metric for evaluating Fenchel-Young Loss over a dataset.
130119
131120
This metric stores a dataset and computes the average Fenchel-Young Loss
132121
when `evaluate!` is called. Useful for tracking validation loss during training.
122+
Can also be used in the algorithms to accumulate loss over training data.
133123
134124
# Fields
135-
- `name::Symbol` - Identifier for this metric (e.g., `:validation_loss`)
136125
- `dataset::D` - Dataset to evaluate on (stored internally)
137-
- `total_loss::Float64` - Running sum during evaluation
138-
- `count::Int` - Number of samples evaluated
126+
- `accumulator::LossAccumulator` - Embedded accumulator holding `name`, `total_loss`, and `count`.
139127
140128
# Examples
141129
```julia
@@ -151,11 +139,9 @@ avg_loss = evaluate!(val_metric, context)
151139
- [`LossAccumulator`](@ref)
152140
- [`FunctionMetric`](@ref)
153141
"""
154-
mutable struct FYLLossMetric{D} <: AbstractMetric
155-
const name::Symbol
156-
const dataset::D
157-
total_loss::Float64
158-
count::Int
142+
struct FYLLossMetric{D} <: AbstractMetric
143+
dataset::D
144+
accumulator::LossAccumulator
159145
end
160146

161147
"""
@@ -174,7 +160,7 @@ test_metric = FYLLossMetric(test_dataset, :test_loss)
174160
```
175161
"""
176162
function FYLLossMetric(dataset, name::Symbol=:fyl_loss)
177-
return FYLLossMetric(name, dataset, 0.0, 0)
163+
return FYLLossMetric(dataset, LossAccumulator(name))
178164
end
179165

180166
"""
@@ -183,8 +169,15 @@ end
183169
Reset the metric's accumulated loss to zero.
184170
"""
185171
function reset!(metric::FYLLossMetric)
186-
metric.total_loss = 0.0
187-
return metric.count = 0
172+
return reset!(metric.accumulator)
173+
end
174+
175+
function Base.getproperty(metric::FYLLossMetric, s::Symbol)
176+
if s === :name
177+
return metric.accumulator.name
178+
else
179+
return getfield(metric, s)
180+
end
188181
end
189182

190183
"""
@@ -204,8 +197,7 @@ Update the metric with a single loss computation.
204197
"""
205198
function update!(metric::FYLLossMetric, loss::FenchelYoungLoss, θ, y_target; kwargs...)
206199
l = loss(θ, y_target; kwargs...)
207-
metric.total_loss += l
208-
metric.count += 1
200+
update!(metric.accumulator, l)
209201
return l
210202
end
211203

@@ -231,7 +223,7 @@ context = TrainingContext(model=model, epoch=5, maximizer=maximizer, loss=loss)
231223
avg_loss = evaluate!(val_metric, context)
232224
```
233225
"""
234-
function evaluate!(metric::FYLLossMetric, context)
226+
function evaluate!(metric::FYLLossMetric, context::TrainingContext)
235227
reset!(metric)
236228
for sample in metric.dataset
237229
θ = context.model(sample.x)
@@ -250,5 +242,5 @@ Compute the average loss from accumulated values.
250242
- `Float64` - Average loss (or 0.0 if no values accumulated)
251243
"""
252244
function compute(metric::FYLLossMetric)
253-
return metric.count == 0 ? 0.0 : metric.total_loss / metric.count
245+
return compute(metric.accumulator)
254246
end

0 commit comments

Comments
 (0)