@@ -82,11 +82,25 @@ Transform `x` into a `Vector`, and return a closure which inverts the transforma
8282"""
8383to_vec (x:: Real ) = ([x], first)
8484
85- # Arrays.
85+ # Vectors
8686to_vec (x:: Vector{<:Real} ) = (x, identity)
87- to_vec (x:: Array ) = vec (x), x_vec-> reshape (x_vec, size (x))
87+ function to_vec (x:: Vector )
88+ x_vecs_and_backs = map (to_vec, x)
89+ x_vecs, backs = first .(x_vecs_and_backs), last .(x_vecs_and_backs)
90+ return vcat (x_vecs... ), function (x_vec)
91+ sz = cumsum ([map (length, x_vecs)... ])
92+ return [backs[n](x_vec[sz[n]- length (x_vecs[n])+ 1 : sz[n]]) for n in eachindex (x)]
93+ end
94+ end
95+
96+ # Arrays
97+ to_vec (x:: Array{<:Real} ) = vec (x), x_vec-> reshape (x_vec, size (x))
98+ function to_vec (x:: Array )
99+ x_vec, back = to_vec (reshape (x, :))
100+ return x_vec, x_vec-> reshape (back (x_vec), size (x))
101+ end
88102
89- # AbstractArrays.
103+ # AbstractArrays
90104function to_vec (x:: T ) where {T<: LinearAlgebra.AbstractTriangular }
91105 x_vec, back = to_vec (Matrix (x))
92106 return x_vec, x_vec-> T (reshape (back (x_vec), size (x)))
@@ -99,11 +113,22 @@ function to_vec(X::T) where T<:Union{Adjoint,Transpose}
99113 return vec (Matrix (X)), x_vec-> U (permutedims (reshape (x_vec, size (X))))
100114end
101115
102- # Non-array data structures.
116+ # Non-array data structures
117+
103118function to_vec (x:: Tuple )
104119 x_vecs, x_backs = zip (map (to_vec, x)... )
105120 sz = cumsum ([map (length, x_vecs)... ])
106121 return vcat (x_vecs... ), function (v)
107122 return ntuple (n-> x_backs[n](v[sz[n]- length (x_vecs[n])+ 1 : sz[n]]), length (x))
108123 end
109124end
125+
126+ # Convert to a vector-of-vectors to make use of existing functionality.
127+ function to_vec (d:: Dict )
128+ d_vec_vec = [val for val in values (d)]
129+ d_vec, back = to_vec (d_vec_vec)
130+ return d_vec, function (v)
131+ v_vec_vec = back (v)
132+ return Dict ([(key, v_vec_vec[n]) for (n, key) in enumerate (keys (d))])
133+ end
134+ end
0 commit comments