Skip to content

Commit 57703c1

Browse files
authored
Merge pull request #181 from JuliaDiff/mz/struct-to-vec
try to handle structs with custom constructors
2 parents 4d30c43 + 8fdf7e6 commit 57703c1

3 files changed

Lines changed: 35 additions & 2 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FiniteDifferences"
22
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
3-
version = "0.12.14"
3+
version = "0.12.15"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/to_vec.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ end
2121
# Base case -- if x is already a Vector{<:Real} there's no conversion necessary.
2222
to_vec(x::Vector{<:Real}) = (x, identity)
2323

24+
# get around the constructors and make the type directly
25+
# Note this is moderately evil accessing julia's internals
26+
if VERSION >= v"1.3"
27+
@generated function _force_construct(T, args...)
28+
return Expr(:splatnew, :T, :args)
29+
end
30+
else
31+
@generated function _force_construct(T, args...)
32+
return Expr(:new, :T, Any[:(args[$i]) for i in 1:length(args)]...)
33+
end
34+
end
35+
2436
# Fallback method for `to_vec`. Won't always do what you wanted, but should be fine a decent
2537
# chunk of the time.
2638
function to_vec(x::T) where {T}
@@ -35,7 +47,11 @@ function to_vec(x::T) where {T}
3547
function structtype_from_vec(v::Vector{<:Real})
3648
val_vecs = vals_from_vec(v)
3749
values = map((b, v) -> b(v), backs, val_vecs)
38-
return T(values...)
50+
try
51+
T(values...)
52+
catch MethodError
53+
return _force_construct(T, values...)
54+
end
3955
end
4056
return v, structtype_from_vec
4157
end

test/to_vec.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ end
4444
Base.size(a::WrapperArray) = size(a.data)
4545
Base.getindex(a::WrapperArray, inds...) = getindex(a.data, inds...)
4646

47+
# can not construct it from: cca = CustomConstructorArray(rand(2, 2))
48+
# T = typeof(cca) # CustomConstructorArray{Float64, 2, Matrix{Float64}}
49+
# T(rand(2, 3)) # errors
50+
struct CustomConstructorArray{T, N, A<:AbstractArray{T, N}} <: AbstractArray{T, N}
51+
data::A
52+
function CustomConstructorArray(data::A) where {T, N, A<:AbstractArray{T, N}}
53+
return new{T, N, A}(data)
54+
end
55+
end
56+
Base.size(a::CustomConstructorArray) = size(a.data)
57+
Base.getindex(a::CustomConstructorArray, inds...) = getindex(a.data, inds...)
58+
4759
function test_to_vec(x::T; check_inferred=true) where {T}
4860
check_inferred && @inferred to_vec(x)
4961
x_vec, back = to_vec(x)
@@ -195,4 +207,9 @@ end
195207
wa = WrapperArray(rand(4, 5))
196208
test_to_vec(wa; check_inferred=false)
197209
end
210+
211+
@testset "CustomConstructorArray" begin
212+
cca = CustomConstructorArray(rand(2, 3))
213+
test_to_vec(cca; check_inferred=false)
214+
end
198215
end

0 commit comments

Comments
 (0)