Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/Test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
actions: write
contents: read
strategy:
fail-fast: false # TODO: toggle
fail-fast: true # TODO: toggle
matrix:
version:
- '1.10'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ using ChainRulesCore:
RuleConfig,
frule_via_ad,
rrule_via_ad,
unthunk
unthunk,
@not_implemented
import DifferentiationInterface as DI

ruleconfig(backend::AutoChainRules) = backend.ruleconfig
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x)
(; f, backend) = dw
y = f(x)
prep_same = DI.prepare_pullback_same_point_nokwarg(Val(false), f, backend, x, (y,))
function pullbackfunc(dy)
tx = DI.pullback(f, prep_same, backend, x, (dy,))
return (NoTangent(), only(tx))
function ChainRulesCore.rrule(
dw::DI.DifferentiateWith{C}, x, contexts::Vararg{Any, C}
) where {C}
(; f, backend, context_wrappers) = dw
y = f(x, contexts...)
wrapped_contexts = map(DI.call, context_wrappers, contexts)
prep_same = DI.prepare_pullback_same_point_nokwarg(
Val(false), f, backend, x, (y,), wrapped_contexts...
)
function diffwith_pullbackfunc(dy)
dx = DI.pullback(f, prep_same, backend, x, (dy,), wrapped_contexts...) |> only
dc = map(contexts) do c
@not_implemented(
"""
Derivatives with respect to context arguments are not implemented.
"""
)
end
return (NoTangent(), dx, dc...)
end
return y, pullbackfunc
return y, diffwith_pullbackfunc
end
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
function (dw::DI.DifferentiateWith)(x::Dual{T, V, N}) where {T, V, N}
function (dw::DI.DifferentiateWith{0})(x::Dual{T, V, N}) where {T, V, N}
(; f, backend) = dw
xval = myvalue(T, x)
tx = mypartials(T, Val(N), x)
y, ty = DI.value_and_pushforward(f, backend, xval, tx)
return make_dual(T, y, ty)
end

function (dw::DI.DifferentiateWith)(x::AbstractArray{Dual{T, V, N}}) where {T, V, N}
function (dw::DI.DifferentiateWith{0})(x::AbstractArray{Dual{T, V, N}}) where {T, V, N}
(; f, backend) = dw
xval = myvalue(T, x)
tx = mypartials(T, Val(N), x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ using Mooncake:
value_and_pullback!!,
zero_dual,
zero_tangent,
zero_rdata,
rdata_type,
fdata,
rdata,
Expand All @@ -26,11 +27,13 @@ using Mooncake:
@is_primitive,
zero_fcodual,
MinimalCtx,
NoFData,
NoRData,
primal,
_copy_output,
_copy_to_output!!,
tangent_to_primal!!
tangent_to_primal!!,
increment!!

const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith, <:Any}
const NumberOrArray = Union{Number, AbstractArray{<:Number}}

# Mark DifferentiateWith with a range of context arities as primitives.
# For C contexts, the corresponding call tuple type is
# Tuple{DI.DifferentiateWith{C}, Any, Vararg{Any, C}}:
# one slot for the primal input x and C slots for contexts.
for C in 0:16
@eval @is_primitive MinimalCtx Tuple{DI.DifferentiateWith{$C}, Vararg{Any, $(C + 1)}}
end
struct MooncakeDifferentiateWithError <: Exception
F::Type
X::Type
Expand All @@ -12,72 +19,87 @@ end
function Base.showerror(io::IO, e::MooncakeDifferentiateWithError)
return print(
io,
"MooncakeDifferentiateWithError: For the function type $(e.F) and input type $(e.X), the output type $(e.Y) is currently not supported.",
"MooncakeDifferentiateWithError: For the function type `$(e.F)` and input types `$(e.X)`, the output type `$(e.Y)` is currently not supported.",
)
end

function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number})
function Mooncake.rrule!!(
dw::CoDual{<:DI.DifferentiateWith{C}},
x::CoDual{<:Number},
contexts::Vararg{CoDual, C}
) where {C}
@assert tangent_type(typeof(dw)) == NoTangent
primal_func = primal(dw)
primal_x = primal(x)
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))
primal_contexts = map(primal, contexts)
(; f, backend, context_wrappers) = primal_func
y = zero_fcodual(f(primal_x, primal_contexts...))
wrapped_primal_contexts = map(DI.call, context_wrappers, primal_contexts)

# output is a vector, so we need to use the vector pullback
function pullback_array!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (y.dx,))
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
return NoRData(), rdata(only(tx))
dx = DI.pullback(f, backend, primal_x, (y.dx,), wrapped_primal_contexts...) |> only
@assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x)))
rc = nanify_fdata_and_rdata!!(contexts...)
return (NoRData(), rdata(dx), rc...)
end

