Skip to content

Commit 6a98761

Browse files
committed
Cleanup train_policy interface and add new version of DFLBenchmarks to Project.toml
1 parent 52e80c7 commit 6a98761

File tree

6 files changed

+49
-51
lines changed

6 files changed

+49
-51
lines changed

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
1919
ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7"
2020

2121
[compat]
22-
DecisionFocusedLearningBenchmarks = "0.4"
22+
DecisionFocusedLearningBenchmarks = "0.5.0"
2323
DocStringExtensions = "0.9.5"
24-
Flux = "0.16.5"
24+
Flux = "0.16.9"
2525
InferOpt = "0.7.1"
2626
MLUtils = "0.4.8"
2727
ProgressMeter = "1.11.0"
2828
Random = "1.11.0"
2929
Statistics = "1.11.1"
30-
UnicodePlots = "3.8.1"
31-
ValueHistories = "0.5.4"
30+
UnicodePlots = "3.8.2"
31+
ValueHistories = "0.5.6"
3232
julia = "1.11"

src/DecisionFocusedLearningAlgorithms.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using DecisionFocusedLearningBenchmarks
44
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
55
using Flux: Flux, Adam
66
using InferOpt: InferOpt, FenchelYoungLoss, PerturbedAdditive, PerturbedMultiplicative
7-
using MLUtils: splitobs, DataLoader
7+
using MLUtils: DataLoader
88
using ProgressMeter: @showprogress
99
using Random: Random, MersenneTwister
1010
using Statistics: mean
@@ -39,6 +39,7 @@ export AbstractMetric,
3939
compute!,
4040
evaluate_metrics!
4141

42+
export AbstractAlgorithm, AbstractImitationAlgorithm
4243
export PerturbedFenchelYoungLossImitation,
4344
DAgger, AnticipativeImitation, train_policy!, train_policy
4445
export AbstractPolicy, DFLPolicy

src/algorithms/abstract_algorithm.jl

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,46 @@ $TYPEDEF
1111
An abstract type for imitation learning algorithms.
1212
1313
All subtypes must implement:
14-
- `train_policy!(algorithm::AbstractImitationAlgorithm, model, maximizer, train_data; epochs, metrics)`
14+
- `train_policy!(algorithm::AbstractImitationAlgorithm, policy::DFLPolicy, train_data; epochs, metrics)`
1515
"""
1616
abstract type AbstractImitationAlgorithm <: AbstractAlgorithm end
17+
18+
"""
19+
$TYPEDSIGNATURES
20+
21+
Train a new DFLPolicy on a benchmark using any imitation learning algorithm.
22+
23+
Convenience wrapper that handles dataset generation, model initialization, and policy
24+
creation. Returns the training history and the trained policy.
25+
26+
For dynamic benchmarks, use the algorithm-specific `train_policy` overload that accepts
27+
environments and an anticipative policy.
28+
"""
29+
function train_policy(
30+
algorithm::AbstractImitationAlgorithm,
31+
benchmark::AbstractBenchmark;
32+
target_policy=nothing,
33+
dataset_size=30,
34+
epochs=100,
35+
metrics::Tuple=(),
36+
seed=nothing,
37+
)
38+
dataset = generate_dataset(benchmark, dataset_size; target_policy)
39+
40+
if any(s -> isnothing(s.y), dataset)
41+
error(
42+
"Training dataset contains unlabeled samples (y=nothing). " *
43+
"Provide a `target_policy` kwarg to label samples during dataset generation.",
44+
)
45+
end
46+
47+
model = generate_statistical_model(benchmark; seed)
48+
maximizer = generate_maximizer(benchmark)
49+
policy = DFLPolicy(model, maximizer)
50+
51+
history = train_policy!(
52+
algorithm, policy, dataset; epochs, metrics, maximizer_kwargs=s -> s.context
53+
)
54+
55+
return history, policy
56+
end

src/algorithms/supervised/fyl.jl

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -118,46 +118,3 @@ function train_policy!(
118118
maximizer_kwargs=maximizer_kwargs,
119119
)
120120
end
121-
122-
"""
123-
$TYPEDSIGNATURES
124-
125-
Train a DFLPolicy using the Perturbed Fenchel-Young Loss Imitation Algorithm on a benchmark.
126-
127-
# Benchmark convenience wrapper
128-
129-
This high-level function handles all setup from the benchmark and returns a trained policy.
130-
"""
131-
function train_policy(
132-
algorithm::PerturbedFenchelYoungLossImitation,
133-
benchmark::AbstractBenchmark;
134-
target_policy=nothing,
135-
dataset_size=30,
136-
split_ratio=(0.3, 0.3),
137-
epochs=100,
138-
metrics::Tuple=(),
139-
seed=nothing,
140-
)
141-
# Generate dataset and split
142-
dataset = generate_dataset(benchmark, dataset_size; target_policy)
143-
train_instances, _, _ = splitobs(dataset; at=split_ratio)
144-
145-
if any(s -> isnothing(s.y), train_instances)
146-
error(
147-
"Training dataset contains unlabeled samples (y=nothing). " *
148-
"Provide a `target_policy` kwarg to label samples during dataset generation.",
149-
)
150-
end
151-
152-
# Initialize model and create policy
153-
model = generate_statistical_model(benchmark; seed)
154-
maximizer = generate_maximizer(benchmark)
155-
policy = DFLPolicy(model, maximizer)
156-
157-
# Train policy
158-
history = train_policy!(
159-
algorithm, policy, train_instances; epochs, metrics, maximizer_kwargs=s -> s.context
160-
)
161-
162-
return history, policy
163-
end

src/metrics/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ Metrics must return a `Number`, a `NamedTuple`, or `nothing`.
7979
function _store_metric_value!(::MVHistory, metric_name::Symbol, ::Int, value)
8080
return error(
8181
"Metric `$metric_name` returned a value of type $(typeof(value)), which cannot " *
82-
"be stored in history. Metrics must return a Number, a NamedTuple, or nothing."
82+
"be stored in history. Metrics must return a Number, a NamedTuple, or nothing.",
8383
)
8484
end
8585

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ DecisionFocusedLearningAlgorithms = {path = ".."}
1515

1616
[compat]
1717
Aqua = "0.8"
18-
DecisionFocusedLearningBenchmarks = "0.4"
18+
DecisionFocusedLearningBenchmarks = "0.5"
1919
Documenter = "1"
2020
JuliaFormatter = "1"
2121
MLUtils = "0.4"

0 commit comments

Comments
 (0)