Skip to content

Commit c31eec3

Browse files
author
Stuart Daines
committed
Add fast path for _colorediteration! with sparse J
Partial fix for JuliaDiff/SparseDiffTools.jl#138 JuliaDiff/SparseDiffTools.jl#100 Adds a fast path for the case where J and sparsity are are both SparseMatrixCSC and have the same number of columns and stored values.
1 parent 0c79fe7 commit c31eec3

2 files changed

Lines changed: 47 additions & 5 deletions

File tree

src/iteration_utils.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,37 @@ end
1616
end
1717
end
1818

19+
# fast version for the case where J and sparsity have the same sparsity pattern
20+
@inline function _colorediteration!(Jsparsity::SparseMatrixCSC,vfx,colorvec,color_i,ncols)
21+
@inbounds for col_index in 1:ncols
22+
if colorvec[col_index] == color_i
23+
@inbounds for spidx in nzrange(Jsparsity, col_index)
24+
row_index = Jsparsity.rowval[spidx]
25+
Jsparsity.nzval[spidx]=vfx[row_index]
26+
end
27+
end
28+
end
29+
end
30+
1931
#override default setting of using findstructralnz
2032
_use_findstructralnz(sparsity) = ArrayInterface.has_sparsestruct(sparsity)
2133
_use_findstructralnz(::SparseMatrixCSC) = false
2234

35+
# test if J, sparsity are both SparseMatrixCSC and have the same size storage arrays,
36+
# if so, update J so they can share the same sparsity pattern
37+
_use_sparseCSC_common_sparsity!(J, sparsity) = false
38+
function _use_sparseCSC_common_sparsity!(J::SparseMatrixCSC, sparsity::SparseMatrixCSC)
39+
common_sparsity = (length(J.colptr) == length(sparsity.colptr) &&
40+
length(J.nzval) == length(sparsity.nzval))
41+
42+
if common_sparsity
43+
J.colptr .= sparsity.colptr
44+
J.rowval .= sparsity.rowval
45+
end
46+
47+
return common_sparsity
48+
end
49+
2350
function __init__()
2451
@require BlockBandedMatrices="ffab5731-97b5-5995-9138-79e8c1846df0" begin
2552
@require BlockArrays="8e7c35d0-a365-5155-bbbb-fb81a777f24e" begin

src/jacobians.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,9 @@ function finite_difference_jacobian!(
348348
fill!(J,false)
349349
end
350350

351+
# fast path if J and sparsity are both SparseMatrixCSC and have the same number of columns and stored values
352+
sparseCSC_common_sparsity = _use_sparseCSC_common_sparsity!(J, sparsity)
353+
351354
if fdtype == Val(:forward)
352355
vfx1 = _vec(fx1)
353356

@@ -378,7 +381,11 @@ function finite_difference_jacobian!(
378381
# J is a sparse matrix, so decompress on the fly
379382
@. vfx1 = (vfx1 - vfx) / epsilon
380383
if ArrayInterface.fast_scalar_indexing(x1)
381-
_colorediteration!(J,sparsity,rows_index,cols_index,vfx1,colorvec,color_i,n)
384+
if sparseCSC_common_sparsity
385+
_colorediteration!(J,vfx1,colorvec,color_i,n)
386+
else
387+
_colorediteration!(J,sparsity,rows_index,cols_index,vfx1,colorvec,color_i,n)
388+
end
382389
else
383390
#=
384391
J.nzval[rows_index] .+= (colorvec[cols_index] .== color_i) .* vfx1[rows_index]
@@ -417,8 +424,12 @@ function finite_difference_jacobian!(
417424
f(fx1, x1)
418425
f(fx, x)
419426
@. vfx1 = (vfx1 - vfx) / 2epsilon
420-
if ArrayInterface.fast_scalar_indexing(x1)
421-
_colorediteration!(J,sparsity,rows_index,cols_index,vfx1,colorvec,color_i,n)
427+
if ArrayInterface.fast_scalar_indexing(x1)
428+
if sparseCSC_common_sparsity
429+
_colorediteration!(J,vfx1,colorvec,color_i,n)
430+
else
431+
_colorediteration!(J,sparsity,rows_index,cols_index,vfx1,colorvec,color_i,n)
432+
end
422433
else
423434
if J isa SparseMatrixCSC
424435
@. void_setindex!((J.nzval,), getindex((J.nzval,), rows_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx1,), rows_index), rows_index)
@@ -443,8 +454,12 @@ function finite_difference_jacobian!(
443454
@. x1 = x1 + im * epsilon * (_color == color_i)
444455
f(fx,x1)
445456
@. vfx = imag(vfx) / epsilon
446-
if ArrayInterface.fast_scalar_indexing(x1)
447-
_colorediteration!(J,sparsity,rows_index,cols_index,vfx,colorvec,color_i,n)
457+
if ArrayInterface.fast_scalar_indexing(x1)
458+
if sparseCSC_common_sparsity
459+
_colorediteration!(J,vfx,colorvec,color_i,n)
460+
else
461+
_colorediteration!(J,sparsity,rows_index,cols_index,vfx,colorvec,color_i,n)
462+
end
448463
else
449464
if J isa SparseMatrixCSC
450465
@. void_setindex!((J.nzval,), getindex((J.nzval,), rows_index) + (getindex((_color,), cols_index) == color_i) * getindex((vfx,),rows_index), rows_index)

0 commit comments

Comments
 (0)