Skip to content

Commit fa615d3

Browse files
committed
update and cleanup
1 parent f8c3968 commit fa615d3

17 files changed

+888
-664
lines changed

scripts/example_new_metrics.jl

Lines changed: 0 additions & 44 deletions
This file was deleted.

scripts/main.jl

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,48 @@ using InferOpt
66
using MLUtils
77
using Plots
88

9-
b = ArgmaxBenchmark()
9+
b = ArgmaxBenchmark(; seed=42)
1010
initial_model = generate_statistical_model(b; seed=0)
1111
maximizer = generate_maximizer(b)
1212
dataset = generate_dataset(b, 100; seed=0);
1313
train_dataset, val_dataset = splitobs(dataset; at=(0.5, 0.5));
1414

1515
algorithm = PerturbedImitationAlgorithm(;
16-
nb_samples=20, ε=0.1, threaded=true, training_optimizer=Adam()
16+
nb_samples=20, ε=0.1, threaded=true, training_optimizer=Adam(), seed=0
1717
)
1818

19-
validation_metric = FYLLossMetric(algorithm, val_dataset, :validation_loss, maximizer);
19+
validation_metric = FYLLossMetric(val_dataset, :validation_loss);
20+
epoch_metric = FunctionMetric(ctx -> ctx.epoch, :current_epoch)
21+
22+
dual_gap_metric = FunctionMetric(:dual_gap, (train_dataset, val_dataset)) do ctx, datasets
23+
_train_dataset, _val_dataset = datasets
24+
train_gap = compute_gap(b, _train_dataset, ctx.model, ctx.maximizer)
25+
val_gap = compute_gap(b, _val_dataset, ctx.model, ctx.maximizer)
26+
return (train_gap=train_gap, val_gap=val_gap)
27+
end
28+
29+
gap_metric = FunctionMetric(:validation_gap, val_dataset) do ctx, data
30+
compute_gap(b, data, ctx.model, ctx.maximizer)
31+
end
32+
periodic_gap = PeriodicMetric(gap_metric, 5)
33+
34+
gap_metric_offset = FunctionMetric(:delayed_gap, val_dataset) do ctx, data
35+
compute_gap(b, data, ctx.model, ctx.maximizer)
36+
end
37+
delayed_periodic_gap = PeriodicMetric(gap_metric_offset, 5; offset=10)
38+
39+
# Combine metrics
40+
metrics = (
41+
validation_metric,
42+
epoch_metric,
43+
dual_gap_metric, # Outputs both train_gap and val_gap every epoch
44+
periodic_gap, # Outputs validation_gap every 5 epochs
45+
delayed_periodic_gap, # Outputs delayed_gap every 5 epochs starting at epoch 10
46+
);
2047

2148
model = deepcopy(initial_model)
2249
history = train_policy!(
23-
algorithm,
24-
model,
25-
maximizer,
26-
train_dataset,
27-
val_dataset;
28-
epochs=50,
29-
metrics=(validation_metric,),
50+
algorithm, model, maximizer, train_dataset, val_dataset; epochs=50, metrics=metrics
3051
)
3152
X_train, Y_train = get(history, :training_loss)
3253
X_val, Y_val = get(history, :validation_loss)
@@ -44,3 +65,5 @@ plot!(
4465
label="Validation Loss",
4566
title="Validation Loss over Epochs",
4667
)
68+
69+
plot(get(history, :validation_gap); xlabel="Epoch", title="Validation Gap over Epochs")

