[ADD] Added polymorphic comm_handle data excahnge for non-blocking and persistant neighbor communication. Check test/comm/ to see how it's used. For now works only for d types

communication_v2
Stack-1 2 months ago
parent 02f1ef741c
commit 09a5a74d75

@ -437,7 +437,7 @@ set(PSB_base_source_files
# modules/comm/psi_i2_comm_a_mod.f90
modules/comm/psi_m_comm_a_mod.f90
modules/comm/psi_l_comm_v_mod.f90
modules/comm/psb_comm_mod.f90
modules/comm/psb_comm_mod.F90
modules/comm/psb_l_comm_mod.f90
modules/comm/psb_d_linmap_mod.f90
modules/comm/psi_d_comm_v_mod.f90

File diff suppressed because it is too large Load Diff

@ -59,12 +59,12 @@
!
!
! Arguments:
! flag - integer Choose the algorithm for data exchange:
! swap_status - integer Choose the algorithm for data exchange:
! this is chosen through bit fields.
! swap_mpi = iand(flag,psb_swap_mpi_) /= 0
! swap_sync = iand(flag,psb_swap_sync_) /= 0
! swap_send = iand(flag,psb_swap_send_) /= 0
! swap_recv = iand(flag,psb_swap_recv_) /= 0
! swap_mpi = iand(swap_status,psb_swap_mpi_) /= 0
! swap_sync = iand(swap_status,psb_swap_sync_) /= 0
! swap_send = iand(swap_status,psb_swap_send_) /= 0
! swap_recv = iand(swap_status,psb_swap_recv_) /= 0
! if (swap_mpi): use underlying MPI_ALLTOALLV.
! if (swap_sync): use PSB_SND and PSB_RCV in
! synchronized pairs
@ -92,8 +92,12 @@
!
submodule (psi_d_comm_v_mod) psi_d_swaptran_impl
use psb_base_mod
use psb_comm_schemes_mod, only: psb_comm_handle_type, psb_comm_isend_irecv_
use psb_comm_factory_mod, only: psb_comm_init, psb_comm_free
use psb_comm_baseline_mod, only: psb_comm_baseline_handle
use psb_comm_neighbor_impl_mod, only: psb_comm_neighbor_handle
contains
module subroutine psi_dswaptran_vect(flag,beta,y,desc_a,info,data)
module subroutine psi_dswaptran_vect(swap_status,beta,y,desc_a,info,data)
#ifdef PSB_MPI_MOD
use mpi
@ -103,7 +107,7 @@ contains
include 'mpif.h'
#endif
integer(psb_ipk_), intent(in) :: flag
integer(psb_ipk_), intent(in) :: swap_status
real(psb_dpk_), intent(in) :: beta
class(psb_d_base_vect_type), intent(inout) :: y
type(psb_desc_type),target :: desc_a
@ -153,37 +157,46 @@ contains
goto 9999
end if
swap_mpi = iand(flag,psb_swap_mpi_) /= 0
swap_sync = iand(flag,psb_swap_sync_) /= 0
swap_send = iand(flag,psb_swap_send_) /= 0
swap_recv = iand(flag,psb_swap_recv_) /= 0
swap_start = iand(flag,psb_swap_start_) /= 0
swap_wait = iand(flag,psb_swap_wait_) /= 0
swap_mpi = iand(swap_status,psb_swap_mpi_) /= 0
swap_sync = iand(swap_status,psb_swap_sync_) /= 0
swap_send = iand(swap_status,psb_swap_send_) /= 0
swap_recv = iand(swap_status,psb_swap_recv_) /= 0
swap_start = iand(swap_status,psb_swap_start_) /= 0
swap_wait = iand(swap_status,psb_swap_wait_) /= 0
baseline = swap_mpi .or. swap_send .or. swap_recv .or. swap_sync
neighbor_a2av = swap_start .or. swap_wait
if( (baseline.eqv..true.).and.(neighbor_a2av.eqv..true.) ) then
info = psb_err_mpi_error_
call psb_errpush(info,name,a_err='Incompatible flag settings: both baseline and neighbor_a2av are true')
call psb_errpush(info,name,a_err='Incompatible swap_status settings: both baseline and neighbor_a2av are true')
goto 9999
end if
if (.not. allocated(y%comm_handle)) then
call psb_comm_init(psb_comm_isend_irecv_, y%comm_handle, info)
if (info /= psb_success_) then
call psb_errpush(psb_err_internal_error_, name, a_err='init comm default baseline')
goto 9999
end if
end if
if (baseline) then
call psi_dtran_baseline_vect(ctxt,flag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,info)
call psi_dtran_baseline_vect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,total_send,total_recv,y%comm_handle,info)
if (info /= psb_success_) then
call psb_errpush(info,name,a_err='baseline swap')
goto 9999
end if
else if (neighbor_a2av) then
call psi_dtran_neighbor_topology_vect(ctxt,flag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,info)
call psi_dtran_neighbor_topology_vect(ctxt,swap_status,beta,y,comm_indexes,&
& num_neighbors,total_send,total_recv,y%comm_handle,info)
if (info /= psb_success_) then
call psb_errpush(info,name,a_err='neighbor a2av swap')
goto 9999
end if
else
info = psb_err_mpi_error_
call psb_errpush(info,name,a_err='Incompatible flag settings: neither baseline nor neighbor_a2av is true')
call psb_errpush(info,name,a_err='Incompatible swap_status settings: neither baseline nor neighbor_a2av is true')
goto 9999
end if
@ -208,8 +221,8 @@ contains
!
!
!
module subroutine psi_dtran_baseline_vect(ctxt,flag,beta,y,idx,&
& num_neighbors,total_send,total_recv,info)
module subroutine psi_dtran_baseline_vect(ctxt,swap_status,beta,y,idx,&
& num_neighbors,total_send,total_recv,comm_handle,info)
#ifdef PSB_MPI_MOD
use mpi
@ -220,11 +233,12 @@ contains
#endif
type(psb_ctxt_type), intent(in) :: ctxt
integer(psb_ipk_), intent(in) :: flag
integer(psb_ipk_), intent(in) :: swap_status
real(psb_dpk_), intent(in) :: beta
class(psb_d_base_vect_type), intent(inout) :: y
class(psb_i_base_vect_type), intent(inout) :: idx
integer(psb_ipk_), intent(in) :: num_neighbors,total_send, total_recv
class(psb_comm_handle_type), intent(inout) :: comm_handle
integer(psb_ipk_), intent(out) :: info
! locals
@ -232,6 +246,7 @@ contains
integer(psb_mpk_) :: proc_to_comm, p2ptag, p2pstat(mpi_status_size), iret
integer(psb_mpk_) :: icomm
integer(psb_mpk_), allocatable :: prcid(:)
type(psb_comm_baseline_handle), pointer :: baseline_comm_handle
integer(psb_ipk_) :: err_act, i, idx_pt, total_send_, total_recv_,&
& snd_pt, rcv_pt, pnti
logical :: swap_mpi, swap_sync, swap_send, swap_recv,&
@ -250,11 +265,21 @@ contains
endif
icomm = ctxt%get_mpic()
baseline_comm_handle => null()
select type(ch => comm_handle)
type is(psb_comm_baseline_handle)
baseline_comm_handle => ch
class default
info = psb_err_mpi_error_
call psb_errpush(info,name,a_err='Expected baseline comm_handle in baseline swaptran')
goto 9999
end select
n=1
swap_mpi = iand(flag,psb_swap_mpi_) /= 0
swap_sync = iand(flag,psb_swap_sync_) /= 0
swap_send = iand(flag,psb_swap_send_) /= 0
swap_recv = iand(flag,psb_swap_recv_) /= 0
swap_mpi = iand(swap_status,psb_swap_mpi_) /= 0
swap_sync = iand(swap_status,psb_swap_sync_) /= 0
swap_send = iand(swap_status,psb_swap_send_) /= 0
swap_recv = iand(swap_status,psb_swap_recv_) /= 0
do_send = swap_mpi .or. swap_sync .or. swap_send
do_recv = swap_mpi .or. swap_sync .or. swap_recv
@ -265,8 +290,8 @@ contains
if (debug) write(*,*) me,'Internal buffer'
if (do_send) then
if (allocated(y%comid)) then
if (any(y%comid /= mpi_request_null)) then
if (allocated(baseline_comm_handle%comid)) then
if (any(baseline_comm_handle%comid /= mpi_request_null)) then
!
! Unfinished communication? Something is wrong....
!
@ -277,8 +302,8 @@ contains
end if
if (debug) write(*,*) me,'do_send start'
call y%new_buffer(ione*size(idx%v),info)
call y%new_comid(num_neighbors,info)
y%comid = mpi_request_null
call psb_realloc(num_neighbors,2_psb_ipk_,baseline_comm_handle%comid,info)
baseline_comm_handle%comid = mpi_request_null
call psb_realloc(num_neighbors,prcid,info)
! First I post all the non blocking receives
pnti = 1
@ -295,7 +320,7 @@ contains
if (debug) write(*,*) me,'Posting receive from',prcid(i),rcv_pt
call mpi_irecv(y%combuf(snd_pt),nesd,&
& psb_mpi_r_dpk_,prcid(i),&
& p2ptag, icomm,y%comid(i,2),iret)
& p2ptag, icomm,baseline_comm_handle%comid(i,2),iret)
end if
pnti = pnti + nerv + nesd + 3
end do
@ -342,7 +367,7 @@ contains
if ((nerv>0).and.(proc_to_comm /= me)) then
call mpi_isend(y%combuf(rcv_pt),nerv,&
& psb_mpi_r_dpk_,prcid(i),&
& p2ptag,icomm,y%comid(i,1),iret)
& p2ptag,icomm,baseline_comm_handle%comid(i,1),iret)
end if
if(iret /= mpi_success) then
@ -357,7 +382,7 @@ contains
if (do_recv) then
if (debug) write(*,*) me,' do_Recv'
if (.not.allocated(y%comid)) then
if (.not.allocated(baseline_comm_handle%comid)) then
!
! No matching send? Something is wrong....
!
@ -379,7 +404,7 @@ contains
if (proc_to_comm /= me)then
if (nerv>0) then
call mpi_wait(y%comid(i,1),p2pstat,iret)
call mpi_wait(baseline_comm_handle%comid(i,1),p2pstat,iret)
if(iret /= mpi_success) then
info=psb_err_mpi_error_
call psb_errpush(info,name,m_err=(/iret/))
@ -387,7 +412,7 @@ contains
end if
end if
if (nesd>0) then
call mpi_wait(y%comid(i,2),p2pstat,iret)
call mpi_wait(baseline_comm_handle%comid(i,2),p2pstat,iret)
if(iret /= mpi_success) then
info=psb_err_mpi_error_
call psb_errpush(info,name,m_err=(/iret/))
@ -425,7 +450,7 @@ contains
!
! Waited for everybody, clean up
!
y%comid = mpi_request_null
baseline_comm_handle%comid = mpi_request_null
!
! Then wait for device
@ -434,7 +459,9 @@ contains
call y%device_wait()
if (debug) write(*,*) me,' free buffer'
call y%maybe_free_buffer(info)
if (info == 0) call y%free_comid(info)
if (info == 0) then
if (allocated(y%comm_handle)) call psb_comm_free(y%comm_handle, info)
end if
if (info /= 0) then
call psb_errpush(psb_err_alloc_dealloc_,name)
goto 9999
@ -455,7 +482,8 @@ contains
subroutine psi_dtran_neighbor_topology_vect(ctxt,flag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,info)
subroutine psi_dtran_neighbor_topology_vect(ctxt,swap_status,beta,y,comm_indexes,&
& num_neighbors,total_send,total_recv,comm_handle,info)
#ifdef PSB_MPI_MOD
use mpi
#endif
@ -465,17 +493,19 @@ contains
#endif
type(psb_ctxt_type), intent(in) :: ctxt
integer(psb_ipk_), intent(in) :: flag
integer(psb_ipk_), intent(in) :: swap_status
real(psb_dpk_), intent(in) :: beta
class(psb_d_base_vect_type), intent(inout) :: y
class(psb_i_base_vect_type), intent(inout) :: comm_indexes
integer(psb_ipk_), intent(in) :: num_neighbors,total_send,total_recv
class(psb_comm_handle_type), intent(inout) :: comm_handle
integer(psb_ipk_), intent(out) :: info
! locals
integer(psb_mpk_) :: icomm
integer(psb_mpk_) :: np, me
integer(psb_mpk_) :: iret, p2pstat(mpi_status_size)
type(psb_comm_neighbor_handle), pointer :: neighbor_comm_handle
integer(psb_ipk_) :: err_act, topology_total_send, topology_total_recv, buffer_size
logical :: do_start, do_wait
logical, parameter :: debug = .false.
@ -494,8 +524,18 @@ contains
icomm = ctxt%get_mpic()
do_start = iand(flag,psb_swap_start_) /= 0
do_wait = iand(flag,psb_swap_wait_) /= 0
neighbor_comm_handle => null()
select type(ch => comm_handle)
type is(psb_comm_neighbor_handle)
neighbor_comm_handle => ch
class default
info = psb_err_mpi_error_
call psb_errpush(info,name,a_err='Expected neighbor comm_handle in neighbor swaptran')
goto 9999
end select
do_start = iand(swap_status,psb_swap_start_) /= 0
do_wait = iand(swap_status,psb_swap_wait_) /= 0
call comm_indexes%sync()
@ -505,9 +545,9 @@ contains
if (do_start) then
if(debug) write(*,*) me,' nbr_vect: starting data exchange'
! Lazy initialization: build the topology on first call
if (.not. y%neighbor_topology%is_initialized) then
if (.not. neighbor_comm_handle%is_initialized) then
if (debug) write(*,*) me,' nbr_vect: building topology'
call y%neighbor_topology%init(comm_indexes%v, num_neighbors, total_send, total_recv, ctxt, icomm, info)
call neighbor_comm_handle%topology_init(comm_indexes%v, num_neighbors, total_send, total_recv, ctxt, icomm, info)
if (info /= psb_success_) then
call psb_errpush(psb_err_internal_error_, name, &
& a_err='neighbor_topology_init')
@ -515,8 +555,8 @@ contains
end if
end if
topology_total_send = y%neighbor_topology%total_send
topology_total_recv = y%neighbor_topology%total_recv
topology_total_send = neighbor_comm_handle%total_send
topology_total_recv = neighbor_comm_handle%total_recv
! Buffer layout:
! combuf(1 : total_send) = send area
@ -528,12 +568,12 @@ contains
call psb_errpush(psb_err_alloc_dealloc_, name)
goto 9999
end if
y%communication_handle = mpi_request_null
neighbor_comm_handle%comm_request = mpi_request_null
! For transpose exchange: gather recv area first (we will send "recv" data)
if (debug) write(*,*) me,' nbr_tran_vect: gathering recv data,', topology_total_recv,' elems'
call y%gth(int(topology_total_recv,psb_mpk_), &
& y%neighbor_topology%recv_indexes, &
& neighbor_comm_handle%recv_indexes, &
& y%combuf(1:topology_total_recv))
! Wait for device (important for GPU subclasses)
@ -543,15 +583,15 @@ contains
if (debug) write(*,*) me,' nbr_tran_vect: posting MPI_Ineighbor_alltoallv (swapped)'
call mpi_ineighbor_alltoallv( &
& y%combuf(1), & ! send buffer (recv_indexes gathered)
& y%neighbor_topology%recv_counts, &
& y%neighbor_topology%recv_displs, &
& neighbor_comm_handle%recv_counts, &
& neighbor_comm_handle%recv_displs, &
& psb_mpi_r_dpk_, &
& y%combuf(topology_total_recv + 1), & ! recv buffer (will contain send_indexes data)
& y%neighbor_topology%send_counts, &
& y%neighbor_topology%send_displs, &
& neighbor_comm_handle%send_counts, &
& neighbor_comm_handle%send_displs, &
& psb_mpi_r_dpk_, &
& y%neighbor_topology%graph_comm, &
& y%communication_handle, iret)
& neighbor_comm_handle%graph_comm, &
& neighbor_comm_handle%comm_request, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info, name, m_err=(/iret/))
@ -565,19 +605,19 @@ contains
! ---------------------------------------------------------
if (do_wait) then
if (y%communication_handle == mpi_request_null) then
if (neighbor_comm_handle%comm_request == mpi_request_null) then
! No matching start? Something is wrong
info = psb_err_mpi_error_
call psb_errpush(info, name, m_err=(/-2/))
goto 9999
end if
topology_total_send = y%neighbor_topology%total_send
topology_total_recv = y%neighbor_topology%total_recv
topology_total_send = neighbor_comm_handle%total_send
topology_total_recv = neighbor_comm_handle%total_recv
! Wait for the non-blocking collective to complete
if (debug) write(*,*) me,' nbr_vect: waiting on MPI request'
call mpi_wait(y%communication_handle, p2pstat, iret)
call mpi_wait(neighbor_comm_handle%comm_request, p2pstat, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info, name, m_err=(/iret/))
@ -587,13 +627,13 @@ contains
! For transpose exchange: scatter the data that correspond to peers' send area
if (debug) write(*,*) me,' nbr_tran_vect: scattering send-index data,', topology_total_send,' elems'
call y%sct(int(topology_total_send,psb_mpk_), &
& y%neighbor_topology%send_indexes, &
& neighbor_comm_handle%send_indexes, &
& y%combuf(topology_total_recv+1:topology_total_recv+topology_total_send), &
& beta)
! Clean up
y%communication_handle = mpi_request_null
neighbor_comm_handle%comm_request = mpi_request_null
call y%device_wait()
call y%maybe_free_buffer(info)
if (info /= 0) then
@ -625,7 +665,7 @@ contains
! Takes care of Y an encaspulated multivector.
!
!
module subroutine psi_dswaptran_multivect(flag,beta,y,desc_a,info,data)
module subroutine psi_dswaptran_multivect(swap_status,beta,y,desc_a,info,data)
#ifdef PSB_MPI_MOD
use mpi
@ -635,7 +675,7 @@ contains
include 'mpif.h'
#endif
integer(psb_ipk_), intent(in) :: flag
integer(psb_ipk_), intent(in) :: swap_status
real(psb_dpk_), intent(in) :: beta
class(psb_d_base_multivect_type), intent(inout) :: y
type(psb_desc_type),target :: desc_a
@ -685,37 +725,46 @@ contains
goto 9999
end if
swap_mpi = iand(flag,psb_swap_mpi_) /= 0
swap_sync = iand(flag,psb_swap_sync_) /= 0
swap_send = iand(flag,psb_swap_send_) /= 0
swap_recv = iand(flag,psb_swap_recv_) /= 0
swap_start = iand(flag,psb_swap_start_) /= 0
swap_wait = iand(flag,psb_swap_wait_) /= 0
swap_mpi = iand(swap_status,psb_swap_mpi_) /= 0
swap_sync = iand(swap_status,psb_swap_sync_) /= 0
swap_send = iand(swap_status,psb_swap_send_) /= 0
swap_recv = iand(swap_status,psb_swap_recv_) /= 0
swap_start = iand(swap_status,psb_swap_start_) /= 0
swap_wait = iand(swap_status,psb_swap_wait_) /= 0
baseline = swap_mpi .or. swap_send .or. swap_recv .or. swap_sync
neighbor_a2av = swap_start .or. swap_wait
if( (baseline.eqv..true.).and.(neighbor_a2av.eqv..true.) ) then
info = psb_err_mpi_error_
call psb_errpush(info,name,a_err='Incompatible flag settings: both baseline and neighbor_a2av are true')
call psb_errpush(info,name,a_err='Incompatible swap_status settings: both baseline and neighbor_a2av are true')
goto 9999
end if
if (.not. allocated(y%comm_handle)) then
call psb_comm_init(psb_comm_isend_irecv_, y%comm_handle, info)
if (info /= psb_success_) then
call psb_errpush(psb_err_internal_error_, name, a_err='init comm default baseline')
goto 9999
end if
end if
if (baseline) then
call psi_dtran_baseline_multivect(ctxt,flag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,info)
call psi_dtran_baseline_multivect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,total_send,total_recv,y%comm_handle,info)
if (info /= psb_success_) then
call psb_errpush(info,name,a_err='baseline swap')
goto 9999
end if
else if (neighbor_a2av) then
call psi_dtran_neighbor_topology_multivect(ctxt,flag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,info)
call psi_dtran_neighbor_topology_multivect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,&
& total_send,total_recv,y%comm_handle,info)
if (info /= psb_success_) then
call psb_errpush(info,name,a_err='neighbor a2av swap')
goto 9999
end if
else
info = psb_err_mpi_error_
call psb_errpush(info,name,a_err='Incompatible flag settings: neither baseline nor neighbor_a2av is true')
call psb_errpush(info,name,a_err='Incompatible swap_status settings: neither baseline nor neighbor_a2av is true')
goto 9999
end if
@ -727,8 +776,8 @@ contains
return
end subroutine psi_dswaptran_multivect
subroutine psi_dtran_baseline_multivect(ctxt,flag,beta,y,idx,&
& num_neighbors,total_send,total_recv,info)
subroutine psi_dtran_baseline_multivect(ctxt,swap_status,beta,y,idx,&
& num_neighbors,total_send,total_recv,comm_handle,info)
#ifdef PSB_MPI_MOD
use mpi
@ -739,18 +788,20 @@ contains
#endif
type(psb_ctxt_type), intent(in) :: ctxt
integer(psb_ipk_), intent(in) :: flag
integer(psb_ipk_), intent(in) :: swap_status
integer(psb_ipk_), intent(out) :: info
class(psb_d_base_multivect_type), intent(inout) :: y
real(psb_dpk_), intent(in) :: beta
class(psb_i_base_vect_type), intent(inout) :: idx
integer(psb_ipk_), intent(in) :: num_neighbors,total_send,total_recv
class(psb_comm_handle_type), intent(inout) :: comm_handle
! locals
integer(psb_mpk_) :: np, me, nesd, nerv, n
integer(psb_mpk_) :: proc_to_comm, p2ptag, p2pstat(mpi_status_size), iret
integer(psb_mpk_) :: icomm
integer(psb_mpk_), allocatable :: prcid(:)
type(psb_comm_baseline_handle), pointer :: baseline_comm_handle
integer(psb_ipk_) :: err_act, i, idx_pt, total_send_, total_recv_,&
& snd_pt, rcv_pt, pnti
logical :: swap_mpi, swap_sync, swap_send, swap_recv,&
@ -769,12 +820,22 @@ contains
endif
icomm = ctxt%get_mpic()
baseline_comm_handle => null()
select type(ch => comm_handle)
type is(psb_comm_baseline_handle)
baseline_comm_handle => ch
class default
info = psb_err_mpi_error_
call psb_errpush(info,name,a_err='Expected baseline comm_handle in baseline multivect swaptran')
goto 9999
end select
n = y%get_ncols()
swap_mpi = iand(flag,psb_swap_mpi_) /= 0
swap_sync = iand(flag,psb_swap_sync_) /= 0
swap_send = iand(flag,psb_swap_send_) /= 0
swap_recv = iand(flag,psb_swap_recv_) /= 0
swap_mpi = iand(swap_status,psb_swap_mpi_) /= 0
swap_sync = iand(swap_status,psb_swap_sync_) /= 0
swap_send = iand(swap_status,psb_swap_send_) /= 0
swap_recv = iand(swap_status,psb_swap_recv_) /= 0
do_send = swap_mpi .or. swap_sync .or. swap_send
do_recv = swap_mpi .or. swap_sync .or. swap_recv
@ -785,8 +846,8 @@ contains
if (debug) write(*,*) me,'Internal buffer'
if (do_send) then
if (allocated(y%comid)) then
if (any(y%comid /= mpi_request_null)) then
if (allocated(baseline_comm_handle%comid)) then
if (any(baseline_comm_handle%comid /= mpi_request_null)) then
!
! Unfinished communication? Something is wrong....
!
@ -797,8 +858,8 @@ contains
end if
if (debug) write(*,*) me,'do_send start'
call y%new_buffer(ione*size(idx%v),info)
call y%new_comid(num_neighbors,info)
y%comid = mpi_request_null
call psb_realloc(num_neighbors,2_psb_ipk_,baseline_comm_handle%comid,info)
baseline_comm_handle%comid = mpi_request_null
call psb_realloc(num_neighbors,prcid,info)
! First I post all the non blocking receives
pnti = 1
@ -814,7 +875,7 @@ contains
if (debug) write(*,*) me,'Posting receive from',prcid(i),snd_pt
call mpi_irecv(y%combuf(snd_pt),n*nesd,&
& psb_mpi_r_dpk_,prcid(i),&
& p2ptag, icomm,y%comid(i,2),iret)
& p2ptag, icomm,baseline_comm_handle%comid(i,2),iret)
end if
rcv_pt = rcv_pt + n*nerv
snd_pt = snd_pt + n*nesd
@ -861,7 +922,7 @@ contains
if ((nerv>0).and.(proc_to_comm /= me)) then
call mpi_isend(y%combuf(rcv_pt),n*nerv,&
& psb_mpi_r_dpk_,prcid(i),&
& p2ptag,icomm,y%comid(i,1),iret)
& p2ptag,icomm,baseline_comm_handle%comid(i,1),iret)
end if
if(iret /= mpi_success) then
@ -877,7 +938,7 @@ contains
if (do_recv) then
if (debug) write(*,*) me,' do_Recv'
if (.not.allocated(y%comid)) then
if (.not.allocated(baseline_comm_handle%comid)) then
!
! No matching send? Something is wrong....
!
@ -898,7 +959,7 @@ contains
nesd = idx%v(pnti+nerv+psb_n_elem_send_)
if (proc_to_comm /= me)then
if (nerv>0) then
call mpi_wait(y%comid(i,1),p2pstat,iret)
call mpi_wait(baseline_comm_handle%comid(i,1),p2pstat,iret)
if(iret /= mpi_success) then
info=psb_err_mpi_error_
call psb_errpush(info,name,m_err=(/iret/))
@ -906,7 +967,7 @@ contains
end if
end if
if (nesd>0) then
call mpi_wait(y%comid(i,2),p2pstat,iret)
call mpi_wait(baseline_comm_handle%comid(i,2),p2pstat,iret)
if(iret /= mpi_success) then
info=psb_err_mpi_error_
call psb_errpush(info,name,m_err=(/iret/))
@ -948,7 +1009,7 @@ contains
!
! Waited for com, cleanup comid
!
y%comid = mpi_request_null
baseline_comm_handle%comid = mpi_request_null
!
! Then wait for device
@ -957,7 +1018,9 @@ contains
call y%device_wait()
if (debug) write(*,*) me,' free buffer'
call y%maybe_free_buffer(info)
if (info == 0) call y%free_comid(info)
if (info == 0) then
if (allocated(y%comm_handle)) call psb_comm_free(y%comm_handle, info)
end if
if (info /= 0) then
call psb_errpush(psb_err_alloc_dealloc_,name)
goto 9999
@ -976,7 +1039,8 @@ contains
end subroutine psi_dtran_baseline_multivect
subroutine psi_dtran_neighbor_topology_multivect(ctxt,flag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,info)
subroutine psi_dtran_neighbor_topology_multivect(ctxt,swap_status,beta,y,comm_indexes,&
& num_neighbors,total_send,total_recv,comm_handle,info)
#ifdef PSB_MPI_MOD
use mpi
#endif
@ -986,17 +1050,19 @@ contains
#endif
type(psb_ctxt_type), intent(in) :: ctxt
integer(psb_ipk_), intent(in) :: flag
integer(psb_ipk_), intent(in) :: swap_status
real(psb_dpk_), intent(in) :: beta
class(psb_d_base_multivect_type), intent(inout) :: y
class(psb_i_base_vect_type), intent(inout) :: comm_indexes
integer(psb_ipk_), intent(in) :: num_neighbors,total_send,total_recv
class(psb_comm_handle_type), intent(inout) :: comm_handle
integer(psb_ipk_), intent(out) :: info
! locals
integer(psb_mpk_) :: icomm
integer(psb_mpk_) :: np, me
integer(psb_mpk_) :: iret, p2pstat(mpi_status_size)
type(psb_comm_neighbor_handle), pointer :: neighbor_comm_handle
integer(psb_ipk_) :: err_act, topology_total_send, topology_total_recv, buffer_size
logical :: do_start, do_wait
logical, parameter :: debug = .false.
@ -1015,8 +1081,18 @@ contains
icomm = ctxt%get_mpic()
do_start = iand(flag,psb_swap_start_) /= 0
do_wait = iand(flag,psb_swap_wait_) /= 0
neighbor_comm_handle => null()
select type(ch => comm_handle)
type is(psb_comm_neighbor_handle)
neighbor_comm_handle => ch
class default
info = psb_err_mpi_error_
call psb_errpush(info,name,a_err='Expected neighbor comm_handle in neighbor multivect swaptran')
goto 9999
end select
do_start = iand(swap_status,psb_swap_start_) /= 0
do_wait = iand(swap_status,psb_swap_wait_) /= 0
call comm_indexes%sync()
@ -1026,9 +1102,9 @@ contains
if (do_start) then
if(debug) write(*,*) me,' nbr_vect: starting data exchange'
! Lazy initialization: build the topology on first call
if (.not. y%neighbor_topology%is_initialized) then
if (.not. neighbor_comm_handle%is_initialized) then
if (debug) write(*,*) me,' nbr_vect: building topology'
call y%neighbor_topology%init(comm_indexes%v, num_neighbors, total_send, total_recv, ctxt, icomm, info)
call neighbor_comm_handle%topology_init(comm_indexes%v, num_neighbors, total_send, total_recv, ctxt, icomm, info)
if (info /= psb_success_) then
call psb_errpush(psb_err_internal_error_, name, &
& a_err='neighbor_topology_init')
@ -1036,8 +1112,8 @@ contains
end if
end if
topology_total_send = y%neighbor_topology%total_send
topology_total_recv = y%neighbor_topology%total_recv
topology_total_send = neighbor_comm_handle%total_send
topology_total_recv = neighbor_comm_handle%total_recv
! Buffer layout:
! combuf(1 : total_send) = send area
@ -1049,12 +1125,12 @@ contains
call psb_errpush(psb_err_alloc_dealloc_, name)
goto 9999
end if
y%communication_handle = mpi_request_null
neighbor_comm_handle%comm_request = mpi_request_null
! For transpose exchange: gather recv area first (we will send "recv" data)
if (debug) write(*,*) me,' nbr_tran_vect: gathering recv data,', topology_total_recv,' elems'
call y%gth(int(topology_total_recv,psb_mpk_), &
& y%neighbor_topology%recv_indexes, &
& neighbor_comm_handle%recv_indexes, &
& y%combuf(1:topology_total_recv))
! Wait for device (important for GPU subclasses)
@ -1064,15 +1140,15 @@ contains
if (debug) write(*,*) me,' nbr_tran_vect: posting MPI_Ineighbor_alltoallv (swapped)'
call mpi_ineighbor_alltoallv( &
& y%combuf(1), & ! send buffer (recv_indexes gathered)
& y%neighbor_topology%recv_counts, &
& y%neighbor_topology%recv_displs, &
& neighbor_comm_handle%recv_counts, &
& neighbor_comm_handle%recv_displs, &
& psb_mpi_r_dpk_, &
& y%combuf(topology_total_recv + 1), & ! recv buffer (will contain send_indexes data)
& y%neighbor_topology%send_counts, &
& y%neighbor_topology%send_displs, &
& neighbor_comm_handle%send_counts, &
& neighbor_comm_handle%send_displs, &
& psb_mpi_r_dpk_, &
& y%neighbor_topology%graph_comm, &
& y%communication_handle, iret)
& neighbor_comm_handle%graph_comm, &
& neighbor_comm_handle%comm_request, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info, name, m_err=(/iret/))
@ -1086,19 +1162,19 @@ contains
! ---------------------------------------------------------
if (do_wait) then
if (y%communication_handle == mpi_request_null) then
if (neighbor_comm_handle%comm_request == mpi_request_null) then
! No matching start? Something is wrong
info = psb_err_mpi_error_
call psb_errpush(info, name, m_err=(/-2/))
goto 9999
end if
topology_total_send = y%neighbor_topology%total_send
topology_total_recv = y%neighbor_topology%total_recv
topology_total_send = neighbor_comm_handle%total_send
topology_total_recv = neighbor_comm_handle%total_recv
! Wait for the non-blocking collective to complete
if (debug) write(*,*) me,' nbr_vect: waiting on MPI request'
call mpi_wait(y%communication_handle, p2pstat, iret)
call mpi_wait(neighbor_comm_handle%comm_request, p2pstat, iret)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info, name, m_err=(/iret/))
@ -1108,13 +1184,13 @@ contains
! For transpose exchange: scatter the data that correspond to peers' send area
if (debug) write(*,*) me,' nbr_tran_vect: scattering send-index data,', topology_total_send,' elems'
call y%sct(int(topology_total_send,psb_mpk_), &
& y%neighbor_topology%send_indexes, &
& neighbor_comm_handle%send_indexes, &
& y%combuf(topology_total_recv+1:topology_total_recv+topology_total_send), &
& beta)
! Clean up
y%communication_handle = mpi_request_null
neighbor_comm_handle%comm_request = mpi_request_null
call y%device_wait()
call y%maybe_free_buffer(info)
if (info /= 0) then