# output is a scalar, so we can use the scalar pullback
function pullback_scalar!!(dy::Number)
tx = DI.pullback(f, backend, primal_x, (dy,))
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
return NoRData(), rdata(only(tx))
dx = DI.pullback(f, backend, primal_x, (dy,), wrapped_primal_contexts...) |> only
@assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x)))
rc = nanify_fdata_and_rdata!!(contexts...)
return (NoRData(), rdata(dx), rc...)
end

pullback = if primal(y) isa Number
pullback_scalar!!
elseif primal(y) isa AbstractArray
pullback_array!!
else
throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y)))
throw(MooncakeDifferentiateWithError(primal_func, (primal_x, primal_contexts...), primal(y)))
end

return y, pullback
end

function Mooncake.rrule!!(
dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray{<:Number}}
)
dw::CoDual{<:DI.DifferentiateWith{C}},
x::CoDual{<:AbstractArray{<:Number}},
contexts::Vararg{CoDual, C}
) where {C}
@assert tangent_type(typeof(dw)) == NoTangent
primal_func = primal(dw)
primal_x = primal(x)
fdata_arg = x.dx
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))
primal_contexts = map(primal, contexts)
(; f, backend, context_wrappers) = primal_func
y = zero_fcodual(f(primal_x, primal_contexts...))
wrapped_primal_contexts = map(DI.call, context_wrappers, primal_contexts)

# output is a vector, so we need to use the vector pullback
function pullback_array!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (y.dx,))
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
fdata_arg .+= only(tx)
return NoRData(), dy
dx = DI.pullback(f, backend, primal_x, (y.dx,), wrapped_primal_contexts...) |> only
@assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x)))
x.dx .+= dx
rc = nanify_fdata_and_rdata!!(contexts...)
return (NoRData(), dy, rc...)
end

# output is a scalar, so we can use the scalar pullback
function pullback_scalar!!(dy::Number)
tx = DI.pullback(f, backend, primal_x, (dy,))
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
fdata_arg .+= only(tx)
return NoRData(), NoRData()
dx = DI.pullback(f, backend, primal_x, (dy,), wrapped_primal_contexts...) |> only
@assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x)))
x.dx .+= dx
rc = nanify_fdata_and_rdata!!(contexts...)
return (NoRData(), NoRData(), rc...)
end

pullback = if primal(y) isa Number
pullback_scalar!!
elseif primal(y) isa AbstractArray
pullback_array!!
else
throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y)))
throw(MooncakeDifferentiateWithError(primal_func, (primal_x, primal_contexts...), primal(y)))
end

return y, pullback
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,19 @@ function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
return zero_tangent(x)
end
end

nanify(x::AbstractFloat) = convert(typeof(x), NaN)
nanify(x::AbstractArray) = map(nanify, x)
nanify(x::NamedTuple) = NamedTuple{keys(x)}(map(nanify, values(x)))
nanify(x::Tuple) = map(nanify, x)
nanify(::NoRData) = NoRData()

function nanify_fdata_and_rdata!!(contexts::Vararg{CoDual, C}) where {C}
primal_contexts = map(primal, contexts)
fdata_contexts = map(tangent, contexts)
zero_rdata_contexts = map(zero_rdata, primal_contexts)
foreach(fdata_contexts) do fc
increment!!(fc, nanify(fc))
end
return map(nanify, zero_rdata_contexts)
end
32 changes: 26 additions & 6 deletions DifferentiationInterface/src/misc/differentiate_with.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be

!!! warning

`DifferentiateWith` only supports out-of-place functions `y = f(x)` without additional context arguments.
`DifferentiateWith` only supports out-of-place functions `y = f(x, contexts...)`, where the derivatives with respect to `contexts` can be safely ignored in the rest of your code.
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), reverse-mode [Mooncake](https://github.com/chalk-lab/Mooncake.jl), or if it automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules.
For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper).

Expand All @@ -25,16 +25,17 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be

# Fields

- `f`: the function in question, with signature `f(x)`
- `f`: the function in question, with signature `f(x, contexts...)`
- `backend::AbstractADType`: the substitute backend to use for differentiation
- `context_wrappers::NTuple`: a tuple like `(Constant, Cache)`, meaning that `f(x, a, b)` will be differentiated with `Constant(a)` and `Cache(b)` as contexts.

!!! note

For the substitute AD backend to be called under the hood, its package needs to be loaded in addition to the package of the true AD backend.

# Constructor

DifferentiateWith(f, backend)
DifferentiateWith(f, backend, context_wrappers)

# Example

