Skip to content

Commit 13d2398

Browse files
authored
Merge pull request #3 from JuliaDecisionFocusedLearning/first-draft
First draft at implementing generic SIL and generic DAgger
2 parents b13a4db + 0f3dfbe commit 13d2398

30 files changed

+2079
-50
lines changed

.JuliaFormatter.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
# See https://domluna.github.io/JuliaFormatter.jl/stable/ for a list of options
21
style = "blue"

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,8 @@
44
/Manifest*.toml
55
/docs/Manifest*.toml
66
/docs/build/
7+
tensorboard_logs
78
.vscode
9+
Manifest.toml
10+
examples
11+
scripts

Project.toml

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,32 @@
11
name = "DecisionFocusedLearningAlgorithms"
22
uuid = "46d52364-bc3b-4fac-a992-eb1d3ef2de15"
3+
version = "0.1.0"
34
authors = ["Members of JuliaDecisionFocusedLearning and contributors"]
4-
version = "0.0.1"
5+
6+
[workspace]
7+
projects = ["docs", "test"]
58

69
[deps]
10+
DecisionFocusedLearningBenchmarks = "2fbe496a-299b-4c81-bab5-c44dfc55cf20"
11+
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
12+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
13+
InferOpt = "4846b161-c94e-4150-8dac-c7ae193c601f"
14+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
15+
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
16+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
17+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
18+
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
19+
ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7"
720

821
[compat]
22+
DecisionFocusedLearningBenchmarks = "0.4"
23+
DocStringExtensions = "0.9.5"
24+
Flux = "0.16.5"
25+
InferOpt = "0.7.1"
26+
MLUtils = "0.4.8"
27+
ProgressMeter = "1.11.0"
28+
Random = "1.11.0"
29+
Statistics = "1.11.1"
30+
UnicodePlots = "3.8.1"
31+
ValueHistories = "0.5.4"
932
julia = "1.11"
10-
11-
[extras]
12-
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
13-
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
14-
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
15-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
16-
17-
[targets]
18-
test = ["Aqua", "JET", "JuliaFormatter", "Test"]