@ -28,7 +28,10 @@ COMMINT= penv/psi_penv_mod.o \
SERIAL_MODS=serial/psb_s_serial_mod.o serial/psb_d_serial_mod.o \
serial/psb_c_serial_mod.o serial/psb_z_serial_mod.o \
serial/psb_serial_mod.o \
comm/psb_neighbor_topology_mod.o \
comm/comm_schemes/psb_comm_schemes_mod.o \
comm/comm_schemes/psb_comm_baseline_mod.o \
comm/comm_schemes/psb_comm_neighbor_impl_mod.o \
comm/comm_schemes/psb_comm_factory_mod.o \
serial/psb_i_base_vect_mod.o serial/psb_i_vect_mod.o\
serial/psb_l_base_vect_mod.o serial/psb_l_vect_mod.o\
serial/psb_d_base_vect_mod.o serial/psb_d_vect_mod.o\
@ -95,9 +98,11 @@ UTIL_MODS = desc/psb_desc_const_mod.o desc/psb_indx_map_mod.o\
comm/psb_s_linmap_mod.o comm/psb_d_linmap_mod.o \
comm/psb_c_linmap_mod.o comm/psb_z_linmap_mod.o \
comm/psb_comm_mod.o \
comm/psb_neighbor_topology_mod.o \
comm/psb_i_comm_mod.o comm/psb_l_comm_mod.o \
comm/psb_s_comm_mod.o comm/psb_d_comm_mod.o\
comm/psb_c_comm_mod.o comm/psb_z_comm_mod.o \
comm/psb_i2_comm_a_mod.o \
comm/psb_m_comm_a_mod.o comm/psb_e_comm_a_mod.o \
comm/psb_s_comm_a_mod.o comm/psb_d_comm_a_mod.o\
comm/psb_c_comm_a_mod.o comm/psb_z_comm_a_mod.o \
@ -182,6 +187,14 @@ auxil/psb_string_mod.o auxil/psb_m_realloc_mod.o auxil/psb_e_realloc_mod.o auxil
auxil/psb_d_realloc_mod.o auxil/psb_c_realloc_mod.o auxil/psb_z_realloc_mod.o \
desc/psb_desc_const_mod.o psi_penv_mod.o: psb_const_mod.o
comm/comm_schemes/psb_comm_schemes_mod.o: psb_const_mod.o
comm/comm_schemes/psb_comm_baseline_mod.o: comm/comm_schemes/psb_comm_schemes_mod.o
comm/comm_schemes/psb_comm_neighbor_impl_mod.o: psb_const_mod.o desc/psb_desc_const_mod.o comm/comm_schemes/psb_comm_schemes_mod.o
comm/comm_schemes/psb_comm_factory_mod.o: comm/comm_schemes/psb_comm_schemes_mod.o comm/comm_schemes/psb_comm_baseline_mod.o comm/comm_schemes/psb_comm_neighbor_impl_mod.o
comm/psb_neighbor_topology_mod.o: psb_const_mod.o desc/psb_desc_const_mod.o
desc/psb_indx_map_mod.o desc/psb_hash_mod.o: psb_realloc_mod.o psb_const_mod.o desc/psb_desc_const_mod.o
@ -263,8 +276,10 @@ serial/psb_d_base_mat_mod.o: serial/psb_d_base_vect_mod.o
serial/psb_c_base_mat_mod.o: serial/psb_c_base_vect_mod.o
serial/psb_z_base_mat_mod.o: serial/psb_z_base_vect_mod.o
serial/psb_l_base_vect_mod.o: serial/psb_i_base_vect_mod.o
serial/psb_c_base_vect_mod.o serial/psb_s_base_vect_mod.o serial/psb_d_base_vect_mod.o serial/psb_z_base_vect_mod.o: serial/psb_i_base_vect_mod.o serial/psb_l_base_vect_mod.o comm/psb_neighbor_topology_mod.o
serial/psb_i_base_vect_mod.o serial/psb_l_base_vect_mod.o serial/psb_c_base_vect_mod.o serial/psb_s_base_vect_mod.o serial/psb_d_base_vect_mod.o serial/psb_z_base_vect_mod.o: auxil/psi_serial_mod.o psb_realloc_mod.o comm/psb_neighbor_topology_mod.o
serial/psb_c_base_vect_mod.o serial/psb_s_base_vect_mod.o serial/psb_z_base_vect_mod.o: serial/psb_i_base_vect_mod.o serial/psb_l_base_vect_mod.o comm/comm_schemes/psb_comm_neighbor_impl_mod.o
serial/psb_d_base_vect_mod.o: serial/psb_i_base_vect_mod.o serial/psb_l_base_vect_mod.o comm/comm_schemes/psb_comm_schemes_mod.o comm/comm_schemes/psb_comm_factory_mod.o
serial/psb_i_base_vect_mod.o serial/psb_l_base_vect_mod.o serial/psb_c_base_vect_mod.o serial/psb_s_base_vect_mod.o serial/psb_d_base_vect_mod.o serial/psb_z_base_vect_mod.o: auxil/psi_serial_mod.o psb_realloc_mod.o comm/comm_schemes/psb_comm_neighbor_impl_mod.o \
comm/psb_neighbor_topology_mod.o
serial/psb_s_mat_mod.o: serial/psb_s_base_mat_mod.o serial/psb_s_csr_mat_mod.o serial/psb_s_csc_mat_mod.o serial/psb_s_vect_mod.o \
serial/psb_i_vect_mod.o serial/psb_l_vect_mod.o
serial/psb_d_mat_mod.o: serial/psb_d_base_mat_mod.o serial/psb_d_csr_mat_mod.o serial/psb_d_csc_mat_mod.o serial/psb_d_vect_mod.o \
@ -342,11 +357,12 @@ comm/psb_comm_mod.o: desc/psb_desc_mod.o serial/psb_mat_mod.o
comm/psb_comm_mod.o: comm/psb_i_comm_mod.o comm/psb_l_comm_mod.o \
comm/psb_s_comm_mod.o comm/psb_d_comm_mod.o \
comm/psb_c_comm_mod.o comm/psb_z_comm_mod.o \
comm/psb_i2_comm_a_mod.o \
comm/psb_m_comm_a_mod.o comm/psb_e_comm_a_mod.o \
comm/psb_s_comm_a_mod.o comm/psb_d_comm_a_mod.o\
comm/psb_c_comm_a_mod.o comm/psb_z_comm_a_mod.o
comm/psb_m_comm_a_mod.o comm/psb_e_comm_a_mod.o \
comm/psb_i2_comm_a_mod.o comm/psb_m_comm_a_mod.o comm/psb_e_comm_a_mod.o \
comm/psb_s_comm_a_mod.o comm/psb_d_comm_a_mod.o\
comm/psb_c_comm_a_mod.o comm/psb_z_comm_a_mod.o: desc/psb_desc_mod.o
@ -420,5 +436,4 @@ clean:
/bin/rm -f $(MODULES) $(OBJS) $(MPFOBJS) *$(.mod)
veryclean: clean
/bin/rm -f *.h
/bin/rm -f *.h

