Skip to content

Commit 37b816a

Browse files
committed
wip
1 parent 13d2398 commit 37b816a

File tree

12 files changed

+85
-110
lines changed

12 files changed

+85
-110
lines changed

src/DecisionFocusedLearningAlgorithms.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ using Statistics: mean
1010
using UnicodePlots: lineplot
1111
using ValueHistories: MVHistory
1212

13-
include("utils.jl")
1413
include("training_context.jl")
1514

1615
include("metrics/interface.jl")

src/algorithms/supervised/anticipative_imitation.jl

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,11 @@ function train_policy!(
3131
anticipative_policy,
3232
epochs=10,
3333
metrics::Tuple=(),
34-
maximizer_kwargs=get_state,
34+
maximizer_kwargs=sample -> sample.context,
3535
)
3636
# Generate anticipative solutions as training data
3737
train_dataset = vcat(map(train_environments) do env
38-
v, y = anticipative_policy(env; reset_env=true)
39-
return y
38+
return anticipative_policy(env; reset_env=true)
4039
end...)
4140

4241
# Delegate to inner algorithm
@@ -62,26 +61,22 @@ Uses anticipative solutions as expert demonstrations.
6261
"""
6362
function train_policy(
6463
algorithm::AnticipativeImitation,
65-
benchmark::AbstractStochasticBenchmark{true};
64+
benchmark::ExogenousDynamicBenchmark;
6665
dataset_size=30,
67-
split_ratio=(0.3, 0.3),
6866
epochs=10,
6967
metrics::Tuple=(),
7068
seed=nothing,
7169
)
72-
# Generate instances and environments
73-
dataset = generate_dataset(benchmark, dataset_size)
74-
train_instances, validation_instances, _ = splitobs(dataset; at=split_ratio)
75-
train_environments = generate_environments(benchmark, train_instances)
70+
# Generate environments
71+
train_environments = generate_environments(benchmark, dataset_size; seed)
7672

7773
# Initialize model and create policy
7874
model = generate_statistical_model(benchmark; seed)
7975
maximizer = generate_maximizer(benchmark)
8076
policy = DFLPolicy(model, maximizer)
8177

8278
# Define anticipative policy from benchmark
83-
anticipative_policy =
84-
(env; reset_env) -> generate_anticipative_solution(benchmark, env; reset_env)
79+
anticipative_policy = generate_anticipative_solver(benchmark)
8580

8681
# Train policy
8782
history = train_policy!(

src/algorithms/supervised/dagger.jl

Lines changed: 16 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function train_policy!(
3434
train_environments;
3535
anticipative_policy,
3636
metrics::Tuple=(),
37-
maximizer_kwargs=get_state,
37+
maximizer_kwargs=sample -> sample.context,
3838
)
3939
(; inner_algorithm, iterations, epochs_per_iteration, α_decay) = algorithm
4040
(; statistical_model, maximizer) = policy
@@ -43,15 +43,14 @@ function train_policy!(
4343

4444
# Initial dataset from expert demonstrations
4545
train_dataset = vcat(map(train_environments) do env
46-
v, y = anticipative_policy(env; reset_env=true)
47-
return y
46+
return anticipative_policy(env; reset_env=true)
4847
end...)
4948

5049
dataset = deepcopy(train_dataset)
5150

5251
# Initialize combined history for all DAgger iterations
5352
combined_history = MVHistory()
54-
global_epoch = 0
53+
epoch_offset = 0
5554

5655
for iter in 1:iterations
5756
println("DAgger iteration $iter/$iterations (α=$(round(α, digits=3)))")
@@ -68,50 +67,24 @@ function train_policy!(
6867

6968
# Merge iteration history into combined history
7069
for key in keys(iter_history)
71-
epochs, values = get(iter_history, key)
72-
for i in eachindex(epochs)
73-
# Calculate global epoch number
74-
if iter == 1
75-
# First iteration: use epochs as-is [0, 1, 2, ...]
76-
global_epoch_value = epochs[i]
77-
else
78-
# Later iterations: skip epoch 0 and renumber starting from global_epoch
79-
if epochs[i] == 0
80-
continue # Skip epoch 0 for iterations > 1
81-
end
82-
# Map epoch 1 → global_epoch, epoch 2 → global_epoch+1, etc.
83-
global_epoch_value = global_epoch + epochs[i] - 1
84-
end
85-
86-
# For the epoch key, use global_epoch_value as both time and value
87-
# For other keys, use global_epoch_value as time and original value
88-
if key == :epoch
89-
push!(combined_history, key, global_epoch_value, global_epoch_value)
90-
else
91-
push!(combined_history, key, global_epoch_value, values[i])
92-
end
70+
local_epochs, values = get(iter_history, key)
71+
for i in eachindex(local_epochs)
72+
# Skip epoch 0 for all iterations after the first
73+
local_epochs[i] == 0 && epoch_offset > 0 && continue
74+
global_e = epoch_offset + local_epochs[i]
75+
push!(combined_history, key, global_e, key == :epoch ? global_e : values[i])
9376
end
9477
end
9578

96-
# Update global_epoch for next iteration
97-
# After each iteration, advance by the number of non-zero epochs processed
98-
if iter == 1
99-
# First iteration processes all epochs [0, 1, ..., epochs_per_iteration]
100-
# Next iteration should start at epochs_per_iteration + 1
101-
global_epoch = epochs_per_iteration + 1
102-
else
103-
# Subsequent iterations skip epoch 0, so they process epochs_per_iteration epochs
104-
# Next iteration should start epochs_per_iteration later
105-
global_epoch += epochs_per_iteration
106-
end
79+
epoch_offset += epochs_per_iteration
10780

10881
# Dataset update - collect new samples using mixed policy
10982
new_samples = eltype(dataset)[]
11083
for env in train_environments
11184
DecisionFocusedLearningBenchmarks.reset!(env; reset_rng=false)
11285
while !is_terminated(env)
11386
x_before = copy(observe(env)[1])
114-
_, anticipative_solution = anticipative_policy(env; reset_env=false)
87+
anticipative_solution = anticipative_policy(env; reset_env=false)
11588
p = rand()
11689
target = anticipative_solution[1]
11790
x, state = observe(env)
@@ -149,25 +122,21 @@ This high-level function handles all setup from the benchmark and returns a trai
149122
"""
150123
function train_policy(
151124
algorithm::DAgger,
152-
benchmark::AbstractStochasticBenchmark{true};
125+
benchmark::ExogenousDynamicBenchmark;
153126
dataset_size=30,
154-
split_ratio=(0.3, 0.3, 0.4),
155127
metrics::Tuple=(),
156128
seed=0,
157129
)
158-
# Generate dataset and environments
159-
dataset = generate_dataset(benchmark, dataset_size)
160-
train_instances, validation_instances, _ = splitobs(dataset; at=split_ratio)
161-
train_environments = generate_environments(benchmark, train_instances; seed)
130+
# Generate environments
131+
train_environments = generate_environments(benchmark, dataset_size; seed)
162132

