[ADD] Added RMA one sided communication schemes

communication_v2
Stack-1 3 weeks ago
parent 52ad95461d
commit fb5ba59693

@ -16,6 +16,7 @@ set(PSB_base_source_files
comm/internals/psi_lovrl_upd.f90
comm/internals/psi_dswapdata_a.F90
comm/internals/psi_movrl_upd_a.f90
modules/comm/comm_schemes/psb_comm_rma_mod.F90
# comm/internals/psi_i2swaptran_a.F90
comm/internals/psi_dswaptran.F90
comm/internals/psi_covrl_save_a.f90

@ -83,6 +83,9 @@ submodule (psi_d_comm_v_mod) psi_d_swapdata_impl
use psb_desc_const_mod, only: psb_swap_start_, psb_swap_wait_
use psb_base_mod
use psb_error_mod, only: psb_get_debug_level, psb_get_debug_unit, psb_debug_ext_
use psb_comm_schemes_mod, only: psb_comm_isend_irecv_, psb_comm_ineighbor_alltoallv_, &
& psb_comm_persistent_ineighbor_alltoallv_, psb_comm_rma_pull_, psb_comm_rma_push_
use psb_comm_rma_mod, only: psb_comm_rma_handle
use psb_comm_factory_mod
contains
@ -191,6 +194,20 @@ contains
call psb_errpush(info,name,a_err='neighbor persistent swap')
goto 9999
end if
case(psb_comm_rma_pull_)
call psi_dswap_rma_pull_vect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,total_send,total_recv,&
& y%comm_handle,info)
if (info /= psb_success_) then
call psb_errpush(info,name,a_err='rma pull swap')
goto 9999
end if
case(psb_comm_rma_push_)
call psi_dswap_rma_push_vect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,total_send,total_recv,&
& y%comm_handle,info)
if (info /= psb_success_) then
call psb_errpush(info,name,a_err='rma push swap')
goto 9999
end if
case default
info = psb_err_mpi_error_
call psb_errpush(info,name,a_err='Incompatible swap_status settings: no valid communication mode selected')
@ -911,6 +928,477 @@ contains
end subroutine psi_dswap_neighbor_persistent_topology_vect
subroutine psi_dswap_rma_pull_vect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,total_send,total_recv,comm_handle,info)
#ifdef PSB_MPI_MOD
use mpi
#endif
implicit none
#ifdef PSB_MPI_H
include 'mpif.h'
#endif
type(psb_ctxt_type), intent(in) :: ctxt
integer(psb_ipk_), intent(in) :: swap_status
real(psb_dpk_), intent(in) :: beta
class(psb_d_base_vect_type), intent(inout) :: y
class(psb_i_base_vect_type), intent(inout) :: comm_indexes
integer(psb_ipk_), intent(in) :: num_neighbors, total_send, total_recv
class(psb_comm_handle_type), intent(inout) :: comm_handle
integer(psb_ipk_), intent(out) :: info
integer(psb_mpk_) :: np, my_rank, iret, element_bytes, icomm
integer(psb_mpk_) :: proc_to_comm, prc_rank, recv_count, send_count, send_pos, recv_pos, list_pos
integer(psb_mpk_) :: remote_base
integer(kind=MPI_ADDRESS_KIND) :: remote_disp, exposed_bytes
integer(psb_ipk_) :: err_act, neighbor_idx, buffer_size
integer(psb_ipk_), allocatable :: peer_mpi_rank(:)
logical :: do_start, do_wait, memory_buffer_layout_rebuild_needed
type(psb_comm_rma_handle), pointer :: rma_handle
character(len=30) :: name
info = psb_success_
name = 'psi_dswap_rma_pull_vect'
call psb_erractionsave(err_act)
call psb_info(ctxt,my_rank,np)
if (np == -1) then
info = psb_err_context_error_
call psb_errpush(info,name)
goto 9999
end if
icomm = ctxt%get_mpic()
select type(ch => comm_handle)
type is(psb_comm_rma_handle)
rma_handle => ch
class default
info = psb_err_mpi_error_
call psb_errpush(info,name,a_err='Expected RMA comm_handle for pull mode')
goto 9999
end select
do_start = (swap_status == psb_comm_status_start_) .or. (swap_status == psb_comm_status_sync_)
do_wait = (swap_status == psb_comm_status_wait_) .or. (swap_status == psb_comm_status_sync_)
call comm_indexes%sync()
! START phase: build the layout once, prepare the exposed buffer, then issue RMA.
if (do_start) then
buffer_size = total_send + total_recv
memory_buffer_layout_rebuild_needed = (.not. rma_handle%layout_ready) .or. &
& (rma_handle%layout_nnbr /= num_neighbors) .or. &
& (rma_handle%layout_send /= total_send) .or. &
& (rma_handle%layout_recv /= total_recv)
if (memory_buffer_layout_rebuild_needed) then
if (allocated(peer_mpi_rank)) deallocate(peer_mpi_rank)
if (num_neighbors > 0) then
allocate(peer_mpi_rank(num_neighbors), stat=iret)
if (iret /= 0) then
info = psb_err_alloc_dealloc_
call psb_errpush(info,name,a_err='RMA pull rank cache allocation')
goto 9999
end if
end if
list_pos = 1
do neighbor_idx = 1, num_neighbors
proc_to_comm = comm_indexes%v(list_pos+psb_proc_id_)
peer_mpi_rank(neighbor_idx) = psb_get_mpi_rank(ctxt,proc_to_comm)
recv_count = comm_indexes%v(list_pos+psb_n_elem_recv_)
send_count = comm_indexes%v(list_pos+recv_count+psb_n_elem_send_)
list_pos = list_pos + recv_count + send_count + 3
end do
call rma_handle%init_memory_buffer_layout(info, comm_indexes%v, peer_mpi_rank, &
& num_neighbors, total_send, total_recv, my_rank, icomm)
if (info /= psb_success_) then
call psb_errpush(info,name,a_err='RMA pull init_memory_buffer_layout failure')
goto 9999
end if
end if
if (buffer_size > 0) then
if (.not. allocated(y%combuf)) then
call y%new_buffer(buffer_size, info)
if (info /= psb_success_) then
call psb_errpush(psb_err_alloc_dealloc_,name)
goto 9999
end if
else if (size(y%combuf) < buffer_size) then
! Need a larger exposed memory area: recreate the RMA window first,
! then reallocate combuf and lazily create a new window below.
if (rma_handle%window_open) then
call mpi_win_unlock_all(rma_handle%win, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info,name,m_err=(/iret/))
goto 9999
end if
rma_handle%window_open = .false.
end if
if (rma_handle%window_ready) then
call mpi_win_free(rma_handle%win, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info,name,m_err=(/iret/))
goto 9999
end if
rma_handle%window_ready = .false.
rma_handle%win = mpi_win_null
end if
call y%new_buffer(buffer_size, info)
if (info /= psb_success_) then
call psb_errpush(psb_err_alloc_dealloc_,name)
goto 9999
end if
end if
end if
if ((buffer_size > 0).and.(.not. rma_handle%window_ready)) then
! Expose combuf once and keep the window around until the descriptor changes.
element_bytes = storage_size(y%combuf(1))/8
exposed_bytes = int(size(y%combuf),kind=MPI_ADDRESS_KIND) * int(element_bytes,kind=MPI_ADDRESS_KIND)
call mpi_win_create(y%combuf, exposed_bytes, element_bytes, &
& mpi_info_null, ctxt%get_mpic(), rma_handle%win, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info,name,m_err=(/iret/))
goto 9999
end if
rma_handle%window_ready = .true.
end if
if (buffer_size > 0) then
if (total_send > 0) then
call y%gth(int(total_send,psb_mpk_), rma_handle%peer_send_indexes, y%combuf(1:total_send))
end if
call y%device_wait()
call mpi_win_lock_all(0, rma_handle%win, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info,name,m_err=(/iret/))
goto 9999
end if
rma_handle%window_open = .true.
! Pull data from each peer. The metadata exchange stays local and simple.
do neighbor_idx=1, num_neighbors
proc_to_comm = rma_handle%peer_proc(neighbor_idx)
recv_count = rma_handle%peer_recv_counts(neighbor_idx)
send_count = rma_handle%peer_send_counts(neighbor_idx)
prc_rank = rma_handle%peer_mpi_rank(neighbor_idx)
send_pos = rma_handle%peer_send_displs(neighbor_idx) + 1
recv_pos = total_send + rma_handle%peer_recv_displs(neighbor_idx) + 1
if (proc_to_comm /= my_rank) then
remote_base = rma_handle%peer_remote_send_displs(neighbor_idx)
if (remote_base < 1) then
info = psb_err_internal_error_
call psb_errpush(info,name,a_err='Invalid remote metadata in RMA pull')
goto 9999
end if
if ((recv_pos < 1) .or. (recv_count < 0) .or. (recv_pos+max(0,recv_count)-1 > size(y%combuf))) then
info = psb_err_internal_error_
call psb_errpush(info,name,a_err='RMA pull local receive bounds error')
goto 9999
end if
if (recv_count > 0) then
remote_disp = int(remote_base - 1, kind=MPI_ADDRESS_KIND)
call mpi_get(y%combuf(recv_pos), recv_count, psb_mpi_r_dpk_, prc_rank, remote_disp, recv_count, psb_mpi_r_dpk_, &
& rma_handle%win, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info,name,m_err=(/iret/))
goto 9999
end if
end if
call mpi_win_flush(prc_rank, rma_handle%win, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info,name,m_err=(/iret/))
goto 9999
end if
else
if (send_count /= recv_count) then
info = psb_err_internal_error_
call psb_errpush(info,name,a_err='RMA pull self-copy mismatch')
goto 9999
end if
y%combuf(recv_pos:recv_pos+recv_count-1) = y%combuf(send_pos:send_pos+send_count-1)
end if
end do
end if
end if
! WAIT phase: close the epoch and scatter the received slice back into Y.
if (do_wait) then
if (rma_handle%window_open) then
call mpi_win_unlock_all(rma_handle%win, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info,name,m_err=(/iret/))
goto 9999
end if
rma_handle%window_open = .false.
end if
if (total_recv > 0) then
call y%sct(int(total_recv,psb_mpk_), rma_handle%peer_recv_indexes, y%combuf(total_send+1:total_send+total_recv), beta)
end if
call y%device_wait()
end if
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine psi_dswap_rma_pull_vect
subroutine psi_dswap_rma_push_vect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,total_send,total_recv,comm_handle,info)
#ifdef PSB_MPI_MOD
use mpi
#endif
implicit none
#ifdef PSB_MPI_H
include 'mpif.h'
#endif
type(psb_ctxt_type), intent(in) :: ctxt
integer(psb_ipk_), intent(in) :: swap_status
real(psb_dpk_), intent(in) :: beta
class(psb_d_base_vect_type), intent(inout) :: y
class(psb_i_base_vect_type), intent(inout) :: comm_indexes
integer(psb_ipk_), intent(in) :: num_neighbors, total_send, total_recv
class(psb_comm_handle_type), intent(inout) :: comm_handle
integer(psb_ipk_), intent(out) :: info
integer(psb_mpk_) :: np, my_rank, iret, element_bytes, icomm
integer(psb_mpk_) :: proc_to_comm, prc_rank, recv_count, send_count, send_pos, recv_pos, list_pos
integer(psb_mpk_) :: remote_base
integer(kind=MPI_ADDRESS_KIND) :: remote_disp, exposed_bytes
integer(psb_ipk_) :: err_act, neighbor_idx, buffer_size
integer(psb_ipk_), allocatable :: peer_mpi_rank(:)
logical :: do_start, do_wait, memory_buffer_layout_rebuild_needed
type(psb_comm_rma_handle), pointer :: rma_handle
character(len=30) :: name
info = psb_success_
name = 'psi_dswap_rma_push_vect'
call psb_erractionsave(err_act)
call psb_info(ctxt,my_rank,np)
if (np == -1) then
info = psb_err_context_error_
call psb_errpush(info,name)
goto 9999
end if
icomm = ctxt%get_mpic()
select type(ch => comm_handle)
type is(psb_comm_rma_handle)
rma_handle => ch
class default
info = psb_err_mpi_error_
call psb_errpush(info,name,a_err='Expected RMA comm_handle for push mode')
goto 9999
end select
do_start = (swap_status == psb_comm_status_start_) .or. (swap_status == psb_comm_status_sync_)
do_wait = (swap_status == psb_comm_status_wait_) .or. (swap_status == psb_comm_status_sync_)
call comm_indexes%sync()
! START phase: identical layout handling, but the remote metadata describes the receive side.
if (do_start) then
buffer_size = total_send + total_recv
memory_buffer_layout_rebuild_needed = (.not. rma_handle%layout_ready) .or. &
& (rma_handle%layout_nnbr /= num_neighbors) .or. &
& (rma_handle%layout_send /= total_send) .or. &
& (rma_handle%layout_recv /= total_recv)
if (memory_buffer_layout_rebuild_needed) then
if (allocated(peer_mpi_rank)) deallocate(peer_mpi_rank)
if (num_neighbors > 0) then
allocate(peer_mpi_rank(num_neighbors), stat=iret)
if (iret /= 0) then
info = psb_err_alloc_dealloc_
call psb_errpush(info,name,a_err='RMA put rank cache allocation')
goto 9999
end if
end if
list_pos = 1
do neighbor_idx = 1, num_neighbors
proc_to_comm = comm_indexes%v(list_pos+psb_proc_id_)
peer_mpi_rank(neighbor_idx) = psb_get_mpi_rank(ctxt,proc_to_comm)
recv_count = comm_indexes%v(list_pos+psb_n_elem_recv_)
send_count = comm_indexes%v(list_pos+recv_count+psb_n_elem_send_)
list_pos = list_pos + recv_count + send_count + 3
end do
call rma_handle%init_memory_buffer_layout(info, comm_indexes%v, peer_mpi_rank, &
& num_neighbors, total_send, total_recv, my_rank, icomm)
if (info /= psb_success_) then
call psb_errpush(info,name,a_err='RMA put ini_memory_buffer_layout')
goto 9999
end if
end if
if (buffer_size > 0) then
if (.not. allocated(y%combuf)) then
call y%new_buffer(buffer_size, info)
if (info /= psb_success_) then
call psb_errpush(psb_err_alloc_dealloc_,name)
goto 9999
end if
else if (size(y%combuf) < buffer_size) then
! Need a larger exposed memory area: recreate the RMA window first,
! then reallocate combuf and lazily create a new window below.
if (rma_handle%window_open) then
call mpi_win_unlock_all(rma_handle%win, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info,name,m_err=(/iret/))
goto 9999
end if
rma_handle%window_open = .false.
end if
if (rma_handle%window_ready) then
call mpi_win_free(rma_handle%win, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info,name,m_err=(/iret/))
goto 9999
end if
rma_handle%window_ready = .false.
rma_handle%win = mpi_win_null
end if
call y%new_buffer(buffer_size, info)
if (info /= psb_success_) then
call psb_errpush(psb_err_alloc_dealloc_,name)
goto 9999
end if
end if
end if
if ((buffer_size > 0).and.(.not. rma_handle%window_ready)) then
! Keep the window alive across repetitions: it is created once and reused.
element_bytes = storage_size(y%combuf(1))/8
exposed_bytes = int(size(y%combuf),kind=MPI_ADDRESS_KIND) * int(element_bytes,kind=MPI_ADDRESS_KIND)
call mpi_win_create(y%combuf, exposed_bytes, element_bytes, &
& mpi_info_null, ctxt%get_mpic(), rma_handle%win, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info,name,m_err=(/iret/))
goto 9999
end if
rma_handle%window_ready = .true.
end if
if (buffer_size > 0) then
if (total_send > 0) then
call y%gth(int(total_send,psb_mpk_), rma_handle%peer_send_indexes, y%combuf(1:total_send))
end if
call y%device_wait()
call mpi_win_lock_all(0, rma_handle%win, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info,name,m_err=(/iret/))
goto 9999
end if
rma_handle%window_open = .true.
! Push data to each peer. Only the base displacement and count are exchanged.
do neighbor_idx=1, num_neighbors
proc_to_comm = rma_handle%peer_proc(neighbor_idx)
recv_count = rma_handle%peer_recv_counts(neighbor_idx)
send_count = rma_handle%peer_send_counts(neighbor_idx)
prc_rank = rma_handle%peer_mpi_rank(neighbor_idx)
send_pos = rma_handle%peer_send_displs(neighbor_idx) + 1
recv_pos = total_send + rma_handle%peer_recv_displs(neighbor_idx) + 1
if (proc_to_comm /= my_rank) then
remote_base = rma_handle%peer_remote_recv_displs(neighbor_idx)
if (remote_base < 1) then
info = psb_err_internal_error_
call psb_errpush(info,name,a_err='Invalid remote metadata in RMA push')
goto 9999
end if
if ((send_pos < 1) .or. (send_count < 0) .or. (send_pos+max(0,send_count)-1 > size(y%combuf))) then
info = psb_err_internal_error_
call psb_errpush(info,name,a_err='RMA push local send bounds error')
goto 9999
end if
if (send_count > 0) then
remote_disp = int(remote_base - 1, kind=MPI_ADDRESS_KIND)
call mpi_put(y%combuf(send_pos), send_count, psb_mpi_r_dpk_, prc_rank, remote_disp, send_count, psb_mpi_r_dpk_, &
& rma_handle%win, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info,name,m_err=(/iret/))
goto 9999
end if
end if
call mpi_win_flush(prc_rank, rma_handle%win, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info,name,m_err=(/iret/))
goto 9999
end if
else
if (send_count /= recv_count) then
info = psb_err_internal_error_
call psb_errpush(info,name,a_err='RMA push self-copy mismatch')
goto 9999
end if
y%combuf(recv_pos:recv_pos+recv_count-1) = y%combuf(send_pos:send_pos+send_count-1)
end if
end do
end if
end if
! WAIT phase: close the epoch and apply the receive-side scatter.
if (do_wait) then
if (rma_handle%window_open) then
call mpi_win_unlock_all(rma_handle%win, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info,name,m_err=(/iret/))
goto 9999
end if
rma_handle%window_open = .false.
end if
call mpi_barrier(icomm, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info,name,m_err=(/iret/))
goto 9999
end if
if (total_recv > 0) then
call y%sct(int(total_recv,psb_mpk_), rma_handle%peer_recv_indexes, y%combuf(total_send+1:total_send+total_recv), beta)
end if
call y%device_wait()
end if
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine psi_dswap_rma_push_vect
!
!
@ -1007,6 +1495,10 @@ contains
ineighbor_a2av = .true.
case(psb_comm_persistent_ineighbor_alltoallv_)
ineighbor_a2av_persistent = .true.
case(psb_comm_rma_pull_, psb_comm_rma_push_)
info = psb_err_mpi_error_
call psb_errpush(info,name,a_err='RMA swap is not yet enabled for multivectors')
goto 9999
case default
baseline = .true.
end select