@ -0,0 +1,62 @@
module psb_comm_baseline_mod
use psb_comm_schemes_mod, only: psb_comm_handle_type, psb_comm_isend_irecv_
use psb_const_mod
implicit none
type, extends(psb_comm_handle_type) :: psb_comm_baseline_handle
! MPI request IDs for Isend/Irecv (dimension: num_neighbors x 2)
! First column: send requests, second column: recv requests
integer(psb_ipk_), allocatable :: comid(:,:)
contains
procedure, pass :: init => psb_comm_baseline_init
procedure, pass :: free => psb_comm_baseline_free
procedure, pass :: set_swap_status => psb_comm_baseline_set_swap_status
procedure, pass :: get_swap_status => psb_comm_baseline_get_swap_status
end type psb_comm_baseline_handle
contains
subroutine psb_comm_baseline_init(this, info)
implicit none
class(psb_comm_baseline_handle), intent(inout) :: this
integer(psb_ipk_), intent(out) :: info
info = 0
this%comm_type = psb_comm_isend_irecv_
this%id = 0
this%swap_status = 0
end subroutine psb_comm_baseline_init
subroutine psb_comm_baseline_free(this, info)
class(psb_comm_baseline_handle), intent(inout) :: this
integer(psb_ipk_), intent(out) :: info
info = 0
! Free MPI resources (comid)
if (allocated(this%comid)) deallocate(this%comid, stat=info)
end subroutine psb_comm_baseline_free
subroutine psb_comm_baseline_set_swap_status(this, flag, info)
class(psb_comm_baseline_handle), intent(inout) :: this
integer(psb_ipk_), intent(in) :: flag
integer(psb_ipk_), intent(out) :: info
info = 0
this%swap_status = flag
end subroutine psb_comm_baseline_set_swap_status
subroutine psb_comm_baseline_get_swap_status(this, flag, info)
class(psb_comm_baseline_handle), intent(in) :: this
integer(psb_ipk_), intent(out) :: flag
integer(psb_ipk_), intent(out) :: info
info = 0
flag = this%swap_status
end subroutine psb_comm_baseline_get_swap_status
! Allocate comid array for num_neighbors
subroutine psb_comm_baseline_alloc_comid(this, n, info)
implicit none
class(psb_comm_baseline_handle), intent(inout) :: this
integer(psb_ipk_), intent(in) :: n
integer(psb_ipk_), intent(out) :: info
allocate(this%comid(n, 2_psb_ipk_), stat=info)
end subroutine psb_comm_baseline_alloc_comid
end module psb_comm_baseline_mod

