-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathabstract_algorithm.jl
More file actions
56 lines (44 loc) · 1.55 KB
/
abstract_algorithm.jl
File metadata and controls
56 lines (44 loc) · 1.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
$TYPEDEF
An abstract type for decision-focused learning algorithms.
"""
abstract type AbstractAlgorithm end
"""
$TYPEDEF
An abstract type for imitation learning algorithms.
All subtypes must implement:
- `train_policy!(algorithm::AbstractImitationAlgorithm, policy::DFLPolicy, train_data; epochs, metrics)`
"""
abstract type AbstractImitationAlgorithm <: AbstractAlgorithm end
"""
$TYPEDSIGNATURES
Train a new DFLPolicy on a benchmark using any imitation learning algorithm.
Convenience wrapper that handles dataset generation, model initialization, and policy
creation. Returns the training history and the trained policy.
For dynamic benchmarks, use the algorithm-specific `train_policy` overload that accepts
environments and an anticipative policy.
"""
function train_policy(
algorithm::AbstractImitationAlgorithm,
benchmark::AbstractBenchmark;
target_policy=nothing,
dataset_size=30,
epochs=100,
metrics::Tuple=(),
seed=nothing,
)
dataset = generate_dataset(benchmark, dataset_size; target_policy)
if any(s -> isnothing(s.y), dataset)
error(
"Training dataset contains unlabeled samples (y=nothing). " *
"Provide a `target_policy` kwarg to label samples during dataset generation.",
)
end
model = generate_statistical_model(benchmark; seed)
maximizer = generate_maximizer(benchmark)
policy = DFLPolicy(model, maximizer)
history = train_policy!(
algorithm, policy, dataset; epochs, metrics, maximizer_kwargs=s -> s.context
)
return history, policy
end