Skip to content

Commit 8cf2cc3

Browse files
Merge pull request #145 from DanielVandH/master
Non-allocating GradientCache and support for Setfield
2 parents ffdac57 + 9b18754 commit 8cf2cc3

4 files changed

Lines changed: 281 additions & 182 deletions

File tree

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "2.13.1"
66
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
9+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
910
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1011
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1112

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,11 +380,12 @@ FiniteDiff.GradientCache(
380380

381381
```julia
382382
FiniteDiff.GradientCache(
383+
fx :: Union{Nothing,<:Number,AbstractArray{<:Number}},
383384
c1 :: Union{Nothing,AbstractArray{<:Number}},
384385
c2 :: Union{Nothing,AbstractArray{<:Number}},
385-
fx :: Union{Nothing,<:Number,AbstractArray{<:Number}} = nothing,
386+
c3 :: Union{Nothing,AbstractArray{<:Number}},
386387
fdtype :: Type{T1} = Val{:central},
387-
returntype :: Type{T2} = eltype(df),
388+
returntype :: Type{T2} = eltype(fx),
388389
inplace :: Type{Val{T3}} = Val{true})
389390
```
390391

@@ -399,7 +400,7 @@ into the differencing algorithm here.
399400
## Jacobians
400401

401402
Jacobians are for functions `f!(fx,x)` when using in-place `finite_difference_jacobian!`,
402-
and `fx = f(x)` when using out-of-place `finite_difference_jacobain`. The out-of-place
403+
and `fx = f(x)` when using out-of-place `finite_difference_jacobian`. The out-of-place
403404
jacobian will return a similar type as `jac_prototype` if it is not a `nothing`. For non-square
404405
Jacobians, a cache which specifies the vector `fx` is required.
405406

src/gradients.jl

Lines changed: 132 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
1-
struct GradientCache{CacheType1, CacheType2, CacheType3, CacheType4, fdtype, returntype, inplace}
2-
fx :: CacheType1
3-
c1 :: CacheType2
4-
c2 :: CacheType3
5-
c3 :: CacheType4
1+
struct GradientCache{CacheType1,CacheType2,CacheType3,CacheType4,fdtype,returntype,inplace}
2+
fx::CacheType1
3+
c1::CacheType2
4+
c2::CacheType3
5+
c3::CacheType4
66
end
77

88
function GradientCache(
99
df,
1010
x,
11-
fdtype = Val(:central),
12-
returntype = eltype(df),
13-
inplace = Val(true))
11+
fdtype=Val(:central),
12+
returntype=eltype(df),
13+
inplace=Val(true))
1414