@ -31,6 +31,7 @@ SERIAL_MODS=serial/psb_s_serial_mod.o serial/psb_d_serial_mod.o \
comm/comm_schemes/psb_comm_schemes_mod.o \
comm/comm_schemes/psb_comm_baseline_mod.o \
comm/comm_schemes/psb_comm_neighbor_impl_mod.o \
comm/comm_schemes/psb_comm_rma_mod.o \
comm/comm_schemes/psb_comm_factory_mod.o \
serial/psb_i_base_vect_mod.o serial/psb_i_vect_mod.o\
serial/psb_l_base_vect_mod.o serial/psb_l_vect_mod.o\
@ -193,7 +194,9 @@ comm/comm_schemes/psb_comm_baseline_mod.o: comm/comm_schemes/psb_comm_schemes_mo
comm/comm_schemes/psb_comm_neighbor_impl_mod.o: psb_const_mod.o desc/psb_desc_const_mod.o comm/comm_schemes/psb_comm_schemes_mod.o
comm/comm_schemes/psb_comm_factory_mod.o: comm/comm_schemes/psb_comm_schemes_mod.o comm/comm_schemes/psb_comm_baseline_mod.o comm/comm_schemes/psb_comm_neighbor_impl_mod.o
comm/comm_schemes/psb_comm_rma_mod.o: comm/comm_schemes/psb_comm_schemes_mod.o psb_const_mod.o desc/psb_desc_const_mod.o
comm/comm_schemes/psb_comm_factory_mod.o: comm/comm_schemes/psb_comm_schemes_mod.o comm/comm_schemes/psb_comm_baseline_mod.o comm/comm_schemes/psb_comm_neighbor_impl_mod.o comm/comm_schemes/psb_comm_rma_mod.o
comm/psb_neighbor_topology_mod.o: psb_const_mod.o desc/psb_desc_const_mod.o

