Skip to content

Commit 0f3dfbe

Browse files
committed
big update
1 parent 77613cb commit 0f3dfbe

33 files changed

+719
-1015
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ tensorboard_logs
88
.vscode
99
Manifest.toml
1010
examples
11+
scripts

README.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,43 @@
66
[![Coverage](https://codecov.io/gh/JuliaDecisionFocusedLearning/DecisionFocusedLearningAlgorithms.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/JuliaDecisionFocusedLearning/DecisionFocusedLearningAlgorithms.jl)
77
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)
88
[![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
9+
10+
> [!WARNING]
11+
> This package is currently under active development. The API may change in future releases.
12+
> Please refer to the [documentation](https://JuliaDecisionFocusedLearning.github.io/DecisionFocusedLearningAlgorithms.jl/stable/) for the latest updates.
13+
14+
## Overview
15+
16+
This package provides a unified interface for training decision-focused learning algorithms that combine machine learning with combinatorial optimization. It implements several state-of-the-art algorithms for learning to predict parameters of optimization problems.
17+
18+
### Key Features
19+
20+
- **Unified Interface**: Consistent API across all algorithms via `train_policy!`
21+
- **Policy-Centric Design**: `DFLPolicy` encapsulates statistical models and optimizers
22+
- **Flexible Metrics**: Track custom metrics during training
23+
- **Benchmark Integration**: Seamless integration with DecisionFocusedLearningBenchmarks.jl
24+
25+
### Quick Start
26+
27+
```julia
28+
using DecisionFocusedLearningAlgorithms
29+
using DecisionFocusedLearningBenchmarks
30+
31+
# Create a policy
32+
benchmark = ArgmaxBenchmark()
33+
model = generate_statistical_model(benchmark)
34+
maximizer = generate_maximizer(benchmark)
35+
policy = DFLPolicy(model, maximizer)
36+
37+
# Train with FYL algorithm
38+
algorithm = PerturbedFenchelYoungLossImitation()
39+
result = train_policy(algorithm, benchmark; epochs=50)
40+
```
41+
42+
See the [documentation](https://JuliaDecisionFocusedLearning.github.io/DecisionFocusedLearningAlgorithms.jl/stable/) for more details.
43+
44+
## Available Algorithms
45+
46+
- **Perturbed Fenchel-Young Loss Imitation**: Differentiable imitation learning with perturbed optimization
47+
- **AnticipativeImitation**: Imitation of anticipative solutions for dynamic problems
48+
- **DAgger**: DAgger algorithm for dynamic problems

docs/make.jl

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,6 @@ using DecisionFocusedLearningAlgorithms
22
using Documenter
33
using Literate
44

5-
DocMeta.setdocmeta!(
6-
DecisionFocusedLearningAlgorithms,
7-
:DocTestSetup,
8-
:(
9-
begin
10-
using DecisionFocusedLearningAlgorithms
11-
using DecisionFocusedLearningBenchmarks
12-
using Flux
13-
using MLUtils
14-
using Plots
15-
end
16-
);
17-
recursive=true,
18-
)
19-
205
# Generate markdown files from tutorial scripts
216
tutorial_dir = joinpath(@__DIR__, "src", "tutorials")
227
tutorial_files = filter(f -> endswith(f, ".jl"), readdir(tutorial_dir))
@@ -29,22 +14,27 @@ end
2914

3015
# Get list of generated markdown files for the docs
3116
md_tutorial_files = [
32-
"tutorials/" * replace(file, ".jl" => ".md") for file in tutorial_files
17+
joinpath("tutorials", replace(file, ".jl" => ".md")) for file in tutorial_files
3318
]
3419

3520
makedocs(;
3621
modules=[DecisionFocusedLearningAlgorithms],
3722
authors="Members of JuliaDecisionFocusedLearning and contributors",
3823
sitename="DecisionFocusedLearningAlgorithms.jl",
39-
format=Documenter.HTML(;
40-
canonical="https://JuliaDecisionFocusedLearning.github.io/DecisionFocusedLearningAlgorithms.jl",
41-
edit_link="main",
42-
assets=String[],
43-
),
44-
pages=["Home" => "index.md", "Tutorials" => md_tutorial_files],
24+
format=Documenter.HTML(; size_threshold=typemax(Int)),
25+
pages=[
26+
"Home" => "index.md",
27+
"Interface Guide" => "interface.md",
28+
"Tutorials" => md_tutorial_files,
29+
"API Reference" => "api.md",
30+
],
4531
)
4632

4733
deploydocs(;
4834
repo="github.com/JuliaDecisionFocusedLearning/DecisionFocusedLearningAlgorithms.jl",
4935
devbranch="main",
5036
)
37+
38+
for file in md_tutorial_files
39+
rm(joinpath(@__DIR__, "src", file))
40+
end

docs/src/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
```@index
2+
```
3+
4+
```@autodocs
5+
Modules = [DecisionFocusedLearningAlgorithms]
6+
```

docs/src/index.md

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,38 @@
22

33
Documentation for [DecisionFocusedLearningAlgorithms](https://github.com/JuliaDecisionFocusedLearning/DecisionFocusedLearningAlgorithms.jl).
44

5-
```@index
6-
```
5+
## Overview
6+
7+
This package provides a unified interface for training decision-focused learning algorithms that combine machine learning with combinatorial optimization. It implements several state-of-the-art algorithms for learning to predict parameters of optimization problems.
8+
9+
### Key Features
10+
11+
- **Unified Interface**: Consistent API across all algorithms via `train_policy!`
12+
- **Policy-Centric Design**: `DFLPolicy` encapsulates statistical models and optimizers
13+
- **Flexible Metrics**: Track custom metrics during training
14+
- **Benchmark Integration**: Seamless integration with DecisionFocusedLearningBenchmarks.jl
15+
16+
### Quick Start
717

8-
```@autodocs
9-
Modules = [DecisionFocusedLearningAlgorithms]
18+
```julia
19+
using DecisionFocusedLearningAlgorithms
20+
using DecisionFocusedLearningBenchmarks
21+
22+
# Create a policy
23+
benchmark = ArgmaxBenchmark()
24+
model = generate_statistical_model(benchmark)
25+
maximizer = generate_maximizer(benchmark)
26+
policy = DFLPolicy(model, maximizer)
27+
28+
# Train with FYL algorithm
29+
algorithm = PerturbedFenchelYoungLossImitation()
30+
result = train_policy(algorithm, benchmark; epochs=50)
1031
```
32+
33+
See the [Interface Guide](interface.md) and [Tutorials](tutorials/tutorial.md) for more details.
34+
35+
## Available Algorithms
36+
37+
- **Perturbed Fenchel-Young Loss Imitation**: Differentiable imitation learning with perturbed optimization
38+
- **AnticipativeImitation**: Imitation of anticipative solutions for dynamic problems
39+
- **DAgger**: DAgger algorithm for dynamic problems

docs/src/interface.md

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Algorithm Interface
2+
3+
This page describes the unified interface for Decision-Focused Learning algorithms provided by this package.
4+
5+
## Core Concepts
6+
7+
### DFLPolicy
8+
9+
The [`DFLPolicy`](@ref) is the central abstraction that encapsulates a decision-focused learning policy. It combines:
10+
- A **statistical model** (typically a neural network) that predicts parameters from input features
11+
- A **combinatorial optimizer** (maximizer) that solves optimization problems using the predicted parameters
12+
13+
```julia
14+
policy = DFLPolicy(
15+
Chain(Dense(input_dim => hidden_dim, relu), Dense(hidden_dim => output_dim)),
16+
my_optimizer
17+
)
18+
```
19+
20+
### Training Interface
21+
22+
All algorithms in this package follow a unified training interface with two main functions:
23+
24+
#### Core Training Method
25+
26+
```julia
27+
history = train_policy!(algorithm, policy, training_data; epochs=100, metrics=(), maximizer_kwargs=get_info)
28+
```
29+
30+
**Arguments:**
31+
- `algorithm`: An algorithm instance (e.g., `PerturbedFenchelYoungLossImitation`, `DAgger`, `AnticipativeImitation`)
32+
- `policy::DFLPolicy`: The policy to train (contains the model and maximizer)
33+
- `training_data`: Either a dataset of `DataSample` objects or `Environment` (depends on algorithm)
34+
- `epochs::Int`: Number of training epochs (default: 100)
35+
- `metrics::Tuple`: Metrics to evaluate during training (default: empty)
36+
- `maximizer_kwargs::Function`: Function that extracts keyword arguments for the maximizer from data samples (default: `get_info`)
37+
38+
**Returns:**
39+
- `history::MVHistory`: Training history containing loss values and metric evaluations
40+
41+
#### Benchmark Convenience Wrapper
42+
43+
```julia
44+
result = train_policy(algorithm, benchmark; dataset_size=30, split_ratio=(0.3, 0.3), epochs=100, metrics=())
45+
```
46+
47+
This high-level function handles all setup from a benchmark and returns a trained policy along with training history.
48+
49+
**Arguments:**
50+
- `algorithm`: An algorithm instance
51+
- `benchmark::AbstractBenchmark`: A benchmark from DecisionFocusedLearningBenchmarks.jl
52+
- `dataset_size::Int`: Number of instances to generate
53+
- `split_ratio::Tuple`: Train/validation/test split ratios
54+
- `epochs::Int`: Number of training epochs
55+
- `metrics::Tuple`: Metrics to track during training
56+
57+
**Returns:**
58+
- `(; policy, history)`: Named tuple with trained policy and training history
59+
60+
## Metrics
61+
62+
Metrics allow you to track additional quantities during training.
63+
64+
### Built-in Metrics
65+
66+
#### FYLLossMetric
67+
68+
Evaluates Fenchel-Young loss on a validation dataset.
69+
70+
```julia
71+
val_metric = FYLLossMetric(validation_data, :validation_loss)
72+
```
73+
74+
#### FunctionMetric
75+
76+
Custom metric defined by a function.
77+
78+
```julia
79+
# Simple metric (no stored data)
80+
epoch_metric = FunctionMetric(ctx -> ctx.epoch, :epoch)
81+
82+
# Metric with stored data
83+
gap_metric = FunctionMetric(:validation_gap, validation_data) do ctx, data
84+
compute_gap(benchmark, data, ctx.policy.statistical_model, ctx.policy.maximizer)
85+
end
86+
```
87+
88+
### TrainingContext
89+
90+
Metrics receive a `TrainingContext` object containing:
91+
- `policy::DFLPolicy`: The policy being trained
92+
- `epoch::Int`: Current epoch number
93+
- `maximizer_kwargs::Function`: Maximizer kwargs extractor
94+
- `other_fields`: Algorithm-specific fields (e.g., `loss` for FYL)
95+
96+
Access policy components:
97+
```julia
98+
ctx.policy.statistical_model # Neural network
99+
ctx.policy.maximizer # Combinatorial optimizer
100+
```

docs/src/tutorials/tutorial.jl

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,46 @@
1-
# Tutorial
1+
# # Basic Tutorial: Training with FYL on Argmax Benchmark
2+
#
3+
# This tutorial demonstrates the basic workflow for training a policy
4+
# using the Perturbed Fenchel-Young Loss algorithm.
5+
6+
# ## Setup
27
using DecisionFocusedLearningAlgorithms
38
using DecisionFocusedLearningBenchmarks
49
using MLUtils: splitobs
510
using Plots
611

12+
# ## Create Benchmark and Data
713
b = ArgmaxBenchmark()
814
dataset = generate_dataset(b, 100)
9-
train_instances, validation_instances, test_instances = splitobs(
10-
dataset; at=(0.3, 0.3, 0.4)
11-
)
15+
train_data, val_data, test_data = splitobs(dataset; at=(0.3, 0.3, 0.4))
1216

17+
# ## Create Policy
1318
model = generate_statistical_model(b; seed=0)
1419
maximizer = generate_maximizer(b)
20+
policy = DFLPolicy(model, maximizer)
1521

16-
# Compute initial gap
17-
initial_gap = compute_gap(b, test_instances, model, maximizer)
18-
println("Initial test gap: $initial_gap")
19-
20-
# Configure the training algorithm
21-
algorithm = PerturbedImitationAlgorithm(; nb_samples=10, ε=0.1, threaded=true, seed=0)
22+
# ## Configure Algorithm
23+
algorithm = PerturbedFenchelYoungLossImitation(;
24+
nb_samples=10, ε=0.1, threaded=true, seed=0
25+
)
2226

23-
# Define metrics to track during training
24-
validation_loss_metric = FYLLossMetric(validation_instances, :validation_loss)
27+
# ## Define Metrics to track during training
28+
validation_loss_metric = FYLLossMetric(val_data, :validation_loss)
2529

26-
# Validation gap metric
27-
val_gap_metric = FunctionMetric(:val_gap, validation_instances) do ctx, data
28-
compute_gap(b, data, ctx.model, ctx.maximizer)
30+
val_gap_metric = FunctionMetric(:val_gap, val_data) do ctx, data
31+
compute_gap(b, data, ctx.policy.statistical_model, ctx.policy.maximizer)
2932
end
3033

31-
# Test gap metric
32-
test_gap_metric = FunctionMetric(:test_gap, test_instances) do ctx, data
33-
compute_gap(b, data, ctx.model, ctx.maximizer)
34+
test_gap_metric = FunctionMetric(:test_gap, test_data) do ctx, data
35+
compute_gap(b, data, ctx.policy.statistical_model, ctx.policy.maximizer)
3436
end
3537

36-
# Combine metrics
3738
metrics = (validation_loss_metric, val_gap_metric, test_gap_metric)
3839

39-
# Train the model
40-
fyl_model = deepcopy(model)
41-
history = train_policy!(
42-
algorithm, fyl_model, maximizer, train_instances; epochs=100, metrics=metrics
43-
)
40+
# ## Train the Policy
41+
history = train_policy!(algorithm, policy, train_data; epochs=100, metrics=metrics)
4442

45-
# Plot validation and test gaps
43+
# ## Plot Results
4644
val_gap_epochs, val_gap_values = get(history, :val_gap)
4745
test_gap_epochs, test_gap_values = get(history, :test_gap)
4846

@@ -55,7 +53,7 @@ plot(
5553
title="Gap Evolution During Training",
5654
)
5755

58-
# Plot validation loss
56+
# Plot loss evolution
5957
train_loss_epochs, train_loss_values = get(history, :training_loss)
6058
val_loss_epochs, val_loss_values = get(history, :validation_loss)
6159

0 commit comments

Comments
 (0)