Skip to content

Commit 9081eec

Browse files
committed
Take into account the comments of Guillaume
1 parent a071e46 commit 9081eec

3 files changed

Lines changed: 23 additions & 68 deletions

File tree

src/decompression.jl

Lines changed: 19 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -530,29 +530,16 @@ end
530530

531531
## TreeSetColoringResult
532532

533-
function compute_tree_value(is_star::Val{false}, B::AbstractMatrix, i::Integer, j::Integer, color::AbstractVector{<:Integer}, buffer::AbstractVector{<:Real})
534-
# The tree is not a star
535-
val = B[i, color[j]] - buffer[i]
536-
buffer[j] = buffer[j] + val
537-
return val
538-
end
539-
540-
function compute_tree_value(is_star::Val{true}, B::AbstractMatrix, i::Integer, j::Integer, color::AbstractVector{<:Integer}, buffer::AbstractVector{<:Real})
541-
# The tree is a star (trivial or non-trivial)
542-
val = B[i, color[j]]
543-
return val
544-
end
545-
546533
function decompress!(
547534
A::AbstractMatrix, B::AbstractMatrix, result::TreeSetColoringResult, uplo::Symbol=:F
548535
)
549-
(; ag, color, reverse_bfs_orders, is_star, buffer) = result
536+
(; ag, color, reverse_bfs_orders, buffer) = result
550537
(; S) = ag
551538
uplo == :F && check_same_pattern(A, S)
552539
R = eltype(A)
553540
fill!(A, zero(R))
554541

