From 336f7bf1320108740d35854d8e0aa57cf190bca7 Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Mon, 28 Mar 2022 10:32:29 +0200 Subject: [PATCH] Defined non-blocking version of PSB_SUM. --- base/modules/Makefile | 2 +- base/modules/desc/psb_desc_const_mod.f90 | 3 + base/modules/penv/psi_c_collective_mod.F90 | 179 +++++++++++++++----- base/modules/penv/psi_d_collective_mod.F90 | 179 +++++++++++++++----- base/modules/penv/psi_e_collective_mod.F90 | 179 +++++++++++++++----- base/modules/penv/psi_i2_collective_mod.F90 | 179 +++++++++++++++----- base/modules/penv/psi_m_collective_mod.F90 | 179 +++++++++++++++----- base/modules/penv/psi_s_collective_mod.F90 | 179 +++++++++++++++----- base/modules/penv/psi_z_collective_mod.F90 | 179 +++++++++++++++----- 9 files changed, 970 insertions(+), 288 deletions(-) diff --git a/base/modules/Makefile b/base/modules/Makefile index 8d50011f..9951bd89 100644 --- a/base/modules/Makefile +++ b/base/modules/Makefile @@ -375,7 +375,7 @@ psblas/psb_s_psblas_mod.o psblas/psb_c_psblas_mod.o psblas/psb_d_psblas_mod.o ps psb_base_mod.o: $(MODULES) -penv/psi_penv_mod.o: penv/psi_penv_mod.F90 psb_const_mod.o serial/psb_vect_mod.o serial/psb_mat_mod.o +penv/psi_penv_mod.o: penv/psi_penv_mod.F90 psb_const_mod.o serial/psb_vect_mod.o serial/psb_mat_mod.o desc/psb_desc_const_mod.o $(FC) $(FINCLUDES) $(FDEFINES) $(FCOPT) $(EXTRA_OPT) -c $< -o $@ psb_penv_mod.o: psb_penv_mod.F90 $(COMMINT) $(BASIC_MODS) diff --git a/base/modules/desc/psb_desc_const_mod.f90 b/base/modules/desc/psb_desc_const_mod.f90 index aa2ea2fe..8953aafc 100644 --- a/base/modules/desc/psb_desc_const_mod.f90 +++ b/base/modules/desc/psb_desc_const_mod.f90 @@ -48,6 +48,9 @@ module psb_desc_const_mod ! The following are bit fields. integer(psb_ipk_), parameter :: psb_swap_send_=1, psb_swap_recv_=2 integer(psb_ipk_), parameter :: psb_swap_sync_=4, psb_swap_mpi_=8 + integer(psb_ipk_), parameter :: psb_collective_start_=1, psb_collective_end_=2 + integer(psb_ipk_), parameter :: psb_collective_sync_=4 + ! Choice among lists on which to base data exchange integer(psb_ipk_), parameter :: psb_no_comm_=-1 integer(psb_ipk_), parameter :: psb_comm_halo_=1, psb_comm_ovr_=2 diff --git a/base/modules/penv/psi_c_collective_mod.F90 b/base/modules/penv/psi_c_collective_mod.F90 index 17113ec0..37609891 100644 --- a/base/modules/penv/psi_c_collective_mod.F90 +++ b/base/modules/penv/psi_c_collective_mod.F90 @@ -31,7 +31,8 @@ ! module psi_c_collective_mod use psi_penv_mod - + use psb_desc_const_mod + interface psb_sum module procedure psb_csums, psb_csumv, psb_csumm @@ -79,7 +80,7 @@ contains ! SUM ! - subroutine psb_csums(ctxt,dat,root) + subroutine psb_csums(ctxt,dat,root,mode,request) #ifdef MPI_MOD use mpi #endif @@ -90,11 +91,16 @@ contains type(psb_ctxt_type), intent(in) :: ctxt complex(psb_spk_), intent(inout) :: dat integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ complex(psb_spk_) :: dat_ - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo - + logical :: collective_start, collective_end, collective_sync + #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -104,17 +110,41 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call mpi_allreduce(dat,dat_,1,psb_mpi_c_spk_,mpi_sum,icomm,info) - dat = dat_ + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - call mpi_reduce(dat,dat_,1,psb_mpi_c_spk_,mpi_sum,root_,icomm,info) - if (iam == root_) dat = dat_ - endif + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allreduce(MPI_IN_PLACE,dat,1,psb_mpi_c_spk_,mpi_sum,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,1,psb_mpi_c_spk_,mpi_sum,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallreduce(MPI_IN_PLACE,dat,1,psb_mpi_c_spk_,mpi_sum,icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,1,psb_mpi_c_spk_,mpi_sum,root_,icomm,request,info) + end if + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if #endif end subroutine psb_csums - subroutine psb_csumv(ctxt,dat,root) + subroutine psb_csumv(ctxt,dat,root,mode,request) use psb_realloc_mod #ifdef MPI_MOD use mpi @@ -126,10 +156,14 @@ contains type(psb_ctxt_type), intent(in) :: ctxt complex(psb_spk_), intent(inout) :: dat(:) integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ - complex(psb_spk_), allocatable :: dat_(:) - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo + logical :: collective_start, collective_end, collective_sync #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -140,25 +174,55 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call psb_realloc(size(dat),dat_,iinfo) - dat_ = dat - if (iinfo == psb_success_) & - & call mpi_allreduce(dat_,dat,size(dat),psb_mpi_c_spk_,mpi_sum,icomm,info) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - if (iam == root_) then - call psb_realloc(size(dat),dat_,iinfo) - dat_ = dat - call mpi_reduce(dat_,dat,size(dat),psb_mpi_c_spk_,mpi_sum,root_,icomm,info) + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_spk_,mpi_sum,icomm,info) else - call psb_realloc(1,dat_,iinfo) - call mpi_reduce(dat,dat_,size(dat),psb_mpi_c_spk_,mpi_sum,root_,icomm,info) + if (iam == root_) then + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_spk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_spk_,mpi_sum,root_,icomm,info) + end if end if - endif + else + if (collective_start) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_spk_,mpi_sum,& + & icomm,request,info) + else + if (iam == root_) then + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_spk_,mpi_sum,root_,& + & icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_spk_,mpi_sum,root_,& + & icomm,request,info) + end if + end if + else if (collective_end) then + call mpi_wait(request,status,info) + endif + end if + #endif end subroutine psb_csumv - subroutine psb_csumm(ctxt,dat,root) + subroutine psb_csumm(ctxt,dat,root,mode,request) use psb_realloc_mod #ifdef MPI_MOD use mpi @@ -170,11 +234,15 @@ contains type(psb_ctxt_type), intent(in) :: ctxt complex(psb_spk_), intent(inout) :: dat(:,:) integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ - complex(psb_spk_), allocatable :: dat_(:,:) - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo - + logical :: collective_start, collective_end, collective_sync + #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -185,21 +253,50 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call psb_realloc(size(dat,1),size(dat,2),dat_,iinfo) - dat_ = dat - if (iinfo == psb_success_)& - & call mpi_allreduce(dat_,dat,size(dat),psb_mpi_c_spk_,mpi_sum,icomm,info) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - if (iam == root_) then - call psb_realloc(size(dat,1),size(dat,2),dat_,iinfo) - dat_ = dat - call mpi_reduce(dat_,dat,size(dat),psb_mpi_c_spk_,mpi_sum,root_,icomm,info) + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_spk_,mpi_sum,icomm,info) else - call psb_realloc(1,1,dat_,iinfo) - call mpi_reduce(dat,dat_,size(dat),psb_mpi_c_spk_,mpi_sum,root_,icomm,info) + if (iam == root_) then + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_spk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_spk_,mpi_sum,root_,icomm,info) + end if end if - endif + else + if (collective_start) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_spk_,mpi_sum,& + & icomm,request,info) + else + if (iam == root_) then + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_spk_,mpi_sum,root_,& + & icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_spk_,mpi_sum,root_,& + & icomm,request,info) + end if + end if + else if (collective_end) then + call mpi_wait(request,status,info) + endif + end if #endif end subroutine psb_csumm diff --git a/base/modules/penv/psi_d_collective_mod.F90 b/base/modules/penv/psi_d_collective_mod.F90 index 12d5f38b..c07affda 100644 --- a/base/modules/penv/psi_d_collective_mod.F90 +++ b/base/modules/penv/psi_d_collective_mod.F90 @@ -31,7 +31,8 @@ ! module psi_d_collective_mod use psi_penv_mod - + use psb_desc_const_mod + interface psb_max module procedure psb_dmaxs, psb_dmaxv, psb_dmaxm end interface @@ -441,7 +442,7 @@ contains ! SUM ! - subroutine psb_dsums(ctxt,dat,root) + subroutine psb_dsums(ctxt,dat,root,mode,request) #ifdef MPI_MOD use mpi #endif @@ -452,11 +453,16 @@ contains type(psb_ctxt_type), intent(in) :: ctxt real(psb_dpk_), intent(inout) :: dat integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ real(psb_dpk_) :: dat_ - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo - + logical :: collective_start, collective_end, collective_sync + #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -466,17 +472,41 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call mpi_allreduce(dat,dat_,1,psb_mpi_r_dpk_,mpi_sum,icomm,info) - dat = dat_ + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - call mpi_reduce(dat,dat_,1,psb_mpi_r_dpk_,mpi_sum,root_,icomm,info) - if (iam == root_) dat = dat_ - endif + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allreduce(MPI_IN_PLACE,dat,1,psb_mpi_r_dpk_,mpi_sum,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,1,psb_mpi_r_dpk_,mpi_sum,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallreduce(MPI_IN_PLACE,dat,1,psb_mpi_r_dpk_,mpi_sum,icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,1,psb_mpi_r_dpk_,mpi_sum,root_,icomm,request,info) + end if + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if #endif end subroutine psb_dsums - subroutine psb_dsumv(ctxt,dat,root) + subroutine psb_dsumv(ctxt,dat,root,mode,request) use psb_realloc_mod #ifdef MPI_MOD use mpi @@ -488,10 +518,14 @@ contains type(psb_ctxt_type), intent(in) :: ctxt real(psb_dpk_), intent(inout) :: dat(:) integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ - real(psb_dpk_), allocatable :: dat_(:) - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo + logical :: collective_start, collective_end, collective_sync #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -502,25 +536,55 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call psb_realloc(size(dat),dat_,iinfo) - dat_ = dat - if (iinfo == psb_success_) & - & call mpi_allreduce(dat_,dat,size(dat),psb_mpi_r_dpk_,mpi_sum,icomm,info) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - if (iam == root_) then - call psb_realloc(size(dat),dat_,iinfo) - dat_ = dat - call mpi_reduce(dat_,dat,size(dat),psb_mpi_r_dpk_,mpi_sum,root_,icomm,info) + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_dpk_,mpi_sum,icomm,info) else - call psb_realloc(1,dat_,iinfo) - call mpi_reduce(dat,dat_,size(dat),psb_mpi_r_dpk_,mpi_sum,root_,icomm,info) + if (iam == root_) then + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_dpk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_dpk_,mpi_sum,root_,icomm,info) + end if end if - endif + else + if (collective_start) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_dpk_,mpi_sum,& + & icomm,request,info) + else + if (iam == root_) then + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_dpk_,mpi_sum,root_,& + & icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_dpk_,mpi_sum,root_,& + & icomm,request,info) + end if + end if + else if (collective_end) then + call mpi_wait(request,status,info) + endif + end if + #endif end subroutine psb_dsumv - subroutine psb_dsumm(ctxt,dat,root) + subroutine psb_dsumm(ctxt,dat,root,mode,request) use psb_realloc_mod #ifdef MPI_MOD use mpi @@ -532,11 +596,15 @@ contains type(psb_ctxt_type), intent(in) :: ctxt real(psb_dpk_), intent(inout) :: dat(:,:) integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ - real(psb_dpk_), allocatable :: dat_(:,:) - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo - + logical :: collective_start, collective_end, collective_sync + #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -547,21 +615,50 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call psb_realloc(size(dat,1),size(dat,2),dat_,iinfo) - dat_ = dat - if (iinfo == psb_success_)& - & call mpi_allreduce(dat_,dat,size(dat),psb_mpi_r_dpk_,mpi_sum,icomm,info) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - if (iam == root_) then - call psb_realloc(size(dat,1),size(dat,2),dat_,iinfo) - dat_ = dat - call mpi_reduce(dat_,dat,size(dat),psb_mpi_r_dpk_,mpi_sum,root_,icomm,info) + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_dpk_,mpi_sum,icomm,info) else - call psb_realloc(1,1,dat_,iinfo) - call mpi_reduce(dat,dat_,size(dat),psb_mpi_r_dpk_,mpi_sum,root_,icomm,info) + if (iam == root_) then + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_dpk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_dpk_,mpi_sum,root_,icomm,info) + end if end if - endif + else + if (collective_start) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_dpk_,mpi_sum,& + & icomm,request,info) + else + if (iam == root_) then + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_dpk_,mpi_sum,root_,& + & icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_dpk_,mpi_sum,root_,& + & icomm,request,info) + end if + end if + else if (collective_end) then + call mpi_wait(request,status,info) + endif + end if #endif end subroutine psb_dsumm diff --git a/base/modules/penv/psi_e_collective_mod.F90 b/base/modules/penv/psi_e_collective_mod.F90 index 215446c0..954c6141 100644 --- a/base/modules/penv/psi_e_collective_mod.F90 +++ b/base/modules/penv/psi_e_collective_mod.F90 @@ -31,7 +31,8 @@ ! module psi_e_collective_mod use psi_penv_mod - + use psb_desc_const_mod + interface psb_max module procedure psb_emaxs, psb_emaxv, psb_emaxm end interface @@ -349,7 +350,7 @@ contains ! SUM ! - subroutine psb_esums(ctxt,dat,root) + subroutine psb_esums(ctxt,dat,root,mode,request) #ifdef MPI_MOD use mpi #endif @@ -360,11 +361,16 @@ contains type(psb_ctxt_type), intent(in) :: ctxt integer(psb_epk_), intent(inout) :: dat integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ integer(psb_epk_) :: dat_ - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo - + logical :: collective_start, collective_end, collective_sync + #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -374,17 +380,41 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call mpi_allreduce(dat,dat_,1,psb_mpi_epk_,mpi_sum,icomm,info) - dat = dat_ + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - call mpi_reduce(dat,dat_,1,psb_mpi_epk_,mpi_sum,root_,icomm,info) - if (iam == root_) dat = dat_ - endif + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allreduce(MPI_IN_PLACE,dat,1,psb_mpi_epk_,mpi_sum,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,1,psb_mpi_epk_,mpi_sum,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallreduce(MPI_IN_PLACE,dat,1,psb_mpi_epk_,mpi_sum,icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,1,psb_mpi_epk_,mpi_sum,root_,icomm,request,info) + end if + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if #endif end subroutine psb_esums - subroutine psb_esumv(ctxt,dat,root) + subroutine psb_esumv(ctxt,dat,root,mode,request) use psb_realloc_mod #ifdef MPI_MOD use mpi @@ -396,10 +426,14 @@ contains type(psb_ctxt_type), intent(in) :: ctxt integer(psb_epk_), intent(inout) :: dat(:) integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ - integer(psb_epk_), allocatable :: dat_(:) - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo + logical :: collective_start, collective_end, collective_sync #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -410,25 +444,55 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call psb_realloc(size(dat),dat_,iinfo) - dat_ = dat - if (iinfo == psb_success_) & - & call mpi_allreduce(dat_,dat,size(dat),psb_mpi_epk_,mpi_sum,icomm,info) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - if (iam == root_) then - call psb_realloc(size(dat),dat_,iinfo) - dat_ = dat - call mpi_reduce(dat_,dat,size(dat),psb_mpi_epk_,mpi_sum,root_,icomm,info) + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_epk_,mpi_sum,icomm,info) else - call psb_realloc(1,dat_,iinfo) - call mpi_reduce(dat,dat_,size(dat),psb_mpi_epk_,mpi_sum,root_,icomm,info) + if (iam == root_) then + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_epk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_epk_,mpi_sum,root_,icomm,info) + end if end if - endif + else + if (collective_start) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_epk_,mpi_sum,& + & icomm,request,info) + else + if (iam == root_) then + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_epk_,mpi_sum,root_,& + & icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_epk_,mpi_sum,root_,& + & icomm,request,info) + end if + end if + else if (collective_end) then + call mpi_wait(request,status,info) + endif + end if + #endif end subroutine psb_esumv - subroutine psb_esumm(ctxt,dat,root) + subroutine psb_esumm(ctxt,dat,root,mode,request) use psb_realloc_mod #ifdef MPI_MOD use mpi @@ -440,11 +504,15 @@ contains type(psb_ctxt_type), intent(in) :: ctxt integer(psb_epk_), intent(inout) :: dat(:,:) integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ - integer(psb_epk_), allocatable :: dat_(:,:) - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo - + logical :: collective_start, collective_end, collective_sync + #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -455,21 +523,50 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call psb_realloc(size(dat,1),size(dat,2),dat_,iinfo) - dat_ = dat - if (iinfo == psb_success_)& - & call mpi_allreduce(dat_,dat,size(dat),psb_mpi_epk_,mpi_sum,icomm,info) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - if (iam == root_) then - call psb_realloc(size(dat,1),size(dat,2),dat_,iinfo) - dat_ = dat - call mpi_reduce(dat_,dat,size(dat),psb_mpi_epk_,mpi_sum,root_,icomm,info) + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_epk_,mpi_sum,icomm,info) else - call psb_realloc(1,1,dat_,iinfo) - call mpi_reduce(dat,dat_,size(dat),psb_mpi_epk_,mpi_sum,root_,icomm,info) + if (iam == root_) then + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_epk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_epk_,mpi_sum,root_,icomm,info) + end if end if - endif + else + if (collective_start) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_epk_,mpi_sum,& + & icomm,request,info) + else + if (iam == root_) then + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_epk_,mpi_sum,root_,& + & icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_epk_,mpi_sum,root_,& + & icomm,request,info) + end if + end if + else if (collective_end) then + call mpi_wait(request,status,info) + endif + end if #endif end subroutine psb_esumm diff --git a/base/modules/penv/psi_i2_collective_mod.F90 b/base/modules/penv/psi_i2_collective_mod.F90 index 781653d4..d53d7f83 100644 --- a/base/modules/penv/psi_i2_collective_mod.F90 +++ b/base/modules/penv/psi_i2_collective_mod.F90 @@ -31,7 +31,8 @@ ! module psi_i2_collective_mod use psi_penv_mod - + use psb_desc_const_mod + interface psb_max module procedure psb_i2maxs, psb_i2maxv, psb_i2maxm end interface @@ -349,7 +350,7 @@ contains ! SUM ! - subroutine psb_i2sums(ctxt,dat,root) + subroutine psb_i2sums(ctxt,dat,root,mode,request) #ifdef MPI_MOD use mpi #endif @@ -360,11 +361,16 @@ contains type(psb_ctxt_type), intent(in) :: ctxt integer(psb_i2pk_), intent(inout) :: dat integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ integer(psb_i2pk_) :: dat_ - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo - + logical :: collective_start, collective_end, collective_sync + #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -374,17 +380,41 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call mpi_allreduce(dat,dat_,1,psb_mpi_i2pk_,mpi_sum,icomm,info) - dat = dat_ + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - call mpi_reduce(dat,dat_,1,psb_mpi_i2pk_,mpi_sum,root_,icomm,info) - if (iam == root_) dat = dat_ - endif + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allreduce(MPI_IN_PLACE,dat,1,psb_mpi_i2pk_,mpi_sum,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,1,psb_mpi_i2pk_,mpi_sum,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallreduce(MPI_IN_PLACE,dat,1,psb_mpi_i2pk_,mpi_sum,icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,1,psb_mpi_i2pk_,mpi_sum,root_,icomm,request,info) + end if + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if #endif end subroutine psb_i2sums - subroutine psb_i2sumv(ctxt,dat,root) + subroutine psb_i2sumv(ctxt,dat,root,mode,request) use psb_realloc_mod #ifdef MPI_MOD use mpi @@ -396,10 +426,14 @@ contains type(psb_ctxt_type), intent(in) :: ctxt integer(psb_i2pk_), intent(inout) :: dat(:) integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ - integer(psb_i2pk_), allocatable :: dat_(:) - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo + logical :: collective_start, collective_end, collective_sync #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -410,25 +444,55 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call psb_realloc(size(dat),dat_,iinfo) - dat_ = dat - if (iinfo == psb_success_) & - & call mpi_allreduce(dat_,dat,size(dat),psb_mpi_i2pk_,mpi_sum,icomm,info) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - if (iam == root_) then - call psb_realloc(size(dat),dat_,iinfo) - dat_ = dat - call mpi_reduce(dat_,dat,size(dat),psb_mpi_i2pk_,mpi_sum,root_,icomm,info) + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_i2pk_,mpi_sum,icomm,info) else - call psb_realloc(1,dat_,iinfo) - call mpi_reduce(dat,dat_,size(dat),psb_mpi_i2pk_,mpi_sum,root_,icomm,info) + if (iam == root_) then + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_i2pk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_i2pk_,mpi_sum,root_,icomm,info) + end if end if - endif + else + if (collective_start) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_i2pk_,mpi_sum,& + & icomm,request,info) + else + if (iam == root_) then + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_i2pk_,mpi_sum,root_,& + & icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_i2pk_,mpi_sum,root_,& + & icomm,request,info) + end if + end if + else if (collective_end) then + call mpi_wait(request,status,info) + endif + end if + #endif end subroutine psb_i2sumv - subroutine psb_i2summ(ctxt,dat,root) + subroutine psb_i2summ(ctxt,dat,root,mode,request) use psb_realloc_mod #ifdef MPI_MOD use mpi @@ -440,11 +504,15 @@ contains type(psb_ctxt_type), intent(in) :: ctxt integer(psb_i2pk_), intent(inout) :: dat(:,:) integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ - integer(psb_i2pk_), allocatable :: dat_(:,:) - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo - + logical :: collective_start, collective_end, collective_sync + #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -455,21 +523,50 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call psb_realloc(size(dat,1),size(dat,2),dat_,iinfo) - dat_ = dat - if (iinfo == psb_success_)& - & call mpi_allreduce(dat_,dat,size(dat),psb_mpi_i2pk_,mpi_sum,icomm,info) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - if (iam == root_) then - call psb_realloc(size(dat,1),size(dat,2),dat_,iinfo) - dat_ = dat - call mpi_reduce(dat_,dat,size(dat),psb_mpi_i2pk_,mpi_sum,root_,icomm,info) + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_i2pk_,mpi_sum,icomm,info) else - call psb_realloc(1,1,dat_,iinfo) - call mpi_reduce(dat,dat_,size(dat),psb_mpi_i2pk_,mpi_sum,root_,icomm,info) + if (iam == root_) then + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_i2pk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_i2pk_,mpi_sum,root_,icomm,info) + end if end if - endif + else + if (collective_start) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_i2pk_,mpi_sum,& + & icomm,request,info) + else + if (iam == root_) then + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_i2pk_,mpi_sum,root_,& + & icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_i2pk_,mpi_sum,root_,& + & icomm,request,info) + end if + end if + else if (collective_end) then + call mpi_wait(request,status,info) + endif + end if #endif end subroutine psb_i2summ diff --git a/base/modules/penv/psi_m_collective_mod.F90 b/base/modules/penv/psi_m_collective_mod.F90 index 8fdea824..aa2ffc46 100644 --- a/base/modules/penv/psi_m_collective_mod.F90 +++ b/base/modules/penv/psi_m_collective_mod.F90 @@ -31,7 +31,8 @@ ! module psi_m_collective_mod use psi_penv_mod - + use psb_desc_const_mod + interface psb_max module procedure psb_mmaxs, psb_mmaxv, psb_mmaxm end interface @@ -349,7 +350,7 @@ contains ! SUM ! - subroutine psb_msums(ctxt,dat,root) + subroutine psb_msums(ctxt,dat,root,mode,request) #ifdef MPI_MOD use mpi #endif @@ -360,11 +361,16 @@ contains type(psb_ctxt_type), intent(in) :: ctxt integer(psb_mpk_), intent(inout) :: dat integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ integer(psb_mpk_) :: dat_ - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo - + logical :: collective_start, collective_end, collective_sync + #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -374,17 +380,41 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call mpi_allreduce(dat,dat_,1,psb_mpi_mpk_,mpi_sum,icomm,info) - dat = dat_ + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - call mpi_reduce(dat,dat_,1,psb_mpi_mpk_,mpi_sum,root_,icomm,info) - if (iam == root_) dat = dat_ - endif + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allreduce(MPI_IN_PLACE,dat,1,psb_mpi_mpk_,mpi_sum,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,1,psb_mpi_mpk_,mpi_sum,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallreduce(MPI_IN_PLACE,dat,1,psb_mpi_mpk_,mpi_sum,icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,1,psb_mpi_mpk_,mpi_sum,root_,icomm,request,info) + end if + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if #endif end subroutine psb_msums - subroutine psb_msumv(ctxt,dat,root) + subroutine psb_msumv(ctxt,dat,root,mode,request) use psb_realloc_mod #ifdef MPI_MOD use mpi @@ -396,10 +426,14 @@ contains type(psb_ctxt_type), intent(in) :: ctxt integer(psb_mpk_), intent(inout) :: dat(:) integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ - integer(psb_mpk_), allocatable :: dat_(:) - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo + logical :: collective_start, collective_end, collective_sync #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -410,25 +444,55 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call psb_realloc(size(dat),dat_,iinfo) - dat_ = dat - if (iinfo == psb_success_) & - & call mpi_allreduce(dat_,dat,size(dat),psb_mpi_mpk_,mpi_sum,icomm,info) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - if (iam == root_) then - call psb_realloc(size(dat),dat_,iinfo) - dat_ = dat - call mpi_reduce(dat_,dat,size(dat),psb_mpi_mpk_,mpi_sum,root_,icomm,info) + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_mpk_,mpi_sum,icomm,info) else - call psb_realloc(1,dat_,iinfo) - call mpi_reduce(dat,dat_,size(dat),psb_mpi_mpk_,mpi_sum,root_,icomm,info) + if (iam == root_) then + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_mpk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_mpk_,mpi_sum,root_,icomm,info) + end if end if - endif + else + if (collective_start) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_mpk_,mpi_sum,& + & icomm,request,info) + else + if (iam == root_) then + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_mpk_,mpi_sum,root_,& + & icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_mpk_,mpi_sum,root_,& + & icomm,request,info) + end if + end if + else if (collective_end) then + call mpi_wait(request,status,info) + endif + end if + #endif end subroutine psb_msumv - subroutine psb_msumm(ctxt,dat,root) + subroutine psb_msumm(ctxt,dat,root,mode,request) use psb_realloc_mod #ifdef MPI_MOD use mpi @@ -440,11 +504,15 @@ contains type(psb_ctxt_type), intent(in) :: ctxt integer(psb_mpk_), intent(inout) :: dat(:,:) integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ - integer(psb_mpk_), allocatable :: dat_(:,:) - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo - + logical :: collective_start, collective_end, collective_sync + #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -455,21 +523,50 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call psb_realloc(size(dat,1),size(dat,2),dat_,iinfo) - dat_ = dat - if (iinfo == psb_success_)& - & call mpi_allreduce(dat_,dat,size(dat),psb_mpi_mpk_,mpi_sum,icomm,info) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - if (iam == root_) then - call psb_realloc(size(dat,1),size(dat,2),dat_,iinfo) - dat_ = dat - call mpi_reduce(dat_,dat,size(dat),psb_mpi_mpk_,mpi_sum,root_,icomm,info) + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_mpk_,mpi_sum,icomm,info) else - call psb_realloc(1,1,dat_,iinfo) - call mpi_reduce(dat,dat_,size(dat),psb_mpi_mpk_,mpi_sum,root_,icomm,info) + if (iam == root_) then + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_mpk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_mpk_,mpi_sum,root_,icomm,info) + end if end if - endif + else + if (collective_start) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_mpk_,mpi_sum,& + & icomm,request,info) + else + if (iam == root_) then + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_mpk_,mpi_sum,root_,& + & icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_mpk_,mpi_sum,root_,& + & icomm,request,info) + end if + end if + else if (collective_end) then + call mpi_wait(request,status,info) + endif + end if #endif end subroutine psb_msumm diff --git a/base/modules/penv/psi_s_collective_mod.F90 b/base/modules/penv/psi_s_collective_mod.F90 index 82f96aac..ef0b946c 100644 --- a/base/modules/penv/psi_s_collective_mod.F90 +++ b/base/modules/penv/psi_s_collective_mod.F90 @@ -31,7 +31,8 @@ ! module psi_s_collective_mod use psi_penv_mod - + use psb_desc_const_mod + interface psb_max module procedure psb_smaxs, psb_smaxv, psb_smaxm end interface @@ -441,7 +442,7 @@ contains ! SUM ! - subroutine psb_ssums(ctxt,dat,root) + subroutine psb_ssums(ctxt,dat,root,mode,request) #ifdef MPI_MOD use mpi #endif @@ -452,11 +453,16 @@ contains type(psb_ctxt_type), intent(in) :: ctxt real(psb_spk_), intent(inout) :: dat integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ real(psb_spk_) :: dat_ - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo - + logical :: collective_start, collective_end, collective_sync + #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -466,17 +472,41 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call mpi_allreduce(dat,dat_,1,psb_mpi_r_spk_,mpi_sum,icomm,info) - dat = dat_ + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - call mpi_reduce(dat,dat_,1,psb_mpi_r_spk_,mpi_sum,root_,icomm,info) - if (iam == root_) dat = dat_ - endif + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allreduce(MPI_IN_PLACE,dat,1,psb_mpi_r_spk_,mpi_sum,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,1,psb_mpi_r_spk_,mpi_sum,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallreduce(MPI_IN_PLACE,dat,1,psb_mpi_r_spk_,mpi_sum,icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,1,psb_mpi_r_spk_,mpi_sum,root_,icomm,request,info) + end if + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if #endif end subroutine psb_ssums - subroutine psb_ssumv(ctxt,dat,root) + subroutine psb_ssumv(ctxt,dat,root,mode,request) use psb_realloc_mod #ifdef MPI_MOD use mpi @@ -488,10 +518,14 @@ contains type(psb_ctxt_type), intent(in) :: ctxt real(psb_spk_), intent(inout) :: dat(:) integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ - real(psb_spk_), allocatable :: dat_(:) - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo + logical :: collective_start, collective_end, collective_sync #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -502,25 +536,55 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call psb_realloc(size(dat),dat_,iinfo) - dat_ = dat - if (iinfo == psb_success_) & - & call mpi_allreduce(dat_,dat,size(dat),psb_mpi_r_spk_,mpi_sum,icomm,info) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - if (iam == root_) then - call psb_realloc(size(dat),dat_,iinfo) - dat_ = dat - call mpi_reduce(dat_,dat,size(dat),psb_mpi_r_spk_,mpi_sum,root_,icomm,info) + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_spk_,mpi_sum,icomm,info) else - call psb_realloc(1,dat_,iinfo) - call mpi_reduce(dat,dat_,size(dat),psb_mpi_r_spk_,mpi_sum,root_,icomm,info) + if (iam == root_) then + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_spk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_spk_,mpi_sum,root_,icomm,info) + end if end if - endif + else + if (collective_start) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_spk_,mpi_sum,& + & icomm,request,info) + else + if (iam == root_) then + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_spk_,mpi_sum,root_,& + & icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_spk_,mpi_sum,root_,& + & icomm,request,info) + end if + end if + else if (collective_end) then + call mpi_wait(request,status,info) + endif + end if + #endif end subroutine psb_ssumv - subroutine psb_ssumm(ctxt,dat,root) + subroutine psb_ssumm(ctxt,dat,root,mode,request) use psb_realloc_mod #ifdef MPI_MOD use mpi @@ -532,11 +596,15 @@ contains type(psb_ctxt_type), intent(in) :: ctxt real(psb_spk_), intent(inout) :: dat(:,:) integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ - real(psb_spk_), allocatable :: dat_(:,:) - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo - + logical :: collective_start, collective_end, collective_sync + #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -547,21 +615,50 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call psb_realloc(size(dat,1),size(dat,2),dat_,iinfo) - dat_ = dat - if (iinfo == psb_success_)& - & call mpi_allreduce(dat_,dat,size(dat),psb_mpi_r_spk_,mpi_sum,icomm,info) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - if (iam == root_) then - call psb_realloc(size(dat,1),size(dat,2),dat_,iinfo) - dat_ = dat - call mpi_reduce(dat_,dat,size(dat),psb_mpi_r_spk_,mpi_sum,root_,icomm,info) + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_spk_,mpi_sum,icomm,info) else - call psb_realloc(1,1,dat_,iinfo) - call mpi_reduce(dat,dat_,size(dat),psb_mpi_r_spk_,mpi_sum,root_,icomm,info) + if (iam == root_) then + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_spk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_spk_,mpi_sum,root_,icomm,info) + end if end if - endif + else + if (collective_start) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_spk_,mpi_sum,& + & icomm,request,info) + else + if (iam == root_) then + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_spk_,mpi_sum,root_,& + & icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_spk_,mpi_sum,root_,& + & icomm,request,info) + end if + end if + else if (collective_end) then + call mpi_wait(request,status,info) + endif + end if #endif end subroutine psb_ssumm diff --git a/base/modules/penv/psi_z_collective_mod.F90 b/base/modules/penv/psi_z_collective_mod.F90 index 80c9213a..f2dfc131 100644 --- a/base/modules/penv/psi_z_collective_mod.F90 +++ b/base/modules/penv/psi_z_collective_mod.F90 @@ -31,7 +31,8 @@ ! module psi_z_collective_mod use psi_penv_mod - + use psb_desc_const_mod + interface psb_sum module procedure psb_zsums, psb_zsumv, psb_zsumm @@ -79,7 +80,7 @@ contains ! SUM ! - subroutine psb_zsums(ctxt,dat,root) + subroutine psb_zsums(ctxt,dat,root,mode,request) #ifdef MPI_MOD use mpi #endif @@ -90,11 +91,16 @@ contains type(psb_ctxt_type), intent(in) :: ctxt complex(psb_dpk_), intent(inout) :: dat integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ complex(psb_dpk_) :: dat_ - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo - + logical :: collective_start, collective_end, collective_sync + #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -104,17 +110,41 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call mpi_allreduce(dat,dat_,1,psb_mpi_c_dpk_,mpi_sum,icomm,info) - dat = dat_ + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - call mpi_reduce(dat,dat_,1,psb_mpi_c_dpk_,mpi_sum,root_,icomm,info) - if (iam == root_) dat = dat_ - endif + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allreduce(MPI_IN_PLACE,dat,1,psb_mpi_c_dpk_,mpi_sum,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,1,psb_mpi_c_dpk_,mpi_sum,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallreduce(MPI_IN_PLACE,dat,1,psb_mpi_c_dpk_,mpi_sum,icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,1,psb_mpi_c_dpk_,mpi_sum,root_,icomm,request,info) + end if + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if #endif end subroutine psb_zsums - subroutine psb_zsumv(ctxt,dat,root) + subroutine psb_zsumv(ctxt,dat,root,mode,request) use psb_realloc_mod #ifdef MPI_MOD use mpi @@ -126,10 +156,14 @@ contains type(psb_ctxt_type), intent(in) :: ctxt complex(psb_dpk_), intent(inout) :: dat(:) integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ - complex(psb_dpk_), allocatable :: dat_(:) - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo + logical :: collective_start, collective_end, collective_sync #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -140,25 +174,55 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call psb_realloc(size(dat),dat_,iinfo) - dat_ = dat - if (iinfo == psb_success_) & - & call mpi_allreduce(dat_,dat,size(dat),psb_mpi_c_dpk_,mpi_sum,icomm,info) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - if (iam == root_) then - call psb_realloc(size(dat),dat_,iinfo) - dat_ = dat - call mpi_reduce(dat_,dat,size(dat),psb_mpi_c_dpk_,mpi_sum,root_,icomm,info) + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_dpk_,mpi_sum,icomm,info) else - call psb_realloc(1,dat_,iinfo) - call mpi_reduce(dat,dat_,size(dat),psb_mpi_c_dpk_,mpi_sum,root_,icomm,info) + if (iam == root_) then + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_dpk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_dpk_,mpi_sum,root_,icomm,info) + end if end if - endif + else + if (collective_start) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_dpk_,mpi_sum,& + & icomm,request,info) + else + if (iam == root_) then + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_dpk_,mpi_sum,root_,& + & icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_dpk_,mpi_sum,root_,& + & icomm,request,info) + end if + end if + else if (collective_end) then + call mpi_wait(request,status,info) + endif + end if + #endif end subroutine psb_zsumv - subroutine psb_zsumm(ctxt,dat,root) + subroutine psb_zsumm(ctxt,dat,root,mode,request) use psb_realloc_mod #ifdef MPI_MOD use mpi @@ -170,11 +234,15 @@ contains type(psb_ctxt_type), intent(in) :: ctxt complex(psb_dpk_), intent(inout) :: dat(:,:) integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request integer(psb_mpk_) :: root_ - complex(psb_dpk_), allocatable :: dat_(:,:) - integer(psb_mpk_) :: iam, np, info, icomm + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) integer(psb_ipk_) :: iinfo - + logical :: collective_start, collective_end, collective_sync + #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -185,21 +253,50 @@ contains root_ = -1 endif icomm = psb_get_mpi_comm(ctxt) - if (root_ == -1) then - call psb_realloc(size(dat,1),size(dat,2),dat_,iinfo) - dat_ = dat - if (iinfo == psb_success_)& - & call mpi_allreduce(dat_,dat,size(dat),psb_mpi_c_dpk_,mpi_sum,icomm,info) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if else - if (iam == root_) then - call psb_realloc(size(dat,1),size(dat,2),dat_,iinfo) - dat_ = dat - call mpi_reduce(dat_,dat,size(dat),psb_mpi_c_dpk_,mpi_sum,root_,icomm,info) + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_dpk_,mpi_sum,icomm,info) else - call psb_realloc(1,1,dat_,iinfo) - call mpi_reduce(dat,dat_,size(dat),psb_mpi_c_dpk_,mpi_sum,root_,icomm,info) + if (iam == root_) then + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_dpk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_dpk_,mpi_sum,root_,icomm,info) + end if end if - endif + else + if (collective_start) then + if (root_ == -1) then + if (iinfo == psb_success_) & + & call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_dpk_,mpi_sum,& + & icomm,request,info) + else + if (iam == root_) then + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_dpk_,mpi_sum,root_,& + & icomm,request,info) + else + call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_c_dpk_,mpi_sum,root_,& + & icomm,request,info) + end if + end if + else if (collective_end) then + call mpi_wait(request,status,info) + endif + end if #endif end subroutine psb_zsumm