@@ -11,6 +11,46 @@ $TYPEDEF
1111An abstract type for imitation learning algorithms.
1212
1313All subtypes must implement:
14- - `train_policy!(algorithm::AbstractImitationAlgorithm, model, maximizer , train_data; epochs, metrics)`
14+ - `train_policy!(algorithm::AbstractImitationAlgorithm, policy::DFLPolicy , train_data; epochs, metrics)`
1515"""
1616abstract type AbstractImitationAlgorithm <: AbstractAlgorithm end
17+
18+ """
19+ $TYPEDSIGNATURES
20+
21+ Train a new DFLPolicy on a benchmark using any imitation learning algorithm.
22+
23+ Convenience wrapper that handles dataset generation, model initialization, and policy
24+ creation. Returns the training history and the trained policy.
25+
26+ For dynamic benchmarks, use the algorithm-specific `train_policy` overload that accepts
27+ environments and an anticipative policy.
28+ """
29+ function train_policy (
30+ algorithm:: AbstractImitationAlgorithm ,
31+ benchmark:: AbstractBenchmark ;
32+ target_policy= nothing ,
33+ dataset_size= 30 ,
34+ epochs= 100 ,
35+ metrics:: Tuple = (),
36+ seed= nothing ,
37+ )
38+ dataset = generate_dataset (benchmark, dataset_size; target_policy)
39+
40+ if any (s -> isnothing (s. y), dataset)
41+ error (
42+ " Training dataset contains unlabeled samples (y=nothing). " *
43+ " Provide a `target_policy` kwarg to label samples during dataset generation." ,
44+ )
45+ end
46+
47+ model = generate_statistical_model (benchmark; seed)
48+ maximizer = generate_maximizer (benchmark)
49+ policy = DFLPolicy (model, maximizer)
50+
51+ history = train_policy! (
52+ algorithm, policy, dataset; epochs, metrics, maximizer_kwargs= s -> s. context
53+ )
54+
55+ return history, policy
56+ end
0 commit comments