-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathtest.jl
More file actions
131 lines (115 loc) · 3.98 KB
/
test.jl
File metadata and controls
131 lines (115 loc) · 3.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
include("../../testutils.jl")
using ChainRulesTestUtils: ChainRulesTestUtils
using DifferentiationInterface, DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Zygote: Zygote
using Mooncake: Mooncake
using StableRNGs
using Test
struct ADBreaker{F}
f::F
end
function (adb::ADBreaker)(x::Number)
copyto!(Float64[0], x) # break ForwardDiff and Zygote
return adb.f(x)
end
function (adb::ADBreaker)(x::AbstractArray)
copyto!(similar(x, Float64), x) # break ForwardDiff and Zygote
return adb.f(x)
end
# TODO: break Mooncake with overlay?
function differentiatewith_scenarios(; kwargs...)
outofplace_scens = filter(DIT.default_scenarios(; kwargs...)) do scen
DIT.function_place(scen) == :out &&
# save some time
!isa(scen.x, AbstractMatrix) &&
!isa(scen.y, AbstractMatrix)
end
# with bad_scens, everything would break
bad_scens = map(outofplace_scens) do scen
DIT.change_function(scen, ADBreaker(scen.f))
end
# with good_scens, everything is fixed
good_scens = map(bad_scens) do scen
DIT.change_function(scen, DifferentiateWith(scen.f, AutoFiniteDiff()))
end
return good_scens
end
test_differentiation(
[AutoForwardDiff(), AutoZygote(), AutoMooncake(; config = nothing)],
differentiatewith_scenarios();
excluded = SECOND_ORDER,
logging = LOGGING,
testset_name = "DI tests - normal",
)
test_differentiation(
[AutoZygote(), AutoMooncake(; config = nothing)],
map(DIT.constantify, differentiatewith_scenarios());
excluded = SECOND_ORDER,
logging = LOGGING,
testset_name = "DI tests - Constant",
)
test_differentiation(
[AutoMooncake(; config = nothing)],
map(differentiatewith_scenarios()) do s
s = DIT.cachify(s; use_tuples = true)
DIT.change_function(s, DifferentiateWith(s.f, AutoFiniteDiff(), (Cache,)))
end;
excluded = SECOND_ORDER,
logging = LOGGING,
testset_name = "DI tests - Cache",
)
@testset "ChainRules tests" begin
@testset for scen in filter(differentiatewith_scenarios()) do scen
DIT.operator(scen) == :pullback
end
ChainRulesTestUtils.test_rrule(scen.f, scen.x; rtol = 1.0e-4)
end
end;
@testset "Mooncake tests" begin
@testset for scen in filter(differentiatewith_scenarios()) do scen
DIT.operator(scen) == :pullback
end
Mooncake.TestUtils.test_rule(
StableRNG(0), scen.f, scen.x; is_primitive = true, mode = Mooncake.ReverseMode
)
end
end;
@testset "Mooncake errors" begin
MooncakeDifferentiateWithError =
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceMooncakeExt).MooncakeDifferentiateWithError
e = MooncakeDifferentiateWithError(identity, (1.0,), 2.0)
@test sprint(showerror, e) ==
"MooncakeDifferentiateWithError: For the function type `typeof(identity)` and input types `Tuple{Float64}`, the output type `Float64` is currently not supported."
f_num2tup(x::Number) = (x,)
f_vec2tup(x::Vector) = (first(x),)
f_tup2num(x::Tuple{<:Number}) = only(x)
f_tup2vec(x::Tuple{<:Number}) = [only(x)]
@test_throws MooncakeDifferentiateWithError pullback(
DifferentiateWith(f_num2tup, AutoFiniteDiff()),
AutoMooncake(; config = nothing),
1.0,
((2.0,),),
)
@test_throws MooncakeDifferentiateWithError pullback(
DifferentiateWith(f_vec2tup, AutoFiniteDiff()),
AutoMooncake(; config = nothing),
[1.0],
((2.0,),),
)
@test_throws MethodError pullback(
DifferentiateWith(f_tup2num, AutoFiniteDiff()),
AutoMooncake(; config = nothing),
(1.0,),
(2.0,),
)
@test_throws MethodError pullback(
DifferentiateWith(f_tup2vec, AutoFiniteDiff()),
AutoMooncake(; config = nothing),
(1.0,),
([2.0],),
)
end
@test_throws MethodError DifferentiateWith(exp, AutoForwardDiff(), (3,))