1- struct GradientCache{CacheType1, CacheType2, CacheType3, CacheType4, fdtype, returntype, inplace}
2- fx :: CacheType1
3- c1 :: CacheType2
4- c2 :: CacheType3
5- c3 :: CacheType4
1+ struct GradientCache{CacheType1,CacheType2,CacheType3,CacheType4,fdtype,returntype,inplace}
2+ fx:: CacheType1
3+ c1:: CacheType2
4+ c2:: CacheType3
5+ c3:: CacheType4
66end
77
88function GradientCache (
99 df,
1010 x,
11- fdtype = Val (:central ),
12- returntype = eltype (df),
13- inplace = Val (true ))
11+ fdtype= Val (:central ),
12+ returntype= eltype (df),
13+ inplace= Val (true ))
1414
1515 fdtype isa Type && (fdtype = fdtype ())
1616 inplace isa Type && (inplace = inplace ())
17- if typeof (x)<: AbstractArray # the vector->scalar case
18- if fdtype!= Val (:complex ) # complex-mode FD only needs one cache, for x+eps*im
19- if typeof (x)<: StridedVector
20- if eltype (df)<: Complex && ! (eltype (x)<: Complex )
17+ if typeof (x) <: AbstractArray # the vector->scalar case
18+ if fdtype != Val (:complex ) # complex-mode FD only needs one cache, for x+eps*im
19+ if typeof (x) <: StridedVector
20+ if eltype (df) <: Complex && ! (eltype (x) <: Complex )
2121 _c1 = zero (Complex{eltype (x)}) .* x
2222 _c2 = nothing
2323 else
@@ -29,7 +29,7 @@ function GradientCache(
2929 _c2 = zero (real (eltype (x))) .* x
3030 end
3131 else
32- if ! (returntype<: Real )
32+ if ! (returntype <: Real )
3333 fdtype_error (returntype)
3434 else
3535 _c1 = x .+ zero (eltype (x)) .* im
@@ -50,19 +50,61 @@ function GradientCache(
5050 end
5151
5252 GradientCache{Nothing,typeof (_c1),typeof (_c2),typeof (_c3),fdtype,
53- returntype,inplace}(nothing ,_c1,_c2,_c3)
53+ returntype,inplace}(nothing , _c1, _c2, _c3)
5454
5555end
5656
57+ """
58+ GradientCache(c1, c2, c3, fx, fdtype = Val(:central), returntype = eltype(fx), inplace = Val(false))
59+
60+ Construct a non-allocating gradient cache.
61+
62+ # Arguments
63+ - `c1`, `c2`, `c3`: (Non-aliased) caches for the input vector.
64+ - `fx`: Cached function call.
65+ - `fdtype = Val(:central)`: Method for cmoputing the finite difference.
66+ - `returntype = eltype(fx)`: Element type for the returned function value.
67+ - `inplace = Val(false)`: Whether the function is computed in-place or not.
68+
69+ # Output
70+ The output is a [`GradientCache`](@ref) struct.
71+
72+ ```julia
73+ julia> x = [1.0, 3.0]
74+ 2-element Vector{Float64}:
75+ 1.0
76+ 3.0
77+
78+ julia> _f = x -> x[1] + x[2]
79+ #13 (generic function with 1 method)
80+
81+ julia> fx = _f(x)
82+ 4.0
83+
84+ julia> gradcache = GradientCache(copy(x), copy(x), copy(x), fx)
85+ GradientCache{Float64, Vector{Float64}, Vector{Float64}, Vector{Float64}, Val{:central}(), Float64, Val{false}()}(4.0, [1.0, 3.0], [1.0, 3.0], [1.0, 3.0])
86+ ```
87+ """
88+ function GradientCache (
89+ fx:: Fx ,# match order in struct for Setfield
90+ c1:: T ,
91+ c2:: T ,
92+ c3:: T ,
93+ fdtype= Val (:central ),
94+ returntype= eltype (fx),
95+ inplace= Val (true )) where {T,Fx} # Val(false) isn't so important for vector -> scalar, it gets ignored in that case anyway.
96+ GradientCache {Fx,T,T,T,fdtype,returntype,inplace} (fx, c1, c2, c3)
97+ end
98+
5799function finite_difference_gradient (
58100 f,
59101 x,
60- fdtype = Val (:central ),
61- returntype = eltype (x),
62- inplace = Val (true ),
63- fx = nothing ,
64- c1 = nothing ,
65- c2 = nothing ;
102+ fdtype= Val (:central ),
103+ returntype= eltype (x),
104+ inplace= Val (true ),
105+ fx= nothing ,
106+ c1= nothing ,
107+ c2= nothing ;
66108 relstep= default_relstep (fdtype, eltype (x)),
67109 absstep= relstep,
68110 dir= true )
@@ -72,12 +114,15 @@ function finite_difference_gradient(
72114 df = zero (returntype) .* x
73115 else
74116 if inplace == Val (true )
75- if typeof (fx)== Nothing && typeof (c1)== Nothing && typeof (c2)== Nothing
117+ if typeof (fx) == Nothing && typeof (c1) == Nothing && typeof (c2) == Nothing
76118 error (" In the scalar->vector in-place map case, at least one of fx, c1 or c2 must be provided, otherwise we cannot infer the return size." )
77119 else
78- if c1 != nothing df = zero (c1)
79- elseif fx != nothing df = zero (fx)
80- elseif c2 != nothing df = zero (c2)
120+ if c1 != nothing
121+ df = zero (c1)
122+ elseif fx != nothing
123+ df = zero (fx)
124+ elseif c2 != nothing
125+ df = zero (c2)
81126 end
82127 end
83128 else
@@ -138,71 +183,71 @@ function finite_difference_gradient!(
138183 fx, c1, c2, c3 = cache. fx, cache. c1, cache. c2, cache. c3
139184 if fdtype != Val (:complex ) && ArrayInterfaceCore. fast_scalar_indexing (c2)
140185 @. c2 = compute_epsilon (fdtype, x, relstep, absstep, dir)
141- copyto! (c1,x)
186+ copyto! (c1, x)
142187 end
143- copyto! (c3,x)
188+ copyto! (c3, x)
144189 if fdtype == Val (:forward )
145190 @inbounds for i ∈ eachindex (x)
146191 if ArrayInterfaceCore. fast_scalar_indexing (c2)
147- epsilon = ArrayInterfaceCore. allowed_getindex (c2,i) * dir
192+ epsilon = ArrayInterfaceCore. allowed_getindex (c2, i) * dir
148193 else
149- epsilon = compute_epsilon (fdtype, x, relstep, absstep, dir)* dir
194+ epsilon = compute_epsilon (fdtype, x, relstep, absstep, dir) * dir
150195 end
151- c1_old = ArrayInterfaceCore. allowed_getindex (c1,i)
152- ArrayInterfaceCore. allowed_setindex! (c1,c1_old + epsilon,i)
196+ c1_old = ArrayInterfaceCore. allowed_getindex (c1, i)
197+ ArrayInterfaceCore. allowed_setindex! (c1, c1_old + epsilon, i)
153198 if typeof (fx) != Nothing
154199 dfi = (f (c1) - fx) / epsilon
155200 else
156201 fx0 = f (x)
157202 dfi = (f (c1) - fx0) / epsilon
158203 end
159204 df_tmp = real (dfi)
160- if eltype (df)<: Complex
161- ArrayInterfaceCore. allowed_setindex! (c1,c1_old + im * epsilon,i)
205+ if eltype (df) <: Complex
206+ ArrayInterfaceCore. allowed_setindex! (c1, c1_old + im * epsilon, i)
162207 if typeof (fx) != Nothing
163- dfi = (f (c1) - fx) / (im* epsilon)
208+ dfi = (f (c1) - fx) / (im * epsilon)
164209 else
165- dfi = (f (c1) - fx0) / (im* epsilon)
210+ dfi = (f (c1) - fx0) / (im * epsilon)
166211 end
167- ArrayInterfaceCore. allowed_setindex! (c1,c1_old,i)
212+ ArrayInterfaceCore. allowed_setindex! (c1, c1_old, i)
168213 ArrayInterfaceCore. allowed_setindex! (df, df_tmp - im * imag (dfi), i)
169214 else
170215 ArrayInterfaceCore. allowed_setindex! (df, df_tmp, i)
171- ArrayInterfaceCore. allowed_setindex! (c1,c1_old,i)
216+ ArrayInterfaceCore. allowed_setindex! (c1, c1_old, i)
172217 end
173218 end
174219 elseif fdtype == Val (:central )
175220 @inbounds for i ∈ eachindex (x)
176221 if ArrayInterfaceCore. fast_scalar_indexing (c2)
177- epsilon = ArrayInterfaceCore. allowed_getindex (c2,i) * dir
222+ epsilon = ArrayInterfaceCore. allowed_getindex (c2, i) * dir
178223 else
179- epsilon = compute_epsilon (fdtype, x, relstep, absstep, dir)* dir
224+ epsilon = compute_epsilon (fdtype, x, relstep, absstep, dir) * dir
180225 end
181- c1_old = ArrayInterfaceCore. allowed_getindex (c1,i)
182- ArrayInterfaceCore. allowed_setindex! (c1,c1_old + epsilon, i)
183- x_old = ArrayInterfaceCore. allowed_getindex (x,i)
184- ArrayInterfaceCore. allowed_setindex! (c3,x_old - epsilon,i)
185- df_tmp = real ((f (c1) - f (c3)) / (2 * epsilon))
186- if eltype (df)<: Complex
187- ArrayInterfaceCore. allowed_setindex! (c1,c1_old + im* epsilon,i)
188- ArrayInterfaceCore. allowed_setindex! (c3,x_old - im* epsilon,i)
189- df_tmp2 = im* imag ( (f (c1) - f (c3)) / (2 * im * epsilon) )
190- ArrayInterfaceCore. allowed_setindex! (df,df_tmp- df_tmp2,i)
226+ c1_old = ArrayInterfaceCore. allowed_getindex (c1, i)
227+ ArrayInterfaceCore. allowed_setindex! (c1, c1_old + epsilon, i)
228+ x_old = ArrayInterfaceCore. allowed_getindex (x, i)
229+ ArrayInterfaceCore. allowed_setindex! (c3, x_old - epsilon, i)
230+ df_tmp = real ((f (c1) - f (c3)) / (2 * epsilon))
231+ if eltype (df) <: Complex
232+ ArrayInterfaceCore. allowed_setindex! (c1, c1_old + im * epsilon, i)
233+ ArrayInterfaceCore. allowed_setindex! (c3, x_old - im * epsilon, i)
234+ df_tmp2 = im * imag ((f (c1) - f (c3)) / (2 * im * epsilon))
235+ ArrayInterfaceCore. allowed_setindex! (df, df_tmp - df_tmp2, i)
191236 else
192- ArrayInterfaceCore. allowed_setindex! (df,df_tmp,i)
237+ ArrayInterfaceCore. allowed_setindex! (df, df_tmp, i)
193238 end
194- ArrayInterfaceCore. allowed_setindex! (c1,c1_old, i)
195- ArrayInterfaceCore. allowed_setindex! (c3,x_old,i)
239+ ArrayInterfaceCore. allowed_setindex! (c1, c1_old, i)
240+ ArrayInterfaceCore. allowed_setindex! (c3, x_old, i)
196241 end
197242 elseif fdtype == Val (:complex ) && returntype <: Real
198- copyto! (c1,x)
243+ copyto! (c1, x)
199244 epsilon_complex = eps (real (eltype (x)))
200245 # we use c1 here to avoid typing issues with x
201246 @inbounds for i ∈ eachindex (x)
202- c1_old = ArrayInterfaceCore. allowed_getindex (c1,i)
203- ArrayInterfaceCore. allowed_setindex! (c1,c1_old+ im * epsilon_complex,i)
204- ArrayInterfaceCore. allowed_setindex! (df,imag (f (c1)) / epsilon_complex,i)
205- ArrayInterfaceCore. allowed_setindex! (c1,c1_old,i)
247+ c1_old = ArrayInterfaceCore. allowed_getindex (c1, i)
248+ ArrayInterfaceCore. allowed_setindex! (c1, c1_old + im * epsilon_complex, i)
249+ ArrayInterfaceCore. allowed_setindex! (df, imag (f (c1)) / epsilon_complex, i)
250+ ArrayInterfaceCore. allowed_setindex! (c1, c1_old, i)
206251 end
207252 else
208253 fdtype_error (returntype)
@@ -223,11 +268,11 @@ function finite_difference_gradient!(
223268 # c2 is Nothing
224269 fx, c1, c2, c3 = cache. fx, cache. c1, cache. c2, cache. c3
225270 if fdtype != Val (:complex )
226- if eltype (df)<: Complex && ! (eltype (x)<: Complex )
227- copyto! (c1,x)
271+ if eltype (df) <: Complex && ! (eltype (x) <: Complex )
272+ copyto! (c1, x)
228273 end
229274 end
230- copyto! (c3,x)
275+ copyto! (c3, x)
231276 if fdtype == Val (:forward )
232277 for i ∈ eachindex (x)
233278 epsilon = compute_epsilon (fdtype, x[i], relstep, absstep, dir)
@@ -244,21 +289,21 @@ function finite_difference_gradient!(
244289 end
245290
246291 df[i] = real (dfi)
247- if eltype (df)<: Complex
248- if eltype (x)<: Complex
292+ if eltype (df) <: Complex
293+ if eltype (x) <: Complex
249294 c3[i] += im * epsilon
250295 if typeof (fx) != Nothing
251- dfi = (f (c3) - fx) / (im* epsilon)
296+ dfi = (f (c3) - fx) / (im * epsilon)
252297 else
253- dfi = (f (c3) - fx0) / (im* epsilon)
298+ dfi = (f (c3) - fx0) / (im * epsilon)
254299 end
255300 c3[i] = x_old
256301 else
257302 c1[i] += im * epsilon
258303 if typeof (fx) != Nothing
259- dfi = (f (c1) - fx) / (im* epsilon)
304+ dfi = (f (c1) - fx) / (im * epsilon)
260305 else
261- dfi = (f (c1) - fx0) / (im* epsilon)
306+ dfi = (f (c1) - fx0) / (im * epsilon)
262307 end
263308 c1[i] = x_old
264309 end
@@ -274,33 +319,33 @@ function finite_difference_gradient!(
274319 c3[i] = x_old - epsilon
275320 dfi -= f (c3)
276321 c3[i] = x_old
277- df[i] = real (dfi / (2 * epsilon))
278- if eltype (df)<: Complex
279- if eltype (x)<: Complex
280- c3[i] += im* epsilon
322+ df[i] = real (dfi / (2 * epsilon))
323+ if eltype (df) <: Complex
324+ if eltype (x) <: Complex
325+ c3[i] += im * epsilon
281326 dfi = f (c3)
282- c3[i] = x_old - im* epsilon
327+ c3[i] = x_old - im * epsilon
283328 dfi -= f (c3)
284329 c3[i] = x_old
285330 else
286- c1[i] += im* epsilon
331+ c1[i] += im * epsilon
287332 dfi = f (c1)
288- c1[i] = x_old - im* epsilon
333+ c1[i] = x_old - im * epsilon
289334 dfi -= f (c1)
290335 c1[i] = x_old
291336 end
292- df[i] -= im* imag (dfi / (2 * im * epsilon))
337+ df[i] -= im * imag (dfi / (2 * im * epsilon))
293338 end
294339 end
295- elseif fdtype== Val (:complex ) && returntype<: Real && eltype (df)<: Real && eltype (x)<: Real
296- copyto! (c1,x)
340+ elseif fdtype == Val (:complex ) && returntype <: Real && eltype (df) <: Real && eltype (x) <: Real
341+ copyto! (c1, x)
297342 epsilon_complex = eps (real (eltype (x)))
298343 # we use c1 here to avoid typing issues with x
299344 @inbounds for i ∈ eachindex (x)
300345 c1_old = c1[i]
301- c1[i] += im* epsilon_complex
302- df[i] = imag (f (c1)) / epsilon_complex
303- c1[i] = c1_old
346+ c1[i] += im * epsilon_complex
347+ df[i] = imag (f (c1)) / epsilon_complex
348+ c1[i] = c1_old
304349 end
305350 else
306351 fdtype_error (returntype)
@@ -330,9 +375,9 @@ function finite_difference_gradient!(
330375 if fdtype == Val (:forward )
331376 epsilon = compute_epsilon (Val (:forward ), x, relstep, absstep, dir)
332377 if inplace == Val (true )
333- f (c1, x+ epsilon)
378+ f (c1, x + epsilon)
334379 else
335- _c1 = f (x+ epsilon)
380+ _c1 = f (x + epsilon)
336381 end
337382 if typeof (fx) != Nothing
338383 @. df = (_c1 - fx) / epsilon
@@ -347,19 +392,19 @@ function finite_difference_gradient!(
347392 elseif fdtype == Val (:central )
348393 epsilon = compute_epsilon (Val (:central ), x, relstep, absstep, dir)
349394 if inplace == Val (true )
350- f (c1, x+ epsilon)
351- f (c2, x- epsilon)
395+ f (c1, x + epsilon)
396+ f (c2, x - epsilon)
352397 else
353- _c1 = f (x+ epsilon)
354- _c2 = f (x- epsilon)
398+ _c1 = f (x + epsilon)
399+ _c2 = f (x - epsilon)
355400 end
356- @. df = (_c1 - _c2) / (2 * epsilon)
401+ @. df = (_c1 - _c2) / (2 * epsilon)
357402 elseif fdtype == Val (:complex ) && returntype <: Real
358403 epsilon_complex = eps (real (eltype (x)))
359404 if inplace == Val (true )
360- f (c1, x+ im * epsilon_complex)
405+ f (c1, x + im * epsilon_complex)
361406 else
362- _c1 = f (x+ im * epsilon_complex)
407+ _c1 = f (x + im * epsilon_complex)
363408 end
364409 @. df = imag (_c1) / epsilon_complex
365410 else
0 commit comments