555-
if eltype(buffer) == R || isempty(buffer)
542+
if eltype(buffer) == R
556543
buffer_right_type = buffer
557544
else
558545
buffer_right_type = similar(buffer, R)
@@ -569,22 +556,17 @@ function decompress!(
569556

570557
# Recover the off-diagonal coefficients of A
571558
for k in eachindex(reverse_bfs_orders)
572-
is_star_k = is_star[k]
573-
val_is_star_k = Val(is_star_k)
574-
575-
# We need the buffer only when the tree is not a star (trivial or non-trivial)
576-
if !is_star_k
577-
# Reset the buffer to zero for all vertices in a tree (except the root)
578-
for (vertex, _) in reverse_bfs_orders[k]
579-
buffer_right_type[vertex] = zero(R)
580-
end
581-
# Reset the buffer to zero for the root vertex
582-
(_, root) = reverse_bfs_orders[k][end]
583-
buffer_right_type[root] = zero(R)
559+
# Reset the buffer to zero for all vertices in a tree (except the root)
560+
for (vertex, _) in reverse_bfs_orders[k]
561+
buffer_right_type[vertex] = zero(R)
584562
end
563+
# Reset the buffer to zero for the root vertex
564+
(_, root) = reverse_bfs_orders[k][end]
565+
buffer_right_type[root] = zero(R)
585566

586567
for (i, j) in reverse_bfs_orders[k]
587-
val = compute_tree_value(val_is_star_k, B, i, j, color, buffer_right_type)
568+
val = B[i, color[j]] - buffer_right_type[i]
569+
buffer_right_type[j] = buffer_right_type[j] + val
588570

589571
if in_triangle(i, j, uplo)
590572
A[i, j] = val
@@ -607,7 +589,6 @@ function decompress!(
607589
ag,
608590
color,
609591
reverse_bfs_orders,
610-
is_star,
611592
diagonal_indices,
612593
diagonal_nzind,
613594
lower_triangle_offsets,
@@ -619,7 +600,7 @@ function decompress!(
619600
nzA = nonzeros(A)
620601
uplo == :F && check_same_pattern(A, S)
621602

622-
if eltype(buffer) == R || isempty(buffer)
603+
if eltype(buffer) == R
623604
buffer_right_type = buffer
624605
else
625606
buffer_right_type = similar(buffer, R)
@@ -652,23 +633,18 @@ function decompress!(
652633

653634
# Recover the off-diagonal coefficients of A
654635
for k in eachindex(reverse_bfs_orders)
655-
is_star_k = is_star[k]
656-
val_is_star_k = Val(is_star_k)
657-
658-
# We need the buffer only when the tree is not a star (trivial or non-trivial)
659-
if !is_star_k
660-
# Reset the buffer to zero for all vertices in a tree (except the root)
661-
for (vertex, _) in reverse_bfs_orders[k]
662-
buffer_right_type[vertex] = zero(R)
663-
end
664-
# Reset the buffer to zero for the root vertex
665-
(_, root) = reverse_bfs_orders[k][end]
666-
buffer_right_type[root] = zero(R)
636+
# Reset the buffer to zero for all vertices in a tree (except the root)
637+
for (vertex, _) in reverse_bfs_orders[k]
638+
buffer_right_type[vertex] = zero(R)
667639
end
640+
# Reset the buffer to zero for the root vertex
641+
(_, root) = reverse_bfs_orders[k][end]
642+
buffer_right_type[root] = zero(R)
668643

669644
for (i, j) in reverse_bfs_orders[k]
670645
counter += 1
671-
compute_tree_value(val_is_star_k, B, i, j, color, buffer_right_type)
646+
val = B[i, color[j]] - buffer_right_type[i]
647+
buffer_right_type[j] = buffer_right_type[j] + val
672648

673649
#! format: off
674650
# A[i,j] is in the lower triangular part of A

src/result.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,6 @@ struct TreeSetColoringResult{M<:AbstractMatrix,G<:AdjacencyGraph,V,R} <:
279279
color::Vector{Int}
280280
group::V
281281
reverse_bfs_orders::Vector{Vector{Tuple{Int,Int}}}
282-
is_star::Vector{Bool}
283282
diagonal_indices::Vector{Int}
284283
diagonal_nzind::Vector{Int}
285284
lower_triangle_offsets::Vector{Int}
@@ -294,7 +293,7 @@ function TreeSetColoringResult(
294293
tree_set::TreeSet,
295294
decompression_eltype::Type{R},
296295
) where {R}
297-
(; reverse_bfs_orders, is_star) = tree_set
296+
(; reverse_bfs_orders) = tree_set
298297
(; S) = ag
299298
nvertices = length(color)
300299
group = group_by_color(color)
@@ -362,16 +361,14 @@ function TreeSetColoringResult(
362361

363362
# buffer holds the sum of edge values for subtrees in a tree.
364363
# For each vertex i, buffer[i] is the sum of edge values in the subtree rooted at i.
365-
# Note that we don't need a buffer is all trees are stars.
366-
buffer = all(is_star) ? R[] : Vector{R}(undef, nvertices)
364+
buffer = Vector{R}(undef, nvertices)
367365

368366
return TreeSetColoringResult(
369367
A,
370368
ag,
371369
color,
372370
group,
373371
reverse_bfs_orders,
374-
is_star,
375372
diagonal_indices,
376373
diagonal_nzind,
377374
lower_triangle_offsets,

test/allocations.jl

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ end
1919

2020
@testset "Distance-2 coloring" begin
2121
test_noallocs_distance2_coloring(1000)
22-
end;
22+
end
2323

2424
function test_noallocs_sparse_decompression(
2525
n::Integer; structure::Symbol, partition::Symbol, decompression::Symbol
@@ -121,7 +121,7 @@ end
121121
]
122122
test_noallocs_sparse_decompression(1000; structure, partition, decompression)
123123
end
124-
end;
124+
end
125125

126126
@testset "Structured decompression" begin
127127
@testset "$structure - $partition - $decompression" for (
@@ -131,22 +131,4 @@ end;
131131
]
132132
test_noallocs_structured_decompression(1000; structure, partition, decompression)
133133
end
134-
end;
135-
136-
@testset "Multi-precision acyclic decompression" begin
137-
@testset "$format" for format in ("dense", "sparse")
138-
A = [0 0 1; 0 1 0; 1 0 0]
139-
if format == "sparse"
140-
A = sparse(A)
141-
end
142-
problem = ColoringProblem(; structure=:symmetric, partition=:column)
143-
result = coloring(A, problem, GreedyColoringAlgorithm{:substitution}())
144-
@test isempty(result.buffer)
145-
for T in (Float32, Float64)
146-
C = rand(T) * T.(A)
147-
B = compress(C, result)
148-
bench_multiprecision = @be decompress!(C, B, result)
149-
@test minimum(bench_multiprecision).allocs == 0
150-
end
151-
end
152134
end

0 commit comments

Comments
 (0)