diff --git a/base/comm/internals/psi_dswapdata.F90 b/base/comm/internals/psi_dswapdata.F90 index 890cfc30..dd208d57 100644 --- a/base/comm/internals/psi_dswapdata.F90 +++ b/base/comm/internals/psi_dswapdata.F90 @@ -92,7 +92,7 @@ submodule (psi_d_comm_v_mod) psi_d_swapdata_impl use psb_base_mod contains - subroutine psi_dswapdata_vect(flag,beta,y,desc_a,work,info,data) + module subroutine psi_dswapdata_vect(flag,beta,y,desc_a,work,info,data) #ifdef PSB_MPI_MOD use mpi @@ -122,6 +122,8 @@ contains call psb_erractionsave(err_act) ctxt = desc_a%get_context() + + ! get communication from context -- this can be eliminated since it is not passed to psi_swapdata icomm = ctxt%get_mpic() call psb_info(ctxt,me,np) if (np == -1) then @@ -160,19 +162,21 @@ contains end subroutine psi_dswapdata_vect - ! - ! - ! Subroutine: psi_dswap_vidx_vect - ! Data exchange among processes. - ! - ! Takes care of Y an exanspulated vector. Relies on the gather/scatter methods - ! of vectors. - ! - ! The real workhorse: the outer routine will only choose the index list - ! this one takes the index list and does the actual exchange. - ! - ! - ! +! +! +! Subroutine: psi_dswap_vidx_vect +! Data exchange among processes. +! +! Takes care of Y an exanspulated vector. Relies on the gather/scatter methods +! of vectors. +! +! The real workhorse: the outer routine will only choose the index list +! this one takes the index list and does the actual exchange. +! +! This is a wrapper function that calls different communication schemes depending +! on the flag variable. +! +! module subroutine psi_dswap_vidx_vect(ctxt,flag,beta,y,idx, & & totxch,totsnd,totrcv,work,info) @@ -184,233 +188,492 @@ contains include 'mpif.h' #endif - type(psb_ctxt_type), intent(in) :: ctxt - integer(psb_ipk_), intent(in) :: flag - integer(psb_ipk_), intent(out) :: info - class(psb_d_base_vect_type) :: y - real(psb_dpk_) :: beta - real(psb_dpk_), target :: work(:) - class(psb_i_base_vect_type), intent(inout) :: idx - integer(psb_ipk_), intent(in) :: totxch,totsnd, totrcv - - ! locals - integer(psb_mpk_) :: np, me - integer(psb_mpk_) :: proc_to_comm, p2ptag, p2pstat(mpi_status_size),& - & iret, nesd, nerv - integer(psb_mpk_) :: icomm - integer(psb_mpk_), allocatable :: prcid(:) - integer(psb_ipk_) :: err_act, i, idx_pt, totsnd_, totrcv_,& - & snd_pt, rcv_pt, pnti, n - logical :: swap_mpi, swap_sync, swap_send, swap_recv,& - & albf,do_send,do_recv - logical, parameter :: usersend=.false., debug=.false. - character(len=20) :: name - - info=psb_success_ - name='psi_swap_datav' - call psb_erractionsave(err_act) - call psb_info(ctxt,me,np) - if (np == -1) then - info=psb_err_context_error_ - call psb_errpush(info,name) + type(psb_ctxt_type), intent(in) :: ctxt + !integer(psb_mpk_), intent(in) :: icomm + integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(out) :: info + class(psb_d_base_vect_type) :: y + real(psb_dpk_), intent(in) :: beta + real(psb_dpk_), target :: work(:) + class(psb_i_base_vect_type), intent(inout) :: idx + integer(psb_ipk_), intent(in) :: totxch,totsnd, totrcv + + ! local variables used to detect the communication scheme + logical :: swap_mpi, swap_sync, swap_send, swap_recv, swap_start, swap_wait + logical :: baseline, neighbor_a2av + + ! local variable used for get the communicator + integer(psb_mpk_) :: icomm + + ! error handling variables + integer(psb_ipk_) :: err_act + integer(psb_mpk_) :: me, np + character(len=30) :: name + + + info=psb_success_ + name='psi_dswap_vidx_vect' + call psb_erractionsave(err_act) + call psb_info(ctxt,me,np) + if (np == -1) then + info=psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + + + swap_mpi = iand(flag,psb_swap_mpi_) /= 0 + swap_sync = iand(flag,psb_swap_sync_) /= 0 + swap_send = iand(flag,psb_swap_send_) /= 0 + swap_recv = iand(flag,psb_swap_recv_) /= 0 + swap_start = iand(flag,psb_swap_start_) /= 0 + swap_wait = iand(flag,psb_swap_wait_) /= 0 + + baseline = swap_mpi .or. swap_send .or. swap_recv .or. swap_sync + neighbor_a2av = swap_start .or. swap_wait + + icomm = ctxt%get_mpic() + + if( (baseline.eqv..true.).and.(neighbor_a2av.eqv..true.) ) then + info=psb_err_mpi_error_ + call psb_errpush(info,name,a_err='Incompatible flag settings: both baseline and neighbor_a2av are true') + goto 9999 + end if + + + if (baseline) then + call psi_dswap_baseline_vect(ctxt,icomm,flag,beta,y,idx,totxch,totsnd,totrcv,work,info) + if (info /= psb_success_) then + call psb_errpush(info,name,a_err='baseline swap') goto 9999 - endif - icomm = ctxt%get_mpic() + end if + else if (neighbor_a2av) then + call psi_dswap_neighbor_topology_vect(ctxt,icomm,flag,beta,y,idx,totxch,totsnd,totrcv,work,info) + if (info /= psb_success_) then + call psb_errpush(info,name,a_err='neighbor a2av swap') + goto 9999 + end if + else + info = psb_err_mpi_error_ + call psb_errpush(info,name,a_err='Incompatible flag settings: neither baseline nor neighbor_a2av is true') + goto 9999 + end if + + call psb_erractionrestore(err_act) + return - n=1 - swap_mpi = iand(flag,psb_swap_mpi_) /= 0 - swap_sync = iand(flag,psb_swap_sync_) /= 0 - swap_send = iand(flag,psb_swap_send_) /= 0 - swap_recv = iand(flag,psb_swap_recv_) /= 0 - do_send = swap_mpi .or. swap_sync .or. swap_send - do_recv = swap_mpi .or. swap_sync .or. swap_recv - - totrcv_ = totrcv * n - totsnd_ = totsnd * n - call idx%sync() - - if (debug) write(*,*) me,'Internal buffer' - if (do_send) then - if (allocated(y%comid)) then - if (any(y%comid /= mpi_request_null)) then - ! - ! Unfinished communication? Something is wrong.... - ! - info=psb_err_mpi_error_ - call psb_errpush(info,name,m_err=(/-2/)) - goto 9999 - end if - end if - if (debug) write(*,*) me,'do_send start' - call y%new_buffer(ione*size(idx%v),info) - call y%new_comid(totxch,info) - y%comid = mpi_request_null - call psb_realloc(totxch,prcid,info) - ! First I post all the non blocking receives - pnti = 1 - do i=1, totxch - proc_to_comm = idx%v(pnti+psb_proc_id_) - nerv = idx%v(pnti+psb_n_elem_recv_) - nesd = idx%v(pnti+nerv+psb_n_elem_send_) - - rcv_pt = 1+pnti+psb_n_elem_recv_ - prcid(i) = psb_get_mpi_rank(ctxt,proc_to_comm) - if ((nerv>0).and.(proc_to_comm /= me)) then - if (debug) write(*,*) me,'Posting receive from',prcid(i),rcv_pt - p2ptag = psb_double_swap_tag - call mpi_irecv(y%combuf(rcv_pt),nerv,& - & psb_mpi_r_dpk_,prcid(i),& - & p2ptag, icomm,y%comid(i,2),iret) - end if - pnti = pnti + nerv + nesd + 3 - end do - if (debug) write(*,*) me,' Gather ' - ! - ! Then gather for sending. - ! - pnti = 1 - do i=1, totxch - nerv = idx%v(pnti+psb_n_elem_recv_) - nesd = idx%v(pnti+nerv+psb_n_elem_send_) - snd_pt = 1+pnti+nerv+psb_n_elem_send_ - rcv_pt = 1+pnti+psb_n_elem_recv_ - idx_pt = snd_pt - call y%gth(idx_pt,nesd,idx) - pnti = pnti + nerv + nesd + 3 - end do +9999 call psb_error_handler(ctxt,err_act) - ! - ! Then wait - ! - call y%device_wait() + return + end subroutine psi_dswap_vidx_vect - if (debug) write(*,*) me,' isend' - ! - ! Then send - ! - pnti = 1 - snd_pt = 1 - rcv_pt = 1 - p2ptag = psb_double_swap_tag - do i=1, totxch - proc_to_comm = idx%v(pnti+psb_proc_id_) - nerv = idx%v(pnti+psb_n_elem_recv_) - nesd = idx%v(pnti+nerv+psb_n_elem_send_) - snd_pt = 1+pnti+nerv+psb_n_elem_send_ - rcv_pt = 1+pnti+psb_n_elem_recv_ - - if ((nesd>0).and.(proc_to_comm /= me)) then - call mpi_isend(y%combuf(snd_pt),nesd,& - & psb_mpi_r_dpk_,prcid(i),& - & p2ptag,icomm,y%comid(i,1),iret) - end if +! +! subroutine psi_dswap_baseline_vect +! This performs Isend/Irecv as a baseline communication mode +! +subroutine psi_dswap_baseline_vect(ctxt,icomm,flag,beta,y,idx, & + & totxch,totsnd,totrcv,work,info) - if(iret /= mpi_success) then - info=psb_err_mpi_error_ - call psb_errpush(info,name,m_err=(/iret/)) - goto 9999 - end if - pnti = pnti + nerv + nesd + 3 - end do - end if +#ifdef PSB_MPI_MOD + use mpi +#endif + implicit none +#ifdef PSB_MPI_H + include 'mpif.h' +#endif - if (do_recv) then - if (debug) write(*,*) me,' do_Recv' - if (.not.allocated(y%comid)) then + type(psb_ctxt_type), intent(in) :: ctxt + integer(psb_mpk_), intent(in) :: icomm + integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(out) :: info + class(psb_d_base_vect_type) :: y + real(psb_dpk_), intent(in) :: beta + real(psb_dpk_), target :: work(:) + class(psb_i_base_vect_type), intent(inout) :: idx + integer(psb_ipk_), intent(in) :: totxch,totsnd, totrcv + + ! locals + integer(psb_mpk_) :: np, me + integer(psb_mpk_) :: proc_to_comm, p2ptag, p2pstat(mpi_status_size),& + & iret, nesd, nerv + integer(psb_mpk_), allocatable :: prcid(:) + integer(psb_ipk_) :: err_act, i, idx_pt, totsnd_, totrcv_,& + & snd_pt, rcv_pt, pnti, n + logical :: swap_mpi, swap_sync, swap_send, swap_recv,& + & albf,do_send,do_recv + logical, parameter :: usersend=.false., debug=.false. + character(len=20) :: name + + info = psb_success_ + name = 'psi_dswap_baseline_vect' + call psb_erractionsave(err_act) + call psb_info(ctxt,me,np) + if (np == -1) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + + n=1 + swap_mpi = iand(flag,psb_swap_mpi_) /= 0 + swap_sync = iand(flag,psb_swap_sync_) /= 0 + swap_send = iand(flag,psb_swap_send_) /= 0 + swap_recv = iand(flag,psb_swap_recv_) /= 0 + + do_send = swap_mpi .or. swap_sync .or. swap_send + do_recv = swap_mpi .or. swap_sync .or. swap_recv + + totrcv_ = totrcv * n + totsnd_ = totsnd * n + call idx%sync() + + if (debug) write(*,*) me,'Internal buffer' + if (do_send) then + if (allocated(y%comid)) then + if (any(y%comid /= mpi_request_null)) then ! - ! No matching send? Something is wrong.... + ! Unfinished communication? Something is wrong.... ! info=psb_err_mpi_error_ - call psb_errpush(info,name,m_err=(/-2/)) + call psb_errpush(info,name,a_err='Unfinished communication? Something is wrong....') goto 9999 end if - call psb_realloc(totxch,prcid,info) - - if (debug) write(*,*) me,' wait' - pnti = 1 - p2ptag = psb_double_swap_tag - do i=1, totxch - proc_to_comm = idx%v(pnti+psb_proc_id_) - nerv = idx%v(pnti+psb_n_elem_recv_) - nesd = idx%v(pnti+nerv+psb_n_elem_send_) - snd_pt = 1+pnti+nerv+psb_n_elem_send_ - rcv_pt = 1+pnti+psb_n_elem_recv_ - - if (proc_to_comm /= me)then - if (nesd>0) then - call mpi_wait(y%comid(i,1),p2pstat,iret) - if(iret /= mpi_success) then - info=psb_err_mpi_error_ - call psb_errpush(info,name,m_err=(/iret/)) - goto 9999 - end if - end if - if (nerv>0) then - call mpi_wait(y%comid(i,2),p2pstat,iret) - if(iret /= mpi_success) then - info=psb_err_mpi_error_ - call psb_errpush(info,name,m_err=(/iret/)) - goto 9999 - end if + end if + if (debug) write(*,*) me,'do_send start' + call y%new_buffer(ione*size(idx%v),info) + call y%new_comid(totxch,info) + y%comid = mpi_request_null + call psb_realloc(totxch,prcid,info) + ! First I post all the non blocking receives + pnti = 1 + do i=1, totxch + proc_to_comm = idx%v(pnti+psb_proc_id_) + nerv = idx%v(pnti+psb_n_elem_recv_) + nesd = idx%v(pnti+nerv+psb_n_elem_send_) + + rcv_pt = 1+pnti+psb_n_elem_recv_ + prcid(i) = psb_get_mpi_rank(ctxt,proc_to_comm) + if ((nerv>0).and.(proc_to_comm /= me)) then + if (debug) write(*,*) me,'Posting receive from',prcid(i),rcv_pt + p2ptag = psb_double_swap_tag + call mpi_irecv(y%combuf(rcv_pt),nerv,& + & psb_mpi_r_dpk_,prcid(i),& + & p2ptag, icomm,y%comid(i,2),iret) + end if + pnti = pnti + nerv + nesd + 3 + end do + if (debug) write(*,*) me,' Gather ' + ! + ! Then gather for sending. + ! + pnti = 1 + do i=1, totxch + nerv = idx%v(pnti+psb_n_elem_recv_) + nesd = idx%v(pnti+nerv+psb_n_elem_send_) + snd_pt = 1+pnti+nerv+psb_n_elem_send_ + rcv_pt = 1+pnti+psb_n_elem_recv_ + idx_pt = snd_pt + call y%gth(idx_pt,nesd,idx) + pnti = pnti + nerv + nesd + 3 + end do + + ! + ! Then wait + ! + call y%device_wait() + + if (debug) write(*,*) me,' isend' + ! + ! Then send + ! + + pnti = 1 + snd_pt = 1 + rcv_pt = 1 + p2ptag = psb_double_swap_tag + do i=1, totxch + proc_to_comm = idx%v(pnti+psb_proc_id_) + nerv = idx%v(pnti+psb_n_elem_recv_) + nesd = idx%v(pnti+nerv+psb_n_elem_send_) + snd_pt = 1+pnti+nerv+psb_n_elem_send_ + rcv_pt = 1+pnti+psb_n_elem_recv_ + + if ((nesd>0).and.(proc_to_comm /= me)) then + call mpi_isend(y%combuf(snd_pt),nesd,& + & psb_mpi_r_dpk_,prcid(i),& + & p2ptag,icomm,y%comid(i,1),iret) + end if + + if(iret /= mpi_success) then + info=psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 + end if + + pnti = pnti + nerv + nesd + 3 + end do + end if + + if (do_recv) then + if (debug) write(*,*) me,' do_Recv' + if (.not.allocated(y%comid)) then + ! + ! No matching send? Something is wrong.... + ! + info=psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/-2/)) + goto 9999 + end if + call psb_realloc(totxch,prcid,info) + + if (debug) write(*,*) me,' wait' + pnti = 1 + p2ptag = psb_double_swap_tag + do i=1, totxch + proc_to_comm = idx%v(pnti+psb_proc_id_) + nerv = idx%v(pnti+psb_n_elem_recv_) + nesd = idx%v(pnti+nerv+psb_n_elem_send_) + snd_pt = 1+pnti+nerv+psb_n_elem_send_ + rcv_pt = 1+pnti+psb_n_elem_recv_ + + if (proc_to_comm /= me)then + if (nesd>0) then + call mpi_wait(y%comid(i,1),p2pstat,iret) + if(iret /= mpi_success) then + info=psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 end if - else if (proc_to_comm == me) then - if (nesd /= nerv) then - write(psb_err_unit,*) & - & 'Fatal error in swapdata: mismatch on self send',& - & nerv,nesd + end if + if (nerv>0) then + call mpi_wait(y%comid(i,2),p2pstat,iret) + if(iret /= mpi_success) then + info=psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 end if - y%combuf(rcv_pt:rcv_pt+nerv-1) = y%combuf(snd_pt:snd_pt+nesd-1) end if - pnti = pnti + nerv + nesd + 3 - end do - - if (debug) write(*,*) me,' scatter' - pnti = 1 - snd_pt = 1 - rcv_pt = 1 - do i=1, totxch - proc_to_comm = idx%v(pnti+psb_proc_id_) - nerv = idx%v(pnti+psb_n_elem_recv_) - nesd = idx%v(pnti+nerv+psb_n_elem_send_) - idx_pt = 1+pnti+psb_n_elem_recv_ - snd_pt = 1+pnti+nerv+psb_n_elem_send_ - rcv_pt = 1+pnti+psb_n_elem_recv_ - - if (debug) write(0,*)me,' Received from: ',prcid(i),& - & y%combuf(rcv_pt:rcv_pt+nerv-1) - call y%sct(rcv_pt,nerv,idx,beta) - pnti = pnti + nerv + nesd + 3 - end do - ! - ! Waited for everybody, clean up - ! - y%comid = mpi_request_null + else if (proc_to_comm == me) then + if (nesd /= nerv) then + write(psb_err_unit,*) & + & 'Fatal error in swapdata: mismatch on self send',& + & nerv,nesd + end if + y%combuf(rcv_pt:rcv_pt+nerv-1) = y%combuf(snd_pt:snd_pt+nesd-1) + end if + pnti = pnti + nerv + nesd + 3 + end do + + if (debug) write(*,*) me,' scatter' + pnti = 1 + snd_pt = 1 + rcv_pt = 1 + do i=1, totxch + proc_to_comm = idx%v(pnti+psb_proc_id_) + nerv = idx%v(pnti+psb_n_elem_recv_) + nesd = idx%v(pnti+nerv+psb_n_elem_send_) + idx_pt = 1+pnti+psb_n_elem_recv_ + snd_pt = 1+pnti+nerv+psb_n_elem_send_ + rcv_pt = 1+pnti+psb_n_elem_recv_ + + if (debug) write(0,*)me,' Received from: ',prcid(i),& + & y%combuf(rcv_pt:rcv_pt+nerv-1) + call y%sct(rcv_pt,nerv,idx,beta) + pnti = pnti + nerv + nesd + 3 + end do + ! + ! Waited for everybody, clean up + ! + y%comid = mpi_request_null + + ! + ! Then wait for device + ! + if (debug) write(*,*) me,' wait' + call y%device_wait() + if (debug) write(*,*) me,' free buffer' + call y%maybe_free_buffer(info) + if (info == 0) call y%free_comid(info) + if (info /= 0) then + call psb_errpush(psb_err_alloc_dealloc_,name) + goto 9999 + end if + if (debug) write(*,*) me,' done' + end if - ! - ! Then wait for device - ! - if (debug) write(*,*) me,' wait' - call y%device_wait() - if (debug) write(*,*) me,' free buffer' - call y%maybe_free_buffer(info) - if (info == 0) call y%free_comid(info) - if (info /= 0) then - call psb_errpush(psb_err_alloc_dealloc_,name) + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return +end subroutine psi_dswap_baseline_vect + + + +subroutine psi_dswap_neighbor_topology_vect(ctxt,icomm,flag,beta,y,idx, & + & totxch,totsnd,totrcv,work,info) + +#ifdef PSB_MPI_MOD + use mpi +#endif + implicit none +#ifdef PSB_MPI_H + include 'mpif.h' +#endif + + type(psb_ctxt_type), intent(in) :: ctxt + integer(psb_mpk_), intent(in) :: icomm + integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(out) :: info + class(psb_d_base_vect_type) :: y + real(psb_dpk_), intent(in) :: beta + real(psb_dpk_), target :: work(:) + class(psb_i_base_vect_type), intent(inout) :: idx + integer(psb_ipk_), intent(in) :: totxch,totsnd, totrcv + + ! locals + integer(psb_mpk_) :: np, me + integer(psb_mpk_) :: iret, p2pstat(mpi_status_size) + integer(psb_ipk_) :: err_act, topology_total_send, topology_total_recv, buffer_size + logical :: do_start, do_wait + logical, parameter :: debug = .false. + character(len=30) :: name + + + info = psb_success_ + name = 'psi_dswap_nbr_vect' + call psb_erractionsave(err_act) + call psb_info(ctxt,me,np) + if (np == -1) then + info=psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + + do_start = iand(flag,psb_swap_start_) /= 0 + do_wait = iand(flag,psb_swap_wait_) /= 0 + + call idx%sync() + + ! --------------------------------------------------------- + ! START phase: build topology (if needed), gather, post MPI + ! --------------------------------------------------------- + if (do_start) then + if(debug) write(*,*) me,' nbr_vect: starting data exchange' + ! Lazy initialization: build the topology on first call + if (.not. y%neighbor_topology%is_initialized) then + if (debug) write(*,*) me,' nbr_vect: building topology' + call y%neighbor_topology%init(idx%v, totxch, totsnd, totrcv, & + & ctxt, icomm, info) + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_, name, & + & a_err='neighbor_topology_init') goto 9999 end if - if (debug) write(*,*) me,' done' end if + topology_total_send = y%neighbor_topology%total_send + topology_total_recv = y%neighbor_topology%total_recv - call psb_erractionrestore(err_act) - return + ! Buffer layout: + ! combuf(1 : total_send) = send area + ! combuf(total_send+1 : total_send+total_recv) = recv area + buffer_size = topology_total_send + topology_total_recv + + call y%new_buffer(buffer_size, info) + if (info /= 0) then + call psb_errpush(psb_err_alloc_dealloc_, name) + goto 9999 + end if + y%communication_handle = mpi_request_null + + ! Gather send data into contiguous send buffer (polymorphic for GPU) + if (debug) write(*,*) me,' nbr_vect: gathering send data,', topology_total_send,' elems' + call y%gth(int(topology_total_send,psb_mpk_), & + & y%neighbor_topology%send_indexes, & + & y%combuf(1:topology_total_send)) + + ! Wait for device (important for GPU subclasses) + call y%device_wait() + + ! Post non-blocking neighborhood alltoallv + if (debug) write(*,*) me,' nbr_vect: posting MPI_Ineighbor_alltoallv' + call mpi_ineighbor_alltoallv( & + & y%combuf(1), & ! send buffer + & y%neighbor_topology%send_counts, & + & y%neighbor_topology%send_displs, & + & psb_mpi_r_dpk_, & + & y%combuf(topology_total_send + 1), & ! recv buffer + & y%neighbor_topology%recv_counts, & + & y%neighbor_topology%recv_displs, & + & psb_mpi_r_dpk_, & + & y%neighbor_topology%graph_comm, & + & y%communication_handle, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/iret/)) + goto 9999 + end if + + end if ! do_start + + ! --------------------------------------------------------- + ! WAIT phase: complete MPI, scatter received data + ! --------------------------------------------------------- + if (do_wait) then + + if (y%communication_handle == mpi_request_null) then + ! No matching start? Something is wrong + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/-2/)) + goto 9999 + end if + + topology_total_send = y%neighbor_topology%total_send + topology_total_recv = y%neighbor_topology%total_recv + + ! Wait for the non-blocking collective to complete + if (debug) write(*,*) me,' nbr_vect: waiting on MPI request' + call mpi_wait(y%communication_handle, p2pstat, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/iret/)) + goto 9999 + end if + + ! Scatter received data to local vector positions (polymorphic for GPU) + if (debug) write(*,*) me,' nbr_vect: scattering recv data,', topology_total_recv,' elems' + call y%sct(int(topology_total_recv,psb_mpk_), & + & y%neighbor_topology%recv_indexes, & + & y%combuf(topology_total_send+1:topology_total_send+topology_total_recv), & + & beta) + + + ! Clean up + y%communication_handle = mpi_request_null + call y%device_wait() + call y%maybe_free_buffer(info) + if (info /= 0) then + call psb_errpush(psb_err_alloc_dealloc_, name) + goto 9999 + end if + if (debug) write(*,*) me,' nbr_vect: done' + + end if ! do_wait + + call psb_erractionrestore(err_act) + return 9999 call psb_error_handler(ctxt,err_act) - return - end subroutine psi_dswap_vidx_vect + return +end subroutine psi_dswap_neighbor_topology_vect + + ! ! @@ -511,237 +774,493 @@ contains include 'mpif.h' #endif - type(psb_ctxt_type), intent(in) :: ctxt - integer(psb_ipk_), intent(in) :: flag - integer(psb_ipk_), intent(out) :: info - class(psb_d_base_multivect_type) :: y - real(psb_dpk_) :: beta - real(psb_dpk_), target :: work(:) - class(psb_i_base_vect_type), intent(inout) :: idx - integer(psb_ipk_), intent(in) :: totxch,totsnd, totrcv + type(psb_ctxt_type), intent(in) :: ctxt + !integer(psb_mpk_), intent(in) :: icomm + integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(out) :: info + class(psb_d_base_multivect_type) :: y + real(psb_dpk_), intent(in) :: beta + real(psb_dpk_), target :: work(:) + class(psb_i_base_vect_type), intent(inout) :: idx + integer(psb_ipk_), intent(in) :: totxch,totsnd, totrcv + + ! local variables used to detect the communication scheme + logical :: swap_mpi, swap_sync, swap_send, swap_recv, swap_start, swap_wait + logical :: baseline, neighbor_a2av + + ! local variable used to get communicator + integer(psb_mpk_) :: icomm + + ! error handling variables + integer(psb_ipk_) :: err_act + integer(psb_mpk_) :: me, np + character(len=30) :: name + + + info=psb_success_ + name='psi_dswap_vidx_multivect' + call psb_erractionsave(err_act) + call psb_info(ctxt,me,np) + if (np == -1) then + info=psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + + + swap_mpi = iand(flag,psb_swap_mpi_) /= 0 + swap_sync = iand(flag,psb_swap_sync_) /= 0 + swap_send = iand(flag,psb_swap_send_) /= 0 + swap_recv = iand(flag,psb_swap_recv_) /= 0 + swap_start = iand(flag,psb_swap_start_) /= 0 + swap_wait = iand(flag,psb_swap_wait_) /= 0 + + baseline = swap_mpi .or. swap_send .or. swap_recv .or. swap_sync + neighbor_a2av = swap_start .or. swap_wait + + icomm = ctxt%get_mpic() + + if( (baseline.eqv..true.).and.(neighbor_a2av.eqv..true.) ) then + info=psb_err_mpi_error_ + call psb_errpush(info,name,a_err='Incompatible flag settings: both baseline and neighbor_a2av are true') + goto 9999 + end if + + + if (baseline) then + call psi_dswap_baseline_multivect(ctxt,icomm,flag,beta,y,idx,totxch,totsnd,totrcv,work,info) + if (info /= psb_success_) then + call psb_errpush(info,name,a_err='baseline swap') + goto 9999 + end if + else if (neighbor_a2av) then + call psi_dswap_neighbor_topology_multivect(ctxt,icomm,flag,beta,y,idx,totxch,totsnd,totrcv,work,info) + if (info /= psb_success_) then + call psb_errpush(info,name,a_err='neighbor a2av swap') + goto 9999 + end if + else + info = psb_err_mpi_error_ + call psb_errpush(info,name,a_err='Incompatible flag settings: neither baseline nor neighbor_a2av is true') + goto 9999 + end if + + call psb_erractionrestore(err_act) + return - ! locals - integer(psb_mpk_) :: np, me, nesd, nerv, n - integer(psb_mpk_) :: proc_to_comm, p2ptag, p2pstat(mpi_status_size), iret - integer(psb_mpk_) :: icomm - integer(psb_mpk_), allocatable :: prcid(:) - integer(psb_ipk_) :: err_act, i, idx_pt, totsnd_, totrcv_,& - & snd_pt, rcv_pt, pnti - logical :: swap_mpi, swap_sync, swap_send, swap_recv,& - & albf,do_send,do_recv - logical, parameter :: usersend=.false., debug=.false. - character(len=20) :: name +9999 call psb_error_handler(ctxt,err_act) - info=psb_success_ - name='psi_swap_datav' - call psb_erractionsave(err_act) - call psb_info(ctxt,me,np) - if (np == -1) then - info=psb_err_context_error_ - call psb_errpush(info,name) - goto 9999 - endif - icomm = ctxt%get_mpic() + return + end subroutine psi_dswap_vidx_multivect - n = y%get_ncols() - - swap_mpi = iand(flag,psb_swap_mpi_) /= 0 - swap_sync = iand(flag,psb_swap_sync_) /= 0 - swap_send = iand(flag,psb_swap_send_) /= 0 - swap_recv = iand(flag,psb_swap_recv_) /= 0 - do_send = swap_mpi .or. swap_sync .or. swap_send - do_recv = swap_mpi .or. swap_sync .or. swap_recv - - totrcv_ = totrcv * n - totsnd_ = totsnd * n - - call idx%sync() - - if (debug) write(*,*) me,'Internal buffer' - if (do_send) then - if (allocated(y%comid)) then - if (any(y%comid /= mpi_request_null)) then - ! - ! Unfinished communication? Something is wrong.... - ! - info=psb_err_mpi_error_ - call psb_errpush(info,name,m_err=(/-2/)) - goto 9999 - end if - end if - if (debug) write(*,*) me,'do_send start' - call y%new_buffer(ione*size(idx%v),info) - call y%new_comid(totxch,info) - y%comid = mpi_request_null - call psb_realloc(totxch,prcid,info) - ! First I post all the non blocking receives - pnti = 1 - snd_pt = totrcv_+1 - rcv_pt = 1 - do i=1, totxch - proc_to_comm = idx%v(pnti+psb_proc_id_) - nerv = idx%v(pnti+psb_n_elem_recv_) - nesd = idx%v(pnti+nerv+psb_n_elem_send_) - prcid(i) = psb_get_mpi_rank(ctxt,proc_to_comm) - if ((nerv>0).and.(proc_to_comm /= me)) then - if (debug) write(*,*) me,'Posting receive from',prcid(i),rcv_pt - p2ptag = psb_double_swap_tag - call mpi_irecv(y%combuf(rcv_pt),n*nerv,& - & psb_mpi_r_dpk_,prcid(i),& - & p2ptag, icomm,y%comid(i,2),iret) - end if - rcv_pt = rcv_pt + n*nerv - snd_pt = snd_pt + n*nesd - pnti = pnti + nerv + nesd + 3 - end do - if (debug) write(*,*) me,' Gather ' - ! - ! Then gather for sending. - ! - pnti = 1 - snd_pt = totrcv_+1 - rcv_pt = 1 - do i=1, totxch - nerv = idx%v(pnti+psb_n_elem_recv_) - nesd = idx%v(pnti+nerv+psb_n_elem_send_) - idx_pt = 1+pnti+nerv+psb_n_elem_send_ - call y%gth(idx_pt,snd_pt,nesd,idx) - rcv_pt = rcv_pt + n*nerv - snd_pt = snd_pt + n*nesd - pnti = pnti + nerv + nesd + 3 - end do - ! - ! Then wait for device - ! - call y%device_wait() - if (debug) write(*,*) me,' isend' - ! - ! Then send - ! - pnti = 1 - snd_pt = totrcv_+1 - rcv_pt = 1 - p2ptag = psb_double_swap_tag - do i=1, totxch - proc_to_comm = idx%v(pnti+psb_proc_id_) - nerv = idx%v(pnti+psb_n_elem_recv_) - nesd = idx%v(pnti+nerv+psb_n_elem_send_) - - if ((nesd>0).and.(proc_to_comm /= me)) then - call mpi_isend(y%combuf(snd_pt),n*nesd,& - & psb_mpi_r_dpk_,prcid(i),& - & p2ptag,icomm,y%comid(i,1),iret) - end if +subroutine psi_dswap_baseline_multivect(ctxt,icomm,flag,beta,y,idx, & + & totxch,totsnd,totrcv,work,info) - if(iret /= mpi_success) then - info=psb_err_mpi_error_ - call psb_errpush(info,name,m_err=(/iret/)) - goto 9999 - end if - rcv_pt = rcv_pt + n*nerv - snd_pt = snd_pt + n*nesd - pnti = pnti + nerv + nesd + 3 - end do - end if - if (do_recv) then - if (debug) write(*,*) me,' do_Recv' - if (.not.allocated(y%comid)) then +#ifdef PSB_MPI_MOD + use mpi +#endif + implicit none +#ifdef PSB_MPI_H + include 'mpif.h' +#endif + + type(psb_ctxt_type), intent(in) :: ctxt + integer(psb_mpk_), intent(in) :: icomm + integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(out) :: info + class(psb_d_base_multivect_type) :: y + real(psb_dpk_), intent(in) :: beta + real(psb_dpk_), target :: work(:) + class(psb_i_base_vect_type), intent(inout) :: idx + integer(psb_ipk_), intent(in) :: totxch,totsnd, totrcv + + ! locals + integer(psb_mpk_) :: np, me, nesd, nerv, n + integer(psb_mpk_) :: proc_to_comm, p2ptag, p2pstat(mpi_status_size), iret + integer(psb_mpk_), allocatable :: prcid(:) + integer(psb_ipk_) :: err_act, i, idx_pt, totsnd_, totrcv_,& + & snd_pt, rcv_pt, pnti + logical :: swap_mpi, swap_sync, swap_send, swap_recv,& + & albf,do_send,do_recv + logical, parameter :: usersend=.false., debug=.false. + character(len=20) :: name + + info = psb_success_ + name = 'psi_dswap_baseline_multivect' + call psb_erractionsave(err_act) + call psb_info(ctxt,me,np) + if (np == -1) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + + n = y%get_ncols() + + swap_mpi = iand(flag,psb_swap_mpi_) /= 0 + swap_sync = iand(flag,psb_swap_sync_) /= 0 + swap_send = iand(flag,psb_swap_send_) /= 0 + swap_recv = iand(flag,psb_swap_recv_) /= 0 + do_send = swap_mpi .or. swap_sync .or. swap_send + do_recv = swap_mpi .or. swap_sync .or. swap_recv + + totrcv_ = totrcv * n + totsnd_ = totsnd * n + + call idx%sync() + + if (debug) write(*,*) me,'Internal buffer' + if (do_send) then + if (allocated(y%comid)) then + if (any(y%comid /= mpi_request_null)) then ! - ! No matching send? Something is wrong.... + ! Unfinished communication? Something is wrong.... ! info=psb_err_mpi_error_ call psb_errpush(info,name,m_err=(/-2/)) goto 9999 end if - call psb_realloc(totxch,prcid,info) - - if (debug) write(*,*) me,' wait' - pnti = 1 - snd_pt = totrcv_+1 - rcv_pt = 1 - p2ptag = psb_double_swap_tag - do i=1, totxch - proc_to_comm = idx%v(pnti+psb_proc_id_) - nerv = idx%v(pnti+psb_n_elem_recv_) - nesd = idx%v(pnti+nerv+psb_n_elem_send_) - if (proc_to_comm /= me)then - if (nesd>0) then - call mpi_wait(y%comid(i,1),p2pstat,iret) - if(iret /= mpi_success) then - info=psb_err_mpi_error_ - call psb_errpush(info,name,m_err=(/iret/)) - goto 9999 - end if - end if - if (nerv>0) then - call mpi_wait(y%comid(i,2),p2pstat,iret) - if(iret /= mpi_success) then - info=psb_err_mpi_error_ - call psb_errpush(info,name,m_err=(/iret/)) - goto 9999 - end if + end if + if (debug) write(*,*) me,'do_send start' + call y%new_buffer(ione*size(idx%v),info) + call y%new_comid(totxch,info) + y%comid = mpi_request_null + call psb_realloc(totxch,prcid,info) + ! First I post all the non blocking receives + pnti = 1 + snd_pt = totrcv_+1 + rcv_pt = 1 + do i=1, totxch + proc_to_comm = idx%v(pnti+psb_proc_id_) + nerv = idx%v(pnti+psb_n_elem_recv_) + nesd = idx%v(pnti+nerv+psb_n_elem_send_) + prcid(i) = psb_get_mpi_rank(ctxt,proc_to_comm) + if ((nerv>0).and.(proc_to_comm /= me)) then + if (debug) write(*,*) me,'Posting receive from',prcid(i),rcv_pt + p2ptag = psb_double_swap_tag + call mpi_irecv(y%combuf(rcv_pt),n*nerv,& + & psb_mpi_r_dpk_,prcid(i),& + & p2ptag, icomm,y%comid(i,2),iret) + end if + rcv_pt = rcv_pt + n*nerv + snd_pt = snd_pt + n*nesd + pnti = pnti + nerv + nesd + 3 + end do + if (debug) write(*,*) me,' Gather ' + ! + ! Then gather for sending. + ! + pnti = 1 + snd_pt = totrcv_+1 + rcv_pt = 1 + do i=1, totxch + nerv = idx%v(pnti+psb_n_elem_recv_) + nesd = idx%v(pnti+nerv+psb_n_elem_send_) + idx_pt = 1+pnti+nerv+psb_n_elem_send_ + call y%gth(idx_pt,snd_pt,nesd,idx) + rcv_pt = rcv_pt + n*nerv + snd_pt = snd_pt + n*nesd + pnti = pnti + nerv + nesd + 3 + end do + + ! + ! Then wait for device + ! + call y%device_wait() + + if (debug) write(*,*) me,' isend' + ! + ! Then send + ! + + pnti = 1 + snd_pt = totrcv_+1 + rcv_pt = 1 + p2ptag = psb_double_swap_tag + do i=1, totxch + proc_to_comm = idx%v(pnti+psb_proc_id_) + nerv = idx%v(pnti+psb_n_elem_recv_) + nesd = idx%v(pnti+nerv+psb_n_elem_send_) + + if ((nesd>0).and.(proc_to_comm /= me)) then + call mpi_isend(y%combuf(snd_pt),n*nesd,& + & psb_mpi_r_dpk_,prcid(i),& + & p2ptag,icomm,y%comid(i,1),iret) + end if + + if(iret /= mpi_success) then + info=psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 + end if + rcv_pt = rcv_pt + n*nerv + snd_pt = snd_pt + n*nesd + pnti = pnti + nerv + nesd + 3 + end do + end if + + if (do_recv) then + if (debug) write(*,*) me,' do_Recv' + if (.not.allocated(y%comid)) then + ! + ! No matching send? Something is wrong.... + ! + info=psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/-2/)) + goto 9999 + end if + call psb_realloc(totxch,prcid,info) + + if (debug) write(*,*) me,' wait' + pnti = 1 + snd_pt = totrcv_+1 + rcv_pt = 1 + p2ptag = psb_double_swap_tag + do i=1, totxch + proc_to_comm = idx%v(pnti+psb_proc_id_) + nerv = idx%v(pnti+psb_n_elem_recv_) + nesd = idx%v(pnti+nerv+psb_n_elem_send_) + if (proc_to_comm /= me)then + if (nesd>0) then + call mpi_wait(y%comid(i,1),p2pstat,iret) + if(iret /= mpi_success) then + info=psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 end if - else if (proc_to_comm == me) then - if (nesd /= nerv) then - write(psb_err_unit,*) & - & 'Fatal error in swapdata: mismatch on self send',& - & nerv,nesd + end if + if (nerv>0) then + call mpi_wait(y%comid(i,2),p2pstat,iret) + if(iret /= mpi_success) then + info=psb_err_mpi_error_ + call psb_errpush(info,name,m_err=(/iret/)) + goto 9999 end if - y%combuf(rcv_pt:rcv_pt+n*nerv-1) = y%combuf(snd_pt:snd_pt+n*nesd-1) end if - rcv_pt = rcv_pt + n*nerv - snd_pt = snd_pt + n*nesd - pnti = pnti + nerv + nesd + 3 - end do - - if (debug) write(*,*) me,' scatter' - pnti = 1 - snd_pt = totrcv_+1 - rcv_pt = 1 - do i=1, totxch - proc_to_comm = idx%v(pnti+psb_proc_id_) - nerv = idx%v(pnti+psb_n_elem_recv_) - nesd = idx%v(pnti+nerv+psb_n_elem_send_) - idx_pt = 1+pnti+psb_n_elem_recv_ - - if (debug) write(0,*)me,' Received from: ',prcid(i),& - & y%combuf(rcv_pt:rcv_pt+n*nerv-1) - call y%sct(idx_pt,rcv_pt,nerv,idx,beta) - rcv_pt = rcv_pt + n*nerv - snd_pt = snd_pt + n*nesd - pnti = pnti + nerv + nesd + 3 - end do - ! - ! Waited for com, cleanup comid - ! - y%comid = mpi_request_null + else if (proc_to_comm == me) then + if (nesd /= nerv) then + write(psb_err_unit,*) & + & 'Fatal error in swapdata: mismatch on self send',& + & nerv,nesd + end if + y%combuf(rcv_pt:rcv_pt+n*nerv-1) = y%combuf(snd_pt:snd_pt+n*nesd-1) + end if + rcv_pt = rcv_pt + n*nerv + snd_pt = snd_pt + n*nesd + pnti = pnti + nerv + nesd + 3 + end do + + if (debug) write(*,*) me,' scatter' + pnti = 1 + snd_pt = totrcv_+1 + rcv_pt = 1 + do i=1, totxch + proc_to_comm = idx%v(pnti+psb_proc_id_) + nerv = idx%v(pnti+psb_n_elem_recv_) + nesd = idx%v(pnti+nerv+psb_n_elem_send_) + idx_pt = 1+pnti+psb_n_elem_recv_ + + if (debug) write(0,*)me,' Received from: ',prcid(i),& + & y%combuf(rcv_pt:rcv_pt+n*nerv-1) + call y%sct(idx_pt,rcv_pt,nerv,idx,beta) + rcv_pt = rcv_pt + n*nerv + snd_pt = snd_pt + n*nesd + pnti = pnti + nerv + nesd + 3 + end do + ! + ! Waited for com, cleanup comid + ! + y%comid = mpi_request_null + + ! + ! Then wait for device + ! + if (debug) write(*,*) me,' wait' + call y%device_wait() + if (debug) write(*,*) me,' free buffer' + call y%free_buffer(info) + if (info == 0) call y%free_comid(info) + if (info /= 0) then + call psb_errpush(psb_err_alloc_dealloc_,name) + goto 9999 + end if + if (debug) write(*,*) me,' done' + end if - ! - ! Then wait for device - ! - if (debug) write(*,*) me,' wait' - call y%device_wait() - if (debug) write(*,*) me,' free buffer' - call y%free_buffer(info) - if (info == 0) call y%free_comid(info) - if (info /= 0) then - call psb_errpush(psb_err_alloc_dealloc_,name) + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return +end subroutine psi_dswap_baseline_vidx_multivect + + + +subroutine psi_dswap_neighbor_topology_multivect(ctxt,icomm,flag,beta,y,idx, & + & totxch,totsnd,totrcv,work,info) + +#ifdef PSB_MPI_MOD + use mpi +#endif + implicit none +#ifdef PSB_MPI_H + include 'mpif.h' +#endif + + type(psb_ctxt_type), intent(in) :: ctxt + integer(psb_mpk_), intent(in) :: icomm + integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(out) :: info + class(psb_d_base_multivect_type) :: y + real(psb_dpk_), intent(in) :: beta + real(psb_dpk_), target :: work(:) + class(psb_i_base_vect_type), intent(inout) :: idx + integer(psb_ipk_), intent(in) :: totxch,totsnd, totrcv + + ! locals + integer(psb_mpk_) :: np, me + integer(psb_mpk_) :: iret, p2pstat(mpi_status_size) + integer(psb_ipk_) :: err_act, topology_total_send, topology_total_recv, buffer_size + logical :: do_start, do_wait + logical, parameter :: debug = .false. + character(len=30) :: name + + + info = psb_success_ + name = 'psi_dswap_nbr_vect' + call psb_erractionsave(err_act) + call psb_info(ctxt,me,np) + if (np == -1) then + info=psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + + do_start = iand(flag,psb_swap_start_) /= 0 + do_wait = iand(flag,psb_swap_wait_) /= 0 + + call idx%sync() + + ! --------------------------------------------------------- + ! START phase: build topology (if needed), gather, post MPI + ! --------------------------------------------------------- + if (do_start) then + if(debug) write(*,*) me,' nbr_vect: starting data exchange' + ! Lazy initialization: build the topology on first call + if (.not. y%neighbor_topology%is_initialized) then + if (debug) write(*,*) me,' nbr_vect: building topology' + call y%neighbor_topology%init(idx%v, totxch, totsnd, totrcv, & + & ctxt, icomm, info) + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_, name, & + & a_err='neighbor_topology_init') goto 9999 end if - if (debug) write(*,*) me,' done' end if + topology_total_send = y%neighbor_topology%total_send + topology_total_recv = y%neighbor_topology%total_recv - call psb_erractionrestore(err_act) - return + ! Buffer layout: + ! combuf(1 : total_send) = send area + ! combuf(total_send+1 : total_send+total_recv) = recv area + buffer_size = topology_total_send + topology_total_recv + + call y%new_buffer(buffer_size, info) + if (info /= 0) then + call psb_errpush(psb_err_alloc_dealloc_, name) + goto 9999 + end if + y%communication_handle = mpi_request_null + + ! Gather send data into contiguous send buffer (polymorphic for GPU) + if (debug) write(*,*) me,' nbr_vect: gathering send data,', topology_total_send,' elems' + call y%gth(int(topology_total_send,psb_mpk_), & + & y%neighbor_topology%send_indexes, & + & y%combuf(1:topology_total_send)) + + ! Wait for device (important for GPU subclasses) + call y%device_wait() + + ! Post non-blocking neighborhood alltoallv + if (debug) write(*,*) me,' nbr_vect: posting MPI_Ineighbor_alltoallv' + call mpi_ineighbor_alltoallv( & + & y%combuf(1), & ! send buffer + & y%neighbor_topology%send_counts, & + & y%neighbor_topology%send_displs, & + & psb_mpi_r_dpk_, & + & y%combuf(topology_total_send + 1), & ! recv buffer + & y%neighbor_topology%recv_counts, & + & y%neighbor_topology%recv_displs, & + & psb_mpi_r_dpk_, & + & y%neighbor_topology%graph_comm, & + & y%communication_handle, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/iret/)) + goto 9999 + end if + + end if ! do_start + + ! --------------------------------------------------------- + ! WAIT phase: complete MPI, scatter received data + ! --------------------------------------------------------- + if (do_wait) then + + if (y%communication_handle == mpi_request_null) then + ! No matching start? Something is wrong + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/-2/)) + goto 9999 + end if + + topology_total_send = y%neighbor_topology%total_send + topology_total_recv = y%neighbor_topology%total_recv + + ! Wait for the non-blocking collective to complete + if (debug) write(*,*) me,' nbr_vect: waiting on MPI request' + call mpi_wait(y%communication_handle, p2pstat, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/iret/)) + goto 9999 + end if + + ! Scatter received data to local vector positions (polymorphic for GPU) + if (debug) write(*,*) me,' nbr_vect: scattering recv data,', topology_total_recv,' elems' + call y%sct(int(topology_total_recv,psb_mpk_), & + & y%neighbor_topology%recv_indexes, & + & y%combuf(topology_total_send+1:topology_total_send+topology_total_recv), & + & beta) + + + ! Clean up + y%communication_handle = mpi_request_null + call y%device_wait() + call y%maybe_free_buffer(info) + if (info /= 0) then + call psb_errpush(psb_err_alloc_dealloc_, name) + goto 9999 + end if + if (debug) write(*,*) me,' nbr_vect: done' + + end if ! do_wait + + call psb_erractionrestore(err_act) + return 9999 call psb_error_handler(ctxt,err_act) - return - end subroutine psi_dswap_vidx_multivect + return +end subroutine psi_dswap_neighbor_topology_multivect + + end submodule psi_d_swapdata_impl