Skip to content

Commit 308b4e9

Browse files
committed
Preliminary cleanup
1 parent 7d8bec1 commit 308b4e9

File tree

17 files changed

+612
-585
lines changed

17 files changed

+612
-585
lines changed

Project.toml

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
name = "DecisionFocusedLearningAlgorithms"
22
uuid = "46d52364-bc3b-4fac-a992-eb1d3ef2de15"
33
authors = ["Members of JuliaDecisionFocusedLearning and contributors"]
4-
version = "0.0.1"
4+
version = "0.1.0"
5+
6+
[workspace]
7+
projects = ["docs", "test"]
58

69
[deps]
710
DecisionFocusedLearningBenchmarks = "2fbe496a-299b-4c81-bab5-c44dfc55cf20"
@@ -15,7 +18,7 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
1518
ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7"
1619

1720
[compat]
18-
DecisionFocusedLearningBenchmarks = "0.3.0"
21+
DecisionFocusedLearningBenchmarks = "0.4"
1922
Flux = "0.16.5"
2023
InferOpt = "0.7.1"
2124
MLUtils = "0.4.8"
@@ -25,14 +28,3 @@ Statistics = "1.11.1"
2528
UnicodePlots = "3.8.1"
2629
ValueHistories = "0.5.4"
2730
julia = "1.11"
28-
29-
[extras]
30-
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
31-
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
32-
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
33-
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
34-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
35-
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
36-
37-
[targets]
38-
test = ["Aqua", "Documenter", "JET", "JuliaFormatter", "Test", "TestItemRunner"]

debug_dagger.jl

Whitespace-only changes.

scripts/main.jl

Lines changed: 9 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,107 +1,21 @@
11
using DecisionFocusedLearningAlgorithms
22
using DecisionFocusedLearningBenchmarks
3+
4+
using Flux
35
using MLUtils
4-
using Statistics
56
using Plots
67

7-
# ! metric(prediction, data_sample)
8-
98
b = ArgmaxBenchmark()
109
initial_model = generate_statistical_model(b)
1110
maximizer = generate_maximizer(b)
1211
dataset = generate_dataset(b, 100)
13-
train_dataset, val_dataset, _ = splitobs(dataset; at=(0.3, 0.3, 0.4))
14-
res, model = fyl_train_model(
15-
initial_model, maximizer, train_dataset, val_dataset; epochs=100
16-
)
17-
18-
res = fyl_train_model(StochasticVehicleSchedulingBenchmark(); epochs=100)
19-
plot(res.validation_loss; label="Validation Loss")
20-
plot!(res.training_loss; label="Training Loss")
21-
22-
baty_train_model(DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false))
23-
DAgger_train_model(DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false))
24-
25-
struct KleopatraPolicy{M}
26-
model::M
27-
end
28-
29-
function (m::KleopatraPolicy)(env)
30-
x, instance = observe(env)
31-
θ = m.model(x)
32-
return maximizer(θ; instance)
33-
end
34-
35-
b = DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false)
36-
dataset = generate_dataset(b, 100)
37-
train_instances, validation_instances, test_instances = splitobs(
38-
dataset; at=(0.3, 0.3, 0.4)
39-
)
40-
train_environments = generate_environments(b, train_instances; seed=0)
41-
validation_environments = generate_environments(b, validation_instances)
42-
test_environments = generate_environments(b, test_instances)
43-
44-
train_dataset = vcat(map(train_environments) do env
45-
v, y = generate_anticipative_solution(b, env; reset_env=true)
46-
return y
47-
end...)
48-
49-
val_dataset = vcat(map(validation_environments) do env
50-
v, y = generate_anticipative_solution(b, env; reset_env=true)
51-
return y
52-
end...)
12+
train_dataset, val_dataset, test_dataset = splitobs(dataset; at=(0.3, 0.3, 0.4))
5313

