Skip to content

Commit c43f01e

Browse files
author
Roger-luo
committed
revise implementation & test
1 parent f5d031e commit c43f01e

2 files changed

Lines changed: 61 additions & 34 deletions

File tree

src/grad.jl

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,62 @@
1+
export grad, jacobian
2+
function replace_arg(x, xs::Tuple, k::Int)
3+
return ntuple(length(xs)) do p
4+
if p == k
5+
x
6+
else
7+
xs[p]
8+
end
9+
end
10+
end
11+
112
"""
2-
grad(fdm, f, x::AbstractVector)
13+
grad(fdm, f, xs...)
314
4-
Approximate the gradient of `f` at `x` using `fdm`. Assumes that `f(x)` is scalar.
15+
Approximate the gradient of `f` at `xs...` using `fdm`. Assumes that `f(xs...)` is scalar.
516
"""
6-
function grad(fdm, f, x::Vector{T}) where T<:Real
7-
v, dx, tmp = fill(zero(T), size(x)), similar(x), similar(x)
8-
for n in eachindex(x)
9-
v[n] = one(T)
10-
dx[n] = fdm(function(ϵ)
11-
tmp .= x .+ ϵ .* v
12-
return f(tmp)
13-
end,
14-
zero(T),
15-
)
16-
v[n] = zero(T)
17+
function grad end
18+
19+
function grad(fdm, f, x::AbstractArray{T}) where T
20+
dx, tmp = similar(x), similar(x)
21+
for k in eachindex(x)
22+
dx[k] = fdm(zero(T)) do ϵ
23+
tmp .= x
24+
tmp[k] += ϵ
25+
return f(tmp)
26+
end
1727
end
1828
return dx
1929
end
2030

31+
grad(fdm, f, x::Real) = fdm(f, x)
32+
33+
function grad(fdm, f, xs...)
34+
return ntuple(length(xs)) do k
35+
grad(fdm, x->f(replace_arg(x, xs, k)...), xs[k])
36+
end
37+
end
38+
2139
"""
22-
jacobian(fdm, f, x::AbstractVector{<:Real}, D::Int)
23-
jacobian(fdm, f, x::AbstractVector{<:Real})
40+
jacobian(fdm, f, xs::Union{Real, AbstractArray{<:Real}}[; dim::Int=length(f(x))])
2441
2542
Approximate the Jacobian of `f` at `x` using `fdm`. `f(x)` must be a length `D` vector. If
2643
`D` is not provided, then `f(x)` is computed once to determine the output size.
2744
"""
28-
function jacobian(fdm, f, x::Vector{T}, D::Int) where {T<:Real}
29-
J = Matrix{T}(undef, D, length(x))
30-
for d in 1:D
31-
J[d, :] = grad(fdm, x->f(x)[d], x)
45+
function jacobian(fdm, f, x::Union{T, AbstractArray{T}}; dim::Int=length(f(x))) where {T <: Real}
46+
J = Matrix{float(T)}(undef, dim, length(x))
47+
for d in 1:dim
48+
gs = grad(fdm, x->f(x)[d], x)
49+
for k in 1:length(x)
50+
J[d, k] = gs[k]
51+
end
3252
end
3353
return J
3454
end
35-
jacobian(fdm, f, x::Vector{<:Real}) = jacobian(fdm, f, x, length(f(x)))
36-
37-
function jacobian(fdm, f, x::Real, D::Int)
38-
x_vec, vec_to_x = to_vec(x)
39-
return jacobian(fdm, x->f(vec_to_x(x)), x_vec, D)
40-
end
41-
42-
replace_arg(k, xs::Tuple, x) = (xs[1:k-1]..., x, xs[k+1:end]...)
4355

44-
function jacobian(fdm, f, xs...)
45-
D = length(f(xs...))
46-
N = length(xs)
47-
return ntuple(k->jacobian(fdm, x->f(replace_arg(k, xs, x)), xs[k], D), N)
56+
function jacobian(fdm, f, xs...; dim::Int=length(f(xs...)))
57+
return ntuple(length(xs)) do k
58+
jacobian(fdm, x->f(replace_arg(x, xs, k)...), xs[k]; dim=dim)
59+
end
4860
end
4961

5062
"""
@@ -59,7 +71,7 @@ _jvp(fdm, f, x::Vector{<:Real}, ẋ::AV{<:Real}) = jacobian(fdm, f, x) * ẋ
5971
6072
Convenience function to compute `jacobian(f, x)' * ȳ`.
6173
"""
62-
_j′vp(fdm, f, ȳ::AV{<:Real}, x::Vector{<:Real}) = jacobian(fdm, f, x, length(ȳ))' *
74+
_j′vp(fdm, f, ȳ::AV{<:Real}, x::Vector{<:Real}) = jacobian(fdm, f, x; dim=length(ȳ))' *
6375

6476
"""
6577
jvp(fdm, f, x, ẋ)

test/grad.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ Base.length(x::DummyType) = size(x.X, 1)
2626

2727
function check_jac_and_jvp_and_j′vp(fdm, f, ȳ, x, ẋ, J_exact)
2828
xc = copy(x)
29-
@test jacobian(fdm, f, x, length(ȳ)) J_exact
30-
@test jacobian(fdm, f, x) == jacobian(fdm, f, x, length(ȳ))
29+
@test jacobian(fdm, f, x; dim=length(ȳ)) J_exact
30+
@test jacobian(fdm, f, x) == jacobian(fdm, f, x; dim=length(ȳ))
3131
@test _jvp(fdm, f, x, ẋ) J_exact *
3232
@test _j′vp(fdm, f, ȳ, x) J_exact' *
3333
@test xc == x
@@ -44,6 +44,21 @@ Base.length(x::DummyType) = size(x.X, 1)
4444
@test Ac == A
4545
end
4646

47+
@testset "multi vars jacobian" begin
48+
fdm = central_fdm(5, 1)
49+
f1(x, y) = x * y + x
50+
x, y = rand(3, 3), rand(3, 3)
51+
jac_xs = jacobian(fdm, f1, x, y)
52+
@test jac_xs[1] jacobian(fdm, x->f1(x, y), x)
53+
@test jac_xs[2] jacobian(fdm, y->f1(x, y), y)
54+
55+
# mixed scalar and matrices
56+
x, y = rand(3, 3), 2
57+
jac_xs = jacobian(fdm, f1, x, y)
58+
@test jac_xs[1] jacobian(fdm, x->f1(x, y), x)
59+
@test jac_xs[2] jacobian(fdm, y->f1(x, y), y)
60+
end
61+
4762
function test_to_vec(x)
4863
x_vec, back = to_vec(x)
4964
@test x_vec isa Vector

0 commit comments

Comments
 (0)