Skip to content

Commit 1b2b20e

Browse files
committed
fix doc
1 parent d8002b2 commit 1b2b20e

File tree

7 files changed

+189
-28
lines changed

7 files changed

+189
-28
lines changed

docs/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
[deps]
22
DecisionFocusedLearningAlgorithms = "46d52364-bc3b-4fac-a992-eb1d3ef2de15"
3+
DecisionFocusedLearningBenchmarks = "2fbe496a-299b-4c81-bab5-c44dfc55cf20"
34
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
45
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
6+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
7+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
8+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"

docs/make.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,35 @@
11
using DecisionFocusedLearningAlgorithms
22
using Documenter
3+
using Literate
34

45
DocMeta.setdocmeta!(
56
DecisionFocusedLearningAlgorithms,
67
:DocTestSetup,
7-
:(using DecisionFocusedLearningAlgorithms);
8+
:(begin
9+
using DecisionFocusedLearningAlgorithms
10+
using DecisionFocusedLearningBenchmarks
11+
using Flux
12+
using MLUtils
13+
using Plots
14+
end),
815
recursive=true,
916
)
1017

18+
# Generate markdown files from tutorial scripts
1119
tutorial_dir = joinpath(@__DIR__, "src", "tutorials")
20+
tutorial_files = filter(f -> endswith(f, ".jl"), readdir(tutorial_dir))
1221

13-
include_tutorial = true
14-
15-
if include_tutorial
16-
for file in tutorial_files
17-
filepath = joinpath(tutorial_dir, file)
18-
Literate.markdown(filepath, md_dir; documenter=true, execute=false)
19-
end
22+
# Convert .jl tutorial files to markdown
23+
for file in tutorial_files
24+
filepath = joinpath(tutorial_dir, file)
25+
Literate.markdown(filepath, tutorial_dir; documenter=true, execute=false)
2026
end
2127

28+
# Get list of generated markdown files for the docs
29+
md_tutorial_files = [
30+
"tutorials/" * replace(file, ".jl" => ".md") for file in tutorial_files
31+
]
32+
2233
makedocs(;
2334
modules=[DecisionFocusedLearningAlgorithms],
2435
authors="Members of JuliaDecisionFocusedLearning and contributors",
@@ -28,7 +39,7 @@ makedocs(;
2839
edit_link="main",
2940
assets=String[],
3041
),
31-
pages=["Home" => "index.md", "Tutorials" => include_tutorial ? md_tutorial_files : []],
42+
pages=["Home" => "index.md", "Tutorials" => md_tutorial_files],
3243
)
3344