@ -0,0 +1,100 @@
module psb_comm_factory_mod
use psb_const_mod
use psb_comm_schemes_mod, only: psb_comm_handle_type, psb_comm_ineighbor_alltoallv_, &
& psb_comm_persistent_ineighbor_alltoallv_, psb_comm_unknown_
use psb_comm_baseline_mod, only: psb_comm_baseline_handle
use psb_comm_neighbor_impl_mod, only: psb_comm_neighbor_handle
implicit none
contains
! Allocatable-based factory routines (preferred names)
subroutine psb_comm_init(comm_type, handle, info)
implicit none
integer(psb_ipk_), intent(in) :: comm_type
class(psb_comm_handle_type), allocatable, intent(inout) :: handle
integer(psb_ipk_), intent(out) :: info
info = 0
if (allocated(handle)) then
info = -1
return
end if
select case(comm_type)
case(psb_comm_ineighbor_alltoallv_, psb_comm_persistent_ineighbor_alltoallv_)
allocate(psb_comm_neighbor_handle :: handle, stat=info)
if (info /= 0) return
call handle%init(info)
if (info /= 0) return
select type(h => handle)
type is(psb_comm_neighbor_handle)
h%comm_type = comm_type
h%use_persistent_buffers = (comm_type == psb_comm_persistent_ineighbor_alltoallv_)
end select
case default
allocate(psb_comm_baseline_handle :: handle, stat=info)
if (info /= 0) return
call handle%init(info)
end select
end subroutine psb_comm_init
subroutine psb_comm_free(handle, info)
implicit none
class(psb_comm_handle_type), allocatable, intent(inout) :: handle
integer(psb_ipk_), intent(out) :: info
info = 0
if (.not. allocated(handle)) return
call handle%free(info)
if (allocated(handle)) then
deallocate(handle)
end if
end subroutine psb_comm_free
! Allocatable-based factory routines
subroutine psb_comm_create(comm_type, handle, info)
implicit none
integer(psb_ipk_), intent(in) :: comm_type
class(psb_comm_handle_type), allocatable, intent(inout) :: handle
integer(psb_ipk_), intent(out) :: info
call psb_comm_init(comm_type, handle, info)
end subroutine psb_comm_create
subroutine psb_comm_destroy(handle, info)
implicit none
class(psb_comm_handle_type), allocatable, intent(inout) :: handle
integer(psb_ipk_), intent(out) :: info
call psb_comm_free(handle, info)
end subroutine psb_comm_destroy
subroutine psb_comm_set_swap_status(handle, flag, info)
implicit none
class(psb_comm_handle_type), allocatable, intent(inout) :: handle
integer(psb_ipk_), intent(in) :: flag
integer(psb_ipk_), intent(out) :: info
info = 0
if (.not. allocated(handle)) then
info = -1
return
end if
call handle%set_swap_status(flag, info)
end subroutine psb_comm_set_swap_status
subroutine psb_comm_get_swap_status(handle, flag, info)
implicit none
class(psb_comm_handle_type), allocatable, intent(in) :: handle
integer(psb_ipk_), intent(out) :: flag
integer(psb_ipk_), intent(out) :: info
info = 0
if (.not. allocated(handle)) then
flag = 0
info = -1
return
end if
call handle%get_swap_status(flag, info)
end subroutine psb_comm_get_swap_status
end module psb_comm_factory_mod

