|
| 1 | +# Algorithm Interface |
| 2 | + |
| 3 | +This page describes the unified interface for Decision-Focused Learning algorithms provided by this package. |
| 4 | + |
| 5 | +## Core Concepts |
| 6 | + |
| 7 | +### DFLPolicy |
| 8 | + |
| 9 | +The [`DFLPolicy`](@ref) is the central abstraction that encapsulates a decision-focused learning policy. It combines: |
| 10 | +- A **statistical model** (typically a neural network) that predicts parameters from input features |
| 11 | +- A **combinatorial optimizer** (maximizer) that solves optimization problems using the predicted parameters |
| 12 | + |
| 13 | +```julia |
| 14 | +policy = DFLPolicy( |
| 15 | + Chain(Dense(input_dim => hidden_dim, relu), Dense(hidden_dim => output_dim)), |
| 16 | + my_optimizer |
| 17 | +) |
| 18 | +``` |
| 19 | + |
| 20 | +### Training Interface |
| 21 | + |
| 22 | +All algorithms in this package follow a unified training interface with two main functions: |
| 23 | + |
| 24 | +#### Core Training Method |
| 25 | + |
| 26 | +```julia |
| 27 | +history = train_policy!(algorithm, policy, training_data; epochs=100, metrics=(), maximizer_kwargs=get_info) |
| 28 | +``` |
| 29 | + |
| 30 | +**Arguments:** |
| 31 | +- `algorithm`: An algorithm instance (e.g., `PerturbedFenchelYoungLossImitation`, `DAgger`, `AnticipativeImitation`) |
| 32 | +- `policy::DFLPolicy`: The policy to train (contains the model and maximizer) |
| 33 | +- `training_data`: Either a dataset of `DataSample` objects or `Environment` (depends on algorithm) |
| 34 | +- `epochs::Int`: Number of training epochs (default: 100) |
| 35 | +- `metrics::Tuple`: Metrics to evaluate during training (default: empty) |
| 36 | +- `maximizer_kwargs::Function`: Function that extracts keyword arguments for the maximizer from data samples (default: `get_info`) |
| 37 | + |
| 38 | +**Returns:** |
| 39 | +- `history::MVHistory`: Training history containing loss values and metric evaluations |
| 40 | + |
| 41 | +#### Benchmark Convenience Wrapper |
| 42 | + |
| 43 | +```julia |
| 44 | +result = train_policy(algorithm, benchmark; dataset_size=30, split_ratio=(0.3, 0.3), epochs=100, metrics=()) |
| 45 | +``` |
| 46 | + |
| 47 | +This high-level function handles all setup from a benchmark and returns a trained policy along with training history. |
| 48 | + |
| 49 | +**Arguments:** |
| 50 | +- `algorithm`: An algorithm instance |
| 51 | +- `benchmark::AbstractBenchmark`: A benchmark from DecisionFocusedLearningBenchmarks.jl |
| 52 | +- `dataset_size::Int`: Number of instances to generate |
| 53 | +- `split_ratio::Tuple`: Train/validation/test split ratios |
| 54 | +- `epochs::Int`: Number of training epochs |
| 55 | +- `metrics::Tuple`: Metrics to track during training |
| 56 | + |
| 57 | +**Returns:** |
| 58 | +- `(; policy, history)`: Named tuple with trained policy and training history |
| 59 | + |
| 60 | +## Metrics |
| 61 | + |
| 62 | +Metrics allow you to track additional quantities during training. |
| 63 | + |
| 64 | +### Built-in Metrics |
| 65 | + |
| 66 | +#### FYLLossMetric |
| 67 | + |
| 68 | +Evaluates Fenchel-Young loss on a validation dataset. |
| 69 | + |
| 70 | +```julia |
| 71 | +val_metric = FYLLossMetric(validation_data, :validation_loss) |
| 72 | +``` |
| 73 | + |
| 74 | +#### FunctionMetric |
| 75 | + |
| 76 | +Custom metric defined by a function. |
| 77 | + |
| 78 | +```julia |
| 79 | +# Simple metric (no stored data) |
| 80 | +epoch_metric = FunctionMetric(ctx -> ctx.epoch, :epoch) |
| 81 | + |
| 82 | +# Metric with stored data |
| 83 | +gap_metric = FunctionMetric(:validation_gap, validation_data) do ctx, data |
| 84 | + compute_gap(benchmark, data, ctx.policy.statistical_model, ctx.policy.maximizer) |
| 85 | +end |
| 86 | +``` |
| 87 | + |
| 88 | +### TrainingContext |
| 89 | + |
| 90 | +Metrics receive a `TrainingContext` object containing: |
| 91 | +- `policy::DFLPolicy`: The policy being trained |
| 92 | +- `epoch::Int`: Current epoch number |
| 93 | +- `maximizer_kwargs::Function`: Maximizer kwargs extractor |
| 94 | +- `other_fields`: Algorithm-specific fields (e.g., `loss` for FYL) |
| 95 | + |
| 96 | +Access policy components: |
| 97 | +```julia |
| 98 | +ctx.policy.statistical_model # Neural network |
| 99 | +ctx.policy.maximizer # Combinatorial optimizer |
| 100 | +``` |
0 commit comments