diff --git a/base/CMakeLists.txt b/base/CMakeLists.txt index f2b83f91..1833a93c 100644 --- a/base/CMakeLists.txt +++ b/base/CMakeLists.txt @@ -437,7 +437,7 @@ set(PSB_base_source_files # modules/comm/psi_i2_comm_a_mod.f90 modules/comm/psi_m_comm_a_mod.f90 modules/comm/psi_l_comm_v_mod.f90 - modules/comm/psb_comm_mod.f90 + modules/comm/psb_comm_mod.F90 modules/comm/psb_l_comm_mod.f90 modules/comm/psb_d_linmap_mod.f90 modules/comm/psi_d_comm_v_mod.f90 diff --git a/base/comm/internals/psi_dswapdata.F90 b/base/comm/internals/psi_dswapdata.F90 index c21c2bbf..42e72902 100644 --- a/base/comm/internals/psi_dswapdata.F90 +++ b/base/comm/internals/psi_dswapdata.F90 @@ -56,21 +56,13 @@ ! so that special versions (i.e. GPU vectors can override them ! ! Arguments: -! flag - integer Choose the algorithm for data exchange: -! this is chosen through bit fields. -! swap_mpi = iand(flag,psb_swap_mpi_) /= 0 -! swap_sync = iand(flag,psb_swap_sync_) /= 0 -! swap_send = iand(flag,psb_swap_send_) /= 0 -! swap_recv = iand(flag,psb_swap_recv_) /= 0 -! if (swap_mpi): use underlying MPI_ALLTOALLV. -! if (swap_sync): use PSB_SND and PSB_RCV in -! synchronized pairs -! if (swap_send .and. swap_recv): use mpi_irecv -! and mpi_send -! if (swap_send): use psb_snd (but need another -! call with swap_recv to complete) -! if (swap_recv): use psb_rcv (completing a -! previous call with swap_send) +! swap_status - integer Swap status selector. +! It is interpreted as a communication status: +! psb_comm_status_start_ -> START phase +! psb_comm_status_wait_ -> WAIT phase +! psb_comm_status_unknown_ -> START+WAIT +! The communication scheme is selected from +! y%comm_handle%comm_type. ! ! ! n - integer Number of columns in Y @@ -90,8 +82,14 @@ submodule (psi_d_comm_v_mod) psi_d_swapdata_impl use psb_desc_const_mod, only: psb_swap_start_, psb_swap_wait_ use psb_base_mod + use psb_comm_schemes_mod, only: psb_comm_handle_type, psb_comm_isend_irecv_, psb_comm_ineighbor_alltoallv_, & + & psb_comm_persistent_ineighbor_alltoallv_, & + & psb_comm_status_start_, psb_comm_status_wait_, psb_comm_status_unknown_ + use psb_comm_factory_mod, only: psb_comm_init, psb_comm_free + use psb_comm_baseline_mod, only: psb_comm_baseline_handle, psb_comm_baseline_alloc_comid + use psb_comm_neighbor_impl_mod, only: psb_comm_neighbor_handle contains - module subroutine psi_dswapdata_vect(flag,beta,y,desc_a,info,data) + module subroutine psi_dswapdata_vect(swap_status,beta,y,desc_a,info,data) #ifdef PSB_MPI_MOD use mpi @@ -101,7 +99,7 @@ contains include 'mpif.h' #endif - integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(in) :: swap_status class(psb_d_base_vect_type), intent(inout) :: y real(psb_dpk_), intent(in) :: beta type(psb_desc_type), target :: desc_a ! TODO: should this be intent(in)? @@ -111,10 +109,10 @@ contains ! locals type(psb_ctxt_type) :: ctxt integer(psb_ipk_) :: np, me, total_send, total_recv, num_neighbors, data_, err_act + integer(psb_ipk_) :: setflag class(psb_i_base_vect_type), pointer :: comm_indexes - ! local variables used to detect the communication scheme - logical :: swap_mpi, swap_sync, swap_send, swap_recv, swap_start, swap_wait + ! communication scheme/status selectors logical :: baseline, neighbor_a2av ! error handling variables @@ -152,37 +150,84 @@ contains goto 9999 end if - swap_mpi = iand(flag,psb_swap_mpi_) /= 0 - swap_sync = iand(flag,psb_swap_sync_) /= 0 - swap_send = iand(flag,psb_swap_send_) /= 0 - swap_recv = iand(flag,psb_swap_recv_) /= 0 - swap_start = iand(flag,psb_swap_start_) /= 0 - swap_wait = iand(flag,psb_swap_wait_) /= 0 - - baseline = swap_mpi .or. swap_send .or. swap_recv .or. swap_sync - neighbor_a2av = swap_start .or. swap_wait + ! Debug: report list sizes + ! if(me == 0) then + ! write(psb_err_unit,*) me, 'DBG: get_list_p -> num_neighbors=', & + ! & num_neighbors, ' total_send=', total_send, ' total_recv=', total_recv + ! end if + ! Accept both new comm-status enums and legacy descriptor bitfields. + setflag = swap_status + if (swap_status == psb_swap_start_) then + setflag = psb_comm_status_start_ + else if (swap_status == psb_swap_wait_) then + setflag = psb_comm_status_wait_ + else if (iand(swap_status, psb_swap_start_) /= 0 .and. iand(swap_status, psb_swap_wait_) /= 0) then + setflag = psb_comm_status_unknown_ + end if - if( (baseline.eqv..true.).and.(neighbor_a2av.eqv..true.) ) then + if ((setflag /= psb_comm_status_start_) .and. (setflag /= psb_comm_status_wait_) .and. & + & (setflag /= psb_comm_status_unknown_)) then info = psb_err_mpi_error_ - call psb_errpush(info,name,a_err='Incompatible flag settings: both baseline and neighbor_a2av are true') + call psb_errpush(info,name,a_err='Invalid swap_status swap_status') + goto 9999 + end if + + if (.not. allocated(y%comm_handle)) then + call psb_comm_init(psb_comm_isend_irecv_, y%comm_handle, info) + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_, name, a_err='init comm default baseline') + goto 9999 + end if + end if + + ! Debug: comm handle allocation and type + ! if(me == 0) then + ! write(psb_err_unit,*) me, 'CALLING' + ! write(psb_err_unit,*) me, 'DBG: comm_handle allocated=', allocated(y%comm_handle) + ! if (allocated(y%comm_handle)) then + ! write(psb_err_unit,*) me, 'DBG: comm_handle%comm_type=', y%comm_handle%comm_type + ! end if + ! end if + ! Set the normalized swap status on the comm handle + call y%comm_handle%set_swap_status(setflag, info) + if (info /= psb_success_) then + call psb_errpush(info,name,a_err='set_swap_status') goto 9999 end if + ! if(me == 0) then + ! write(psb_err_unit,*) me, 'DBG: after set_swap_status, info=', info + ! end if + + baseline = .false. + neighbor_a2av = .false. + select case(y%comm_handle%comm_type) + case(psb_comm_ineighbor_alltoallv_, psb_comm_persistent_ineighbor_alltoallv_) + neighbor_a2av = .true. + case default + baseline = .true. + end select + + ! if(me == 0) then + ! write(psb_err_unit,*) me, 'DBG: selected baseline=', baseline, ' neighbor=', neighbor_a2av + ! end if + if (baseline) then - call psi_dswap_baseline_vect(ctxt,flag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,info) + call psi_dswap_baseline_vect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,total_send,total_recv,y%comm_handle,info) if (info /= psb_success_) then call psb_errpush(info,name,a_err='baseline swap') goto 9999 end if else if (neighbor_a2av) then - call psi_dswap_neighbor_topology_vect(ctxt,flag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,info) + call psi_dswap_neighbor_topology_vect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,& + & total_send,total_recv,y%comm_handle,info) if (info /= psb_success_) then call psb_errpush(info,name,a_err='neighbor a2av swap') goto 9999 end if else info = psb_err_mpi_error_ - call psb_errpush(info,name,a_err='Incompatible flag settings: neither baseline nor neighbor_a2av is true') + call psb_errpush(info,name,a_err='Incompatible swap_status settings: neither baseline nor neighbor_a2av is true') goto 9999 end if @@ -193,11 +238,8 @@ contains return end subroutine psi_dswapdata_vect - ! - ! subroutine psi_dswap_baseline_vect - ! This performs Isend/Irecv as a baseline communication mode - ! - subroutine psi_dswap_baseline_vect(ctxt,flag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,info) + + subroutine psi_dswap_baseline_vect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,total_send,total_recv,comm_handle,info) #ifdef PSB_MPI_MOD use mpi #endif @@ -207,11 +249,12 @@ contains #endif type(psb_ctxt_type), intent(in) :: ctxt - integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(in) :: swap_status real(psb_dpk_), intent(in) :: beta class(psb_d_base_vect_type), intent(inout) :: y class(psb_i_base_vect_type), intent(inout) :: comm_indexes integer(psb_ipk_), intent(in) :: num_neighbors, total_send, total_recv + class(psb_comm_handle_type), intent(inout) :: comm_handle integer(psb_ipk_), intent(out) :: info ! locals @@ -220,10 +263,10 @@ contains integer(psb_mpk_) :: proc_to_comm, p2ptag, p2pstat(mpi_status_size),& & iret, nesd, nerv integer(psb_mpk_), allocatable :: prcid(:) + type(psb_comm_baseline_handle), pointer :: baseline_comm_handle integer(psb_ipk_) :: err_act, i, idx_pt, total_send_, total_recv_,& & snd_pt, rcv_pt, pnti, n - logical :: swap_mpi, swap_sync, swap_send, swap_recv,& - & albf,do_send,do_recv + logical :: do_send,do_recv logical, parameter :: usersend=.false., debug=.false. character(len=20) :: name @@ -239,14 +282,19 @@ contains icomm = ctxt%get_mpic() - n=1 - swap_mpi = iand(flag,psb_swap_mpi_) /= 0 - swap_sync = iand(flag,psb_swap_sync_) /= 0 - swap_send = iand(flag,psb_swap_send_) /= 0 - swap_recv = iand(flag,psb_swap_recv_) /= 0 + baseline_comm_handle => null() + select type(ch => comm_handle) + type is(psb_comm_baseline_handle) + baseline_comm_handle => ch + class default + info = psb_err_mpi_error_ + call psb_errpush(info,name,a_err='Expected baseline comm_handle in baseline swap') + goto 9999 + end select - do_send = swap_mpi .or. swap_sync .or. swap_send - do_recv = swap_mpi .or. swap_sync .or. swap_recv + n=1 + do_send = (swap_status == psb_comm_status_start_) .or. (swap_status == psb_comm_status_unknown_) + do_recv = (swap_status == psb_comm_status_wait_) .or. (swap_status == psb_comm_status_unknown_) total_recv_ = total_recv * n total_send_ = total_send * n @@ -254,8 +302,8 @@ contains if (debug) write(*,*) me,'Internal buffer' if (do_send) then - if (allocated(y%comid)) then - if (any(y%comid /= mpi_request_null)) then + if (allocated(baseline_comm_handle%comid)) then + if (any(baseline_comm_handle%comid /= mpi_request_null)) then ! ! Unfinished communication? Something is wrong.... ! @@ -266,8 +314,8 @@ contains end if if (debug) write(*,*) me,'do_send start' call y%new_buffer(ione*size(comm_indexes%v),info) - call y%new_comid(num_neighbors,info) - y%comid = mpi_request_null + call psb_realloc(num_neighbors,2_psb_ipk_,baseline_comm_handle%comid,info) + baseline_comm_handle%comid = mpi_request_null call psb_realloc(num_neighbors,prcid,info) ! First I post all the non blocking receives pnti = 1 @@ -281,9 +329,9 @@ contains if ((nerv>0).and.(proc_to_comm /= me)) then if (debug) write(*,*) me,'Posting receive from',prcid(i),rcv_pt p2ptag = psb_double_swap_tag - call mpi_irecv(y%combuf(rcv_pt),nerv,& + call mpi_irecv(y%combuf(rcv_pt),nerv,& & psb_mpi_r_dpk_,prcid(i),& - & p2ptag, icomm,y%comid(i,2),iret) + & p2ptag, icomm,baseline_comm_handle%comid(i,2),iret) end if pnti = pnti + nerv + nesd + 3 end do @@ -324,9 +372,9 @@ contains rcv_pt = 1+pnti+psb_n_elem_recv_ if ((nesd>0).and.(proc_to_comm /= me)) then - call mpi_isend(y%combuf(snd_pt),nesd,& + call mpi_isend(y%combuf(snd_pt),nesd,& & psb_mpi_r_dpk_,prcid(i),& - & p2ptag,icomm,y%comid(i,1),iret) + & p2ptag,icomm,baseline_comm_handle%comid(i,1),iret) end if if(iret /= mpi_success) then @@ -341,7 +389,7 @@ contains if (do_recv) then if (debug) write(*,*) me,' do_Recv' - if (.not.allocated(y%comid)) then + if (.not.allocated(baseline_comm_handle%comid)) then ! ! No matching send? Something is wrong.... ! @@ -363,7 +411,7 @@ contains if (proc_to_comm /= me)then if (nesd>0) then - call mpi_wait(y%comid(i,1),p2pstat,iret) + call mpi_wait(baseline_comm_handle%comid(i,1),p2pstat,iret) if(iret /= mpi_success) then info=psb_err_mpi_error_ call psb_errpush(info,name,m_err=(/iret/)) @@ -371,7 +419,7 @@ contains end if end if if (nerv>0) then - call mpi_wait(y%comid(i,2),p2pstat,iret) + call mpi_wait(baseline_comm_handle%comid(i,2),p2pstat,iret) if(iret /= mpi_success) then info=psb_err_mpi_error_ call psb_errpush(info,name,m_err=(/iret/)) @@ -409,7 +457,7 @@ contains ! ! Waited for everybody, clean up ! - y%comid = mpi_request_null + baseline_comm_handle%comid = mpi_request_null ! ! Then wait for device @@ -418,7 +466,9 @@ contains call y%device_wait() if (debug) write(*,*) me,' free buffer' call y%maybe_free_buffer(info) - if (info == 0) call y%free_comid(info) + if (info == 0) then + if (allocated(y%comm_handle)) call y%comm_handle%free(info) + end if if (info /= 0) then call psb_errpush(psb_err_alloc_dealloc_,name) goto 9999 @@ -437,7 +487,8 @@ contains - subroutine psi_dswap_neighbor_topology_vect(ctxt,flag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,info) + subroutine psi_dswap_neighbor_topology_vect(ctxt,swap_status,beta,y,comm_indexes,& + & num_neighbors,total_send,total_recv,comm_handle,info) #ifdef PSB_MPI_MOD use mpi #endif @@ -447,17 +498,19 @@ contains #endif type(psb_ctxt_type), intent(in) :: ctxt - integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(in) :: swap_status real(psb_dpk_), intent(in) :: beta class(psb_d_base_vect_type), intent(inout) :: y class(psb_i_base_vect_type), intent(inout) :: comm_indexes integer(psb_ipk_), intent(in) :: num_neighbors,total_send,total_recv + class(psb_comm_handle_type), intent(inout) :: comm_handle integer(psb_ipk_), intent(out) :: info ! locals integer(psb_mpk_) :: icomm integer(psb_mpk_) :: np, me integer(psb_mpk_) :: iret, p2pstat(mpi_status_size) + type(psb_comm_neighbor_handle), pointer :: neighbor_comm_handle integer(psb_ipk_) :: err_act, topology_total_send, topology_total_recv, buffer_size logical :: do_start, do_wait logical, parameter :: debug = .false. @@ -476,8 +529,18 @@ contains icomm = ctxt%get_mpic() - do_start = iand(flag,psb_swap_start_) /= 0 - do_wait = iand(flag,psb_swap_wait_) /= 0 + neighbor_comm_handle => null() + select type(ch => comm_handle) + type is(psb_comm_neighbor_handle) + neighbor_comm_handle => ch + class default + info = psb_err_mpi_error_ + call psb_errpush(info,name,a_err='Expected neighbor comm_handle in neighbor swap') + goto 9999 + end select + + do_start = (swap_status == psb_comm_status_start_) .or. (swap_status == psb_comm_status_unknown_) + do_wait = (swap_status == psb_comm_status_wait_) .or. (swap_status == psb_comm_status_unknown_) call comm_indexes%sync() @@ -486,58 +549,131 @@ contains ! --------------------------------------------------------- if (do_start) then if(debug) write(*,*) me,' nbr_vect: starting data exchange' - ! Lazy initialization: build the topology on first call - if (.not. y%neighbor_topology%is_initialized) then - if (debug) write(*,*) me,' nbr_vect: building topology' - call y%neighbor_topology%init(comm_indexes%v, num_neighbors, total_send, total_recv, ctxt, icomm, info) + if (.not. neighbor_comm_handle%is_initialized) then + if (debug) write(*,*) me,' nbr_vect: building topology via handle' + call neighbor_comm_handle%topology_init(comm_indexes%v, num_neighbors, total_send, total_recv, ctxt, icomm, info) if (info /= psb_success_) then - call psb_errpush(psb_err_internal_error_, name, & - & a_err='neighbor_topology_init') + call psb_errpush(psb_err_internal_error_, name, a_err='neighbor_topology_init') goto 9999 end if end if - - topology_total_send = y%neighbor_topology%total_send - topology_total_recv = y%neighbor_topology%total_recv + topology_total_send = neighbor_comm_handle%total_send + topology_total_recv = neighbor_comm_handle%total_recv ! Buffer layout: ! combuf(1 : total_send) = send area ! combuf(total_send+1 : total_send+total_recv) = recv area buffer_size = topology_total_send + topology_total_recv - call y%new_buffer(buffer_size, info) - if (info /= 0) then - call psb_errpush(psb_err_alloc_dealloc_, name) - goto 9999 + if (neighbor_comm_handle%use_persistent_buffers) then + if ((.not.allocated(y%combuf)) .or. (size(y%combuf) < buffer_size)) then + if (neighbor_comm_handle%persistent_request_ready) then + if (neighbor_comm_handle%persistent_request /= mpi_request_null) then + call mpi_request_free(neighbor_comm_handle%persistent_request, iret) + end if + neighbor_comm_handle%persistent_request = mpi_request_null + neighbor_comm_handle%persistent_request_ready = .false. + neighbor_comm_handle%persistent_buffer_size = 0 + end if + call y%new_buffer(buffer_size, info) + if (info /= 0) then + call psb_errpush(psb_err_alloc_dealloc_, name) + goto 9999 + end if + end if + else + call y%new_buffer(buffer_size, info) + if (info /= 0) then + call psb_errpush(psb_err_alloc_dealloc_, name) + goto 9999 + end if end if - y%communication_handle = mpi_request_null + neighbor_comm_handle%comm_request = mpi_request_null ! Gather send data into contiguous send buffer (polymorphic for GPU) if (debug) write(*,*) me,' nbr_vect: gathering send data,', topology_total_send,' elems' call y%gth(int(topology_total_send,psb_mpk_), & - & y%neighbor_topology%send_indexes, & + & neighbor_comm_handle%send_indexes, & & y%combuf(1:topology_total_send)) ! Wait for device (important for GPU subclasses) call y%device_wait() - ! Post non-blocking neighborhood alltoallv - if (debug) write(*,*) me,' nbr_vect: posting MPI_Ineighbor_alltoallv' - call mpi_ineighbor_alltoallv( & - & y%combuf(1), & ! send buffer - & y%neighbor_topology%send_counts, & - & y%neighbor_topology%send_displs, & - & psb_mpi_r_dpk_, & - & y%combuf(topology_total_send + 1), & ! recv buffer - & y%neighbor_topology%recv_counts, & - & y%neighbor_topology%recv_displs, & - & psb_mpi_r_dpk_, & - & y%neighbor_topology%graph_comm, & - & y%communication_handle, iret) - if (iret /= mpi_success) then - info = psb_err_mpi_error_ - call psb_errpush(info, name, m_err=(/iret/)) - goto 9999 + if (neighbor_comm_handle%use_persistent_buffers) then + ! Lazy persistent-init: build the request once, then reuse with START/WAIT. + if (.not. neighbor_comm_handle%persistent_request_ready) then +#ifdef PSB_HAVE_MPI_NEIGHBOR_PERSISTENT + if (debug) write(*,*) me,' nbr_vect: posting MPI_Neighbor_alltoallv_init' + call mpi_neighbor_alltoallv_init( & + & y%combuf(1), & ! send buffer + & neighbor_comm_handle%send_counts, & + & neighbor_comm_handle%send_displs, & + & psb_mpi_r_dpk_, & + & y%combuf(topology_total_send + 1), & ! recv buffer + & neighbor_comm_handle%recv_counts, & + & neighbor_comm_handle%recv_displs, & + & psb_mpi_r_dpk_, & + & neighbor_comm_handle%graph_comm, & + & mpi_info_null, & + & neighbor_comm_handle%persistent_request, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/iret/)) + goto 9999 + end if + neighbor_comm_handle%persistent_request_ready = .true. + neighbor_comm_handle%persistent_buffer_size = buffer_size +#else + ! Fallback when persistent neighborhood collectives are not available + neighbor_comm_handle%persistent_request_ready = .false. + neighbor_comm_handle%persistent_buffer_size = 0 +#endif + end if + +#ifdef PSB_HAVE_MPI_NEIGHBOR_PERSISTENT + call mpi_start(neighbor_comm_handle%persistent_request, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/iret/)) + goto 9999 + end if +#else + call mpi_ineighbor_alltoallv( & + & y%combuf(1), & ! send buffer + & neighbor_comm_handle%send_counts, & + & neighbor_comm_handle%send_displs, & + & psb_mpi_r_dpk_, & + & y%combuf(topology_total_send + 1), & ! recv buffer + & neighbor_comm_handle%recv_counts, & + & neighbor_comm_handle%recv_displs, & + & psb_mpi_r_dpk_, & + & neighbor_comm_handle%graph_comm, & + & neighbor_comm_handle%comm_request, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/iret/)) + goto 9999 + end if +#endif + else + ! Post non-blocking neighborhood alltoallv + if (debug) write(*,*) me,' nbr_vect: posting MPI_Ineighbor_alltoallv' + call mpi_ineighbor_alltoallv( & + & y%combuf(1), & ! send buffer + & neighbor_comm_handle%send_counts, & + & neighbor_comm_handle%send_displs, & + & psb_mpi_r_dpk_, & + & y%combuf(topology_total_send + 1), & ! recv buffer + & neighbor_comm_handle%recv_counts, & + & neighbor_comm_handle%recv_displs, & + & psb_mpi_r_dpk_, & + & neighbor_comm_handle%graph_comm, & + & neighbor_comm_handle%comm_request, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/iret/)) + goto 9999 + end if end if end if ! do_start @@ -547,19 +683,44 @@ contains ! --------------------------------------------------------- if (do_wait) then - if (y%communication_handle == mpi_request_null) then - ! No matching start? Something is wrong - info = psb_err_mpi_error_ - call psb_errpush(info, name, m_err=(/-2/)) - goto 9999 + if (neighbor_comm_handle%use_persistent_buffers) then +#ifdef PSB_HAVE_MPI_NEIGHBOR_PERSISTENT + if (.not. neighbor_comm_handle%persistent_request_ready) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/-2/)) + goto 9999 + end if +#else + if (neighbor_comm_handle%comm_request == mpi_request_null) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/-2/)) + goto 9999 + end if +#endif + else + if (neighbor_comm_handle%comm_request == mpi_request_null) then + write(psb_err_unit,*) me, 'DBG: neighbor WAIT but comm_request is NULL; is_initialized=', & + & neighbor_comm_handle%is_initialized + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/-2/)) + goto 9999 + end if end if - topology_total_send = y%neighbor_topology%total_send - topology_total_recv = y%neighbor_topology%total_recv + topology_total_send = neighbor_comm_handle%total_send + topology_total_recv = neighbor_comm_handle%total_recv ! Wait for the non-blocking collective to complete if (debug) write(*,*) me,' nbr_vect: waiting on MPI request' - call mpi_wait(y%communication_handle, p2pstat, iret) + if (neighbor_comm_handle%use_persistent_buffers) then +#ifdef PSB_HAVE_MPI_NEIGHBOR_PERSISTENT + call mpi_wait(neighbor_comm_handle%persistent_request, p2pstat, iret) +#else + call mpi_wait(neighbor_comm_handle%comm_request, p2pstat, iret) +#endif + else + call mpi_wait(neighbor_comm_handle%comm_request, p2pstat, iret) + end if if (iret /= mpi_success) then info = psb_err_mpi_error_ call psb_errpush(info, name, m_err=(/iret/)) @@ -569,18 +730,23 @@ contains ! Scatter received data to local vector positions (polymorphic for GPU) if (debug) write(*,*) me,' nbr_vect: scattering recv data,', topology_total_recv,' elems' call y%sct(int(topology_total_recv,psb_mpk_), & - & y%neighbor_topology%recv_indexes, & + & neighbor_comm_handle%recv_indexes, & & y%combuf(topology_total_send+1:topology_total_send+topology_total_recv), & & beta) ! Clean up - y%communication_handle = mpi_request_null + if ((.not. neighbor_comm_handle%use_persistent_buffers) .or. & + & (neighbor_comm_handle%use_persistent_buffers .and. .not. neighbor_comm_handle%persistent_request_ready)) then + neighbor_comm_handle%comm_request = mpi_request_null + end if call y%device_wait() - call y%maybe_free_buffer(info) - if (info /= 0) then - call psb_errpush(psb_err_alloc_dealloc_, name) - goto 9999 + if (.not. neighbor_comm_handle%use_persistent_buffers) then + call y%maybe_free_buffer(info) + if (info /= 0) then + call psb_errpush(psb_err_alloc_dealloc_, name) + goto 9999 + end if end if if (debug) write(*,*) me,' nbr_vect: done' @@ -604,7 +770,7 @@ contains ! Takes care of Y an encaspulated multivector. ! ! - module subroutine psi_dswapdata_multivect(flag,beta,y,desc_a,info,data) + module subroutine psi_dswapdata_multivect(swap_status,beta,y,desc_a,info,data) #ifdef PSB_MPI_MOD use mpi #endif @@ -613,16 +779,15 @@ contains include 'mpif.h' #endif - integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(in) :: swap_status class(psb_d_base_multivect_type), intent(inout) :: y real(psb_dpk_), intent(in) :: beta type(psb_desc_type), target :: desc_a integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), optional :: data - ! local variables used to detect the communication scheme - logical :: swap_mpi, swap_sync, swap_send, swap_recv, swap_start, swap_wait - logical :: baseline, neighbor_a2av + ! communication scheme/status selectors + logical :: baseline, neighbor_a2av ! locals type(psb_ctxt_type) :: ctxt @@ -663,37 +828,52 @@ contains goto 9999 end if - swap_mpi = iand(flag,psb_swap_mpi_) /= 0 - swap_sync = iand(flag,psb_swap_sync_) /= 0 - swap_send = iand(flag,psb_swap_send_) /= 0 - swap_recv = iand(flag,psb_swap_recv_) /= 0 - swap_start = iand(flag,psb_swap_start_) /= 0 - swap_wait = iand(flag,psb_swap_wait_) /= 0 + if ((swap_status /= psb_comm_status_start_) .and. (swap_status /= psb_comm_status_wait_) & + & .and. (swap_status /= psb_comm_status_unknown_)) then + info = psb_err_mpi_error_ + call psb_errpush(info,name,a_err='Invalid swap_status swap_status') + goto 9999 + end if - baseline = swap_mpi .or. swap_send .or. swap_recv .or. swap_sync - neighbor_a2av = swap_start .or. swap_wait + if (.not. allocated(y%comm_handle)) then + call psb_comm_init(psb_comm_isend_irecv_, y%comm_handle, info) + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_, name, a_err='init comm default baseline') + goto 9999 + end if + end if - if( (baseline.eqv..true.).and.(neighbor_a2av.eqv..true.) ) then - info=psb_err_mpi_error_ - call psb_errpush(info,name,a_err='Incompatible flag settings: both baseline and neighbor_a2av are true') + call y%comm_handle%set_swap_status(swap_status, info) + if (info /= psb_success_) then + call psb_errpush(info,name,a_err='set_swap_status') goto 9999 end if + baseline = .false. + neighbor_a2av = .false. + select case(y%comm_handle%comm_type) + case(psb_comm_ineighbor_alltoallv_, psb_comm_persistent_ineighbor_alltoallv_) + neighbor_a2av = .true. + case default + baseline = .true. + end select + if (baseline) then - call psi_dswap_baseline_multivect(ctxt,flag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,info) + call psi_dswap_baseline_multivect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,total_send,total_recv,y%comm_handle,info) if (info /= psb_success_) then call psb_errpush(info,name,a_err='baseline swap') goto 9999 end if else if (neighbor_a2av) then - call psi_dswap_neighbor_topology_multivect(ctxt,flag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,info) + call psi_dswap_neighbor_topology_multivect(ctxt,swap_status,beta,y,comm_indexes, & + & num_neighbors,total_send,total_recv,y%comm_handle,info) if (info /= psb_success_) then call psb_errpush(info,name,a_err='neighbor a2av swap') goto 9999 end if else info = psb_err_mpi_error_ - call psb_errpush(info,name,a_err='Incompatible flag settings: neither baseline nor neighbor_a2av is true') + call psb_errpush(info,name,a_err='Incompatible swap_status settings: neither baseline nor neighbor_a2av is true') goto 9999 end if @@ -708,10 +888,8 @@ contains -subroutine psi_dswap_baseline_multivect(ctxt,flag,beta,y,comm_indexes, & - & num_neighbors,total_send,total_recv,info) - - +subroutine psi_dswap_baseline_multivect(ctxt,swap_status,beta,y,comm_indexes, & + & num_neighbors,total_send,total_recv,comm_handle,info) #ifdef PSB_MPI_MOD use mpi #endif @@ -721,11 +899,12 @@ subroutine psi_dswap_baseline_multivect(ctxt,flag,beta,y,comm_indexes, & #endif type(psb_ctxt_type), intent(in) :: ctxt - integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(in) :: swap_status real(psb_dpk_), intent(in) :: beta class(psb_d_base_multivect_type), intent(inout) :: y class(psb_i_base_vect_type), intent(inout) :: comm_indexes integer(psb_ipk_), intent(in) :: num_neighbors,total_send, total_recv + class(psb_comm_handle_type), intent(inout) :: comm_handle integer(psb_ipk_), intent(out) :: info ! locals @@ -733,10 +912,10 @@ subroutine psi_dswap_baseline_multivect(ctxt,flag,beta,y,comm_indexes, & integer(psb_mpk_) :: np, me, nesd, nerv, n integer(psb_mpk_) :: proc_to_comm, p2ptag, p2pstat(mpi_status_size), iret integer(psb_mpk_), allocatable :: prcid(:) + type(psb_comm_baseline_handle), pointer :: baseline_comm_handle integer(psb_ipk_) :: err_act, i, idx_pt, total_send_, total_recv_,& & snd_pt, rcv_pt, pnti - logical :: swap_mpi, swap_sync, swap_send, swap_recv,& - & albf,do_send,do_recv + logical :: do_send,do_recv logical, parameter :: usersend=.false., debug=.false. character(len=20) :: name @@ -752,12 +931,18 @@ subroutine psi_dswap_baseline_multivect(ctxt,flag,beta,y,comm_indexes, & n = y%get_ncols() - swap_mpi = iand(flag,psb_swap_mpi_) /= 0 - swap_sync = iand(flag,psb_swap_sync_) /= 0 - swap_send = iand(flag,psb_swap_send_) /= 0 - swap_recv = iand(flag,psb_swap_recv_) /= 0 - do_send = swap_mpi .or. swap_sync .or. swap_send - do_recv = swap_mpi .or. swap_sync .or. swap_recv + baseline_comm_handle => null() + select type(ch => comm_handle) + type is(psb_comm_baseline_handle) + baseline_comm_handle => ch + class default + info = psb_err_mpi_error_ + call psb_errpush(info,name,a_err='Expected baseline comm_handle in baseline multivect swap') + goto 9999 + end select + + do_send = (swap_status == psb_comm_status_start_) .or. (swap_status == psb_comm_status_unknown_) + do_recv = (swap_status == psb_comm_status_wait_) .or. (swap_status == psb_comm_status_unknown_) total_recv_ = total_recv * n total_send_ = total_send * n @@ -766,20 +951,21 @@ subroutine psi_dswap_baseline_multivect(ctxt,flag,beta,y,comm_indexes, & if (debug) write(*,*) me,'Internal buffer' if (do_send) then - if (allocated(y%comid)) then - if (any(y%comid /= mpi_request_null)) then - ! - ! Unfinished communication? Something is wrong.... - ! - info=psb_err_mpi_error_ + if (allocated(baseline_comm_handle%comid)) then + if (any(baseline_comm_handle%comid /= mpi_request_null)) then + info = psb_err_mpi_error_ call psb_errpush(info,name,m_err=(/-2/)) goto 9999 end if end if if (debug) write(*,*) me,'do_send start' call y%new_buffer(ione*size(comm_indexes%v),info) - call y%new_comid(num_neighbors,info) - y%comid = mpi_request_null + call psb_realloc(num_neighbors,2_psb_ipk_,baseline_comm_handle%comid,info) + if (info /= psb_success_) then + call psb_errpush(psb_err_alloc_dealloc_,name) + goto 9999 + end if + baseline_comm_handle%comid = mpi_request_null call psb_realloc(num_neighbors,prcid,info) ! First I post all the non blocking receives pnti = 1 @@ -793,9 +979,9 @@ subroutine psi_dswap_baseline_multivect(ctxt,flag,beta,y,comm_indexes, & if ((nerv>0).and.(proc_to_comm /= me)) then if (debug) write(*,*) me,'Posting receive from',prcid(i),rcv_pt p2ptag = psb_double_swap_tag - call mpi_irecv(y%combuf(rcv_pt),n*nerv,& + call mpi_irecv(y%combuf(rcv_pt),n*nerv,& & psb_mpi_r_dpk_,prcid(i),& - & p2ptag, icomm,y%comid(i,2),iret) + & p2ptag, icomm,baseline_comm_handle%comid(i,2),iret) end if rcv_pt = rcv_pt + n*nerv snd_pt = snd_pt + n*nesd @@ -838,9 +1024,9 @@ subroutine psi_dswap_baseline_multivect(ctxt,flag,beta,y,comm_indexes, & nesd = comm_indexes%v(pnti+nerv+psb_n_elem_send_) if ((nesd>0).and.(proc_to_comm /= me)) then - call mpi_isend(y%combuf(snd_pt),n*nesd,& + call mpi_isend(y%combuf(snd_pt),n*nesd,& & psb_mpi_r_dpk_,prcid(i),& - & p2ptag,icomm,y%comid(i,1),iret) + & p2ptag,icomm,baseline_comm_handle%comid(i,1),iret) end if if(iret /= mpi_success) then @@ -856,7 +1042,7 @@ subroutine psi_dswap_baseline_multivect(ctxt,flag,beta,y,comm_indexes, & if (do_recv) then if (debug) write(*,*) me,' do_Recv' - if (.not.allocated(y%comid)) then + if (.not.allocated(baseline_comm_handle%comid)) then ! ! No matching send? Something is wrong.... ! @@ -877,7 +1063,7 @@ subroutine psi_dswap_baseline_multivect(ctxt,flag,beta,y,comm_indexes, & nesd = comm_indexes%v(pnti+nerv+psb_n_elem_send_) if (proc_to_comm /= me)then if (nesd>0) then - call mpi_wait(y%comid(i,1),p2pstat,iret) + call mpi_wait(baseline_comm_handle%comid(i,1),p2pstat,iret) if(iret /= mpi_success) then info=psb_err_mpi_error_ call psb_errpush(info,name,m_err=(/iret/)) @@ -885,7 +1071,7 @@ subroutine psi_dswap_baseline_multivect(ctxt,flag,beta,y,comm_indexes, & end if end if if (nerv>0) then - call mpi_wait(y%comid(i,2),p2pstat,iret) + call mpi_wait(baseline_comm_handle%comid(i,2),p2pstat,iret) if(iret /= mpi_success) then info=psb_err_mpi_error_ call psb_errpush(info,name,m_err=(/iret/)) @@ -925,7 +1111,7 @@ subroutine psi_dswap_baseline_multivect(ctxt,flag,beta,y,comm_indexes, & ! ! Waited for com, cleanup comid ! - y%comid = mpi_request_null + baseline_comm_handle%comid = mpi_request_null ! ! Then wait for device @@ -934,7 +1120,9 @@ subroutine psi_dswap_baseline_multivect(ctxt,flag,beta,y,comm_indexes, & call y%device_wait() if (debug) write(*,*) me,' free buffer' call y%free_buffer(info) - if (info == 0) call y%free_comid(info) + if (info == 0) then + if (allocated(y%comm_handle)) call psb_comm_free(y%comm_handle, info) + end if if (info /= 0) then call psb_errpush(psb_err_alloc_dealloc_,name) goto 9999 @@ -953,7 +1141,8 @@ end subroutine psi_dswap_baseline_multivect -subroutine psi_dswap_neighbor_topology_multivect(ctxt,flag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,info) +subroutine psi_dswap_neighbor_topology_multivect(ctxt,swap_status,beta,y,comm_indexes,& + & num_neighbors,total_send,total_recv,comm_handle,info) #ifdef PSB_MPI_MOD use mpi #endif @@ -963,11 +1152,12 @@ subroutine psi_dswap_neighbor_topology_multivect(ctxt,flag,beta,y,comm_indexes,n #endif type(psb_ctxt_type), intent(in) :: ctxt - integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(in) :: swap_status real(psb_dpk_), intent(in) :: beta class(psb_d_base_multivect_type), intent(inout) :: y class(psb_i_base_vect_type), intent(inout) :: comm_indexes integer(psb_ipk_), intent(in) :: num_neighbors,total_send, total_recv + class(psb_comm_handle_type), intent(inout) :: comm_handle integer(psb_ipk_), intent(out) :: info @@ -975,6 +1165,7 @@ subroutine psi_dswap_neighbor_topology_multivect(ctxt,flag,beta,y,comm_indexes,n integer(psb_mpk_) :: icomm integer(psb_mpk_) :: np, me integer(psb_mpk_) :: iret, p2pstat(mpi_status_size) + type(psb_comm_neighbor_handle), pointer :: neighbor_comm_handle integer(psb_ipk_) :: err_act, topology_total_send, topology_total_recv, buffer_size logical :: do_start, do_wait logical, parameter :: debug = .false. @@ -991,8 +1182,20 @@ subroutine psi_dswap_neighbor_topology_multivect(ctxt,flag,beta,y,comm_indexes,n goto 9999 endif - do_start = iand(flag,psb_swap_start_) /= 0 - do_wait = iand(flag,psb_swap_wait_) /= 0 + icomm = ctxt%get_mpic() + + neighbor_comm_handle => null() + select type(ch => comm_handle) + type is(psb_comm_neighbor_handle) + neighbor_comm_handle => ch + class default + info = psb_err_mpi_error_ + call psb_errpush(info,name,a_err='Expected neighbor comm_handle in neighbor multivect swap') + goto 9999 + end select + + do_start = (swap_status == psb_comm_status_start_) .or. (swap_status == psb_comm_status_unknown_) + do_wait = (swap_status == psb_comm_status_wait_) .or. (swap_status == psb_comm_status_unknown_) call comm_indexes%sync() @@ -1001,59 +1204,131 @@ subroutine psi_dswap_neighbor_topology_multivect(ctxt,flag,beta,y,comm_indexes,n ! --------------------------------------------------------- if (do_start) then if(debug) write(*,*) me,' nbr_vect: starting data exchange' - ! Lazy initialization: build the topology on first call - if (.not. y%neighbor_topology%is_initialized) then - if (debug) write(*,*) me,' nbr_vect: building topology' - call y%neighbor_topology%init(comm_indexes%v, num_neighbors, total_send, total_recv, & + if (.not. neighbor_comm_handle%is_initialized) then + if (debug) write(*,*) me,' nbr_vect: building topology via handle' + call neighbor_comm_handle%topology_init(comm_indexes%v, num_neighbors, total_send, total_recv, & & ctxt, icomm, info) if (info /= psb_success_) then - call psb_errpush(psb_err_internal_error_, name, & - & a_err='neighbor_topology_init') + call psb_errpush(psb_err_internal_error_, name, a_err='neighbor_topology_init') goto 9999 end if end if - - topology_total_send = y%neighbor_topology%total_send - topology_total_recv = y%neighbor_topology%total_recv + topology_total_send = neighbor_comm_handle%total_send + topology_total_recv = neighbor_comm_handle%total_recv ! Buffer layout: ! combuf(1 : total_send) = send area ! combuf(total_send+1 : total_send+total_recv) = recv area buffer_size = topology_total_send + topology_total_recv - call y%new_buffer(buffer_size, info) - if (info /= 0) then - call psb_errpush(psb_err_alloc_dealloc_, name) - goto 9999 + if (neighbor_comm_handle%use_persistent_buffers) then + if ((.not.allocated(y%combuf)) .or. (size(y%combuf) < buffer_size)) then + if (neighbor_comm_handle%persistent_request_ready) then + if (neighbor_comm_handle%persistent_request /= mpi_request_null) then + call mpi_request_free(neighbor_comm_handle%persistent_request, iret) + end if + neighbor_comm_handle%persistent_request = mpi_request_null + neighbor_comm_handle%persistent_request_ready = .false. + neighbor_comm_handle%persistent_buffer_size = 0 + end if + call y%new_buffer(buffer_size, info) + if (info /= 0) then + call psb_errpush(psb_err_alloc_dealloc_, name) + goto 9999 + end if + end if + else + call y%new_buffer(buffer_size, info) + if (info /= 0) then + call psb_errpush(psb_err_alloc_dealloc_, name) + goto 9999 + end if end if - y%communication_handle = mpi_request_null + neighbor_comm_handle%comm_request = mpi_request_null ! Gather send data into contiguous send buffer (polymorphic for GPU) if (debug) write(*,*) me,' nbr_vect: gathering send data,', topology_total_send,' elems' call y%gth(int(topology_total_send,psb_mpk_), & - & y%neighbor_topology%send_indexes, & - & y%combuf(1:topology_total_send)) + & neighbor_comm_handle%send_indexes, & + & y%combuf(1:topology_total_send)) ! Wait for device (important for GPU subclasses) call y%device_wait() - ! Post non-blocking neighborhood alltoallv - if (debug) write(*,*) me,' nbr_vect: posting MPI_Ineighbor_alltoallv' - call mpi_ineighbor_alltoallv( & - & y%combuf(1), & ! send buffer - & y%neighbor_topology%send_counts, & - & y%neighbor_topology%send_displs, & - & psb_mpi_r_dpk_, & - & y%combuf(topology_total_send + 1), & ! recv buffer - & y%neighbor_topology%recv_counts, & - & y%neighbor_topology%recv_displs, & - & psb_mpi_r_dpk_, & - & y%neighbor_topology%graph_comm, & - & y%communication_handle, iret) - if (iret /= mpi_success) then - info = psb_err_mpi_error_ - call psb_errpush(info, name, m_err=(/iret/)) - goto 9999 + if (neighbor_comm_handle%use_persistent_buffers) then + if (.not. neighbor_comm_handle%persistent_request_ready) then +#ifdef PSB_HAVE_MPI_NEIGHBOR_PERSISTENT + if (debug) write(*,*) me,' nbr_vect: posting MPI_Neighbor_alltoallv_init' + call mpi_neighbor_alltoallv_init( & + & y%combuf(1), & ! send buffer + & neighbor_comm_handle%send_counts, & + & neighbor_comm_handle%send_displs, & + & psb_mpi_r_dpk_, & + & y%combuf(topology_total_send + 1), & ! recv buffer + & neighbor_comm_handle%recv_counts, & + & neighbor_comm_handle%recv_displs, & + & psb_mpi_r_dpk_, & + & neighbor_comm_handle%graph_comm, & + & mpi_info_null, & + & neighbor_comm_handle%persistent_request, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/iret/)) + goto 9999 + end if + neighbor_comm_handle%persistent_request_ready = .true. + neighbor_comm_handle%persistent_buffer_size = buffer_size +#else + ! Fallback when persistent neighborhood collectives are not available + neighbor_comm_handle%persistent_request_ready = .false. + neighbor_comm_handle%persistent_buffer_size = 0 +#endif + end if + +#ifdef PSB_HAVE_MPI_NEIGHBOR_PERSISTENT + call mpi_start(neighbor_comm_handle%persistent_request, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/iret/)) + goto 9999 + end if +#else + call mpi_ineighbor_alltoallv( & + & y%combuf(1), & ! send buffer + & neighbor_comm_handle%send_counts, & + & neighbor_comm_handle%send_displs, & + & psb_mpi_r_dpk_, & + & y%combuf(topology_total_send + 1), & ! recv buffer + & neighbor_comm_handle%recv_counts, & + & neighbor_comm_handle%recv_displs, & + & psb_mpi_r_dpk_, & + & neighbor_comm_handle%graph_comm, & + & neighbor_comm_handle%comm_request, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/iret/)) + goto 9999 + end if +#endif + else + ! Post non-blocking neighborhood alltoallv + if (debug) write(*,*) me,' nbr_vect: posting MPI_Ineighbor_alltoallv' + call mpi_ineighbor_alltoallv( & + & y%combuf(1), & ! send buffer + & neighbor_comm_handle%send_counts, & + & neighbor_comm_handle%send_displs, & + & psb_mpi_r_dpk_, & + & y%combuf(topology_total_send + 1), & ! recv buffer + & neighbor_comm_handle%recv_counts, & + & neighbor_comm_handle%recv_displs, & + & psb_mpi_r_dpk_, & + & neighbor_comm_handle%graph_comm, & + & neighbor_comm_handle%comm_request, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/iret/)) + goto 9999 + end if end if end if ! do_start @@ -1063,19 +1338,42 @@ subroutine psi_dswap_neighbor_topology_multivect(ctxt,flag,beta,y,comm_indexes,n ! --------------------------------------------------------- if (do_wait) then - if (y%communication_handle == mpi_request_null) then - ! No matching start? Something is wrong - info = psb_err_mpi_error_ - call psb_errpush(info, name, m_err=(/-2/)) - goto 9999 + if (neighbor_comm_handle%use_persistent_buffers) then +#ifdef PSB_HAVE_MPI_NEIGHBOR_PERSISTENT + if (.not. neighbor_comm_handle%persistent_request_ready) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/-2/)) + goto 9999 + end if +#else + if (neighbor_comm_handle%comm_request == mpi_request_null) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/-2/)) + goto 9999 + end if +#endif + else + if (neighbor_comm_handle%comm_request == mpi_request_null) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/-2/)) + goto 9999 + end if end if - topology_total_send = y%neighbor_topology%total_send - topology_total_recv = y%neighbor_topology%total_recv + topology_total_send = neighbor_comm_handle%total_send + topology_total_recv = neighbor_comm_handle%total_recv ! Wait for the non-blocking collective to complete if (debug) write(*,*) me,' nbr_vect: waiting on MPI request' - call mpi_wait(y%communication_handle, p2pstat, iret) + if (neighbor_comm_handle%use_persistent_buffers) then +#ifdef PSB_HAVE_MPI_NEIGHBOR_PERSISTENT + call mpi_wait(neighbor_comm_handle%persistent_request, p2pstat, iret) +#else + call mpi_wait(neighbor_comm_handle%comm_request, p2pstat, iret) +#endif + else + call mpi_wait(neighbor_comm_handle%comm_request, p2pstat, iret) + end if if (iret /= mpi_success) then info = psb_err_mpi_error_ call psb_errpush(info, name, m_err=(/iret/)) @@ -1085,18 +1383,23 @@ subroutine psi_dswap_neighbor_topology_multivect(ctxt,flag,beta,y,comm_indexes,n ! Scatter received data to local vector positions (polymorphic for GPU) if (debug) write(*,*) me,' nbr_vect: scattering recv data,', topology_total_recv,' elems' call y%sct(int(topology_total_recv,psb_mpk_), & - & y%neighbor_topology%recv_indexes, & - & y%combuf(topology_total_send+1:topology_total_send+topology_total_recv), & - & beta) + & neighbor_comm_handle%recv_indexes, & + & y%combuf(topology_total_send+1:topology_total_send+topology_total_recv), & + & beta) ! Clean up - y%communication_handle = mpi_request_null + if ((.not. neighbor_comm_handle%use_persistent_buffers) .or. & + & (neighbor_comm_handle%use_persistent_buffers .and. .not. neighbor_comm_handle%persistent_request_ready)) then + neighbor_comm_handle%comm_request = mpi_request_null + end if call y%device_wait() - call y%maybe_free_buffer(info) - if (info /= 0) then - call psb_errpush(psb_err_alloc_dealloc_, name) - goto 9999 + if (.not. neighbor_comm_handle%use_persistent_buffers) then + call y%maybe_free_buffer(info) + if (info /= 0) then + call psb_errpush(psb_err_alloc_dealloc_, name) + goto 9999 + end if end if if (debug) write(*,*) me,' nbr_vect: done' diff --git a/base/comm/internals/psi_dswaptran.F90 b/base/comm/internals/psi_dswaptran.F90 index ec6e6d2e..2c75eb76 100644 --- a/base/comm/internals/psi_dswaptran.F90 +++ b/base/comm/internals/psi_dswaptran.F90 @@ -59,12 +59,12 @@ ! ! ! Arguments: -! flag - integer Choose the algorithm for data exchange: +! swap_status - integer Choose the algorithm for data exchange: ! this is chosen through bit fields. -! swap_mpi = iand(flag,psb_swap_mpi_) /= 0 -! swap_sync = iand(flag,psb_swap_sync_) /= 0 -! swap_send = iand(flag,psb_swap_send_) /= 0 -! swap_recv = iand(flag,psb_swap_recv_) /= 0 +! swap_mpi = iand(swap_status,psb_swap_mpi_) /= 0 +! swap_sync = iand(swap_status,psb_swap_sync_) /= 0 +! swap_send = iand(swap_status,psb_swap_send_) /= 0 +! swap_recv = iand(swap_status,psb_swap_recv_) /= 0 ! if (swap_mpi): use underlying MPI_ALLTOALLV. ! if (swap_sync): use PSB_SND and PSB_RCV in ! synchronized pairs @@ -92,8 +92,12 @@ ! submodule (psi_d_comm_v_mod) psi_d_swaptran_impl use psb_base_mod + use psb_comm_schemes_mod, only: psb_comm_handle_type, psb_comm_isend_irecv_ + use psb_comm_factory_mod, only: psb_comm_init, psb_comm_free + use psb_comm_baseline_mod, only: psb_comm_baseline_handle + use psb_comm_neighbor_impl_mod, only: psb_comm_neighbor_handle contains - module subroutine psi_dswaptran_vect(flag,beta,y,desc_a,info,data) + module subroutine psi_dswaptran_vect(swap_status,beta,y,desc_a,info,data) #ifdef PSB_MPI_MOD use mpi @@ -103,7 +107,7 @@ contains include 'mpif.h' #endif - integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(in) :: swap_status real(psb_dpk_), intent(in) :: beta class(psb_d_base_vect_type), intent(inout) :: y type(psb_desc_type),target :: desc_a @@ -153,37 +157,46 @@ contains goto 9999 end if - swap_mpi = iand(flag,psb_swap_mpi_) /= 0 - swap_sync = iand(flag,psb_swap_sync_) /= 0 - swap_send = iand(flag,psb_swap_send_) /= 0 - swap_recv = iand(flag,psb_swap_recv_) /= 0 - swap_start = iand(flag,psb_swap_start_) /= 0 - swap_wait = iand(flag,psb_swap_wait_) /= 0 + swap_mpi = iand(swap_status,psb_swap_mpi_) /= 0 + swap_sync = iand(swap_status,psb_swap_sync_) /= 0 + swap_send = iand(swap_status,psb_swap_send_) /= 0 + swap_recv = iand(swap_status,psb_swap_recv_) /= 0 + swap_start = iand(swap_status,psb_swap_start_) /= 0 + swap_wait = iand(swap_status,psb_swap_wait_) /= 0 baseline = swap_mpi .or. swap_send .or. swap_recv .or. swap_sync neighbor_a2av = swap_start .or. swap_wait if( (baseline.eqv..true.).and.(neighbor_a2av.eqv..true.) ) then info = psb_err_mpi_error_ - call psb_errpush(info,name,a_err='Incompatible flag settings: both baseline and neighbor_a2av are true') + call psb_errpush(info,name,a_err='Incompatible swap_status settings: both baseline and neighbor_a2av are true') goto 9999 end if + if (.not. allocated(y%comm_handle)) then + call psb_comm_init(psb_comm_isend_irecv_, y%comm_handle, info) + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_, name, a_err='init comm default baseline') + goto 9999 + end if + end if + if (baseline) then - call psi_dtran_baseline_vect(ctxt,flag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,info) + call psi_dtran_baseline_vect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,total_send,total_recv,y%comm_handle,info) if (info /= psb_success_) then call psb_errpush(info,name,a_err='baseline swap') goto 9999 end if else if (neighbor_a2av) then - call psi_dtran_neighbor_topology_vect(ctxt,flag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,info) + call psi_dtran_neighbor_topology_vect(ctxt,swap_status,beta,y,comm_indexes,& + & num_neighbors,total_send,total_recv,y%comm_handle,info) if (info /= psb_success_) then call psb_errpush(info,name,a_err='neighbor a2av swap') goto 9999 end if else info = psb_err_mpi_error_ - call psb_errpush(info,name,a_err='Incompatible flag settings: neither baseline nor neighbor_a2av is true') + call psb_errpush(info,name,a_err='Incompatible swap_status settings: neither baseline nor neighbor_a2av is true') goto 9999 end if @@ -208,8 +221,8 @@ contains ! ! ! - module subroutine psi_dtran_baseline_vect(ctxt,flag,beta,y,idx,& - & num_neighbors,total_send,total_recv,info) + module subroutine psi_dtran_baseline_vect(ctxt,swap_status,beta,y,idx,& + & num_neighbors,total_send,total_recv,comm_handle,info) #ifdef PSB_MPI_MOD use mpi @@ -220,11 +233,12 @@ contains #endif type(psb_ctxt_type), intent(in) :: ctxt - integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(in) :: swap_status real(psb_dpk_), intent(in) :: beta class(psb_d_base_vect_type), intent(inout) :: y class(psb_i_base_vect_type), intent(inout) :: idx integer(psb_ipk_), intent(in) :: num_neighbors,total_send, total_recv + class(psb_comm_handle_type), intent(inout) :: comm_handle integer(psb_ipk_), intent(out) :: info ! locals @@ -232,6 +246,7 @@ contains integer(psb_mpk_) :: proc_to_comm, p2ptag, p2pstat(mpi_status_size), iret integer(psb_mpk_) :: icomm integer(psb_mpk_), allocatable :: prcid(:) + type(psb_comm_baseline_handle), pointer :: baseline_comm_handle integer(psb_ipk_) :: err_act, i, idx_pt, total_send_, total_recv_,& & snd_pt, rcv_pt, pnti logical :: swap_mpi, swap_sync, swap_send, swap_recv,& @@ -250,11 +265,21 @@ contains endif icomm = ctxt%get_mpic() + baseline_comm_handle => null() + select type(ch => comm_handle) + type is(psb_comm_baseline_handle) + baseline_comm_handle => ch + class default + info = psb_err_mpi_error_ + call psb_errpush(info,name,a_err='Expected baseline comm_handle in baseline swaptran') + goto 9999 + end select + n=1 - swap_mpi = iand(flag,psb_swap_mpi_) /= 0 - swap_sync = iand(flag,psb_swap_sync_) /= 0 - swap_send = iand(flag,psb_swap_send_) /= 0 - swap_recv = iand(flag,psb_swap_recv_) /= 0 + swap_mpi = iand(swap_status,psb_swap_mpi_) /= 0 + swap_sync = iand(swap_status,psb_swap_sync_) /= 0 + swap_send = iand(swap_status,psb_swap_send_) /= 0 + swap_recv = iand(swap_status,psb_swap_recv_) /= 0 do_send = swap_mpi .or. swap_sync .or. swap_send do_recv = swap_mpi .or. swap_sync .or. swap_recv @@ -265,8 +290,8 @@ contains if (debug) write(*,*) me,'Internal buffer' if (do_send) then - if (allocated(y%comid)) then - if (any(y%comid /= mpi_request_null)) then + if (allocated(baseline_comm_handle%comid)) then + if (any(baseline_comm_handle%comid /= mpi_request_null)) then ! ! Unfinished communication? Something is wrong.... ! @@ -277,8 +302,8 @@ contains end if if (debug) write(*,*) me,'do_send start' call y%new_buffer(ione*size(idx%v),info) - call y%new_comid(num_neighbors,info) - y%comid = mpi_request_null + call psb_realloc(num_neighbors,2_psb_ipk_,baseline_comm_handle%comid,info) + baseline_comm_handle%comid = mpi_request_null call psb_realloc(num_neighbors,prcid,info) ! First I post all the non blocking receives pnti = 1 @@ -295,7 +320,7 @@ contains if (debug) write(*,*) me,'Posting receive from',prcid(i),rcv_pt call mpi_irecv(y%combuf(snd_pt),nesd,& & psb_mpi_r_dpk_,prcid(i),& - & p2ptag, icomm,y%comid(i,2),iret) + & p2ptag, icomm,baseline_comm_handle%comid(i,2),iret) end if pnti = pnti + nerv + nesd + 3 end do @@ -342,7 +367,7 @@ contains if ((nerv>0).and.(proc_to_comm /= me)) then call mpi_isend(y%combuf(rcv_pt),nerv,& & psb_mpi_r_dpk_,prcid(i),& - & p2ptag,icomm,y%comid(i,1),iret) + & p2ptag,icomm,baseline_comm_handle%comid(i,1),iret) end if if(iret /= mpi_success) then @@ -357,7 +382,7 @@ contains if (do_recv) then if (debug) write(*,*) me,' do_Recv' - if (.not.allocated(y%comid)) then + if (.not.allocated(baseline_comm_handle%comid)) then ! ! No matching send? Something is wrong.... ! @@ -379,7 +404,7 @@ contains if (proc_to_comm /= me)then if (nerv>0) then - call mpi_wait(y%comid(i,1),p2pstat,iret) + call mpi_wait(baseline_comm_handle%comid(i,1),p2pstat,iret) if(iret /= mpi_success) then info=psb_err_mpi_error_ call psb_errpush(info,name,m_err=(/iret/)) @@ -387,7 +412,7 @@ contains end if end if if (nesd>0) then - call mpi_wait(y%comid(i,2),p2pstat,iret) + call mpi_wait(baseline_comm_handle%comid(i,2),p2pstat,iret) if(iret /= mpi_success) then info=psb_err_mpi_error_ call psb_errpush(info,name,m_err=(/iret/)) @@ -425,7 +450,7 @@ contains ! ! Waited for everybody, clean up ! - y%comid = mpi_request_null + baseline_comm_handle%comid = mpi_request_null ! ! Then wait for device @@ -434,7 +459,9 @@ contains call y%device_wait() if (debug) write(*,*) me,' free buffer' call y%maybe_free_buffer(info) - if (info == 0) call y%free_comid(info) + if (info == 0) then + if (allocated(y%comm_handle)) call psb_comm_free(y%comm_handle, info) + end if if (info /= 0) then call psb_errpush(psb_err_alloc_dealloc_,name) goto 9999 @@ -455,7 +482,8 @@ contains - subroutine psi_dtran_neighbor_topology_vect(ctxt,flag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,info) + subroutine psi_dtran_neighbor_topology_vect(ctxt,swap_status,beta,y,comm_indexes,& + & num_neighbors,total_send,total_recv,comm_handle,info) #ifdef PSB_MPI_MOD use mpi #endif @@ -465,17 +493,19 @@ contains #endif type(psb_ctxt_type), intent(in) :: ctxt - integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(in) :: swap_status real(psb_dpk_), intent(in) :: beta class(psb_d_base_vect_type), intent(inout) :: y class(psb_i_base_vect_type), intent(inout) :: comm_indexes integer(psb_ipk_), intent(in) :: num_neighbors,total_send,total_recv + class(psb_comm_handle_type), intent(inout) :: comm_handle integer(psb_ipk_), intent(out) :: info ! locals integer(psb_mpk_) :: icomm integer(psb_mpk_) :: np, me integer(psb_mpk_) :: iret, p2pstat(mpi_status_size) + type(psb_comm_neighbor_handle), pointer :: neighbor_comm_handle integer(psb_ipk_) :: err_act, topology_total_send, topology_total_recv, buffer_size logical :: do_start, do_wait logical, parameter :: debug = .false. @@ -494,8 +524,18 @@ contains icomm = ctxt%get_mpic() - do_start = iand(flag,psb_swap_start_) /= 0 - do_wait = iand(flag,psb_swap_wait_) /= 0 + neighbor_comm_handle => null() + select type(ch => comm_handle) + type is(psb_comm_neighbor_handle) + neighbor_comm_handle => ch + class default + info = psb_err_mpi_error_ + call psb_errpush(info,name,a_err='Expected neighbor comm_handle in neighbor swaptran') + goto 9999 + end select + + do_start = iand(swap_status,psb_swap_start_) /= 0 + do_wait = iand(swap_status,psb_swap_wait_) /= 0 call comm_indexes%sync() @@ -505,9 +545,9 @@ contains if (do_start) then if(debug) write(*,*) me,' nbr_vect: starting data exchange' ! Lazy initialization: build the topology on first call - if (.not. y%neighbor_topology%is_initialized) then + if (.not. neighbor_comm_handle%is_initialized) then if (debug) write(*,*) me,' nbr_vect: building topology' - call y%neighbor_topology%init(comm_indexes%v, num_neighbors, total_send, total_recv, ctxt, icomm, info) + call neighbor_comm_handle%topology_init(comm_indexes%v, num_neighbors, total_send, total_recv, ctxt, icomm, info) if (info /= psb_success_) then call psb_errpush(psb_err_internal_error_, name, & & a_err='neighbor_topology_init') @@ -515,8 +555,8 @@ contains end if end if - topology_total_send = y%neighbor_topology%total_send - topology_total_recv = y%neighbor_topology%total_recv + topology_total_send = neighbor_comm_handle%total_send + topology_total_recv = neighbor_comm_handle%total_recv ! Buffer layout: ! combuf(1 : total_send) = send area @@ -528,12 +568,12 @@ contains call psb_errpush(psb_err_alloc_dealloc_, name) goto 9999 end if - y%communication_handle = mpi_request_null + neighbor_comm_handle%comm_request = mpi_request_null ! For transpose exchange: gather recv area first (we will send "recv" data) if (debug) write(*,*) me,' nbr_tran_vect: gathering recv data,', topology_total_recv,' elems' call y%gth(int(topology_total_recv,psb_mpk_), & - & y%neighbor_topology%recv_indexes, & + & neighbor_comm_handle%recv_indexes, & & y%combuf(1:topology_total_recv)) ! Wait for device (important for GPU subclasses) @@ -543,15 +583,15 @@ contains if (debug) write(*,*) me,' nbr_tran_vect: posting MPI_Ineighbor_alltoallv (swapped)' call mpi_ineighbor_alltoallv( & & y%combuf(1), & ! send buffer (recv_indexes gathered) - & y%neighbor_topology%recv_counts, & - & y%neighbor_topology%recv_displs, & + & neighbor_comm_handle%recv_counts, & + & neighbor_comm_handle%recv_displs, & & psb_mpi_r_dpk_, & & y%combuf(topology_total_recv + 1), & ! recv buffer (will contain send_indexes data) - & y%neighbor_topology%send_counts, & - & y%neighbor_topology%send_displs, & + & neighbor_comm_handle%send_counts, & + & neighbor_comm_handle%send_displs, & & psb_mpi_r_dpk_, & - & y%neighbor_topology%graph_comm, & - & y%communication_handle, iret) + & neighbor_comm_handle%graph_comm, & + & neighbor_comm_handle%comm_request, iret) if (iret /= mpi_success) then info = psb_err_mpi_error_ call psb_errpush(info, name, m_err=(/iret/)) @@ -565,19 +605,19 @@ contains ! --------------------------------------------------------- if (do_wait) then - if (y%communication_handle == mpi_request_null) then + if (neighbor_comm_handle%comm_request == mpi_request_null) then ! No matching start? Something is wrong info = psb_err_mpi_error_ call psb_errpush(info, name, m_err=(/-2/)) goto 9999 end if - topology_total_send = y%neighbor_topology%total_send - topology_total_recv = y%neighbor_topology%total_recv + topology_total_send = neighbor_comm_handle%total_send + topology_total_recv = neighbor_comm_handle%total_recv ! Wait for the non-blocking collective to complete if (debug) write(*,*) me,' nbr_vect: waiting on MPI request' - call mpi_wait(y%communication_handle, p2pstat, iret) + call mpi_wait(neighbor_comm_handle%comm_request, p2pstat, iret) if (iret /= mpi_success) then info = psb_err_mpi_error_ call psb_errpush(info, name, m_err=(/iret/)) @@ -587,13 +627,13 @@ contains ! For transpose exchange: scatter the data that correspond to peers' send area if (debug) write(*,*) me,' nbr_tran_vect: scattering send-index data,', topology_total_send,' elems' call y%sct(int(topology_total_send,psb_mpk_), & - & y%neighbor_topology%send_indexes, & + & neighbor_comm_handle%send_indexes, & & y%combuf(topology_total_recv+1:topology_total_recv+topology_total_send), & & beta) ! Clean up - y%communication_handle = mpi_request_null + neighbor_comm_handle%comm_request = mpi_request_null call y%device_wait() call y%maybe_free_buffer(info) if (info /= 0) then @@ -625,7 +665,7 @@ contains ! Takes care of Y an encaspulated multivector. ! ! - module subroutine psi_dswaptran_multivect(flag,beta,y,desc_a,info,data) + module subroutine psi_dswaptran_multivect(swap_status,beta,y,desc_a,info,data) #ifdef PSB_MPI_MOD use mpi @@ -635,7 +675,7 @@ contains include 'mpif.h' #endif - integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(in) :: swap_status real(psb_dpk_), intent(in) :: beta class(psb_d_base_multivect_type), intent(inout) :: y type(psb_desc_type),target :: desc_a @@ -685,37 +725,46 @@ contains goto 9999 end if - swap_mpi = iand(flag,psb_swap_mpi_) /= 0 - swap_sync = iand(flag,psb_swap_sync_) /= 0 - swap_send = iand(flag,psb_swap_send_) /= 0 - swap_recv = iand(flag,psb_swap_recv_) /= 0 - swap_start = iand(flag,psb_swap_start_) /= 0 - swap_wait = iand(flag,psb_swap_wait_) /= 0 + swap_mpi = iand(swap_status,psb_swap_mpi_) /= 0 + swap_sync = iand(swap_status,psb_swap_sync_) /= 0 + swap_send = iand(swap_status,psb_swap_send_) /= 0 + swap_recv = iand(swap_status,psb_swap_recv_) /= 0 + swap_start = iand(swap_status,psb_swap_start_) /= 0 + swap_wait = iand(swap_status,psb_swap_wait_) /= 0 baseline = swap_mpi .or. swap_send .or. swap_recv .or. swap_sync neighbor_a2av = swap_start .or. swap_wait if( (baseline.eqv..true.).and.(neighbor_a2av.eqv..true.) ) then info = psb_err_mpi_error_ - call psb_errpush(info,name,a_err='Incompatible flag settings: both baseline and neighbor_a2av are true') + call psb_errpush(info,name,a_err='Incompatible swap_status settings: both baseline and neighbor_a2av are true') goto 9999 end if + if (.not. allocated(y%comm_handle)) then + call psb_comm_init(psb_comm_isend_irecv_, y%comm_handle, info) + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_, name, a_err='init comm default baseline') + goto 9999 + end if + end if + if (baseline) then - call psi_dtran_baseline_multivect(ctxt,flag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,info) + call psi_dtran_baseline_multivect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,total_send,total_recv,y%comm_handle,info) if (info /= psb_success_) then call psb_errpush(info,name,a_err='baseline swap') goto 9999 end if else if (neighbor_a2av) then - call psi_dtran_neighbor_topology_multivect(ctxt,flag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,info) + call psi_dtran_neighbor_topology_multivect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,& + & total_send,total_recv,y%comm_handle,info) if (info /= psb_success_) then call psb_errpush(info,name,a_err='neighbor a2av swap') goto 9999 end if else info = psb_err_mpi_error_ - call psb_errpush(info,name,a_err='Incompatible flag settings: neither baseline nor neighbor_a2av is true') + call psb_errpush(info,name,a_err='Incompatible swap_status settings: neither baseline nor neighbor_a2av is true') goto 9999 end if @@ -727,8 +776,8 @@ contains return end subroutine psi_dswaptran_multivect - subroutine psi_dtran_baseline_multivect(ctxt,flag,beta,y,idx,& - & num_neighbors,total_send,total_recv,info) + subroutine psi_dtran_baseline_multivect(ctxt,swap_status,beta,y,idx,& + & num_neighbors,total_send,total_recv,comm_handle,info) #ifdef PSB_MPI_MOD use mpi @@ -739,18 +788,20 @@ contains #endif type(psb_ctxt_type), intent(in) :: ctxt - integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(in) :: swap_status integer(psb_ipk_), intent(out) :: info class(psb_d_base_multivect_type), intent(inout) :: y real(psb_dpk_), intent(in) :: beta class(psb_i_base_vect_type), intent(inout) :: idx integer(psb_ipk_), intent(in) :: num_neighbors,total_send,total_recv + class(psb_comm_handle_type), intent(inout) :: comm_handle ! locals integer(psb_mpk_) :: np, me, nesd, nerv, n integer(psb_mpk_) :: proc_to_comm, p2ptag, p2pstat(mpi_status_size), iret integer(psb_mpk_) :: icomm integer(psb_mpk_), allocatable :: prcid(:) + type(psb_comm_baseline_handle), pointer :: baseline_comm_handle integer(psb_ipk_) :: err_act, i, idx_pt, total_send_, total_recv_,& & snd_pt, rcv_pt, pnti logical :: swap_mpi, swap_sync, swap_send, swap_recv,& @@ -769,12 +820,22 @@ contains endif icomm = ctxt%get_mpic() + baseline_comm_handle => null() + select type(ch => comm_handle) + type is(psb_comm_baseline_handle) + baseline_comm_handle => ch + class default + info = psb_err_mpi_error_ + call psb_errpush(info,name,a_err='Expected baseline comm_handle in baseline multivect swaptran') + goto 9999 + end select + n = y%get_ncols() - swap_mpi = iand(flag,psb_swap_mpi_) /= 0 - swap_sync = iand(flag,psb_swap_sync_) /= 0 - swap_send = iand(flag,psb_swap_send_) /= 0 - swap_recv = iand(flag,psb_swap_recv_) /= 0 + swap_mpi = iand(swap_status,psb_swap_mpi_) /= 0 + swap_sync = iand(swap_status,psb_swap_sync_) /= 0 + swap_send = iand(swap_status,psb_swap_send_) /= 0 + swap_recv = iand(swap_status,psb_swap_recv_) /= 0 do_send = swap_mpi .or. swap_sync .or. swap_send do_recv = swap_mpi .or. swap_sync .or. swap_recv @@ -785,8 +846,8 @@ contains if (debug) write(*,*) me,'Internal buffer' if (do_send) then - if (allocated(y%comid)) then - if (any(y%comid /= mpi_request_null)) then + if (allocated(baseline_comm_handle%comid)) then + if (any(baseline_comm_handle%comid /= mpi_request_null)) then ! ! Unfinished communication? Something is wrong.... ! @@ -797,8 +858,8 @@ contains end if if (debug) write(*,*) me,'do_send start' call y%new_buffer(ione*size(idx%v),info) - call y%new_comid(num_neighbors,info) - y%comid = mpi_request_null + call psb_realloc(num_neighbors,2_psb_ipk_,baseline_comm_handle%comid,info) + baseline_comm_handle%comid = mpi_request_null call psb_realloc(num_neighbors,prcid,info) ! First I post all the non blocking receives pnti = 1 @@ -814,7 +875,7 @@ contains if (debug) write(*,*) me,'Posting receive from',prcid(i),snd_pt call mpi_irecv(y%combuf(snd_pt),n*nesd,& & psb_mpi_r_dpk_,prcid(i),& - & p2ptag, icomm,y%comid(i,2),iret) + & p2ptag, icomm,baseline_comm_handle%comid(i,2),iret) end if rcv_pt = rcv_pt + n*nerv snd_pt = snd_pt + n*nesd @@ -861,7 +922,7 @@ contains if ((nerv>0).and.(proc_to_comm /= me)) then call mpi_isend(y%combuf(rcv_pt),n*nerv,& & psb_mpi_r_dpk_,prcid(i),& - & p2ptag,icomm,y%comid(i,1),iret) + & p2ptag,icomm,baseline_comm_handle%comid(i,1),iret) end if if(iret /= mpi_success) then @@ -877,7 +938,7 @@ contains if (do_recv) then if (debug) write(*,*) me,' do_Recv' - if (.not.allocated(y%comid)) then + if (.not.allocated(baseline_comm_handle%comid)) then ! ! No matching send? Something is wrong.... ! @@ -898,7 +959,7 @@ contains nesd = idx%v(pnti+nerv+psb_n_elem_send_) if (proc_to_comm /= me)then if (nerv>0) then - call mpi_wait(y%comid(i,1),p2pstat,iret) + call mpi_wait(baseline_comm_handle%comid(i,1),p2pstat,iret) if(iret /= mpi_success) then info=psb_err_mpi_error_ call psb_errpush(info,name,m_err=(/iret/)) @@ -906,7 +967,7 @@ contains end if end if if (nesd>0) then - call mpi_wait(y%comid(i,2),p2pstat,iret) + call mpi_wait(baseline_comm_handle%comid(i,2),p2pstat,iret) if(iret /= mpi_success) then info=psb_err_mpi_error_ call psb_errpush(info,name,m_err=(/iret/)) @@ -948,7 +1009,7 @@ contains ! ! Waited for com, cleanup comid ! - y%comid = mpi_request_null + baseline_comm_handle%comid = mpi_request_null ! ! Then wait for device @@ -957,7 +1018,9 @@ contains call y%device_wait() if (debug) write(*,*) me,' free buffer' call y%maybe_free_buffer(info) - if (info == 0) call y%free_comid(info) + if (info == 0) then + if (allocated(y%comm_handle)) call psb_comm_free(y%comm_handle, info) + end if if (info /= 0) then call psb_errpush(psb_err_alloc_dealloc_,name) goto 9999 @@ -976,7 +1039,8 @@ contains end subroutine psi_dtran_baseline_multivect - subroutine psi_dtran_neighbor_topology_multivect(ctxt,flag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,info) + subroutine psi_dtran_neighbor_topology_multivect(ctxt,swap_status,beta,y,comm_indexes,& + & num_neighbors,total_send,total_recv,comm_handle,info) #ifdef PSB_MPI_MOD use mpi #endif @@ -986,17 +1050,19 @@ contains #endif type(psb_ctxt_type), intent(in) :: ctxt - integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(in) :: swap_status real(psb_dpk_), intent(in) :: beta class(psb_d_base_multivect_type), intent(inout) :: y class(psb_i_base_vect_type), intent(inout) :: comm_indexes integer(psb_ipk_), intent(in) :: num_neighbors,total_send,total_recv + class(psb_comm_handle_type), intent(inout) :: comm_handle integer(psb_ipk_), intent(out) :: info ! locals integer(psb_mpk_) :: icomm integer(psb_mpk_) :: np, me integer(psb_mpk_) :: iret, p2pstat(mpi_status_size) + type(psb_comm_neighbor_handle), pointer :: neighbor_comm_handle integer(psb_ipk_) :: err_act, topology_total_send, topology_total_recv, buffer_size logical :: do_start, do_wait logical, parameter :: debug = .false. @@ -1015,8 +1081,18 @@ contains icomm = ctxt%get_mpic() - do_start = iand(flag,psb_swap_start_) /= 0 - do_wait = iand(flag,psb_swap_wait_) /= 0 + neighbor_comm_handle => null() + select type(ch => comm_handle) + type is(psb_comm_neighbor_handle) + neighbor_comm_handle => ch + class default + info = psb_err_mpi_error_ + call psb_errpush(info,name,a_err='Expected neighbor comm_handle in neighbor multivect swaptran') + goto 9999 + end select + + do_start = iand(swap_status,psb_swap_start_) /= 0 + do_wait = iand(swap_status,psb_swap_wait_) /= 0 call comm_indexes%sync() @@ -1026,9 +1102,9 @@ contains if (do_start) then if(debug) write(*,*) me,' nbr_vect: starting data exchange' ! Lazy initialization: build the topology on first call - if (.not. y%neighbor_topology%is_initialized) then + if (.not. neighbor_comm_handle%is_initialized) then if (debug) write(*,*) me,' nbr_vect: building topology' - call y%neighbor_topology%init(comm_indexes%v, num_neighbors, total_send, total_recv, ctxt, icomm, info) + call neighbor_comm_handle%topology_init(comm_indexes%v, num_neighbors, total_send, total_recv, ctxt, icomm, info) if (info /= psb_success_) then call psb_errpush(psb_err_internal_error_, name, & & a_err='neighbor_topology_init') @@ -1036,8 +1112,8 @@ contains end if end if - topology_total_send = y%neighbor_topology%total_send - topology_total_recv = y%neighbor_topology%total_recv + topology_total_send = neighbor_comm_handle%total_send + topology_total_recv = neighbor_comm_handle%total_recv ! Buffer layout: ! combuf(1 : total_send) = send area @@ -1049,12 +1125,12 @@ contains call psb_errpush(psb_err_alloc_dealloc_, name) goto 9999 end if - y%communication_handle = mpi_request_null + neighbor_comm_handle%comm_request = mpi_request_null ! For transpose exchange: gather recv area first (we will send "recv" data) if (debug) write(*,*) me,' nbr_tran_vect: gathering recv data,', topology_total_recv,' elems' call y%gth(int(topology_total_recv,psb_mpk_), & - & y%neighbor_topology%recv_indexes, & + & neighbor_comm_handle%recv_indexes, & & y%combuf(1:topology_total_recv)) ! Wait for device (important for GPU subclasses) @@ -1064,15 +1140,15 @@ contains if (debug) write(*,*) me,' nbr_tran_vect: posting MPI_Ineighbor_alltoallv (swapped)' call mpi_ineighbor_alltoallv( & & y%combuf(1), & ! send buffer (recv_indexes gathered) - & y%neighbor_topology%recv_counts, & - & y%neighbor_topology%recv_displs, & + & neighbor_comm_handle%recv_counts, & + & neighbor_comm_handle%recv_displs, & & psb_mpi_r_dpk_, & & y%combuf(topology_total_recv + 1), & ! recv buffer (will contain send_indexes data) - & y%neighbor_topology%send_counts, & - & y%neighbor_topology%send_displs, & + & neighbor_comm_handle%send_counts, & + & neighbor_comm_handle%send_displs, & & psb_mpi_r_dpk_, & - & y%neighbor_topology%graph_comm, & - & y%communication_handle, iret) + & neighbor_comm_handle%graph_comm, & + & neighbor_comm_handle%comm_request, iret) if (iret /= mpi_success) then info = psb_err_mpi_error_ call psb_errpush(info, name, m_err=(/iret/)) @@ -1086,19 +1162,19 @@ contains ! --------------------------------------------------------- if (do_wait) then - if (y%communication_handle == mpi_request_null) then + if (neighbor_comm_handle%comm_request == mpi_request_null) then ! No matching start? Something is wrong info = psb_err_mpi_error_ call psb_errpush(info, name, m_err=(/-2/)) goto 9999 end if - topology_total_send = y%neighbor_topology%total_send - topology_total_recv = y%neighbor_topology%total_recv + topology_total_send = neighbor_comm_handle%total_send + topology_total_recv = neighbor_comm_handle%total_recv ! Wait for the non-blocking collective to complete if (debug) write(*,*) me,' nbr_vect: waiting on MPI request' - call mpi_wait(y%communication_handle, p2pstat, iret) + call mpi_wait(neighbor_comm_handle%comm_request, p2pstat, iret) if (iret /= mpi_success) then info = psb_err_mpi_error_ call psb_errpush(info, name, m_err=(/iret/)) @@ -1108,13 +1184,13 @@ contains ! For transpose exchange: scatter the data that correspond to peers' send area if (debug) write(*,*) me,' nbr_tran_vect: scattering send-index data,', topology_total_send,' elems' call y%sct(int(topology_total_send,psb_mpk_), & - & y%neighbor_topology%send_indexes, & + & neighbor_comm_handle%send_indexes, & & y%combuf(topology_total_recv+1:topology_total_recv+topology_total_send), & & beta) ! Clean up - y%communication_handle = mpi_request_null + neighbor_comm_handle%comm_request = mpi_request_null call y%device_wait() call y%maybe_free_buffer(info) if (info /= 0) then diff --git a/base/modules/Makefile b/base/modules/Makefile index 3fd14a42..7e6f0e7c 100644 --- a/base/modules/Makefile +++ b/base/modules/Makefile @@ -28,7 +28,10 @@ COMMINT= penv/psi_penv_mod.o \ SERIAL_MODS=serial/psb_s_serial_mod.o serial/psb_d_serial_mod.o \ serial/psb_c_serial_mod.o serial/psb_z_serial_mod.o \ serial/psb_serial_mod.o \ - comm/psb_neighbor_topology_mod.o \ + comm/comm_schemes/psb_comm_schemes_mod.o \ + comm/comm_schemes/psb_comm_baseline_mod.o \ + comm/comm_schemes/psb_comm_neighbor_impl_mod.o \ + comm/comm_schemes/psb_comm_factory_mod.o \ serial/psb_i_base_vect_mod.o serial/psb_i_vect_mod.o\ serial/psb_l_base_vect_mod.o serial/psb_l_vect_mod.o\ serial/psb_d_base_vect_mod.o serial/psb_d_vect_mod.o\ @@ -95,9 +98,11 @@ UTIL_MODS = desc/psb_desc_const_mod.o desc/psb_indx_map_mod.o\ comm/psb_s_linmap_mod.o comm/psb_d_linmap_mod.o \ comm/psb_c_linmap_mod.o comm/psb_z_linmap_mod.o \ comm/psb_comm_mod.o \ + comm/psb_neighbor_topology_mod.o \ comm/psb_i_comm_mod.o comm/psb_l_comm_mod.o \ comm/psb_s_comm_mod.o comm/psb_d_comm_mod.o\ comm/psb_c_comm_mod.o comm/psb_z_comm_mod.o \ + comm/psb_i2_comm_a_mod.o \ comm/psb_m_comm_a_mod.o comm/psb_e_comm_a_mod.o \ comm/psb_s_comm_a_mod.o comm/psb_d_comm_a_mod.o\ comm/psb_c_comm_a_mod.o comm/psb_z_comm_a_mod.o \ @@ -182,6 +187,14 @@ auxil/psb_string_mod.o auxil/psb_m_realloc_mod.o auxil/psb_e_realloc_mod.o auxil auxil/psb_d_realloc_mod.o auxil/psb_c_realloc_mod.o auxil/psb_z_realloc_mod.o \ desc/psb_desc_const_mod.o psi_penv_mod.o: psb_const_mod.o +comm/comm_schemes/psb_comm_schemes_mod.o: psb_const_mod.o + +comm/comm_schemes/psb_comm_baseline_mod.o: comm/comm_schemes/psb_comm_schemes_mod.o + +comm/comm_schemes/psb_comm_neighbor_impl_mod.o: psb_const_mod.o desc/psb_desc_const_mod.o comm/comm_schemes/psb_comm_schemes_mod.o + +comm/comm_schemes/psb_comm_factory_mod.o: comm/comm_schemes/psb_comm_schemes_mod.o comm/comm_schemes/psb_comm_baseline_mod.o comm/comm_schemes/psb_comm_neighbor_impl_mod.o + comm/psb_neighbor_topology_mod.o: psb_const_mod.o desc/psb_desc_const_mod.o desc/psb_indx_map_mod.o desc/psb_hash_mod.o: psb_realloc_mod.o psb_const_mod.o desc/psb_desc_const_mod.o @@ -263,8 +276,10 @@ serial/psb_d_base_mat_mod.o: serial/psb_d_base_vect_mod.o serial/psb_c_base_mat_mod.o: serial/psb_c_base_vect_mod.o serial/psb_z_base_mat_mod.o: serial/psb_z_base_vect_mod.o serial/psb_l_base_vect_mod.o: serial/psb_i_base_vect_mod.o -serial/psb_c_base_vect_mod.o serial/psb_s_base_vect_mod.o serial/psb_d_base_vect_mod.o serial/psb_z_base_vect_mod.o: serial/psb_i_base_vect_mod.o serial/psb_l_base_vect_mod.o comm/psb_neighbor_topology_mod.o -serial/psb_i_base_vect_mod.o serial/psb_l_base_vect_mod.o serial/psb_c_base_vect_mod.o serial/psb_s_base_vect_mod.o serial/psb_d_base_vect_mod.o serial/psb_z_base_vect_mod.o: auxil/psi_serial_mod.o psb_realloc_mod.o comm/psb_neighbor_topology_mod.o +serial/psb_c_base_vect_mod.o serial/psb_s_base_vect_mod.o serial/psb_z_base_vect_mod.o: serial/psb_i_base_vect_mod.o serial/psb_l_base_vect_mod.o comm/comm_schemes/psb_comm_neighbor_impl_mod.o +serial/psb_d_base_vect_mod.o: serial/psb_i_base_vect_mod.o serial/psb_l_base_vect_mod.o comm/comm_schemes/psb_comm_schemes_mod.o comm/comm_schemes/psb_comm_factory_mod.o +serial/psb_i_base_vect_mod.o serial/psb_l_base_vect_mod.o serial/psb_c_base_vect_mod.o serial/psb_s_base_vect_mod.o serial/psb_d_base_vect_mod.o serial/psb_z_base_vect_mod.o: auxil/psi_serial_mod.o psb_realloc_mod.o comm/comm_schemes/psb_comm_neighbor_impl_mod.o \ + comm/psb_neighbor_topology_mod.o serial/psb_s_mat_mod.o: serial/psb_s_base_mat_mod.o serial/psb_s_csr_mat_mod.o serial/psb_s_csc_mat_mod.o serial/psb_s_vect_mod.o \ serial/psb_i_vect_mod.o serial/psb_l_vect_mod.o serial/psb_d_mat_mod.o: serial/psb_d_base_mat_mod.o serial/psb_d_csr_mat_mod.o serial/psb_d_csc_mat_mod.o serial/psb_d_vect_mod.o \ @@ -342,11 +357,12 @@ comm/psb_comm_mod.o: desc/psb_desc_mod.o serial/psb_mat_mod.o comm/psb_comm_mod.o: comm/psb_i_comm_mod.o comm/psb_l_comm_mod.o \ comm/psb_s_comm_mod.o comm/psb_d_comm_mod.o \ comm/psb_c_comm_mod.o comm/psb_z_comm_mod.o \ + comm/psb_i2_comm_a_mod.o \ comm/psb_m_comm_a_mod.o comm/psb_e_comm_a_mod.o \ comm/psb_s_comm_a_mod.o comm/psb_d_comm_a_mod.o\ comm/psb_c_comm_a_mod.o comm/psb_z_comm_a_mod.o -comm/psb_m_comm_a_mod.o comm/psb_e_comm_a_mod.o \ +comm/psb_i2_comm_a_mod.o comm/psb_m_comm_a_mod.o comm/psb_e_comm_a_mod.o \ comm/psb_s_comm_a_mod.o comm/psb_d_comm_a_mod.o\ comm/psb_c_comm_a_mod.o comm/psb_z_comm_a_mod.o: desc/psb_desc_mod.o @@ -420,5 +436,4 @@ clean: /bin/rm -f $(MODULES) $(OBJS) $(MPFOBJS) *$(.mod) veryclean: clean - /bin/rm -f *.h - + /bin/rm -f *.h \ No newline at end of file diff --git a/base/modules/comm/comm_schemes/psb_comm_baseline_mod.F90 b/base/modules/comm/comm_schemes/psb_comm_baseline_mod.F90 new file mode 100644 index 00000000..350fd372 --- /dev/null +++ b/base/modules/comm/comm_schemes/psb_comm_baseline_mod.F90 @@ -0,0 +1,62 @@ +module psb_comm_baseline_mod + use psb_comm_schemes_mod, only: psb_comm_handle_type, psb_comm_isend_irecv_ + use psb_const_mod + implicit none + + type, extends(psb_comm_handle_type) :: psb_comm_baseline_handle + ! MPI request IDs for Isend/Irecv (dimension: num_neighbors x 2) + ! First column: send requests, second column: recv requests + integer(psb_ipk_), allocatable :: comid(:,:) + contains + procedure, pass :: init => psb_comm_baseline_init + procedure, pass :: free => psb_comm_baseline_free + procedure, pass :: set_swap_status => psb_comm_baseline_set_swap_status + procedure, pass :: get_swap_status => psb_comm_baseline_get_swap_status + end type psb_comm_baseline_handle + +contains + + subroutine psb_comm_baseline_init(this, info) + implicit none + class(psb_comm_baseline_handle), intent(inout) :: this + integer(psb_ipk_), intent(out) :: info + info = 0 + this%comm_type = psb_comm_isend_irecv_ + this%id = 0 + this%swap_status = 0 + end subroutine psb_comm_baseline_init + + subroutine psb_comm_baseline_free(this, info) + class(psb_comm_baseline_handle), intent(inout) :: this + integer(psb_ipk_), intent(out) :: info + info = 0 + ! Free MPI resources (comid) + if (allocated(this%comid)) deallocate(this%comid, stat=info) + end subroutine psb_comm_baseline_free + + subroutine psb_comm_baseline_set_swap_status(this, flag, info) + class(psb_comm_baseline_handle), intent(inout) :: this + integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(out) :: info + info = 0 + this%swap_status = flag + end subroutine psb_comm_baseline_set_swap_status + + subroutine psb_comm_baseline_get_swap_status(this, flag, info) + class(psb_comm_baseline_handle), intent(in) :: this + integer(psb_ipk_), intent(out) :: flag + integer(psb_ipk_), intent(out) :: info + info = 0 + flag = this%swap_status + end subroutine psb_comm_baseline_get_swap_status + + ! Allocate comid array for num_neighbors + subroutine psb_comm_baseline_alloc_comid(this, n, info) + implicit none + class(psb_comm_baseline_handle), intent(inout) :: this + integer(psb_ipk_), intent(in) :: n + integer(psb_ipk_), intent(out) :: info + allocate(this%comid(n, 2_psb_ipk_), stat=info) + end subroutine psb_comm_baseline_alloc_comid + +end module psb_comm_baseline_mod diff --git a/base/modules/comm/comm_schemes/psb_comm_factory_mod.F90 b/base/modules/comm/comm_schemes/psb_comm_factory_mod.F90 new file mode 100644 index 00000000..1652f90b --- /dev/null +++ b/base/modules/comm/comm_schemes/psb_comm_factory_mod.F90 @@ -0,0 +1,100 @@ +module psb_comm_factory_mod + use psb_const_mod + use psb_comm_schemes_mod, only: psb_comm_handle_type, psb_comm_ineighbor_alltoallv_, & + & psb_comm_persistent_ineighbor_alltoallv_, psb_comm_unknown_ + use psb_comm_baseline_mod, only: psb_comm_baseline_handle + use psb_comm_neighbor_impl_mod, only: psb_comm_neighbor_handle + implicit none + +contains + + ! Allocatable-based factory routines (preferred names) + subroutine psb_comm_init(comm_type, handle, info) + implicit none + integer(psb_ipk_), intent(in) :: comm_type + class(psb_comm_handle_type), allocatable, intent(inout) :: handle + integer(psb_ipk_), intent(out) :: info + + info = 0 + if (allocated(handle)) then + info = -1 + return + end if + select case(comm_type) + case(psb_comm_ineighbor_alltoallv_, psb_comm_persistent_ineighbor_alltoallv_) + allocate(psb_comm_neighbor_handle :: handle, stat=info) + if (info /= 0) return + call handle%init(info) + if (info /= 0) return + select type(h => handle) + type is(psb_comm_neighbor_handle) + h%comm_type = comm_type + h%use_persistent_buffers = (comm_type == psb_comm_persistent_ineighbor_alltoallv_) + end select + case default + allocate(psb_comm_baseline_handle :: handle, stat=info) + if (info /= 0) return + call handle%init(info) + end select + end subroutine psb_comm_init + + subroutine psb_comm_free(handle, info) + implicit none + class(psb_comm_handle_type), allocatable, intent(inout) :: handle + integer(psb_ipk_), intent(out) :: info + + info = 0 + if (.not. allocated(handle)) return + call handle%free(info) + if (allocated(handle)) then + deallocate(handle) + end if + end subroutine psb_comm_free + + + ! Allocatable-based factory routines + subroutine psb_comm_create(comm_type, handle, info) + implicit none + integer(psb_ipk_), intent(in) :: comm_type + class(psb_comm_handle_type), allocatable, intent(inout) :: handle + integer(psb_ipk_), intent(out) :: info + + call psb_comm_init(comm_type, handle, info) + end subroutine psb_comm_create + + subroutine psb_comm_destroy(handle, info) + implicit none + class(psb_comm_handle_type), allocatable, intent(inout) :: handle + integer(psb_ipk_), intent(out) :: info + + call psb_comm_free(handle, info) + end subroutine psb_comm_destroy + + subroutine psb_comm_set_swap_status(handle, flag, info) + implicit none + class(psb_comm_handle_type), allocatable, intent(inout) :: handle + integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(out) :: info + info = 0 + if (.not. allocated(handle)) then + info = -1 + return + end if + call handle%set_swap_status(flag, info) + end subroutine psb_comm_set_swap_status + + subroutine psb_comm_get_swap_status(handle, flag, info) + implicit none + class(psb_comm_handle_type), allocatable, intent(in) :: handle + integer(psb_ipk_), intent(out) :: flag + integer(psb_ipk_), intent(out) :: info + info = 0 + if (.not. allocated(handle)) then + flag = 0 + info = -1 + return + end if + call handle%get_swap_status(flag, info) + end subroutine psb_comm_get_swap_status + +end module psb_comm_factory_mod diff --git a/base/modules/comm/comm_schemes/psb_comm_neighbor_impl_mod.F90 b/base/modules/comm/comm_schemes/psb_comm_neighbor_impl_mod.F90 new file mode 100644 index 00000000..a4e8423d --- /dev/null +++ b/base/modules/comm/comm_schemes/psb_comm_neighbor_impl_mod.F90 @@ -0,0 +1,431 @@ +! Merged neighbor-topology module +! +module psb_comm_neighbor_impl_mod + use psb_const_mod + use psb_desc_const_mod, only: psb_proc_id_, psb_n_elem_recv_, psb_elem_recv_, & + & psb_n_elem_send_, psb_elem_send_ + use psb_error_mod + use psb_comm_schemes_mod, only: psb_comm_handle_type, psb_comm_ineighbor_alltoallv_ +#ifdef PSB_MPI_MOD + use mpi +#endif + implicit none +#ifdef PSB_MPI_H + include 'mpif.h' +#endif + + type, extends(psb_comm_handle_type) :: psb_comm_neighbor_handle + integer(psb_mpk_) :: graph_comm = mpi_comm_null + integer(psb_ipk_) :: num_neighbors = 0 + integer(psb_mpk_), allocatable :: send_counts(:), recv_counts(:) + integer(psb_mpk_), allocatable :: send_displs(:), recv_displs(:) + integer(psb_ipk_), allocatable :: send_indexes(:) + integer(psb_ipk_), allocatable :: recv_indexes(:) + integer(psb_ipk_) :: total_send = 0 + integer(psb_ipk_) :: total_recv = 0 + logical :: is_initialized = .false. + logical :: use_persistent_buffers = .false. + integer(psb_mpk_) :: comm_request = mpi_request_null + integer(psb_mpk_) :: persistent_request = mpi_request_null + logical :: persistent_request_ready = .false. + integer(psb_ipk_) :: persistent_buffer_size = 0 + contains + procedure, pass :: init => psb_comm_neighbor_init + procedure, pass :: free => neighbor_topology_free + procedure, pass :: set_swap_status => psb_comm_neighbor_set_swap_status + procedure, pass :: get_swap_status => psb_comm_neighbor_get_swap_status + procedure, pass :: topology_init => neighbor_topology_init + procedure, pass :: sizeof => neighbor_topology_sizeof + end type psb_comm_neighbor_handle + + +contains + + ! --------------------------------------------------------------- + ! neighbor_topology_init + ! + ! Parse the halo index list (obtained via desc_a%get_list_p) + ! and build: + ! - MPI dist-graph communicator with only the true neighbors + ! - per-neighbor send/recv counts and displacements + ! - contiguous gather/scatter index arrays + ! + ! The topology is stored inside the vector and lazily built + ! on the first psi_swapdata call that uses the neighbor-alltoallv + ! communication mode. + ! + ! Arguments: + ! topology - the persistent state (output, intent inout) + ! halo_index - halo_index array (from get_list_p, intent in) + ! num_neighbors - number of exchanges (from get_list_p) + ! total_send_elems - total send count (from get_list_p) + ! total_recv_elems - total recv count (from get_list_p) + ! ctxt - PSBLAS context + ! icomm - MPI communicator + ! info - error code (output) + ! --------------------------------------------------------------- + subroutine neighbor_topology_init(topology, halo_index, num_neighbors, & + & total_send_elems, total_recv_elems, ctxt, icomm, info) +#ifdef PSB_MPI_MOD + use mpi +#endif + implicit none +#ifdef PSB_MPI_H + include 'mpif.h' +#endif + + class(psb_comm_neighbor_handle), intent(inout) :: topology + integer(psb_ipk_), intent(in) :: halo_index(:) + integer(psb_ipk_), intent(in) :: num_neighbors, total_send_elems, total_recv_elems + type(psb_ctxt_type), intent(in) :: ctxt + integer(psb_mpk_), intent(in) :: icomm + integer(psb_ipk_), intent(out) :: info + + ! locals + integer(psb_mpk_) :: iret + integer(psb_ipk_) :: i, k, idx_ptr, num_elem_recv, num_elem_send, partner_proc + integer(psb_ipk_) :: neighbor_count, send_offset, recv_offset + integer(psb_mpk_), allocatable :: source_ranks(:), dest_ranks(:) + integer(psb_mpk_), allocatable :: source_weights(:), dest_weights(:) + integer(psb_mpk_) :: in_degree, out_degree + character(len=40) :: name + integer(psb_ipk_) :: proc_id + integer(psb_ipk_) :: position + integer(psb_ipk_) :: err_act + + info = psb_success_ + name = 'neighbor_topology_init' + call psb_erractionsave(err_act) + + ! Clean up any previous state + call topology%free(info) + + ! ---------------------------------------------------------- + ! First pass: count neighbors (excluding self) and totals + ! ---------------------------------------------------------- + topology%num_neighbors = 0 + topology%total_send = 0 + topology%total_recv = 0 + + if(size(halo_index) < 1) then + call psb_errpush(psb_err_topology_invalid_args_,name) + goto 9999 + end if + + allocate(source_ranks(num_neighbors), stat=info) + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info, name, a_err='Source ranks allocation failed') + goto 9999 + end if + + allocate(dest_ranks(num_neighbors), stat=info) + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info, name, a_err='Destination ranks allocation failed') + goto 9999 + end if + + allocate(source_weights(num_neighbors), stat=info) + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info, name, a_err='Source weights allocation failed') + goto 9999 + end if + + allocate(dest_weights(num_neighbors), stat=info) + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info, name, a_err='Destination weights allocation failed') + goto 9999 + end if + + allocate(topology%send_counts(num_neighbors), stat=info) + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info, name, a_err='Send counts allocation failed') + goto 9999 + end if + + allocate(topology%recv_counts(num_neighbors), stat=info) + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info, name, a_err='Receive counts allocation failed') + goto 9999 + end if + + allocate(topology%send_displs(num_neighbors), stat=info) + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info, name, a_err='Send displacements allocation failed') + goto 9999 + end if + + allocate(topology%recv_displs(num_neighbors), stat=info) + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info, name, a_err='Receive displacements allocation failed') + goto 9999 + end if + + + ! ----------------------------------------------------------- + ! Allocate the gather/scatter index arrays + ! ----------------------------------------------------------- + allocate(topology%send_indexes(total_send_elems), stat=info) + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info, name, a_err='Send indexes allocation failed') + goto 9999 + end if + + allocate(topology%recv_indexes(total_recv_elems), stat=info) + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info, name, a_err='Recv indexes allocation failed') + goto 9999 + end if + + ! ----------------------------------------------------------- + ! Fill neighbor ranks, weights, counts, displacements, + ! and gather/scatter index arrays. + ! + ! The halo_index layout per neighbor (starting at position): + ! position + 0 : process id + ! position + 1 : nerv (num recv elements) + ! position + 2 .. +1+nerv : recv element indexes + ! position + 2+nerv : nesd (num send elements) + ! position + 3+nerv .. +2+nerv+nesd : send element indexes + ! Total stride per neighbor: nerv + nesd + 3 + ! ----------------------------------------------------------- + send_offset = 0 + recv_offset = 0 + position = 1 + + do i = 1, num_neighbors + proc_id = halo_index(position) + num_elem_recv = halo_index(position + 1) + num_elem_send = halo_index(position + num_elem_recv + 2) + + ! Fill source/destination ranks and weights (weights are all 1 for now) + source_ranks(i) = int(proc_id, psb_mpk_) + dest_ranks(i) = int(proc_id, psb_mpk_) + source_weights(i) = 1 + dest_weights(i) = 1 + + ! Counts and displacements (displs set BEFORE accumulating offset) + topology%send_counts(i) = int(num_elem_send, psb_mpk_) + topology%recv_counts(i) = int(num_elem_recv, psb_mpk_) + topology%send_displs(i) = int(send_offset, psb_mpk_) + topology%recv_displs(i) = int(recv_offset, psb_mpk_) + + ! Fill recv_indexes from halo_index(position+2 .. position+1+nerv) + do k = 1, num_elem_recv + topology%recv_indexes(recv_offset + k) = halo_index(position + psb_elem_recv_ + k - 1) + end do + + ! Fill send_indexes from halo_index(position+3+nerv .. position+2+nerv+nesd) + do k = 1, num_elem_send + topology%send_indexes(send_offset + k) = halo_index(position + num_elem_recv + psb_elem_send_ + k - 1) + end do + + send_offset = send_offset + num_elem_send + recv_offset = recv_offset + num_elem_recv + + topology%num_neighbors = topology%num_neighbors + 1 + topology%total_send = topology%total_send + num_elem_send + topology%total_recv = topology%total_recv + num_elem_recv + + position = position + num_elem_recv + num_elem_send + 3 + end do + + ! ---------------------------------------------------------- + ! Sanity check: the totals computed from the neighbor list + ! should match the totals returned by get_list_p. + ! ---------------------------------------------------------- + if (topology%total_send /= total_send_elems) then + info = psb_err_topology_args_mismatch_ + call psb_errpush(info, name, a_err='Send elements mismatch') + goto 9999 + end if + + if (topology%total_recv /= total_recv_elems) then + info = psb_err_topology_args_mismatch_ + call psb_errpush(info, name, a_err='Receive elements mismatch') + goto 9999 + end if + + if(topology%num_neighbors /= num_neighbors) then + info = psb_err_topology_args_mismatch_ + call psb_errpush(info, name, a_err='Number of neighbors mismatch') + goto 9999 + end if + + + ! ---------------------------------------------------------- + ! Build the dist-graph communicator + ! ---------------------------------------------------------- + in_degree = topology%num_neighbors !! Just for clarity + out_degree = topology%num_neighbors !! Just for clarity + + call mpi_dist_graph_create_adjacent(icomm, & + & in_degree, source_ranks, source_weights, & + & out_degree, dest_ranks, dest_weights, & + & mpi_info_null, .false., & ! Check this line for optimizations + & topology%graph_comm, info) + if (info /= mpi_success) then + info = psb_err_topology_error_ + call psb_errpush(info, name) + goto 9999 + end if + + topology%is_initialized = .true. + + ! TODO: Is it safe to deallocate these temporary arrays here, or do we need them for the gather/scatter indexes? + ! deallocate(source_ranks, dest_ranks, source_weights, dest_weights) + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + end subroutine neighbor_topology_init + + + ! --------------------------------------------------------------- + ! neighbor_topology_free + ! Release all resources held by the persistent state. + ! --------------------------------------------------------------- + subroutine neighbor_topology_free(this, info) +#ifdef PSB_MPI_MOD + use mpi +#endif + implicit none +#ifdef PSB_MPI_H + include 'mpif.h' +#endif + class(psb_comm_neighbor_handle), intent(inout) :: this + integer(psb_ipk_), intent(out) :: info + integer(psb_mpk_) :: iret + + info = psb_success_ + + if (this%persistent_request_ready) then + if (this%persistent_request /= mpi_request_null) then + call mpi_request_free(this%persistent_request, iret) + end if + this%persistent_request = mpi_request_null + this%persistent_request_ready = .false. + this%persistent_buffer_size = 0 + end if + + if (this%graph_comm /= mpi_comm_null) then + call mpi_comm_free(this%graph_comm, iret) + this%graph_comm = mpi_comm_null + end if + + if (allocated(this%send_counts)) deallocate(this%send_counts) + if (allocated(this%recv_counts)) deallocate(this%recv_counts) + if (allocated(this%send_displs)) deallocate(this%send_displs) + if (allocated(this%recv_displs)) deallocate(this%recv_displs) + if (allocated(this%send_indexes)) deallocate(this%send_indexes) + if (allocated(this%recv_indexes)) deallocate(this%recv_indexes) + + this%num_neighbors = 0 + this%total_send = 0 + this%total_recv = 0 + this%is_initialized = .false. + this%comm_request = mpi_request_null + + end subroutine neighbor_topology_free + + + ! --------------------------------------------------------------- + ! neighbor_topology_sizeof + ! Return approximate memory footprint in bytes. + ! --------------------------------------------------------------- + function neighbor_topology_sizeof(this) result(val) + implicit none + class(psb_comm_neighbor_handle), intent(in) :: this + integer(psb_epk_) :: val + + val = 0 + val = val + psb_sizeof_ip * 6 ! scalar integers + logicals + if (allocated(this%send_counts)) val = val + psb_sizeof_ip * size(this%send_counts) + if (allocated(this%recv_counts)) val = val + psb_sizeof_ip * size(this%recv_counts) + if (allocated(this%send_displs)) val = val + psb_sizeof_ip * size(this%send_displs) + if (allocated(this%recv_displs)) val = val + psb_sizeof_ip * size(this%recv_displs) + if (allocated(this%send_indexes)) val = val + psb_sizeof_ip * size(this%send_indexes) + if (allocated(this%recv_indexes)) val = val + psb_sizeof_ip * size(this%recv_indexes) + + + end function neighbor_topology_sizeof + + + subroutine psb_comm_neighbor_create(this, comm_type, info) + class(psb_comm_neighbor_handle), intent(inout) :: this + integer(psb_ipk_), intent(in) :: comm_type + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + this%comm_type = comm_type + this%id = 0 + this%swap_status = 0 + this%comm_request = mpi_request_null + this%persistent_request = mpi_request_null + this%persistent_request_ready = .false. + this%persistent_buffer_size = 0 + + call this%free(info) + end subroutine psb_comm_neighbor_create + + + subroutine psb_comm_neighbor_destroy(this, info) + class(psb_comm_neighbor_handle), intent(inout) :: this + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + this%comm_request = mpi_request_null + this%persistent_request = mpi_request_null + this%persistent_request_ready = .false. + this%persistent_buffer_size = 0 + call this%free(info) + end subroutine psb_comm_neighbor_destroy + + + subroutine psb_comm_neighbor_set_swap_status(this, flag, info) + class(psb_comm_neighbor_handle), intent(inout) :: this + integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + this%swap_status = flag + end subroutine psb_comm_neighbor_set_swap_status + + + subroutine psb_comm_neighbor_get_swap_status(this, flag, info) + class(psb_comm_neighbor_handle), intent(in) :: this + integer(psb_ipk_), intent(out) :: flag + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + flag = this%swap_status + end subroutine psb_comm_neighbor_get_swap_status + + subroutine psb_comm_neighbor_init(this, info) + class(psb_comm_neighbor_handle), intent(inout) :: this + integer(psb_ipk_), intent(out) :: info + info = 0 + this%comm_type = psb_comm_ineighbor_alltoallv_ + this%id = 0 + this%swap_status = 0 + this%is_initialized = .false. + this%use_persistent_buffers = .false. + this%comm_request = mpi_request_null + this%persistent_request = mpi_request_null + this%persistent_request_ready = .false. + this%persistent_buffer_size = 0 + end subroutine psb_comm_neighbor_init + + +end module psb_comm_neighbor_impl_mod diff --git a/base/modules/comm/comm_schemes/psb_comm_schemes_mod.F90 b/base/modules/comm/comm_schemes/psb_comm_schemes_mod.F90 new file mode 100644 index 00000000..0c91d858 --- /dev/null +++ b/base/modules/comm/comm_schemes/psb_comm_schemes_mod.F90 @@ -0,0 +1,68 @@ +! +! psb_comm_mod - communication handle module +! +module psb_comm_schemes_mod + use psb_const_mod + implicit none + + ! Communication type enumeration (keeps compatibility with integer selectors) + enum, bind(c) + enumerator psb_comm_unknown_ + enumerator psb_comm_isend_irecv_ + enumerator psb_comm_ineighbor_alltoallv_ + enumerator psb_comm_persistent_ineighbor_alltoallv_ + end enum + + enum, bind(c) + enumerator psb_comm_status_unknown_ + enumerator psb_comm_status_start_ + enumerator psb_comm_status_wait_ + end enum + + + ! (abstract interfaces moved below type definition) + + ! --- comm handle type --- + type, abstract :: psb_comm_handle_type + integer(psb_ipk_) :: id = -1 + integer(psb_ipk_) :: comm_type = psb_comm_unknown_ + integer(psb_ipk_) :: swap_status = psb_comm_status_unknown_ + contains + procedure(psb_comm_init), deferred :: init + procedure(psb_comm_free), deferred :: free + procedure(psb_comm_set_swap_status), deferred :: set_swap_status + procedure(psb_comm_get_swap_status), deferred :: get_swap_status + end type psb_comm_handle_type + + ! --- abstract interfaces --- + abstract interface + subroutine psb_comm_init(this, info) + import :: psb_ipk_, psb_comm_handle_type + class(psb_comm_handle_type), intent(inout) :: this + integer(psb_ipk_), intent(out) :: info + end subroutine + + subroutine psb_comm_free(this, info) + import :: psb_ipk_, psb_comm_handle_type + class(psb_comm_handle_type), intent(inout) :: this + integer(psb_ipk_), intent(out) :: info + end subroutine + + subroutine psb_comm_set_swap_status(this, flag, info) + import :: psb_ipk_, psb_comm_handle_type + class(psb_comm_handle_type), intent(inout) :: this + integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(out) :: info + end subroutine + + subroutine psb_comm_get_swap_status(this, flag, info) + import :: psb_ipk_, psb_comm_handle_type + class(psb_comm_handle_type), intent(in) :: this + integer(psb_ipk_), intent(out) :: flag + integer(psb_ipk_), intent(out) :: info + end subroutine + end interface + +contains + +end module psb_comm_schemes_mod diff --git a/base/modules/comm/psi_d_comm_v_mod.f90 b/base/modules/comm/psi_d_comm_v_mod.f90 index 19380e28..fdc624b4 100644 --- a/base/modules/comm/psi_d_comm_v_mod.f90 +++ b/base/modules/comm/psi_d_comm_v_mod.f90 @@ -38,18 +38,18 @@ module psi_d_comm_v_mod interface psi_swapdata ! --------------------------------------------------------------- ! Wrapper that calls different communications schemes depending on - ! flag variable using communication buff obtained from desc_a%get_list_p + ! swap_status variable using communication buff obtained from desc_a%get_list_p ! --------------------------------------------------------------- - module subroutine psi_dswapdata_vect(flag,beta,y,desc_a,info,data) - integer(psb_ipk_), intent(in) :: flag + module subroutine psi_dswapdata_vect(swap_status,beta,y,desc_a,info,data) + integer(psb_ipk_), intent(in) :: swap_status real(psb_dpk_), intent(in) :: beta class(psb_d_base_vect_type), intent(inout) :: y type(psb_desc_type), target :: desc_a integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), optional :: data end subroutine psi_dswapdata_vect - module subroutine psi_dswapdata_multivect(flag,beta,y,desc_a,info,data) - integer(psb_ipk_), intent(in) :: flag + module subroutine psi_dswapdata_multivect(swap_status,beta,y,desc_a,info,data) + integer(psb_ipk_), intent(in) :: swap_status real(psb_dpk_), intent(in) :: beta class(psb_d_base_multivect_type), intent(inout) :: y type(psb_desc_type), target :: desc_a @@ -63,18 +63,18 @@ module psi_d_comm_v_mod ! --------------------------------------------------------------- ! Upper call in order to populate idx using desc_a%get_list_p ! and then call different communications schemes depending - ! on flag variable + ! on swap_status variable ! --------------------------------------------------------------- - module subroutine psi_dswaptran_vect(flag,beta,y,desc_a,info,data) - integer(psb_ipk_), intent(in) :: flag + module subroutine psi_dswaptran_vect(swap_status,beta,y,desc_a,info,data) + integer(psb_ipk_), intent(in) :: swap_status real(psb_dpk_), intent(in) :: beta class(psb_d_base_vect_type), intent(inout) :: y type(psb_desc_type), target :: desc_a integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), optional :: data end subroutine psi_dswaptran_vect - module subroutine psi_dswaptran_multivect(flag,beta,y,desc_a,info,data) - integer(psb_ipk_), intent(in) :: flag + module subroutine psi_dswaptran_multivect(swap_status,beta,y,desc_a,info,data) + integer(psb_ipk_), intent(in) :: swap_status real(psb_dpk_), intent(in) :: beta class(psb_d_base_multivect_type), intent(inout) :: y type(psb_desc_type), target :: desc_a diff --git a/base/modules/serial/psb_d_base_vect_mod.F90 b/base/modules/serial/psb_d_base_vect_mod.F90 index 9a239220..1e270644 100644 --- a/base/modules/serial/psb_d_base_vect_mod.F90 +++ b/base/modules/serial/psb_d_base_vect_mod.F90 @@ -49,7 +49,8 @@ module psb_d_base_vect_mod use psb_realloc_mod use psb_i_base_vect_mod use psb_l_base_vect_mod - use psb_neighbor_topology_mod + use psb_comm_schemes_mod, only: psb_comm_handle_type, psb_comm_isend_irecv_, psb_comm_unknown_ + use psb_comm_factory_mod, only: psb_comm_init, psb_comm_free !> \namespace psb_base_mod \class psb_d_base_vect_type @@ -64,10 +65,10 @@ module psb_d_base_vect_mod !! type psb_d_base_vect_type !> Values. - real(psb_dpk_), allocatable :: v(:) - real(psb_dpk_), allocatable :: combuf(:) - integer(psb_mpk_), allocatable :: comid(:,:) ! This is used only for Isend/Irecv scheme, to store the communication handles for each neighbor - integer(psb_mpk_) :: communication_handle ! This is used only for Isend/Irecv scheme, to store the communication handle for the whole halo exchange + real(psb_dpk_), allocatable :: v(:) + real(psb_dpk_), allocatable :: combuf(:) + ! Polymorphic communication handle stored at vector level. + class(psb_comm_handle_type), allocatable :: comm_handle !> vector bldstate: !! null: pristine; @@ -81,9 +82,6 @@ module psb_d_base_vect_mod integer(psb_ipk_), private :: dupl = psb_dupl_null_ integer(psb_ipk_), private :: ncfs = 0 integer(psb_ipk_), allocatable :: iv(:) - - type(psb_neighbor_topology_type) :: neighbor_topology - contains ! ! Constructors/allocators @@ -147,8 +145,6 @@ module psb_d_base_vect_mod procedure, nopass :: device_wait => d_base_device_wait procedure, pass(x) :: maybe_free_buffer => d_base_maybe_free_buffer procedure, pass(x) :: free_buffer => d_base_free_buffer - procedure, pass(x) :: new_comid => d_base_new_comid - procedure, pass(x) :: free_comid => d_base_free_comid ! ! Basic info @@ -179,7 +175,13 @@ module psb_d_base_vect_mod generic, public :: sct => sctb, sctb_x, sctb_buf procedure, pass(x) :: check_addr => d_base_check_addr - + + ! Communication lifecycle split: + ! - `create_comm`: allocate/select a fresh handle implementation via factory. + ! - `init_comm`: configure the current handle instance from an existing one + ! (e.g., copy `id` and swap status), and reset buffers; + ! it recreates the handle only if missing or scheme changes. + ! - `destroy_comm`/`free_comm`: release current handle resources. ! @@ -255,12 +257,6 @@ module psb_d_base_vect_mod procedure, pass(x) :: minquotient_v => d_base_minquotient_v procedure, pass(x) :: minquotient_a2 => d_base_minquotient_a2 generic, public :: minquotient => minquotient_v, minquotient_a2 - - - ! Methods used to handle topology in neighbor_alltoallv communication scheme - procedure, pass(x) :: init_topology => d_base_init_topology - procedure, pass(x) :: free_topology => d_base_free_topology - end type psb_d_base_vect_type public :: psb_d_base_vect @@ -416,6 +412,11 @@ contains call psb_realloc(n,x%iv,info) call x%set_ncfs(0) end if + if (info == psb_success_) then + if (.not. allocated(x%comm_handle)) then + call psb_comm_init(psb_comm_isend_irecv_, x%comm_handle, info) + end if + end if end subroutine d_base_all @@ -434,6 +435,9 @@ contains integer(psb_ipk_), intent(out) :: info allocate(psb_d_base_vect_type :: y, stat=info) + if (info == psb_success_) then + call psb_comm_init(psb_comm_isend_irecv_, y%comm_handle, info) + end if end subroutine d_base_mold @@ -441,7 +445,7 @@ contains use psi_serial_mod use psb_realloc_mod implicit none - class(psb_d_base_vect_type), intent(out) :: x + class(psb_d_base_vect_type), intent(inout) :: x integer(psb_ipk_), intent(out) :: info logical, intent(in), optional :: clear logical :: clear_ @@ -458,6 +462,11 @@ contains call x%set_host() call x%set_upd() end if + if (info == psb_success_) then + if (.not. allocated(x%comm_handle)) then + call psb_comm_init(psb_comm_isend_irecv_, x%comm_handle, info) + end if + end if end subroutine d_base_reinit @@ -837,11 +846,16 @@ contains class(psb_d_base_vect_type), intent(inout) :: x integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: info_comm + info = 0 if (allocated(x%v)) deallocate(x%v, stat=info) if ((info == 0).and.allocated(x%combuf)) call x%free_buffer(info) - if ((info == 0).and.allocated(x%comid)) call x%free_comid(info) - if ((info == 0).and.allocated(x%iv)) deallocate(x%iv, stat=info) + if ((info == 0).and.allocated(x%iv)) deallocate(x%iv, stat=info) + if ((info == 0).and.allocated(x%comm_handle)) then + call psb_comm_free(x%comm_handle, info_comm) + if (info_comm /= psb_success_) info = info_comm + end if if (info /= 0) call & & psb_errpush(psb_err_alloc_dealloc_,'vect_free') call x%set_null() @@ -888,24 +902,6 @@ contains end subroutine d_base_maybe_free_buffer - ! - !> Function base_free_comid: - !! \memberof psb_d_base_vect_type - !! \brief Free aux MPI communication id buffer - !! - !! \param info return code - !! - ! - subroutine d_base_free_comid(x,info) - use psb_realloc_mod - implicit none - class(psb_d_base_vect_type), intent(inout) :: x - integer(psb_ipk_), intent(out) :: info - - if (allocated(x%comid)) & - & deallocate(x%comid,stat=info) - end subroutine d_base_free_comid - function d_base_get_ncfs(x) result(res) implicit none class(psb_d_base_vect_type), intent(in) :: x @@ -1109,12 +1105,25 @@ contains implicit none class(psb_d_base_vect_type), intent(in) :: x class(psb_d_base_vect_type), intent(out) :: y + integer(psb_ipk_) :: info + integer(psb_ipk_) :: swap_status if (allocated(x%v)) call y%bld(x%v) call y%set_state(x%get_state()) call y%set_dupl(x%get_dupl()) call y%set_ncfs(x%get_ncfs()) if (allocated(x%iv)) y%iv = x%iv + if (allocated(x%comm_handle)) then + call psb_comm_init(x%comm_handle%comm_type, y%comm_handle, info) + if (info /= psb_success_) return + y%comm_handle%id = x%comm_handle%id + call x%comm_handle%get_swap_status(swap_status, info) + if (info /= psb_success_) return + call y%comm_handle%set_swap_status(swap_status, info) + if (info /= psb_success_) return + else + call psb_comm_init(psb_comm_isend_irecv_, y%comm_handle, info) + end if end subroutine d_base_cpy ! @@ -2364,6 +2373,56 @@ contains end subroutine d_base_device_wait + + subroutine d_base_init_comm(x, comm_handle, info) + ! `init_comm` is intentionally a configuration step. + ! It does not define the communication API surface itself; instead it: + ! 1) resets local communication buffers, + ! 2) ensures a handle exists with the requested concrete scheme, + ! recreating it only when needed (missing handle or type change), + ! 3) copies runtime state (`id`, swap status) from the input handle. + implicit none + class(psb_d_base_vect_type), intent(inout) :: x + class(psb_comm_handle_type), intent(in), pointer :: comm_handle + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: comm_type, swap_status + logical :: need_new_handle + + info = psb_success_ + + ! Reset/initialize communication related storage. Actual + ! topology/building is done lazily by neighbor_topology_init + if (allocated(x%combuf)) then + deallocate(x%combuf) + end if + + comm_type = psb_comm_isend_irecv_ + if (associated(comm_handle)) then + comm_type = comm_handle%comm_type + if (comm_type == psb_comm_unknown_) comm_type = psb_comm_isend_irecv_ + end if + + ! Recreate only when needed (missing handle or scheme change). + need_new_handle = .not. allocated(x%comm_handle) + if (.not. need_new_handle) then + need_new_handle = (x%comm_handle%comm_type /= comm_type) + end if + + if (need_new_handle) then + call psb_comm_init(comm_type, x%comm_handle, info) + if (info /= psb_success_) return + end if + + if (associated(comm_handle)) then + x%comm_handle%id = comm_handle%id + call comm_handle%get_swap_status(swap_status, info) + if (info /= psb_success_) return + call x%comm_handle%set_swap_status(swap_status, info) + if (info /= psb_success_) return + end if + + end subroutine d_base_init_comm + function d_base_use_buffer() result(res) logical :: res @@ -2380,16 +2439,6 @@ contains call psb_realloc(n,x%combuf,info) end subroutine d_base_new_buffer - subroutine d_base_new_comid(n,x,info) - use psb_realloc_mod - implicit none - class(psb_d_base_vect_type), intent(inout) :: x - integer(psb_ipk_), intent(in) :: n - integer(psb_ipk_), intent(out) :: info - - call psb_realloc(n,2_psb_ipk_,x%comid,info) - end subroutine d_base_new_comid - ! ! shortcut alpha=1 beta=0 @@ -2622,35 +2671,6 @@ contains call z%addconst(x%v,b,info) end subroutine d_base_addconst_v2 - - ! -------------------------------------------------------------------- - ! Implementation of methods used for neighbor alltoallv communication - ! -------------------------------------------------------------------- - subroutine d_base_init_topology(x, halo_index, num_exchanges, & - & total_send_elems, total_recv_elems, ctxt, icomm, info) - implicit none - class(psb_d_base_vect_type), intent(inout) :: x - integer(psb_ipk_), intent(in) :: halo_index(:) - integer(psb_ipk_), intent(in) :: num_exchanges, total_send_elems, total_recv_elems - type(psb_ctxt_type), intent(in) :: ctxt - integer(psb_mpk_), intent(in) :: icomm - integer(psb_ipk_), intent(out) :: info - - call x%neighbor_topology%init(halo_index, num_exchanges, & - & total_send_elems, total_recv_elems, ctxt, icomm, info) - - end subroutine d_base_init_topology - - subroutine d_base_free_topology(x, info) - implicit none - class(psb_d_base_vect_type), intent(inout) :: x - integer(psb_ipk_), intent(out) :: info - - call x%neighbor_topology%free(info) - - end subroutine d_base_free_topology - ! -------------------------------------------------------------------- - end module psb_d_base_vect_mod @@ -2660,7 +2680,7 @@ module psb_d_base_multivect_mod use psb_error_mod use psb_realloc_mod use psb_d_base_vect_mod - use psb_neighbor_topology_mod + use psb_comm_schemes_mod, only: psb_comm_handle_type !> \namespace psb_base_mod \class psb_d_base_vect_type !! The psb_d_base_vect_type @@ -2679,8 +2699,7 @@ module psb_d_base_multivect_mod !> Values. real(psb_dpk_), allocatable :: v(:,:) real(psb_dpk_), allocatable :: combuf(:) - integer(psb_mpk_), allocatable :: comid(:,:) ! This is used only for Isend/Irecv scheme, to store the communication handles for each neighbor - integer(psb_mpk_) :: communication_handle ! This is used only for Isend/Irecv scheme, to store the communication handle for the whole halo exchange + ! neighbor-specific communication state removed; comm_handle owned below !> vector bldstate: !! null: pristine; @@ -2695,7 +2714,7 @@ module psb_d_base_multivect_mod integer(psb_ipk_), private :: ncfs = 0 integer(psb_ipk_), allocatable :: iv(:) - type(psb_neighbor_topology_type) :: neighbor_topology + class(psb_comm_handle_type), allocatable :: comm_handle contains ! @@ -2804,8 +2823,6 @@ module psb_d_base_multivect_mod procedure, nopass :: device_wait => d_base_mlv_device_wait procedure, pass(x) :: maybe_free_buffer => d_base_mlv_maybe_free_buffer procedure, pass(x) :: free_buffer => d_base_mlv_free_buffer - procedure, pass(x) :: new_comid => d_base_mlv_new_comid - procedure, pass(x) :: free_comid => d_base_mlv_free_comid ! ! Gather/scatter. These are needed for MPI interfacing. @@ -2823,11 +2840,6 @@ module psb_d_base_multivect_mod procedure, pass(y) :: sctb_buf => d_base_mlv_sctb_buf generic, public :: sct => sctb, sctbr2, sctb_x, sctb_buf - ! Neighbor alltoallv communication topology handling - procedure, pass(x) :: init_topology => d_base_mlv_init_topology - procedure, pass(x) :: free_topology => d_base_mlv_free_topology - - end type psb_d_base_multivect_type interface psb_d_base_multivect @@ -4085,17 +4097,6 @@ contains call psb_realloc(n*nc,x%combuf,info) end subroutine d_base_mlv_new_buffer - subroutine d_base_mlv_new_comid(n,x,info) - use psb_realloc_mod - implicit none - class(psb_d_base_multivect_type), intent(inout) :: x - integer(psb_ipk_), intent(in) :: n - integer(psb_ipk_), intent(out) :: info - - call psb_realloc(n,2_psb_ipk_,x%comid,info) - end subroutine d_base_mlv_new_comid - - subroutine d_base_mlv_maybe_free_buffer(x,info) use psb_realloc_mod implicit none @@ -4119,17 +4120,6 @@ contains & deallocate(x%combuf,stat=info) end subroutine d_base_mlv_free_buffer - subroutine d_base_mlv_free_comid(x,info) - use psb_realloc_mod - implicit none - class(psb_d_base_multivect_type), intent(inout) :: x - integer(psb_ipk_), intent(out) :: info - - if (allocated(x%comid)) & - & deallocate(x%comid,stat=info) - end subroutine d_base_mlv_free_comid - - ! ! Gather: Y = beta * Y + alpha * X(IDX(:)) ! @@ -4351,35 +4341,8 @@ contains end subroutine d_base_mlv_device_wait - - - ! -------------------------------------------------------------------- - ! Implementation of methods used for neighbor alltoallv communication - ! -------------------------------------------------------------------- - subroutine d_base_mlv_init_topology(x, halo_index, num_exchanges, & - & total_send_elems, total_recv_elems, ctxt, icomm, info) - implicit none - class(psb_d_base_multivect_type), intent(inout) :: x - integer(psb_ipk_), intent(in) :: halo_index(:) - integer(psb_ipk_), intent(in) :: num_exchanges, total_send_elems, total_recv_elems - type(psb_ctxt_type), intent(in) :: ctxt - integer(psb_mpk_), intent(in) :: icomm - integer(psb_ipk_), intent(out) :: info - - call x%neighbor_topology%init(halo_index, num_exchanges, & - & total_send_elems, total_recv_elems, ctxt, icomm, info) - - end subroutine d_base_mlv_init_topology - - subroutine d_base_mlv_free_topology(x, info) - implicit none - class(psb_d_base_multivect_type), intent(inout) :: x - integer(psb_ipk_), intent(out) :: info - - call x%neighbor_topology%free(info) - - end subroutine d_base_mlv_free_topology - ! -------------------------------------------------------------------- - + ! + ! Communication routines for multivectors (delegates to base vector implementation) + ! end module psb_d_base_multivect_mod diff --git a/test/comm/psb_comm_test.F90 b/test/comm/psb_comm_test.F90 index 4ed282de..a908068a 100644 --- a/test/comm/psb_comm_test.F90 +++ b/test/comm/psb_comm_test.F90 @@ -18,12 +18,18 @@ program psb_comm_test use psb_base_mod use psi_mod + use psb_comm_factory_mod, only: psb_comm_init, psb_comm_free + use psb_comm_schemes_mod, only: psb_comm_ineighbor_alltoallv_, psb_comm_persistent_ineighbor_alltoallv_, & + & psb_comm_isend_irecv_ + use psb_comm_schemes_mod, only: psb_comm_status_start_, psb_comm_status_wait_, psb_comm_status_unknown_ implicit none ! ---- parameters ---- integer(psb_ipk_) :: idim integer(psb_ipk_) :: argc + integer(psb_ipk_) :: iters character(len=32) :: arg + character(len=16) :: mode ! ---- descriptor / context ---- type(psb_ctxt_type) :: ctxt @@ -33,11 +39,11 @@ program psb_comm_test integer(psb_lpk_), allocatable :: myidx(:) ! ---- vectors ---- - type(psb_d_vect_type) :: v_baseline, v_neighbor + type(psb_d_vect_type) :: v_baseline, v_neighbor, v_neighbor_persistent ! ---- temporary / comparison arrays ---- real(psb_dpk_), allocatable :: vals(:) - real(psb_dpk_), allocatable :: result_baseline(:), result_neighbor(:) + real(psb_dpk_), allocatable :: result_baseline(:), result_neighbor(:), result_persistent(:) real(psb_dpk_), allocatable :: expected(:) ! ---- halo index bookkeeping ---- @@ -46,7 +52,9 @@ program psb_comm_test ! ---- error / reporting ---- integer(psb_ipk_) :: n_pass, n_total, imode + logical :: comm_ok real(psb_dpk_) :: err, tol + real(psb_dpk_) :: t0, t1, dt, tsum_baseline, tsum_neighbor, tsum_neighbor_persistent integer(psb_lpk_), allocatable :: glob_col(:) character(len=40) :: name @@ -54,6 +62,8 @@ program psb_comm_test tol = 1.0d-12 n_pass = 0 n_total = 0 + iters = 5 + mode = 'both' ! ---- parse command-line argument for idim ---- idim = 10 @@ -64,7 +74,22 @@ program psb_comm_test if (i < argc) then call get_command_argument(i+1, arg) read(arg, *) idim - exit + end if + else if (trim(arg) == '--iters') then + if (i < argc) then + call get_command_argument(i+1, arg) + read(arg, *) iters + end if + end if + end do + + ! parse optional mode flag + do i = 1, argc + call get_command_argument(i, arg) + if (trim(arg) == '--mode') then + if (i < argc) then + call get_command_argument(i+1, arg) + read(arg, *) mode end if end if end do @@ -135,8 +160,10 @@ program psb_comm_test ! ================================================================== call psb_geall(v_baseline, desc_a, info) call psb_geall(v_neighbor, desc_a, info) + call psb_geall(v_neighbor_persistent, desc_a, info) call psb_geasb(v_baseline, desc_a, info, scratch=.true.) call psb_geasb(v_neighbor, desc_a, info, scratch=.true.) + call psb_geasb(v_neighbor_persistent, desc_a, info, scratch=.true.) ! Fill owned entries with the global index value allocate(vals(ncol)) @@ -146,6 +173,7 @@ program psb_comm_test end do call v_baseline%set_vect(vals) call v_neighbor%set_vect(vals) + call v_neighbor_persistent%set_vect(vals) deallocate(vals) ! ================================================================== @@ -162,36 +190,162 @@ program psb_comm_test ! ================================================================== ! 6. Baseline halo exchange (Isend/Irecv in one call) ! ================================================================== - imode = IOR(psb_swap_send_, psb_swap_recv_) ! v_baseline%v is a psb_d_base_vect_type - call psi_swapdata(flag=imode, beta=dzero, y=v_baseline%v, desc_a=desc_a, info=info, data=psb_comm_halo_) + call psi_swapdata( & + swap_status=psb_comm_status_start_, & + beta=dzero, & + y=v_baseline%v, & + desc_a=desc_a, & + info=info, & + data=psb_comm_halo_) if (info /= psb_success_) then write(psb_err_unit,*) my_rank, 'baseline swap error:', info call psb_abort(ctxt) end if + call psi_swapdata( & + swap_status=psb_comm_status_wait_, & + beta=dzero, & + y=v_baseline%v, & + desc_a=desc_a, & + info=info, & + data=psb_comm_halo_) + if (info /= psb_success_) then + write(psb_err_unit,*) my_rank, 'baseline swap error:', info + call psb_abort(ctxt) + end if + + ! ================================================================== ! 7. Neighbor topology halo exchange (start + wait) ! ================================================================== - imode = psb_swap_start_ - call psi_swapdata(imode, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_) + call psb_comm_init(psb_comm_ineighbor_alltoallv_, v_neighbor%v%comm_handle, info) + if (info /= 0) then + write(psb_err_unit,*) my_rank, 'psb_comm_init neighbor error:', info + call psb_abort(ctxt) + end if + call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_) if (info /= psb_success_) then write(psb_err_unit,*) my_rank, 'neighbor start error:', info call psb_abort(ctxt) end if - imode = psb_swap_wait_ - call psi_swapdata(imode, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_) + call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_) if (info /= psb_success_) then write(psb_err_unit,*) my_rank, 'neighbor wait error:', info call psb_abort(ctxt) end if + ! ================================================================== + ! 7b. Persistent-neighbor halo exchange (start + wait) + ! ================================================================== + call psb_comm_init(psb_comm_persistent_ineighbor_alltoallv_, v_neighbor_persistent%v%comm_handle, info) + if (info /= 0) then + write(psb_err_unit,*) my_rank, 'psb_comm_init persistent-neighbor error:', info + call psb_abort(ctxt) + end if + call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_) + if (info /= psb_success_) then + write(psb_err_unit,*) my_rank, 'persistent-neighbor start error:', info + call psb_abort(ctxt) + end if + call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_) + if (info /= psb_success_) then + write(psb_err_unit,*) my_rank, 'persistent-neighbor wait error:', info + call psb_abort(ctxt) + end if + + ! ================================================================== + ! 8. Performance: repeat exchanges and measure timings + ! ================================================================== + if (my_rank == 0) then + write(psb_out_unit,'("Timing: running ",i0," iterations for baseline, neighbor and persistent-neighbor")') iters + end if + + tsum_baseline = 0.0_psb_dpk_ + tsum_neighbor = 0.0_psb_dpk_ + tsum_neighbor_persistent = 0.0_psb_dpk_ + + call psb_comm_init(psb_comm_isend_irecv_, v_baseline%v%comm_handle, info) + call psb_comm_init(psb_comm_ineighbor_alltoallv_, v_neighbor%v%comm_handle, info) + call psb_comm_init(psb_comm_persistent_ineighbor_alltoallv_, v_neighbor_persistent%v%comm_handle, info) + + ! ---- Comm check: verify selected communication schemes ---- + n_total = n_total + 1 + comm_ok = allocated(v_baseline%v%comm_handle) .and. allocated(v_neighbor%v%comm_handle) .and. & + & allocated(v_neighbor_persistent%v%comm_handle) + + if (comm_ok) then + comm_ok = (v_baseline%v%comm_handle%comm_type == psb_comm_isend_irecv_) .and. & + & (v_neighbor%v%comm_handle%comm_type == psb_comm_ineighbor_alltoallv_) .and. & + & (v_neighbor_persistent%v%comm_handle%comm_type == psb_comm_persistent_ineighbor_alltoallv_) + end if + + if (my_rank == 0) then + if (comm_ok) then + write(psb_out_unit,'(" [PASS] comm scheme selection : baseline/neighbor/persistent OK")') + n_pass = n_pass + 1 + else + write(psb_out_unit,'(" [FAIL] comm scheme selection : unexpected comm_type mapping")') + end if + end if + + do i = 1, iters + ! baseline timing + t0 = psb_wtime() + call psi_swapdata( & + swap_status=psb_comm_status_start_, & + beta=dzero, & + y=v_baseline%v, & + desc_a=desc_a, & + info=info, & + data=psb_comm_halo_) + call psi_swapdata( & + swap_status=psb_comm_status_wait_, & + beta=dzero, & + y=v_baseline%v, & + desc_a=desc_a, & + info=info, & + data=psb_comm_halo_) + t1 = psb_wtime() + dt = t1 - t0 + call psb_amx(ctxt, dt) + tsum_baseline = tsum_baseline + dt + + ! neighbor timing (start + wait) + t0 = psb_wtime() + call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_) + call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_) + t1 = psb_wtime() + dt = t1 - t0 + call psb_amx(ctxt, dt) + tsum_neighbor = tsum_neighbor + dt + + ! persistent-neighbor timing (start + wait) + t0 = psb_wtime() + call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_) + call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_) + t1 = psb_wtime() + dt = t1 - t0 + call psb_amx(ctxt, dt) + tsum_neighbor_persistent = tsum_neighbor_persistent + dt + end do + + if (my_rank == 0) then + write(psb_out_unit,'(" Avg baseline time : ",es12.5)') (tsum_baseline / real(iters,psb_dpk_)) + write(psb_out_unit,'(" Tot baseline time : ",es12.5)') tsum_baseline + write(psb_out_unit,'(" Avg neighbor time : ",es12.5)') (tsum_neighbor / real(iters,psb_dpk_)) + write(psb_out_unit,'(" Tot neighbor time : ",es12.5)') tsum_neighbor + write(psb_out_unit,'(" Avg pers-neigh time: ",es12.5)') (tsum_neighbor_persistent / real(iters,psb_dpk_)) + write(psb_out_unit,'(" Tot pers-neigh time: ",es12.5)') tsum_neighbor_persistent + end if + ! ================================================================== ! 8. Extract results and compare ! ================================================================== result_baseline = v_baseline%get_vect() result_neighbor = v_neighbor%get_vect() + result_persistent = v_neighbor_persistent%get_vect() ! ---- Test 1: cross-check baseline vs neighbor (all entries) ---- n_total = n_total + 1 @@ -232,15 +386,41 @@ program psb_comm_test end if end if - ! ---- Test 4: repeat neighbor exchange (topology reuse) ---- + ! ---- Test 4: cross-check baseline vs persistent-neighbor (all entries) ---- + n_total = n_total + 1 + err = maxval(abs(result_baseline(1:ncol) - result_persistent(1:ncol))) + call psb_amx(ctxt, err) + if (my_rank == 0) then + if (err < tol) then + write(psb_out_unit,'(" [PASS] cross-check baseline vs pers-nei : err = ",es12.5)') err + n_pass = n_pass + 1 + else + write(psb_out_unit,'(" [FAIL] cross-check baseline vs pers-nei : err = ",es12.5)') err + end if + end if + + ! ---- Test 5: persistent-neighbor absolute correctness ---- + n_total = n_total + 1 + err = maxval(abs(result_persistent(1:ncol) - expected(1:ncol))) + call psb_amx(ctxt, err) + if (my_rank == 0) then + if (err < tol) then + write(psb_out_unit,'(" [PASS] pers-neigh absolute correctness : err = ",es12.5)') err + n_pass = n_pass + 1 + else + write(psb_out_unit,'(" [FAIL] pers-neigh absolute correctness : err = ",es12.5)') err + end if + end if + + ! ---- Test 6: repeat neighbor exchange (topology reuse) ---- ! Reset halo entries to zero, run again, and check do i = nrow+1, ncol result_neighbor(i) = dzero end do call v_neighbor%set_vect(result_neighbor) - call psi_swapdata(psb_swap_start_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_) - call psi_swapdata(psb_swap_wait_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_) + call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_) + call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_) result_neighbor = v_neighbor%get_vect() n_total = n_total + 1 @@ -255,6 +435,28 @@ program psb_comm_test end if end if + ! ---- Test 7: repeat persistent-neighbor exchange (buffer reuse) ---- + do i = nrow+1, ncol + result_persistent(i) = dzero + end do + call v_neighbor_persistent%set_vect(result_persistent) + + call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_) + call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_) + + result_persistent = v_neighbor_persistent%get_vect() + n_total = n_total + 1 + err = maxval(abs(result_persistent(1:ncol) - expected(1:ncol))) + call psb_amx(ctxt, err) + if (my_rank == 0) then + if (err < tol) then + write(psb_out_unit,'(" [PASS] pers-neigh buffer reuse : err = ",es12.5)') err + n_pass = n_pass + 1 + else + write(psb_out_unit,'(" [FAIL] pers-neigh buffer reuse : err = ",es12.5)') err + end if + end if + ! ================================================================== ! 9. Summary ! ================================================================== @@ -272,9 +474,11 @@ program psb_comm_test ! ================================================================== ! 10. Cleanup ! ================================================================== - deallocate(result_baseline, result_neighbor, expected, glob_col) - call psb_gefree(v_baseline, desc_a, info) + deallocate(result_baseline, result_neighbor, result_persistent, expected, glob_col) + +9999 call psb_gefree(v_baseline, desc_a, info) call psb_gefree(v_neighbor, desc_a, info) + call psb_gefree(v_neighbor_persistent, desc_a, info) call psb_cdfree(desc_a, info) call psb_exit(ctxt)