From fb5ba596933a09a3ba5e460ebb9c60120771f5d8 Mon Sep 17 00:00:00 2001 From: Stack-1 Date: Fri, 24 Apr 2026 15:10:06 +0200 Subject: [PATCH] [ADD] Added RMA one sided communication schemes --- base/CMakeLists.txt | 1 + base/comm/internals/psi_dswapdata.F90 | 492 ++++++++++++++++++ base/modules/Makefile | 5 +- .../comm_schemes/psb_comm_factory_mod.F90 | 15 + .../comm/comm_schemes/psb_comm_rma_mod.F90 | 240 +++++++++ .../comm_schemes/psb_comm_schemes_mod.F90 | 2 + ext/psb_c_hll_mat_mod.f90 | 2 +- ext/psb_d_hll_mat_mod.f90 | 2 +- ext/psb_s_hll_mat_mod.f90 | 2 +- ext/psb_z_hll_mat_mod.f90 | 2 +- test/comm/cg/psb_comm_cg_test.F90 | 15 +- test/comm/spmv/psb_spmv_test.f90 | 70 ++- test/comm/swapdata/psb_comm_test.F90 | 433 +++++++++------ 13 files changed, 1085 insertions(+), 196 deletions(-) create mode 100644 base/modules/comm/comm_schemes/psb_comm_rma_mod.F90 diff --git a/base/CMakeLists.txt b/base/CMakeLists.txt index 1833a93c..ebac7d33 100644 --- a/base/CMakeLists.txt +++ b/base/CMakeLists.txt @@ -16,6 +16,7 @@ set(PSB_base_source_files comm/internals/psi_lovrl_upd.f90 comm/internals/psi_dswapdata_a.F90 comm/internals/psi_movrl_upd_a.f90 + modules/comm/comm_schemes/psb_comm_rma_mod.F90 # comm/internals/psi_i2swaptran_a.F90 comm/internals/psi_dswaptran.F90 comm/internals/psi_covrl_save_a.f90 diff --git a/base/comm/internals/psi_dswapdata.F90 b/base/comm/internals/psi_dswapdata.F90 index d8c7fff9..a19c17e3 100644 --- a/base/comm/internals/psi_dswapdata.F90 +++ b/base/comm/internals/psi_dswapdata.F90 @@ -83,6 +83,9 @@ submodule (psi_d_comm_v_mod) psi_d_swapdata_impl use psb_desc_const_mod, only: psb_swap_start_, psb_swap_wait_ use psb_base_mod use psb_error_mod, only: psb_get_debug_level, psb_get_debug_unit, psb_debug_ext_ + use psb_comm_schemes_mod, only: psb_comm_isend_irecv_, psb_comm_ineighbor_alltoallv_, & + & psb_comm_persistent_ineighbor_alltoallv_, psb_comm_rma_pull_, psb_comm_rma_push_ + use psb_comm_rma_mod, only: psb_comm_rma_handle use psb_comm_factory_mod contains @@ -191,6 +194,20 @@ contains call psb_errpush(info,name,a_err='neighbor persistent swap') goto 9999 end if + case(psb_comm_rma_pull_) + call psi_dswap_rma_pull_vect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,total_send,total_recv,& + & y%comm_handle,info) + if (info /= psb_success_) then + call psb_errpush(info,name,a_err='rma pull swap') + goto 9999 + end if + case(psb_comm_rma_push_) + call psi_dswap_rma_push_vect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,total_send,total_recv,& + & y%comm_handle,info) + if (info /= psb_success_) then + call psb_errpush(info,name,a_err='rma push swap') + goto 9999 + end if case default info = psb_err_mpi_error_ call psb_errpush(info,name,a_err='Incompatible swap_status settings: no valid communication mode selected') @@ -911,6 +928,477 @@ contains end subroutine psi_dswap_neighbor_persistent_topology_vect + subroutine psi_dswap_rma_pull_vect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,total_send,total_recv,comm_handle,info) +#ifdef PSB_MPI_MOD + use mpi +#endif + implicit none +#ifdef PSB_MPI_H + include 'mpif.h' +#endif + + type(psb_ctxt_type), intent(in) :: ctxt + integer(psb_ipk_), intent(in) :: swap_status + real(psb_dpk_), intent(in) :: beta + class(psb_d_base_vect_type), intent(inout) :: y + class(psb_i_base_vect_type), intent(inout) :: comm_indexes + integer(psb_ipk_), intent(in) :: num_neighbors, total_send, total_recv + class(psb_comm_handle_type), intent(inout) :: comm_handle + integer(psb_ipk_), intent(out) :: info + + integer(psb_mpk_) :: np, my_rank, iret, element_bytes, icomm + integer(psb_mpk_) :: proc_to_comm, prc_rank, recv_count, send_count, send_pos, recv_pos, list_pos + integer(psb_mpk_) :: remote_base + integer(kind=MPI_ADDRESS_KIND) :: remote_disp, exposed_bytes + integer(psb_ipk_) :: err_act, neighbor_idx, buffer_size + integer(psb_ipk_), allocatable :: peer_mpi_rank(:) + logical :: do_start, do_wait, memory_buffer_layout_rebuild_needed + type(psb_comm_rma_handle), pointer :: rma_handle + character(len=30) :: name + + info = psb_success_ + name = 'psi_dswap_rma_pull_vect' + call psb_erractionsave(err_act) + call psb_info(ctxt,my_rank,np) + if (np == -1) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + end if + icomm = ctxt%get_mpic() + + select type(ch => comm_handle) + type is(psb_comm_rma_handle) + rma_handle => ch + class default + info = psb_err_mpi_error_ + call psb_errpush(info,name,a_err='Expected RMA comm_handle for pull mode') + goto 9999 + end select + + do_start = (swap_status == psb_comm_status_start_) .or. (swap_status == psb_comm_status_sync_) + do_wait = (swap_status == psb_comm_status_wait_) .or. (swap_status == psb_comm_status_sync_) + + call comm_indexes%sync() + + ! START phase: build the layout once, prepare the exposed buffer, then issue RMA. + if (do_start) then + buffer_size = total_send + total_recv + + memory_buffer_layout_rebuild_needed = (.not. rma_handle%layout_ready) .or. & + & (rma_handle%layout_nnbr /= num_neighbors) .or. & + & (rma_handle%layout_send /= total_send) .or. & + & (rma_handle%layout_recv /= total_recv) + + if (memory_buffer_layout_rebuild_needed) then + if (allocated(peer_mpi_rank)) deallocate(peer_mpi_rank) + if (num_neighbors > 0) then + allocate(peer_mpi_rank(num_neighbors), stat=iret) + if (iret /= 0) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info,name,a_err='RMA pull rank cache allocation') + goto 9999 + end if + end if + + list_pos = 1 + do neighbor_idx = 1, num_neighbors + proc_to_comm = comm_indexes%v(list_pos+psb_proc_id_) + peer_mpi_rank(neighbor_idx) = psb_get_mpi_rank(ctxt,proc_to_comm) + recv_count = comm_indexes%v(list_pos+psb_n_elem_recv_) + send_count = comm_indexes%v(list_pos+recv_count+psb_n_elem_send_) + list_pos = list_pos + recv_count + send_count + 3 + end do + + call rma_handle%init_memory_buffer_layout(info, comm_indexes%v, peer_mpi_rank, & + & num_neighbors, total_send, total_recv, my_rank, icomm) + if (info /= psb_success_) then + call psb_errpush(info,name,a_err='RMA pull init_memory_buffer_layout failure') + goto 9999 + end if + end if + + if (buffer_size > 0) then + if (.not. allocated(y%combuf)) then + call y%new_buffer(buffer_size, info) + if (info /= psb_success_) then + call psb_errpush(psb_err_alloc_dealloc_,name) + goto 9999 + end if + else if (size(y%combuf) < buffer_size) then + ! Need a larger exposed memory area: recreate the RMA window first, + ! then reallocate combuf and lazily create a new window below. + if (rma_handle%window_open) then + call mpi_win_unlock_all(rma_handle%win, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 + end if + rma_handle%window_open = .false. + end if + if (rma_handle%window_ready) then + call mpi_win_free(rma_handle%win, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 + end if + rma_handle%window_ready = .false. + rma_handle%win = mpi_win_null + end if + call y%new_buffer(buffer_size, info) + if (info /= psb_success_) then + call psb_errpush(psb_err_alloc_dealloc_,name) + goto 9999 + end if + end if + end if + + if ((buffer_size > 0).and.(.not. rma_handle%window_ready)) then + ! Expose combuf once and keep the window around until the descriptor changes. + element_bytes = storage_size(y%combuf(1))/8 + exposed_bytes = int(size(y%combuf),kind=MPI_ADDRESS_KIND) * int(element_bytes,kind=MPI_ADDRESS_KIND) + call mpi_win_create(y%combuf, exposed_bytes, element_bytes, & + & mpi_info_null, ctxt%get_mpic(), rma_handle%win, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 + end if + rma_handle%window_ready = .true. + end if + + if (buffer_size > 0) then + if (total_send > 0) then + call y%gth(int(total_send,psb_mpk_), rma_handle%peer_send_indexes, y%combuf(1:total_send)) + end if + call y%device_wait() + + call mpi_win_lock_all(0, rma_handle%win, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 + end if + rma_handle%window_open = .true. + + ! Pull data from each peer. The metadata exchange stays local and simple. + do neighbor_idx=1, num_neighbors + proc_to_comm = rma_handle%peer_proc(neighbor_idx) + recv_count = rma_handle%peer_recv_counts(neighbor_idx) + send_count = rma_handle%peer_send_counts(neighbor_idx) + prc_rank = rma_handle%peer_mpi_rank(neighbor_idx) + send_pos = rma_handle%peer_send_displs(neighbor_idx) + 1 + recv_pos = total_send + rma_handle%peer_recv_displs(neighbor_idx) + 1 + + if (proc_to_comm /= my_rank) then + remote_base = rma_handle%peer_remote_send_displs(neighbor_idx) + if (remote_base < 1) then + info = psb_err_internal_error_ + call psb_errpush(info,name,a_err='Invalid remote metadata in RMA pull') + goto 9999 + end if + if ((recv_pos < 1) .or. (recv_count < 0) .or. (recv_pos+max(0,recv_count)-1 > size(y%combuf))) then + info = psb_err_internal_error_ + call psb_errpush(info,name,a_err='RMA pull local receive bounds error') + goto 9999 + end if + + if (recv_count > 0) then + remote_disp = int(remote_base - 1, kind=MPI_ADDRESS_KIND) + call mpi_get(y%combuf(recv_pos), recv_count, psb_mpi_r_dpk_, prc_rank, remote_disp, recv_count, psb_mpi_r_dpk_, & + & rma_handle%win, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 + end if + end if + call mpi_win_flush(prc_rank, rma_handle%win, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 + end if + else + if (send_count /= recv_count) then + info = psb_err_internal_error_ + call psb_errpush(info,name,a_err='RMA pull self-copy mismatch') + goto 9999 + end if + y%combuf(recv_pos:recv_pos+recv_count-1) = y%combuf(send_pos:send_pos+send_count-1) + end if + end do + end if + end if + + ! WAIT phase: close the epoch and scatter the received slice back into Y. + if (do_wait) then + if (rma_handle%window_open) then + call mpi_win_unlock_all(rma_handle%win, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 + end if + rma_handle%window_open = .false. + end if + + if (total_recv > 0) then + call y%sct(int(total_recv,psb_mpk_), rma_handle%peer_recv_indexes, y%combuf(total_send+1:total_send+total_recv), beta) + end if + + call y%device_wait() + end if + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + return + end subroutine psi_dswap_rma_pull_vect + + + subroutine psi_dswap_rma_push_vect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,total_send,total_recv,comm_handle,info) +#ifdef PSB_MPI_MOD + use mpi +#endif + implicit none +#ifdef PSB_MPI_H + include 'mpif.h' +#endif + + type(psb_ctxt_type), intent(in) :: ctxt + integer(psb_ipk_), intent(in) :: swap_status + real(psb_dpk_), intent(in) :: beta + class(psb_d_base_vect_type), intent(inout) :: y + class(psb_i_base_vect_type), intent(inout) :: comm_indexes + integer(psb_ipk_), intent(in) :: num_neighbors, total_send, total_recv + class(psb_comm_handle_type), intent(inout) :: comm_handle + integer(psb_ipk_), intent(out) :: info + + integer(psb_mpk_) :: np, my_rank, iret, element_bytes, icomm + integer(psb_mpk_) :: proc_to_comm, prc_rank, recv_count, send_count, send_pos, recv_pos, list_pos + integer(psb_mpk_) :: remote_base + integer(kind=MPI_ADDRESS_KIND) :: remote_disp, exposed_bytes + integer(psb_ipk_) :: err_act, neighbor_idx, buffer_size + integer(psb_ipk_), allocatable :: peer_mpi_rank(:) + logical :: do_start, do_wait, memory_buffer_layout_rebuild_needed + type(psb_comm_rma_handle), pointer :: rma_handle + character(len=30) :: name + + info = psb_success_ + name = 'psi_dswap_rma_push_vect' + call psb_erractionsave(err_act) + call psb_info(ctxt,my_rank,np) + if (np == -1) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + end if + icomm = ctxt%get_mpic() + + select type(ch => comm_handle) + type is(psb_comm_rma_handle) + rma_handle => ch + class default + info = psb_err_mpi_error_ + call psb_errpush(info,name,a_err='Expected RMA comm_handle for push mode') + goto 9999 + end select + + do_start = (swap_status == psb_comm_status_start_) .or. (swap_status == psb_comm_status_sync_) + do_wait = (swap_status == psb_comm_status_wait_) .or. (swap_status == psb_comm_status_sync_) + + call comm_indexes%sync() + + ! START phase: identical layout handling, but the remote metadata describes the receive side. + if (do_start) then + buffer_size = total_send + total_recv + + memory_buffer_layout_rebuild_needed = (.not. rma_handle%layout_ready) .or. & + & (rma_handle%layout_nnbr /= num_neighbors) .or. & + & (rma_handle%layout_send /= total_send) .or. & + & (rma_handle%layout_recv /= total_recv) + + if (memory_buffer_layout_rebuild_needed) then + if (allocated(peer_mpi_rank)) deallocate(peer_mpi_rank) + if (num_neighbors > 0) then + allocate(peer_mpi_rank(num_neighbors), stat=iret) + if (iret /= 0) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info,name,a_err='RMA put rank cache allocation') + goto 9999 + end if + end if + + list_pos = 1 + do neighbor_idx = 1, num_neighbors + proc_to_comm = comm_indexes%v(list_pos+psb_proc_id_) + peer_mpi_rank(neighbor_idx) = psb_get_mpi_rank(ctxt,proc_to_comm) + recv_count = comm_indexes%v(list_pos+psb_n_elem_recv_) + send_count = comm_indexes%v(list_pos+recv_count+psb_n_elem_send_) + list_pos = list_pos + recv_count + send_count + 3 + end do + + call rma_handle%init_memory_buffer_layout(info, comm_indexes%v, peer_mpi_rank, & + & num_neighbors, total_send, total_recv, my_rank, icomm) + if (info /= psb_success_) then + call psb_errpush(info,name,a_err='RMA put ini_memory_buffer_layout') + goto 9999 + end if + end if + + if (buffer_size > 0) then + if (.not. allocated(y%combuf)) then + call y%new_buffer(buffer_size, info) + if (info /= psb_success_) then + call psb_errpush(psb_err_alloc_dealloc_,name) + goto 9999 + end if + else if (size(y%combuf) < buffer_size) then + ! Need a larger exposed memory area: recreate the RMA window first, + ! then reallocate combuf and lazily create a new window below. + if (rma_handle%window_open) then + call mpi_win_unlock_all(rma_handle%win, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 + end if + rma_handle%window_open = .false. + end if + if (rma_handle%window_ready) then + call mpi_win_free(rma_handle%win, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 + end if + rma_handle%window_ready = .false. + rma_handle%win = mpi_win_null + end if + call y%new_buffer(buffer_size, info) + if (info /= psb_success_) then + call psb_errpush(psb_err_alloc_dealloc_,name) + goto 9999 + end if + end if + end if + + if ((buffer_size > 0).and.(.not. rma_handle%window_ready)) then + ! Keep the window alive across repetitions: it is created once and reused. + element_bytes = storage_size(y%combuf(1))/8 + exposed_bytes = int(size(y%combuf),kind=MPI_ADDRESS_KIND) * int(element_bytes,kind=MPI_ADDRESS_KIND) + call mpi_win_create(y%combuf, exposed_bytes, element_bytes, & + & mpi_info_null, ctxt%get_mpic(), rma_handle%win, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 + end if + rma_handle%window_ready = .true. + end if + + if (buffer_size > 0) then + if (total_send > 0) then + call y%gth(int(total_send,psb_mpk_), rma_handle%peer_send_indexes, y%combuf(1:total_send)) + end if + call y%device_wait() + + call mpi_win_lock_all(0, rma_handle%win, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 + end if + rma_handle%window_open = .true. + + ! Push data to each peer. Only the base displacement and count are exchanged. + do neighbor_idx=1, num_neighbors + proc_to_comm = rma_handle%peer_proc(neighbor_idx) + recv_count = rma_handle%peer_recv_counts(neighbor_idx) + send_count = rma_handle%peer_send_counts(neighbor_idx) + prc_rank = rma_handle%peer_mpi_rank(neighbor_idx) + send_pos = rma_handle%peer_send_displs(neighbor_idx) + 1 + recv_pos = total_send + rma_handle%peer_recv_displs(neighbor_idx) + 1 + + if (proc_to_comm /= my_rank) then + remote_base = rma_handle%peer_remote_recv_displs(neighbor_idx) + if (remote_base < 1) then + info = psb_err_internal_error_ + call psb_errpush(info,name,a_err='Invalid remote metadata in RMA push') + goto 9999 + end if + if ((send_pos < 1) .or. (send_count < 0) .or. (send_pos+max(0,send_count)-1 > size(y%combuf))) then + info = psb_err_internal_error_ + call psb_errpush(info,name,a_err='RMA push local send bounds error') + goto 9999 + end if + + if (send_count > 0) then + remote_disp = int(remote_base - 1, kind=MPI_ADDRESS_KIND) + call mpi_put(y%combuf(send_pos), send_count, psb_mpi_r_dpk_, prc_rank, remote_disp, send_count, psb_mpi_r_dpk_, & + & rma_handle%win, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 + end if + end if + call mpi_win_flush(prc_rank, rma_handle%win, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 + end if + else + if (send_count /= recv_count) then + info = psb_err_internal_error_ + call psb_errpush(info,name,a_err='RMA push self-copy mismatch') + goto 9999 + end if + y%combuf(recv_pos:recv_pos+recv_count-1) = y%combuf(send_pos:send_pos+send_count-1) + end if + end do + end if + end if + + ! WAIT phase: close the epoch and apply the receive-side scatter. + if (do_wait) then + if (rma_handle%window_open) then + call mpi_win_unlock_all(rma_handle%win, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 + end if + rma_handle%window_open = .false. + end if + + call mpi_barrier(icomm, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 + end if + + if (total_recv > 0) then + call y%sct(int(total_recv,psb_mpk_), rma_handle%peer_recv_indexes, y%combuf(total_send+1:total_send+total_recv), beta) + end if + + call y%device_wait() + end if + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + return + end subroutine psi_dswap_rma_push_vect + + ! ! @@ -1007,6 +1495,10 @@ contains ineighbor_a2av = .true. case(psb_comm_persistent_ineighbor_alltoallv_) ineighbor_a2av_persistent = .true. + case(psb_comm_rma_pull_, psb_comm_rma_push_) + info = psb_err_mpi_error_ + call psb_errpush(info,name,a_err='RMA swap is not yet enabled for multivectors') + goto 9999 case default baseline = .true. end select diff --git a/base/modules/Makefile b/base/modules/Makefile index 7e6f0e7c..7d4a1b36 100644 --- a/base/modules/Makefile +++ b/base/modules/Makefile @@ -31,6 +31,7 @@ SERIAL_MODS=serial/psb_s_serial_mod.o serial/psb_d_serial_mod.o \ comm/comm_schemes/psb_comm_schemes_mod.o \ comm/comm_schemes/psb_comm_baseline_mod.o \ comm/comm_schemes/psb_comm_neighbor_impl_mod.o \ + comm/comm_schemes/psb_comm_rma_mod.o \ comm/comm_schemes/psb_comm_factory_mod.o \ serial/psb_i_base_vect_mod.o serial/psb_i_vect_mod.o\ serial/psb_l_base_vect_mod.o serial/psb_l_vect_mod.o\ @@ -193,7 +194,9 @@ comm/comm_schemes/psb_comm_baseline_mod.o: comm/comm_schemes/psb_comm_schemes_mo comm/comm_schemes/psb_comm_neighbor_impl_mod.o: psb_const_mod.o desc/psb_desc_const_mod.o comm/comm_schemes/psb_comm_schemes_mod.o -comm/comm_schemes/psb_comm_factory_mod.o: comm/comm_schemes/psb_comm_schemes_mod.o comm/comm_schemes/psb_comm_baseline_mod.o comm/comm_schemes/psb_comm_neighbor_impl_mod.o +comm/comm_schemes/psb_comm_rma_mod.o: comm/comm_schemes/psb_comm_schemes_mod.o psb_const_mod.o desc/psb_desc_const_mod.o + +comm/comm_schemes/psb_comm_factory_mod.o: comm/comm_schemes/psb_comm_schemes_mod.o comm/comm_schemes/psb_comm_baseline_mod.o comm/comm_schemes/psb_comm_neighbor_impl_mod.o comm/comm_schemes/psb_comm_rma_mod.o comm/psb_neighbor_topology_mod.o: psb_const_mod.o desc/psb_desc_const_mod.o diff --git a/base/modules/comm/comm_schemes/psb_comm_factory_mod.F90 b/base/modules/comm/comm_schemes/psb_comm_factory_mod.F90 index 03c12e72..fbb63358 100644 --- a/base/modules/comm/comm_schemes/psb_comm_factory_mod.F90 +++ b/base/modules/comm/comm_schemes/psb_comm_factory_mod.F90 @@ -2,10 +2,12 @@ module psb_comm_factory_mod use psb_const_mod use psb_comm_schemes_mod, only: psb_comm_handle_type, psb_comm_isend_irecv_, & & psb_comm_ineighbor_alltoallv_, psb_comm_persistent_ineighbor_alltoallv_, & + & psb_comm_rma_pull_, psb_comm_rma_push_, & & psb_comm_unknown_, psb_comm_status_start_, psb_comm_status_wait_, & & psb_comm_status_sync_, psb_comm_status_unknown_ use psb_comm_baseline_mod, only: psb_comm_baseline_handle use psb_comm_neighbor_impl_mod, only: psb_comm_neighbor_handle + use psb_comm_rma_mod, only: psb_comm_rma_handle implicit none contains @@ -37,6 +39,8 @@ contains type is(psb_comm_neighbor_handle) h%comm_type = comm_type h%use_persistent_buffers = (comm_type == psb_comm_persistent_ineighbor_alltoallv_) + type is(psb_comm_rma_handle) + h%comm_type = comm_type class default ! nothing else to configure end select @@ -60,6 +64,17 @@ contains h%comm_type = comm_type h%use_persistent_buffers = (comm_type == psb_comm_persistent_ineighbor_alltoallv_) end select + case(psb_comm_rma_pull_, psb_comm_rma_push_) + allocate(psb_comm_rma_handle :: handle, stat=info) + if (info /= 0) return + call handle%init(info) + if (info /= 0) return + handle%id = old_id + handle%swap_status = old_swap_status + select type(h => handle) + type is(psb_comm_rma_handle) + h%comm_type = comm_type + end select case default allocate(psb_comm_baseline_handle :: handle, stat=info) if (info /= 0) return diff --git a/base/modules/comm/comm_schemes/psb_comm_rma_mod.F90 b/base/modules/comm/comm_schemes/psb_comm_rma_mod.F90 new file mode 100644 index 00000000..30822617 --- /dev/null +++ b/base/modules/comm/comm_schemes/psb_comm_rma_mod.F90 @@ -0,0 +1,240 @@ +module psb_comm_rma_mod + use psb_const_mod + use psb_desc_const_mod, only: psb_proc_id_, psb_n_elem_recv_, psb_elem_recv_, & + & psb_n_elem_send_, psb_elem_send_ + use psb_error_mod +#ifdef PSB_MPI_MOD + use mpi +#endif + use psb_comm_schemes_mod, only: psb_comm_handle_type, psb_comm_rma_pull_, psb_comm_rma_push_, & + & psb_comm_unknown_ + implicit none + integer(psb_mpk_), parameter :: psb_rma_meta_tag = 913 +#ifdef PSB_MPI_H + include 'mpif.h' +#endif + + type, extends(psb_comm_handle_type) :: psb_comm_rma_handle + integer(psb_mpk_) :: win = mpi_win_null + logical :: window_ready = .false. + logical :: window_open = .false. + logical :: layout_ready = .false. + integer(psb_ipk_) :: layout_nnbr = -1 + integer(psb_ipk_) :: layout_send = -1 + integer(psb_ipk_) :: layout_recv = -1 + integer(psb_ipk_), allocatable :: peer_proc(:) + integer(psb_ipk_), allocatable :: peer_send_counts(:) + integer(psb_ipk_), allocatable :: peer_recv_counts(:) + integer(psb_ipk_), allocatable :: peer_send_displs(:) + integer(psb_ipk_), allocatable :: peer_recv_displs(:) + integer(psb_ipk_), allocatable :: peer_mpi_rank(:) + integer(psb_ipk_), allocatable :: peer_remote_send_displs(:) + integer(psb_ipk_), allocatable :: peer_remote_recv_displs(:) + integer(psb_ipk_), allocatable :: peer_send_indexes(:) + integer(psb_ipk_), allocatable :: peer_recv_indexes(:) + contains + procedure, pass :: init => psb_comm_rma_init + procedure, pass :: free => psb_comm_rma_free + procedure, pass :: clear_memory_buffer_layout => psb_comm_rma_clear_memory_buffer_layout + procedure, pass :: init_memory_buffer_layout => psb_comm_rma_ini_memory_buffer_layout + procedure, pass :: set_swap_status => psb_comm_rma_set_swap_status + procedure, pass :: get_swap_status => psb_comm_rma_get_swap_status + end type psb_comm_rma_handle + +contains + + subroutine psb_comm_rma_init(this, info) + class(psb_comm_rma_handle), intent(inout) :: this + integer(psb_ipk_), intent(out) :: info + info = 0 + this%comm_type = psb_comm_unknown_ + this%id = 0 + this%swap_status = 0 + this%win = mpi_win_null + this%window_ready = .false. + this%window_open = .false. + this%layout_ready = .false. + this%layout_nnbr = -1 + this%layout_send = -1 + this%layout_recv = -1 + call this%clear_memory_buffer_layout(info) + end subroutine psb_comm_rma_init + + subroutine psb_comm_rma_clear_memory_buffer_layout(this, info) + class(psb_comm_rma_handle), intent(inout) :: this + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + if (allocated(this%peer_proc)) deallocate(this%peer_proc) + if (allocated(this%peer_send_counts)) deallocate(this%peer_send_counts) + if (allocated(this%peer_recv_counts)) deallocate(this%peer_recv_counts) + if (allocated(this%peer_send_displs)) deallocate(this%peer_send_displs) + if (allocated(this%peer_recv_displs)) deallocate(this%peer_recv_displs) + if (allocated(this%peer_mpi_rank)) deallocate(this%peer_mpi_rank) + if (allocated(this%peer_remote_send_displs)) deallocate(this%peer_remote_send_displs) + if (allocated(this%peer_remote_recv_displs)) deallocate(this%peer_remote_recv_displs) + if (allocated(this%peer_send_indexes)) deallocate(this%peer_send_indexes) + if (allocated(this%peer_recv_indexes)) deallocate(this%peer_recv_indexes) + + this%layout_ready = .false. + this%layout_nnbr = -1 + this%layout_send = -1 + this%layout_recv = -1 + end subroutine psb_comm_rma_clear_memory_buffer_layout + + subroutine psb_comm_rma_ini_memory_buffer_layout(this, info, comm_list, peer_mpi_rank, & + & num_neighbors, total_send, total_recv, my_rank, icomm) + class(psb_comm_rma_handle), intent(inout) :: this + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_), intent(in) :: comm_list(:) + integer(psb_ipk_), intent(in) :: peer_mpi_rank(:) + integer(psb_ipk_), intent(in) :: num_neighbors, total_send, total_recv + integer(psb_ipk_), intent(in) :: my_rank + integer(psb_mpk_), intent(in) :: icomm + + integer(psb_mpk_) :: iret, p2pstat(mpi_status_size), prc_rank + integer(psb_ipk_) :: n_neighbors, send_total, recv_total + integer(psb_ipk_) :: neighbor_idx, item_idx + integer(psb_ipk_) :: list_pos, send_offset, recv_offset + integer(psb_ipk_) :: proc_to_comm, recv_count, send_count + integer(psb_ipk_) :: local_meta(4), remote_meta(4) + + call this%clear_memory_buffer_layout(info) + if (info /= psb_success_) return + + n_neighbors = num_neighbors + send_total = total_send + recv_total = total_recv + + if (n_neighbors > 0) then + allocate(this%peer_proc(n_neighbors), this%peer_send_counts(n_neighbors), & + & this%peer_recv_counts(n_neighbors), this%peer_send_displs(n_neighbors), & + & this%peer_recv_displs(n_neighbors), this%peer_mpi_rank(n_neighbors), & + & this%peer_remote_send_displs(n_neighbors), this%peer_remote_recv_displs(n_neighbors), stat=iret) + if (iret /= 0) then + info = psb_err_alloc_dealloc_ + return + end if + end if + + if (send_total > 0) then + allocate(this%peer_send_indexes(send_total), stat=iret) + if (iret /= 0) then + info = psb_err_alloc_dealloc_ + return + end if + end if + + if (recv_total > 0) then + allocate(this%peer_recv_indexes(recv_total), stat=iret) + if (iret /= 0) then + info = psb_err_alloc_dealloc_ + return + end if + end if + + list_pos = 1 + send_offset = 0 + recv_offset = 0 + do neighbor_idx = 1, n_neighbors + proc_to_comm = comm_list(list_pos + psb_proc_id_) + recv_count = comm_list(list_pos + psb_n_elem_recv_) + send_count = comm_list(list_pos + recv_count + psb_n_elem_send_) + + this%peer_proc(neighbor_idx) = proc_to_comm + this%peer_recv_counts(neighbor_idx) = recv_count + this%peer_send_counts(neighbor_idx) = send_count + this%peer_recv_displs(neighbor_idx) = recv_offset + this%peer_send_displs(neighbor_idx) = send_offset + this%peer_mpi_rank(neighbor_idx) = peer_mpi_rank(neighbor_idx) + + if (recv_count > 0) then + do item_idx = 1, recv_count + this%peer_recv_indexes(recv_offset + item_idx) = comm_list(list_pos + psb_elem_recv_ + item_idx - 1) + end do + end if + + if (send_count > 0) then + do item_idx = 1, send_count + this%peer_send_indexes(send_offset + item_idx) = comm_list(list_pos + recv_count + psb_elem_send_ + item_idx - 1) + end do + end if + + recv_offset = recv_offset + recv_count + send_offset = send_offset + send_count + list_pos = list_pos + recv_count + send_count + 3 + end do + + do neighbor_idx = 1, n_neighbors + proc_to_comm = this%peer_proc(neighbor_idx) + if (proc_to_comm /= my_rank) then + prc_rank = this%peer_mpi_rank(neighbor_idx) + local_meta = (/ this%peer_send_displs(neighbor_idx)+1, this%peer_send_counts(neighbor_idx), & + & this%peer_recv_displs(neighbor_idx)+1+send_total, this%peer_recv_counts(neighbor_idx) /) + call mpi_sendrecv(local_meta, 4, psb_mpi_mpk_, prc_rank, psb_rma_meta_tag, & + & remote_meta, 4, psb_mpi_mpk_, prc_rank, psb_rma_meta_tag, icomm, p2pstat, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + return + end if + this%peer_remote_send_displs(neighbor_idx) = remote_meta(1) + this%peer_remote_recv_displs(neighbor_idx) = remote_meta(3) + else + this%peer_remote_send_displs(neighbor_idx) = this%peer_send_displs(neighbor_idx)+1 + this%peer_remote_recv_displs(neighbor_idx) = this%peer_recv_displs(neighbor_idx)+1+send_total + end if + end do + + if ((send_offset /= send_total) .or. (recv_offset /= recv_total)) then + info = psb_err_internal_error_ + return + end if + + this%layout_nnbr = n_neighbors + this%layout_send = send_total + this%layout_recv = recv_total + this%layout_ready = .true. + end subroutine psb_comm_rma_ini_memory_buffer_layout + + subroutine psb_comm_rma_free(this, info) +#ifdef PSB_MPI_MOD + use mpi +#endif + class(psb_comm_rma_handle), intent(inout) :: this + integer(psb_ipk_), intent(out) :: info + integer(psb_mpk_) :: iret + + info = 0 + if (this%window_open) then + call mpi_win_unlock_all(this%win, iret) + this%window_open = .false. + end if + if (this%win /= mpi_win_null) then + call mpi_win_free(this%win, iret) + this%win = mpi_win_null + end if + this%window_ready = .false. + this%layout_ready = .false. + this%layout_nnbr = -1 + this%layout_send = -1 + this%layout_recv = -1 + call this%clear_memory_buffer_layout(info) + end subroutine psb_comm_rma_free + + subroutine psb_comm_rma_set_swap_status(this, flag, info) + class(psb_comm_rma_handle), intent(inout) :: this + integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(out) :: info + info = 0 + this%swap_status = flag + end subroutine psb_comm_rma_set_swap_status + + subroutine psb_comm_rma_get_swap_status(this, flag, info) + class(psb_comm_rma_handle), intent(in) :: this + integer(psb_ipk_), intent(out) :: flag + integer(psb_ipk_), intent(out) :: info + info = 0 + flag = this%swap_status + end subroutine psb_comm_rma_get_swap_status + +end module psb_comm_rma_mod \ No newline at end of file diff --git a/base/modules/comm/comm_schemes/psb_comm_schemes_mod.F90 b/base/modules/comm/comm_schemes/psb_comm_schemes_mod.F90 index 1340f73c..d3467a2c 100644 --- a/base/modules/comm/comm_schemes/psb_comm_schemes_mod.F90 +++ b/base/modules/comm/comm_schemes/psb_comm_schemes_mod.F90 @@ -11,6 +11,8 @@ module psb_comm_schemes_mod enumerator psb_comm_isend_irecv_ enumerator psb_comm_ineighbor_alltoallv_ enumerator psb_comm_persistent_ineighbor_alltoallv_ + enumerator psb_comm_rma_pull_ + enumerator psb_comm_rma_push_ end enum enum, bind(c) diff --git a/ext/psb_c_hll_mat_mod.f90 b/ext/psb_c_hll_mat_mod.f90 index 966b60f5..343d779e 100644 --- a/ext/psb_c_hll_mat_mod.f90 +++ b/ext/psb_c_hll_mat_mod.f90 @@ -534,7 +534,7 @@ contains if (allocated(a%irn)) deallocate(a%irn) if (allocated(a%ja)) deallocate(a%ja) if (allocated(a%val)) deallocate(a%val) - if (allocated(a%val)) deallocate(a%hkoffs) + if (allocated(a%hkoffs)) deallocate(a%hkoffs) call a%set_null() call a%set_nrows(izero) call a%set_ncols(izero) diff --git a/ext/psb_d_hll_mat_mod.f90 b/ext/psb_d_hll_mat_mod.f90 index acc3b312..2307db95 100644 --- a/ext/psb_d_hll_mat_mod.f90 +++ b/ext/psb_d_hll_mat_mod.f90 @@ -534,7 +534,7 @@ contains if (allocated(a%irn)) deallocate(a%irn) if (allocated(a%ja)) deallocate(a%ja) if (allocated(a%val)) deallocate(a%val) - if (allocated(a%val)) deallocate(a%hkoffs) + if (allocated(a%hkoffs)) deallocate(a%hkoffs) call a%set_null() call a%set_nrows(izero) call a%set_ncols(izero) diff --git a/ext/psb_s_hll_mat_mod.f90 b/ext/psb_s_hll_mat_mod.f90 index 735091c8..d42478d0 100644 --- a/ext/psb_s_hll_mat_mod.f90 +++ b/ext/psb_s_hll_mat_mod.f90 @@ -534,7 +534,7 @@ contains if (allocated(a%irn)) deallocate(a%irn) if (allocated(a%ja)) deallocate(a%ja) if (allocated(a%val)) deallocate(a%val) - if (allocated(a%val)) deallocate(a%hkoffs) + if (allocated(a%hkoffs)) deallocate(a%hkoffs) call a%set_null() call a%set_nrows(izero) call a%set_ncols(izero) diff --git a/ext/psb_z_hll_mat_mod.f90 b/ext/psb_z_hll_mat_mod.f90 index 98eb403f..b825e1c8 100644 --- a/ext/psb_z_hll_mat_mod.f90 +++ b/ext/psb_z_hll_mat_mod.f90 @@ -534,7 +534,7 @@ contains if (allocated(a%irn)) deallocate(a%irn) if (allocated(a%ja)) deallocate(a%ja) if (allocated(a%val)) deallocate(a%val) - if (allocated(a%val)) deallocate(a%hkoffs) + if (allocated(a%hkoffs)) deallocate(a%hkoffs) call a%set_null() call a%set_nrows(izero) call a%set_ncols(izero) diff --git a/test/comm/cg/psb_comm_cg_test.F90 b/test/comm/cg/psb_comm_cg_test.F90 index 357d0048..38652f82 100644 --- a/test/comm/cg/psb_comm_cg_test.F90 +++ b/test/comm/cg/psb_comm_cg_test.F90 @@ -29,7 +29,7 @@ program psb_comm_cg_test integer(psb_ipk_) :: desc_me, desc_np integer(psb_ipk_) :: idim, itmax, itrace, istop, iter integer(psb_ipk_) :: scheme_idx, prec_idx, rep, nrep, nwarm - integer(psb_ipk_), parameter :: n_schemes=3, n_precs=2 + integer(psb_ipk_), parameter :: n_schemes=5, n_precs=2 integer(psb_ipk_), allocatable :: iter_count(:,:,:), solve_info(:,:,:) integer(psb_ipk_) :: scheme_type(n_schemes) real(psb_dpk_) :: eps, err, t_start, t_elapsed @@ -71,10 +71,13 @@ program psb_comm_cg_test use_gpu = .false. #endif scheme_type = (/ psb_comm_isend_irecv_, psb_comm_ineighbor_alltoallv_, & - & psb_comm_persistent_ineighbor_alltoallv_ /) + & psb_comm_persistent_ineighbor_alltoallv_ , psb_comm_rma_pull_, psb_comm_rma_push_ /) scheme_name(1) = 'isend_irecv' scheme_name(2) = 'ineighbor_alltoallv' scheme_name(3) = 'persistent_ineighbor_a2av' + scheme_name(4) = 'psb_comm_rma_pull_' + scheme_name(5) = 'psb_comm_rma_push_' + prec_type(1) = 'NONE' prec_type(2) = 'DIAG' prec_name(1) = 'none' @@ -201,6 +204,10 @@ program psb_comm_cg_test do prec_idx = 1, n_precs do scheme_idx = 1, n_schemes do rep = 1, nrep + t_start = psb_wtime() + call psb_comm_set(scheme_type(scheme_idx),x%v%comm_handle,info) + comm_set_time(prec_idx,scheme_idx,rep) = psb_wtime() - t_start + call psb_geaxpby(dzero,b,dzero,x,desc_a,info) if (info /= psb_success_) goto 9999 @@ -225,9 +232,7 @@ program psb_comm_cg_test goto 9999 end if - t_start = psb_wtime() - call psb_comm_set(scheme_type(scheme_idx),x%v%comm_handle,info) - comm_set_time(prec_idx,scheme_idx,rep) = psb_wtime() - t_start + call psb_amx(ctxt,comm_set_time(prec_idx,scheme_idx,rep)) if (info /= psb_success_) goto 9999 diff --git a/test/comm/spmv/psb_spmv_test.f90 b/test/comm/spmv/psb_spmv_test.f90 index a6964956..4e20b0b4 100644 --- a/test/comm/spmv/psb_spmv_test.f90 +++ b/test/comm/spmv/psb_spmv_test.f90 @@ -5,10 +5,12 @@ module psb_spmv_overlap_test use psb_base_mod use psb_util_mod - use psb_comm_factory_mod, only: psb_comm_set use psb_comm_schemes_mod, only: psb_comm_isend_irecv_, psb_comm_ineighbor_alltoallv_, & - & psb_comm_persistent_ineighbor_alltoallv_ + & psb_comm_persistent_ineighbor_alltoallv_, psb_comm_rma_pull_, psb_comm_rma_push_, & + & psb_comm_handle_type + use psb_comm_baseline_mod, only: psb_comm_baseline_handle use psb_comm_neighbor_impl_mod, only: psb_comm_neighbor_handle + use psb_comm_rma_mod, only: psb_comm_rma_handle #ifdef PSB_HAVE_CUDA use psb_cuda_mod #endif @@ -529,7 +531,7 @@ contains return end subroutine psb_d_gen_pde3d - subroutine run_spmv_kernel(ctxt,use_gpu,matrix_file,matrix_fmt,cpu_fmt,gpu_fmt,idim_in,times_in,do_swap) + subroutine run_spmv_kernel(ctxt,use_gpu,matrix_file,matrix_fmt,cpu_fmt,gpu_fmt,idim_in,times_in,do_swap,comm_mode) use psb_base_mod #ifdef PSB_HAVE_CUDA use psb_cuda_mod @@ -544,6 +546,7 @@ contains character(len=*), intent(in) :: gpu_fmt integer(psb_ipk_), intent(in) :: idim_in, times_in logical, intent(in) :: do_swap + character(len=*), intent(in) :: comm_mode type(psb_dspmat_type) :: a type(psb_d_vect_type) :: x, y @@ -556,16 +559,36 @@ contains real(psb_dpk_) :: alpha, beta, t0, t1, dt, avg_t logical :: use_external_matrix + integer(psb_ipk_) :: comm_type + #ifdef PSB_HAVE_CUDA - type(psb_d_vect_cuda) :: cuda_vector_mold - type(psb_i_vect_cuda) :: cuda_index_mold - type(psb_d_cuda_elg_sparse_mat), target :: cuda_ell_sparse_mold - type(psb_d_cuda_csrg_sparse_mat), target :: cuda_csr_sparse_mold - type(psb_d_cuda_hdiag_sparse_mat), target :: cuda_hdia_sparse_mold - type(psb_d_cuda_hlg_sparse_mat), target :: cuda_hll_sparse_mold - class(psb_d_base_sparse_mat), pointer :: cuda_sparse_mold + type(psb_d_vect_cuda) :: cuda_vector_mold + type(psb_i_vect_cuda) :: cuda_index_mold + type(psb_d_cuda_elg_sparse_mat), target :: cuda_ell_sparse_mold + type(psb_d_cuda_csrg_sparse_mat), target :: cuda_csr_sparse_mold + type(psb_d_cuda_hdiag_sparse_mat), target :: cuda_hdia_sparse_mold + type(psb_d_cuda_hlg_sparse_mat), target :: cuda_hll_sparse_mold + class(psb_d_base_sparse_mat), pointer :: cuda_sparse_mold #endif + select case(psb_toupper(trim(comm_mode))) + case('P2P','ISEND_IRECV') + comm_type = psb_comm_isend_irecv_ + case('NEIGHBOR','INEIGHBOR_ALLTOALLV') + comm_type = psb_comm_ineighbor_alltoallv_ + case('PNEIGHBOR','PERSISTENT','PERSISTENT_INEIGHBOR_A2AV') + comm_type = psb_comm_persistent_ineighbor_alltoallv_ + case('MPI_GET','RMA_PULL') + comm_type = psb_comm_rma_pull_ + case('MPI_PUT','RMA_PUSH') + comm_type = psb_comm_rma_push_ + case default + comm_type = psb_comm_isend_irecv_ + if (my_rank == psb_root_) then + write(psb_err_unit,'("Unknown comm backend: ",a,", defaulting to P2P")') trim(comm_mode) + end if + end select + info = psb_success_ afmt = psb_toupper(trim(cpu_fmt)) if (len_trim(afmt) == 0) afmt = 'CSR' @@ -613,6 +636,9 @@ contains end if if (info /= psb_success_) goto 9999 + call psb_comm_set(comm_type,x%v%comm_handle,info) + if (info /= psb_success_) goto 9999 + #ifdef PSB_HAVE_CUDA if (use_gpu) then select case(psb_toupper(trim(gpu_fmt))) @@ -669,6 +695,7 @@ contains end if write(psb_out_unit,'(" global unknowns : ",i0)') n_global write(psb_out_unit,'(" repetitions : ",i0)') times + write(psb_out_unit,'(" comm backend : ",a)') trim(psb_toupper(trim(comm_mode))) write(psb_out_unit,'(" total time [s] : ",es12.5)') dt write(psb_out_unit,'(" avg time [s] : ",es12.5)') avg_t end if @@ -764,8 +791,13 @@ program psb_spmv_kernel character(len=8) :: gpu_fmt integer(psb_ipk_) :: idim_arg, times_arg logical :: do_swap - idim_arg = -1 - times_arg = -1 + integer :: kmode + integer, parameter :: n_comm_modes = 5 + character(len=20), parameter :: comm_modes(n_comm_modes) = [character(len=20) :: & + & 'P2P', 'NEIGHBOR', 'PNEIGHBOR', 'MPI_GET', 'MPI_PUT'] + + idim_arg = -1 + times_arg = -1 matrix_file = '' matrix_fmt = 'MM' @@ -869,12 +901,18 @@ program psb_spmv_kernel write(psb_out_unit,*) 'Welcome to PSBLAS version: ', psb_version_string_ write(psb_out_unit,*) 'This is the psb_spmv_kernel sample program' write(psb_out_unit,'("GPU enabled : ",l1)') use_gpu - write(psb_out_unit,'("Usage: ./psb_spmv_kernel [--gpu=TRUE|FALSE] [--dim=N] [--times=N] ",& - &"[--cpu_fmt=CSR|COO|CSC|ELL|HLL] [--gpu_fmt=HLL|ELL|CSR|HDIA] [--matrix=] [--fmt=MM|HB] ",& - &"[--overlap|--nooverlap]")') + write(psb_out_unit,'("Usage: ./psb_spmv_kernel [--gpu=TRUE|FALSE] [--dim=N] [--times=N] ",& + &"[--cpu_fmt=CSR|COO|CSC|ELL|HLL] [--gpu_fmt=HLL|ELL|CSR|HDIA] [--matrix=] [--fmt=MM|HB] ",& + &"[--overlap|--nooverlap] (runs all comm backends)")') end if - call run_spmv_kernel(ctxt,use_gpu,matrix_file,matrix_fmt,cpu_fmt,gpu_fmt,idim_arg,times_arg,do_swap) + do kmode = 1, n_comm_modes + if (my_rank == psb_root_) then + write(psb_out_unit,'(/,"=== Backend sweep: ",a," ===")') trim(comm_modes(kmode)) + end if + call run_spmv_kernel(ctxt, use_gpu, matrix_file, matrix_fmt, cpu_fmt, gpu_fmt, & + & idim_arg, times_arg, do_swap, comm_modes(kmode)) + end do #ifdef PSB_HAVE_CUDA if (use_gpu) call psb_cuda_exit() diff --git a/test/comm/swapdata/psb_comm_test.F90 b/test/comm/swapdata/psb_comm_test.F90 index 40e83734..0de15b0d 100644 --- a/test/comm/swapdata/psb_comm_test.F90 +++ b/test/comm/swapdata/psb_comm_test.F90 @@ -22,7 +22,7 @@ program psb_comm_test use psi_mod use psb_comm_factory_mod, only: psb_comm_set, psb_comm_free use psb_comm_schemes_mod, only: psb_comm_ineighbor_alltoallv_, psb_comm_persistent_ineighbor_alltoallv_, & - & psb_comm_isend_irecv_ + & psb_comm_isend_irecv_, psb_comm_rma_pull_, psb_comm_rma_push_ use psb_comm_schemes_mod, only: psb_comm_status_start_, psb_comm_status_wait_, psb_comm_status_unknown_ implicit none @@ -47,11 +47,12 @@ program psb_comm_test type(psb_ldspmat_type) :: aux_a ! ---- vectors ---- - type(psb_d_vect_type) :: v_baseline, v_neighbor, v_neighbor_persistent + type(psb_d_vect_type) :: v_baseline, v_neighbor, v_neighbor_persistent, v_rma_get, v_rma_put ! ---- temporary / comparison arrays ---- real(psb_dpk_), allocatable :: vals(:) - real(psb_dpk_), allocatable :: result_baseline(:), result_neighbor(:), result_persistent(:) + real(psb_dpk_), allocatable :: result_baseline(:), result_neighbor(:), result_persistent(:), & + & result_rma_get(:), result_rma_put(:) real(psb_dpk_), allocatable :: expected(:) ! ---- halo index bookkeeping ---- @@ -60,11 +61,16 @@ program psb_comm_test ! ---- error / reporting ---- integer(psb_ipk_) :: n_pass, n_total, imode - logical :: run_baseline, run_neighbor, run_persistent + logical :: run_baseline, run_neighbor, run_persistent, run_rma_get, run_rma_put logical :: mat_allocated logical :: comm_ok real(psb_dpk_) :: err, tol - real(psb_dpk_) :: t0, t1, dt, tsum_baseline, tsum_neighbor, tsum_neighbor_persistent + real(psb_dpk_) :: first_swap_baseline, first_swap_neighbor, first_swap_persistent, & + & first_swap_rma_get, first_swap_rma_put + real(psb_dpk_) :: comm_setup_time_baseline, comm_setup_time_neighbor, comm_setup_time_persistent, & + & comm_setup_time_rma_get, comm_setup_time_rma_put + real(psb_dpk_) :: t0, t1, dt, tsum_baseline, tsum_neighbor, tsum_neighbor_persistent, & + & tsum_rma_get, tsum_rma_put integer(psb_lpk_), allocatable :: glob_col(:) character(len=40) :: name real(psb_dpk_) :: huge_d @@ -136,21 +142,33 @@ program psb_comm_test run_baseline = .false. run_neighbor = .false. run_persistent = .false. + run_rma_get = .false. + run_rma_put = .false. + + select case (trim(adjustl(mode))) case ('both','all') - run_baseline = .true. - run_neighbor = .true. - run_persistent = .true. + run_baseline = .true. + run_neighbor = .true. + run_persistent = .true. + run_rma_get = .true. + run_rma_put = .true. case ('baseline') run_baseline = .true. case ('neighbor') run_neighbor = .true. case ('persistent','persistent_neighbor','persistent-neighbor') run_persistent = .true. + case ('rma_get') + run_rma_get = .true. + case ('rma_put') + run_rma_put = .true. case default run_baseline = .true. run_neighbor = .true. run_persistent = .true. + run_rma_get = .true. + run_rma_put = .true. end select if ((.not.use_external_matrix) .and. (idim <= 0)) then @@ -267,6 +285,18 @@ program psb_comm_test write(psb_err_unit,*) my_rank, 'geall persistent-neighbor error:', info call psb_abort(ctxt) end if + call psb_geall(v_rma_get, desc_a, info) + if (info /= psb_success_) then + write(psb_err_unit,*) my_rank, 'geall rma-get error:', info + call psb_abort(ctxt) + end if + call psb_geall(v_rma_put, desc_a, info) + if (info /= psb_success_) then + write(psb_err_unit,*) my_rank, 'geall rma-put error:', info + call psb_abort(ctxt) + end if + + call psb_geasb(v_baseline, desc_a, info, scratch=.true.) if (info /= psb_success_) then write(psb_err_unit,*) my_rank, 'geasb baseline error:', info @@ -282,6 +312,18 @@ program psb_comm_test write(psb_err_unit,*) my_rank, 'geasb persistent-neighbor error:', info call psb_abort(ctxt) end if + call psb_geasb(v_rma_get, desc_a, info, scratch=.true.) + if (info /= psb_success_) then + write(psb_err_unit,*) my_rank, 'geasb rma-get error:', info + call psb_abort(ctxt) + end if + call psb_geasb(v_rma_put, desc_a, info, scratch=.true.) + if (info /= psb_success_) then + write(psb_err_unit,*) my_rank, 'geasb rma-put error:', info + call psb_abort(ctxt) + end if + + ! Fill owned entries with the global index value allocate(vals(ncol)) @@ -292,6 +334,8 @@ program psb_comm_test call v_baseline%set_vect(vals) call v_neighbor%set_vect(vals) call v_neighbor_persistent%set_vect(vals) + call v_rma_get%set_vect(vals) + call v_rma_put%set_vect(vals) deallocate(vals) ! ================================================================== @@ -304,15 +348,39 @@ program psb_comm_test do i = 1, ncol expected(i) = real(glob_col(i), psb_dpk_) end do - allocate(result_baseline(ncol), result_neighbor(ncol), result_persistent(ncol)) + allocate(result_baseline(ncol), result_neighbor(ncol), result_persistent(ncol), & + & result_rma_get(ncol), result_rma_put(ncol)) result_baseline = huge_d result_neighbor = huge_d result_persistent = huge_d + result_rma_get = huge_d + result_rma_put = huge_d + + first_swap_baseline = 0.0_psb_dpk_ + first_swap_neighbor = 0.0_psb_dpk_ + first_swap_persistent = 0.0_psb_dpk_ + first_swap_rma_get = 0.0_psb_dpk_ + first_swap_rma_put = 0.0_psb_dpk_ + + comm_setup_time_baseline = 0.0_psb_dpk_ + comm_setup_time_neighbor = 0.0_psb_dpk_ + comm_setup_time_persistent = 0.0_psb_dpk_ + comm_setup_time_rma_get = 0.0_psb_dpk_ + comm_setup_time_rma_put = 0.0_psb_dpk_ ! ================================================================== ! 6. Baseline halo exchange (Isend/Irecv in one call) ! ================================================================== if (run_baseline) then + comm_setup_time_baseline = psb_wtime() + call psb_comm_set(psb_comm_isend_irecv_, v_baseline%v%comm_handle, info) + if (info /= 0) then + write(psb_err_unit,*) my_rank, 'psb_comm_set baseline error:', info + call psb_abort(ctxt) + end if + comm_setup_time_baseline = psb_wtime() - comm_setup_time_baseline + + first_swap_baseline = psb_wtime() call psi_swapdata( & swap_status=psb_comm_status_start_, & beta=dzero, & @@ -336,6 +404,7 @@ program psb_comm_test write(psb_err_unit,*) my_rank, 'baseline swap error:', info call psb_abort(ctxt) end if + first_swap_baseline = psb_wtime() - first_swap_baseline end if @@ -343,11 +412,16 @@ program psb_comm_test ! 7. Neighbor topology halo exchange (start + wait) ! ================================================================== if (run_neighbor) then + comm_setup_time_neighbor = psb_wtime() call psb_comm_set(psb_comm_ineighbor_alltoallv_, v_neighbor%v%comm_handle, info) if (info /= 0) then write(psb_err_unit,*) my_rank, 'psb_comm_set neighbor error:', info call psb_abort(ctxt) end if + comm_setup_time_neighbor = psb_wtime() - comm_setup_time_neighbor + + + first_swap_neighbor = psb_wtime() call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_) if (info /= psb_success_) then write(psb_err_unit,*) my_rank, 'neighbor start error:', info @@ -359,17 +433,23 @@ program psb_comm_test write(psb_err_unit,*) my_rank, 'neighbor wait error:', info call psb_abort(ctxt) end if + first_swap_neighbor = psb_wtime() - first_swap_neighbor end if ! ================================================================== ! 7b. Persistent-neighbor halo exchange (start + wait) ! ================================================================== if (run_persistent) then + comm_setup_time_persistent = psb_wtime() call psb_comm_set(psb_comm_persistent_ineighbor_alltoallv_, v_neighbor_persistent%v%comm_handle, info) if (info /= 0) then write(psb_err_unit,*) my_rank, 'psb_comm_set persistent-neighbor error:', info call psb_abort(ctxt) end if + comm_setup_time_persistent = psb_wtime() - comm_setup_time_persistent + + + first_swap_persistent = psb_wtime() call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_) if (info /= psb_success_) then write(psb_err_unit,*) my_rank, 'persistent-neighbor start error:', info @@ -380,8 +460,62 @@ program psb_comm_test write(psb_err_unit,*) my_rank, 'persistent-neighbor wait error:', info call psb_abort(ctxt) end if + first_swap_persistent = psb_wtime() - first_swap_persistent end if + + + if(run_rma_get) then + comm_setup_time_rma_get = psb_wtime() + call psb_comm_set(psb_comm_rma_pull_, v_rma_get%v%comm_handle, info) + if (info /= 0) then + write(psb_err_unit,*) my_rank, 'psb_comm_set RMA get error:', info + call psb_abort(ctxt) + end if + comm_setup_time_rma_get = psb_wtime() - comm_setup_time_rma_get + + first_swap_rma_get = psb_wtime() + call psi_swapdata(psb_comm_status_start_, dzero, v_rma_get%v, desc_a, info, data=psb_comm_halo_) + if (info /= psb_success_) then + write(psb_err_unit,*) my_rank, 'RMA get start error:', info + call psb_abort(ctxt) + end if + call psi_swapdata(psb_comm_status_wait_, dzero, v_rma_get%v, desc_a, info, data=psb_comm_halo_) + if (info /= psb_success_) then + write(psb_err_unit,*) my_rank, 'RMA get wait error:', info + call psb_abort(ctxt) + end if + first_swap_rma_get = psb_wtime() - first_swap_rma_get + end if + + + + + if(run_rma_put) then + comm_setup_time_rma_put = psb_wtime() + call psb_comm_set(psb_comm_rma_push_, v_rma_put%v%comm_handle, info) + if (info /= 0) then + write(psb_err_unit,*) my_rank, 'psb_comm_set RMA put error:', info + call psb_abort(ctxt) + end if + comm_setup_time_rma_put = psb_wtime() - comm_setup_time_rma_put + + first_swap_rma_put = psb_wtime() + call psi_swapdata(psb_comm_status_start_, dzero, v_rma_put%v, desc_a, info, data=psb_comm_halo_) + if (info /= psb_success_) then + write(psb_err_unit,*) my_rank, 'RMA put start error:', info + call psb_abort(ctxt) + end if + call psi_swapdata(psb_comm_status_wait_, dzero, v_rma_put%v, desc_a, info, data=psb_comm_halo_) + if (info /= psb_success_) then + write(psb_err_unit,*) my_rank, 'RMA put wait error:', info + call psb_abort(ctxt) + end if + first_swap_rma_put = psb_wtime() - first_swap_rma_put + end if + + + ! ================================================================== ! 8. Performance: repeat exchanges and measure timings ! ================================================================== @@ -392,6 +526,8 @@ program psb_comm_test tsum_baseline = 0.0_psb_dpk_ tsum_neighbor = 0.0_psb_dpk_ tsum_neighbor_persistent = 0.0_psb_dpk_ + tsum_rma_get = 0.0_psb_dpk_ + tsum_rma_put = 0.0_psb_dpk_ do i = 1, iters if (run_baseline) then @@ -423,194 +559,119 @@ program psb_comm_test call psb_amx(ctxt, dt) tsum_neighbor_persistent = tsum_neighbor_persistent + dt end if + + + if (run_rma_get) then + t0 = psb_wtime() + call psi_swapdata(psb_comm_status_start_, dzero, v_rma_get%v, desc_a, info, data=psb_comm_halo_) + call psi_swapdata(psb_comm_status_wait_, dzero, v_rma_get%v, desc_a, info, data=psb_comm_halo_) + t1 = psb_wtime() + dt = t1 - t0 + call psb_amx(ctxt, dt) + tsum_rma_get = tsum_rma_get + dt + end if + + + if (run_rma_put) then + t0 = psb_wtime() + call psi_swapdata(psb_comm_status_start_, dzero, v_rma_put%v, desc_a, info, data=psb_comm_halo_) + call psi_swapdata(psb_comm_status_wait_, dzero, v_rma_put%v, desc_a, info, data=psb_comm_halo_) + t1 = psb_wtime() + dt = t1 - t0 + call psb_amx(ctxt, dt) + tsum_rma_put = tsum_rma_put + dt + end if + + end do + + + call psb_amx(ctxt, tsum_baseline) + call psb_amx(ctxt, tsum_neighbor) + call psb_amx(ctxt, tsum_neighbor_persistent) + call psb_amx(ctxt, first_swap_baseline) + call psb_amx(ctxt, first_swap_neighbor) + call psb_amx(ctxt, first_swap_persistent) + call psb_amx(ctxt, comm_setup_time_baseline) + call psb_amx(ctxt, comm_setup_time_neighbor) + call psb_amx(ctxt, comm_setup_time_persistent) + call psb_amx(ctxt, comm_setup_time_rma_get) + call psb_amx(ctxt, comm_setup_time_rma_put) + + + + if (my_rank == 0) then if (run_baseline) then write(psb_out_unit,'(" Avg baseline time : ",es12.5)') (tsum_baseline / real(iters,psb_dpk_)) write(psb_out_unit,'(" Tot baseline time : ",es12.5)') tsum_baseline + write(psb_out_unit,'(" First baseline time: ",es12.5)') first_swap_baseline + write(psb_out_unit,'(" Baseline comm setup: ",es12.5)') comm_setup_time_baseline end if if (run_neighbor) then write(psb_out_unit,'(" Avg neighbor time : ",es12.5)') (tsum_neighbor / real(iters,psb_dpk_)) write(psb_out_unit,'(" Tot neighbor time : ",es12.5)') tsum_neighbor + write(psb_out_unit,'(" First neighbor time: ",es12.5)') first_swap_neighbor + write(psb_out_unit,'(" Neighbor comm setup: ",es12.5)') comm_setup_time_neighbor end if if (run_persistent) then write(psb_out_unit,'(" Avg pers-neigh time: ",es12.5)') (tsum_neighbor_persistent / real(iters,psb_dpk_)) write(psb_out_unit,'(" Tot pers-neigh time: ",es12.5)') tsum_neighbor_persistent + write(psb_out_unit,'(" First pers-neigh time: ",es12.5)') first_swap_persistent + write(psb_out_unit,'(" Persistent comm setup: ",es12.5)') comm_setup_time_persistent + end if + if (run_rma_get) then + write(psb_out_unit,'(" Avg RMA get time : ",es12.5)') (tsum_rma_get / real(iters,psb_dpk_)) + write(psb_out_unit,'(" Tot RMA get time : ",es12.5)') tsum_rma_get + write(psb_out_unit,'(" First RMA get time : ",es12.5)') first_swap_rma_get + write(psb_out_unit,'(" RMA get comm setup : ",es12.5)') comm_setup_time_rma_get + end if + if (run_rma_put) then + write(psb_out_unit,'(" Avg RMA put time : ",es12.5)') (tsum_rma_put / real(iters,psb_dpk_)) + write(psb_out_unit,'(" Tot RMA put time : ",es12.5)') tsum_rma_put + write(psb_out_unit,'(" First RMA put time : ",es12.5)') first_swap_rma_put + write(psb_out_unit,'(" RMA put comm setup : ",es12.5)') comm_setup_time_rma_put end if end if ! ================================================================== ! 8. Extract results and compare ! ================================================================== - result_baseline = v_baseline%v%v - result_neighbor = v_neighbor%v%v + result_baseline = v_baseline%v%v + result_neighbor = v_neighbor%v%v result_persistent = v_neighbor_persistent%v%v + result_rma_get = v_rma_get%v%v + result_rma_put = v_rma_put%v%v - ! Debug: Check if results are properly populated - if (my_rank == 0 .and. debug_swapdata) then - write(psb_out_unit,'("DEBUG: ncol=",i0," nrow=",i0)') ncol, nrow - write(psb_out_unit,'("DEBUG: size(result_baseline)=",i0)') size(result_baseline) - if (ncol > 0) then - write(psb_out_unit,'("DEBUG: result_baseline(1:min(5,ncol))=",5(es12.5,1x))') & - & result_baseline(1:min(5,ncol)) - write(psb_out_unit,'("DEBUG: expected(1:min(5,ncol))=",5(es12.5,1x))') & - & expected(1:min(5,ncol)) - end if - end if - - if (run_baseline .and. run_neighbor) then - n_total = n_total + 1 - err = huge_d - if (ncol > 0) then - err = maxval(abs(result_baseline(1:ncol) - result_neighbor(1:ncol))) - else - err = 0.0_psb_dpk_ - end if - call psb_amx(ctxt, err) - if (my_rank == 0) then - if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then - write(psb_out_unit,'(" [PASS] cross-check baseline vs neighbor : err = ",es12.5)') err - n_pass = n_pass + 1 - else - write(psb_out_unit,'(" [FAIL] cross-check baseline vs neighbor : err = ",es12.5)') err - end if - end if - end if - - if (run_baseline) then - n_total = n_total + 1 - err = huge_d - if (ncol > 0) then - err = maxval(abs(result_baseline(1:ncol) - expected(1:ncol))) - else - err = 0.0_psb_dpk_ - end if - call psb_amx(ctxt, err) - if (my_rank == 0) then - if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then - write(psb_out_unit,'(" [PASS] baseline absolute correctness : err = ",es12.5)') err - n_pass = n_pass + 1 - else - write(psb_out_unit,'(" [FAIL] baseline absolute correctness : err = ",es12.5)') err - end if - end if - end if - - if (run_neighbor) then - n_total = n_total + 1 - err = huge_d - if (ncol > 0) then - err = maxval(abs(result_neighbor(1:ncol) - expected(1:ncol))) - else - err = 0.0_psb_dpk_ - end if - call psb_amx(ctxt, err) - if (my_rank == 0) then - if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then - write(psb_out_unit,'(" [PASS] neighbor absolute correctness : err = ",es12.5)') err - n_pass = n_pass + 1 - else - write(psb_out_unit,'(" [FAIL] neighbor absolute correctness : err = ",es12.5)') err - end if - end if - end if + ! --- Cross Checks --- + if (run_baseline .and. run_neighbor) & + call check_result("cross-check baseline vs neighbor", result_baseline, result_neighbor, ncol, tol) - if (run_baseline .and. run_persistent) then - n_total = n_total + 1 - err = huge_d - if (ncol > 0) then - err = maxval(abs(result_baseline(1:ncol) - result_persistent(1:ncol))) - else - err = 0.0_psb_dpk_ - end if - call psb_amx(ctxt, err) - if (my_rank == 0) then - if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then - write(psb_out_unit,'(" [PASS] cross-check baseline vs pers-nei : err = ",es12.5)') err - n_pass = n_pass + 1 - else - write(psb_out_unit,'(" [FAIL] cross-check baseline vs pers-nei : err = ",es12.5)') err - end if - end if - end if + if (run_baseline .and. run_persistent) & + call check_result("cross-check baseline vs pers-nei", result_baseline, result_persistent, ncol, tol) - if (run_persistent) then - n_total = n_total + 1 - err = huge_d - if (ncol > 0) then - err = maxval(abs(result_persistent(1:ncol) - expected(1:ncol))) - else - err = 0.0_psb_dpk_ - end if - call psb_amx(ctxt, err) - if (my_rank == 0) then - if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then - write(psb_out_unit,'(" [PASS] pers-neigh absolute correctness : err = ",es12.5)') err - n_pass = n_pass + 1 - else - write(psb_out_unit,'(" [FAIL] pers-neigh absolute correctness : err = ",es12.5)') err - end if - end if - end if + if (run_baseline .and. run_rma_get) & + call check_result("cross-check baseline vs rma-get", result_baseline, result_rma_get, ncol, tol) + + if (run_baseline .and. run_rma_put) & + call check_result("cross-check baseline vs rma-put", result_baseline, result_rma_put, ncol, tol) - if (run_neighbor) then - ! ---- Test 6: repeat neighbor exchange (topology reuse) ---- - do i = nrow+1, ncol - result_neighbor(i) = dzero - end do - call v_neighbor%set_vect(result_neighbor) + ! --- Absolute Correctness Checks against Expected --- + if (run_baseline) & + call check_result("baseline absolute correctness", result_baseline, expected, ncol, tol) + + if (run_neighbor) & + call check_result("neighbor absolute correctness", result_neighbor, expected, ncol, tol) + + if (run_persistent) & + call check_result("pers-neigh absolute correctness", result_persistent, expected, ncol, tol) - call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_) - call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_) + if (run_rma_get) & + call check_result("rma_get absolute correctness", result_rma_get, expected, ncol, tol) - result_neighbor = v_neighbor%v%v - n_total = n_total + 1 - err = huge_d - if (ncol > 0) then - err = maxval(abs(result_neighbor(1:ncol) - expected(1:ncol))) - else - err = 0.0_psb_dpk_ - end if - call psb_amx(ctxt, err) - if (my_rank == 0) then - if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then - write(psb_out_unit,'(" [PASS] neighbor topology reuse : err = ",es12.5)') err - n_pass = n_pass + 1 - else - write(psb_out_unit,'(" [FAIL] neighbor topology reuse : err = ",es12.5)') err - end if - end if - end if - - if (run_persistent) then - ! ---- Test 7: repeat persistent-neighbor exchange (buffer reuse) ---- - do i = nrow+1, ncol - result_persistent(i) = dzero - end do - call v_neighbor_persistent%set_vect(result_persistent) - - call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_) - call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_) - - result_persistent = v_neighbor_persistent%v%v - n_total = n_total + 1 - err = huge_d - if (ncol > 0) then - err = maxval(abs(result_persistent(1:ncol) - expected(1:ncol))) - else - err = 0.0_psb_dpk_ - end if - call psb_amx(ctxt, err) - if (my_rank == 0) then - if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then - write(psb_out_unit,'(" [PASS] pers-neigh buffer reuse : err = ",es12.5)') err - n_pass = n_pass + 1 - else - write(psb_out_unit,'(" [FAIL] pers-neigh buffer reuse : err = ",es12.5)') err - end if - end if - end if + if (run_rma_put) & + call check_result("rma_put absolute correctness", result_rma_put, expected, ncol, tol) ! ================================================================== ! 9. Summary @@ -640,4 +701,36 @@ program psb_comm_test call psb_cdfree(desc_a, info) call psb_exit(ctxt) + +contains + + ! Helper routine to compare two arrays, reduce the error across MPI ranks, and print the result + subroutine check_result(test_name, arr1, arr2, n, tolerance) + character(len=*), intent(in) :: test_name + real(psb_dpk_), intent(in) :: arr1(:), arr2(:) + integer(psb_ipk_), intent(in):: n + real(psb_dpk_), intent(in) :: tolerance + real(psb_dpk_) :: err_val + + n_total = n_total + 1 + + if (n > 0) then + err_val = maxval(abs(arr1(1:n) - arr2(1:n))) + else + err_val = 0.0_psb_dpk_ + end if + + ! Get max error across all processes + call psb_amx(ctxt, err_val) + + if (my_rank == 0) then + if ((err_val >= 0.0_psb_dpk_) .and. (err_val < tolerance)) then + write(psb_out_unit,'(" [PASS] ", a, t45, ": err = ",es12.5)') test_name, err_val + n_pass = n_pass + 1 + else + write(psb_out_unit,'(" [FAIL] ", a, t45, ": err = ",es12.5)') test_name, err_val + end if + end if + end subroutine check_result + end program psb_comm_test \ No newline at end of file