Expand Down Expand Up @@ -69,22 +70,41 @@ julia> Zygote.gradient(alg, [3.0, 5.0])[1]
70.0
```
"""
struct DifferentiateWith{F, B <: AbstractADType}
struct DifferentiateWith{C, F, B <: AbstractADType, N <: NTuple{C, Any}}
f::F
backend::B
context_wrappers::N

function DifferentiateWith(
f::F,
backend::B,
context_wrappers::NTuple{C, <:Union{Function, Type}},
) where {F, B <: AbstractADType, C}
return new{C, F, B, typeof(context_wrappers)}(
f,
backend,
context_wrappers,
)
end
end

Comment thread
gdalle marked this conversation as resolved.
(dw::DifferentiateWith)(x) = dw.f(x)
DifferentiateWith(f::F, backend::AbstractADType) where {F} = DifferentiateWith(f, backend, ())

function (dw::DifferentiateWith{C})(x, args::Vararg{Any, C}) where {C}
return dw.f(x, args...)
end

function Base.show(io::IO, dw::DifferentiateWith)
(; f, backend) = dw
(; f, backend, context_wrappers) = dw
return print(
io,
DifferentiateWith,
"(",
repr(f; context = io),
", ",
repr(backend; context = io),
", ",
repr(context_wrappers; context = io),
")",
)
end
2 changes: 2 additions & 0 deletions DifferentiationInterface/src/utils/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,5 @@ Convenience for constructing a [`FixTail`](@ref), with a shortcut when there are
"""
@inline fix_tail(f::F) where {F} = f
fix_tail(f::F, args::Vararg{Any, N}) where {F, N} = FixTail(f, args...)

@inline call(f::F, x) where {F} = f(x)
38 changes: 32 additions & 6 deletions DifferentiationInterface/test/Back/DifferentiateWith/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@ function (adb::ADBreaker)(x::AbstractArray)
return adb.f(x)
end

function differentiatewith_scenarios()
outofplace_scens = filter(DIT.default_scenarios()) do scen
DIT.function_place(scen) == :out
# TODO: break Mooncake with overlay?

function differentiatewith_scenarios(; kwargs...)
outofplace_scens = filter(DIT.default_scenarios(; kwargs...)) do scen
DIT.function_place(scen) == :out &&
# save some time
!isa(scen.x, AbstractMatrix) &&
!isa(scen.y, AbstractMatrix)
end
# with bad_scens, everything would break
bad_scens = map(outofplace_scens) do scen
Expand All @@ -44,7 +49,26 @@ test_differentiation(
differentiatewith_scenarios();
excluded = SECOND_ORDER,
logging = LOGGING,
testset_name = "DI tests",
testset_name = "DI tests - normal",
)

test_differentiation(
[AutoZygote(), AutoMooncake(; config = nothing)],
map(DIT.constantify, differentiatewith_scenarios());
excluded = SECOND_ORDER,
logging = LOGGING,
testset_name = "DI tests - Constant",
)

test_differentiation(
[AutoMooncake(; config = nothing)],
map(differentiatewith_scenarios()) do s
s = DIT.cachify(s; use_tuples = true)
DIT.change_function(s, DifferentiateWith(s.f, AutoFiniteDiff(), (Cache,)))
end;
excluded = SECOND_ORDER,
logging = LOGGING,
testset_name = "DI tests - Cache",
)

@testset "ChainRules tests" begin
Expand All @@ -69,9 +93,9 @@ end;
MooncakeDifferentiateWithError =
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceMooncakeExt).MooncakeDifferentiateWithError

e = MooncakeDifferentiateWithError(identity, 1.0, 2.0)
e = MooncakeDifferentiateWithError(identity, (1.0,), 2.0)
@test sprint(showerror, e) ==
"MooncakeDifferentiateWithError: For the function type typeof(identity) and input type Float64, the output type Float64 is currently not supported."
"MooncakeDifferentiateWithError: For the function type `typeof(identity)` and input types `Tuple{Float64}`, the output type `Float64` is currently not supported."

f_num2tup(x::Number) = (x,)
f_vec2tup(x::Vector) = (first(x),)
Expand Down Expand Up @@ -103,3 +127,5 @@ end;
([2.0],),
)
end

@test_throws MethodError DifferentiateWith(exp, AutoForwardDiff(), (3,))
2 changes: 1 addition & 1 deletion DifferentiationInterface/test/Back/run_backend.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Test
group = ENV["JULIA_DI_TEST_GROUP"]
@testset "$group" begin
@testset verbose = true "$group" begin
include(joinpath(@__DIR__, group, "test.jl"))
end
2 changes: 1 addition & 1 deletion DifferentiationInterface/test/Core/Internals/display.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ detector = DenseSparsityDetector(AutoForwardDiff(); atol = 1.0e-23)
"DenseSparsityDetector(AutoForwardDiff(); atol=1.0e-23, method=:iterative)"

diffwith = DifferentiateWith(exp, AutoForwardDiff())
@test string(diffwith) == "DifferentiateWith(exp, AutoForwardDiff())"
@test string(diffwith) == "DifferentiateWith(exp, AutoForwardDiff(), ())"

@test required_packages(AutoForwardDiff()) == ["ForwardDiff"]
@test required_packages(AutoZygote()) == ["Zygote"]
Expand Down
Loading