-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathcontext.jl
More file actions
183 lines (130 loc) · 4.65 KB
/
context.jl
File metadata and controls
183 lines (130 loc) · 4.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""
Context
Abstract supertype for additional context arguments, which can be passed to differentiation operators after the active input `x` but are not differentiated.
# Subtypes
- [`Constant`](@ref)
- [`Cache`](@ref)
- [`ConstantOrCache`](@ref)
"""
abstract type Context end
abstract type GeneralizedConstant <: Context end
unwrap(c::Context) = c.data
Base.:(==)(c1::Context, c2::Context) = unwrap(c1) == unwrap(c2)
## Public contexts
"""
Constant
Concrete type of [`Context`](@ref) argument which is kept constant during differentiation.
Note that an operator can be prepared with an arbitrary value of the constant.
However, same-point preparation must occur with the exact value that will be reused later.
!!! warning
Some backends require any `Constant` context to be a `Number` or an `AbstractArray`.
# Example
```jldoctest
julia> using DifferentiationInterface
julia> using ForwardDiff: ForwardDiff
julia> f(x, c) = c * sum(abs2, x);
julia> gradient(f, AutoForwardDiff(), [1.0, 2.0], Constant(10))
2-element Vector{Float64}:
20.0
40.0
julia> gradient(f, AutoForwardDiff(), [1.0, 2.0], Constant(100))
2-element Vector{Float64}:
200.0
400.0
```
"""
struct Constant{T} <: GeneralizedConstant
data::T
end
constant_maker(c) = Constant(c)
maker(::Constant) = constant_maker
adapt_eltype(c::Constant, ::Type) = c
"""
Cache
Concrete type of [`Context`](@ref) argument which can be mutated with active values during differentiation.
The initial values present inside the cache do not matter.
For some backends, preparation allocates the required memory for `Cache` contexts with the right element type, similar to [PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl).
!!! warning
Some backends require any `Cache` context to be an `AbstractArray`, others accept nested (named) tuples of `AbstractArray`s.
# Example
```jldoctest
julia> using DifferentiationInterface
julia> using ForwardDiff: ForwardDiff
julia> f(x, c) = sum(copyto!(c, x));
julia> prep = prepare_gradient(f, AutoForwardDiff(), [1.0, 2.0], Cache(zeros(2)));
julia> gradient(f, prep, AutoForwardDiff(), [3.0, 4.0], Cache(zeros(2)))
2-element Vector{Float64}:
1.0
1.0
```
"""
struct Cache{T} <: Context
data::T
end
cache_maker(c) = Cache(c)
maker(::Cache) = cache_maker
adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(recursive_similar(unwrap(c), T))
"""
ConstantOrCache
Concrete type of [`Context`](@ref) argument which can contain a mixture of constants and caches, passed along to the backend without modification.
Unlike for [`Cache`](@ref), it is up to the user to ensure that the internal storage can adapt to the required element types, for instance by using [PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl) directly.
"""
struct ConstantOrCache{T} <: Context
data::T
end
constantorcache_maker(c) = ConstantOrCache(c)
maker(::ConstantOrCache) = constantorcache_maker
adapt_eltype(c::ConstantOrCache, ::Type) = c
## Internal contexts for passing stuff around
"""
FunctionContext
Private type of [`Context`](@ref) argument used for passing functions inside second-order differentiation.
Behaves differently for Enzyme only, where the function can be annotated.
"""
struct FunctionContext{T} <: GeneralizedConstant
data::T
end
## Context manipulation
"""
Rewrap
Utility for recording context types of additional arguments (e.g. `Constant` or `Cache`) and re-wrapping them into their types after they have been unwrapped.
Useful for second-order differentiation.
"""
struct Rewrap{C, T}
context_makers::T
function Rewrap(contexts::Vararg{Context, C}) where {C}
context_makers = map(maker, contexts)
return new{C, typeof(context_makers)}(context_makers)
end
end
(::Rewrap{0})() = ()
function (r::Rewrap{C, T})(unannotated_contexts::Vararg{Any, C}) where {C, T}
return map(r.context_makers, unannotated_contexts) do maker, c
maker(c)
end
end
## Closures
"""
FixTail
Closure around a function `f` and a set of tail argument `tail_args` such that
```
(ft::FixTail)(args...) = ft.f(args..., ft.tail_args...)
```
"""
struct FixTail{F, A <: Tuple}
f::F
tail_args::A
function FixTail(f::F, tail_args::Vararg{Any, N}) where {F, N}
return new{F, typeof(tail_args)}(f, tail_args)
end
end
function (ft::FixTail)(args::Vararg{Any, N}) where {N}
return ft.f(args..., ft.tail_args...)
end
"""
fix_tail(f, tail_args...)
Convenience for constructing a [`FixTail`](@ref), with a shortcut when there are no tail arguments.
"""
@inline fix_tail(f::F) where {F} = f
fix_tail(f::F, args::Vararg{Any, N}) where {F, N} = FixTail(f, args...)
@inline call(f::F, x) where {F} = f(x)