scripts/main_dagger.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
using DecisionFocusedLearningAlgorithms
2+
using DecisionFocusedLearningBenchmarks
3+
4+
using Flux
5+
using InferOpt
6+
using MLUtils
7+
using Plots
8+
9+
# Create Dynamic Vehicle Scheduling Problem benchmark
10+
b = DynamicVehicleSchedulingBenchmark(; two_dimensional_features=true)
11+
12+
# Generate dataset and environments
13+
dataset = generate_dataset(b, 9)
14+
train_instances, val_instances, test_instances = splitobs(dataset; at=(0.5, 0.3, 0.2))
15+
16+
train_envs = generate_environments(b, train_instances; seed=0)
17+
val_envs = generate_environments(b, val_instances; seed=1)
18+
19+
# Initialize model and maximizer
20+
initial_model = generate_statistical_model(b; seed=0)
21+
maximizer = generate_maximizer(b)
22+
23+
# Define anticipative (expert) policy
24+
anticipative_policy = (env; reset_env) -> generate_anticipative_solution(b, env; reset_env)
25+
26+
# Configure training algorithm
27+
algorithm = PerturbedImitationAlgorithm(;
28+
nb_samples=10, ε=0.1, threaded=true, training_optimizer=Adam(0.001), seed=0
29+
)
30+
31+
# Define metrics to track during training
32+
epoch_metric = FunctionMetric(ctx -> ctx.epoch, :current_epoch)
33+
34+
# You can add validation metrics if you have a validation function
35+
# For now, we'll just track epochs
36+
metrics = (epoch_metric,)
37+
38+
# Train using DAgger
39+
println("Starting DAgger training on Dynamic Vehicle Scheduling Problem...")
40+
model = deepcopy(initial_model)
41+
42+
history = DAgger_train_model!(
43+
model,
44+
maximizer,
45+
train_envs,
46+
val_envs,
47+
anticipative_policy;
48+
iterations=5,
49+
fyl_epochs=10,
50+
metrics=metrics,
51+
algorithm=algorithm,
52+
)
53+
54+
# Plot training progress
55+
X_train, Y_train = get(history, :training_loss)
56+
plot(
57+
X_train,
58+
Y_train;
59+
xlabel="Epoch",
60+
ylabel="Training Loss",
61+
label="Training Loss",
62+
title="DAgger Training on Dynamic VSP",
63+
legend=:topright,
64+
)
65+
66+
# Plot epoch tracking if available
67+
if haskey(history, :current_epoch)
68+
X_epoch, Y_epoch = get(history, :current_epoch)
69+
println("Tracked epochs: ", Y_epoch)
70+
end
71+
72+
println("\nTraining completed!")
73+
println("Final training loss: ", Y_train[end])
74+
println("Total epochs: ", length(Y_train) - 1) # -1 because epoch 0 is included

src/DecisionFocusedLearningAlgorithms.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,30 @@ using ValueHistories: MVHistory
1212

1313
include("utils.jl")
1414
include("training_context.jl")
15-
# include("dfl_policy.jl")
16-
# include("callbacks.jl")
17-
include("metric.jl")
15+
16+
# Metrics subsystem
17+
include("metrics/interface.jl")
18+
include("metrics/accumulators.jl")
19+
include("metrics/function_metric.jl")
20+
include("metrics/periodic.jl")
21+
1822
include("fyl.jl")
1923
include("dagger.jl")
2024

21-
export fyl_train_model!,
22-
fyl_train_model, baty_train_model, DAgger_train_model!, DAgger_train_model
23-
export TrainingCallback, Metric, on_epoch_end, get_metric_names, run_callbacks!
24-
export TrainingContext, update_context
25+
export TrainingContext
2526

2627
export AbstractMetric,
27-
FYLLossMetric, FunctionMetric, LossAccumulator, reset!, update!, evaluate!, compute
28+
FYLLossMetric,
29+
FunctionMetric,
30+
PeriodicMetric,
31+
LossAccumulator,
32+
reset!,
33+
update!,
34+
evaluate!,
35+
compute,
36+
run_metrics!
37+
38+
export fyl_train_model, baty_train_model, DAgger_train_model!, DAgger_train_model
2839
export PerturbedImitationAlgorithm, train_policy!
2940

3041
end

0 commit comments

Comments
 (0)