Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add high-level API Scatter for splitting 1D array #816

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/collective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,38 @@ Scatter(sendbuf, T, comm; root::Integer=Cint(0)) =
Scatter(sendbuf, ::Type{T}, root::Integer, comm::Comm) where {T} =
Scatter!(sendbuf, Ref{T}(), root, comm)[]

"""
Scatter(arr::AbstractVector, comm::Comm; root::Cint=0)

Splits a 1D array `arr` with elements of the same type in the `root` process into `nprocs=Comm_size(comm)` smaller 1D arrays.
The array `arr` is splitted in rank order, and the number of elements `n=length(arr)` can be not divisible by `nprocs`.
Each process with the rank `j` returns a smaller array with the number of elements
`j < rem(n,nprocs) ? div(n,nprocs) + 1 : div(n,nprocs)`.
"""
function Scatter(arr::Union{Nothing, AbstractVector}, comm; root=Cint(0))
rank = MPI.Comm_rank(comm)
nprocs = MPI.Comm_size(comm)

arr_len = 0
elm_t = nothing
if rank == root
arr_len = length(arr)
elm_t = eltype(arr)
end
arr_len = MPI.Bcast(arr_len, root, comm)
elm_t = MPI.bcast(elm_t, root, comm)
Comment on lines +188 to +189
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is one of these Bcast and the other bcast?

Copy link
Author

@ykkan ykkan Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, elm_t is not a isbits type and can only be broadcasted with bcast. My idea was to use Bcast whenever possible, as its overhead should be less than bcast(?). Of course, using bcast for both elm_t and arr_len should work. I think my idea behind is rather a matter of style. What do you think?


q,r = divrem(arr_len, nprocs)
count = rank < r ? (q+1) : q
local_arr = Vector{elm_t}(undef, count)
if rank == root
counts = [i < r ? (q+1) : q for i = 0:(nprocs - 1)]
return MPI.Scatterv!(MPI.VBuffer(arr, counts), MPI.Buffer(local_arr), root, comm)
else
return MPI.Scatterv!(nothing, MPI.Buffer(local_arr), root, comm)
end
end

"""
scatter(objs::Union{AbstractVector, Nothing}, comm::Comm; root::Integer=0)

Expand Down
Loading