@ -2,10 +2,12 @@ module psb_comm_factory_mod
use psb_const_mod
use psb_comm_schemes_mod, only: psb_comm_handle_type, psb_comm_isend_irecv_, &
& psb_comm_ineighbor_alltoallv_, psb_comm_persistent_ineighbor_alltoallv_, &
& psb_comm_rma_pull_, psb_comm_rma_push_, &
& psb_comm_unknown_, psb_comm_status_start_, psb_comm_status_wait_, &
& psb_comm_status_sync_, psb_comm_status_unknown_
use psb_comm_baseline_mod, only: psb_comm_baseline_handle
use psb_comm_neighbor_impl_mod, only: psb_comm_neighbor_handle
use psb_comm_rma_mod, only: psb_comm_rma_handle
implicit none
contains
@ -37,6 +39,8 @@ contains
type is(psb_comm_neighbor_handle)
h%comm_type = comm_type
h%use_persistent_buffers = (comm_type == psb_comm_persistent_ineighbor_alltoallv_)
type is(psb_comm_rma_handle)
h%comm_type = comm_type
class default
! nothing else to configure
end select
@ -60,6 +64,17 @@ contains
h%comm_type = comm_type
h%use_persistent_buffers = (comm_type == psb_comm_persistent_ineighbor_alltoallv_)
end select
case(psb_comm_rma_pull_, psb_comm_rma_push_)
allocate(psb_comm_rma_handle :: handle, stat=info)
if (info /= 0) return
call handle%init(info)
if (info /= 0) return
handle%id = old_id
handle%swap_status = old_swap_status
select type(h => handle)
type is(psb_comm_rma_handle)
h%comm_type = comm_type
end select
case default
allocate(psb_comm_baseline_handle :: handle, stat=info)
if (info /= 0) return

