Skip to content

Commit b5a242c

Browse files
committed
Check if a buffer for the decompression of acyclic coloring is needed
1 parent f1bf8ea commit b5a242c

3 files changed

Lines changed: 22 additions & 3 deletions

File tree

src/decompression.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ function decompress!(
523523
R = eltype(A)
524524
fill!(A, zero(R))
525525

526-
if eltype(buffer) == R
526+
if eltype(buffer) == R || isempty(buffer)
527527
buffer_right_type = buffer
528528
else
529529
buffer_right_type = similar(buffer, R)
@@ -599,7 +599,7 @@ function decompress!(
599599
nzA = nonzeros(A)
600600
uplo == :F && check_same_pattern(A, S)
601601

602-
if eltype(buffer) == R
602+
if eltype(buffer) == R || isempty(buffer)
603603
buffer_right_type = buffer
604604
else
605605
buffer_right_type = similar(buffer, R)

src/result.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,8 @@ function TreeSetColoringResult(
362362

363363
# buffer holds the sum of edge values for subtrees in a tree.
364364
# For each vertex i, buffer[i] is the sum of edge values in the subtree rooted at i.
365-
buffer = Vector{R}(undef, nvertices)
365+
# Note that we don't need a buffer is all trees are stars.
366+
buffer = all(is_star) ? R[] : Vector{R}(undef, nvertices)
366367

367368
return TreeSetColoringResult(
368369
A,

test/allocations.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,21 @@ end;
132132
test_noallocs_structured_decompression(1000; structure, partition, decompression)
133133
end
134134
end;
135+
136+
@testset "Multi-precision acyclic decompression" begin
137+
@testset "$format" for format in ("dense", "sparse")
138+
A = [0 0 1; 0 1 0; 1 0 0]
139+
if format == "sparse"
140+
A = sparse(A)
141+
end
142+
problem = ColoringProblem(; structure=:symmetric, partition=:column)
143+
result = coloring(A, problem, GreedyColoringAlgorithm{:substitution}())
144+
@test isempty(result.buffer)
145+
for T in (Float32, Float64)
146+
C = rand(T) * T.(A)
147+
B = compress(C, result)
148+
bench_multiprecision = @be decompress!(C, B, result)
149+
@test minimum(bench_multiprecision).allocs == 0
150+
end
151+
end
152+
end

0 commit comments

Comments
 (0)