Skip to content

Commit 75d3c05

Browse files
committed
fix Mooncake friendly_tangents compatibility
1 parent a5ecbe0 commit 75d3c05

2 files changed

Lines changed: 57 additions & 2 deletions

File tree

  • DifferentiationInterface

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,32 @@ end
1111

1212
function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
1313
if get_config(backend).friendly_tangents
14-
# zero(x) but safer
15-
return tangent_to_primal!!(_copy_output(x), zero_tangent(x))
14+
# Mooncake 0.5.25+ replaced `tangent_to_primal!!` with the
15+
# `tangent_to_friendly!!` framework. For this internal backup we still
16+
# need a primal-shaped value, so use the `AsPrimal` path when
17+
# available and fall back for older Mooncake releases.
18+
return tangent_to_user_primal(zero_tangent(x), x)
1619
else
1720
return zero_tangent(x)
1821
end
1922
end
23+
24+
@inline maybe_getfield(mod, name::Symbol) = isdefined(mod, name) ? getfield(mod, name) : nothing
25+
26+
const mooncake_tangent_to_friendly = maybe_getfield(Mooncake, Symbol("tangent_to_friendly!!"))
27+
const mooncake_friendly_tangent_cache = maybe_getfield(Mooncake, :FriendlyTangentCache)
28+
const mooncake_as_primal = maybe_getfield(Mooncake, :AsPrimal)
29+
const mooncake_no_cache = maybe_getfield(Mooncake, :NoCache)
30+
31+
function tangent_to_user_primal(tx, x)
32+
if !isnothing(mooncake_tangent_to_friendly) &&
33+
!isnothing(mooncake_friendly_tangent_cache) &&
34+
!isnothing(mooncake_as_primal) &&
35+
!isnothing(mooncake_no_cache)
36+
dest = mooncake_friendly_tangent_cache{mooncake_as_primal}(_copy_output(x))
37+
cache = isbitstype(typeof(x)) ? mooncake_no_cache() : IdDict{Any,Any}()
38+
return mooncake_tangent_to_friendly(dest, x, tx, cache)
39+
else
40+
return tangent_to_primal!!(_copy_output(x), tx)
41+
end
42+
end

DifferentiationInterface/test/Back/Mooncake/test.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
include("../../testutils.jl")
22

33
using DifferentiationInterface, DifferentiationInterfaceTest
4+
using LinearAlgebra: Hermitian, SymTridiagonal, Symmetric
45
using Mooncake: Mooncake
56
using Test
67

@@ -80,3 +81,34 @@ test_differentiation(
8081
logging = LOGGING,
8182
excluded = SECOND_ORDER
8283
)
84+
85+
@testset "Friendly tangents structured matrices" begin
86+
backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))
87+
inputs = (
88+
Symmetric([2.0 1.0; 1.0 3.0]),
89+
Hermitian(ComplexF64[2 1 + im; 1 - im 3]),
90+
SymTridiagonal([2.0, 3.0, 4.0], [5.0, 6.0]),
91+
)
92+
f(x) = real(sum(abs2, x))
93+
94+
@testset "$(typeof(x))" for x in inputs
95+
grad = gradient(f, backend, x)
96+
y, grad2 = value_and_gradient(f, backend, x)
97+
pb = only(pullback(identity, backend, x, (x,)))
98+
99+
@test grad isa Matrix
100+
@test grad2 isa Matrix
101+
@test pb isa Matrix
102+
@test grad == grad2
103+
@test y == f(x)
104+
@test pb == Matrix(x)
105+
106+
grad_dense = zero(Matrix(x))
107+
@test gradient!(f, grad_dense, backend, x) === grad_dense
108+
@test grad_dense == grad
109+
110+
tx_dense = (zero(Matrix(x)),)
111+
@test only(pullback!(identity, tx_dense, backend, x, (x,))) === tx_dense[1]
112+
@test tx_dense[1] == pb
113+
end
114+
end

0 commit comments

Comments
 (0)