11"""
2- LossAccumulator <: AbstractMetric
2+ $TYPEDEF
33
44Accumulates loss values during training and computes their average.
55
66This metric is used internally by training loops to track training loss.
77It accumulates loss values via `update!` calls and computes the average via `compute`.
88
99# Fields
10- - `name::Symbol` - Identifier for this metric (e.g., `:training_loss`)
11- - `total_loss::Float64` - Running sum of loss values
12- - `count::Int` - Number of samples accumulated
10+ $TYPEDFIELDS
1311
1412# Examples
1513```julia
@@ -31,32 +29,27 @@ avg_loss = compute(metric) # Automatically resets
3129- [`update!`](@ref)
3230- [`compute`](@ref)
3331"""
34- mutable struct LossAccumulator <: AbstractMetric
32+ mutable struct LossAccumulator
33+ " Identifier for this metric (e.g., `:training_loss`)"
3534 const name:: Symbol
35+ " Running sum of loss values"
3636 total_loss:: Float64
37+ " Number of samples accumulated"
3738 count:: Int
3839end
3940
4041"""
41- LossAccumulator(name::Symbol=:training_loss)
42+ $TYPEDSIGNATURES
4243
4344Construct a LossAccumulator with the given name.
44-
45- # Arguments
46- - `name::Symbol` - Identifier for the metric (default: `:training_loss`)
47-
48- # Examples
49- ```julia
50- train_metric = LossAccumulator(:training_loss)
51- val_metric = LossAccumulator(:validation_loss)
52- ```
45+ Initializes total loss and count to zero.
5346"""
5447function LossAccumulator (name:: Symbol = :training_loss )
5548 return LossAccumulator (name, 0.0 , 0 )
5649end
5750
5851"""
59- reset!(metric::LossAccumulator)
52+ $TYPEDSIGNATURES
6053
6154Reset the accumulator to its initial state (zero total loss and count).
6255
@@ -74,14 +67,10 @@ function reset!(metric::LossAccumulator)
7467end
7568
7669"""
77- update!(metric::LossAccumulator, loss_value::Float64)
70+ $TYPEDSIGNATURES
7871
7972Add a loss value to the accumulator.
8073
81- # Arguments
82- - `metric::LossAccumulator` - The accumulator to update
83- - `loss_value::Float64` - Loss value to add
84-
8574# Examples
8675```julia
8776metric = LossAccumulator()
@@ -96,7 +85,7 @@ function update!(metric::LossAccumulator, loss_value::Float64)
9685end
9786
9887"""
99- compute(metric::LossAccumulator; reset::Bool=true)
88+ $TYPEDSIGNATURES
10089
10190Compute the average loss from accumulated values.
10291
@@ -130,12 +119,11 @@ Metric for evaluating Fenchel-Young Loss over a dataset.
130119
131120This metric stores a dataset and computes the average Fenchel-Young Loss
132121when `evaluate!` is called. Useful for tracking validation loss during training.
122+ Can also be used in the algorithms to accumulate loss over training data.
133123
134124# Fields
135- - `name::Symbol` - Identifier for this metric (e.g., `:validation_loss`)
136125- `dataset::D` - Dataset to evaluate on (stored internally)
137- - `total_loss::Float64` - Running sum during evaluation
138- - `count::Int` - Number of samples evaluated
126+ - `accumulator::LossAccumulator` - Embedded accumulator holding `name`, `total_loss`, and `count`.
139127
140128# Examples
141129```julia
@@ -151,11 +139,9 @@ avg_loss = evaluate!(val_metric, context)
151139- [`LossAccumulator`](@ref)
152140- [`FunctionMetric`](@ref)
153141"""
154- mutable struct FYLLossMetric{D} <: AbstractMetric
155- const name:: Symbol
156- const dataset:: D
157- total_loss:: Float64
158- count:: Int
142+ struct FYLLossMetric{D} <: AbstractMetric
143+ dataset:: D
144+ accumulator:: LossAccumulator
159145end
160146
161147"""
@@ -174,7 +160,7 @@ test_metric = FYLLossMetric(test_dataset, :test_loss)
174160```
175161"""
176162function FYLLossMetric (dataset, name:: Symbol = :fyl_loss )
177- return FYLLossMetric (name, dataset, 0.0 , 0 )
163+ return FYLLossMetric (dataset, LossAccumulator (name) )
178164end
179165
180166"""
183169Reset the metric's accumulated loss to zero.
184170"""
185171function reset! (metric:: FYLLossMetric )
186- metric. total_loss = 0.0
187- return metric. count = 0
172+ return reset! (metric. accumulator)
173+ end
174+
175+ function Base. getproperty (metric:: FYLLossMetric , s:: Symbol )
176+ if s === :name
177+ return metric. accumulator. name
178+ else
179+ return getfield (metric, s)
180+ end
188181end
189182
190183"""
@@ -204,8 +197,7 @@ Update the metric with a single loss computation.
204197"""
205198function update! (metric:: FYLLossMetric , loss:: FenchelYoungLoss , θ, y_target; kwargs... )
206199 l = loss (θ, y_target; kwargs... )
207- metric. total_loss += l
208- metric. count += 1
200+ update! (metric. accumulator, l)
209201 return l
210202end
211203
@@ -231,7 +223,7 @@ context = TrainingContext(model=model, epoch=5, maximizer=maximizer, loss=loss)
231223avg_loss = evaluate!(val_metric, context)
232224```
233225"""
234- function evaluate! (metric:: FYLLossMetric , context)
226+ function evaluate! (metric:: FYLLossMetric , context:: TrainingContext )
235227 reset! (metric)
236228 for sample in metric. dataset
237229 θ = context. model (sample. x)
@@ -250,5 +242,5 @@ Compute the average loss from accumulated values.
250242- `Float64` - Average loss (or 0.0 if no values accumulated)
251243"""
252244function compute (metric:: FYLLossMetric )
253- return metric. count == 0 ? 0.0 : metric . total_loss / metric . count
245+ return compute ( metric. accumulator)
254246end
0 commit comments