|
28 | 28 | function get_f_and_df_prepared!( |
29 | 29 | df, f::F, ::AutoEnzyme{M, <:AnyDuplicated}, ::Val{B} |
30 | 30 | ) where {F, M, B} |
31 | | - #= |
32 | | - It is not obvious why we don't need a `make_zero` here, in the case of mutable constant data in `f`. |
33 | | - - In forward mode, `df` is never incremented if `f` is not mutated, so it remains equal to its initial value of `0`. |
34 | | - - In reverse mode, `df` gets incremented but it does not influence the input cotangent `dx`. |
35 | | - =# |
36 | | - if B == 1 |
37 | | - return Duplicated(f, df) |
| 31 | + if isnothing(df) |
| 32 | + return Const(f) |
38 | 33 | else |
39 | | - return BatchDuplicated(f, df) |
| 34 | + if B == 1 |
| 35 | + return Duplicated(f, df) |
| 36 | + else |
| 37 | + return BatchDuplicated(f, df) |
| 38 | + end |
40 | 39 | end |
41 | 40 | end |
42 | 41 |
|
43 | 42 | function function_shadow( |
44 | | - ::F, ::AutoEnzyme{M, <:Union{Const, Nothing}}, ::Val{B} |
| 43 | + ::F, ::AutoEnzyme{M, <:Union{Const, Nothing}}, ::Mode, ::Val{B} |
45 | 44 | ) where {M, B, F} |
46 | 45 | return nothing |
47 | 46 | end |
48 | 47 |
|
49 | | -function function_shadow(f::F, ::AutoEnzyme{M, <:AnyDuplicated}, ::Val{B}) where {F, M, B} |
50 | | - if B == 1 |
51 | | - return make_zero(f) |
| 48 | +function function_shadow(f::F, ::AutoEnzyme{M, <:AnyDuplicated}, mode::Mode, ::Val{B}) where {F, M, B} |
| 49 | + IA = guess_activity(F, mode) |
| 50 | + return if IA <: Const |
| 51 | + nothing |
52 | 52 | else |
53 | | - return ntuple(_ -> make_zero(f), Val(B)) |
| 53 | + if B == 1 |
| 54 | + return make_zero(f) |
| 55 | + else |
| 56 | + return ntuple(_ -> make_zero(f), Val(B)) |
| 57 | + end |
54 | 58 | end |
55 | 59 | end |
56 | 60 |
|
@@ -87,13 +91,13 @@ function _shadow( |
87 | 91 | end |
88 | 92 |
|
89 | 93 | function _shadow( |
90 | | - backend::AutoEnzyme{M, <:Union{Const, Nothing}}, |
91 | | - ::Mode, |
| 94 | + backend::AutoEnzyme, |
| 95 | + mode::Mode, |
92 | 96 | ::Val{B}, |
93 | 97 | c_wrapped::DI.FunctionContext, |
94 | | - ) where {M, B} |
| 98 | + ) where {B} |
95 | 99 | f = DI.unwrap(c_wrapped) |
96 | | - return function_shadow(f, backend, Val(B)) |
| 100 | + return function_shadow(f, backend, mode, Val(B)) |
97 | 101 | end |
98 | 102 |
|
99 | 103 | function make_context_shadows( |
|
122 | 126 | function _translate_prepared!( |
123 | 127 | dc, c_wrapped::Union{DI.ConstantOrCache, DI.FunctionContext}, ::Val{B} |
124 | 128 | ) where {B} |
125 | | - #= |
126 | | - It is not obvious why we don't need a `make_zero` here, in the case of mutable constant contexts. |
127 | | - - In forward mode, `dc` is never incremented because `c` is not mutated, so it remains equal to its initial value of `0`. |
128 | | - - In reverse mode, `dc` gets incremented but it does not influence the input cotangent `dx`. |
129 | | - =# |
130 | 129 | c = DI.unwrap(c_wrapped) |
131 | 130 | if isnothing(dc) |
132 | 131 | return Const(c) |
|
0 commit comments