@ -0,0 +1,431 @@
! Merged neighbor-topology module
!
module psb_comm_neighbor_impl_mod
use psb_const_mod
use psb_desc_const_mod, only: psb_proc_id_, psb_n_elem_recv_, psb_elem_recv_, &
& psb_n_elem_send_, psb_elem_send_
use psb_error_mod
use psb_comm_schemes_mod, only: psb_comm_handle_type, psb_comm_ineighbor_alltoallv_
#ifdef PSB_MPI_MOD
use mpi
#endif
implicit none
#ifdef PSB_MPI_H
include 'mpif.h'
#endif
type, extends(psb_comm_handle_type) :: psb_comm_neighbor_handle
integer(psb_mpk_) :: graph_comm = mpi_comm_null
integer(psb_ipk_) :: num_neighbors = 0
integer(psb_mpk_), allocatable :: send_counts(:), recv_counts(:)
integer(psb_mpk_), allocatable :: send_displs(:), recv_displs(:)
integer(psb_ipk_), allocatable :: send_indexes(:)
integer(psb_ipk_), allocatable :: recv_indexes(:)
integer(psb_ipk_) :: total_send = 0
integer(psb_ipk_) :: total_recv = 0
logical :: is_initialized = .false.
logical :: use_persistent_buffers = .false.
integer(psb_mpk_) :: comm_request = mpi_request_null
integer(psb_mpk_) :: persistent_request = mpi_request_null
logical :: persistent_request_ready = .false.
integer(psb_ipk_) :: persistent_buffer_size = 0
contains
procedure, pass :: init => psb_comm_neighbor_init
procedure, pass :: free => neighbor_topology_free
procedure, pass :: set_swap_status => psb_comm_neighbor_set_swap_status
procedure, pass :: get_swap_status => psb_comm_neighbor_get_swap_status
procedure, pass :: topology_init => neighbor_topology_init
procedure, pass :: sizeof => neighbor_topology_sizeof
end type psb_comm_neighbor_handle
contains
! ---------------------------------------------------------------
! neighbor_topology_init
!
! Parse the halo index list (obtained via desc_a%get_list_p)
! and build:
! - MPI dist-graph communicator with only the true neighbors
! - per-neighbor send/recv counts and displacements
! - contiguous gather/scatter index arrays
!
! The topology is stored inside the vector and lazily built
! on the first psi_swapdata call that uses the neighbor-alltoallv
! communication mode.
!
! Arguments:
! topology - the persistent state (output, intent inout)
! halo_index - halo_index array (from get_list_p, intent in)
! num_neighbors - number of exchanges (from get_list_p)
! total_send_elems - total send count (from get_list_p)
! total_recv_elems - total recv count (from get_list_p)
! ctxt - PSBLAS context
! icomm - MPI communicator
! info - error code (output)
! ---------------------------------------------------------------
subroutine neighbor_topology_init(topology, halo_index, num_neighbors, &
& total_send_elems, total_recv_elems, ctxt, icomm, info)
#ifdef PSB_MPI_MOD
use mpi
#endif
implicit none
#ifdef PSB_MPI_H
include 'mpif.h'
#endif
class(psb_comm_neighbor_handle), intent(inout) :: topology
integer(psb_ipk_), intent(in) :: halo_index(:)
integer(psb_ipk_), intent(in) :: num_neighbors, total_send_elems, total_recv_elems
type(psb_ctxt_type), intent(in) :: ctxt
integer(psb_mpk_), intent(in) :: icomm
integer(psb_ipk_), intent(out) :: info
! locals
integer(psb_mpk_) :: iret
integer(psb_ipk_) :: i, k, idx_ptr, num_elem_recv, num_elem_send, partner_proc
integer(psb_ipk_) :: neighbor_count, send_offset, recv_offset
integer(psb_mpk_), allocatable :: source_ranks(:), dest_ranks(:)
integer(psb_mpk_), allocatable :: source_weights(:), dest_weights(:)
integer(psb_mpk_) :: in_degree, out_degree
character(len=40) :: name
integer(psb_ipk_) :: proc_id
integer(psb_ipk_) :: position
integer(psb_ipk_) :: err_act
info = psb_success_
name = 'neighbor_topology_init'
call psb_erractionsave(err_act)
! Clean up any previous state
call topology%free(info)
! ----------------------------------------------------------
! First pass: count neighbors (excluding self) and totals
! ----------------------------------------------------------
topology%num_neighbors = 0
topology%total_send = 0
topology%total_recv = 0
if(size(halo_index) < 1) then
call psb_errpush(psb_err_topology_invalid_args_,name)
goto 9999
end if
allocate(source_ranks(num_neighbors), stat=info)
if (info /= psb_success_) then
info = psb_err_alloc_dealloc_
call psb_errpush(info, name, a_err='Source ranks allocation failed')
goto 9999
end if
allocate(dest_ranks(num_neighbors), stat=info)
if (info /= psb_success_) then
info = psb_err_alloc_dealloc_
call psb_errpush(info, name, a_err='Destination ranks allocation failed')
goto 9999
end if
allocate(source_weights(num_neighbors), stat=info)
if (info /= psb_success_) then
info = psb_err_alloc_dealloc_
call psb_errpush(info, name, a_err='Source weights allocation failed')
goto 9999
end if
allocate(dest_weights(num_neighbors), stat=info)
if (info /= psb_success_) then
info = psb_err_alloc_dealloc_
call psb_errpush(info, name, a_err='Destination weights allocation failed')
goto 9999
end if
allocate(topology%send_counts(num_neighbors), stat=info)
if (info /= psb_success_) then
info = psb_err_alloc_dealloc_
call psb_errpush(info, name, a_err='Send counts allocation failed')
goto 9999
end if
allocate(topology%recv_counts(num_neighbors), stat=info)
if (info /= psb_success_) then
info = psb_err_alloc_dealloc_
call psb_errpush(info, name, a_err='Receive counts allocation failed')
goto 9999
end if
allocate(topology%send_displs(num_neighbors), stat=info)
if (info /= psb_success_) then
info = psb_err_alloc_dealloc_
call psb_errpush(info, name, a_err='Send displacements allocation failed')
goto 9999
end if
allocate(topology%recv_displs(num_neighbors), stat=info)
if (info /= psb_success_) then
info = psb_err_alloc_dealloc_
call psb_errpush(info, name, a_err='Receive displacements allocation failed')
goto 9999
end if
! -----------------------------------------------------------
! Allocate the gather/scatter index arrays
! -----------------------------------------------------------
allocate(topology%send_indexes(total_send_elems), stat=info)
if (info /= psb_success_) then
info = psb_err_alloc_dealloc_
call psb_errpush(info, name, a_err='Send indexes allocation failed')
goto 9999
end if
allocate(topology%recv_indexes(total_recv_elems), stat=info)
if (info /= psb_success_) then
info = psb_err_alloc_dealloc_
call psb_errpush(info, name, a_err='Recv indexes allocation failed')
goto 9999
end if
! -----------------------------------------------------------
! Fill neighbor ranks, weights, counts, displacements,
! and gather/scatter index arrays.
!
! The halo_index layout per neighbor (starting at position):
! position + 0 : process id
! position + 1 : nerv (num recv elements)
! position + 2 .. +1+nerv : recv element indexes
! position + 2+nerv : nesd (num send elements)
! position + 3+nerv .. +2+nerv+nesd : send element indexes
! Total stride per neighbor: nerv + nesd + 3
! -----------------------------------------------------------
send_offset = 0
recv_offset = 0
position = 1
do i = 1, num_neighbors
proc_id = halo_index(position)
num_elem_recv = halo_index(position + 1)
num_elem_send = halo_index(position + num_elem_recv + 2)
! Fill source/destination ranks and weights (weights are all 1 for now)
source_ranks(i) = int(proc_id, psb_mpk_)
dest_ranks(i) = int(proc_id, psb_mpk_)
source_weights(i) = 1
dest_weights(i) = 1
! Counts and displacements (displs set BEFORE accumulating offset)
topology%send_counts(i) = int(num_elem_send, psb_mpk_)
topology%recv_counts(i) = int(num_elem_recv, psb_mpk_)
topology%send_displs(i) = int(send_offset, psb_mpk_)
topology%recv_displs(i) = int(recv_offset, psb_mpk_)
! Fill recv_indexes from halo_index(position+2 .. position+1+nerv)
do k = 1, num_elem_recv
topology%recv_indexes(recv_offset + k) = halo_index(position + psb_elem_recv_ + k - 1)
end do
! Fill send_indexes from halo_index(position+3+nerv .. position+2+nerv+nesd)
do k = 1, num_elem_send
topology%send_indexes(send_offset + k) = halo_index(position + num_elem_recv + psb_elem_send_ + k - 1)
end do
send_offset = send_offset + num_elem_send
recv_offset = recv_offset + num_elem_recv
topology%num_neighbors = topology%num_neighbors + 1
topology%total_send = topology%total_send + num_elem_send
topology%total_recv = topology%total_recv + num_elem_recv
position = position + num_elem_recv + num_elem_send + 3
end do
! ----------------------------------------------------------
! Sanity check: the totals computed from the neighbor list
! should match the totals returned by get_list_p.
! ----------------------------------------------------------
if (topology%total_send /= total_send_elems) then
info = psb_err_topology_args_mismatch_
call psb_errpush(info, name, a_err='Send elements mismatch')
goto 9999
end if
if (topology%total_recv /= total_recv_elems) then
info = psb_err_topology_args_mismatch_
call psb_errpush(info, name, a_err='Receive elements mismatch')
goto 9999
end if
if(topology%num_neighbors /= num_neighbors) then
info = psb_err_topology_args_mismatch_
call psb_errpush(info, name, a_err='Number of neighbors mismatch')
goto 9999
end if
! ----------------------------------------------------------
! Build the dist-graph communicator
! ----------------------------------------------------------
in_degree = topology%num_neighbors !! Just for clarity
out_degree = topology%num_neighbors !! Just for clarity
call mpi_dist_graph_create_adjacent(icomm, &
& in_degree, source_ranks, source_weights, &
& out_degree, dest_ranks, dest_weights, &
& mpi_info_null, .false., & ! Check this line for optimizations
& topology%graph_comm, info)
if (info /= mpi_success) then
info = psb_err_topology_error_
call psb_errpush(info, name)
goto 9999
end if
topology%is_initialized = .true.
! TODO: Is it safe to deallocate these temporary arrays here, or do we need them for the gather/scatter indexes?
! deallocate(source_ranks, dest_ranks, source_weights, dest_weights)
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine neighbor_topology_init
! ---------------------------------------------------------------
! neighbor_topology_free
! Release all resources held by the persistent state.
! ---------------------------------------------------------------
subroutine neighbor_topology_free(this, info)
#ifdef PSB_MPI_MOD
use mpi
#endif
implicit none
#ifdef PSB_MPI_H
include 'mpif.h'
#endif
class(psb_comm_neighbor_handle), intent(inout) :: this
integer(psb_ipk_), intent(out) :: info
integer(psb_mpk_) :: iret
info = psb_success_
if (this%persistent_request_ready) then
if (this%persistent_request /= mpi_request_null) then
call mpi_request_free(this%persistent_request, iret)
end if
this%persistent_request = mpi_request_null
this%persistent_request_ready = .false.
this%persistent_buffer_size = 0
end if
if (this%graph_comm /= mpi_comm_null) then
call mpi_comm_free(this%graph_comm, iret)
this%graph_comm = mpi_comm_null
end if
if (allocated(this%send_counts)) deallocate(this%send_counts)
if (allocated(this%recv_counts)) deallocate(this%recv_counts)
if (allocated(this%send_displs)) deallocate(this%send_displs)
if (allocated(this%recv_displs)) deallocate(this%recv_displs)
if (allocated(this%send_indexes)) deallocate(this%send_indexes)
if (allocated(this%recv_indexes)) deallocate(this%recv_indexes)
this%num_neighbors = 0
this%total_send = 0
this%total_recv = 0
this%is_initialized = .false.
this%comm_request = mpi_request_null
end subroutine neighbor_topology_free
! ---------------------------------------------------------------
! neighbor_topology_sizeof
! Return approximate memory footprint in bytes.
! ---------------------------------------------------------------
function neighbor_topology_sizeof(this) result(val)
implicit none
class(psb_comm_neighbor_handle), intent(in) :: this
integer(psb_epk_) :: val
val = 0
val = val + psb_sizeof_ip * 6 ! scalar integers + logicals
if (allocated(this%send_counts)) val = val + psb_sizeof_ip * size(this%send_counts)
if (allocated(this%recv_counts)) val = val + psb_sizeof_ip * size(this%recv_counts)
if (allocated(this%send_displs)) val = val + psb_sizeof_ip * size(this%send_displs)
if (allocated(this%recv_displs)) val = val + psb_sizeof_ip * size(this%recv_displs)
if (allocated(this%send_indexes)) val = val + psb_sizeof_ip * size(this%send_indexes)
if (allocated(this%recv_indexes)) val = val + psb_sizeof_ip * size(this%recv_indexes)
end function neighbor_topology_sizeof
subroutine psb_comm_neighbor_create(this, comm_type, info)
class(psb_comm_neighbor_handle), intent(inout) :: this
integer(psb_ipk_), intent(in) :: comm_type
integer(psb_ipk_), intent(out) :: info
info = psb_success_
this%comm_type = comm_type
this%id = 0
this%swap_status = 0
this%comm_request = mpi_request_null
this%persistent_request = mpi_request_null
this%persistent_request_ready = .false.
this%persistent_buffer_size = 0
call this%free(info)
end subroutine psb_comm_neighbor_create
subroutine psb_comm_neighbor_destroy(this, info)
class(psb_comm_neighbor_handle), intent(inout) :: this
integer(psb_ipk_), intent(out) :: info
info = psb_success_
this%comm_request = mpi_request_null
this%persistent_request = mpi_request_null
this%persistent_request_ready = .false.
this%persistent_buffer_size = 0
call this%free(info)
end subroutine psb_comm_neighbor_destroy
subroutine psb_comm_neighbor_set_swap_status(this, flag, info)
class(psb_comm_neighbor_handle), intent(inout) :: this
integer(psb_ipk_), intent(in) :: flag
integer(psb_ipk_), intent(out) :: info
info = psb_success_
this%swap_status = flag
end subroutine psb_comm_neighbor_set_swap_status
subroutine psb_comm_neighbor_get_swap_status(this, flag, info)
class(psb_comm_neighbor_handle), intent(in) :: this
integer(psb_ipk_), intent(out) :: flag
integer(psb_ipk_), intent(out) :: info
info = psb_success_
flag = this%swap_status
end subroutine psb_comm_neighbor_get_swap_status
subroutine psb_comm_neighbor_init(this, info)
class(psb_comm_neighbor_handle), intent(inout) :: this
integer(psb_ipk_), intent(out) :: info
info = 0
this%comm_type = psb_comm_ineighbor_alltoallv_
this%id = 0
this%swap_status = 0
this%is_initialized = .false.
this%use_persistent_buffers = .false.
this%comm_request = mpi_request_null
this%persistent_request = mpi_request_null
this%persistent_request_ready = .false.
this%persistent_buffer_size = 0
end subroutine psb_comm_neighbor_init
end module psb_comm_neighbor_impl_mod