3445
deploydocs(;

docs/src/tutorials/tutorial.jl

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,35 +13,65 @@ train_instances, validation_instances, test_instances = splitobs(
1313
model = generate_statistical_model(b; seed=0)
1414
maximizer = generate_maximizer(b)
1515

16-
compute_gap(b, test_instances, model, maximizer)
17-
18-
metrics_callbacks = (;
19-
:time => (model, maximizer, epoch) -> (epoch_time = time()),
20-
:gap => (;
21-
:val =>
22-
(model, maximizer, epoch) ->
23-
(gap = compute_gap(b, validation_instances, model, maximizer)),
24-
:test =>
25-
(model, maximizer, epoch) ->
26-
(gap = compute_gap(b, test_instances, model, maximizer)),
27-
),
16+
# Compute initial gap
17+
initial_gap = compute_gap(b, test_instances, model, maximizer)
18+
println("Initial test gap: $initial_gap")
19+
20+
# Configure the training algorithm
21+
algorithm = PerturbedImitationAlgorithm(;
22+
nb_samples=10, ε=0.1, threaded=true, seed=0
2823
)
2924

25+
# Define metrics to track during training
26+
validation_loss_metric = FYLLossMetric(validation_instances, :validation_loss)
27+
28+
# Validation gap metric
29+
val_gap_metric = FunctionMetric(:val_gap, validation_instances) do ctx, data
30+
compute_gap(b, data, ctx.model, ctx.maximizer)
31+
end
32+
33+
# Test gap metric
34+
test_gap_metric = FunctionMetric(:test_gap, test_instances) do ctx, data
35+
compute_gap(b, data, ctx.model, ctx.maximizer)
36+
end
37+
38+
# Combine metrics
39+
metrics = (validation_loss_metric, val_gap_metric, test_gap_metric)
40+
41+
# Train the model
3042
fyl_model = deepcopy(model)
31-
log = fyl_train_model!(
43+
history = train_policy!(
44+
algorithm,
3245
fyl_model,
3346
maximizer,
3447
train_instances,
3548
validation_instances;
3649
epochs=100,
37-
metrics_callbacks,
50+
metrics=metrics,
3851
)
3952

40-
log[:gap]
53+
# Plot validation and test gaps
54+
val_gap_epochs, val_gap_values = get(history, :val_gap)
55+
test_gap_epochs, test_gap_values = get(history, :test_gap)
56+
4157
plot(
42-
[log[:gap].val, log[:gap].test];
58+
[val_gap_epochs, test_gap_epochs],
59+
[val_gap_values, test_gap_values];
4360
labels=["Val Gap" "Test Gap"],
4461
xlabel="Epoch",
4562
ylabel="Gap",
63+
title="Gap Evolution During Training",
64+
)
65+
66+
# Plot validation loss
67+
train_loss_epochs, train_loss_values = get(history, :training_loss)
68+
val_loss_epochs, val_loss_values = get(history, :validation_loss)
69+
70+
plot(
71+
[train_loss_epochs, val_loss_epochs],
72+
[train_loss_values, val_loss_values];
73+
labels=["Training Loss" "Validation Loss"],
74+
xlabel="Epoch",
75+
ylabel="Loss",
76+
title="Loss Evolution During Training",
4677
)
47-
plot(log[:validation_loss])

docs/src/tutorials/tutorial.md

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
```@meta
2+
EditURL = "tutorial.jl"
3+
```
4+
5+
Tutorial
6+
7+
````@example tutorial
8+
using DecisionFocusedLearningAlgorithms
9+
using DecisionFocusedLearningBenchmarks
10+
using MLUtils: splitobs
11+
using Plots
12+
13+
b = ArgmaxBenchmark()
14+
dataset = generate_dataset(b, 100)
15+
train_instances, validation_instances, test_instances = splitobs(
16+
dataset; at=(0.3, 0.3, 0.4)
17+
)
18+
19+
model = generate_statistical_model(b; seed=0)
20+
maximizer = generate_maximizer(b)
21+
````
22+
23+
Compute initial gap
24+
25+
````@example tutorial
26+
initial_gap = compute_gap(b, test_instances, model, maximizer)
27+
println("Initial test gap: $initial_gap")
28+
````
29+
30+
Configure the training algorithm
31+
32+
````@example tutorial
33+
algorithm = PerturbedImitationAlgorithm(;
34+
nb_samples=10, ε=0.1, threaded=true, seed=0
35+
)
36+
````
37+
38+
Define metrics to track during training
39+
40+
````@example tutorial
41+
validation_loss_metric = FYLLossMetric(validation_instances, :validation_loss)
42+
````
43+
44+
Validation gap metric
45+
46+
````@example tutorial
47+
val_gap_metric = FunctionMetric(:val_gap, validation_instances) do ctx, data
48+
compute_gap(b, data, ctx.model, ctx.maximizer)
49+
end
50+
````
51+
52+
Test gap metric
53+
54+
````@example tutorial
55+
test_gap_metric = FunctionMetric(:test_gap, test_instances) do ctx, data
56+
compute_gap(b, data, ctx.model, ctx.maximizer)
57+
end
58+
````
59+
60+
Combine metrics
61+
62+
````@example tutorial
63+
metrics = (validation_loss_metric, val_gap_metric, test_gap_metric)
64+
````
65+
66+
Train the model
67+
68+
````@example tutorial
69+
fyl_model = deepcopy(model)
70+
history = train_policy!(
71+
algorithm,
72+
fyl_model,
73+
maximizer,
74+
train_instances,
75+
validation_instances;
76+
epochs=100,
77+
metrics=metrics,
78+
)
79+
````
80+
81+
Plot validation and test gaps
82+
83+
````@example tutorial
84+
val_gap_epochs, val_gap_values = get(history, :val_gap)
85+
test_gap_epochs, test_gap_values = get(history, :test_gap)
86+
87+
plot(
88+
[val_gap_epochs, test_gap_epochs],
89+
[val_gap_values, test_gap_values];
90+
labels=["Val Gap" "Test Gap"],
91+
xlabel="Epoch",
92+
ylabel="Gap",
93+
title="Gap Evolution During Training",
94+
)
95+
````
96+
97+
Plot validation loss
98+
99+
````@example tutorial
100+
train_loss_epochs, train_loss_values = get(history, :training_loss)
101+
val_loss_epochs, val_loss_values = get(history, :validation_loss)
102+
103+
plot(
104+
[train_loss_epochs, val_loss_epochs],
105+
[train_loss_values, val_loss_values];
106+
labels=["Training Loss" "Validation Loss"],
107+
xlabel="Epoch",
108+
ylabel="Loss",
109+
title="Loss Evolution During Training",
110+
)
111+
````
112+
113+
---
114+
115+
*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*
116+

src/DecisionFocusedLearningAlgorithms.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ include("metrics/accumulators.jl")
1919
include("metrics/function_metric.jl")
2020
include("metrics/periodic.jl")
2121

22-
include("fyl.jl")
23-
include("dagger.jl")
22+
include("algorithms/fyl.jl")
23+
include("algorithms/dagger.jl")
2424

2525
export TrainingContext
2626

File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)