54-
model = generate_statistical_model(b; seed=0)
55-
maximizer = generate_maximizer(b)
56-
anticipative_policy = (env; reset_env) -> generate_anticipative_solution(b, env; reset_env)
57-
58-
fyl_model = deepcopy(model)
59-
fyl_policy = Policy("fyl", "", KleopatraPolicy(fyl_model))
60-
61-
callbacks = [
62-
Metric(:obj, (data, ctx) -> mean(evaluate_policy!(fyl_policy, test_environments, 1)[1]))
63-
]
64-
65-
fyl_history = fyl_train_model!(
66-
fyl_model, maximizer, train_dataset, val_dataset; epochs=100, callbacks
67-
)
68-
69-
dagger_model = deepcopy(model)
70-
dagger_policy = Policy("dagger", "", KleopatraPolicy(dagger_model))
71-
72-
callbacks = [
73-
Metric(
74-
:obj, (data, ctx) -> mean(evaluate_policy!(dagger_policy, test_environments, 1)[1])
75-
),
76-
]
77-
78-
dagger_history = DAgger_train_model!(
79-
dagger_model,
80-
maximizer,
81-
train_environments,
82-
validation_environments,
83-
anticipative_policy;
84-
iterations=10,
85-
fyl_epochs=10,
86-
callbacks=callbacks,
87-
)
88-
89-
# Extract metric values for plotting
90-
fyl_epochs, fyl_obj_values = get(fyl_history, :val_obj)
91-
dagger_epochs, dagger_obj_values = get(dagger_history, :val_obj)
92-
93-
plot(
94-
[fyl_epochs, dagger_epochs],
95-
[fyl_obj_values, dagger_obj_values];
96-
labels=["FYL" "DAgger"],
97-
xlabel="Epoch",
98-
ylabel="Test Average Reward (1 scenario)",
14+
algorithm = PerturbedImitationAlgorithm(;
15+
nb_samples=20, ε=0.05, threaded=true, training_optimizer=Adam()
9916
)
10017

101-
using Statistics
102-
v_fyl, _ = evaluate_policy!(fyl_policy, test_environments, 100)
103-
v_dagger, _ = evaluate_policy!(dagger_policy, test_environments, 100)
104-
mean(v_fyl)
105-
mean(v_dagger)
106-
107-
anticipative_policy(test_environments[1]; reset_env=true)
18+
model = deepcopy(initial_model)
19+
history = train!(algorithm, model, maximizer, train_dataset, val_dataset; epochs=50)
20+
x, y = get(history, :training_loss)
21+
plot(x, y; xlabel="Epoch", ylabel="Training Loss", title="Training Loss over Epochs")

scripts/old/main.jl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
using DecisionFocusedLearningAlgorithms
2+
using DecisionFocusedLearningBenchmarks
3+
using MLUtils
4+
using Statistics
5+
using Plots
6+
7+
# ! metric(prediction, data_sample)
8+
9+
b = ArgmaxBenchmark()
10+
initial_model = generate_statistical_model(b)
11+
maximizer = generate_maximizer(b)
12+
dataset = generate_dataset(b, 100)
13+
train_dataset, val_dataset, _ = splitobs(dataset; at=(0.3, 0.3, 0.4))
14+
res, model = fyl_train_model(
15+
initial_model, maximizer, train_dataset, val_dataset; epochs=100
16+
)
17+
18+
res = fyl_train_model(StochasticVehicleSchedulingBenchmark(); epochs=100)
19+
plot(res.validation_loss; label="Validation Loss")
20+
plot!(res.training_loss; label="Training Loss")
21+
22+
baty_train_model(DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false))
23+
DAgger_train_model(DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false))
24+
25+
struct KleopatraPolicy{M}
26+
model::M
27+
end
28+
29+
function (m::KleopatraPolicy)(env)
30+
x, instance = observe(env)
31+
θ = m.model(x)
32+
return maximizer(θ; instance)
33+
end
34+
35+
b = DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false)
36+
dataset = generate_dataset(b, 100)
37+
train_instances, validation_instances, test_instances = splitobs(
38+
dataset; at=(0.3, 0.3, 0.4)
39+
)
40+
train_environments = generate_environments(b, train_instances; seed=0)
41+
validation_environments = generate_environments(b, validation_instances)
42+
test_environments = generate_environments(b, test_instances)
43+
44+
train_dataset = vcat(map(train_environments) do env
45+
v, y = generate_anticipative_solution(b, env; reset_env=true)
46+
return y
47+
end...)
48+
49+
val_dataset = vcat(map(validation_environments) do env
50+
v, y = generate_anticipative_solution(b, env; reset_env=true)
51+
return y
52+
end...)
53+
54+
model = generate_statistical_model(b; seed=0)
55+
maximizer = generate_maximizer(b)
56+
anticipative_policy = (env; reset_env) -> generate_anticipative_solution(b, env; reset_env)
57+
58+
fyl_model = deepcopy(model)
59+
fyl_policy = Policy("fyl", "", KleopatraPolicy(fyl_model))
60+
61+
callbacks = [
62+
Metric(:obj, (data, ctx) -> mean(evaluate_policy!(fyl_policy, test_environments, 1)[1]))
63+
]
64+
65+
fyl_history = fyl_train_model!(
66+
fyl_model, maximizer, train_dataset, val_dataset; epochs=100, callbacks
67+
)
68+
69+
dagger_model = deepcopy(model)
70+
dagger_policy = Policy("dagger", "", KleopatraPolicy(dagger_model))
71+
72+
callbacks = [
73+
Metric(
74+
:obj, (data, ctx) -> mean(evaluate_policy!(dagger_policy, test_environments, 1)[1])
75+
),
76+
]
77+
78+
dagger_history = DAgger_train_model!(
79+
dagger_model,
80+
maximizer,
81+
train_environments,
82+
validation_environments,
83+
anticipative_policy;
84+
iterations=10,
85+
fyl_epochs=10,
86+
callbacks=callbacks,
87+
)
88+
89+
# Extract metric values for plotting
90+
fyl_epochs, fyl_obj_values = get(fyl_history, :val_obj)
91+
dagger_epochs, dagger_obj_values = get(dagger_history, :val_obj)
92+
93+
plot(
94+
[fyl_epochs, dagger_epochs],
95+
[fyl_obj_values, dagger_obj_values];
96+
labels=["FYL" "DAgger"],
97+
xlabel="Epoch",
98+
ylabel="Test Average Reward (1 scenario)",
99+
)
100+
101+
using Statistics
102+
v_fyl, _ = evaluate_policy!(fyl_policy, test_environments, 100)
103+
v_dagger, _ = evaluate_policy!(dagger_policy, test_environments, 100)
104+
mean(v_fyl)
105+
mean(v_dagger)
106+
107+
anticipative_policy(test_environments[1]; reset_env=true)
File renamed without changes.
File renamed without changes.
File renamed without changes.