@ -0,0 +1,68 @@
!
! psb_comm_mod - communication handle module
!
module psb_comm_schemes_mod
use psb_const_mod
implicit none
! Communication type enumeration (keeps compatibility with integer selectors)
enum, bind(c)
enumerator psb_comm_unknown_
enumerator psb_comm_isend_irecv_
enumerator psb_comm_ineighbor_alltoallv_
enumerator psb_comm_persistent_ineighbor_alltoallv_
end enum
enum, bind(c)
enumerator psb_comm_status_unknown_
enumerator psb_comm_status_start_
enumerator psb_comm_status_wait_
end enum
! (abstract interfaces moved below type definition)
! --- comm handle type ---
type, abstract :: psb_comm_handle_type
integer(psb_ipk_) :: id = -1
integer(psb_ipk_) :: comm_type = psb_comm_unknown_
integer(psb_ipk_) :: swap_status = psb_comm_status_unknown_
contains
procedure(psb_comm_init), deferred :: init
procedure(psb_comm_free), deferred :: free
procedure(psb_comm_set_swap_status), deferred :: set_swap_status
procedure(psb_comm_get_swap_status), deferred :: get_swap_status
end type psb_comm_handle_type
! --- abstract interfaces ---
abstract interface
subroutine psb_comm_init(this, info)
import :: psb_ipk_, psb_comm_handle_type
class(psb_comm_handle_type), intent(inout) :: this
integer(psb_ipk_), intent(out) :: info
end subroutine
subroutine psb_comm_free(this, info)
import :: psb_ipk_, psb_comm_handle_type
class(psb_comm_handle_type), intent(inout) :: this
integer(psb_ipk_), intent(out) :: info
end subroutine
subroutine psb_comm_set_swap_status(this, flag, info)
import :: psb_ipk_, psb_comm_handle_type
class(psb_comm_handle_type), intent(inout) :: this
integer(psb_ipk_), intent(in) :: flag
integer(psb_ipk_), intent(out) :: info
end subroutine
subroutine psb_comm_get_swap_status(this, flag, info)
import :: psb_ipk_, psb_comm_handle_type
class(psb_comm_handle_type), intent(in) :: this
integer(psb_ipk_), intent(out) :: flag
integer(psb_ipk_), intent(out) :: info
end subroutine
end interface
contains
end module psb_comm_schemes_mod

@ -38,18 +38,18 @@ module psi_d_comm_v_mod
interface psi_swapdata
! ---------------------------------------------------------------
! Wrapper that calls different communications schemes depending on
! flag variable using communication buff obtained from desc_a%get_list_p
! swap_status variable using communication buff obtained from desc_a%get_list_p
! ---------------------------------------------------------------
module subroutine psi_dswapdata_vect(flag,beta,y,desc_a,info,data)
integer(psb_ipk_), intent(in) :: flag
module subroutine psi_dswapdata_vect(swap_status,beta,y,desc_a,info,data)
integer(psb_ipk_), intent(in) :: swap_status
real(psb_dpk_), intent(in) :: beta
class(psb_d_base_vect_type), intent(inout) :: y
type(psb_desc_type), target :: desc_a
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_), optional :: data
end subroutine psi_dswapdata_vect
module subroutine psi_dswapdata_multivect(flag,beta,y,desc_a,info,data)
integer(psb_ipk_), intent(in) :: flag
module subroutine psi_dswapdata_multivect(swap_status,beta,y,desc_a,info,data)
integer(psb_ipk_), intent(in) :: swap_status
real(psb_dpk_), intent(in) :: beta
class(psb_d_base_multivect_type), intent(inout) :: y
type(psb_desc_type), target :: desc_a
@ -63,18 +63,18 @@ module psi_d_comm_v_mod
! ---------------------------------------------------------------
! Upper call in order to populate idx using desc_a%get_list_p
! and then call different communications schemes depending
! on flag variable
! on swap_status variable
! ---------------------------------------------------------------
module subroutine psi_dswaptran_vect(flag,beta,y,desc_a,info,data)
integer(psb_ipk_), intent(in) :: flag
module subroutine psi_dswaptran_vect(swap_status,beta,y,desc_a,info,data)
integer(psb_ipk_), intent(in) :: swap_status
real(psb_dpk_), intent(in) :: beta
class(psb_d_base_vect_type), intent(inout) :: y
type(psb_desc_type), target :: desc_a
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_), optional :: data
end subroutine psi_dswaptran_vect
module subroutine psi_dswaptran_multivect(flag,beta,y,desc_a,info,data)
integer(psb_ipk_), intent(in) :: flag
module subroutine psi_dswaptran_multivect(swap_status,beta,y,desc_a,info,data)
integer(psb_ipk_), intent(in) :: swap_status
real(psb_dpk_), intent(in) :: beta
class(psb_d_base_multivect_type), intent(inout) :: y
type(psb_desc_type), target :: desc_a