163133
# Initialize model and create policy
164134
model = generate_statistical_model(benchmark)
165135
maximizer = generate_maximizer(benchmark)
166136
policy = DFLPolicy(model, maximizer)
167137

168138
# Define anticipative policy from benchmark
169-
anticipative_policy =
170-
(env; reset_env) -> generate_anticipative_solution(benchmark, env; reset_env)
139+
anticipative_policy = generate_anticipative_solver(benchmark)
171140

172141
# Train policy
173142
history = train_policy!(
@@ -176,7 +145,7 @@ function train_policy(
176145
train_environments;
177146
anticipative_policy=anticipative_policy,
178147
metrics=metrics,
179-
maximizer_kwargs=get_state,
148+
maximizer_kwargs=sample -> sample.context,
180149
)
181150

182151
return history, policy

src/algorithms/supervised/fyl.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function train_policy!(
4545
train_dataset::DataLoader;
4646
epochs=100,
4747
metrics::Tuple=(),
48-
maximizer_kwargs=get_info,
48+
maximizer_kwargs=sample -> sample.context,
4949
)
5050
(; nb_samples, ε, threaded, training_optimizer, seed) = algorithm
5151
(; statistical_model, maximizer) = policy
@@ -106,7 +106,7 @@ function train_policy!(
106106
train_dataset::AbstractArray{<:DataSample};
107107
epochs=100,
108108
metrics::Tuple=(),
109-
maximizer_kwargs=get_info,
109+
maximizer_kwargs=sample -> sample.context,
110110
)
111111
data_loader = DataLoader(train_dataset; batchsize=1, shuffle=false)
112112
return train_policy!(
@@ -131,24 +131,32 @@ This high-level function handles all setup from the benchmark and returns a trai
131131
function train_policy(
132132
algorithm::PerturbedFenchelYoungLossImitation,
133133
benchmark::AbstractBenchmark;
134+
target_policy=nothing,
134135
dataset_size=30,
135136
split_ratio=(0.3, 0.3),
136137
epochs=100,
137138
metrics::Tuple=(),
138139
seed=nothing,
139140
)
140141
# Generate dataset and split
141-
dataset = generate_dataset(benchmark, dataset_size)
142+
dataset = generate_dataset(benchmark, dataset_size; target_policy)
142143
train_instances, _, _ = splitobs(dataset; at=split_ratio)
143144

145+
if any(s -> isnothing(s.y), train_instances)
146+
error(
147+
"Training dataset contains unlabeled samples (y=nothing). " *
148+
"Provide a `target_policy` kwarg to label samples during dataset generation.",
149+
)
150+
end
151+
144152
# Initialize model and create policy
145153
model = generate_statistical_model(benchmark; seed)
146154
maximizer = generate_maximizer(benchmark)
147155
policy = DFLPolicy(model, maximizer)
148156

149157
# Train policy
150158
history = train_policy!(
151-
algorithm, policy, train_instances; epochs, metrics, maximizer_kwargs=get_info
159+
algorithm, policy, train_instances; epochs, metrics, maximizer_kwargs=s -> s.context
152160
)
153161

154162
return history, policy

src/metrics/accumulators.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ end
151151
Construct a FYLLossMetric for a given dataset.
152152
153153
# Arguments
154-
- `dataset` - Dataset to evaluate on (should have samples with `.x`, `.y`, and `.info` fields)
155-
- `name::Symbol` - Identifier for the metric (default: `:fyl_loss`)
154+
- `dataset`: Dataset to evaluate on (should have samples with `.x`, `.y`, and `.context` fields)
155+
- `name::Symbol`: Identifier for the metric (default: `:fyl_loss`)
156156
"""
157157
function FYLLossMetric(dataset, name::Symbol=:fyl_loss)
158158
return FYLLossMetric(dataset, LossAccumulator(name))
@@ -181,11 +181,11 @@ $TYPEDSIGNATURES
181181
Update the metric with a single loss computation.
182182
183183
# Arguments
184-
- `metric::FYLLossMetric` - The metric to update
185-
- `loss::FenchelYoungLoss` - Loss function to use
186-
- `θ` - Model prediction
187-
- `y_target` - Target value
188-
- `kwargs...` - Additional arguments passed to loss function
184+
- `metric::FYLLossMetric`: The metric to update
185+
- `loss::FenchelYoungLoss`: Loss function to use
186+
- `θ`: Model prediction
187+
- `y_target`: Target value
188+
- `kwargs...`: Additional arguments passed to loss function
189189
"""
190190
function update!(metric::FYLLossMetric, loss::FenchelYoungLoss, θ, y_target; kwargs...)
191191
l = loss(θ, y_target; kwargs...)
@@ -202,8 +202,8 @@ This method iterates through the dataset, computes predictions using `context.po
202202
and accumulates losses using `context.loss`. The dataset should be stored in the metric.
203203
204204
# Arguments
205-
- `metric::FYLLossMetric` - The metric to evaluate
206-
- `context` - TrainingContext with `policy`, `loss`, and other fields
205+
- `metric::FYLLossMetric`: The metric to evaluate
206+
- `context::TrainingContext`: TrainingContext with `policy`, `loss`, and other fields
207207
"""
208208
function evaluate!(metric::FYLLossMetric, context::TrainingContext)
209209
reset!(metric)

src/metrics/function_metric.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,21 @@ epoch_metric = FunctionMetric(ctx -> ctx.epoch, :current_epoch)
1818
1919
# Metric with stored data (dataset)
2020
gap_metric = FunctionMetric(:val_gap, val_data) do ctx, data
21-
compute_gap(benchmark, data, ctx.model, ctx.maximizer)
21+
compute_gap(benchmark, data, ctx.policy.statistical_model, ctx.policy.maximizer)
2222
end
2323
2424
# Metric returning multiple values
2525
dual_gap = FunctionMetric(:gaps, (train_data, val_data)) do ctx, datasets
2626
train_ds, val_ds = datasets
2727
return (
28-
train_gap = compute_gap(benchmark, train_ds, ctx.model, ctx.maximizer),
29-
val_gap = compute_gap(benchmark, val_ds, ctx.model, ctx.maximizer)
28+
train_gap = compute_gap(benchmark, train_ds, ctx.policy.statistical_model, ctx.policy.maximizer),
29+
val_gap = compute_gap(benchmark, val_ds, ctx.policy.statistical_model, ctx.policy.maximizer)
3030
)
3131
end
3232
```
3333
3434
# See also
35-
- [`PeriodicMetric`](@ref) - Wrap a metric to evaluate periodically
35+
- [`PeriodicMetric`](@ref): Wrap a metric to evaluate periodically
3636
- [`evaluate!`](@ref)
3737
"""
3838
struct FunctionMetric{F,D} <: AbstractMetric
@@ -52,8 +52,8 @@ Construct a FunctionMetric without stored data.
5252
The function should have signature `(context) -> value`.
5353
5454
# Arguments
55-
- `metric_fn::Function` - Function to compute the metric
56-
- `name::Symbol` - Identifier for the metric
55+
- `metric_fn::Function`: Function to compute the metric
56+
- `name::Symbol`: Identifier for the metric
5757
"""
5858
function FunctionMetric(metric_fn::F, name::Symbol) where {F}
5959
return FunctionMetric{F,Nothing}(metric_fn, name, nothing)
@@ -65,8 +65,8 @@ $TYPEDSIGNATURES
6565
Evaluate the function metric by calling the stored function.
6666
6767
# Arguments
68-
- `metric::FunctionMetric` - The metric to evaluate
69-
- `context` - TrainingContext with current training state
68+
- `metric::FunctionMetric`: The metric to evaluate
69+
- `context::TrainingContext`: TrainingContext with current training state
7070
7171
# Returns
7272
- The value returned by `metric.metric_fn` (can be single value or NamedTuple)

src/metrics/interface.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ $TYPEDEF
44
Abstract base type for all metrics used during training.
55
66
All concrete metric types should implement:
7-
- `evaluate!(metric, context)` - Evaluate the metric given a training context
7+
- `evaluate!(metric, context)`: Evaluate the metric given a training context
88
99
# See also
1010
- [`LossAccumulator`](@ref)
@@ -20,14 +20,14 @@ abstract type AbstractMetric end
2020
Evaluate the metric given the current training context.
2121
2222
# Arguments
23-
- `metric::AbstractMetric` - The metric to evaluate
24-
- `context::TrainingContext` - Current training state (model, epoch, maximizer, etc.)
23+
- `metric::AbstractMetric`: The metric to evaluate
24+
- `context::TrainingContext`: Current training state (model, epoch, maximizer, etc.)
2525
2626
# Returns
2727
Can return:
28-
- A single value (Float64, Int, etc.) - stored with `metric.name`
29-
- A `NamedTuple` - each key-value pair stored separately
30-
- `nothing` - skipped (e.g., periodic metrics on off-epochs)
28+
- A single value (Float64, Int, etc.): stored with `metric.name`
29+
- A `NamedTuple`: each key-value pair stored separately
30+
- `nothing`: skipped (e.g., periodic metrics on off-epochs)
3131
"""
3232
function evaluate! end
3333

@@ -89,9 +89,9 @@ This function handles three types of metric returns through multiple dispatch:
8989
- **nothing**: Skipped (e.g., periodic metrics on epochs when not evaluated)
9090
9191
# Arguments
92-
- `history::MVHistory` - MVHistory object to store metric values
93-
- `metrics::Tuple` - Tuple of AbstractMetric instances to evaluate
94-
- `context::TrainingContext` - TrainingContext with current training state (policy, epoch, etc.)
92+
- `history::MVHistory`: MVHistory object to store metric values
93+
- `metrics::Tuple`: Tuple of AbstractMetric instances to evaluate
94+
- `context::TrainingContext`: TrainingContext with current training state (policy, epoch, etc.)
9595
9696
# Examples
9797
```julia

src/metrics/periodic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ $TYPEDSIGNATURES
7474
Evaluate the wrapped metric only if the current epoch matches the frequency pattern.
7575
7676
# Arguments
77-
- `pm::PeriodicMetric` - The periodic metric wrapper
78-
- `context` - TrainingContext with current epoch
77+
- `pm::PeriodicMetric`: The periodic metric wrapper
78+
- `context::TrainingContext`: TrainingContext with current epoch
7979
8080
# Returns
8181
- The result of `evaluate!(pm.metric, context)` if epoch matches the pattern

src/policies/dfl_policy.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,17 @@ function (p::DFLPolicy)(features::AbstractArray; kwargs...)
2222
y = p.maximizer(θ; kwargs...)
2323
return y
2424
end
25+
26+
"""
27+
$TYPEDSIGNATURES
28+
29+
Convenience overload: evaluate the optimality gap using a [`DFLPolicy`](@ref) directly,
30+
instead of unpacking `policy.statistical_model` and `policy.maximizer`.
31+
"""
32+
function DecisionFocusedLearningBenchmarks.compute_gap(
33+
bench, dataset, policy::DFLPolicy, op=mean
34+
)
35+
return DecisionFocusedLearningBenchmarks.compute_gap(
36+
bench, dataset, policy.statistical_model, policy.maximizer, op
37+
)
38+
end

0 commit comments

Comments
 (0)