@ -0,0 +1,240 @@
module psb_comm_rma_mod
use psb_const_mod
use psb_desc_const_mod, only: psb_proc_id_, psb_n_elem_recv_, psb_elem_recv_, &
& psb_n_elem_send_, psb_elem_send_
use psb_error_mod
#ifdef PSB_MPI_MOD
use mpi
#endif
use psb_comm_schemes_mod, only: psb_comm_handle_type, psb_comm_rma_pull_, psb_comm_rma_push_, &
& psb_comm_unknown_
implicit none
integer(psb_mpk_), parameter :: psb_rma_meta_tag = 913
#ifdef PSB_MPI_H
include 'mpif.h'
#endif
type, extends(psb_comm_handle_type) :: psb_comm_rma_handle
integer(psb_mpk_) :: win = mpi_win_null
logical :: window_ready = .false.
logical :: window_open = .false.
logical :: layout_ready = .false.
integer(psb_ipk_) :: layout_nnbr = -1
integer(psb_ipk_) :: layout_send = -1
integer(psb_ipk_) :: layout_recv = -1
integer(psb_ipk_), allocatable :: peer_proc(:)
integer(psb_ipk_), allocatable :: peer_send_counts(:)
integer(psb_ipk_), allocatable :: peer_recv_counts(:)
integer(psb_ipk_), allocatable :: peer_send_displs(:)
integer(psb_ipk_), allocatable :: peer_recv_displs(:)
integer(psb_ipk_), allocatable :: peer_mpi_rank(:)
integer(psb_ipk_), allocatable :: peer_remote_send_displs(:)
integer(psb_ipk_), allocatable :: peer_remote_recv_displs(:)
integer(psb_ipk_), allocatable :: peer_send_indexes(:)
integer(psb_ipk_), allocatable :: peer_recv_indexes(:)
contains
procedure, pass :: init => psb_comm_rma_init
procedure, pass :: free => psb_comm_rma_free
procedure, pass :: clear_memory_buffer_layout => psb_comm_rma_clear_memory_buffer_layout
procedure, pass :: init_memory_buffer_layout => psb_comm_rma_ini_memory_buffer_layout
procedure, pass :: set_swap_status => psb_comm_rma_set_swap_status
procedure, pass :: get_swap_status => psb_comm_rma_get_swap_status
end type psb_comm_rma_handle
contains
subroutine psb_comm_rma_init(this, info)
class(psb_comm_rma_handle), intent(inout) :: this
integer(psb_ipk_), intent(out) :: info
info = 0
this%comm_type = psb_comm_unknown_
this%id = 0
this%swap_status = 0
this%win = mpi_win_null
this%window_ready = .false.
this%window_open = .false.
this%layout_ready = .false.
this%layout_nnbr = -1
this%layout_send = -1
this%layout_recv = -1
call this%clear_memory_buffer_layout(info)
end subroutine psb_comm_rma_init
subroutine psb_comm_rma_clear_memory_buffer_layout(this, info)
class(psb_comm_rma_handle), intent(inout) :: this
integer(psb_ipk_), intent(out) :: info
info = psb_success_
if (allocated(this%peer_proc)) deallocate(this%peer_proc)
if (allocated(this%peer_send_counts)) deallocate(this%peer_send_counts)
if (allocated(this%peer_recv_counts)) deallocate(this%peer_recv_counts)
if (allocated(this%peer_send_displs)) deallocate(this%peer_send_displs)
if (allocated(this%peer_recv_displs)) deallocate(this%peer_recv_displs)
if (allocated(this%peer_mpi_rank)) deallocate(this%peer_mpi_rank)
if (allocated(this%peer_remote_send_displs)) deallocate(this%peer_remote_send_displs)
if (allocated(this%peer_remote_recv_displs)) deallocate(this%peer_remote_recv_displs)
if (allocated(this%peer_send_indexes)) deallocate(this%peer_send_indexes)
if (allocated(this%peer_recv_indexes)) deallocate(this%peer_recv_indexes)
this%layout_ready = .false.
this%layout_nnbr = -1
this%layout_send = -1
this%layout_recv = -1
end subroutine psb_comm_rma_clear_memory_buffer_layout
subroutine psb_comm_rma_ini_memory_buffer_layout(this, info, comm_list, peer_mpi_rank, &
& num_neighbors, total_send, total_recv, my_rank, icomm)
class(psb_comm_rma_handle), intent(inout) :: this
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_), intent(in) :: comm_list(:)
integer(psb_ipk_), intent(in) :: peer_mpi_rank(:)
integer(psb_ipk_), intent(in) :: num_neighbors, total_send, total_recv
integer(psb_ipk_), intent(in) :: my_rank
integer(psb_mpk_), intent(in) :: icomm
integer(psb_mpk_) :: iret, p2pstat(mpi_status_size), prc_rank
integer(psb_ipk_) :: n_neighbors, send_total, recv_total
integer(psb_ipk_) :: neighbor_idx, item_idx
integer(psb_ipk_) :: list_pos, send_offset, recv_offset
integer(psb_ipk_) :: proc_to_comm, recv_count, send_count
integer(psb_ipk_) :: local_meta(4), remote_meta(4)
call this%clear_memory_buffer_layout(info)
if (info /= psb_success_) return
n_neighbors = num_neighbors
send_total = total_send
recv_total = total_recv
if (n_neighbors > 0) then
allocate(this%peer_proc(n_neighbors), this%peer_send_counts(n_neighbors), &
& this%peer_recv_counts(n_neighbors), this%peer_send_displs(n_neighbors), &
& this%peer_recv_displs(n_neighbors), this%peer_mpi_rank(n_neighbors), &
& this%peer_remote_send_displs(n_neighbors), this%peer_remote_recv_displs(n_neighbors), stat=iret)
if (iret /= 0) then
info = psb_err_alloc_dealloc_
return
end if
end if
if (send_total > 0) then
allocate(this%peer_send_indexes(send_total), stat=iret)
if (iret /= 0) then
info = psb_err_alloc_dealloc_
return
end if
end if
if (recv_total > 0) then
allocate(this%peer_recv_indexes(recv_total), stat=iret)
if (iret /= 0) then
info = psb_err_alloc_dealloc_
return
end if
end if
list_pos = 1
send_offset = 0
recv_offset = 0
do neighbor_idx = 1, n_neighbors
proc_to_comm = comm_list(list_pos + psb_proc_id_)
recv_count = comm_list(list_pos + psb_n_elem_recv_)
send_count = comm_list(list_pos + recv_count + psb_n_elem_send_)
this%peer_proc(neighbor_idx) = proc_to_comm
this%peer_recv_counts(neighbor_idx) = recv_count
this%peer_send_counts(neighbor_idx) = send_count
this%peer_recv_displs(neighbor_idx) = recv_offset
this%peer_send_displs(neighbor_idx) = send_offset
this%peer_mpi_rank(neighbor_idx) = peer_mpi_rank(neighbor_idx)
if (recv_count > 0) then
do item_idx = 1, recv_count
this%peer_recv_indexes(recv_offset + item_idx) = comm_list(list_pos + psb_elem_recv_ + item_idx - 1)
end do
end if
if (send_count > 0) then
do item_idx = 1, send_count
this%peer_send_indexes(send_offset + item_idx) = comm_list(list_pos + recv_count + psb_elem_send_ + item_idx - 1)
end do
end if
recv_offset = recv_offset + recv_count
send_offset = send_offset + send_count
list_pos = list_pos + recv_count + send_count + 3
end do
do neighbor_idx = 1, n_neighbors
proc_to_comm = this%peer_proc(neighbor_idx)
if (proc_to_comm /= my_rank) then
prc_rank = this%peer_mpi_rank(neighbor_idx)
local_meta = (/ this%peer_send_displs(neighbor_idx)+1, this%peer_send_counts(neighbor_idx), &
& this%peer_recv_displs(neighbor_idx)+1+send_total, this%peer_recv_counts(neighbor_idx) /)
call mpi_sendrecv(local_meta, 4, psb_mpi_mpk_, prc_rank, psb_rma_meta_tag, &
& remote_meta, 4, psb_mpi_mpk_, prc_rank, psb_rma_meta_tag, icomm, p2pstat, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
return
end if
this%peer_remote_send_displs(neighbor_idx) = remote_meta(1)
this%peer_remote_recv_displs(neighbor_idx) = remote_meta(3)
else
this%peer_remote_send_displs(neighbor_idx) = this%peer_send_displs(neighbor_idx)+1
this%peer_remote_recv_displs(neighbor_idx) = this%peer_recv_displs(neighbor_idx)+1+send_total
end if
end do
if ((send_offset /= send_total) .or. (recv_offset /= recv_total)) then
info = psb_err_internal_error_
return
end if
this%layout_nnbr = n_neighbors
this%layout_send = send_total
this%layout_recv = recv_total
this%layout_ready = .true.
end subroutine psb_comm_rma_ini_memory_buffer_layout
subroutine psb_comm_rma_free(this, info)
#ifdef PSB_MPI_MOD
use mpi
#endif
class(psb_comm_rma_handle), intent(inout) :: this
integer(psb_ipk_), intent(out) :: info
integer(psb_mpk_) :: iret
info = 0
if (this%window_open) then
call mpi_win_unlock_all(this%win, iret)
this%window_open = .false.
end if
if (this%win /= mpi_win_null) then
call mpi_win_free(this%win, iret)
this%win = mpi_win_null
end if
this%window_ready = .false.
this%layout_ready = .false.
this%layout_nnbr = -1
this%layout_send = -1
this%layout_recv = -1
call this%clear_memory_buffer_layout(info)
end subroutine psb_comm_rma_free
subroutine psb_comm_rma_set_swap_status(this, flag, info)
class(psb_comm_rma_handle), intent(inout) :: this
integer(psb_ipk_), intent(in) :: flag
integer(psb_ipk_), intent(out) :: info
info = 0
this%swap_status = flag
end subroutine psb_comm_rma_set_swap_status
subroutine psb_comm_rma_get_swap_status(this, flag, info)
class(psb_comm_rma_handle), intent(in) :: this
integer(psb_ipk_), intent(out) :: flag
integer(psb_ipk_), intent(out) :: info
info = 0
flag = this%swap_status
end subroutine psb_comm_rma_get_swap_status
end module psb_comm_rma_mod

@ -11,6 +11,8 @@ module psb_comm_schemes_mod
enumerator psb_comm_isend_irecv_
enumerator psb_comm_ineighbor_alltoallv_
enumerator psb_comm_persistent_ineighbor_alltoallv_
enumerator psb_comm_rma_pull_
enumerator psb_comm_rma_push_
end enum
enum, bind(c)

@ -534,7 +534,7 @@ contains
if (allocated(a%irn)) deallocate(a%irn)
if (allocated(a%ja)) deallocate(a%ja)
if (allocated(a%val)) deallocate(a%val)
if (allocated(a%val)) deallocate(a%hkoffs)
if (allocated(a%hkoffs)) deallocate(a%hkoffs)
call a%set_null()
call a%set_nrows(izero)
call a%set_ncols(izero)

@ -534,7 +534,7 @@ contains
if (allocated(a%irn)) deallocate(a%irn)
if (allocated(a%ja)) deallocate(a%ja)
if (allocated(a%val)) deallocate(a%val)
if (allocated(a%val)) deallocate(a%hkoffs)
if (allocated(a%hkoffs)) deallocate(a%hkoffs)
call a%set_null()
call a%set_nrows(izero)
call a%set_ncols(izero)

@ -534,7 +534,7 @@ contains
if (allocated(a%irn)) deallocate(a%irn)
if (allocated(a%ja)) deallocate(a%ja)
if (allocated(a%val)) deallocate(a%val)
if (allocated(a%val)) deallocate(a%hkoffs)
if (allocated(a%hkoffs)) deallocate(a%hkoffs)
call a%set_null()
call a%set_nrows(izero)
call a%set_ncols(izero)

@ -534,7 +534,7 @@ contains
if (allocated(a%irn)) deallocate(a%irn)
if (allocated(a%ja)) deallocate(a%ja)
if (allocated(a%val)) deallocate(a%val)
if (allocated(a%val)) deallocate(a%hkoffs)
if (allocated(a%hkoffs)) deallocate(a%hkoffs)
call a%set_null()
call a%set_nrows(izero)
call a%set_ncols(izero)

@ -29,7 +29,7 @@ program psb_comm_cg_test
integer(psb_ipk_) :: desc_me, desc_np
integer(psb_ipk_) :: idim, itmax, itrace, istop, iter
integer(psb_ipk_) :: scheme_idx, prec_idx, rep, nrep, nwarm
integer(psb_ipk_), parameter :: n_schemes=3, n_precs=2
integer(psb_ipk_), parameter :: n_schemes=5, n_precs=2
integer(psb_ipk_), allocatable :: iter_count(:,:,:), solve_info(:,:,:)
integer(psb_ipk_) :: scheme_type(n_schemes)
real(psb_dpk_) :: eps, err, t_start, t_elapsed
@ -71,10 +71,13 @@ program psb_comm_cg_test
use_gpu = .false.
#endif
scheme_type = (/ psb_comm_isend_irecv_, psb_comm_ineighbor_alltoallv_, &
& psb_comm_persistent_ineighbor_alltoallv_ /)
& psb_comm_persistent_ineighbor_alltoallv_ , psb_comm_rma_pull_, psb_comm_rma_push_ /)
scheme_name(1) = 'isend_irecv'
scheme_name(2) = 'ineighbor_alltoallv'
scheme_name(3) = 'persistent_ineighbor_a2av'
scheme_name(4) = 'psb_comm_rma_pull_'
scheme_name(5) = 'psb_comm_rma_push_'
prec_type(1) = 'NONE'
prec_type(2) = 'DIAG'
prec_name(1) = 'none'
@ -201,6 +204,10 @@ program psb_comm_cg_test
do prec_idx = 1, n_precs
do scheme_idx = 1, n_schemes
do rep = 1, nrep
t_start = psb_wtime()
call psb_comm_set(scheme_type(scheme_idx),x%v%comm_handle,info)
comm_set_time(prec_idx,scheme_idx,rep) = psb_wtime() - t_start
call psb_geaxpby(dzero,b,dzero,x,desc_a,info)
if (info /= psb_success_) goto 9999
@ -225,9 +232,7 @@ program psb_comm_cg_test
goto 9999
end if
t_start = psb_wtime()
call psb_comm_set(scheme_type(scheme_idx),x%v%comm_handle,info)
comm_set_time(prec_idx,scheme_idx,rep) = psb_wtime() - t_start
call psb_amx(ctxt,comm_set_time(prec_idx,scheme_idx,rep))
if (info /= psb_success_) goto 9999

