Skip to content

Commit f420234

Browse files
committed
format Mooncake fix
1 parent 75d3c05 commit f420234

2 files changed

Lines changed: 24 additions & 30 deletions

File tree

  • DifferentiationInterface

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,21 @@ function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
2121
end
2222
end
2323

24-
@inline maybe_getfield(mod, name::Symbol) = isdefined(mod, name) ? getfield(mod, name) : nothing
24+
@inline maybe_getfield(mod, name::Symbol) =
25+
isdefined(mod, name) ? getfield(mod, name) : nothing
2526

26-
const mooncake_tangent_to_friendly = maybe_getfield(Mooncake, Symbol("tangent_to_friendly!!"))
27+
const mooncake_tangent_to_friendly = maybe_getfield(
28+
Mooncake, Symbol("tangent_to_friendly!!")
29+
)
2730
const mooncake_friendly_tangent_cache = maybe_getfield(Mooncake, :FriendlyTangentCache)
2831
const mooncake_as_primal = maybe_getfield(Mooncake, :AsPrimal)
2932
const mooncake_no_cache = maybe_getfield(Mooncake, :NoCache)
3033

3134
function tangent_to_user_primal(tx, x)
3235
if !isnothing(mooncake_tangent_to_friendly) &&
33-
!isnothing(mooncake_friendly_tangent_cache) &&
34-
!isnothing(mooncake_as_primal) &&
35-
!isnothing(mooncake_no_cache)
36+
!isnothing(mooncake_friendly_tangent_cache) &&
37+
!isnothing(mooncake_as_primal) &&
38+
!isnothing(mooncake_no_cache)
3639
dest = mooncake_friendly_tangent_cache{mooncake_as_primal}(_copy_output(x))
3740
cache = isbitstype(typeof(x)) ? mooncake_no_cache() : IdDict{Any,Any}()
3841
return mooncake_tangent_to_friendly(dest, x, tx, cache)

DifferentiationInterface/test/Back/Mooncake/test.jl

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ nomatrix(scens) = filter(s -> !(s.x isa AbstractMatrix) && !(s.y isa AbstractMat
1313
backends = [
1414
AutoMooncake(),
1515
AutoMooncakeForward(),
16-
AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)),
17-
AutoMooncakeForward(; config = Mooncake.Config(; friendly_tangents = true)),
16+
AutoMooncake(; config=Mooncake.Config(; friendly_tangents=true)),
17+
AutoMooncakeForward(; config=Mooncake.Config(; friendly_tangents=true)),
1818
]
1919

2020
for backend in backends
@@ -23,31 +23,25 @@ for backend in backends
2323
end
2424

2525
test_differentiation(
26-
backends[3:4],
27-
default_scenarios();
28-
excluded = SECOND_ORDER,
29-
logging = LOGGING,
26+
backends[3:4], default_scenarios(); excluded=SECOND_ORDER, logging=LOGGING
3027
);
3128

3229
test_differentiation(
3330
backends[3:4],
3431
nomatrix(
3532
default_scenarios(;
36-
include_normal = false,
37-
include_constantified = true,
38-
include_cachified = true,
39-
use_tuples = true
40-
)
33+
include_normal=false,
34+
include_constantified=true,
35+
include_cachified=true,
36+
use_tuples=true,
37+
),
4138
);
42-
excluded = SECOND_ORDER,
43-
logging = LOGGING,
39+
excluded=SECOND_ORDER,
40+
logging=LOGGING,
4441
);
4542

4643
test_differentiation(
47-
backends[1:2],
48-
nomatrix(default_scenarios());
49-
excluded = SECOND_ORDER,
50-
logging = LOGGING,
44+
backends[1:2], nomatrix(default_scenarios()); excluded=SECOND_ORDER, logging=LOGGING
5145
);
5246

5347
EXCLUDED = @static if VERSION v"1.11-" && VERSION v"1.12-"
@@ -63,27 +57,24 @@ end
6357
test_differentiation(
6458
[SecondOrder(AutoMooncakeForward(), AutoMooncake())],
6559
nomatrix(default_scenarios());
66-
excluded = EXCLUDED,
67-
logging = LOGGING,
60+
excluded=EXCLUDED,
61+
logging=LOGGING,
6862
)
6963

7064
@testset "NamedTuples" begin
71-
ps = (; A = rand(5), B = rand(5))
65+
ps = (; A=rand(5), B=rand(5))
7266
myfun(ps) = sum(ps.A .* ps.B)
7367
grad = gradient(myfun, backends[1], ps)
7468
@test grad.A == ps.B
7569
@test grad.B == ps.A
7670
end
7771

7872
test_differentiation(
79-
backends[3:4],
80-
nomatrix(static_scenarios());
81-
logging = LOGGING,
82-
excluded = SECOND_ORDER
73+
backends[3:4], nomatrix(static_scenarios()); logging=LOGGING, excluded=SECOND_ORDER
8374
)
8475

8576
@testset "Friendly tangents structured matrices" begin
86-
backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))
77+
backend = AutoMooncake(; config=Mooncake.Config(; friendly_tangents=true))
8778
inputs = (
8879
Symmetric([2.0 1.0; 1.0 3.0]),
8980
Hermitian(ComplexF64[2 1 + im; 1 - im 3]),

0 commit comments

Comments
 (0)