|
1 | 1 | include("../../testutils.jl") |
2 | 2 |
|
3 | 3 | using DifferentiationInterface, DifferentiationInterfaceTest |
| 4 | +using LinearAlgebra: Hermitian, SymTridiagonal, Symmetric |
4 | 5 | using Mooncake: Mooncake |
5 | 6 | using Test |
6 | 7 |
|
@@ -80,3 +81,34 @@ test_differentiation( |
80 | 81 | logging = LOGGING, |
81 | 82 | excluded = SECOND_ORDER |
82 | 83 | ) |
| 84 | + |
| 85 | +@testset "Friendly tangents structured matrices" begin |
| 86 | + backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) |
| 87 | + inputs = ( |
| 88 | + Symmetric([2.0 1.0; 1.0 3.0]), |
| 89 | + Hermitian(ComplexF64[2 1 + im; 1 - im 3]), |
| 90 | + SymTridiagonal([2.0, 3.0, 4.0], [5.0, 6.0]), |
| 91 | + ) |
| 92 | + f(x) = real(sum(abs2, x)) |
| 93 | + |
| 94 | + @testset "$(typeof(x))" for x in inputs |
| 95 | + grad = gradient(f, backend, x) |
| 96 | + y, grad2 = value_and_gradient(f, backend, x) |
| 97 | + pb = only(pullback(identity, backend, x, (x,))) |
| 98 | + |
| 99 | + @test grad isa Matrix |
| 100 | + @test grad2 isa Matrix |
| 101 | + @test pb isa Matrix |
| 102 | + @test grad == grad2 |
| 103 | + @test y == f(x) |
| 104 | + @test pb == Matrix(x) |
| 105 | + |
| 106 | + grad_dense = zero(Matrix(x)) |
| 107 | + @test gradient!(f, grad_dense, backend, x) === grad_dense |
| 108 | + @test grad_dense == grad |
| 109 | + |
| 110 | + tx_dense = (zero(Matrix(x)),) |
| 111 | + @test only(pullback!(identity, tx_dense, backend, x, (x,))) === tx_dense[1] |
| 112 | + @test tx_dense[1] == pb |
| 113 | + end |
| 114 | +end |
0 commit comments