@ -49,7 +49,8 @@ module psb_d_base_vect_mod
use psb_realloc_mod
use psb_i_base_vect_mod
use psb_l_base_vect_mod
use psb_neighbor_topology_mod
use psb_comm_schemes_mod, only: psb_comm_handle_type, psb_comm_isend_irecv_, psb_comm_unknown_
use psb_comm_factory_mod, only: psb_comm_init, psb_comm_free
!> \namespace psb_base_mod \class psb_d_base_vect_type
@ -64,10 +65,10 @@ module psb_d_base_vect_mod
!!
type psb_d_base_vect_type
!> Values.
real(psb_dpk_), allocatable :: v(:)
real(psb_dpk_), allocatable :: combuf(:)
integer(psb_mpk_), allocatable :: comid(:,:) ! This is used only for Isend/Irecv scheme, to store the communication handles for each neighbor
integer(psb_mpk_) :: communication_handle ! This is used only for Isend/Irecv scheme, to store the communication handle for the whole halo exchange
real(psb_dpk_), allocatable :: v(:)
real(psb_dpk_), allocatable :: combuf(:)
! Polymorphic communication handle stored at vector level.
class(psb_comm_handle_type), allocatable :: comm_handle
!> vector bldstate:
!! null: pristine;
@ -81,9 +82,6 @@ module psb_d_base_vect_mod
integer(psb_ipk_), private :: dupl = psb_dupl_null_
integer(psb_ipk_), private :: ncfs = 0
integer(psb_ipk_), allocatable :: iv(:)
type(psb_neighbor_topology_type) :: neighbor_topology
contains
!
! Constructors/allocators
@ -147,8 +145,6 @@ module psb_d_base_vect_mod
procedure, nopass :: device_wait => d_base_device_wait
procedure, pass(x) :: maybe_free_buffer => d_base_maybe_free_buffer
procedure, pass(x) :: free_buffer => d_base_free_buffer
procedure, pass(x) :: new_comid => d_base_new_comid
procedure, pass(x) :: free_comid => d_base_free_comid
!
! Basic info
@ -179,7 +175,13 @@ module psb_d_base_vect_mod
generic, public :: sct => sctb, sctb_x, sctb_buf
procedure, pass(x) :: check_addr => d_base_check_addr
! Communication lifecycle split:
! - `create_comm`: allocate/select a fresh handle implementation via factory.
! - `init_comm`: configure the current handle instance from an existing one
! (e.g., copy `id` and swap status), and reset buffers;
! it recreates the handle only if missing or scheme changes.
! - `destroy_comm`/`free_comm`: release current handle resources.
!
@ -255,12 +257,6 @@ module psb_d_base_vect_mod
procedure, pass(x) :: minquotient_v => d_base_minquotient_v
procedure, pass(x) :: minquotient_a2 => d_base_minquotient_a2
generic, public :: minquotient => minquotient_v, minquotient_a2
! Methods used to handle topology in neighbor_alltoallv communication scheme
procedure, pass(x) :: init_topology => d_base_init_topology
procedure, pass(x) :: free_topology => d_base_free_topology
end type psb_d_base_vect_type
public :: psb_d_base_vect
@ -416,6 +412,11 @@ contains
call psb_realloc(n,x%iv,info)
call x%set_ncfs(0)
end if
if (info == psb_success_) then
if (.not. allocated(x%comm_handle)) then
call psb_comm_init(psb_comm_isend_irecv_, x%comm_handle, info)
end if
end if
end subroutine d_base_all
@ -434,6 +435,9 @@ contains
integer(psb_ipk_), intent(out) :: info
allocate(psb_d_base_vect_type :: y, stat=info)
if (info == psb_success_) then
call psb_comm_init(psb_comm_isend_irecv_, y%comm_handle, info)
end if
end subroutine d_base_mold
@ -441,7 +445,7 @@ contains
use psi_serial_mod
use psb_realloc_mod
implicit none
class(psb_d_base_vect_type), intent(out) :: x
class(psb_d_base_vect_type), intent(inout) :: x
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: clear
logical :: clear_
@ -458,6 +462,11 @@ contains
call x%set_host()
call x%set_upd()
end if
if (info == psb_success_) then
if (.not. allocated(x%comm_handle)) then
call psb_comm_init(psb_comm_isend_irecv_, x%comm_handle, info)
end if
end if
end subroutine d_base_reinit
@ -837,11 +846,16 @@ contains
class(psb_d_base_vect_type), intent(inout) :: x
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: info_comm
info = 0
if (allocated(x%v)) deallocate(x%v, stat=info)
if ((info == 0).and.allocated(x%combuf)) call x%free_buffer(info)
if ((info == 0).and.allocated(x%comid)) call x%free_comid(info)
if ((info == 0).and.allocated(x%iv)) deallocate(x%iv, stat=info)
if ((info == 0).and.allocated(x%iv)) deallocate(x%iv, stat=info)
if ((info == 0).and.allocated(x%comm_handle)) then
call psb_comm_free(x%comm_handle, info_comm)
if (info_comm /= psb_success_) info = info_comm
end if
if (info /= 0) call &
& psb_errpush(psb_err_alloc_dealloc_,'vect_free')
call x%set_null()
@ -888,24 +902,6 @@ contains
end subroutine d_base_maybe_free_buffer
!
!> Function base_free_comid:
!! \memberof psb_d_base_vect_type
!! \brief Free aux MPI communication id buffer
!!
!! \param info return code
!!
!
subroutine d_base_free_comid(x,info)
use psb_realloc_mod
implicit none
class(psb_d_base_vect_type), intent(inout) :: x
integer(psb_ipk_), intent(out) :: info
if (allocated(x%comid)) &
& deallocate(x%comid,stat=info)
end subroutine d_base_free_comid
function d_base_get_ncfs(x) result(res)
implicit none
class(psb_d_base_vect_type), intent(in) :: x
@ -1109,12 +1105,25 @@ contains
implicit none
class(psb_d_base_vect_type), intent(in) :: x
class(psb_d_base_vect_type), intent(out) :: y
integer(psb_ipk_) :: info
integer(psb_ipk_) :: swap_status
if (allocated(x%v)) call y%bld(x%v)
call y%set_state(x%get_state())
call y%set_dupl(x%get_dupl())
call y%set_ncfs(x%get_ncfs())
if (allocated(x%iv)) y%iv = x%iv
if (allocated(x%comm_handle)) then
call psb_comm_init(x%comm_handle%comm_type, y%comm_handle, info)
if (info /= psb_success_) return
y%comm_handle%id = x%comm_handle%id
call x%comm_handle%get_swap_status(swap_status, info)
if (info /= psb_success_) return
call y%comm_handle%set_swap_status(swap_status, info)
if (info /= psb_success_) return
else
call psb_comm_init(psb_comm_isend_irecv_, y%comm_handle, info)
end if
end subroutine d_base_cpy
!
@ -2364,6 +2373,56 @@ contains
end subroutine d_base_device_wait
subroutine d_base_init_comm(x, comm_handle, info)
! `init_comm` is intentionally a configuration step.
! It does not define the communication API surface itself; instead it:
! 1) resets local communication buffers,
! 2) ensures a handle exists with the requested concrete scheme,
! recreating it only when needed (missing handle or type change),
! 3) copies runtime state (`id`, swap status) from the input handle.
implicit none
class(psb_d_base_vect_type), intent(inout) :: x
class(psb_comm_handle_type), intent(in), pointer :: comm_handle
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: comm_type, swap_status
logical :: need_new_handle
info = psb_success_
! Reset/initialize communication related storage. Actual
! topology/building is done lazily by neighbor_topology_init
if (allocated(x%combuf)) then
deallocate(x%combuf)
end if
comm_type = psb_comm_isend_irecv_
if (associated(comm_handle)) then
comm_type = comm_handle%comm_type
if (comm_type == psb_comm_unknown_) comm_type = psb_comm_isend_irecv_
end if
! Recreate only when needed (missing handle or scheme change).
need_new_handle = .not. allocated(x%comm_handle)
if (.not. need_new_handle) then
need_new_handle = (x%comm_handle%comm_type /= comm_type)
end if
if (need_new_handle) then
call psb_comm_init(comm_type, x%comm_handle, info)
if (info /= psb_success_) return
end if
if (associated(comm_handle)) then
x%comm_handle%id = comm_handle%id
call comm_handle%get_swap_status(swap_status, info)
if (info /= psb_success_) return
call x%comm_handle%set_swap_status(swap_status, info)
if (info /= psb_success_) return
end if
end subroutine d_base_init_comm
function d_base_use_buffer() result(res)
logical :: res
@ -2380,16 +2439,6 @@ contains
call psb_realloc(n,x%combuf,info)
end subroutine d_base_new_buffer
subroutine d_base_new_comid(n,x,info)
use psb_realloc_mod
implicit none
class(psb_d_base_vect_type), intent(inout) :: x
integer(psb_ipk_), intent(in) :: n
integer(psb_ipk_), intent(out) :: info
call psb_realloc(n,2_psb_ipk_,x%comid,info)
end subroutine d_base_new_comid
!
! shortcut alpha=1 beta=0
@ -2622,35 +2671,6 @@ contains
call z%addconst(x%v,b,info)
end subroutine d_base_addconst_v2
! --------------------------------------------------------------------
! Implementation of methods used for neighbor alltoallv communication
! --------------------------------------------------------------------
subroutine d_base_init_topology(x, halo_index, num_exchanges, &
& total_send_elems, total_recv_elems, ctxt, icomm, info)
implicit none
class(psb_d_base_vect_type), intent(inout) :: x
integer(psb_ipk_), intent(in) :: halo_index(:)
integer(psb_ipk_), intent(in) :: num_exchanges, total_send_elems, total_recv_elems
type(psb_ctxt_type), intent(in) :: ctxt
integer(psb_mpk_), intent(in) :: icomm
integer(psb_ipk_), intent(out) :: info
call x%neighbor_topology%init(halo_index, num_exchanges, &
& total_send_elems, total_recv_elems, ctxt, icomm, info)
end subroutine d_base_init_topology
subroutine d_base_free_topology(x, info)
implicit none
class(psb_d_base_vect_type), intent(inout) :: x
integer(psb_ipk_), intent(out) :: info
call x%neighbor_topology%free(info)
end subroutine d_base_free_topology
! --------------------------------------------------------------------
end module psb_d_base_vect_mod
@ -2660,7 +2680,7 @@ module psb_d_base_multivect_mod
use psb_error_mod
use psb_realloc_mod
use psb_d_base_vect_mod
use psb_neighbor_topology_mod
use psb_comm_schemes_mod, only: psb_comm_handle_type
!> \namespace psb_base_mod \class psb_d_base_vect_type
!! The psb_d_base_vect_type
@ -2679,8 +2699,7 @@ module psb_d_base_multivect_mod
!> Values.
real(psb_dpk_), allocatable :: v(:,:)
real(psb_dpk_), allocatable :: combuf(:)
integer(psb_mpk_), allocatable :: comid(:,:) ! This is used only for Isend/Irecv scheme, to store the communication handles for each neighbor
integer(psb_mpk_) :: communication_handle ! This is used only for Isend/Irecv scheme, to store the communication handle for the whole halo exchange
! neighbor-specific communication state removed; comm_handle owned below
!> vector bldstate:
!! null: pristine;
@ -2695,7 +2714,7 @@ module psb_d_base_multivect_mod
integer(psb_ipk_), private :: ncfs = 0
integer(psb_ipk_), allocatable :: iv(:)
type(psb_neighbor_topology_type) :: neighbor_topology
class(psb_comm_handle_type), allocatable :: comm_handle
contains
!
@ -2804,8 +2823,6 @@ module psb_d_base_multivect_mod
procedure, nopass :: device_wait => d_base_mlv_device_wait
procedure, pass(x) :: maybe_free_buffer => d_base_mlv_maybe_free_buffer
procedure, pass(x) :: free_buffer => d_base_mlv_free_buffer
procedure, pass(x) :: new_comid => d_base_mlv_new_comid
procedure, pass(x) :: free_comid => d_base_mlv_free_comid
!
! Gather/scatter. These are needed for MPI interfacing.
@ -2823,11 +2840,6 @@ module psb_d_base_multivect_mod
procedure, pass(y) :: sctb_buf => d_base_mlv_sctb_buf
generic, public :: sct => sctb, sctbr2, sctb_x, sctb_buf
! Neighbor alltoallv communication topology handling
procedure, pass(x) :: init_topology => d_base_mlv_init_topology
procedure, pass(x) :: free_topology => d_base_mlv_free_topology
end type psb_d_base_multivect_type
interface psb_d_base_multivect
@ -4085,17 +4097,6 @@ contains
call psb_realloc(n*nc,x%combuf,info)
end subroutine d_base_mlv_new_buffer
subroutine d_base_mlv_new_comid(n,x,info)
use psb_realloc_mod
implicit none
class(psb_d_base_multivect_type), intent(inout) :: x
integer(psb_ipk_), intent(in) :: n
integer(psb_ipk_), intent(out) :: info
call psb_realloc(n,2_psb_ipk_,x%comid,info)
end subroutine d_base_mlv_new_comid
subroutine d_base_mlv_maybe_free_buffer(x,info)
use psb_realloc_mod
implicit none
@ -4119,17 +4120,6 @@ contains
& deallocate(x%combuf,stat=info)
end subroutine d_base_mlv_free_buffer
subroutine d_base_mlv_free_comid(x,info)
use psb_realloc_mod
implicit none
class(psb_d_base_multivect_type), intent(inout) :: x
integer(psb_ipk_), intent(out) :: info
if (allocated(x%comid)) &
& deallocate(x%comid,stat=info)
end subroutine d_base_mlv_free_comid
!
! Gather: Y = beta * Y + alpha * X(IDX(:))
!
@ -4351,35 +4341,8 @@ contains
end subroutine d_base_mlv_device_wait
! --------------------------------------------------------------------
! Implementation of methods used for neighbor alltoallv communication
! --------------------------------------------------------------------
subroutine d_base_mlv_init_topology(x, halo_index, num_exchanges, &
& total_send_elems, total_recv_elems, ctxt, icomm, info)
implicit none
class(psb_d_base_multivect_type), intent(inout) :: x
integer(psb_ipk_), intent(in) :: halo_index(:)
integer(psb_ipk_), intent(in) :: num_exchanges, total_send_elems, total_recv_elems
type(psb_ctxt_type), intent(in) :: ctxt
integer(psb_mpk_), intent(in) :: icomm
integer(psb_ipk_), intent(out) :: info
call x%neighbor_topology%init(halo_index, num_exchanges, &
& total_send_elems, total_recv_elems, ctxt, icomm, info)
end subroutine d_base_mlv_init_topology
subroutine d_base_mlv_free_topology(x, info)
implicit none
class(psb_d_base_multivect_type), intent(inout) :: x
integer(psb_ipk_), intent(out) :: info
call x%neighbor_topology%free(info)
end subroutine d_base_mlv_free_topology
! --------------------------------------------------------------------
!
! Communication routines for multivectors (delegates to base vector implementation)
!
end module psb_d_base_multivect_mod

