-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathutils.jl
More file actions
35 lines (30 loc) · 1.14 KB
/
utils.jl
File metadata and controls
35 lines (30 loc) · 1.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
get_config(::AnyAutoMooncake{Nothing}) = Config()
get_config(backend::AnyAutoMooncake{<:Config}) = backend.config
@inline zero_tangent_unwrap(c::DI.Context) = zero_tangent(DI.unwrap(c))
@inline first_unwrap(c, dc) = (DI.unwrap(c), dc)
function call_and_return(f!::F, y, x, contexts...) where {F}
f!(y, x, contexts...)
return y
end
function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
if get_config(backend).friendly_tangents
# zero(x) but safer
return tangent_to_primal!!(_copy_output(x), zero_tangent(x))
else
return zero_tangent(x)
end
end
nanify(x::AbstractFloat) = convert(typeof(x), NaN)
nanify(x::AbstractArray) = map(nanify, x)
nanify(x::NamedTuple) = NamedTuple{keys(x)}(map(nanify, values(x)))
nanify(x::Tuple) = map(nanify, x)
nanify(::NoRData) = NoRData()
function nanify_fdata_and_rdata!!(contexts::Vararg{CoDual, C}) where {C}
primal_contexts = map(primal, contexts)
fdata_contexts = map(tangent, contexts)
zero_rdata_contexts = map(zero_rdata, primal_contexts)
foreach(fdata_contexts) do fc
increment!!(fc, nanify(fc))
end
return map(nanify, zero_rdata_contexts)
end