Skip to content

Commit 28d2aa6

Browse files
committed
cleanup dagger
1 parent 6a98761 commit 28d2aa6

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

src/algorithms/supervised/dagger.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,10 @@ function train_policy!(
8989
for env in train_environments
9090
DecisionFocusedLearningBenchmarks.reset!(env; reset_rng=false)
9191
while !is_terminated(env)
92-
x_before = copy(observe(env)[1])
9392
anticipative_solution = anticipative_policy(env; reset_env=false)
9493
p = rand(rng)
9594
target = anticipative_solution[1]
96-
x, state = observe(env)
95+
x, _ = observe(env)
9796
if size(target.x) != size(x)
9897
@error "Mismatch between expert and observed state" size(target.x) size(
9998
x
@@ -103,7 +102,6 @@ function train_policy!(
103102
if p < α
104103
action = target.y
105104
else
106-
x, state = observe(env)
107105
θ = statistical_model(x)
108106
action = maximizer(θ; maximizer_kwargs(target)...)
109107
end
@@ -134,13 +132,13 @@ function train_policy(
134132
benchmark::ExogenousDynamicBenchmark;
135133
dataset_size=30,
136134
metrics::Tuple=(),
137-
seed=0,
135+
seed=nothing,
138136
)
139137
# Generate environments
140138
train_environments = generate_environments(benchmark, dataset_size; seed)
141139

142140
# Initialize model and create policy
143-
model = generate_statistical_model(benchmark)
141+
model = generate_statistical_model(benchmark; seed)
144142
maximizer = generate_maximizer(benchmark)
145143
policy = DFLPolicy(model, maximizer)
146144

0 commit comments

Comments
 (0)