1515
fdtype isa Type && (fdtype = fdtype())
1616
inplace isa Type && (inplace = inplace())
17-
if typeof(x)<:AbstractArray # the vector->scalar case
18-
if fdtype!=Val(:complex) # complex-mode FD only needs one cache, for x+eps*im
19-
if typeof(x)<:StridedVector
20-
if eltype(df)<:Complex && !(eltype(x)<:Complex)
17+
if typeof(x) <: AbstractArray # the vector->scalar case
18+
if fdtype != Val(:complex) # complex-mode FD only needs one cache, for x+eps*im
19+
if typeof(x) <: StridedVector
20+
if eltype(df) <: Complex && !(eltype(x) <: Complex)
2121
_c1 = zero(Complex{eltype(x)}) .* x
2222
_c2 = nothing
2323
else
@@ -29,7 +29,7 @@ function GradientCache(
2929
_c2 = zero(real(eltype(x))) .* x
3030
end
3131
else
32-
if !(returntype<:Real)
32+
if !(returntype <: Real)
3333
fdtype_error(returntype)
3434
else
3535
_c1 = x .+ zero(eltype(x)) .* im
@@ -50,19 +50,61 @@ function GradientCache(
5050
end
5151

5252
GradientCache{Nothing,typeof(_c1),typeof(_c2),typeof(_c3),fdtype,
53-
returntype,inplace}(nothing,_c1,_c2,_c3)
53+
returntype,inplace}(nothing, _c1, _c2, _c3)
5454

5555
end
5656

57+
"""
58+
GradientCache(c1, c2, c3, fx, fdtype = Val(:central), returntype = eltype(fx), inplace = Val(false))
59+
60+
Construct a non-allocating gradient cache.
61+
62+
# Arguments
63+
- `c1`, `c2`, `c3`: (Non-aliased) caches for the input vector.
64+
- `fx`: Cached function call.
65+
- `fdtype = Val(:central)`: Method for cmoputing the finite difference.
66+
- `returntype = eltype(fx)`: Element type for the returned function value.
67+
- `inplace = Val(false)`: Whether the function is computed in-place or not.
68+
69+
# Output
70+
The output is a [`GradientCache`](@ref) struct.
71+
72+
```julia
73+
julia> x = [1.0, 3.0]
74+
2-element Vector{Float64}:
75+
1.0
76+
3.0
77+
78+
julia> _f = x -> x[1] + x[2]
79+
#13 (generic function with 1 method)
80+
81+
julia> fx = _f(x)
82+
4.0
83+
84+
julia> gradcache = GradientCache(copy(x), copy(x), copy(x), fx)
85+
GradientCache{Float64, Vector{Float64}, Vector{Float64}, Vector{Float64}, Val{:central}(), Float64, Val{false}()}(4.0, [1.0, 3.0], [1.0, 3.0], [1.0, 3.0])
86+
```
87+
"""
88+
function GradientCache(
89+
fx::Fx,# match order in struct for Setfield
90+
c1::T,
91+
c2::T,
92+
c3::T,
93+
fdtype=Val(:central),
94+
returntype=eltype(fx),
95+
inplace=Val(true)) where {T,Fx} # Val(false) isn't so important for vector -> scalar, it gets ignored in that case anyway.
96+
GradientCache{Fx,T,T,T,fdtype,returntype,inplace}(fx, c1, c2, c3)
97+
end
98+
5799
function finite_difference_gradient(
58100
f,
59101
x,
60-
fdtype = Val(:central),
61-
returntype = eltype(x),
62-
inplace = Val(true),
63-
fx = nothing,
64-
c1 = nothing,
65-
c2 = nothing;
102+
fdtype=Val(:central),
103+
returntype=eltype(x),
104+
inplace=Val(true),
105+
fx=nothing,
106+
c1=nothing,
107+
c2=nothing;
66108
relstep=default_relstep(fdtype, eltype(x)),
67109
absstep=relstep,
68110
dir=true)
@@ -72,12 +114,15 @@ function finite_difference_gradient(
72114
df = zero(returntype) .* x
73115
else
74116
if inplace == Val(true)
75-
if typeof(fx)==Nothing && typeof(c1)==Nothing && typeof(c2)==Nothing
117+
if typeof(fx) == Nothing && typeof(c1) == Nothing && typeof(c2) == Nothing
76118
error("In the scalar->vector in-place map case, at least one of fx, c1 or c2 must be provided, otherwise we cannot infer the return size.")
77119
else
78-
if c1 != nothing df = zero(c1)
79-
elseif fx != nothing df = zero(fx)
80-
elseif c2 != nothing df = zero(c2)
120+
if c1 != nothing
121+
df = zero(c1)
122+
elseif fx != nothing
123+
df = zero(fx)
124+
elseif c2 != nothing
125+
df = zero(c2)
81126
end
82127
end
83128
else
@@ -138,71 +183,71 @@ function finite_difference_gradient!(
138183
fx, c1, c2, c3 = cache.fx, cache.c1, cache.c2, cache.c3
139184
if fdtype != Val(:complex) && ArrayInterfaceCore.fast_scalar_indexing(c2)
140185
@. c2 = compute_epsilon(fdtype, x, relstep, absstep, dir)
141-
copyto!(c1,x)
186+
copyto!(c1, x)
142187
end
143-
copyto!(c3,x)
188+
copyto!(c3, x)
144189
if fdtype == Val(:forward)
145190
@inbounds for i eachindex(x)
146191
if ArrayInterfaceCore.fast_scalar_indexing(c2)
147-
epsilon = ArrayInterfaceCore.allowed_getindex(c2,i)*dir
192+
epsilon = ArrayInterfaceCore.allowed_getindex(c2, i) * dir
148193
else
149-
epsilon = compute_epsilon(fdtype, x, relstep, absstep, dir)*dir
194+
epsilon = compute_epsilon(fdtype, x, relstep, absstep, dir) * dir
150195
end
151-
c1_old = ArrayInterfaceCore.allowed_getindex(c1,i)
152-
ArrayInterfaceCore.allowed_setindex!(c1,c1_old + epsilon,i)
196+
c1_old = ArrayInterfaceCore.allowed_getindex(c1, i)
197+
ArrayInterfaceCore.allowed_setindex!(c1, c1_old + epsilon, i)
153198
if typeof(fx) != Nothing
154199
dfi = (f(c1) - fx) / epsilon
155200
else
156201
fx0 = f(x)
157202
dfi = (f(c1) - fx0) / epsilon
158203
end
159204
df_tmp = real(dfi)
160-
if eltype(df)<:Complex
161-
ArrayInterfaceCore.allowed_setindex!(c1,c1_old + im * epsilon,i)
205+
if eltype(df) <: Complex
206+
ArrayInterfaceCore.allowed_setindex!(c1, c1_old + im * epsilon, i)
162207
if typeof(fx) != Nothing
163-
dfi = (f(c1) - fx) / (im*epsilon)
208+
dfi = (f(c1) - fx) / (im * epsilon)
164209
else
165-
dfi = (f(c1) - fx0) / (im*epsilon)
210+
dfi = (f(c1) - fx0) / (im * epsilon)
166211
end
167-
ArrayInterfaceCore.allowed_setindex!(c1,c1_old,i)
212+
ArrayInterfaceCore.allowed_setindex!(c1, c1_old, i)
168213
ArrayInterfaceCore.allowed_setindex!(df, df_tmp - im * imag(dfi), i)
169214
else
170215
ArrayInterfaceCore.allowed_setindex!(df, df_tmp, i)
171-
ArrayInterfaceCore.allowed_setindex!(c1,c1_old,i)
216+
ArrayInterfaceCore.allowed_setindex!(c1, c1_old, i)
172217
end
173218
end
174219
elseif fdtype == Val(:central)
175220
@inbounds for i eachindex(x)
176221
if ArrayInterfaceCore.fast_scalar_indexing(c2)
177-
epsilon = ArrayInterfaceCore.allowed_getindex(c2,i)*dir
222+
epsilon = ArrayInterfaceCore.allowed_getindex(c2, i) * dir
178223
else
179-
epsilon = compute_epsilon(fdtype, x, relstep, absstep, dir)*dir
224+
epsilon = compute_epsilon(fdtype, x, relstep, absstep, dir) * dir
180225
end
181-
c1_old = ArrayInterfaceCore.allowed_getindex(c1,i)
182-
ArrayInterfaceCore.allowed_setindex!(c1,c1_old + epsilon, i)
183-
x_old = ArrayInterfaceCore.allowed_getindex(x,i)
184-
ArrayInterfaceCore.allowed_setindex!(c3,x_old - epsilon,i)
185-
df_tmp = real((f(c1) - f(c3)) / (2*epsilon))
186-
if eltype(df)<:Complex
187-
ArrayInterfaceCore.allowed_setindex!(c1,c1_old + im*epsilon,i)
188-
ArrayInterfaceCore.allowed_setindex!(c3,x_old - im*epsilon,i)
189-
df_tmp2 = im*imag( (f(c1) - f(c3)) / (2*im*epsilon) )
190-
ArrayInterfaceCore.allowed_setindex!(df,df_tmp-df_tmp2,i)
226+
c1_old = ArrayInterfaceCore.allowed_getindex(c1, i)
227+
ArrayInterfaceCore.allowed_setindex!(c1, c1_old + epsilon, i)
228+
x_old = ArrayInterfaceCore.allowed_getindex(x, i)
229+
ArrayInterfaceCore.allowed_setindex!(c3, x_old - epsilon, i)
230+
df_tmp = real((f(c1) - f(c3)) / (2 * epsilon))
231+
if eltype(df) <: Complex
232+
ArrayInterfaceCore.allowed_setindex!(c1, c1_old + im * epsilon, i)
233+
ArrayInterfaceCore.allowed_setindex!(c3, x_old - im * epsilon, i)
234+
df_tmp2 = im * imag((f(c1) - f(c3)) / (2 * im * epsilon))
235+
ArrayInterfaceCore.allowed_setindex!(df, df_tmp - df_tmp2, i)
191236
else
192-
ArrayInterfaceCore.allowed_setindex!(df,df_tmp,i)
237+
ArrayInterfaceCore.allowed_setindex!(df, df_tmp, i)
193238
end
194-
ArrayInterfaceCore.allowed_setindex!(c1,c1_old, i)
195-
ArrayInterfaceCore.allowed_setindex!(c3,x_old,i)
239+
ArrayInterfaceCore.allowed_setindex!(c1, c1_old, i)
240+
ArrayInterfaceCore.allowed_setindex!(c3, x_old, i)
196241
end
197242
elseif fdtype == Val(:complex) && returntype <: Real
198-
copyto!(c1,x)
243+
copyto!(c1, x)
199244
epsilon_complex = eps(real(eltype(x)))
200245
# we use c1 here to avoid typing issues with x
201246
@inbounds for i eachindex(x)
202-
c1_old = ArrayInterfaceCore.allowed_getindex(c1,i)
203-
ArrayInterfaceCore.allowed_setindex!(c1,c1_old+im*epsilon_complex,i)
204-
ArrayInterfaceCore.allowed_setindex!(df,imag(f(c1)) / epsilon_complex,i)
205-
ArrayInterfaceCore.allowed_setindex!(c1,c1_old,i)
247+
c1_old = ArrayInterfaceCore.allowed_getindex(c1, i)
248+
ArrayInterfaceCore.allowed_setindex!(c1, c1_old + im * epsilon_complex, i)
249+
ArrayInterfaceCore.allowed_setindex!(df, imag(f(c1)) / epsilon_complex, i)
250+
ArrayInterfaceCore.allowed_setindex!(c1, c1_old, i)
206251
end
207252
else
208253
fdtype_error(returntype)
@@ -223,11 +268,11 @@ function finite_difference_gradient!(
223268
# c2 is Nothing
224269
fx, c1, c2, c3 = cache.fx, cache.c1, cache.c2, cache.c3
225270
if fdtype != Val(:complex)
226-
if eltype(df)<:Complex && !(eltype(x)<:Complex)
227-
copyto!(c1,x)
271+
if eltype(df) <: Complex && !(eltype(x) <: Complex)
272+
copyto!(c1, x)
228273
end
229274
end
230-
copyto!(c3,x)
275+
copyto!(c3, x)
231276
if fdtype == Val(:forward)
232277
for i eachindex(x)
233278
epsilon = compute_epsilon(fdtype, x[i], relstep, absstep, dir)
@@ -244,21 +289,21 @@ function finite_difference_gradient!(
244289
end
245290

246291
df[i] = real(dfi)
247-
if eltype(df)<:Complex
248-
if eltype(x)<:Complex
292+
if eltype(df) <: Complex
293+
if eltype(x) <: Complex
249294
c3[i] += im * epsilon
250295
if typeof(fx) != Nothing
251-
dfi = (f(c3) - fx) / (im*epsilon)
296+
dfi = (f(c3) - fx) / (im * epsilon)
252297
else
253-
dfi = (f(c3) - fx0) / (im*epsilon)
298+
dfi = (f(c3) - fx0) / (im * epsilon)
254299
end
255300
c3[i] = x_old
256301
else
257302
c1[i] += im * epsilon
258303
if typeof(fx) != Nothing
259-
dfi = (f(c1) - fx) / (im*epsilon)
304+
dfi = (f(c1) - fx) / (im * epsilon)
260305
else
261-
dfi = (f(c1) - fx0) / (im*epsilon)
306+
dfi = (f(c1) - fx0) / (im * epsilon)
262307
end
263308
c1[i] = x_old
264309
end
@@ -274,33 +319,33 @@ function finite_difference_gradient!(
274319
c3[i] = x_old - epsilon
275320
dfi -= f(c3)
276321
c3[i] = x_old
277-
df[i] = real(dfi / (2*epsilon))
278-
if eltype(df)<:Complex
279-
if eltype(x)<:Complex
280-
c3[i] += im*epsilon
322+
df[i] = real(dfi / (2 * epsilon))
323+
if eltype(df) <: Complex
324+
if eltype(x) <: Complex
325+
c3[i] += im * epsilon
281326
dfi = f(c3)
282-
c3[i] = x_old - im*epsilon
327+
c3[i] = x_old - im * epsilon
283328
dfi -= f(c3)
284329
c3[i] = x_old
285330
else
286-
c1[i] += im*epsilon
331+
c1[i] += im * epsilon
287332
dfi = f(c1)
288-
c1[i] = x_old - im*epsilon
333+
c1[i] = x_old - im * epsilon
289334
dfi -= f(c1)
290335
c1[i] = x_old
291336
end
292-
df[i] -= im*imag(dfi / (2*im*epsilon))
337+
df[i] -= im * imag(dfi / (2 * im * epsilon))
293338
end
294339
end
295-
elseif fdtype==Val(:complex) && returntype<:Real && eltype(df)<:Real && eltype(x)<:Real
296-
copyto!(c1,x)
340+
elseif fdtype == Val(:complex) && returntype <: Real && eltype(df) <: Real && eltype(x) <: Real
341+
copyto!(c1, x)
297342
epsilon_complex = eps(real(eltype(x)))
298343
# we use c1 here to avoid typing issues with x
299344
@inbounds for i eachindex(x)
300345
c1_old = c1[i]
301-
c1[i] += im*epsilon_complex
302-
df[i] = imag(f(c1)) / epsilon_complex
303-
c1[i] = c1_old
346+
c1[i] += im * epsilon_complex
347+
df[i] = imag(f(c1)) / epsilon_complex
348+
c1[i] = c1_old
304349
end
305350
else
306351
fdtype_error(returntype)
@@ -330,9 +375,9 @@ function finite_difference_gradient!(
330375
if fdtype == Val(:forward)
331376
epsilon = compute_epsilon(Val(:forward), x, relstep, absstep, dir)
332377
if inplace == Val(true)
333-
f(c1, x+epsilon)
378+
f(c1, x + epsilon)
334379
else
335-
_c1 = f(x+epsilon)
380+
_c1 = f(x + epsilon)
336381
end
337382
if typeof(fx) != Nothing
338383
@. df = (_c1 - fx) / epsilon
@@ -347,19 +392,19 @@ function finite_difference_gradient!(
347392
elseif fdtype == Val(:central)
348393
epsilon = compute_epsilon(Val(:central), x, relstep, absstep, dir)
349394
if inplace == Val(true)
350-
f(c1, x+epsilon)
351-
f(c2, x-epsilon)
395+
f(c1, x + epsilon)
396+
f(c2, x - epsilon)
352397
else
353-
_c1 = f(x+epsilon)
354-
_c2 = f(x-epsilon)
398+
_c1 = f(x + epsilon)
399+
_c2 = f(x - epsilon)
355400
end
356-
@. df = (_c1 - _c2) / (2*epsilon)
401+
@. df = (_c1 - _c2) / (2 * epsilon)
357402
elseif fdtype == Val(:complex) && returntype <: Real
358403
epsilon_complex = eps(real(eltype(x)))
359404
if inplace == Val(true)
360-
f(c1, x+im*epsilon_complex)
405+
f(c1, x + im * epsilon_complex)
361406
else
362-
_c1 = f(x+im*epsilon_complex)
407+
_c1 = f(x + im * epsilon_complex)
363408
end
364409
@. df = imag(_c1) / epsilon_complex
365410
else

0 commit comments

Comments
 (0)