README.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,43 @@
66
[![Coverage](https://codecov.io/gh/JuliaDecisionFocusedLearning/DecisionFocusedLearningAlgorithms.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/JuliaDecisionFocusedLearning/DecisionFocusedLearningAlgorithms.jl)
77
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)
88
[![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
9+
10+
> [!WARNING]
11+
> This package is currently under active development. The API may change in future releases.
12+
> Please refer to the [documentation](https://JuliaDecisionFocusedLearning.github.io/DecisionFocusedLearningAlgorithms.jl/stable/) for the latest updates.
13+
14+
## Overview
15+
16+
This package provides a unified interface for training decision-focused learning algorithms that combine machine learning with combinatorial optimization. It implements several state-of-the-art algorithms for learning to predict parameters of optimization problems.
17+
18+
### Key Features
19+
20+
- **Unified Interface**: Consistent API across all algorithms via `train_policy!`
21+
- **Policy-Centric Design**: `DFLPolicy` encapsulates statistical models and optimizers
22+
- **Flexible Metrics**: Track custom metrics during training
23+
- **Benchmark Integration**: Seamless integration with DecisionFocusedLearningBenchmarks.jl
24+
25+
### Quick Start
26+
27+
```julia
28+
using DecisionFocusedLearningAlgorithms
29+
using DecisionFocusedLearningBenchmarks
30+
31+
# Create a policy
32+
benchmark = ArgmaxBenchmark()
33+
model = generate_statistical_model(benchmark)
34+
maximizer = generate_maximizer(benchmark)
35+
policy = DFLPolicy(model, maximizer)
36+
37+
# Train with FYL algorithm
38+
algorithm = PerturbedFenchelYoungLossImitation()
39+
result = train_policy(algorithm, benchmark; epochs=50)
40+
```
41+
42+
See the [documentation](https://JuliaDecisionFocusedLearning.github.io/DecisionFocusedLearningAlgorithms.jl/stable/) for more details.
43+
44+
## Available Algorithms
45+
46+
- **Perturbed Fenchel-Young Loss Imitation**: Differentiable imitation learning with perturbed optimization
47+
- **AnticipativeImitation**: Imitation of anticipative solutions for dynamic problems
48+
- **DAgger**: DAgger algorithm for dynamic problems

docs/Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
[deps]
22
DecisionFocusedLearningAlgorithms = "46d52364-bc3b-4fac-a992-eb1d3ef2de15"
3+
DecisionFocusedLearningBenchmarks = "2fbe496a-299b-4c81-bab5-c44dfc55cf20"
34
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
5+
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
6+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
7+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
8+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"

docs/make.jl

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,40 @@
11
using DecisionFocusedLearningAlgorithms
22
using Documenter
3+
using Literate
34

4-
DocMeta.setdocmeta!(
5-
DecisionFocusedLearningAlgorithms,
6-
:DocTestSetup,
7-
:(using DecisionFocusedLearningAlgorithms);
8-
recursive=true,
9-
)
5+
# Generate markdown files from tutorial scripts
6+
tutorial_dir = joinpath(@__DIR__, "src", "tutorials")
7+
tutorial_files = filter(f -> endswith(f, ".jl"), readdir(tutorial_dir))
8+
9+
# Convert .jl tutorial files to markdown
10+
for file in tutorial_files
11+
filepath = joinpath(tutorial_dir, file)
12+
Literate.markdown(filepath, tutorial_dir; documenter=true, execute=false)
13+
end
14+
15+
# Get list of generated markdown files for the docs
16+
md_tutorial_files = [
17+
joinpath("tutorials", replace(file, ".jl" => ".md")) for file in tutorial_files
18+
]
1019

1120
makedocs(;
1221
modules=[DecisionFocusedLearningAlgorithms],
1322
authors="Members of JuliaDecisionFocusedLearning and contributors",
1423
sitename="DecisionFocusedLearningAlgorithms.jl",
15-
format=Documenter.HTML(;
16-
canonical="https://JuliaDecisionFocusedLearning.github.io/DecisionFocusedLearningAlgorithms.jl",
17-
edit_link="main",
18-
assets=String[],
19-
),
20-
pages=["Home" => "index.md"],
24+
format=Documenter.HTML(; size_threshold=typemax(Int)),
25+
pages=[
26+
"Home" => "index.md",
27+
"Interface Guide" => "interface.md",
28+
"Tutorials" => md_tutorial_files,
29+
"API Reference" => "api.md",
30+
],
2131
)
2232

2333
deploydocs(;
2434
repo="github.com/JuliaDecisionFocusedLearning/DecisionFocusedLearningAlgorithms.jl",
2535
devbranch="main",
2636
)
37+
38+
for file in md_tutorial_files
39+
rm(joinpath(@__DIR__, "src", file))
40+
end

docs/src/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
```@index
2+
```
3+
4+
```@autodocs
5+
Modules = [DecisionFocusedLearningAlgorithms]
6+
```

docs/src/index.md

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,39 @@
1-
```@meta
2-
CurrentModule = DecisionFocusedLearningAlgorithms
3-
```
4-
51
# DecisionFocusedLearningAlgorithms
62

73
Documentation for [DecisionFocusedLearningAlgorithms](https://github.com/JuliaDecisionFocusedLearning/DecisionFocusedLearningAlgorithms.jl).
84

9-
```@index
10-
```
5+
## Overview
6+
7+
This package provides a unified interface for training decision-focused learning algorithms that combine machine learning with combinatorial optimization. It implements several state-of-the-art algorithms for learning to predict parameters of optimization problems.
8+
9+
### Key Features
10+
11+
- **Unified Interface**: Consistent API across all algorithms via `train_policy!`
12+
- **Policy-Centric Design**: `DFLPolicy` encapsulates statistical models and optimizers
13+
- **Flexible Metrics**: Track custom metrics during training
14+
- **Benchmark Integration**: Seamless integration with DecisionFocusedLearningBenchmarks.jl
1115

12-
```@autodocs
13-
Modules = [DecisionFocusedLearningAlgorithms]
16+
### Quick Start
17+
18+
```julia
19+
using DecisionFocusedLearningAlgorithms
20+
using DecisionFocusedLearningBenchmarks
21+
22+
# Create a policy
23+
benchmark = ArgmaxBenchmark()
24+
model = generate_statistical_model(benchmark)
25+
maximizer = generate_maximizer(benchmark)
26+
policy = DFLPolicy(model, maximizer)
27+
28+
# Train with FYL algorithm
29+
algorithm = PerturbedFenchelYoungLossImitation()
30+
result = train_policy(algorithm, benchmark; epochs=50)
1431
```
32+
33+
See the [Interface Guide](interface.md) and [Tutorials](tutorials/tutorial.md) for more details.
34+
35+
## Available Algorithms
36+
37+
- **Perturbed Fenchel-Young Loss Imitation**: Differentiable imitation learning with perturbed optimization
38+
- **AnticipativeImitation**: Imitation of anticipative solutions for dynamic problems
39+
- **DAgger**: DAgger algorithm for dynamic problems

docs/src/interface.md

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Algorithm Interface
2+
3+
This page describes the unified interface for Decision-Focused Learning algorithms provided by this package.
4+
5+
## Core Concepts
6+
7+
### DFLPolicy
8+
9+
The [`DFLPolicy`](@ref) is the central abstraction that encapsulates a decision-focused learning policy. It combines:
10+
- A **statistical model** (typically a neural network) that predicts parameters from input features
11+
- A **combinatorial optimizer** (maximizer) that solves optimization problems using the predicted parameters
12+
13+
```julia
14+
policy = DFLPolicy(
15+
Chain(Dense(input_dim => hidden_dim, relu), Dense(hidden_dim => output_dim)),
16+
my_optimizer
17+
)
18+
```
19+
20+
### Training Interface
21+
22+
All algorithms in this package follow a unified training interface with two main functions:
23+
24+
#### Core Training Method
25+
26+
```julia
27+
history = train_policy!(algorithm, policy, training_data; epochs=100, metrics=(), maximizer_kwargs=get_info)
28+
```
29+
30+
**Arguments:**
31+
- `algorithm`: An algorithm instance (e.g., `PerturbedFenchelYoungLossImitation`, `DAgger`, `AnticipativeImitation`)
32+
- `policy::DFLPolicy`: The policy to train (contains the model and maximizer)
33+
- `training_data`: Either a dataset of `DataSample` objects or `Environment` (depends on algorithm)
34+
- `epochs::Int`: Number of training epochs (default: 100)
35+
- `metrics::Tuple`: Metrics to evaluate during training (default: empty)
36+
- `maximizer_kwargs::Function`: Function that extracts keyword arguments for the maximizer from data samples (default: `get_info`)
37+
38+
**Returns:**
39+
- `history::MVHistory`: Training history containing loss values and metric evaluations
40+
41+
#### Benchmark Convenience Wrapper
42+
43+
```julia
44+
result = train_policy(algorithm, benchmark; dataset_size=30, split_ratio=(0.3, 0.3), epochs=100, metrics=())
45+
```
46+
47+
This high-level function handles all setup from a benchmark and returns a trained policy along with training history.
48+
49+
**Arguments:**
50+
- `algorithm`: An algorithm instance
51+
- `benchmark::AbstractBenchmark`: A benchmark from DecisionFocusedLearningBenchmarks.jl
52+
- `dataset_size::Int`: Number of instances to generate
53+
- `split_ratio::Tuple`: Train/validation/test split ratios
54+
- `epochs::Int`: Number of training epochs
55+
- `metrics::Tuple`: Metrics to track during training
56+
57+
**Returns:**
58+
- `(; policy, history)`: Named tuple with trained policy and training history
59+
60+
## Metrics
61+
62+
Metrics allow you to track additional quantities during training.
63+
64+
### Built-in Metrics
65+
66+
#### FYLLossMetric
67+
68+
Evaluates Fenchel-Young loss on a validation dataset.
69+
70+
```julia
71+
val_metric = FYLLossMetric(validation_data, :validation_loss)
72+
```
73+
74+
#### FunctionMetric
75+
76+
Custom metric defined by a function.
77+
78+
```julia
79+
# Simple metric (no stored data)
80+
epoch_metric = FunctionMetric(ctx -> ctx.epoch, :epoch)
81+
82+
# Metric with stored data
83+
gap_metric = FunctionMetric(:validation_gap, validation_data) do ctx, data
84+
compute_gap(benchmark, data, ctx.policy.statistical_model, ctx.policy.maximizer)
85+
end
86+
```
87+
88+
### TrainingContext
89+
90+
Metrics receive a `TrainingContext` object containing:
91+
- `policy::DFLPolicy`: The policy being trained
92+
- `epoch::Int`: Current epoch number
93+
- `maximizer_kwargs::Function`: Maximizer kwargs extractor
94+
- `other_fields`: Algorithm-specific fields (e.g., `loss` for FYL)
95+
96+
Access policy components:
97+
```julia
98+
ctx.policy.statistical_model # Neural network
99+
ctx.policy.maximizer # Combinatorial optimizer
100+
```

docs/src/tutorials/tutorial.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# # Basic Tutorial: Training with FYL on Argmax Benchmark
2+
#
3+
# This tutorial demonstrates the basic workflow for training a policy
4+
# using the Perturbed Fenchel-Young Loss algorithm.
5+
6+
# ## Setup
7+
using DecisionFocusedLearningAlgorithms
8+
using DecisionFocusedLearningBenchmarks
9+
using MLUtils: splitobs
10+
using Plots
11+
12+
# ## Create Benchmark and Data
13+
b = ArgmaxBenchmark()
14+
dataset = generate_dataset(b, 100)
15+
train_data, val_data, test_data = splitobs(dataset; at=(0.3, 0.3, 0.4))
16+
17+
# ## Create Policy
18+
model = generate_statistical_model(b; seed=0)
19+
maximizer = generate_maximizer(b)
20+
policy = DFLPolicy(model, maximizer)
21+
22+
# ## Configure Algorithm
23+
algorithm = PerturbedFenchelYoungLossImitation(;
24+
nb_samples=10, ε=0.1, threaded=true, seed=0
25+
)
26+
27+
# ## Define Metrics to track during training
28+
validation_loss_metric = FYLLossMetric(val_data, :validation_loss)
29+
30+
val_gap_metric = FunctionMetric(:val_gap, val_data) do ctx, data
31+
compute_gap(b, data, ctx.policy.statistical_model, ctx.policy.maximizer)
32+
end
33+
34+
test_gap_metric = FunctionMetric(:test_gap, test_data) do ctx, data
35+
compute_gap(b, data, ctx.policy.statistical_model, ctx.policy.maximizer)
36+
end
37+
38+
metrics = (validation_loss_metric, val_gap_metric, test_gap_metric)
39+
40+
# ## Train the Policy
41+
history = train_policy!(algorithm, policy, train_data; epochs=100, metrics=metrics)
42+
43+
# ## Plot Results
44+
val_gap_epochs, val_gap_values = get(history, :val_gap)
45+
test_gap_epochs, test_gap_values = get(history, :test_gap)
46+
47+
plot(
48+
[val_gap_epochs, test_gap_epochs],
49+
[val_gap_values, test_gap_values];
50+
labels=["Val Gap" "Test Gap"],
51+
xlabel="Epoch",
52+
ylabel="Gap",
53+
title="Gap Evolution During Training",
54+
)
55+
56+
# Plot loss evolution
57+
train_loss_epochs, train_loss_values = get(history, :training_loss)
58+
val_loss_epochs, val_loss_values = get(history, :validation_loss)
59+
60+
plot(
61+
[train_loss_epochs, val_loss_epochs],
62+
[train_loss_values, val_loss_values];
63+
labels=["Training Loss" "Validation Loss"],
64+
xlabel="Epoch",
65+
ylabel="Loss",
66+
title="Loss Evolution During Training",
67+
)

0 commit comments

Comments
 (0)