@ -5,10 +5,12 @@ module psb_spmv_overlap_test
use psb_base_mod
use psb_util_mod
use psb_comm_factory_mod, only: psb_comm_set
use psb_comm_schemes_mod, only: psb_comm_isend_irecv_, psb_comm_ineighbor_alltoallv_, &
& psb_comm_persistent_ineighbor_alltoallv_
& psb_comm_persistent_ineighbor_alltoallv_, psb_comm_rma_pull_, psb_comm_rma_push_, &
& psb_comm_handle_type
use psb_comm_baseline_mod, only: psb_comm_baseline_handle
use psb_comm_neighbor_impl_mod, only: psb_comm_neighbor_handle
use psb_comm_rma_mod, only: psb_comm_rma_handle
#ifdef PSB_HAVE_CUDA
use psb_cuda_mod
#endif
@ -529,7 +531,7 @@ contains
return
end subroutine psb_d_gen_pde3d
subroutine run_spmv_kernel(ctxt,use_gpu,matrix_file,matrix_fmt,cpu_fmt,gpu_fmt,idim_in,times_in,do_swap)
subroutine run_spmv_kernel(ctxt,use_gpu,matrix_file,matrix_fmt,cpu_fmt,gpu_fmt,idim_in,times_in,do_swap,comm_mode)
use psb_base_mod
#ifdef PSB_HAVE_CUDA
use psb_cuda_mod
@ -544,6 +546,7 @@ contains
character(len=*), intent(in) :: gpu_fmt
integer(psb_ipk_), intent(in) :: idim_in, times_in
logical, intent(in) :: do_swap
character(len=*), intent(in) :: comm_mode
type(psb_dspmat_type) :: a
type(psb_d_vect_type) :: x, y
@ -556,16 +559,36 @@ contains
real(psb_dpk_) :: alpha, beta, t0, t1, dt, avg_t
logical :: use_external_matrix
integer(psb_ipk_) :: comm_type
#ifdef PSB_HAVE_CUDA
type(psb_d_vect_cuda) :: cuda_vector_mold
type(psb_i_vect_cuda) :: cuda_index_mold
type(psb_d_cuda_elg_sparse_mat), target :: cuda_ell_sparse_mold
type(psb_d_cuda_csrg_sparse_mat), target :: cuda_csr_sparse_mold
type(psb_d_cuda_hdiag_sparse_mat), target :: cuda_hdia_sparse_mold
type(psb_d_cuda_hlg_sparse_mat), target :: cuda_hll_sparse_mold
class(psb_d_base_sparse_mat), pointer :: cuda_sparse_mold
type(psb_d_vect_cuda) :: cuda_vector_mold
type(psb_i_vect_cuda) :: cuda_index_mold
type(psb_d_cuda_elg_sparse_mat), target :: cuda_ell_sparse_mold
type(psb_d_cuda_csrg_sparse_mat), target :: cuda_csr_sparse_mold
type(psb_d_cuda_hdiag_sparse_mat), target :: cuda_hdia_sparse_mold
type(psb_d_cuda_hlg_sparse_mat), target :: cuda_hll_sparse_mold
class(psb_d_base_sparse_mat), pointer :: cuda_sparse_mold
#endif
select case(psb_toupper(trim(comm_mode)))
case('P2P','ISEND_IRECV')
comm_type = psb_comm_isend_irecv_
case('NEIGHBOR','INEIGHBOR_ALLTOALLV')
comm_type = psb_comm_ineighbor_alltoallv_
case('PNEIGHBOR','PERSISTENT','PERSISTENT_INEIGHBOR_A2AV')
comm_type = psb_comm_persistent_ineighbor_alltoallv_
case('MPI_GET','RMA_PULL')
comm_type = psb_comm_rma_pull_
case('MPI_PUT','RMA_PUSH')
comm_type = psb_comm_rma_push_
case default
comm_type = psb_comm_isend_irecv_
if (my_rank == psb_root_) then
write(psb_err_unit,'("Unknown comm backend: ",a,", defaulting to P2P")') trim(comm_mode)
end if
end select
info = psb_success_
afmt = psb_toupper(trim(cpu_fmt))
if (len_trim(afmt) == 0) afmt = 'CSR'
@ -613,6 +636,9 @@ contains
end if
if (info /= psb_success_) goto 9999
call psb_comm_set(comm_type,x%v%comm_handle,info)
if (info /= psb_success_) goto 9999
#ifdef PSB_HAVE_CUDA
if (use_gpu) then
select case(psb_toupper(trim(gpu_fmt)))
@ -669,6 +695,7 @@ contains
end if
write(psb_out_unit,'(" global unknowns : ",i0)') n_global
write(psb_out_unit,'(" repetitions : ",i0)') times
write(psb_out_unit,'(" comm backend : ",a)') trim(psb_toupper(trim(comm_mode)))
write(psb_out_unit,'(" total time [s] : ",es12.5)') dt
write(psb_out_unit,'(" avg time [s] : ",es12.5)') avg_t
end if
@ -764,8 +791,13 @@ program psb_spmv_kernel
character(len=8) :: gpu_fmt
integer(psb_ipk_) :: idim_arg, times_arg
logical :: do_swap
idim_arg = -1
times_arg = -1
integer :: kmode
integer, parameter :: n_comm_modes = 5
character(len=20), parameter :: comm_modes(n_comm_modes) = [character(len=20) :: &
& 'P2P', 'NEIGHBOR', 'PNEIGHBOR', 'MPI_GET', 'MPI_PUT']
idim_arg = -1
times_arg = -1
matrix_file = ''
matrix_fmt = 'MM'
@ -869,12 +901,18 @@ program psb_spmv_kernel
write(psb_out_unit,*) 'Welcome to PSBLAS version: ', psb_version_string_
write(psb_out_unit,*) 'This is the psb_spmv_kernel sample program'
write(psb_out_unit,'("GPU enabled : ",l1)') use_gpu
write(psb_out_unit,'("Usage: ./psb_spmv_kernel [--gpu=TRUE|FALSE] [--dim=N] [--times=N] ",&
&"[--cpu_fmt=CSR|COO|CSC|ELL|HLL] [--gpu_fmt=HLL|ELL|CSR|HDIA] [--matrix=<path>] [--fmt=MM|HB] ",&
&"[--overlap|--nooverlap]")')
write(psb_out_unit,'("Usage: ./psb_spmv_kernel [--gpu=TRUE|FALSE] [--dim=N] [--times=N] ",&
&"[--cpu_fmt=CSR|COO|CSC|ELL|HLL] [--gpu_fmt=HLL|ELL|CSR|HDIA] [--matrix=<path>] [--fmt=MM|HB] ",&
&"[--overlap|--nooverlap] (runs all comm backends)")')
end if
call run_spmv_kernel(ctxt,use_gpu,matrix_file,matrix_fmt,cpu_fmt,gpu_fmt,idim_arg,times_arg,do_swap)
do kmode = 1, n_comm_modes
if (my_rank == psb_root_) then
write(psb_out_unit,'(/,"=== Backend sweep: ",a," ===")') trim(comm_modes(kmode))
end if
call run_spmv_kernel(ctxt, use_gpu, matrix_file, matrix_fmt, cpu_fmt, gpu_fmt, &
& idim_arg, times_arg, do_swap, comm_modes(kmode))
end do
#ifdef PSB_HAVE_CUDA
if (use_gpu) call psb_cuda_exit()