src/DecisionFocusedLearningAlgorithms.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module DecisionFocusedLearningAlgorithms
22

33
using DecisionFocusedLearningBenchmarks
4-
const DVSP = DecisionFocusedLearningBenchmarks.DynamicVehicleScheduling
54
using Flux: Flux, Adam
65
using InferOpt: InferOpt, FenchelYoungLoss, PerturbedAdditive
76
using MLUtils: splitobs
@@ -22,4 +21,6 @@ export fyl_train_model!,
2221
export TrainingCallback, Metric, on_epoch_end, get_metric_names, run_callbacks!
2322
export TrainingContext, update_context
2423

24+
export PerturbedImitationAlgorithm, train!
25+
2526
end

src/fyl.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,33 @@
55
# TODO: parallelize loss computation on validation set
66
# TODO: have supervised learning training method, where fyl_train calls it, therefore we can easily test new supervised losses if needed
77

8-
function fyl_train_model!(
8+
@kwdef struct PerturbedImitationAlgorithm{O}
9+
nb_samples::Int = 10
10+
ε::Float64 = 0.1
11+
threaded::Bool = true
12+
training_optimizer::O = Adam()
13+
history::MVHistory = MVHistory()
14+
end
15+
16+
reset!(algorithm::PerturbedImitationAlgorithm) = empty!(algorithm.history)
17+
18+
function train!(
19+
algorithm::PerturbedImitationAlgorithm,
920
model,
1021
maximizer,
1122
train_dataset::AbstractArray{<:DataSample},
1223
validation_dataset;
1324
epochs=100,
1425
maximizer_kwargs=get_info,
1526
callbacks::Vector{<:TrainingCallback}=TrainingCallback[],
27+
reset=false,
1628
)
17-
perturbed = PerturbedAdditive(maximizer; nb_samples=10, ε=0.1, threaded=true) # ! hardcoded
29+
reset && reset!(algorithm)
30+
(; nb_samples, ε, threaded, training_optimizer, history) = algorithm
31+
perturbed = PerturbedAdditive(maximizer; nb_samples, ε, threaded)
1832
loss = FenchelYoungLoss(perturbed)
1933

20-
optimizer = Adam() # ! hardcoded
21-
opt_state = Flux.setup(optimizer, model)
22-
23-
# Initialize metrics storage with MVHistory
24-
history = MVHistory()
34+
opt_state = Flux.setup(training_optimizer, model)
2535

2636
# Compute initial losses
2737
initial_val_loss = mean([

src/metric.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
abstract type AbstractMetric end
2+
3+
function reset!(metric::AbstractMetric) end
4+
function update!(metric::AbstractMetric; kwargs...) end
5+
function evaluate!(metric::AbstractMetric, policy, dataset; kwargs...) end

0 commit comments

Comments
 (0)