@@ -26,8 +26,8 @@ Base.length(x::DummyType) = size(x.X, 1)
2626
2727 function check_jac_and_jvp_and_j′vp (fdm, f, ȳ, x, ẋ, J_exact)
2828 xc = copy (x)
29- @test jacobian (fdm, f, x, length (ȳ)) ≈ J_exact
30- @test jacobian (fdm, f, x) == jacobian (fdm, f, x, length (ȳ))
29+ @test jacobian (fdm, f, x; len = length (ȳ)) ≈ J_exact
30+ @test jacobian (fdm, f, x) == jacobian (fdm, f, x; len = length (ȳ))
3131 @test _jvp (fdm, f, x, ẋ) ≈ J_exact * ẋ
3232 @test _j′vp (fdm, f, ȳ, x) ≈ transpose (J_exact) * ȳ
3333 @test xc == x
@@ -44,6 +44,62 @@ Base.length(x::DummyType) = size(x.X, 1)
4444 @test Ac == A
4545 end
4646
47+ @testset " multi vars jacobian/grad" begin
48+ rng, fdm = MersenneTwister (123456 ), central_fdm (5 , 1 )
49+
50+ f1 (x, y) = x * y + x
51+ f2 (x, y) = sum (x * y + x)
52+ f3 (x:: Tuple ) = sum (x[1 ]) + x[2 ]
53+ f4 (d:: Dict ) = sum (d[:x ]) + d[:y ]
54+
55+ @testset " jacobian" begin
56+ @testset " check multiple matrices" begin
57+ x, y = rand (rng, 3 , 3 ), rand (rng, 3 , 3 )
58+ jac_xs = jacobian (fdm, f1, x, y)
59+ @test jac_xs[1 ] ≈ jacobian (fdm, x-> f1 (x, y), x)
60+ @test jac_xs[2 ] ≈ jacobian (fdm, y-> f1 (x, y), y)
61+ end
62+
63+ @testset " check mixed scalar and matrices" begin
64+ x, y = rand (3 , 3 ), 2
65+ jac_xs = jacobian (fdm, f1, x, y)
66+ @test jac_xs[1 ] ≈ jacobian (fdm, x-> f1 (x, y), x)
67+ @test jac_xs[2 ] ≈ jacobian (fdm, y-> f1 (x, y), y)
68+ end
69+ end
70+
71+ @testset " grad" begin
72+ @testset " check multiple matrices" begin
73+ x, y = rand (rng, 3 , 3 ), rand (rng, 3 , 3 )
74+ dxs = grad (fdm, f2, x, y)
75+ @test dxs[1 ] ≈ grad (fdm, x-> f2 (x, y), x)
76+ @test dxs[2 ] ≈ grad (fdm, y-> f2 (x, y), y)
77+ end
78+
79+ @testset " check mixed scalar & matrices" begin
80+ x, y = rand (rng, 3 , 3 ), 2
81+ dxs = grad (fdm, f2, x, y)
82+ @test dxs[1 ] ≈ grad (fdm, x-> f2 (x, y), x)
83+ @test dxs[2 ] ≈ grad (fdm, y-> f2 (x, y), y)
84+ end
85+
86+ @testset " check tuple" begin
87+ x, y = rand (rng, 3 , 3 ), 2
88+ dxs = grad (fdm, f3, (x, y))
89+ @test dxs[1 ] ≈ grad (fdm, x-> f3 ((x, y)), x)
90+ @test dxs[2 ] ≈ grad (fdm, y-> f3 ((x, y)), y)
91+ end
92+
93+ @testset " check dict" begin
94+ x, y = rand (rng, 3 , 3 ), 2
95+ d = Dict (:x => x, :y => y)
96+ dxs = grad (fdm, f4, d)
97+ @test dxs[:x ] ≈ grad (fdm, x-> f3 ((x, y)), x)
98+ @test dxs[:y ] ≈ grad (fdm, y-> f3 ((x, y)), y)
99+ end
100+ end
101+ end
102+
47103 function test_to_vec (x)
48104 x_vec, back = to_vec (x)
49105 @test x_vec isa Vector
0 commit comments