From b5a242c2ac32a3f6376563d152ba75cef95ca9ef Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Mon, 17 Feb 2025 11:51:31 -0600 Subject: [PATCH] Check if a buffer for the decompression of acyclic coloring is needed --- src/decompression.jl | 4 ++-- src/result.jl | 3 ++- test/allocations.jl | 18 ++++++++++++++++++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/decompression.jl b/src/decompression.jl index 4b26f7d3..a50517f7 100644 --- a/src/decompression.jl +++ b/src/decompression.jl @@ -523,7 +523,7 @@ function decompress!( R = eltype(A) fill!(A, zero(R)) - if eltype(buffer) == R + if eltype(buffer) == R || isempty(buffer) buffer_right_type = buffer else buffer_right_type = similar(buffer, R) @@ -599,7 +599,7 @@ function decompress!( nzA = nonzeros(A) uplo == :F && check_same_pattern(A, S) - if eltype(buffer) == R + if eltype(buffer) == R || isempty(buffer) buffer_right_type = buffer else buffer_right_type = similar(buffer, R) diff --git a/src/result.jl b/src/result.jl index c7f13e5e..16527b1f 100644 --- a/src/result.jl +++ b/src/result.jl @@ -362,7 +362,8 @@ function TreeSetColoringResult( # buffer holds the sum of edge values for subtrees in a tree. # For each vertex i, buffer[i] is the sum of edge values in the subtree rooted at i. - buffer = Vector{R}(undef, nvertices) + # Note that we don't need a buffer is all trees are stars. + buffer = all(is_star) ? R[] : Vector{R}(undef, nvertices) return TreeSetColoringResult( A, diff --git a/test/allocations.jl b/test/allocations.jl index 81272c90..1246ec27 100644 --- a/test/allocations.jl +++ b/test/allocations.jl @@ -132,3 +132,21 @@ end; test_noallocs_structured_decompression(1000; structure, partition, decompression) end end; + +@testset "Multi-precision acyclic decompression" begin + @testset "$format" for format in ("dense", "sparse") + A = [0 0 1; 0 1 0; 1 0 0] + if format == "sparse" + A = sparse(A) + end + problem = ColoringProblem(; structure=:symmetric, partition=:column) + result = coloring(A, problem, GreedyColoringAlgorithm{:substitution}()) + @test isempty(result.buffer) + for T in (Float32, Float64) + C = rand(T) * T.(A) + B = compress(C, result) + bench_multiprecision = @be decompress!(C, B, result) + @test minimum(bench_multiprecision).allocs == 0 + end + end +end