-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathdifferentiate_with.jl
More file actions
110 lines (83 loc) · 4.2 KB
/
differentiate_with.jl
File metadata and controls
110 lines (83 loc) · 4.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
DifferentiateWith
Function wrapper that enforces differentiation with a "substitute" AD backend, possible different from the "true" AD backend that is called.
For instance, suppose a function `f` is not differentiable with Zygote because it involves mutation, but you know that it is differentiable with Enzyme.
Then `f2 = DifferentiateWith(f, AutoEnzyme())` is a new function that behaves like `f`, except that `f2` is differentiable with Zygote (thanks to a chain rule which calls Enzyme under the hood).
Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be differentiable with Zygote (as long as `f` was the only Zygote blocker).
!!! tip
This is mainly relevant for package developers who want to produce differentiable code at low cost, without writing the differentiation rules themselves.
If you sprinkle a few `DifferentiateWith` in places where some AD backends may struggle, end users can pick from a wider variety of packages to differentiate your algorithms.
!!! warning
`DifferentiateWith` only supports out-of-place functions `y = f(x, contexts...)`, where the derivatives with respect to `contexts` can be safely ignored in the rest of your code.
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), reverse-mode [Mooncake](https://github.com/chalk-lab/Mooncake.jl), or if it automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules.
For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper).
!!! warning
When using `DifferentiateWith(f, AutoSomething())`, the function `f` must not close over any active data.
As of now, we cannot differentiate with respect to parameters stored inside `f`.
# Fields
- `f`: the function in question, with signature `f(x, contexts...)`
- `backend::AbstractADType`: the substitute backend to use for differentiation
- `context_wrappers::NTuple`: a tuple like `(Constant, Cache)`, meaning that `f(x, a, b)` will be differentiated with `Constant(a)` and `Cache(b)` as contexts.
!!! note
For the substitute AD backend to be called under the hood, its package needs to be loaded in addition to the package of the true AD backend.
# Constructor
DifferentiateWith(f, backend, context_wrappers)
# Example
```jldoctest
julia> using DifferentiationInterface
julia> using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Zygote: Zygote
julia> function f(x::Vector{Float64})
a = Vector{Float64}(undef, 1) # type constraint breaks ForwardDiff
a[1] = sum(abs2, x) # mutation breaks Zygote
return a[1]
end;
julia> f2 = DifferentiateWith(f, AutoFiniteDiff());
julia> f([3.0, 5.0]) == f2([3.0, 5.0])
true
julia> alg(x) = 7 * f2(x);
julia> ForwardDiff.gradient(alg, [3.0, 5.0])
2-element Vector{Float64}:
42.0
70.0
julia> Zygote.gradient(alg, [3.0, 5.0])[1]
2-element Vector{Float64}:
42.0
70.0
```
"""
struct DifferentiateWith{C, F, B <: AbstractADType, N <: NTuple{C, Any}}
f::F
backend::B
context_wrappers::N
function DifferentiateWith(
f::F,
backend::B,
context_wrappers::NTuple{C, <:Union{Function, Type}},
) where {F, B <: AbstractADType, C}
return new{C, F, B, typeof(context_wrappers)}(
f,
backend,
context_wrappers,
)
end
end
DifferentiateWith(f::F, backend::AbstractADType) where {F} = DifferentiateWith(f, backend, ())
function (dw::DifferentiateWith{C})(x, args::Vararg{Any, C}) where {C}
return dw.f(x, args...)
end
function Base.show(io::IO, dw::DifferentiateWith)
(; f, backend, context_wrappers) = dw
return print(
io,
DifferentiateWith,
"(",
repr(f; context = io),
", ",
repr(backend; context = io),
", ",
repr(context_wrappers; context = io),
")",
)
end