@ -22,7 +22,7 @@ program psb_comm_test
use psi_mod
use psb_comm_factory_mod, only: psb_comm_set, psb_comm_free
use psb_comm_schemes_mod, only: psb_comm_ineighbor_alltoallv_, psb_comm_persistent_ineighbor_alltoallv_, &
& psb_comm_isend_irecv_
& psb_comm_isend_irecv_, psb_comm_rma_pull_, psb_comm_rma_push_
use psb_comm_schemes_mod, only: psb_comm_status_start_, psb_comm_status_wait_, psb_comm_status_unknown_
implicit none
@ -47,11 +47,12 @@ program psb_comm_test
type(psb_ldspmat_type) :: aux_a
! ---- vectors ----
type(psb_d_vect_type) :: v_baseline, v_neighbor, v_neighbor_persistent
type(psb_d_vect_type) :: v_baseline, v_neighbor, v_neighbor_persistent, v_rma_get, v_rma_put
! ---- temporary / comparison arrays ----
real(psb_dpk_), allocatable :: vals(:)
real(psb_dpk_), allocatable :: result_baseline(:), result_neighbor(:), result_persistent(:)
real(psb_dpk_), allocatable :: result_baseline(:), result_neighbor(:), result_persistent(:), &
& result_rma_get(:), result_rma_put(:)
real(psb_dpk_), allocatable :: expected(:)
! ---- halo index bookkeeping ----
@ -60,11 +61,16 @@ program psb_comm_test
! ---- error / reporting ----
integer(psb_ipk_) :: n_pass, n_total, imode
logical :: run_baseline, run_neighbor, run_persistent
logical :: run_baseline, run_neighbor, run_persistent, run_rma_get, run_rma_put
logical :: mat_allocated
logical :: comm_ok
real(psb_dpk_) :: err, tol
real(psb_dpk_) :: t0, t1, dt, tsum_baseline, tsum_neighbor, tsum_neighbor_persistent
real(psb_dpk_) :: first_swap_baseline, first_swap_neighbor, first_swap_persistent, &
& first_swap_rma_get, first_swap_rma_put
real(psb_dpk_) :: comm_setup_time_baseline, comm_setup_time_neighbor, comm_setup_time_persistent, &
& comm_setup_time_rma_get, comm_setup_time_rma_put
real(psb_dpk_) :: t0, t1, dt, tsum_baseline, tsum_neighbor, tsum_neighbor_persistent, &
& tsum_rma_get, tsum_rma_put
integer(psb_lpk_), allocatable :: glob_col(:)
character(len=40) :: name
real(psb_dpk_) :: huge_d
@ -136,21 +142,33 @@ program psb_comm_test
run_baseline = .false.
run_neighbor = .false.
run_persistent = .false.
run_rma_get = .false.
run_rma_put = .false.
select case (trim(adjustl(mode)))
case ('both','all')
run_baseline = .true.
run_neighbor = .true.
run_persistent = .true.
run_baseline = .true.
run_neighbor = .true.
run_persistent = .true.
run_rma_get = .true.
run_rma_put = .true.
case ('baseline')
run_baseline = .true.
case ('neighbor')
run_neighbor = .true.
case ('persistent','persistent_neighbor','persistent-neighbor')
run_persistent = .true.
case ('rma_get')
run_rma_get = .true.
case ('rma_put')
run_rma_put = .true.
case default
run_baseline = .true.
run_neighbor = .true.
run_persistent = .true.
run_rma_get = .true.
run_rma_put = .true.
end select
if ((.not.use_external_matrix) .and. (idim <= 0)) then
@ -267,6 +285,18 @@ program psb_comm_test
write(psb_err_unit,*) my_rank, 'geall persistent-neighbor error:', info
call psb_abort(ctxt)
end if
call psb_geall(v_rma_get, desc_a, info)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'geall rma-get error:', info
call psb_abort(ctxt)
end if
call psb_geall(v_rma_put, desc_a, info)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'geall rma-put error:', info
call psb_abort(ctxt)
end if
call psb_geasb(v_baseline, desc_a, info, scratch=.true.)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'geasb baseline error:', info
@ -282,6 +312,18 @@ program psb_comm_test
write(psb_err_unit,*) my_rank, 'geasb persistent-neighbor error:', info
call psb_abort(ctxt)
end if
call psb_geasb(v_rma_get, desc_a, info, scratch=.true.)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'geasb rma-get error:', info
call psb_abort(ctxt)
end if
call psb_geasb(v_rma_put, desc_a, info, scratch=.true.)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'geasb rma-put error:', info
call psb_abort(ctxt)
end if
! Fill owned entries with the global index value
allocate(vals(ncol))
@ -292,6 +334,8 @@ program psb_comm_test
call v_baseline%set_vect(vals)
call v_neighbor%set_vect(vals)
call v_neighbor_persistent%set_vect(vals)
call v_rma_get%set_vect(vals)
call v_rma_put%set_vect(vals)
deallocate(vals)
! ==================================================================
@ -304,15 +348,39 @@ program psb_comm_test
do i = 1, ncol
expected(i) = real(glob_col(i), psb_dpk_)
end do
allocate(result_baseline(ncol), result_neighbor(ncol), result_persistent(ncol))
allocate(result_baseline(ncol), result_neighbor(ncol), result_persistent(ncol), &
& result_rma_get(ncol), result_rma_put(ncol))
result_baseline = huge_d
result_neighbor = huge_d
result_persistent = huge_d
result_rma_get = huge_d
result_rma_put = huge_d
first_swap_baseline = 0.0_psb_dpk_
first_swap_neighbor = 0.0_psb_dpk_
first_swap_persistent = 0.0_psb_dpk_
first_swap_rma_get = 0.0_psb_dpk_
first_swap_rma_put = 0.0_psb_dpk_
comm_setup_time_baseline = 0.0_psb_dpk_
comm_setup_time_neighbor = 0.0_psb_dpk_
comm_setup_time_persistent = 0.0_psb_dpk_
comm_setup_time_rma_get = 0.0_psb_dpk_
comm_setup_time_rma_put = 0.0_psb_dpk_
! ==================================================================
! 6. Baseline halo exchange (Isend/Irecv in one call)
! ==================================================================
if (run_baseline) then
comm_setup_time_baseline = psb_wtime()
call psb_comm_set(psb_comm_isend_irecv_, v_baseline%v%comm_handle, info)
if (info /= 0) then
write(psb_err_unit,*) my_rank, 'psb_comm_set baseline error:', info
call psb_abort(ctxt)
end if
comm_setup_time_baseline = psb_wtime() - comm_setup_time_baseline
first_swap_baseline = psb_wtime()
call psi_swapdata( &
swap_status=psb_comm_status_start_, &
beta=dzero, &
@ -336,6 +404,7 @@ program psb_comm_test
write(psb_err_unit,*) my_rank, 'baseline swap error:', info
call psb_abort(ctxt)
end if
first_swap_baseline = psb_wtime() - first_swap_baseline
end if
@ -343,11 +412,16 @@ program psb_comm_test
! 7. Neighbor topology halo exchange (start + wait)
! ==================================================================
if (run_neighbor) then
comm_setup_time_neighbor = psb_wtime()
call psb_comm_set(psb_comm_ineighbor_alltoallv_, v_neighbor%v%comm_handle, info)
if (info /= 0) then
write(psb_err_unit,*) my_rank, 'psb_comm_set neighbor error:', info
call psb_abort(ctxt)
end if
comm_setup_time_neighbor = psb_wtime() - comm_setup_time_neighbor
first_swap_neighbor = psb_wtime()
call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'neighbor start error:', info
@ -359,17 +433,23 @@ program psb_comm_test
write(psb_err_unit,*) my_rank, 'neighbor wait error:', info
call psb_abort(ctxt)
end if
first_swap_neighbor = psb_wtime() - first_swap_neighbor
end if
! ==================================================================
! 7b. Persistent-neighbor halo exchange (start + wait)
! ==================================================================
if (run_persistent) then
comm_setup_time_persistent = psb_wtime()
call psb_comm_set(psb_comm_persistent_ineighbor_alltoallv_, v_neighbor_persistent%v%comm_handle, info)
if (info /= 0) then
write(psb_err_unit,*) my_rank, 'psb_comm_set persistent-neighbor error:', info
call psb_abort(ctxt)
end if
comm_setup_time_persistent = psb_wtime() - comm_setup_time_persistent
first_swap_persistent = psb_wtime()
call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'persistent-neighbor start error:', info
@ -380,8 +460,62 @@ program psb_comm_test
write(psb_err_unit,*) my_rank, 'persistent-neighbor wait error:', info
call psb_abort(ctxt)
end if
first_swap_persistent = psb_wtime() - first_swap_persistent
end if
if(run_rma_get) then
comm_setup_time_rma_get = psb_wtime()
call psb_comm_set(psb_comm_rma_pull_, v_rma_get%v%comm_handle, info)
if (info /= 0) then
write(psb_err_unit,*) my_rank, 'psb_comm_set RMA get error:', info
call psb_abort(ctxt)
end if
comm_setup_time_rma_get = psb_wtime() - comm_setup_time_rma_get
first_swap_rma_get = psb_wtime()
call psi_swapdata(psb_comm_status_start_, dzero, v_rma_get%v, desc_a, info, data=psb_comm_halo_)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'RMA get start error:', info
call psb_abort(ctxt)
end if
call psi_swapdata(psb_comm_status_wait_, dzero, v_rma_get%v, desc_a, info, data=psb_comm_halo_)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'RMA get wait error:', info
call psb_abort(ctxt)
end if
first_swap_rma_get = psb_wtime() - first_swap_rma_get
end if
if(run_rma_put) then
comm_setup_time_rma_put = psb_wtime()
call psb_comm_set(psb_comm_rma_push_, v_rma_put%v%comm_handle, info)
if (info /= 0) then
write(psb_err_unit,*) my_rank, 'psb_comm_set RMA put error:', info
call psb_abort(ctxt)
end if
comm_setup_time_rma_put = psb_wtime() - comm_setup_time_rma_put
first_swap_rma_put = psb_wtime()
call psi_swapdata(psb_comm_status_start_, dzero, v_rma_put%v, desc_a, info, data=psb_comm_halo_)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'RMA put start error:', info
call psb_abort(ctxt)
end if
call psi_swapdata(psb_comm_status_wait_, dzero, v_rma_put%v, desc_a, info, data=psb_comm_halo_)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'RMA put wait error:', info
call psb_abort(ctxt)
end if
first_swap_rma_put = psb_wtime() - first_swap_rma_put
end if
! ==================================================================
! 8. Performance: repeat exchanges and measure timings
! ==================================================================
@ -392,6 +526,8 @@ program psb_comm_test
tsum_baseline = 0.0_psb_dpk_
tsum_neighbor = 0.0_psb_dpk_
tsum_neighbor_persistent = 0.0_psb_dpk_
tsum_rma_get = 0.0_psb_dpk_
tsum_rma_put = 0.0_psb_dpk_
do i = 1, iters
if (run_baseline) then
@ -423,194 +559,119 @@ program psb_comm_test
call psb_amx(ctxt, dt)
tsum_neighbor_persistent = tsum_neighbor_persistent + dt
end if
if (run_rma_get) then
t0 = psb_wtime()
call psi_swapdata(psb_comm_status_start_, dzero, v_rma_get%v, desc_a, info, data=psb_comm_halo_)
call psi_swapdata(psb_comm_status_wait_, dzero, v_rma_get%v, desc_a, info, data=psb_comm_halo_)
t1 = psb_wtime()
dt = t1 - t0
call psb_amx(ctxt, dt)
tsum_rma_get = tsum_rma_get + dt
end if
if (run_rma_put) then
t0 = psb_wtime()
call psi_swapdata(psb_comm_status_start_, dzero, v_rma_put%v, desc_a, info, data=psb_comm_halo_)
call psi_swapdata(psb_comm_status_wait_, dzero, v_rma_put%v, desc_a, info, data=psb_comm_halo_)
t1 = psb_wtime()
dt = t1 - t0
call psb_amx(ctxt, dt)
tsum_rma_put = tsum_rma_put + dt
end if
end do
call psb_amx(ctxt, tsum_baseline)
call psb_amx(ctxt, tsum_neighbor)
call psb_amx(ctxt, tsum_neighbor_persistent)
call psb_amx(ctxt, first_swap_baseline)
call psb_amx(ctxt, first_swap_neighbor)
call psb_amx(ctxt, first_swap_persistent)
call psb_amx(ctxt, comm_setup_time_baseline)
call psb_amx(ctxt, comm_setup_time_neighbor)
call psb_amx(ctxt, comm_setup_time_persistent)
call psb_amx(ctxt, comm_setup_time_rma_get)
call psb_amx(ctxt, comm_setup_time_rma_put)
if (my_rank == 0) then
if (run_baseline) then
write(psb_out_unit,'(" Avg baseline time : ",es12.5)') (tsum_baseline / real(iters,psb_dpk_))
write(psb_out_unit,'(" Tot baseline time : ",es12.5)') tsum_baseline
write(psb_out_unit,'(" First baseline time: ",es12.5)') first_swap_baseline
write(psb_out_unit,'(" Baseline comm setup: ",es12.5)') comm_setup_time_baseline
end if
if (run_neighbor) then
write(psb_out_unit,'(" Avg neighbor time : ",es12.5)') (tsum_neighbor / real(iters,psb_dpk_))
write(psb_out_unit,'(" Tot neighbor time : ",es12.5)') tsum_neighbor
write(psb_out_unit,'(" First neighbor time: ",es12.5)') first_swap_neighbor
write(psb_out_unit,'(" Neighbor comm setup: ",es12.5)') comm_setup_time_neighbor
end if
if (run_persistent) then
write(psb_out_unit,'(" Avg pers-neigh time: ",es12.5)') (tsum_neighbor_persistent / real(iters,psb_dpk_))
write(psb_out_unit,'(" Tot pers-neigh time: ",es12.5)') tsum_neighbor_persistent
write(psb_out_unit,'(" First pers-neigh time: ",es12.5)') first_swap_persistent
write(psb_out_unit,'(" Persistent comm setup: ",es12.5)') comm_setup_time_persistent
end if
if (run_rma_get) then
write(psb_out_unit,'(" Avg RMA get time : ",es12.5)') (tsum_rma_get / real(iters,psb_dpk_))
write(psb_out_unit,'(" Tot RMA get time : ",es12.5)') tsum_rma_get
write(psb_out_unit,'(" First RMA get time : ",es12.5)') first_swap_rma_get
write(psb_out_unit,'(" RMA get comm setup : ",es12.5)') comm_setup_time_rma_get
end if
if (run_rma_put) then
write(psb_out_unit,'(" Avg RMA put time : ",es12.5)') (tsum_rma_put / real(iters,psb_dpk_))
write(psb_out_unit,'(" Tot RMA put time : ",es12.5)') tsum_rma_put
write(psb_out_unit,'(" First RMA put time : ",es12.5)') first_swap_rma_put
write(psb_out_unit,'(" RMA put comm setup : ",es12.5)') comm_setup_time_rma_put
end if
end if
! ==================================================================
! 8. Extract results and compare
! ==================================================================
result_baseline = v_baseline%v%v
result_neighbor = v_neighbor%v%v
result_baseline = v_baseline%v%v
result_neighbor = v_neighbor%v%v
result_persistent = v_neighbor_persistent%v%v
result_rma_get = v_rma_get%v%v
result_rma_put = v_rma_put%v%v
! Debug: Check if results are properly populated
if (my_rank == 0 .and. debug_swapdata) then
write(psb_out_unit,'("DEBUG: ncol=",i0," nrow=",i0)') ncol, nrow
write(psb_out_unit,'("DEBUG: size(result_baseline)=",i0)') size(result_baseline)
if (ncol > 0) then
write(psb_out_unit,'("DEBUG: result_baseline(1:min(5,ncol))=",5(es12.5,1x))') &
& result_baseline(1:min(5,ncol))
write(psb_out_unit,'("DEBUG: expected(1:min(5,ncol))=",5(es12.5,1x))') &
& expected(1:min(5,ncol))
end if
end if
if (run_baseline .and. run_neighbor) then
n_total = n_total + 1
err = huge_d
if (ncol > 0) then
err = maxval(abs(result_baseline(1:ncol) - result_neighbor(1:ncol)))
else
err = 0.0_psb_dpk_
end if
call psb_amx(ctxt, err)
if (my_rank == 0) then
if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then
write(psb_out_unit,'(" [PASS] cross-check baseline vs neighbor : err = ",es12.5)') err
n_pass = n_pass + 1
else
write(psb_out_unit,'(" [FAIL] cross-check baseline vs neighbor : err = ",es12.5)') err
end if
end if
end if
if (run_baseline) then
n_total = n_total + 1
err = huge_d
if (ncol > 0) then
err = maxval(abs(result_baseline(1:ncol) - expected(1:ncol)))
else
err = 0.0_psb_dpk_
end if
call psb_amx(ctxt, err)
if (my_rank == 0) then
if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then
write(psb_out_unit,'(" [PASS] baseline absolute correctness : err = ",es12.5)') err
n_pass = n_pass + 1
else
write(psb_out_unit,'(" [FAIL] baseline absolute correctness : err = ",es12.5)') err
end if
end if
end if
if (run_neighbor) then
n_total = n_total + 1
err = huge_d
if (ncol > 0) then
err = maxval(abs(result_neighbor(1:ncol) - expected(1:ncol)))
else
err = 0.0_psb_dpk_
end if
call psb_amx(ctxt, err)
if (my_rank == 0) then
if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then
write(psb_out_unit,'(" [PASS] neighbor absolute correctness : err = ",es12.5)') err
n_pass = n_pass + 1
else
write(psb_out_unit,'(" [FAIL] neighbor absolute correctness : err = ",es12.5)') err
end if
end if
end if
! --- Cross Checks ---
if (run_baseline .and. run_neighbor) &
call check_result("cross-check baseline vs neighbor", result_baseline, result_neighbor, ncol, tol)
if (run_baseline .and. run_persistent) then
n_total = n_total + 1
err = huge_d
if (ncol > 0) then
err = maxval(abs(result_baseline(1:ncol) - result_persistent(1:ncol)))
else
err = 0.0_psb_dpk_
end if
call psb_amx(ctxt, err)
if (my_rank == 0) then
if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then
write(psb_out_unit,'(" [PASS] cross-check baseline vs pers-nei : err = ",es12.5)') err
n_pass = n_pass + 1
else
write(psb_out_unit,'(" [FAIL] cross-check baseline vs pers-nei : err = ",es12.5)') err
end if
end if
end if
if (run_baseline .and. run_persistent) &
call check_result("cross-check baseline vs pers-nei", result_baseline, result_persistent, ncol, tol)
if (run_persistent) then
n_total = n_total + 1
err = huge_d
if (ncol > 0) then
err = maxval(abs(result_persistent(1:ncol) - expected(1:ncol)))
else
err = 0.0_psb_dpk_
end if
call psb_amx(ctxt, err)
if (my_rank == 0) then
if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then
write(psb_out_unit,'(" [PASS] pers-neigh absolute correctness : err = ",es12.5)') err
n_pass = n_pass + 1
else
write(psb_out_unit,'(" [FAIL] pers-neigh absolute correctness : err = ",es12.5)') err
end if
end if
end if
if (run_baseline .and. run_rma_get) &
call check_result("cross-check baseline vs rma-get", result_baseline, result_rma_get, ncol, tol)
if (run_baseline .and. run_rma_put) &
call check_result("cross-check baseline vs rma-put", result_baseline, result_rma_put, ncol, tol)
if (run_neighbor) then
! ---- Test 6: repeat neighbor exchange (topology reuse) ----
do i = nrow+1, ncol
result_neighbor(i) = dzero
end do
call v_neighbor%set_vect(result_neighbor)
! --- Absolute Correctness Checks against Expected ---
if (run_baseline) &
call check_result("baseline absolute correctness", result_baseline, expected, ncol, tol)
if (run_neighbor) &
call check_result("neighbor absolute correctness", result_neighbor, expected, ncol, tol)
if (run_persistent) &
call check_result("pers-neigh absolute correctness", result_persistent, expected, ncol, tol)
call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_)
call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_)
if (run_rma_get) &
call check_result("rma_get absolute correctness", result_rma_get, expected, ncol, tol)
result_neighbor = v_neighbor%v%v
n_total = n_total + 1
err = huge_d
if (ncol > 0) then
err = maxval(abs(result_neighbor(1:ncol) - expected(1:ncol)))
else
err = 0.0_psb_dpk_
end if
call psb_amx(ctxt, err)
if (my_rank == 0) then
if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then
write(psb_out_unit,'(" [PASS] neighbor topology reuse : err = ",es12.5)') err
n_pass = n_pass + 1
else
write(psb_out_unit,'(" [FAIL] neighbor topology reuse : err = ",es12.5)') err
end if
end if
end if
if (run_persistent) then
! ---- Test 7: repeat persistent-neighbor exchange (buffer reuse) ----
do i = nrow+1, ncol
result_persistent(i) = dzero
end do
call v_neighbor_persistent%set_vect(result_persistent)
call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_)
call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_)
result_persistent = v_neighbor_persistent%v%v
n_total = n_total + 1
err = huge_d
if (ncol > 0) then
err = maxval(abs(result_persistent(1:ncol) - expected(1:ncol)))
else
err = 0.0_psb_dpk_
end if
call psb_amx(ctxt, err)
if (my_rank == 0) then
if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then
write(psb_out_unit,'(" [PASS] pers-neigh buffer reuse : err = ",es12.5)') err
n_pass = n_pass + 1
else
write(psb_out_unit,'(" [FAIL] pers-neigh buffer reuse : err = ",es12.5)') err
end if
end if
end if
if (run_rma_put) &
call check_result("rma_put absolute correctness", result_rma_put, expected, ncol, tol)
! ==================================================================
! 9. Summary
@ -640,4 +701,36 @@ program psb_comm_test
call psb_cdfree(desc_a, info)
call psb_exit(ctxt)
contains
! Helper routine to compare two arrays, reduce the error across MPI ranks, and print the result
subroutine check_result(test_name, arr1, arr2, n, tolerance)
character(len=*), intent(in) :: test_name
real(psb_dpk_), intent(in) :: arr1(:), arr2(:)
integer(psb_ipk_), intent(in):: n
real(psb_dpk_), intent(in) :: tolerance
real(psb_dpk_) :: err_val
n_total = n_total + 1
if (n > 0) then
err_val = maxval(abs(arr1(1:n) - arr2(1:n)))
else
err_val = 0.0_psb_dpk_
end if
! Get max error across all processes
call psb_amx(ctxt, err_val)
if (my_rank == 0) then
if ((err_val >= 0.0_psb_dpk_) .and. (err_val < tolerance)) then
write(psb_out_unit,'(" [PASS] ", a, t45, ": err = ",es12.5)') test_name, err_val
n_pass = n_pass + 1
else
write(psb_out_unit,'(" [FAIL] ", a, t45, ": err = ",es12.5)') test_name, err_val
end if
end if
end subroutine check_result
end program psb_comm_test
Loading…
Cancel
Save