@ -18,12 +18,18 @@
program psb_comm_test
use psb_base_mod
use psi_mod
use psb_comm_factory_mod, only: psb_comm_init, psb_comm_free
use psb_comm_schemes_mod, only: psb_comm_ineighbor_alltoallv_, psb_comm_persistent_ineighbor_alltoallv_, &
& psb_comm_isend_irecv_
use psb_comm_schemes_mod, only: psb_comm_status_start_, psb_comm_status_wait_, psb_comm_status_unknown_
implicit none
! ---- parameters ----
integer(psb_ipk_) :: idim
integer(psb_ipk_) :: argc
integer(psb_ipk_) :: iters
character(len=32) :: arg
character(len=16) :: mode
! ---- descriptor / context ----
type(psb_ctxt_type) :: ctxt
@ -33,11 +39,11 @@ program psb_comm_test
integer(psb_lpk_), allocatable :: myidx(:)
! ---- vectors ----
type(psb_d_vect_type) :: v_baseline, v_neighbor
type(psb_d_vect_type) :: v_baseline, v_neighbor, v_neighbor_persistent
! ---- temporary / comparison arrays ----
real(psb_dpk_), allocatable :: vals(:)
real(psb_dpk_), allocatable :: result_baseline(:), result_neighbor(:)
real(psb_dpk_), allocatable :: result_baseline(:), result_neighbor(:), result_persistent(:)
real(psb_dpk_), allocatable :: expected(:)
! ---- halo index bookkeeping ----
@ -46,7 +52,9 @@ program psb_comm_test
! ---- error / reporting ----
integer(psb_ipk_) :: n_pass, n_total, imode
logical :: comm_ok
real(psb_dpk_) :: err, tol
real(psb_dpk_) :: t0, t1, dt, tsum_baseline, tsum_neighbor, tsum_neighbor_persistent
integer(psb_lpk_), allocatable :: glob_col(:)
character(len=40) :: name
@ -54,6 +62,8 @@ program psb_comm_test
tol = 1.0d-12
n_pass = 0
n_total = 0
iters = 5
mode = 'both'
! ---- parse command-line argument for idim ----
idim = 10
@ -64,7 +74,22 @@ program psb_comm_test
if (i < argc) then
call get_command_argument(i+1, arg)
read(arg, *) idim
exit
end if
else if (trim(arg) == '--iters') then
if (i < argc) then
call get_command_argument(i+1, arg)
read(arg, *) iters
end if
end if
end do
! parse optional mode flag
do i = 1, argc
call get_command_argument(i, arg)
if (trim(arg) == '--mode') then
if (i < argc) then
call get_command_argument(i+1, arg)
read(arg, *) mode
end if
end if
end do
@ -135,8 +160,10 @@ program psb_comm_test
! ==================================================================
call psb_geall(v_baseline, desc_a, info)
call psb_geall(v_neighbor, desc_a, info)
call psb_geall(v_neighbor_persistent, desc_a, info)
call psb_geasb(v_baseline, desc_a, info, scratch=.true.)
call psb_geasb(v_neighbor, desc_a, info, scratch=.true.)
call psb_geasb(v_neighbor_persistent, desc_a, info, scratch=.true.)
! Fill owned entries with the global index value
allocate(vals(ncol))
@ -146,6 +173,7 @@ program psb_comm_test
end do
call v_baseline%set_vect(vals)
call v_neighbor%set_vect(vals)
call v_neighbor_persistent%set_vect(vals)
deallocate(vals)
! ==================================================================
@ -162,36 +190,162 @@ program psb_comm_test
! ==================================================================
! 6. Baseline halo exchange (Isend/Irecv in one call)
! ==================================================================
imode = IOR(psb_swap_send_, psb_swap_recv_)
! v_baseline%v is a psb_d_base_vect_type
call psi_swapdata(flag=imode, beta=dzero, y=v_baseline%v, desc_a=desc_a, info=info, data=psb_comm_halo_)
call psi_swapdata( &
swap_status=psb_comm_status_start_, &
beta=dzero, &
y=v_baseline%v, &
desc_a=desc_a, &
info=info, &
data=psb_comm_halo_)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'baseline swap error:', info
call psb_abort(ctxt)
end if
call psi_swapdata( &
swap_status=psb_comm_status_wait_, &
beta=dzero, &
y=v_baseline%v, &
desc_a=desc_a, &
info=info, &
data=psb_comm_halo_)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'baseline swap error:', info
call psb_abort(ctxt)
end if
! ==================================================================
! 7. Neighbor topology halo exchange (start + wait)
! ==================================================================
imode = psb_swap_start_
call psi_swapdata(imode, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_)
call psb_comm_init(psb_comm_ineighbor_alltoallv_, v_neighbor%v%comm_handle, info)
if (info /= 0) then
write(psb_err_unit,*) my_rank, 'psb_comm_init neighbor error:', info
call psb_abort(ctxt)
end if
call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'neighbor start error:', info
call psb_abort(ctxt)
end if
imode = psb_swap_wait_
call psi_swapdata(imode, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_)
call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'neighbor wait error:', info
call psb_abort(ctxt)
end if
! ==================================================================
! 7b. Persistent-neighbor halo exchange (start + wait)
! ==================================================================
call psb_comm_init(psb_comm_persistent_ineighbor_alltoallv_, v_neighbor_persistent%v%comm_handle, info)
if (info /= 0) then
write(psb_err_unit,*) my_rank, 'psb_comm_init persistent-neighbor error:', info
call psb_abort(ctxt)
end if
call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'persistent-neighbor start error:', info
call psb_abort(ctxt)
end if
call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'persistent-neighbor wait error:', info
call psb_abort(ctxt)
end if
! ==================================================================
! 8. Performance: repeat exchanges and measure timings
! ==================================================================
if (my_rank == 0) then
write(psb_out_unit,'("Timing: running ",i0," iterations for baseline, neighbor and persistent-neighbor")') iters
end if
tsum_baseline = 0.0_psb_dpk_
tsum_neighbor = 0.0_psb_dpk_
tsum_neighbor_persistent = 0.0_psb_dpk_
call psb_comm_init(psb_comm_isend_irecv_, v_baseline%v%comm_handle, info)
call psb_comm_init(psb_comm_ineighbor_alltoallv_, v_neighbor%v%comm_handle, info)
call psb_comm_init(psb_comm_persistent_ineighbor_alltoallv_, v_neighbor_persistent%v%comm_handle, info)
! ---- Comm check: verify selected communication schemes ----
n_total = n_total + 1
comm_ok = allocated(v_baseline%v%comm_handle) .and. allocated(v_neighbor%v%comm_handle) .and. &
& allocated(v_neighbor_persistent%v%comm_handle)
if (comm_ok) then
comm_ok = (v_baseline%v%comm_handle%comm_type == psb_comm_isend_irecv_) .and. &
& (v_neighbor%v%comm_handle%comm_type == psb_comm_ineighbor_alltoallv_) .and. &
& (v_neighbor_persistent%v%comm_handle%comm_type == psb_comm_persistent_ineighbor_alltoallv_)
end if
if (my_rank == 0) then
if (comm_ok) then
write(psb_out_unit,'(" [PASS] comm scheme selection : baseline/neighbor/persistent OK")')
n_pass = n_pass + 1
else
write(psb_out_unit,'(" [FAIL] comm scheme selection : unexpected comm_type mapping")')
end if
end if
do i = 1, iters
! baseline timing
t0 = psb_wtime()
call psi_swapdata( &
swap_status=psb_comm_status_start_, &
beta=dzero, &
y=v_baseline%v, &
desc_a=desc_a, &
info=info, &
data=psb_comm_halo_)
call psi_swapdata( &
swap_status=psb_comm_status_wait_, &
beta=dzero, &
y=v_baseline%v, &
desc_a=desc_a, &
info=info, &
data=psb_comm_halo_)
t1 = psb_wtime()
dt = t1 - t0
call psb_amx(ctxt, dt)
tsum_baseline = tsum_baseline + dt
! neighbor timing (start + wait)
t0 = psb_wtime()
call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_)
call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_)
t1 = psb_wtime()
dt = t1 - t0
call psb_amx(ctxt, dt)
tsum_neighbor = tsum_neighbor + dt
! persistent-neighbor timing (start + wait)
t0 = psb_wtime()
call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_)
call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_)
t1 = psb_wtime()
dt = t1 - t0
call psb_amx(ctxt, dt)
tsum_neighbor_persistent = tsum_neighbor_persistent + dt
end do
if (my_rank == 0) then
write(psb_out_unit,'(" Avg baseline time : ",es12.5)') (tsum_baseline / real(iters,psb_dpk_))
write(psb_out_unit,'(" Tot baseline time : ",es12.5)') tsum_baseline
write(psb_out_unit,'(" Avg neighbor time : ",es12.5)') (tsum_neighbor / real(iters,psb_dpk_))
write(psb_out_unit,'(" Tot neighbor time : ",es12.5)') tsum_neighbor
write(psb_out_unit,'(" Avg pers-neigh time: ",es12.5)') (tsum_neighbor_persistent / real(iters,psb_dpk_))
write(psb_out_unit,'(" Tot pers-neigh time: ",es12.5)') tsum_neighbor_persistent
end if
! ==================================================================
! 8. Extract results and compare
! ==================================================================
result_baseline = v_baseline%get_vect()
result_neighbor = v_neighbor%get_vect()
result_persistent = v_neighbor_persistent%get_vect()
! ---- Test 1: cross-check baseline vs neighbor (all entries) ----
n_total = n_total + 1
@ -232,15 +386,41 @@ program psb_comm_test
end if
end if
! ---- Test 4: repeat neighbor exchange (topology reuse) ----
! ---- Test 4: cross-check baseline vs persistent-neighbor (all entries) ----
n_total = n_total + 1
err = maxval(abs(result_baseline(1:ncol) - result_persistent(1:ncol)))
call psb_amx(ctxt, err)
if (my_rank == 0) then
if (err < tol) then
write(psb_out_unit,'(" [PASS] cross-check baseline vs pers-nei : err = ",es12.5)') err
n_pass = n_pass + 1
else
write(psb_out_unit,'(" [FAIL] cross-check baseline vs pers-nei : err = ",es12.5)') err
end if
end if
! ---- Test 5: persistent-neighbor absolute correctness ----
n_total = n_total + 1
err = maxval(abs(result_persistent(1:ncol) - expected(1:ncol)))
call psb_amx(ctxt, err)
if (my_rank == 0) then
if (err < tol) then
write(psb_out_unit,'(" [PASS] pers-neigh absolute correctness : err = ",es12.5)') err
n_pass = n_pass + 1
else
write(psb_out_unit,'(" [FAIL] pers-neigh absolute correctness : err = ",es12.5)') err
end if
end if
! ---- Test 6: repeat neighbor exchange (topology reuse) ----
! Reset halo entries to zero, run again, and check
do i = nrow+1, ncol
result_neighbor(i) = dzero
end do
call v_neighbor%set_vect(result_neighbor)
call psi_swapdata(psb_swap_start_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_)
call psi_swapdata(psb_swap_wait_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_)
call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_)
call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_)
result_neighbor = v_neighbor%get_vect()
n_total = n_total + 1
@ -255,6 +435,28 @@ program psb_comm_test
end if
end if
! ---- Test 7: repeat persistent-neighbor exchange (buffer reuse) ----
do i = nrow+1, ncol
result_persistent(i) = dzero
end do
call v_neighbor_persistent%set_vect(result_persistent)
call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_)
call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_)
result_persistent = v_neighbor_persistent%get_vect()
n_total = n_total + 1
err = maxval(abs(result_persistent(1:ncol) - expected(1:ncol)))
call psb_amx(ctxt, err)
if (my_rank == 0) then
if (err < tol) then
write(psb_out_unit,'(" [PASS] pers-neigh buffer reuse : err = ",es12.5)') err
n_pass = n_pass + 1
else
write(psb_out_unit,'(" [FAIL] pers-neigh buffer reuse : err = ",es12.5)') err
end if
end if
! ==================================================================
! 9. Summary
! ==================================================================
@ -272,9 +474,11 @@ program psb_comm_test
! ==================================================================
! 10. Cleanup
! ==================================================================
deallocate(result_baseline, result_neighbor, expected, glob_col)
call psb_gefree(v_baseline, desc_a, info)
deallocate(result_baseline, result_neighbor, result_persistent, expected, glob_col)
9999 call psb_gefree(v_baseline, desc_a, info)
call psb_gefree(v_neighbor, desc_a, info)
call psb_gefree(v_neighbor_persistent, desc_a, info)
call psb_cdfree(desc_a, info)
call psb_exit(ctxt)

Loading…
Cancel
Save