From 00cc83cde8d03e85539ee06fbb3873ab80357a4f Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Wed, 18 Jan 2023 05:04:16 -0500 Subject: [PATCH 01/48] First version of AD/AND with memory duplication --- base/modules/serial/psb_c_mat_mod.F90 | 1 + base/modules/serial/psb_d_mat_mod.F90 | 1 + base/modules/serial/psb_s_mat_mod.F90 | 1 + base/modules/serial/psb_z_mat_mod.F90 | 1 + base/psblas/psb_cspmm.f90 | 21 ++++++++++++++----- base/psblas/psb_dspmm.f90 | 21 ++++++++++++++----- base/psblas/psb_sspmm.f90 | 21 ++++++++++++++----- base/psblas/psb_zspmm.f90 | 21 ++++++++++++++----- base/tools/psb_cspasb.f90 | 30 ++++++++++++++++++++++++++- base/tools/psb_dspasb.f90 | 30 ++++++++++++++++++++++++++- base/tools/psb_sspasb.f90 | 30 ++++++++++++++++++++++++++- base/tools/psb_zspasb.f90 | 30 ++++++++++++++++++++++++++- compile | 0 test/pargen/runs/ppde.inp | 2 +- 14 files changed, 185 insertions(+), 25 deletions(-) create mode 100644 compile diff --git a/base/modules/serial/psb_c_mat_mod.F90 b/base/modules/serial/psb_c_mat_mod.F90 index fd423de3..2e365858 100644 --- a/base/modules/serial/psb_c_mat_mod.F90 +++ b/base/modules/serial/psb_c_mat_mod.F90 @@ -85,6 +85,7 @@ module psb_c_mat_mod type :: psb_cspmat_type class(psb_c_base_sparse_mat), allocatable :: a + class(psb_c_base_sparse_mat), allocatable :: ad, and integer(psb_ipk_) :: remote_build=psb_matbld_noremote_ type(psb_lc_coo_sparse_mat), allocatable :: rmta diff --git a/base/modules/serial/psb_d_mat_mod.F90 b/base/modules/serial/psb_d_mat_mod.F90 index 8f967ce1..49a9545e 100644 --- a/base/modules/serial/psb_d_mat_mod.F90 +++ b/base/modules/serial/psb_d_mat_mod.F90 @@ -85,6 +85,7 @@ module psb_d_mat_mod type :: psb_dspmat_type class(psb_d_base_sparse_mat), allocatable :: a + class(psb_d_base_sparse_mat), allocatable :: ad, and integer(psb_ipk_) :: remote_build=psb_matbld_noremote_ type(psb_ld_coo_sparse_mat), allocatable :: rmta diff --git a/base/modules/serial/psb_s_mat_mod.F90 b/base/modules/serial/psb_s_mat_mod.F90 index 43f1c619..eb444249 100644 --- a/base/modules/serial/psb_s_mat_mod.F90 +++ b/base/modules/serial/psb_s_mat_mod.F90 @@ -85,6 +85,7 @@ module psb_s_mat_mod type :: psb_sspmat_type class(psb_s_base_sparse_mat), allocatable :: a + class(psb_s_base_sparse_mat), allocatable :: ad, and integer(psb_ipk_) :: remote_build=psb_matbld_noremote_ type(psb_ls_coo_sparse_mat), allocatable :: rmta diff --git a/base/modules/serial/psb_z_mat_mod.F90 b/base/modules/serial/psb_z_mat_mod.F90 index c534cad5..e70e48aa 100644 --- a/base/modules/serial/psb_z_mat_mod.F90 +++ b/base/modules/serial/psb_z_mat_mod.F90 @@ -85,6 +85,7 @@ module psb_z_mat_mod type :: psb_zspmat_type class(psb_z_base_sparse_mat), allocatable :: a + class(psb_z_base_sparse_mat), allocatable :: ad, and integer(psb_ipk_) :: remote_build=psb_matbld_noremote_ type(psb_lz_coo_sparse_mat), allocatable :: rmta diff --git a/base/psblas/psb_cspmm.f90 b/base/psblas/psb_cspmm.f90 index fd8a9c39..555461df 100644 --- a/base/psblas/psb_cspmm.f90 +++ b/base/psblas/psb_cspmm.f90 @@ -179,13 +179,24 @@ subroutine psb_cspmv_vect(alpha,a,x,beta,y,desc_a,info,& if (trans_ == 'N') then ! Matrix is not transposed - if (doswap_) then - call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),& + if (.true.) then + call psi_swapdata(psb_swap_send_,& & czero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + call a%ad%spmm(alpha,x%v,beta,y%v,info) + call psi_swapdata(psb_swap_recv_,& + & czero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + call a%and%spmm(alpha,x%v,cone,y%v,info) + + else + if (doswap_) then + call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),& + & czero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + end if + + call psb_csmm(alpha,a,x,beta,y,info) + end if - - call psb_csmm(alpha,a,x,beta,y,info) - + if(info /= psb_success_) then info = psb_err_from_subroutine_non_ call psb_errpush(info,name) diff --git a/base/psblas/psb_dspmm.f90 b/base/psblas/psb_dspmm.f90 index a006c7e9..be8a493f 100644 --- a/base/psblas/psb_dspmm.f90 +++ b/base/psblas/psb_dspmm.f90 @@ -179,13 +179,24 @@ subroutine psb_dspmv_vect(alpha,a,x,beta,y,desc_a,info,& if (trans_ == 'N') then ! Matrix is not transposed - if (doswap_) then - call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),& + if (.true.) then + call psi_swapdata(psb_swap_send_,& & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + call a%ad%spmm(alpha,x%v,beta,y%v,info) + call psi_swapdata(psb_swap_recv_,& + & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + call a%and%spmm(alpha,x%v,done,y%v,info) + + else + if (doswap_) then + call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),& + & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + end if + + call psb_csmm(alpha,a,x,beta,y,info) + end if - - call psb_csmm(alpha,a,x,beta,y,info) - + if(info /= psb_success_) then info = psb_err_from_subroutine_non_ call psb_errpush(info,name) diff --git a/base/psblas/psb_sspmm.f90 b/base/psblas/psb_sspmm.f90 index 43ee0d48..79bfbdd1 100644 --- a/base/psblas/psb_sspmm.f90 +++ b/base/psblas/psb_sspmm.f90 @@ -179,13 +179,24 @@ subroutine psb_sspmv_vect(alpha,a,x,beta,y,desc_a,info,& if (trans_ == 'N') then ! Matrix is not transposed - if (doswap_) then - call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),& + if (.true.) then + call psi_swapdata(psb_swap_send_,& & szero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + call a%ad%spmm(alpha,x%v,beta,y%v,info) + call psi_swapdata(psb_swap_recv_,& + & szero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + call a%and%spmm(alpha,x%v,sone,y%v,info) + + else + if (doswap_) then + call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),& + & szero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + end if + + call psb_csmm(alpha,a,x,beta,y,info) + end if - - call psb_csmm(alpha,a,x,beta,y,info) - + if(info /= psb_success_) then info = psb_err_from_subroutine_non_ call psb_errpush(info,name) diff --git a/base/psblas/psb_zspmm.f90 b/base/psblas/psb_zspmm.f90 index b58ca303..f248db8b 100644 --- a/base/psblas/psb_zspmm.f90 +++ b/base/psblas/psb_zspmm.f90 @@ -179,13 +179,24 @@ subroutine psb_zspmv_vect(alpha,a,x,beta,y,desc_a,info,& if (trans_ == 'N') then ! Matrix is not transposed - if (doswap_) then - call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),& + if (.true.) then + call psi_swapdata(psb_swap_send_,& & zzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + call a%ad%spmm(alpha,x%v,beta,y%v,info) + call psi_swapdata(psb_swap_recv_,& + & zzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + call a%and%spmm(alpha,x%v,zone,y%v,info) + + else + if (doswap_) then + call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),& + & zzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + end if + + call psb_csmm(alpha,a,x,beta,y,info) + end if - - call psb_csmm(alpha,a,x,beta,y,info) - + if(info /= psb_success_) then info = psb_err_from_subroutine_non_ call psb_errpush(info,name) diff --git a/base/tools/psb_cspasb.f90 b/base/tools/psb_cspasb.f90 index 0c5f14ab..ea7789f2 100644 --- a/base/tools/psb_cspasb.f90 +++ b/base/tools/psb_cspasb.f90 @@ -171,7 +171,35 @@ subroutine psb_cspasb(a,desc_a, info, afmt, upd, mold) end if - + if (.true.) then + block + character(len=1024) :: fname + type(psb_c_coo_sparse_mat) :: acoo + type(psb_c_csr_sparse_mat), allocatable :: aclip, andclip + allocate(aclip,andclip) + call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) + call aclip%mv_from_coo(acoo,info) + call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) + call andclip%mv_from_coo(acoo,info) + call move_alloc(aclip,a%ad) + call move_alloc(andclip,a%and) + if (.false.) then + write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' + open(25,file=fname) + call a%ad%print(25) + close(25) + write(fname,'(a,i2.2,a)') 'andclip_',me,'.mtx' + open(25,file=fname) + call a%and%print(25) + close(25) + !call andclip%set_cols(n_col) + write(*,*) me,' ',trim(name),' ad ',& + &a%ad%get_nrows(),a%ad%get_ncols(),n_row,n_col + write(*,*) me,' ',trim(name),' and ',& + &a%and%get_nrows(),a%and%get_ncols(),n_row,n_col + end if + end block + end if if (debug_level >= psb_debug_ext_) then ch_err=a%get_fmt() write(debug_unit, *) me,' ',trim(name),': From SPCNV',& diff --git a/base/tools/psb_dspasb.f90 b/base/tools/psb_dspasb.f90 index 3132f249..89ceef8d 100644 --- a/base/tools/psb_dspasb.f90 +++ b/base/tools/psb_dspasb.f90 @@ -171,7 +171,35 @@ subroutine psb_dspasb(a,desc_a, info, afmt, upd, mold) end if - + if (.true.) then + block + character(len=1024) :: fname + type(psb_d_coo_sparse_mat) :: acoo + type(psb_d_csr_sparse_mat), allocatable :: aclip, andclip + allocate(aclip,andclip) + call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) + call aclip%mv_from_coo(acoo,info) + call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) + call andclip%mv_from_coo(acoo,info) + call move_alloc(aclip,a%ad) + call move_alloc(andclip,a%and) + if (.false.) then + write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' + open(25,file=fname) + call a%ad%print(25) + close(25) + write(fname,'(a,i2.2,a)') 'andclip_',me,'.mtx' + open(25,file=fname) + call a%and%print(25) + close(25) + !call andclip%set_cols(n_col) + write(*,*) me,' ',trim(name),' ad ',& + &a%ad%get_nrows(),a%ad%get_ncols(),n_row,n_col + write(*,*) me,' ',trim(name),' and ',& + &a%and%get_nrows(),a%and%get_ncols(),n_row,n_col + end if + end block + end if if (debug_level >= psb_debug_ext_) then ch_err=a%get_fmt() write(debug_unit, *) me,' ',trim(name),': From SPCNV',& diff --git a/base/tools/psb_sspasb.f90 b/base/tools/psb_sspasb.f90 index cfa316eb..14ad5246 100644 --- a/base/tools/psb_sspasb.f90 +++ b/base/tools/psb_sspasb.f90 @@ -171,7 +171,35 @@ subroutine psb_sspasb(a,desc_a, info, afmt, upd, mold) end if - + if (.true.) then + block + character(len=1024) :: fname + type(psb_s_coo_sparse_mat) :: acoo + type(psb_s_csr_sparse_mat), allocatable :: aclip, andclip + allocate(aclip,andclip) + call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) + call aclip%mv_from_coo(acoo,info) + call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) + call andclip%mv_from_coo(acoo,info) + call move_alloc(aclip,a%ad) + call move_alloc(andclip,a%and) + if (.false.) then + write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' + open(25,file=fname) + call a%ad%print(25) + close(25) + write(fname,'(a,i2.2,a)') 'andclip_',me,'.mtx' + open(25,file=fname) + call a%and%print(25) + close(25) + !call andclip%set_cols(n_col) + write(*,*) me,' ',trim(name),' ad ',& + &a%ad%get_nrows(),a%ad%get_ncols(),n_row,n_col + write(*,*) me,' ',trim(name),' and ',& + &a%and%get_nrows(),a%and%get_ncols(),n_row,n_col + end if + end block + end if if (debug_level >= psb_debug_ext_) then ch_err=a%get_fmt() write(debug_unit, *) me,' ',trim(name),': From SPCNV',& diff --git a/base/tools/psb_zspasb.f90 b/base/tools/psb_zspasb.f90 index aeeef94d..f65be363 100644 --- a/base/tools/psb_zspasb.f90 +++ b/base/tools/psb_zspasb.f90 @@ -171,7 +171,35 @@ subroutine psb_zspasb(a,desc_a, info, afmt, upd, mold) end if - + if (.true.) then + block + character(len=1024) :: fname + type(psb_z_coo_sparse_mat) :: acoo + type(psb_z_csr_sparse_mat), allocatable :: aclip, andclip + allocate(aclip,andclip) + call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) + call aclip%mv_from_coo(acoo,info) + call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) + call andclip%mv_from_coo(acoo,info) + call move_alloc(aclip,a%ad) + call move_alloc(andclip,a%and) + if (.false.) then + write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' + open(25,file=fname) + call a%ad%print(25) + close(25) + write(fname,'(a,i2.2,a)') 'andclip_',me,'.mtx' + open(25,file=fname) + call a%and%print(25) + close(25) + !call andclip%set_cols(n_col) + write(*,*) me,' ',trim(name),' ad ',& + &a%ad%get_nrows(),a%ad%get_ncols(),n_row,n_col + write(*,*) me,' ',trim(name),' and ',& + &a%and%get_nrows(),a%and%get_ncols(),n_row,n_col + end if + end block + end if if (debug_level >= psb_debug_ext_) then ch_err=a%get_fmt() write(debug_unit, *) me,' ',trim(name),': From SPCNV',& diff --git a/compile b/compile new file mode 100644 index 00000000..e69de29b diff --git a/test/pargen/runs/ppde.inp b/test/pargen/runs/ppde.inp index e7e5dca2..5f040075 100644 --- a/test/pargen/runs/ppde.inp +++ b/test/pargen/runs/ppde.inp @@ -2,7 +2,7 @@ BICGSTAB Iterative method BICGSTAB CGS BICG BICGSTABL RGMRES FCG CGR BJAC Preconditioner NONE DIAG BJAC CSR Storage format for matrix A: CSR COO -040 Domain size (acutal system is this**3 (pde3d) or **2 (pde2d) ) +140 Domain size (acutal system is this**3 (pde3d) or **2 (pde2d) ) 3 Partition: 1 BLOCK 3 3D 2 Stopping criterion 1 2 0100 MAXIT From f09e25524ed30ab5f783ac2d6efbd1684717d317 Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Wed, 18 Jan 2023 09:17:49 -0500 Subject: [PATCH 02/48] Create ECSR format and use it for A%AND --- base/modules/serial/psb_c_csr_mat_mod.f90 | 126 ++++++++++- base/modules/serial/psb_c_mat_mod.F90 | 3 +- base/modules/serial/psb_d_csr_mat_mod.f90 | 126 ++++++++++- base/modules/serial/psb_d_mat_mod.F90 | 3 +- base/modules/serial/psb_s_csr_mat_mod.f90 | 126 ++++++++++- base/modules/serial/psb_s_mat_mod.F90 | 3 +- base/modules/serial/psb_z_csr_mat_mod.f90 | 126 ++++++++++- base/modules/serial/psb_z_mat_mod.F90 | 3 +- base/serial/impl/psb_c_csr_impl.f90 | 263 ++++++++++++++++++++++ base/serial/impl/psb_d_csr_impl.f90 | 263 ++++++++++++++++++++++ base/serial/impl/psb_s_csr_impl.f90 | 263 ++++++++++++++++++++++ base/serial/impl/psb_z_csr_impl.f90 | 263 ++++++++++++++++++++++ base/tools/psb_cspasb.f90 | 7 +- base/tools/psb_dspasb.f90 | 7 +- base/tools/psb_sspasb.f90 | 7 +- base/tools/psb_zspasb.f90 | 7 +- test/pargen/runs/ppde.inp | 2 +- 17 files changed, 1577 insertions(+), 21 deletions(-) diff --git a/base/modules/serial/psb_c_csr_mat_mod.f90 b/base/modules/serial/psb_c_csr_mat_mod.f90 index 8b076cc2..d09eca2b 100644 --- a/base/modules/serial/psb_c_csr_mat_mod.f90 +++ b/base/modules/serial/psb_c_csr_mat_mod.f90 @@ -579,7 +579,111 @@ module psb_c_csr_mat_mod end subroutine psb_c_csr_scals end interface - !> \namespace psb_base_mod \class psb_lc_csr_sparse_mat + + type, extends(psb_c_csr_sparse_mat) :: psb_c_ecsr_sparse_mat + + !> Number of non-empty rows + integer(psb_ipk_) :: nnerws + !> Indices of non-empty rows + integer(psb_ipk_), allocatable :: nerwp(:) + + contains + procedure, nopass :: get_fmt => c_ecsr_get_fmt + + ! procedure, pass(a) :: csmm => psb_c_ecsr_csmm + procedure, pass(a) :: csmv => psb_c_ecsr_csmv + + procedure, pass(a) :: cp_from_coo => psb_c_cp_ecsr_from_coo + procedure, pass(a) :: cp_from_fmt => psb_c_cp_ecsr_from_fmt + procedure, pass(a) :: mv_from_coo => psb_c_mv_ecsr_from_coo + procedure, pass(a) :: mv_from_fmt => psb_c_mv_ecsr_from_fmt + + procedure, pass(a) :: cmp_nerwp => psb_c_ecsr_cmp_nerwp + procedure, pass(a) :: free => c_ecsr_free + procedure, pass(a) :: mold => psb_c_ecsr_mold + + end type psb_c_ecsr_sparse_mat + !> \memberof psb_c_ecsr_sparse_mat + !! \see psb_c_base_mat_mod::psb_c_base_csmv + interface + subroutine psb_c_ecsr_csmv(alpha,a,x,beta,y,info,trans) + import + class(psb_c_ecsr_sparse_mat), intent(in) :: a + complex(psb_spk_), intent(in) :: alpha, beta, x(:) + complex(psb_spk_), intent(inout) :: y(:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + end subroutine psb_c_ecsr_csmv + end interface + + !> \memberof psb_c_ecsr_sparse_mat + !! \see psb_c_base_mat_mod::psb_c_base_cp_from_coo + interface + subroutine psb_c_ecsr_cmp_nerwp(a,info) + import + class(psb_c_ecsr_sparse_mat), intent(inout) :: a + integer(psb_ipk_), intent(out) :: info + end subroutine psb_c_ecsr_cmp_nerwp + end interface + + !> \memberof psb_c_ecsr_sparse_mat + !! \see psb_c_base_mat_mod::psb_c_base_cp_from_coo + interface + subroutine psb_c_cp_ecsr_from_coo(a,b,info) + import + class(psb_c_ecsr_sparse_mat), intent(inout) :: a + class(psb_c_coo_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_c_cp_ecsr_from_coo + end interface + + !> \memberof psb_c_ecsr_sparse_mat + !! \see psb_c_base_mat_mod::psb_c_base_cp_from_fmt + interface + subroutine psb_c_cp_ecsr_from_fmt(a,b,info) + import + class(psb_c_ecsr_sparse_mat), intent(inout) :: a + class(psb_c_base_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_c_cp_ecsr_from_fmt + end interface + + !> \memberof psb_c_ecsr_sparse_mat + !! \see psb_c_base_mat_mod::psb_c_base_mv_from_coo + interface + subroutine psb_c_mv_ecsr_from_coo(a,b,info) + import + class(psb_c_ecsr_sparse_mat), intent(inout) :: a + class(psb_c_coo_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_c_mv_ecsr_from_coo + end interface + + !> \memberof psb_c_ecsr_sparse_mat + !! \see psb_c_base_mat_mod::psb_c_base_mv_from_fmt + interface + subroutine psb_c_mv_ecsr_from_fmt(a,b,info) + import + class(psb_c_ecsr_sparse_mat), intent(inout) :: a + class(psb_c_base_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_c_mv_ecsr_from_fmt + end interface + + !> \memberof psb_c_ecsr_sparse_mat + !| \see psb_base_mat_mod::psb_base_mold + interface + subroutine psb_c_ecsr_mold(a,b,info) + import + class(psb_c_ecsr_sparse_mat), intent(in) :: a + class(psb_c_base_sparse_mat), intent(inout), allocatable :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_c_ecsr_mold + end interface + + + + !> \namespace psb_base_mod \class psb_lc_csr_sparse_mat !! \extends psb_lc_base_mat_mod::psb_lc_base_sparse_mat !! !! psb_lc_csr_sparse_mat type and the related methods. @@ -1178,6 +1282,26 @@ contains + function c_ecsr_get_fmt() result(res) + implicit none + character(len=5) :: res + res = 'ECSR' + end function c_ecsr_get_fmt + + subroutine c_ecsr_free(a) + implicit none + + class(psb_c_ecsr_sparse_mat), intent(inout) :: a + + + if (allocated(a%nerwp)) deallocate(a%nerwp) + a%nnerws = 0 + call a%psb_c_csr_sparse_mat%free() + + return + end subroutine c_ecsr_free + + ! == =================================== ! ! diff --git a/base/modules/serial/psb_c_mat_mod.F90 b/base/modules/serial/psb_c_mat_mod.F90 index 2e365858..aa891381 100644 --- a/base/modules/serial/psb_c_mat_mod.F90 +++ b/base/modules/serial/psb_c_mat_mod.F90 @@ -79,7 +79,8 @@ module psb_c_mat_mod use psb_c_base_mat_mod - use psb_c_csr_mat_mod, only : psb_c_csr_sparse_mat, psb_lc_csr_sparse_mat + use psb_c_csr_mat_mod, only : psb_c_csr_sparse_mat, psb_lc_csr_sparse_mat,& + & psb_c_ecsr_sparse_mat use psb_c_csc_mat_mod, only : psb_c_csc_sparse_mat, psb_lc_csc_sparse_mat type :: psb_cspmat_type diff --git a/base/modules/serial/psb_d_csr_mat_mod.f90 b/base/modules/serial/psb_d_csr_mat_mod.f90 index d0aa622b..12d71755 100644 --- a/base/modules/serial/psb_d_csr_mat_mod.f90 +++ b/base/modules/serial/psb_d_csr_mat_mod.f90 @@ -579,7 +579,111 @@ module psb_d_csr_mat_mod end subroutine psb_d_csr_scals end interface - !> \namespace psb_base_mod \class psb_ld_csr_sparse_mat + + type, extends(psb_d_csr_sparse_mat) :: psb_d_ecsr_sparse_mat + + !> Number of non-empty rows + integer(psb_ipk_) :: nnerws + !> Indices of non-empty rows + integer(psb_ipk_), allocatable :: nerwp(:) + + contains + procedure, nopass :: get_fmt => d_ecsr_get_fmt + + ! procedure, pass(a) :: csmm => psb_d_ecsr_csmm + procedure, pass(a) :: csmv => psb_d_ecsr_csmv + + procedure, pass(a) :: cp_from_coo => psb_d_cp_ecsr_from_coo + procedure, pass(a) :: cp_from_fmt => psb_d_cp_ecsr_from_fmt + procedure, pass(a) :: mv_from_coo => psb_d_mv_ecsr_from_coo + procedure, pass(a) :: mv_from_fmt => psb_d_mv_ecsr_from_fmt + + procedure, pass(a) :: cmp_nerwp => psb_d_ecsr_cmp_nerwp + procedure, pass(a) :: free => d_ecsr_free + procedure, pass(a) :: mold => psb_d_ecsr_mold + + end type psb_d_ecsr_sparse_mat + !> \memberof psb_d_ecsr_sparse_mat + !! \see psb_d_base_mat_mod::psb_d_base_csmv + interface + subroutine psb_d_ecsr_csmv(alpha,a,x,beta,y,info,trans) + import + class(psb_d_ecsr_sparse_mat), intent(in) :: a + real(psb_dpk_), intent(in) :: alpha, beta, x(:) + real(psb_dpk_), intent(inout) :: y(:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + end subroutine psb_d_ecsr_csmv + end interface + + !> \memberof psb_d_ecsr_sparse_mat + !! \see psb_d_base_mat_mod::psb_d_base_cp_from_coo + interface + subroutine psb_d_ecsr_cmp_nerwp(a,info) + import + class(psb_d_ecsr_sparse_mat), intent(inout) :: a + integer(psb_ipk_), intent(out) :: info + end subroutine psb_d_ecsr_cmp_nerwp + end interface + + !> \memberof psb_d_ecsr_sparse_mat + !! \see psb_d_base_mat_mod::psb_d_base_cp_from_coo + interface + subroutine psb_d_cp_ecsr_from_coo(a,b,info) + import + class(psb_d_ecsr_sparse_mat), intent(inout) :: a + class(psb_d_coo_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_d_cp_ecsr_from_coo + end interface + + !> \memberof psb_d_ecsr_sparse_mat + !! \see psb_d_base_mat_mod::psb_d_base_cp_from_fmt + interface + subroutine psb_d_cp_ecsr_from_fmt(a,b,info) + import + class(psb_d_ecsr_sparse_mat), intent(inout) :: a + class(psb_d_base_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_d_cp_ecsr_from_fmt + end interface + + !> \memberof psb_d_ecsr_sparse_mat + !! \see psb_d_base_mat_mod::psb_d_base_mv_from_coo + interface + subroutine psb_d_mv_ecsr_from_coo(a,b,info) + import + class(psb_d_ecsr_sparse_mat), intent(inout) :: a + class(psb_d_coo_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_d_mv_ecsr_from_coo + end interface + + !> \memberof psb_d_ecsr_sparse_mat + !! \see psb_d_base_mat_mod::psb_d_base_mv_from_fmt + interface + subroutine psb_d_mv_ecsr_from_fmt(a,b,info) + import + class(psb_d_ecsr_sparse_mat), intent(inout) :: a + class(psb_d_base_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_d_mv_ecsr_from_fmt + end interface + + !> \memberof psb_d_ecsr_sparse_mat + !| \see psb_base_mat_mod::psb_base_mold + interface + subroutine psb_d_ecsr_mold(a,b,info) + import + class(psb_d_ecsr_sparse_mat), intent(in) :: a + class(psb_d_base_sparse_mat), intent(inout), allocatable :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_d_ecsr_mold + end interface + + + + !> \namespace psb_base_mod \class psb_ld_csr_sparse_mat !! \extends psb_ld_base_mat_mod::psb_ld_base_sparse_mat !! !! psb_ld_csr_sparse_mat type and the related methods. @@ -1178,6 +1282,26 @@ contains + function d_ecsr_get_fmt() result(res) + implicit none + character(len=5) :: res + res = 'ECSR' + end function d_ecsr_get_fmt + + subroutine d_ecsr_free(a) + implicit none + + class(psb_d_ecsr_sparse_mat), intent(inout) :: a + + + if (allocated(a%nerwp)) deallocate(a%nerwp) + a%nnerws = 0 + call a%psb_d_csr_sparse_mat%free() + + return + end subroutine d_ecsr_free + + ! == =================================== ! ! diff --git a/base/modules/serial/psb_d_mat_mod.F90 b/base/modules/serial/psb_d_mat_mod.F90 index 49a9545e..c647e76b 100644 --- a/base/modules/serial/psb_d_mat_mod.F90 +++ b/base/modules/serial/psb_d_mat_mod.F90 @@ -79,7 +79,8 @@ module psb_d_mat_mod use psb_d_base_mat_mod - use psb_d_csr_mat_mod, only : psb_d_csr_sparse_mat, psb_ld_csr_sparse_mat + use psb_d_csr_mat_mod, only : psb_d_csr_sparse_mat, psb_ld_csr_sparse_mat,& + & psb_d_ecsr_sparse_mat use psb_d_csc_mat_mod, only : psb_d_csc_sparse_mat, psb_ld_csc_sparse_mat type :: psb_dspmat_type diff --git a/base/modules/serial/psb_s_csr_mat_mod.f90 b/base/modules/serial/psb_s_csr_mat_mod.f90 index 6b4c51c7..884ede38 100644 --- a/base/modules/serial/psb_s_csr_mat_mod.f90 +++ b/base/modules/serial/psb_s_csr_mat_mod.f90 @@ -579,7 +579,111 @@ module psb_s_csr_mat_mod end subroutine psb_s_csr_scals end interface - !> \namespace psb_base_mod \class psb_ls_csr_sparse_mat + + type, extends(psb_s_csr_sparse_mat) :: psb_s_ecsr_sparse_mat + + !> Number of non-empty rows + integer(psb_ipk_) :: nnerws + !> Indices of non-empty rows + integer(psb_ipk_), allocatable :: nerwp(:) + + contains + procedure, nopass :: get_fmt => s_ecsr_get_fmt + + ! procedure, pass(a) :: csmm => psb_s_ecsr_csmm + procedure, pass(a) :: csmv => psb_s_ecsr_csmv + + procedure, pass(a) :: cp_from_coo => psb_s_cp_ecsr_from_coo + procedure, pass(a) :: cp_from_fmt => psb_s_cp_ecsr_from_fmt + procedure, pass(a) :: mv_from_coo => psb_s_mv_ecsr_from_coo + procedure, pass(a) :: mv_from_fmt => psb_s_mv_ecsr_from_fmt + + procedure, pass(a) :: cmp_nerwp => psb_s_ecsr_cmp_nerwp + procedure, pass(a) :: free => s_ecsr_free + procedure, pass(a) :: mold => psb_s_ecsr_mold + + end type psb_s_ecsr_sparse_mat + !> \memberof psb_s_ecsr_sparse_mat + !! \see psb_s_base_mat_mod::psb_s_base_csmv + interface + subroutine psb_s_ecsr_csmv(alpha,a,x,beta,y,info,trans) + import + class(psb_s_ecsr_sparse_mat), intent(in) :: a + real(psb_spk_), intent(in) :: alpha, beta, x(:) + real(psb_spk_), intent(inout) :: y(:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + end subroutine psb_s_ecsr_csmv + end interface + + !> \memberof psb_s_ecsr_sparse_mat + !! \see psb_s_base_mat_mod::psb_s_base_cp_from_coo + interface + subroutine psb_s_ecsr_cmp_nerwp(a,info) + import + class(psb_s_ecsr_sparse_mat), intent(inout) :: a + integer(psb_ipk_), intent(out) :: info + end subroutine psb_s_ecsr_cmp_nerwp + end interface + + !> \memberof psb_s_ecsr_sparse_mat + !! \see psb_s_base_mat_mod::psb_s_base_cp_from_coo + interface + subroutine psb_s_cp_ecsr_from_coo(a,b,info) + import + class(psb_s_ecsr_sparse_mat), intent(inout) :: a + class(psb_s_coo_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_s_cp_ecsr_from_coo + end interface + + !> \memberof psb_s_ecsr_sparse_mat + !! \see psb_s_base_mat_mod::psb_s_base_cp_from_fmt + interface + subroutine psb_s_cp_ecsr_from_fmt(a,b,info) + import + class(psb_s_ecsr_sparse_mat), intent(inout) :: a + class(psb_s_base_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_s_cp_ecsr_from_fmt + end interface + + !> \memberof psb_s_ecsr_sparse_mat + !! \see psb_s_base_mat_mod::psb_s_base_mv_from_coo + interface + subroutine psb_s_mv_ecsr_from_coo(a,b,info) + import + class(psb_s_ecsr_sparse_mat), intent(inout) :: a + class(psb_s_coo_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_s_mv_ecsr_from_coo + end interface + + !> \memberof psb_s_ecsr_sparse_mat + !! \see psb_s_base_mat_mod::psb_s_base_mv_from_fmt + interface + subroutine psb_s_mv_ecsr_from_fmt(a,b,info) + import + class(psb_s_ecsr_sparse_mat), intent(inout) :: a + class(psb_s_base_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_s_mv_ecsr_from_fmt + end interface + + !> \memberof psb_s_ecsr_sparse_mat + !| \see psb_base_mat_mod::psb_base_mold + interface + subroutine psb_s_ecsr_mold(a,b,info) + import + class(psb_s_ecsr_sparse_mat), intent(in) :: a + class(psb_s_base_sparse_mat), intent(inout), allocatable :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_s_ecsr_mold + end interface + + + + !> \namespace psb_base_mod \class psb_ls_csr_sparse_mat !! \extends psb_ls_base_mat_mod::psb_ls_base_sparse_mat !! !! psb_ls_csr_sparse_mat type and the related methods. @@ -1178,6 +1282,26 @@ contains + function s_ecsr_get_fmt() result(res) + implicit none + character(len=5) :: res + res = 'ECSR' + end function s_ecsr_get_fmt + + subroutine s_ecsr_free(a) + implicit none + + class(psb_s_ecsr_sparse_mat), intent(inout) :: a + + + if (allocated(a%nerwp)) deallocate(a%nerwp) + a%nnerws = 0 + call a%psb_s_csr_sparse_mat%free() + + return + end subroutine s_ecsr_free + + ! == =================================== ! ! diff --git a/base/modules/serial/psb_s_mat_mod.F90 b/base/modules/serial/psb_s_mat_mod.F90 index eb444249..3e6b286a 100644 --- a/base/modules/serial/psb_s_mat_mod.F90 +++ b/base/modules/serial/psb_s_mat_mod.F90 @@ -79,7 +79,8 @@ module psb_s_mat_mod use psb_s_base_mat_mod - use psb_s_csr_mat_mod, only : psb_s_csr_sparse_mat, psb_ls_csr_sparse_mat + use psb_s_csr_mat_mod, only : psb_s_csr_sparse_mat, psb_ls_csr_sparse_mat,& + & psb_s_ecsr_sparse_mat use psb_s_csc_mat_mod, only : psb_s_csc_sparse_mat, psb_ls_csc_sparse_mat type :: psb_sspmat_type diff --git a/base/modules/serial/psb_z_csr_mat_mod.f90 b/base/modules/serial/psb_z_csr_mat_mod.f90 index 4ec8dd00..c328fead 100644 --- a/base/modules/serial/psb_z_csr_mat_mod.f90 +++ b/base/modules/serial/psb_z_csr_mat_mod.f90 @@ -579,7 +579,111 @@ module psb_z_csr_mat_mod end subroutine psb_z_csr_scals end interface - !> \namespace psb_base_mod \class psb_lz_csr_sparse_mat + + type, extends(psb_z_csr_sparse_mat) :: psb_z_ecsr_sparse_mat + + !> Number of non-empty rows + integer(psb_ipk_) :: nnerws + !> Indices of non-empty rows + integer(psb_ipk_), allocatable :: nerwp(:) + + contains + procedure, nopass :: get_fmt => z_ecsr_get_fmt + + ! procedure, pass(a) :: csmm => psb_z_ecsr_csmm + procedure, pass(a) :: csmv => psb_z_ecsr_csmv + + procedure, pass(a) :: cp_from_coo => psb_z_cp_ecsr_from_coo + procedure, pass(a) :: cp_from_fmt => psb_z_cp_ecsr_from_fmt + procedure, pass(a) :: mv_from_coo => psb_z_mv_ecsr_from_coo + procedure, pass(a) :: mv_from_fmt => psb_z_mv_ecsr_from_fmt + + procedure, pass(a) :: cmp_nerwp => psb_z_ecsr_cmp_nerwp + procedure, pass(a) :: free => z_ecsr_free + procedure, pass(a) :: mold => psb_z_ecsr_mold + + end type psb_z_ecsr_sparse_mat + !> \memberof psb_z_ecsr_sparse_mat + !! \see psb_z_base_mat_mod::psb_z_base_csmv + interface + subroutine psb_z_ecsr_csmv(alpha,a,x,beta,y,info,trans) + import + class(psb_z_ecsr_sparse_mat), intent(in) :: a + complex(psb_dpk_), intent(in) :: alpha, beta, x(:) + complex(psb_dpk_), intent(inout) :: y(:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + end subroutine psb_z_ecsr_csmv + end interface + + !> \memberof psb_z_ecsr_sparse_mat + !! \see psb_z_base_mat_mod::psb_z_base_cp_from_coo + interface + subroutine psb_z_ecsr_cmp_nerwp(a,info) + import + class(psb_z_ecsr_sparse_mat), intent(inout) :: a + integer(psb_ipk_), intent(out) :: info + end subroutine psb_z_ecsr_cmp_nerwp + end interface + + !> \memberof psb_z_ecsr_sparse_mat + !! \see psb_z_base_mat_mod::psb_z_base_cp_from_coo + interface + subroutine psb_z_cp_ecsr_from_coo(a,b,info) + import + class(psb_z_ecsr_sparse_mat), intent(inout) :: a + class(psb_z_coo_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_z_cp_ecsr_from_coo + end interface + + !> \memberof psb_z_ecsr_sparse_mat + !! \see psb_z_base_mat_mod::psb_z_base_cp_from_fmt + interface + subroutine psb_z_cp_ecsr_from_fmt(a,b,info) + import + class(psb_z_ecsr_sparse_mat), intent(inout) :: a + class(psb_z_base_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_z_cp_ecsr_from_fmt + end interface + + !> \memberof psb_z_ecsr_sparse_mat + !! \see psb_z_base_mat_mod::psb_z_base_mv_from_coo + interface + subroutine psb_z_mv_ecsr_from_coo(a,b,info) + import + class(psb_z_ecsr_sparse_mat), intent(inout) :: a + class(psb_z_coo_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_z_mv_ecsr_from_coo + end interface + + !> \memberof psb_z_ecsr_sparse_mat + !! \see psb_z_base_mat_mod::psb_z_base_mv_from_fmt + interface + subroutine psb_z_mv_ecsr_from_fmt(a,b,info) + import + class(psb_z_ecsr_sparse_mat), intent(inout) :: a + class(psb_z_base_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_z_mv_ecsr_from_fmt + end interface + + !> \memberof psb_z_ecsr_sparse_mat + !| \see psb_base_mat_mod::psb_base_mold + interface + subroutine psb_z_ecsr_mold(a,b,info) + import + class(psb_z_ecsr_sparse_mat), intent(in) :: a + class(psb_z_base_sparse_mat), intent(inout), allocatable :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_z_ecsr_mold + end interface + + + + !> \namespace psb_base_mod \class psb_lz_csr_sparse_mat !! \extends psb_lz_base_mat_mod::psb_lz_base_sparse_mat !! !! psb_lz_csr_sparse_mat type and the related methods. @@ -1178,6 +1282,26 @@ contains + function z_ecsr_get_fmt() result(res) + implicit none + character(len=5) :: res + res = 'ECSR' + end function z_ecsr_get_fmt + + subroutine z_ecsr_free(a) + implicit none + + class(psb_z_ecsr_sparse_mat), intent(inout) :: a + + + if (allocated(a%nerwp)) deallocate(a%nerwp) + a%nnerws = 0 + call a%psb_z_csr_sparse_mat%free() + + return + end subroutine z_ecsr_free + + ! == =================================== ! ! diff --git a/base/modules/serial/psb_z_mat_mod.F90 b/base/modules/serial/psb_z_mat_mod.F90 index e70e48aa..148e9ab9 100644 --- a/base/modules/serial/psb_z_mat_mod.F90 +++ b/base/modules/serial/psb_z_mat_mod.F90 @@ -79,7 +79,8 @@ module psb_z_mat_mod use psb_z_base_mat_mod - use psb_z_csr_mat_mod, only : psb_z_csr_sparse_mat, psb_lz_csr_sparse_mat + use psb_z_csr_mat_mod, only : psb_z_csr_sparse_mat, psb_lz_csr_sparse_mat,& + & psb_z_ecsr_sparse_mat use psb_z_csc_mat_mod, only : psb_z_csc_sparse_mat, psb_lz_csc_sparse_mat type :: psb_zspmat_type diff --git a/base/serial/impl/psb_c_csr_impl.f90 b/base/serial/impl/psb_c_csr_impl.f90 index 55a91648..1fed09ba 100644 --- a/base/serial/impl/psb_c_csr_impl.f90 +++ b/base/serial/impl/psb_c_csr_impl.f90 @@ -3550,6 +3550,269 @@ contains end subroutine psb_ccsrspspmm +subroutine psb_c_ecsr_mold(a,b,info) + use psb_c_csr_mat_mod, psb_protect_name => psb_c_ecsr_mold + use psb_error_mod + implicit none + class(psb_c_ecsr_sparse_mat), intent(in) :: a + class(psb_c_base_sparse_mat), intent(inout), allocatable :: b + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: err_act + character(len=20) :: name='ecsr_mold' + logical, parameter :: debug=.false. + + call psb_get_erraction(err_act) + + info = 0 + if (allocated(b)) then + call b%free() + deallocate(b,stat=info) + end if + if (info == 0) allocate(psb_c_ecsr_sparse_mat :: b, stat=info) + + if (info /= 0) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info, name) + goto 9999 + end if + return + +9999 call psb_error_handler(err_act) + return + +end subroutine psb_c_ecsr_mold + +subroutine psb_c_ecsr_csmv(alpha,a,x,beta,y,info,trans) + use psb_error_mod + use psb_string_mod + use psb_c_csr_mat_mod, psb_protect_name => psb_c_ecsr_csmv + implicit none + class(psb_c_ecsr_sparse_mat), intent(in) :: a + complex(psb_spk_), intent(in) :: alpha, beta, x(:) + complex(psb_spk_), intent(inout) :: y(:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + + character :: trans_ + integer(psb_ipk_) :: m, n + logical :: tra, ctra + integer(psb_ipk_) :: err_act + integer(psb_ipk_) :: ierr(5) + character(len=20) :: name='c_csr_csmv' + logical, parameter :: debug=.false. + + call psb_erractionsave(err_act) + info = psb_success_ + if (a%is_dev()) call a%sync() + + if (present(trans)) then + trans_ = trans + else + trans_ = 'N' + end if + + if (.not.a%is_asb()) then + info = psb_err_invalid_mat_state_ + call psb_errpush(info,name) + goto 9999 + endif + + + tra = (psb_toupper(trans_) == 'T') + ctra = (psb_toupper(trans_) == 'C') + + if (tra.or.ctra) then + m = a%get_ncols() + n = a%get_nrows() + else + n = a%get_ncols() + m = a%get_nrows() + end if + + if (size(x,1) psb_c_ecsr_cmp_nerwp + implicit none + + class(psb_c_ecsr_sparse_mat), intent(inout) :: a + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: nnerws, i, nr, nzr + info = psb_success_ + nr = a%get_nrows() + call psb_realloc(nr,a%nerwp,info) + nnerws = 0 + do i=1, nr + nzr = a%irp(i+1)-a%irp(i) + if (nzr>0) then + nnerws = nnerws + 1 + a%nerwp(nnerws) = i + end if + end do + call psb_realloc(nnerws,a%nerwp,info) +end subroutine psb_c_ecsr_cmp_nerwp + +subroutine psb_c_cp_ecsr_from_coo(a,b,info) + use psb_const_mod + use psb_realloc_mod + use psb_c_base_mat_mod + use psb_c_csr_mat_mod, psb_protect_name => psb_c_cp_ecsr_from_coo + implicit none + + class(psb_c_ecsr_sparse_mat), intent(inout) :: a + class(psb_c_coo_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + call a%psb_c_csr_sparse_mat%cp_from_coo(b,info) + if (info == psb_success_) call a%cmp_nerwp(info) + +end subroutine psb_c_cp_ecsr_from_coo + +subroutine psb_c_mv_ecsr_from_coo(a,b,info) + use psb_const_mod + use psb_realloc_mod + use psb_error_mod + use psb_c_base_mat_mod + use psb_c_csr_mat_mod, psb_protect_name => psb_c_mv_ecsr_from_coo + implicit none + + class(psb_c_ecsr_sparse_mat), intent(inout) :: a + class(psb_c_coo_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + + + info = psb_success_ + call a%psb_c_csr_sparse_mat%mv_from_coo(b,info) + if (info == psb_success_) call a%cmp_nerwp(info) + +end subroutine psb_c_mv_ecsr_from_coo + +subroutine psb_c_mv_ecsr_from_fmt(a,b,info) + use psb_const_mod + use psb_c_base_mat_mod + use psb_c_csr_mat_mod, psb_protect_name => psb_c_mv_ecsr_from_fmt + implicit none + + class(psb_c_ecsr_sparse_mat), intent(inout) :: a + class(psb_c_base_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + + + info = psb_success_ + call a%psb_c_csr_sparse_mat%mv_from_fmt(b,info) + if (info == psb_success_) call a%cmp_nerwp(info) + +end subroutine psb_c_mv_ecsr_from_fmt + +subroutine psb_c_cp_ecsr_from_fmt(a,b,info) + use psb_const_mod + use psb_c_base_mat_mod + use psb_realloc_mod + use psb_c_csr_mat_mod, psb_protect_name => psb_c_cp_ecsr_from_fmt + implicit none + + class(psb_c_ecsr_sparse_mat), intent(inout) :: a + class(psb_c_base_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + + + info = psb_success_ + call a%psb_c_csr_sparse_mat%cp_from_fmt(b,info) + if (info == psb_success_) call a%cmp_nerwp(info) + +end subroutine psb_c_cp_ecsr_from_fmt + ! ! diff --git a/base/serial/impl/psb_d_csr_impl.f90 b/base/serial/impl/psb_d_csr_impl.f90 index 2c59c1a5..1bcc82a9 100644 --- a/base/serial/impl/psb_d_csr_impl.f90 +++ b/base/serial/impl/psb_d_csr_impl.f90 @@ -3550,6 +3550,269 @@ contains end subroutine psb_dcsrspspmm +subroutine psb_d_ecsr_mold(a,b,info) + use psb_d_csr_mat_mod, psb_protect_name => psb_d_ecsr_mold + use psb_error_mod + implicit none + class(psb_d_ecsr_sparse_mat), intent(in) :: a + class(psb_d_base_sparse_mat), intent(inout), allocatable :: b + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: err_act + character(len=20) :: name='ecsr_mold' + logical, parameter :: debug=.false. + + call psb_get_erraction(err_act) + + info = 0 + if (allocated(b)) then + call b%free() + deallocate(b,stat=info) + end if + if (info == 0) allocate(psb_d_ecsr_sparse_mat :: b, stat=info) + + if (info /= 0) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info, name) + goto 9999 + end if + return + +9999 call psb_error_handler(err_act) + return + +end subroutine psb_d_ecsr_mold + +subroutine psb_d_ecsr_csmv(alpha,a,x,beta,y,info,trans) + use psb_error_mod + use psb_string_mod + use psb_d_csr_mat_mod, psb_protect_name => psb_d_ecsr_csmv + implicit none + class(psb_d_ecsr_sparse_mat), intent(in) :: a + real(psb_dpk_), intent(in) :: alpha, beta, x(:) + real(psb_dpk_), intent(inout) :: y(:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + + character :: trans_ + integer(psb_ipk_) :: m, n + logical :: tra, ctra + integer(psb_ipk_) :: err_act + integer(psb_ipk_) :: ierr(5) + character(len=20) :: name='d_csr_csmv' + logical, parameter :: debug=.false. + + call psb_erractionsave(err_act) + info = psb_success_ + if (a%is_dev()) call a%sync() + + if (present(trans)) then + trans_ = trans + else + trans_ = 'N' + end if + + if (.not.a%is_asb()) then + info = psb_err_invalid_mat_state_ + call psb_errpush(info,name) + goto 9999 + endif + + + tra = (psb_toupper(trans_) == 'T') + ctra = (psb_toupper(trans_) == 'C') + + if (tra.or.ctra) then + m = a%get_ncols() + n = a%get_nrows() + else + n = a%get_ncols() + m = a%get_nrows() + end if + + if (size(x,1) psb_d_ecsr_cmp_nerwp + implicit none + + class(psb_d_ecsr_sparse_mat), intent(inout) :: a + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: nnerws, i, nr, nzr + info = psb_success_ + nr = a%get_nrows() + call psb_realloc(nr,a%nerwp,info) + nnerws = 0 + do i=1, nr + nzr = a%irp(i+1)-a%irp(i) + if (nzr>0) then + nnerws = nnerws + 1 + a%nerwp(nnerws) = i + end if + end do + call psb_realloc(nnerws,a%nerwp,info) +end subroutine psb_d_ecsr_cmp_nerwp + +subroutine psb_d_cp_ecsr_from_coo(a,b,info) + use psb_const_mod + use psb_realloc_mod + use psb_d_base_mat_mod + use psb_d_csr_mat_mod, psb_protect_name => psb_d_cp_ecsr_from_coo + implicit none + + class(psb_d_ecsr_sparse_mat), intent(inout) :: a + class(psb_d_coo_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + call a%psb_d_csr_sparse_mat%cp_from_coo(b,info) + if (info == psb_success_) call a%cmp_nerwp(info) + +end subroutine psb_d_cp_ecsr_from_coo + +subroutine psb_d_mv_ecsr_from_coo(a,b,info) + use psb_const_mod + use psb_realloc_mod + use psb_error_mod + use psb_d_base_mat_mod + use psb_d_csr_mat_mod, psb_protect_name => psb_d_mv_ecsr_from_coo + implicit none + + class(psb_d_ecsr_sparse_mat), intent(inout) :: a + class(psb_d_coo_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + + + info = psb_success_ + call a%psb_d_csr_sparse_mat%mv_from_coo(b,info) + if (info == psb_success_) call a%cmp_nerwp(info) + +end subroutine psb_d_mv_ecsr_from_coo + +subroutine psb_d_mv_ecsr_from_fmt(a,b,info) + use psb_const_mod + use psb_d_base_mat_mod + use psb_d_csr_mat_mod, psb_protect_name => psb_d_mv_ecsr_from_fmt + implicit none + + class(psb_d_ecsr_sparse_mat), intent(inout) :: a + class(psb_d_base_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + + + info = psb_success_ + call a%psb_d_csr_sparse_mat%mv_from_fmt(b,info) + if (info == psb_success_) call a%cmp_nerwp(info) + +end subroutine psb_d_mv_ecsr_from_fmt + +subroutine psb_d_cp_ecsr_from_fmt(a,b,info) + use psb_const_mod + use psb_d_base_mat_mod + use psb_realloc_mod + use psb_d_csr_mat_mod, psb_protect_name => psb_d_cp_ecsr_from_fmt + implicit none + + class(psb_d_ecsr_sparse_mat), intent(inout) :: a + class(psb_d_base_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + + + info = psb_success_ + call a%psb_d_csr_sparse_mat%cp_from_fmt(b,info) + if (info == psb_success_) call a%cmp_nerwp(info) + +end subroutine psb_d_cp_ecsr_from_fmt + ! ! diff --git a/base/serial/impl/psb_s_csr_impl.f90 b/base/serial/impl/psb_s_csr_impl.f90 index 75358dbc..9670aeb9 100644 --- a/base/serial/impl/psb_s_csr_impl.f90 +++ b/base/serial/impl/psb_s_csr_impl.f90 @@ -3550,6 +3550,269 @@ contains end subroutine psb_scsrspspmm +subroutine psb_s_ecsr_mold(a,b,info) + use psb_s_csr_mat_mod, psb_protect_name => psb_s_ecsr_mold + use psb_error_mod + implicit none + class(psb_s_ecsr_sparse_mat), intent(in) :: a + class(psb_s_base_sparse_mat), intent(inout), allocatable :: b + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: err_act + character(len=20) :: name='ecsr_mold' + logical, parameter :: debug=.false. + + call psb_get_erraction(err_act) + + info = 0 + if (allocated(b)) then + call b%free() + deallocate(b,stat=info) + end if + if (info == 0) allocate(psb_s_ecsr_sparse_mat :: b, stat=info) + + if (info /= 0) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info, name) + goto 9999 + end if + return + +9999 call psb_error_handler(err_act) + return + +end subroutine psb_s_ecsr_mold + +subroutine psb_s_ecsr_csmv(alpha,a,x,beta,y,info,trans) + use psb_error_mod + use psb_string_mod + use psb_s_csr_mat_mod, psb_protect_name => psb_s_ecsr_csmv + implicit none + class(psb_s_ecsr_sparse_mat), intent(in) :: a + real(psb_spk_), intent(in) :: alpha, beta, x(:) + real(psb_spk_), intent(inout) :: y(:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + + character :: trans_ + integer(psb_ipk_) :: m, n + logical :: tra, ctra + integer(psb_ipk_) :: err_act + integer(psb_ipk_) :: ierr(5) + character(len=20) :: name='s_csr_csmv' + logical, parameter :: debug=.false. + + call psb_erractionsave(err_act) + info = psb_success_ + if (a%is_dev()) call a%sync() + + if (present(trans)) then + trans_ = trans + else + trans_ = 'N' + end if + + if (.not.a%is_asb()) then + info = psb_err_invalid_mat_state_ + call psb_errpush(info,name) + goto 9999 + endif + + + tra = (psb_toupper(trans_) == 'T') + ctra = (psb_toupper(trans_) == 'C') + + if (tra.or.ctra) then + m = a%get_ncols() + n = a%get_nrows() + else + n = a%get_ncols() + m = a%get_nrows() + end if + + if (size(x,1) psb_s_ecsr_cmp_nerwp + implicit none + + class(psb_s_ecsr_sparse_mat), intent(inout) :: a + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: nnerws, i, nr, nzr + info = psb_success_ + nr = a%get_nrows() + call psb_realloc(nr,a%nerwp,info) + nnerws = 0 + do i=1, nr + nzr = a%irp(i+1)-a%irp(i) + if (nzr>0) then + nnerws = nnerws + 1 + a%nerwp(nnerws) = i + end if + end do + call psb_realloc(nnerws,a%nerwp,info) +end subroutine psb_s_ecsr_cmp_nerwp + +subroutine psb_s_cp_ecsr_from_coo(a,b,info) + use psb_const_mod + use psb_realloc_mod + use psb_s_base_mat_mod + use psb_s_csr_mat_mod, psb_protect_name => psb_s_cp_ecsr_from_coo + implicit none + + class(psb_s_ecsr_sparse_mat), intent(inout) :: a + class(psb_s_coo_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + call a%psb_s_csr_sparse_mat%cp_from_coo(b,info) + if (info == psb_success_) call a%cmp_nerwp(info) + +end subroutine psb_s_cp_ecsr_from_coo + +subroutine psb_s_mv_ecsr_from_coo(a,b,info) + use psb_const_mod + use psb_realloc_mod + use psb_error_mod + use psb_s_base_mat_mod + use psb_s_csr_mat_mod, psb_protect_name => psb_s_mv_ecsr_from_coo + implicit none + + class(psb_s_ecsr_sparse_mat), intent(inout) :: a + class(psb_s_coo_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + + + info = psb_success_ + call a%psb_s_csr_sparse_mat%mv_from_coo(b,info) + if (info == psb_success_) call a%cmp_nerwp(info) + +end subroutine psb_s_mv_ecsr_from_coo + +subroutine psb_s_mv_ecsr_from_fmt(a,b,info) + use psb_const_mod + use psb_s_base_mat_mod + use psb_s_csr_mat_mod, psb_protect_name => psb_s_mv_ecsr_from_fmt + implicit none + + class(psb_s_ecsr_sparse_mat), intent(inout) :: a + class(psb_s_base_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + + + info = psb_success_ + call a%psb_s_csr_sparse_mat%mv_from_fmt(b,info) + if (info == psb_success_) call a%cmp_nerwp(info) + +end subroutine psb_s_mv_ecsr_from_fmt + +subroutine psb_s_cp_ecsr_from_fmt(a,b,info) + use psb_const_mod + use psb_s_base_mat_mod + use psb_realloc_mod + use psb_s_csr_mat_mod, psb_protect_name => psb_s_cp_ecsr_from_fmt + implicit none + + class(psb_s_ecsr_sparse_mat), intent(inout) :: a + class(psb_s_base_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + + + info = psb_success_ + call a%psb_s_csr_sparse_mat%cp_from_fmt(b,info) + if (info == psb_success_) call a%cmp_nerwp(info) + +end subroutine psb_s_cp_ecsr_from_fmt + ! ! diff --git a/base/serial/impl/psb_z_csr_impl.f90 b/base/serial/impl/psb_z_csr_impl.f90 index 4f2693c0..e9847849 100644 --- a/base/serial/impl/psb_z_csr_impl.f90 +++ b/base/serial/impl/psb_z_csr_impl.f90 @@ -3550,6 +3550,269 @@ contains end subroutine psb_zcsrspspmm +subroutine psb_z_ecsr_mold(a,b,info) + use psb_z_csr_mat_mod, psb_protect_name => psb_z_ecsr_mold + use psb_error_mod + implicit none + class(psb_z_ecsr_sparse_mat), intent(in) :: a + class(psb_z_base_sparse_mat), intent(inout), allocatable :: b + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: err_act + character(len=20) :: name='ecsr_mold' + logical, parameter :: debug=.false. + + call psb_get_erraction(err_act) + + info = 0 + if (allocated(b)) then + call b%free() + deallocate(b,stat=info) + end if + if (info == 0) allocate(psb_z_ecsr_sparse_mat :: b, stat=info) + + if (info /= 0) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info, name) + goto 9999 + end if + return + +9999 call psb_error_handler(err_act) + return + +end subroutine psb_z_ecsr_mold + +subroutine psb_z_ecsr_csmv(alpha,a,x,beta,y,info,trans) + use psb_error_mod + use psb_string_mod + use psb_z_csr_mat_mod, psb_protect_name => psb_z_ecsr_csmv + implicit none + class(psb_z_ecsr_sparse_mat), intent(in) :: a + complex(psb_dpk_), intent(in) :: alpha, beta, x(:) + complex(psb_dpk_), intent(inout) :: y(:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + + character :: trans_ + integer(psb_ipk_) :: m, n + logical :: tra, ctra + integer(psb_ipk_) :: err_act + integer(psb_ipk_) :: ierr(5) + character(len=20) :: name='z_csr_csmv' + logical, parameter :: debug=.false. + + call psb_erractionsave(err_act) + info = psb_success_ + if (a%is_dev()) call a%sync() + + if (present(trans)) then + trans_ = trans + else + trans_ = 'N' + end if + + if (.not.a%is_asb()) then + info = psb_err_invalid_mat_state_ + call psb_errpush(info,name) + goto 9999 + endif + + + tra = (psb_toupper(trans_) == 'T') + ctra = (psb_toupper(trans_) == 'C') + + if (tra.or.ctra) then + m = a%get_ncols() + n = a%get_nrows() + else + n = a%get_ncols() + m = a%get_nrows() + end if + + if (size(x,1) psb_z_ecsr_cmp_nerwp + implicit none + + class(psb_z_ecsr_sparse_mat), intent(inout) :: a + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: nnerws, i, nr, nzr + info = psb_success_ + nr = a%get_nrows() + call psb_realloc(nr,a%nerwp,info) + nnerws = 0 + do i=1, nr + nzr = a%irp(i+1)-a%irp(i) + if (nzr>0) then + nnerws = nnerws + 1 + a%nerwp(nnerws) = i + end if + end do + call psb_realloc(nnerws,a%nerwp,info) +end subroutine psb_z_ecsr_cmp_nerwp + +subroutine psb_z_cp_ecsr_from_coo(a,b,info) + use psb_const_mod + use psb_realloc_mod + use psb_z_base_mat_mod + use psb_z_csr_mat_mod, psb_protect_name => psb_z_cp_ecsr_from_coo + implicit none + + class(psb_z_ecsr_sparse_mat), intent(inout) :: a + class(psb_z_coo_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + call a%psb_z_csr_sparse_mat%cp_from_coo(b,info) + if (info == psb_success_) call a%cmp_nerwp(info) + +end subroutine psb_z_cp_ecsr_from_coo + +subroutine psb_z_mv_ecsr_from_coo(a,b,info) + use psb_const_mod + use psb_realloc_mod + use psb_error_mod + use psb_z_base_mat_mod + use psb_z_csr_mat_mod, psb_protect_name => psb_z_mv_ecsr_from_coo + implicit none + + class(psb_z_ecsr_sparse_mat), intent(inout) :: a + class(psb_z_coo_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + + + info = psb_success_ + call a%psb_z_csr_sparse_mat%mv_from_coo(b,info) + if (info == psb_success_) call a%cmp_nerwp(info) + +end subroutine psb_z_mv_ecsr_from_coo + +subroutine psb_z_mv_ecsr_from_fmt(a,b,info) + use psb_const_mod + use psb_z_base_mat_mod + use psb_z_csr_mat_mod, psb_protect_name => psb_z_mv_ecsr_from_fmt + implicit none + + class(psb_z_ecsr_sparse_mat), intent(inout) :: a + class(psb_z_base_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + + + info = psb_success_ + call a%psb_z_csr_sparse_mat%mv_from_fmt(b,info) + if (info == psb_success_) call a%cmp_nerwp(info) + +end subroutine psb_z_mv_ecsr_from_fmt + +subroutine psb_z_cp_ecsr_from_fmt(a,b,info) + use psb_const_mod + use psb_z_base_mat_mod + use psb_realloc_mod + use psb_z_csr_mat_mod, psb_protect_name => psb_z_cp_ecsr_from_fmt + implicit none + + class(psb_z_ecsr_sparse_mat), intent(inout) :: a + class(psb_z_base_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + + + info = psb_success_ + call a%psb_z_csr_sparse_mat%cp_from_fmt(b,info) + if (info == psb_success_) call a%cmp_nerwp(info) + +end subroutine psb_z_cp_ecsr_from_fmt + ! ! diff --git a/base/tools/psb_cspasb.f90 b/base/tools/psb_cspasb.f90 index ea7789f2..b4c957b0 100644 --- a/base/tools/psb_cspasb.f90 +++ b/base/tools/psb_cspasb.f90 @@ -175,13 +175,14 @@ subroutine psb_cspasb(a,desc_a, info, afmt, upd, mold) block character(len=1024) :: fname type(psb_c_coo_sparse_mat) :: acoo - type(psb_c_csr_sparse_mat), allocatable :: aclip, andclip + type(psb_c_csr_sparse_mat), allocatable :: aclip + type(psb_c_ecsr_sparse_mat), allocatable :: andclip allocate(aclip,andclip) call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) - call aclip%mv_from_coo(acoo,info) + allocate(a%ad,mold=a%a) + call a%ad%mv_from_coo(acoo,info) call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) call andclip%mv_from_coo(acoo,info) - call move_alloc(aclip,a%ad) call move_alloc(andclip,a%and) if (.false.) then write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' diff --git a/base/tools/psb_dspasb.f90 b/base/tools/psb_dspasb.f90 index 89ceef8d..5ebc47e8 100644 --- a/base/tools/psb_dspasb.f90 +++ b/base/tools/psb_dspasb.f90 @@ -175,13 +175,14 @@ subroutine psb_dspasb(a,desc_a, info, afmt, upd, mold) block character(len=1024) :: fname type(psb_d_coo_sparse_mat) :: acoo - type(psb_d_csr_sparse_mat), allocatable :: aclip, andclip + type(psb_d_csr_sparse_mat), allocatable :: aclip + type(psb_d_ecsr_sparse_mat), allocatable :: andclip allocate(aclip,andclip) call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) - call aclip%mv_from_coo(acoo,info) + allocate(a%ad,mold=a%a) + call a%ad%mv_from_coo(acoo,info) call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) call andclip%mv_from_coo(acoo,info) - call move_alloc(aclip,a%ad) call move_alloc(andclip,a%and) if (.false.) then write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' diff --git a/base/tools/psb_sspasb.f90 b/base/tools/psb_sspasb.f90 index 14ad5246..5423c2a7 100644 --- a/base/tools/psb_sspasb.f90 +++ b/base/tools/psb_sspasb.f90 @@ -175,13 +175,14 @@ subroutine psb_sspasb(a,desc_a, info, afmt, upd, mold) block character(len=1024) :: fname type(psb_s_coo_sparse_mat) :: acoo - type(psb_s_csr_sparse_mat), allocatable :: aclip, andclip + type(psb_s_csr_sparse_mat), allocatable :: aclip + type(psb_s_ecsr_sparse_mat), allocatable :: andclip allocate(aclip,andclip) call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) - call aclip%mv_from_coo(acoo,info) + allocate(a%ad,mold=a%a) + call a%ad%mv_from_coo(acoo,info) call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) call andclip%mv_from_coo(acoo,info) - call move_alloc(aclip,a%ad) call move_alloc(andclip,a%and) if (.false.) then write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' diff --git a/base/tools/psb_zspasb.f90 b/base/tools/psb_zspasb.f90 index f65be363..66fc8cd7 100644 --- a/base/tools/psb_zspasb.f90 +++ b/base/tools/psb_zspasb.f90 @@ -175,13 +175,14 @@ subroutine psb_zspasb(a,desc_a, info, afmt, upd, mold) block character(len=1024) :: fname type(psb_z_coo_sparse_mat) :: acoo - type(psb_z_csr_sparse_mat), allocatable :: aclip, andclip + type(psb_z_csr_sparse_mat), allocatable :: aclip + type(psb_z_ecsr_sparse_mat), allocatable :: andclip allocate(aclip,andclip) call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) - call aclip%mv_from_coo(acoo,info) + allocate(a%ad,mold=a%a) + call a%ad%mv_from_coo(acoo,info) call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) call andclip%mv_from_coo(acoo,info) - call move_alloc(aclip,a%ad) call move_alloc(andclip,a%and) if (.false.) then write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' diff --git a/test/pargen/runs/ppde.inp b/test/pargen/runs/ppde.inp index 5f040075..57fda01a 100644 --- a/test/pargen/runs/ppde.inp +++ b/test/pargen/runs/ppde.inp @@ -2,7 +2,7 @@ BICGSTAB Iterative method BICGSTAB CGS BICG BICGSTABL RGMRES FCG CGR BJAC Preconditioner NONE DIAG BJAC CSR Storage format for matrix A: CSR COO -140 Domain size (acutal system is this**3 (pde3d) or **2 (pde2d) ) +100 Domain size (acutal system is this**3 (pde3d) or **2 (pde2d) ) 3 Partition: 1 BLOCK 3 3D 2 Stopping criterion 1 2 0100 MAXIT From 86b8a261efd23d244a034a2b1826cdc3ecae2c43 Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Thu, 19 Jan 2023 08:36:22 -0500 Subject: [PATCH 03/48] Fixed conversion bug, changed SPASB interface --- base/modules/tools/psb_c_tools_mod.F90 | 3 ++- base/modules/tools/psb_d_tools_mod.F90 | 3 ++- base/modules/tools/psb_s_tools_mod.F90 | 3 ++- base/modules/tools/psb_z_tools_mod.F90 | 3 ++- base/psblas/psb_cspmm.f90 | 6 +++--- base/psblas/psb_dspmm.f90 | 6 +++--- base/psblas/psb_sspmm.f90 | 6 +++--- base/psblas/psb_zspmm.f90 | 6 +++--- base/serial/impl/psb_c_csr_impl.f90 | 9 +++------ base/serial/impl/psb_d_csr_impl.f90 | 9 +++------ base/serial/impl/psb_s_csr_impl.f90 | 9 +++------ base/serial/impl/psb_z_csr_impl.f90 | 9 +++------ base/tools/psb_cspasb.f90 | 28 ++++++++++++++++++++------ base/tools/psb_dspasb.f90 | 28 ++++++++++++++++++++------ base/tools/psb_sspasb.f90 | 28 ++++++++++++++++++++------ base/tools/psb_zspasb.f90 | 28 ++++++++++++++++++++------ test/pargen/psb_d_pde3d.F90 | 4 ++-- test/pargen/runs/ppde.inp | 6 +++--- 18 files changed, 125 insertions(+), 69 deletions(-) diff --git a/base/modules/tools/psb_c_tools_mod.F90 b/base/modules/tools/psb_c_tools_mod.F90 index 2de8f906..0ed2d82c 100644 --- a/base/modules/tools/psb_c_tools_mod.F90 +++ b/base/modules/tools/psb_c_tools_mod.F90 @@ -250,7 +250,7 @@ Module psb_c_tools_mod end interface interface psb_spasb - subroutine psb_cspasb(a,desc_a, info, afmt, upd, mold) + subroutine psb_cspasb(a,desc_a, info, afmt, upd, mold, bld_and) import implicit none type(psb_cspmat_type), intent (inout) :: a @@ -259,6 +259,7 @@ Module psb_c_tools_mod integer(psb_ipk_),optional, intent(in) :: upd character(len=*), optional, intent(in) :: afmt class(psb_c_base_sparse_mat), intent(in), optional :: mold + logical, intent(in), optional :: bld_and end subroutine psb_cspasb end interface diff --git a/base/modules/tools/psb_d_tools_mod.F90 b/base/modules/tools/psb_d_tools_mod.F90 index 30e45d53..26f83201 100644 --- a/base/modules/tools/psb_d_tools_mod.F90 +++ b/base/modules/tools/psb_d_tools_mod.F90 @@ -250,7 +250,7 @@ Module psb_d_tools_mod end interface interface psb_spasb - subroutine psb_dspasb(a,desc_a, info, afmt, upd, mold) + subroutine psb_dspasb(a,desc_a, info, afmt, upd, mold, bld_and) import implicit none type(psb_dspmat_type), intent (inout) :: a @@ -259,6 +259,7 @@ Module psb_d_tools_mod integer(psb_ipk_),optional, intent(in) :: upd character(len=*), optional, intent(in) :: afmt class(psb_d_base_sparse_mat), intent(in), optional :: mold + logical, intent(in), optional :: bld_and end subroutine psb_dspasb end interface diff --git a/base/modules/tools/psb_s_tools_mod.F90 b/base/modules/tools/psb_s_tools_mod.F90 index 5d2f8d00..0f70a31a 100644 --- a/base/modules/tools/psb_s_tools_mod.F90 +++ b/base/modules/tools/psb_s_tools_mod.F90 @@ -250,7 +250,7 @@ Module psb_s_tools_mod end interface interface psb_spasb - subroutine psb_sspasb(a,desc_a, info, afmt, upd, mold) + subroutine psb_sspasb(a,desc_a, info, afmt, upd, mold, bld_and) import implicit none type(psb_sspmat_type), intent (inout) :: a @@ -259,6 +259,7 @@ Module psb_s_tools_mod integer(psb_ipk_),optional, intent(in) :: upd character(len=*), optional, intent(in) :: afmt class(psb_s_base_sparse_mat), intent(in), optional :: mold + logical, intent(in), optional :: bld_and end subroutine psb_sspasb end interface diff --git a/base/modules/tools/psb_z_tools_mod.F90 b/base/modules/tools/psb_z_tools_mod.F90 index 9d6bd77b..1f24e05a 100644 --- a/base/modules/tools/psb_z_tools_mod.F90 +++ b/base/modules/tools/psb_z_tools_mod.F90 @@ -250,7 +250,7 @@ Module psb_z_tools_mod end interface interface psb_spasb - subroutine psb_zspasb(a,desc_a, info, afmt, upd, mold) + subroutine psb_zspasb(a,desc_a, info, afmt, upd, mold, bld_and) import implicit none type(psb_zspmat_type), intent (inout) :: a @@ -259,6 +259,7 @@ Module psb_z_tools_mod integer(psb_ipk_),optional, intent(in) :: upd character(len=*), optional, intent(in) :: afmt class(psb_z_base_sparse_mat), intent(in), optional :: mold + logical, intent(in), optional :: bld_and end subroutine psb_zspasb end interface diff --git a/base/psblas/psb_cspmm.f90 b/base/psblas/psb_cspmm.f90 index 555461df..84d8a7d8 100644 --- a/base/psblas/psb_cspmm.f90 +++ b/base/psblas/psb_cspmm.f90 @@ -179,11 +179,11 @@ subroutine psb_cspmv_vect(alpha,a,x,beta,y,desc_a,info,& if (trans_ == 'N') then ! Matrix is not transposed - if (.true.) then - call psi_swapdata(psb_swap_send_,& + if (allocated(a%ad)) then + if (doswap_) call psi_swapdata(psb_swap_send_,& & czero,x%v,desc_a,iwork,info,data=psb_comm_halo_) call a%ad%spmm(alpha,x%v,beta,y%v,info) - call psi_swapdata(psb_swap_recv_,& + if (doswap_) call psi_swapdata(psb_swap_recv_,& & czero,x%v,desc_a,iwork,info,data=psb_comm_halo_) call a%and%spmm(alpha,x%v,cone,y%v,info) diff --git a/base/psblas/psb_dspmm.f90 b/base/psblas/psb_dspmm.f90 index be8a493f..d5897f82 100644 --- a/base/psblas/psb_dspmm.f90 +++ b/base/psblas/psb_dspmm.f90 @@ -179,11 +179,11 @@ subroutine psb_dspmv_vect(alpha,a,x,beta,y,desc_a,info,& if (trans_ == 'N') then ! Matrix is not transposed - if (.true.) then - call psi_swapdata(psb_swap_send_,& + if (allocated(a%ad)) then + if (doswap_) call psi_swapdata(psb_swap_send_,& & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) call a%ad%spmm(alpha,x%v,beta,y%v,info) - call psi_swapdata(psb_swap_recv_,& + if (doswap_) call psi_swapdata(psb_swap_recv_,& & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) call a%and%spmm(alpha,x%v,done,y%v,info) diff --git a/base/psblas/psb_sspmm.f90 b/base/psblas/psb_sspmm.f90 index 79bfbdd1..7c1e0ab3 100644 --- a/base/psblas/psb_sspmm.f90 +++ b/base/psblas/psb_sspmm.f90 @@ -179,11 +179,11 @@ subroutine psb_sspmv_vect(alpha,a,x,beta,y,desc_a,info,& if (trans_ == 'N') then ! Matrix is not transposed - if (.true.) then - call psi_swapdata(psb_swap_send_,& + if (allocated(a%ad)) then + if (doswap_) call psi_swapdata(psb_swap_send_,& & szero,x%v,desc_a,iwork,info,data=psb_comm_halo_) call a%ad%spmm(alpha,x%v,beta,y%v,info) - call psi_swapdata(psb_swap_recv_,& + if (doswap_) call psi_swapdata(psb_swap_recv_,& & szero,x%v,desc_a,iwork,info,data=psb_comm_halo_) call a%and%spmm(alpha,x%v,sone,y%v,info) diff --git a/base/psblas/psb_zspmm.f90 b/base/psblas/psb_zspmm.f90 index f248db8b..4dc73f83 100644 --- a/base/psblas/psb_zspmm.f90 +++ b/base/psblas/psb_zspmm.f90 @@ -179,11 +179,11 @@ subroutine psb_zspmv_vect(alpha,a,x,beta,y,desc_a,info,& if (trans_ == 'N') then ! Matrix is not transposed - if (.true.) then - call psi_swapdata(psb_swap_send_,& + if (allocated(a%ad)) then + if (doswap_) call psi_swapdata(psb_swap_send_,& & zzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) call a%ad%spmm(alpha,x%v,beta,y%v,info) - call psi_swapdata(psb_swap_recv_,& + if (doswap_) call psi_swapdata(psb_swap_recv_,& & zzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) call a%and%spmm(alpha,x%v,zone,y%v,info) diff --git a/base/serial/impl/psb_c_csr_impl.f90 b/base/serial/impl/psb_c_csr_impl.f90 index 1fed09ba..4744d947 100644 --- a/base/serial/impl/psb_c_csr_impl.f90 +++ b/base/serial/impl/psb_c_csr_impl.f90 @@ -3643,9 +3643,8 @@ subroutine psb_c_ecsr_csmv(alpha,a,x,beta,y,info,trans) goto 9999 end if - if (((beta == cone).and..not.(tra.or.ctra))& - & .or.(a%is_triangle()).or.(a%is_unit())) then - + if ((beta == cone).and.& + & .not.(tra.or.ctra.or.(a%is_triangle()).or.(a%is_unit()))) then call psb_c_ecsr_csmv_inner(m,n,alpha,a%irp,a%ja,a%val,& & a%nnerws,a%nerwp,x,y) else @@ -3672,9 +3671,6 @@ contains if (alpha == czero) return - - - if (alpha == cone) then !$omp parallel do private(ir,i,j,acc) do ir=1,nnerws @@ -3740,6 +3736,7 @@ subroutine psb_c_ecsr_cmp_nerwp(a,info) end if end do call psb_realloc(nnerws,a%nerwp,info) + a%nnerws = nnerws end subroutine psb_c_ecsr_cmp_nerwp subroutine psb_c_cp_ecsr_from_coo(a,b,info) diff --git a/base/serial/impl/psb_d_csr_impl.f90 b/base/serial/impl/psb_d_csr_impl.f90 index 1bcc82a9..6d2b58ad 100644 --- a/base/serial/impl/psb_d_csr_impl.f90 +++ b/base/serial/impl/psb_d_csr_impl.f90 @@ -3643,9 +3643,8 @@ subroutine psb_d_ecsr_csmv(alpha,a,x,beta,y,info,trans) goto 9999 end if - if (((beta == done).and..not.(tra.or.ctra))& - & .or.(a%is_triangle()).or.(a%is_unit())) then - + if ((beta == done).and.& + & .not.(tra.or.ctra.or.(a%is_triangle()).or.(a%is_unit()))) then call psb_d_ecsr_csmv_inner(m,n,alpha,a%irp,a%ja,a%val,& & a%nnerws,a%nerwp,x,y) else @@ -3672,9 +3671,6 @@ contains if (alpha == dzero) return - - - if (alpha == done) then !$omp parallel do private(ir,i,j,acc) do ir=1,nnerws @@ -3740,6 +3736,7 @@ subroutine psb_d_ecsr_cmp_nerwp(a,info) end if end do call psb_realloc(nnerws,a%nerwp,info) + a%nnerws = nnerws end subroutine psb_d_ecsr_cmp_nerwp subroutine psb_d_cp_ecsr_from_coo(a,b,info) diff --git a/base/serial/impl/psb_s_csr_impl.f90 b/base/serial/impl/psb_s_csr_impl.f90 index 9670aeb9..87cfff68 100644 --- a/base/serial/impl/psb_s_csr_impl.f90 +++ b/base/serial/impl/psb_s_csr_impl.f90 @@ -3643,9 +3643,8 @@ subroutine psb_s_ecsr_csmv(alpha,a,x,beta,y,info,trans) goto 9999 end if - if (((beta == sone).and..not.(tra.or.ctra))& - & .or.(a%is_triangle()).or.(a%is_unit())) then - + if ((beta == sone).and.& + & .not.(tra.or.ctra.or.(a%is_triangle()).or.(a%is_unit()))) then call psb_s_ecsr_csmv_inner(m,n,alpha,a%irp,a%ja,a%val,& & a%nnerws,a%nerwp,x,y) else @@ -3672,9 +3671,6 @@ contains if (alpha == szero) return - - - if (alpha == sone) then !$omp parallel do private(ir,i,j,acc) do ir=1,nnerws @@ -3740,6 +3736,7 @@ subroutine psb_s_ecsr_cmp_nerwp(a,info) end if end do call psb_realloc(nnerws,a%nerwp,info) + a%nnerws = nnerws end subroutine psb_s_ecsr_cmp_nerwp subroutine psb_s_cp_ecsr_from_coo(a,b,info) diff --git a/base/serial/impl/psb_z_csr_impl.f90 b/base/serial/impl/psb_z_csr_impl.f90 index e9847849..a4a2dd5a 100644 --- a/base/serial/impl/psb_z_csr_impl.f90 +++ b/base/serial/impl/psb_z_csr_impl.f90 @@ -3643,9 +3643,8 @@ subroutine psb_z_ecsr_csmv(alpha,a,x,beta,y,info,trans) goto 9999 end if - if (((beta == zone).and..not.(tra.or.ctra))& - & .or.(a%is_triangle()).or.(a%is_unit())) then - + if ((beta == zone).and.& + & .not.(tra.or.ctra.or.(a%is_triangle()).or.(a%is_unit()))) then call psb_z_ecsr_csmv_inner(m,n,alpha,a%irp,a%ja,a%val,& & a%nnerws,a%nerwp,x,y) else @@ -3672,9 +3671,6 @@ contains if (alpha == zzero) return - - - if (alpha == zone) then !$omp parallel do private(ir,i,j,acc) do ir=1,nnerws @@ -3740,6 +3736,7 @@ subroutine psb_z_ecsr_cmp_nerwp(a,info) end if end do call psb_realloc(nnerws,a%nerwp,info) + a%nnerws = nnerws end subroutine psb_z_ecsr_cmp_nerwp subroutine psb_z_cp_ecsr_from_coo(a,b,info) diff --git a/base/tools/psb_cspasb.f90 b/base/tools/psb_cspasb.f90 index b4c957b0..46258139 100644 --- a/base/tools/psb_cspasb.f90 +++ b/base/tools/psb_cspasb.f90 @@ -44,7 +44,7 @@ ! psb_upd_perm_ Permutation(more memory) ! ! -subroutine psb_cspasb(a,desc_a, info, afmt, upd, mold) +subroutine psb_cspasb(a,desc_a, info, afmt, upd, mold, bld_and) use psb_base_mod, psb_protect_name => psb_cspasb use psb_sort_mod use psi_mod @@ -58,6 +58,7 @@ subroutine psb_cspasb(a,desc_a, info, afmt, upd, mold) integer(psb_ipk_), optional, intent(in) :: upd character(len=*), optional, intent(in) :: afmt class(psb_c_base_sparse_mat), intent(in), optional :: mold + logical, intent(in), optional :: bld_and !....Locals.... type(psb_ctxt_type) :: ctxt integer(psb_ipk_) :: np,me, err_act @@ -65,6 +66,7 @@ subroutine psb_cspasb(a,desc_a, info, afmt, upd, mold) integer(psb_ipk_) :: debug_level, debug_unit character(len=20) :: name, ch_err class(psb_i_base_vect_type), allocatable :: ivm + logical :: bld_and_ info = psb_success_ name = 'psb_spasb' @@ -93,7 +95,11 @@ subroutine psb_cspasb(a,desc_a, info, afmt, upd, mold) if (debug_level >= psb_debug_ext_)& & write(debug_unit, *) me,' ',trim(name),& & ' Begin matrix assembly...' - + if (present(bld_and)) then + bld_and_ = bld_and + else + bld_and_ = .false. + end if !check on errors encountered in psdspins if (a%is_bld()) then @@ -171,19 +177,26 @@ subroutine psb_cspasb(a,desc_a, info, afmt, upd, mold) end if - if (.true.) then + if (bld_and_) then block character(len=1024) :: fname type(psb_c_coo_sparse_mat) :: acoo type(psb_c_csr_sparse_mat), allocatable :: aclip type(psb_c_ecsr_sparse_mat), allocatable :: andclip - allocate(aclip,andclip) + logical, parameter :: use_ecsr=.false. + allocate(aclip) call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) allocate(a%ad,mold=a%a) call a%ad%mv_from_coo(acoo,info) call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) - call andclip%mv_from_coo(acoo,info) - call move_alloc(andclip,a%and) + if (use_ecsr) then + allocate(andclip) + call andclip%mv_from_coo(acoo,info) + call move_alloc(andclip,a%and) + else + allocate(a%and,mold=a%a) + call a%and%mv_from_coo(acoo,info) + end if if (.false.) then write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' open(25,file=fname) @@ -200,6 +213,9 @@ subroutine psb_cspasb(a,desc_a, info, afmt, upd, mold) &a%and%get_nrows(),a%and%get_ncols(),n_row,n_col end if end block + else + if (allocated(a%ad)) deallocate(a%ad) + if (allocated(a%and)) deallocate(a%and) end if if (debug_level >= psb_debug_ext_) then ch_err=a%get_fmt() diff --git a/base/tools/psb_dspasb.f90 b/base/tools/psb_dspasb.f90 index 5ebc47e8..6beb0e6f 100644 --- a/base/tools/psb_dspasb.f90 +++ b/base/tools/psb_dspasb.f90 @@ -44,7 +44,7 @@ ! psb_upd_perm_ Permutation(more memory) ! ! -subroutine psb_dspasb(a,desc_a, info, afmt, upd, mold) +subroutine psb_dspasb(a,desc_a, info, afmt, upd, mold, bld_and) use psb_base_mod, psb_protect_name => psb_dspasb use psb_sort_mod use psi_mod @@ -58,6 +58,7 @@ subroutine psb_dspasb(a,desc_a, info, afmt, upd, mold) integer(psb_ipk_), optional, intent(in) :: upd character(len=*), optional, intent(in) :: afmt class(psb_d_base_sparse_mat), intent(in), optional :: mold + logical, intent(in), optional :: bld_and !....Locals.... type(psb_ctxt_type) :: ctxt integer(psb_ipk_) :: np,me, err_act @@ -65,6 +66,7 @@ subroutine psb_dspasb(a,desc_a, info, afmt, upd, mold) integer(psb_ipk_) :: debug_level, debug_unit character(len=20) :: name, ch_err class(psb_i_base_vect_type), allocatable :: ivm + logical :: bld_and_ info = psb_success_ name = 'psb_spasb' @@ -93,7 +95,11 @@ subroutine psb_dspasb(a,desc_a, info, afmt, upd, mold) if (debug_level >= psb_debug_ext_)& & write(debug_unit, *) me,' ',trim(name),& & ' Begin matrix assembly...' - + if (present(bld_and)) then + bld_and_ = bld_and + else + bld_and_ = .false. + end if !check on errors encountered in psdspins if (a%is_bld()) then @@ -171,19 +177,26 @@ subroutine psb_dspasb(a,desc_a, info, afmt, upd, mold) end if - if (.true.) then + if (bld_and_) then block character(len=1024) :: fname type(psb_d_coo_sparse_mat) :: acoo type(psb_d_csr_sparse_mat), allocatable :: aclip type(psb_d_ecsr_sparse_mat), allocatable :: andclip - allocate(aclip,andclip) + logical, parameter :: use_ecsr=.true. + allocate(aclip) call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) allocate(a%ad,mold=a%a) call a%ad%mv_from_coo(acoo,info) call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) - call andclip%mv_from_coo(acoo,info) - call move_alloc(andclip,a%and) + if (use_ecsr) then + allocate(andclip) + call andclip%mv_from_coo(acoo,info) + call move_alloc(andclip,a%and) + else + allocate(a%and,mold=a%a) + call a%and%mv_from_coo(acoo,info) + end if if (.false.) then write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' open(25,file=fname) @@ -200,6 +213,9 @@ subroutine psb_dspasb(a,desc_a, info, afmt, upd, mold) &a%and%get_nrows(),a%and%get_ncols(),n_row,n_col end if end block + else + if (allocated(a%ad)) deallocate(a%ad) + if (allocated(a%and)) deallocate(a%and) end if if (debug_level >= psb_debug_ext_) then ch_err=a%get_fmt() diff --git a/base/tools/psb_sspasb.f90 b/base/tools/psb_sspasb.f90 index 5423c2a7..0edae30e 100644 --- a/base/tools/psb_sspasb.f90 +++ b/base/tools/psb_sspasb.f90 @@ -44,7 +44,7 @@ ! psb_upd_perm_ Permutation(more memory) ! ! -subroutine psb_sspasb(a,desc_a, info, afmt, upd, mold) +subroutine psb_sspasb(a,desc_a, info, afmt, upd, mold, bld_and) use psb_base_mod, psb_protect_name => psb_sspasb use psb_sort_mod use psi_mod @@ -58,6 +58,7 @@ subroutine psb_sspasb(a,desc_a, info, afmt, upd, mold) integer(psb_ipk_), optional, intent(in) :: upd character(len=*), optional, intent(in) :: afmt class(psb_s_base_sparse_mat), intent(in), optional :: mold + logical, intent(in), optional :: bld_and !....Locals.... type(psb_ctxt_type) :: ctxt integer(psb_ipk_) :: np,me, err_act @@ -65,6 +66,7 @@ subroutine psb_sspasb(a,desc_a, info, afmt, upd, mold) integer(psb_ipk_) :: debug_level, debug_unit character(len=20) :: name, ch_err class(psb_i_base_vect_type), allocatable :: ivm + logical :: bld_and_ info = psb_success_ name = 'psb_spasb' @@ -93,7 +95,11 @@ subroutine psb_sspasb(a,desc_a, info, afmt, upd, mold) if (debug_level >= psb_debug_ext_)& & write(debug_unit, *) me,' ',trim(name),& & ' Begin matrix assembly...' - + if (present(bld_and)) then + bld_and_ = bld_and + else + bld_and_ = .false. + end if !check on errors encountered in psdspins if (a%is_bld()) then @@ -171,19 +177,26 @@ subroutine psb_sspasb(a,desc_a, info, afmt, upd, mold) end if - if (.true.) then + if (bld_and_) then block character(len=1024) :: fname type(psb_s_coo_sparse_mat) :: acoo type(psb_s_csr_sparse_mat), allocatable :: aclip type(psb_s_ecsr_sparse_mat), allocatable :: andclip - allocate(aclip,andclip) + logical, parameter :: use_ecsr=.false. + allocate(aclip) call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) allocate(a%ad,mold=a%a) call a%ad%mv_from_coo(acoo,info) call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) - call andclip%mv_from_coo(acoo,info) - call move_alloc(andclip,a%and) + if (use_ecsr) then + allocate(andclip) + call andclip%mv_from_coo(acoo,info) + call move_alloc(andclip,a%and) + else + allocate(a%and,mold=a%a) + call a%and%mv_from_coo(acoo,info) + end if if (.false.) then write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' open(25,file=fname) @@ -200,6 +213,9 @@ subroutine psb_sspasb(a,desc_a, info, afmt, upd, mold) &a%and%get_nrows(),a%and%get_ncols(),n_row,n_col end if end block + else + if (allocated(a%ad)) deallocate(a%ad) + if (allocated(a%and)) deallocate(a%and) end if if (debug_level >= psb_debug_ext_) then ch_err=a%get_fmt() diff --git a/base/tools/psb_zspasb.f90 b/base/tools/psb_zspasb.f90 index 66fc8cd7..cd77de15 100644 --- a/base/tools/psb_zspasb.f90 +++ b/base/tools/psb_zspasb.f90 @@ -44,7 +44,7 @@ ! psb_upd_perm_ Permutation(more memory) ! ! -subroutine psb_zspasb(a,desc_a, info, afmt, upd, mold) +subroutine psb_zspasb(a,desc_a, info, afmt, upd, mold, bld_and) use psb_base_mod, psb_protect_name => psb_zspasb use psb_sort_mod use psi_mod @@ -58,6 +58,7 @@ subroutine psb_zspasb(a,desc_a, info, afmt, upd, mold) integer(psb_ipk_), optional, intent(in) :: upd character(len=*), optional, intent(in) :: afmt class(psb_z_base_sparse_mat), intent(in), optional :: mold + logical, intent(in), optional :: bld_and !....Locals.... type(psb_ctxt_type) :: ctxt integer(psb_ipk_) :: np,me, err_act @@ -65,6 +66,7 @@ subroutine psb_zspasb(a,desc_a, info, afmt, upd, mold) integer(psb_ipk_) :: debug_level, debug_unit character(len=20) :: name, ch_err class(psb_i_base_vect_type), allocatable :: ivm + logical :: bld_and_ info = psb_success_ name = 'psb_spasb' @@ -93,7 +95,11 @@ subroutine psb_zspasb(a,desc_a, info, afmt, upd, mold) if (debug_level >= psb_debug_ext_)& & write(debug_unit, *) me,' ',trim(name),& & ' Begin matrix assembly...' - + if (present(bld_and)) then + bld_and_ = bld_and + else + bld_and_ = .false. + end if !check on errors encountered in psdspins if (a%is_bld()) then @@ -171,19 +177,26 @@ subroutine psb_zspasb(a,desc_a, info, afmt, upd, mold) end if - if (.true.) then + if (bld_and_) then block character(len=1024) :: fname type(psb_z_coo_sparse_mat) :: acoo type(psb_z_csr_sparse_mat), allocatable :: aclip type(psb_z_ecsr_sparse_mat), allocatable :: andclip - allocate(aclip,andclip) + logical, parameter :: use_ecsr=.false. + allocate(aclip) call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) allocate(a%ad,mold=a%a) call a%ad%mv_from_coo(acoo,info) call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) - call andclip%mv_from_coo(acoo,info) - call move_alloc(andclip,a%and) + if (use_ecsr) then + allocate(andclip) + call andclip%mv_from_coo(acoo,info) + call move_alloc(andclip,a%and) + else + allocate(a%and,mold=a%a) + call a%and%mv_from_coo(acoo,info) + end if if (.false.) then write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' open(25,file=fname) @@ -200,6 +213,9 @@ subroutine psb_zspasb(a,desc_a, info, afmt, upd, mold) &a%and%get_nrows(),a%and%get_ncols(),n_row,n_col end if end block + else + if (allocated(a%ad)) deallocate(a%ad) + if (allocated(a%and)) deallocate(a%and) end if if (debug_level >= psb_debug_ext_) then ch_err=a%get_fmt() diff --git a/test/pargen/psb_d_pde3d.F90 b/test/pargen/psb_d_pde3d.F90 index d4eeccf2..cd503d29 100644 --- a/test/pargen/psb_d_pde3d.F90 +++ b/test/pargen/psb_d_pde3d.F90 @@ -680,9 +680,9 @@ contains t1 = psb_wtime() if (info == psb_success_) then if (present(amold)) then - call psb_spasb(a,desc_a,info,mold=amold) + call psb_spasb(a,desc_a,info,mold=amold,bld_and=.true.) else - call psb_spasb(a,desc_a,info,afmt=afmt) + call psb_spasb(a,desc_a,info,afmt=afmt,bld_and=.true.) end if end if call psb_barrier(ctxt) diff --git a/test/pargen/runs/ppde.inp b/test/pargen/runs/ppde.inp index 57fda01a..c70a973f 100644 --- a/test/pargen/runs/ppde.inp +++ b/test/pargen/runs/ppde.inp @@ -2,11 +2,11 @@ BICGSTAB Iterative method BICGSTAB CGS BICG BICGSTABL RGMRES FCG CGR BJAC Preconditioner NONE DIAG BJAC CSR Storage format for matrix A: CSR COO -100 Domain size (acutal system is this**3 (pde3d) or **2 (pde2d) ) +200 Domain size (acutal system is this**3 (pde3d) or **2 (pde2d) ) 3 Partition: 1 BLOCK 3 3D 2 Stopping criterion 1 2 -0100 MAXIT -05 ITRACE +0300 MAXIT +10 ITRACE 002 IRST restart for RGMRES and BiCGSTABL ILU Block Solver ILU,ILUT,INVK,AINVT,AORTH NONE If ILU : MILU or NONE othewise ignored From d3fcd566d92605de57ece25977544b32f046e59f Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Mon, 23 Oct 2023 14:16:30 +0200 Subject: [PATCH 04/48] Define a SHIFT argument to compute ILU( A+shft I) --- prec/impl/psb_c_ilu0_fact.f90 | 27 ++++++++++++++++++--------- prec/impl/psb_c_iluk_fact.f90 | 21 ++++++++++++++------- prec/impl/psb_c_ilut_fact.f90 | 17 +++++++++++------ prec/impl/psb_d_ilu0_fact.f90 | 27 ++++++++++++++++++--------- prec/impl/psb_d_iluk_fact.f90 | 21 ++++++++++++++------- prec/impl/psb_d_ilut_fact.f90 | 17 +++++++++++------ prec/impl/psb_s_ilu0_fact.f90 | 27 ++++++++++++++++++--------- prec/impl/psb_s_iluk_fact.f90 | 21 ++++++++++++++------- prec/impl/psb_s_ilut_fact.f90 | 17 +++++++++++------ prec/impl/psb_z_ilu0_fact.f90 | 27 ++++++++++++++++++--------- prec/impl/psb_z_iluk_fact.f90 | 21 ++++++++++++++------- prec/impl/psb_z_ilut_fact.f90 | 17 +++++++++++------ prec/psb_c_ilu_fact_mod.f90 | 9 ++++++--- prec/psb_d_ilu_fact_mod.f90 | 9 ++++++--- prec/psb_s_ilu_fact_mod.f90 | 9 ++++++--- prec/psb_z_ilu_fact_mod.f90 | 9 ++++++--- 16 files changed, 196 insertions(+), 100 deletions(-) diff --git a/prec/impl/psb_c_ilu0_fact.f90 b/prec/impl/psb_c_ilu0_fact.f90 index 1a3e1046..9e039b19 100644 --- a/prec/impl/psb_c_ilu0_fact.f90 +++ b/prec/impl/psb_c_ilu0_fact.f90 @@ -130,7 +130,7 @@ ! greater than 0. If the overlap is 0 or the matrix has been reordered ! (see psb_fact_bld), then blck is empty. ! -subroutine psb_cilu0_fact(ialg,a,l,u,d,info,blck, upd) +subroutine psb_cilu0_fact(ialg,a,l,u,d,info,blck, upd,shft) use psb_base_mod use psb_c_ilu_fact_mod, psb_protect_name => psb_cilu0_fact @@ -145,11 +145,13 @@ subroutine psb_cilu0_fact(ialg,a,l,u,d,info,blck, upd) integer(psb_ipk_), intent(out) :: info type(psb_cspmat_type),intent(in), optional, target :: blck character, intent(in), optional :: upd + complex(psb_spk_), intent(in), optional :: shft ! Local variables integer(psb_ipk_) :: l1, l2, m, err_act type(psb_cspmat_type), pointer :: blck_ type(psb_c_csr_sparse_mat) :: ll, uu + complex(psb_spk_) :: shft_ character :: upd_ character(len=20) :: name, ch_err @@ -177,7 +179,12 @@ subroutine psb_cilu0_fact(ialg,a,l,u,d,info,blck, upd) else upd_ = 'F' end if - + if (present(shft)) then + shft_ = shft + else + shft_ = czero + end if + m = a%get_nrows() + blck_%get_nrows() if ((m /= l%get_nrows()).or.(m /= u%get_nrows()).or.& & (m > size(d)) ) then @@ -193,7 +200,7 @@ subroutine psb_cilu0_fact(ialg,a,l,u,d,info,blck, upd) ! Compute the ILU(0) or the MILU(0) factorization, depending on ialg ! call psb_cilu0_factint(ialg,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,upd_,info) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,upd_,shft_,info) if(info.ne.0) then info=psb_err_from_subroutine_ ch_err='psb_cilu0_factint' @@ -314,7 +321,7 @@ contains ! Error code. ! subroutine psb_cilu0_factint(ialg,a,b,& - & d,lval,lja,lirp,uval,uja,uirp,l1,l2,upd,info) + & d,lval,lja,lirp,uval,uja,uirp,l1,l2,upd,shft,info) implicit none @@ -325,6 +332,7 @@ contains integer(psb_ipk_), intent(inout) :: lja(:),lirp(:),uja(:),uirp(:) complex(psb_spk_), intent(inout) :: lval(:),uval(:),d(:) character, intent(in) :: upd + complex(psb_spk_), intent(in) :: shft ! Local variables integer(psb_ipk_) :: i,j,k,l,low1,low2,kk,jj,ll, ktrw,err_act, m @@ -382,14 +390,14 @@ contains ! into lval/d(i)/uval ! call ilu_copyin(i,ma,a,i,ione,m,l1,lja,lval,& - & d(i),l2,uja,uval,ktrw,trw,upd) + & d(i),l2,uja,uval,ktrw,trw,upd,shft) else ! ! Copy the i-th local row of the matrix, stored in b ! (as (i-ma)-th row), into lval/d(i)/uval ! call ilu_copyin(i-ma,mb,b,i,ione,m,l1,lja,lval,& - & d(i),l2,uja,uval,ktrw,trw,upd) + & d(i),l2,uja,uval,ktrw,trw,upd,shft) endif lirp(i+1) = l1 + 1 @@ -583,7 +591,7 @@ contains ! every nrb calls to copyin. If A is in CSR format it is unused. ! subroutine ilu_copyin(i,m,a,jd,jmin,jmax,l1,lja,lval,& - & dia,l2,uja,uval,ktrw,trw,upd) + & dia,l2,uja,uval,ktrw,trw,upd,shft) use psb_base_mod @@ -597,6 +605,7 @@ contains integer(psb_ipk_), intent(inout) :: lja(:), uja(:) complex(psb_spk_), intent(inout) :: lval(:), uval(:), dia character, intent(in) :: upd + complex(psb_spk_), intent(in) :: shft ! Local variables integer(psb_ipk_) :: k,j,info,irb, nz integer(psb_ipk_), parameter :: nrb=40 @@ -625,7 +634,7 @@ contains lval(l1) = aa%val(j) lja(l1) = k else if (k == jd) then - dia = aa%val(j) + dia = aa%val(j) + shft else if ((k > jd).and.(k <= jmax)) then l2 = l2 + 1 uval(l2) = aa%val(j) @@ -665,7 +674,7 @@ contains lval(l1) = trw%val(ktrw) lja(l1) = k else if (k == jd) then - dia = trw%val(ktrw) + dia = trw%val(ktrw) + shft else if ((k > jd).and.(k <= jmax)) then l2 = l2 + 1 uval(l2) = trw%val(ktrw) diff --git a/prec/impl/psb_c_iluk_fact.f90 b/prec/impl/psb_c_iluk_fact.f90 index c4ebc678..ddda6d20 100644 --- a/prec/impl/psb_c_iluk_fact.f90 +++ b/prec/impl/psb_c_iluk_fact.f90 @@ -127,7 +127,7 @@ ! greater than 0. If the overlap is 0 or the matrix has been reordered ! (see psb_fact_bld), then blck does not contain any row. ! -subroutine psb_ciluk_fact(fill_in,ialg,a,l,u,d,info,blck) +subroutine psb_ciluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) use psb_base_mod use psb_c_ilu_fact_mod, psb_protect_name => psb_ciluk_fact @@ -141,6 +141,7 @@ subroutine psb_ciluk_fact(fill_in,ialg,a,l,u,d,info,blck) type(psb_cspmat_type),intent(inout) :: l,u type(psb_cspmat_type),intent(in), optional, target :: blck complex(psb_spk_), intent(inout) :: d(:) + complex(psb_spk_), intent(in), optional :: shft ! Local Variables integer(psb_ipk_) :: l1, l2, m, err_act @@ -184,7 +185,7 @@ subroutine psb_ciluk_fact(fill_in,ialg,a,l,u,d,info,blck) ! Compute the ILU(k) or the MILU(k) factorization, depending on ialg ! call psb_ciluk_factint(fill_in,ialg,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,shft) if (info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='psb_ciluk_factint' @@ -298,7 +299,7 @@ contains ! Error code. ! subroutine psb_ciluk_factint(fill_in,ialg,a,b,& - & d,lval,lja,lirp,uval,uja,uirp,l1,l2,info) + & d,lval,lja,lirp,uval,uja,uirp,l1,l2,info,shft) use psb_base_mod @@ -311,6 +312,7 @@ contains integer(psb_ipk_), allocatable, intent(inout) :: lja(:),lirp(:),uja(:),uirp(:) complex(psb_spk_), allocatable, intent(inout) :: lval(:),uval(:) complex(psb_spk_), intent(inout) :: d(:) + complex(psb_spk_), intent(in) :: shft ! Local variables integer(psb_ipk_) :: ma,mb,i, ktrw,err_act,nidx, m @@ -400,13 +402,13 @@ contains ! ! Copy into trw the i-th local row of the matrix, stored in a ! - call iluk_copyin(i,ma,a,ione,m,row,rowlevs,heap,ktrw,trw,info) + call iluk_copyin(i,ma,a,ione,m,row,rowlevs,heap,ktrw,trw,info,shft) else ! ! Copy into trw the i-th local row of the matrix, stored in b ! (as (i-ma)-th row) ! - call iluk_copyin(i-ma,mb,b,ione,m,row,rowlevs,heap,ktrw,trw,info) + call iluk_copyin(i-ma,mb,b,ione,m,row,rowlevs,heap,ktrw,trw,info,shft) endif ! Do an elimination step on the current row. It turns out we only @@ -516,7 +518,7 @@ contains ! until we empty the buffer. Thus we will make a call to psb_sp_getblk ! every nrb calls to copyin. If A is in CSR format it is unused. ! - subroutine iluk_copyin(i,m,a,jmin,jmax,row,rowlevs,heap,ktrw,trw,info) + subroutine iluk_copyin(i,m,a,jmin,jmax,row,rowlevs,heap,ktrw,trw,info,shft) use psb_base_mod @@ -530,6 +532,8 @@ contains integer(psb_ipk_), intent(inout) :: rowlevs(:) complex(psb_spk_), intent(inout) :: row(:) type(psb_i_heap), intent(inout) :: heap + complex(psb_spk_), intent(in) :: shft + ! Local variables integer(psb_ipk_) :: k,j,irb,err_act,nz @@ -554,6 +558,7 @@ contains k = aa%ja(j) if ((jmin<=k).and.(k<=jmax)) then row(k) = aa%val(j) + if (k==i) row(k) = row(k) + shft rowlevs(k) = 0 call heap%insert(k,info) end if @@ -587,6 +592,7 @@ contains k = trw%ja(ktrw) if ((jmin<=k).and.(k<=jmax)) then row(k) = trw%val(ktrw) + if (k==i) row(k) = row(k) + shft rowlevs(k) = 0 call heap%insert(k,info) end if @@ -670,7 +676,8 @@ contains ! Note: this argument is intent(inout) and not only intent(out) ! to retain its allocation, done by this routine. ! - subroutine iluk_fact(fill_in,i,row,rowlevs,heap,d,uja,uirp,uval,uplevs,nidx,idxs,info) + subroutine iluk_fact(fill_in,i,row,rowlevs,heap,d,& + & uja,uirp,uval,uplevs,nidx,idxs,info) use psb_base_mod diff --git a/prec/impl/psb_c_ilut_fact.f90 b/prec/impl/psb_c_ilut_fact.f90 index 633899de..997b3b84 100644 --- a/prec/impl/psb_c_ilut_fact.f90 +++ b/prec/impl/psb_c_ilut_fact.f90 @@ -123,7 +123,7 @@ ! greater than 0. If the overlap is 0 or the matrix has been reordered ! (see psb_fact_bld), then blck does not contain any row. ! -subroutine psb_cilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale) +subroutine psb_cilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) use psb_base_mod use psb_c_ilu_fact_mod, psb_protect_name => psb_cilut_fact @@ -139,6 +139,7 @@ subroutine psb_cilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale) complex(psb_spk_), intent(inout) :: d(:) type(psb_cspmat_type),intent(in), optional, target :: blck integer(psb_ipk_), intent(in), optional :: iscale + complex(psb_spk_), intent(in), optional :: shft ! Local Variables integer(psb_ipk_) :: l1, l2, m, err_act, iscale_ @@ -206,7 +207,7 @@ subroutine psb_cilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale) ! Compute the ILU(k,t) factorization ! call psb_cilut_factint(fill_in,thres,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,scale) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,scale,shft) if (info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='psb_cilut_factint' @@ -316,7 +317,7 @@ contains ! Error code. ! subroutine psb_cilut_factint(fill_in,thres,a,b,& - & d,lval,lja,lirp,uval,uja,uirp,l1,l2,info,scale) + & d,lval,lja,lirp,uval,uja,uirp,l1,l2,info,scale,shft) use psb_base_mod @@ -331,6 +332,7 @@ contains complex(psb_spk_), allocatable, intent(inout) :: lval(:),uval(:) complex(psb_spk_), intent(inout) :: d(:) real(psb_spk_), intent(in), optional :: scale + complex(psb_spk_), intent(in) :: shft ! Local Variables integer(psb_ipk_) :: i, ktrw,err_act,nidx,nlw,nup,jmaxup, ma, mb, m @@ -401,10 +403,10 @@ contains d(i) = czero if (i<=ma) then call ilut_copyin(i,ma,a,i,ione,m,nlw,nup,jmaxup,nrmi,weight,& - & row,heap,ktrw,trw,info) + & row,heap,ktrw,trw,info,shft) else call ilut_copyin(i-ma,mb,b,i,ione,m,nlw,nup,jmaxup,nrmi,weight,& - & row,heap,ktrw,trw,info) + & row,heap,ktrw,trw,info,shft) endif ! @@ -540,7 +542,7 @@ contains ! every nrb calls to copyin. If A is in CSR format it is unused. ! subroutine ilut_copyin(i,m,a,jd,jmin,jmax,nlw,nup,jmaxup,& - & nrmi,weight,row,heap,ktrw,trw,info) + & nrmi,weight,row,heap,ktrw,trw,info,shft) use psb_base_mod implicit none type(psb_cspmat_type), intent(in) :: a @@ -551,6 +553,7 @@ contains complex(psb_spk_), intent(inout) :: row(:) real(psb_spk_), intent(in) :: weight type(psb_i_heap), intent(inout) :: heap + complex(psb_spk_), intent(in) :: shft integer(psb_ipk_) :: k,j,irb,kin,nz integer(psb_ipk_), parameter :: nrb=40 @@ -597,6 +600,7 @@ contains call heap%insert(k,info) if (info /= psb_success_) exit if (kjd) then nup = nup + 1 if (abs(row(k))>dmaxup) then @@ -648,6 +652,7 @@ contains call heap%insert(k,info) if (info /= psb_success_) exit if (kjd) then nup = nup + 1 if (abs(row(k))>dmaxup) then diff --git a/prec/impl/psb_d_ilu0_fact.f90 b/prec/impl/psb_d_ilu0_fact.f90 index 478eedfa..29968e0c 100644 --- a/prec/impl/psb_d_ilu0_fact.f90 +++ b/prec/impl/psb_d_ilu0_fact.f90 @@ -130,7 +130,7 @@ ! greater than 0. If the overlap is 0 or the matrix has been reordered ! (see psb_fact_bld), then blck is empty. ! -subroutine psb_dilu0_fact(ialg,a,l,u,d,info,blck, upd) +subroutine psb_dilu0_fact(ialg,a,l,u,d,info,blck, upd,shft) use psb_base_mod use psb_d_ilu_fact_mod, psb_protect_name => psb_dilu0_fact @@ -145,11 +145,13 @@ subroutine psb_dilu0_fact(ialg,a,l,u,d,info,blck, upd) integer(psb_ipk_), intent(out) :: info type(psb_dspmat_type),intent(in), optional, target :: blck character, intent(in), optional :: upd + real(psb_dpk_), intent(in), optional :: shft ! Local variables integer(psb_ipk_) :: l1, l2, m, err_act type(psb_dspmat_type), pointer :: blck_ type(psb_d_csr_sparse_mat) :: ll, uu + real(psb_dpk_) :: shft_ character :: upd_ character(len=20) :: name, ch_err @@ -177,7 +179,12 @@ subroutine psb_dilu0_fact(ialg,a,l,u,d,info,blck, upd) else upd_ = 'F' end if - + if (present(shft)) then + shft_ = shft + else + shft_ = dzero + end if + m = a%get_nrows() + blck_%get_nrows() if ((m /= l%get_nrows()).or.(m /= u%get_nrows()).or.& & (m > size(d)) ) then @@ -193,7 +200,7 @@ subroutine psb_dilu0_fact(ialg,a,l,u,d,info,blck, upd) ! Compute the ILU(0) or the MILU(0) factorization, depending on ialg ! call psb_dilu0_factint(ialg,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,upd_,info) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,upd_,shft_,info) if(info.ne.0) then info=psb_err_from_subroutine_ ch_err='psb_dilu0_factint' @@ -314,7 +321,7 @@ contains ! Error code. ! subroutine psb_dilu0_factint(ialg,a,b,& - & d,lval,lja,lirp,uval,uja,uirp,l1,l2,upd,info) + & d,lval,lja,lirp,uval,uja,uirp,l1,l2,upd,shft,info) implicit none @@ -325,6 +332,7 @@ contains integer(psb_ipk_), intent(inout) :: lja(:),lirp(:),uja(:),uirp(:) real(psb_dpk_), intent(inout) :: lval(:),uval(:),d(:) character, intent(in) :: upd + real(psb_dpk_), intent(in) :: shft ! Local variables integer(psb_ipk_) :: i,j,k,l,low1,low2,kk,jj,ll, ktrw,err_act, m @@ -382,14 +390,14 @@ contains ! into lval/d(i)/uval ! call ilu_copyin(i,ma,a,i,ione,m,l1,lja,lval,& - & d(i),l2,uja,uval,ktrw,trw,upd) + & d(i),l2,uja,uval,ktrw,trw,upd,shft) else ! ! Copy the i-th local row of the matrix, stored in b ! (as (i-ma)-th row), into lval/d(i)/uval ! call ilu_copyin(i-ma,mb,b,i,ione,m,l1,lja,lval,& - & d(i),l2,uja,uval,ktrw,trw,upd) + & d(i),l2,uja,uval,ktrw,trw,upd,shft) endif lirp(i+1) = l1 + 1 @@ -583,7 +591,7 @@ contains ! every nrb calls to copyin. If A is in CSR format it is unused. ! subroutine ilu_copyin(i,m,a,jd,jmin,jmax,l1,lja,lval,& - & dia,l2,uja,uval,ktrw,trw,upd) + & dia,l2,uja,uval,ktrw,trw,upd,shft) use psb_base_mod @@ -597,6 +605,7 @@ contains integer(psb_ipk_), intent(inout) :: lja(:), uja(:) real(psb_dpk_), intent(inout) :: lval(:), uval(:), dia character, intent(in) :: upd + real(psb_dpk_), intent(in) :: shft ! Local variables integer(psb_ipk_) :: k,j,info,irb, nz integer(psb_ipk_), parameter :: nrb=40 @@ -625,7 +634,7 @@ contains lval(l1) = aa%val(j) lja(l1) = k else if (k == jd) then - dia = aa%val(j) + dia = aa%val(j) + shft else if ((k > jd).and.(k <= jmax)) then l2 = l2 + 1 uval(l2) = aa%val(j) @@ -665,7 +674,7 @@ contains lval(l1) = trw%val(ktrw) lja(l1) = k else if (k == jd) then - dia = trw%val(ktrw) + dia = trw%val(ktrw) + shft else if ((k > jd).and.(k <= jmax)) then l2 = l2 + 1 uval(l2) = trw%val(ktrw) diff --git a/prec/impl/psb_d_iluk_fact.f90 b/prec/impl/psb_d_iluk_fact.f90 index 544ec987..8f1e6c7b 100644 --- a/prec/impl/psb_d_iluk_fact.f90 +++ b/prec/impl/psb_d_iluk_fact.f90 @@ -127,7 +127,7 @@ ! greater than 0. If the overlap is 0 or the matrix has been reordered ! (see psb_fact_bld), then blck does not contain any row. ! -subroutine psb_diluk_fact(fill_in,ialg,a,l,u,d,info,blck) +subroutine psb_diluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) use psb_base_mod use psb_d_ilu_fact_mod, psb_protect_name => psb_diluk_fact @@ -141,6 +141,7 @@ subroutine psb_diluk_fact(fill_in,ialg,a,l,u,d,info,blck) type(psb_dspmat_type),intent(inout) :: l,u type(psb_dspmat_type),intent(in), optional, target :: blck real(psb_dpk_), intent(inout) :: d(:) + real(psb_dpk_), intent(in), optional :: shft ! Local Variables integer(psb_ipk_) :: l1, l2, m, err_act @@ -184,7 +185,7 @@ subroutine psb_diluk_fact(fill_in,ialg,a,l,u,d,info,blck) ! Compute the ILU(k) or the MILU(k) factorization, depending on ialg ! call psb_diluk_factint(fill_in,ialg,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,shft) if (info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='psb_diluk_factint' @@ -298,7 +299,7 @@ contains ! Error code. ! subroutine psb_diluk_factint(fill_in,ialg,a,b,& - & d,lval,lja,lirp,uval,uja,uirp,l1,l2,info) + & d,lval,lja,lirp,uval,uja,uirp,l1,l2,info,shft) use psb_base_mod @@ -311,6 +312,7 @@ contains integer(psb_ipk_), allocatable, intent(inout) :: lja(:),lirp(:),uja(:),uirp(:) real(psb_dpk_), allocatable, intent(inout) :: lval(:),uval(:) real(psb_dpk_), intent(inout) :: d(:) + real(psb_dpk_), intent(in) :: shft ! Local variables integer(psb_ipk_) :: ma,mb,i, ktrw,err_act,nidx, m @@ -400,13 +402,13 @@ contains ! ! Copy into trw the i-th local row of the matrix, stored in a ! - call iluk_copyin(i,ma,a,ione,m,row,rowlevs,heap,ktrw,trw,info) + call iluk_copyin(i,ma,a,ione,m,row,rowlevs,heap,ktrw,trw,info,shft) else ! ! Copy into trw the i-th local row of the matrix, stored in b ! (as (i-ma)-th row) ! - call iluk_copyin(i-ma,mb,b,ione,m,row,rowlevs,heap,ktrw,trw,info) + call iluk_copyin(i-ma,mb,b,ione,m,row,rowlevs,heap,ktrw,trw,info,shft) endif ! Do an elimination step on the current row. It turns out we only @@ -516,7 +518,7 @@ contains ! until we empty the buffer. Thus we will make a call to psb_sp_getblk ! every nrb calls to copyin. If A is in CSR format it is unused. ! - subroutine iluk_copyin(i,m,a,jmin,jmax,row,rowlevs,heap,ktrw,trw,info) + subroutine iluk_copyin(i,m,a,jmin,jmax,row,rowlevs,heap,ktrw,trw,info,shft) use psb_base_mod @@ -530,6 +532,8 @@ contains integer(psb_ipk_), intent(inout) :: rowlevs(:) real(psb_dpk_), intent(inout) :: row(:) type(psb_i_heap), intent(inout) :: heap + real(psb_dpk_), intent(in) :: shft + ! Local variables integer(psb_ipk_) :: k,j,irb,err_act,nz @@ -554,6 +558,7 @@ contains k = aa%ja(j) if ((jmin<=k).and.(k<=jmax)) then row(k) = aa%val(j) + if (k==i) row(k) = row(k) + shft rowlevs(k) = 0 call heap%insert(k,info) end if @@ -587,6 +592,7 @@ contains k = trw%ja(ktrw) if ((jmin<=k).and.(k<=jmax)) then row(k) = trw%val(ktrw) + if (k==i) row(k) = row(k) + shft rowlevs(k) = 0 call heap%insert(k,info) end if @@ -670,7 +676,8 @@ contains ! Note: this argument is intent(inout) and not only intent(out) ! to retain its allocation, done by this routine. ! - subroutine iluk_fact(fill_in,i,row,rowlevs,heap,d,uja,uirp,uval,uplevs,nidx,idxs,info) + subroutine iluk_fact(fill_in,i,row,rowlevs,heap,d,& + & uja,uirp,uval,uplevs,nidx,idxs,info) use psb_base_mod diff --git a/prec/impl/psb_d_ilut_fact.f90 b/prec/impl/psb_d_ilut_fact.f90 index 6c2dc698..c6079b55 100644 --- a/prec/impl/psb_d_ilut_fact.f90 +++ b/prec/impl/psb_d_ilut_fact.f90 @@ -123,7 +123,7 @@ ! greater than 0. If the overlap is 0 or the matrix has been reordered ! (see psb_fact_bld), then blck does not contain any row. ! -subroutine psb_dilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale) +subroutine psb_dilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) use psb_base_mod use psb_d_ilu_fact_mod, psb_protect_name => psb_dilut_fact @@ -139,6 +139,7 @@ subroutine psb_dilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale) real(psb_dpk_), intent(inout) :: d(:) type(psb_dspmat_type),intent(in), optional, target :: blck integer(psb_ipk_), intent(in), optional :: iscale + real(psb_dpk_), intent(in), optional :: shft ! Local Variables integer(psb_ipk_) :: l1, l2, m, err_act, iscale_ @@ -206,7 +207,7 @@ subroutine psb_dilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale) ! Compute the ILU(k,t) factorization ! call psb_dilut_factint(fill_in,thres,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,scale) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,scale,shft) if (info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='psb_dilut_factint' @@ -316,7 +317,7 @@ contains ! Error code. ! subroutine psb_dilut_factint(fill_in,thres,a,b,& - & d,lval,lja,lirp,uval,uja,uirp,l1,l2,info,scale) + & d,lval,lja,lirp,uval,uja,uirp,l1,l2,info,scale,shft) use psb_base_mod @@ -331,6 +332,7 @@ contains real(psb_dpk_), allocatable, intent(inout) :: lval(:),uval(:) real(psb_dpk_), intent(inout) :: d(:) real(psb_dpk_), intent(in), optional :: scale + real(psb_dpk_), intent(in) :: shft ! Local Variables integer(psb_ipk_) :: i, ktrw,err_act,nidx,nlw,nup,jmaxup, ma, mb, m @@ -401,10 +403,10 @@ contains d(i) = czero if (i<=ma) then call ilut_copyin(i,ma,a,i,ione,m,nlw,nup,jmaxup,nrmi,weight,& - & row,heap,ktrw,trw,info) + & row,heap,ktrw,trw,info,shft) else call ilut_copyin(i-ma,mb,b,i,ione,m,nlw,nup,jmaxup,nrmi,weight,& - & row,heap,ktrw,trw,info) + & row,heap,ktrw,trw,info,shft) endif ! @@ -540,7 +542,7 @@ contains ! every nrb calls to copyin. If A is in CSR format it is unused. ! subroutine ilut_copyin(i,m,a,jd,jmin,jmax,nlw,nup,jmaxup,& - & nrmi,weight,row,heap,ktrw,trw,info) + & nrmi,weight,row,heap,ktrw,trw,info,shft) use psb_base_mod implicit none type(psb_dspmat_type), intent(in) :: a @@ -551,6 +553,7 @@ contains real(psb_dpk_), intent(inout) :: row(:) real(psb_dpk_), intent(in) :: weight type(psb_i_heap), intent(inout) :: heap + real(psb_dpk_), intent(in) :: shft integer(psb_ipk_) :: k,j,irb,kin,nz integer(psb_ipk_), parameter :: nrb=40 @@ -597,6 +600,7 @@ contains call heap%insert(k,info) if (info /= psb_success_) exit if (kjd) then nup = nup + 1 if (abs(row(k))>dmaxup) then @@ -648,6 +652,7 @@ contains call heap%insert(k,info) if (info /= psb_success_) exit if (kjd) then nup = nup + 1 if (abs(row(k))>dmaxup) then diff --git a/prec/impl/psb_s_ilu0_fact.f90 b/prec/impl/psb_s_ilu0_fact.f90 index b6f442e9..faa21fda 100644 --- a/prec/impl/psb_s_ilu0_fact.f90 +++ b/prec/impl/psb_s_ilu0_fact.f90 @@ -130,7 +130,7 @@ ! greater than 0. If the overlap is 0 or the matrix has been reordered ! (see psb_fact_bld), then blck is empty. ! -subroutine psb_silu0_fact(ialg,a,l,u,d,info,blck, upd) +subroutine psb_silu0_fact(ialg,a,l,u,d,info,blck, upd,shft) use psb_base_mod use psb_s_ilu_fact_mod, psb_protect_name => psb_silu0_fact @@ -145,11 +145,13 @@ subroutine psb_silu0_fact(ialg,a,l,u,d,info,blck, upd) integer(psb_ipk_), intent(out) :: info type(psb_sspmat_type),intent(in), optional, target :: blck character, intent(in), optional :: upd + real(psb_spk_), intent(in), optional :: shft ! Local variables integer(psb_ipk_) :: l1, l2, m, err_act type(psb_sspmat_type), pointer :: blck_ type(psb_s_csr_sparse_mat) :: ll, uu + real(psb_spk_) :: shft_ character :: upd_ character(len=20) :: name, ch_err @@ -177,7 +179,12 @@ subroutine psb_silu0_fact(ialg,a,l,u,d,info,blck, upd) else upd_ = 'F' end if - + if (present(shft)) then + shft_ = shft + else + shft_ = szero + end if + m = a%get_nrows() + blck_%get_nrows() if ((m /= l%get_nrows()).or.(m /= u%get_nrows()).or.& & (m > size(d)) ) then @@ -193,7 +200,7 @@ subroutine psb_silu0_fact(ialg,a,l,u,d,info,blck, upd) ! Compute the ILU(0) or the MILU(0) factorization, depending on ialg ! call psb_silu0_factint(ialg,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,upd_,info) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,upd_,shft_,info) if(info.ne.0) then info=psb_err_from_subroutine_ ch_err='psb_silu0_factint' @@ -314,7 +321,7 @@ contains ! Error code. ! subroutine psb_silu0_factint(ialg,a,b,& - & d,lval,lja,lirp,uval,uja,uirp,l1,l2,upd,info) + & d,lval,lja,lirp,uval,uja,uirp,l1,l2,upd,shft,info) implicit none @@ -325,6 +332,7 @@ contains integer(psb_ipk_), intent(inout) :: lja(:),lirp(:),uja(:),uirp(:) real(psb_spk_), intent(inout) :: lval(:),uval(:),d(:) character, intent(in) :: upd + real(psb_spk_), intent(in) :: shft ! Local variables integer(psb_ipk_) :: i,j,k,l,low1,low2,kk,jj,ll, ktrw,err_act, m @@ -382,14 +390,14 @@ contains ! into lval/d(i)/uval ! call ilu_copyin(i,ma,a,i,ione,m,l1,lja,lval,& - & d(i),l2,uja,uval,ktrw,trw,upd) + & d(i),l2,uja,uval,ktrw,trw,upd,shft) else ! ! Copy the i-th local row of the matrix, stored in b ! (as (i-ma)-th row), into lval/d(i)/uval ! call ilu_copyin(i-ma,mb,b,i,ione,m,l1,lja,lval,& - & d(i),l2,uja,uval,ktrw,trw,upd) + & d(i),l2,uja,uval,ktrw,trw,upd,shft) endif lirp(i+1) = l1 + 1 @@ -583,7 +591,7 @@ contains ! every nrb calls to copyin. If A is in CSR format it is unused. ! subroutine ilu_copyin(i,m,a,jd,jmin,jmax,l1,lja,lval,& - & dia,l2,uja,uval,ktrw,trw,upd) + & dia,l2,uja,uval,ktrw,trw,upd,shft) use psb_base_mod @@ -597,6 +605,7 @@ contains integer(psb_ipk_), intent(inout) :: lja(:), uja(:) real(psb_spk_), intent(inout) :: lval(:), uval(:), dia character, intent(in) :: upd + real(psb_spk_), intent(in) :: shft ! Local variables integer(psb_ipk_) :: k,j,info,irb, nz integer(psb_ipk_), parameter :: nrb=40 @@ -625,7 +634,7 @@ contains lval(l1) = aa%val(j) lja(l1) = k else if (k == jd) then - dia = aa%val(j) + dia = aa%val(j) + shft else if ((k > jd).and.(k <= jmax)) then l2 = l2 + 1 uval(l2) = aa%val(j) @@ -665,7 +674,7 @@ contains lval(l1) = trw%val(ktrw) lja(l1) = k else if (k == jd) then - dia = trw%val(ktrw) + dia = trw%val(ktrw) + shft else if ((k > jd).and.(k <= jmax)) then l2 = l2 + 1 uval(l2) = trw%val(ktrw) diff --git a/prec/impl/psb_s_iluk_fact.f90 b/prec/impl/psb_s_iluk_fact.f90 index 6129663b..17ef4eef 100644 --- a/prec/impl/psb_s_iluk_fact.f90 +++ b/prec/impl/psb_s_iluk_fact.f90 @@ -127,7 +127,7 @@ ! greater than 0. If the overlap is 0 or the matrix has been reordered ! (see psb_fact_bld), then blck does not contain any row. ! -subroutine psb_siluk_fact(fill_in,ialg,a,l,u,d,info,blck) +subroutine psb_siluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) use psb_base_mod use psb_s_ilu_fact_mod, psb_protect_name => psb_siluk_fact @@ -141,6 +141,7 @@ subroutine psb_siluk_fact(fill_in,ialg,a,l,u,d,info,blck) type(psb_sspmat_type),intent(inout) :: l,u type(psb_sspmat_type),intent(in), optional, target :: blck real(psb_spk_), intent(inout) :: d(:) + real(psb_spk_), intent(in), optional :: shft ! Local Variables integer(psb_ipk_) :: l1, l2, m, err_act @@ -184,7 +185,7 @@ subroutine psb_siluk_fact(fill_in,ialg,a,l,u,d,info,blck) ! Compute the ILU(k) or the MILU(k) factorization, depending on ialg ! call psb_siluk_factint(fill_in,ialg,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,shft) if (info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='psb_siluk_factint' @@ -298,7 +299,7 @@ contains ! Error code. ! subroutine psb_siluk_factint(fill_in,ialg,a,b,& - & d,lval,lja,lirp,uval,uja,uirp,l1,l2,info) + & d,lval,lja,lirp,uval,uja,uirp,l1,l2,info,shft) use psb_base_mod @@ -311,6 +312,7 @@ contains integer(psb_ipk_), allocatable, intent(inout) :: lja(:),lirp(:),uja(:),uirp(:) real(psb_spk_), allocatable, intent(inout) :: lval(:),uval(:) real(psb_spk_), intent(inout) :: d(:) + real(psb_spk_), intent(in) :: shft ! Local variables integer(psb_ipk_) :: ma,mb,i, ktrw,err_act,nidx, m @@ -400,13 +402,13 @@ contains ! ! Copy into trw the i-th local row of the matrix, stored in a ! - call iluk_copyin(i,ma,a,ione,m,row,rowlevs,heap,ktrw,trw,info) + call iluk_copyin(i,ma,a,ione,m,row,rowlevs,heap,ktrw,trw,info,shft) else ! ! Copy into trw the i-th local row of the matrix, stored in b ! (as (i-ma)-th row) ! - call iluk_copyin(i-ma,mb,b,ione,m,row,rowlevs,heap,ktrw,trw,info) + call iluk_copyin(i-ma,mb,b,ione,m,row,rowlevs,heap,ktrw,trw,info,shft) endif ! Do an elimination step on the current row. It turns out we only @@ -516,7 +518,7 @@ contains ! until we empty the buffer. Thus we will make a call to psb_sp_getblk ! every nrb calls to copyin. If A is in CSR format it is unused. ! - subroutine iluk_copyin(i,m,a,jmin,jmax,row,rowlevs,heap,ktrw,trw,info) + subroutine iluk_copyin(i,m,a,jmin,jmax,row,rowlevs,heap,ktrw,trw,info,shft) use psb_base_mod @@ -530,6 +532,8 @@ contains integer(psb_ipk_), intent(inout) :: rowlevs(:) real(psb_spk_), intent(inout) :: row(:) type(psb_i_heap), intent(inout) :: heap + real(psb_spk_), intent(in) :: shft + ! Local variables integer(psb_ipk_) :: k,j,irb,err_act,nz @@ -554,6 +558,7 @@ contains k = aa%ja(j) if ((jmin<=k).and.(k<=jmax)) then row(k) = aa%val(j) + if (k==i) row(k) = row(k) + shft rowlevs(k) = 0 call heap%insert(k,info) end if @@ -587,6 +592,7 @@ contains k = trw%ja(ktrw) if ((jmin<=k).and.(k<=jmax)) then row(k) = trw%val(ktrw) + if (k==i) row(k) = row(k) + shft rowlevs(k) = 0 call heap%insert(k,info) end if @@ -670,7 +676,8 @@ contains ! Note: this argument is intent(inout) and not only intent(out) ! to retain its allocation, done by this routine. ! - subroutine iluk_fact(fill_in,i,row,rowlevs,heap,d,uja,uirp,uval,uplevs,nidx,idxs,info) + subroutine iluk_fact(fill_in,i,row,rowlevs,heap,d,& + & uja,uirp,uval,uplevs,nidx,idxs,info) use psb_base_mod diff --git a/prec/impl/psb_s_ilut_fact.f90 b/prec/impl/psb_s_ilut_fact.f90 index 43cacf41..76b514a3 100644 --- a/prec/impl/psb_s_ilut_fact.f90 +++ b/prec/impl/psb_s_ilut_fact.f90 @@ -123,7 +123,7 @@ ! greater than 0. If the overlap is 0 or the matrix has been reordered ! (see psb_fact_bld), then blck does not contain any row. ! -subroutine psb_silut_fact(fill_in,thres,a,l,u,d,info,blck,iscale) +subroutine psb_silut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) use psb_base_mod use psb_s_ilu_fact_mod, psb_protect_name => psb_silut_fact @@ -139,6 +139,7 @@ subroutine psb_silut_fact(fill_in,thres,a,l,u,d,info,blck,iscale) real(psb_spk_), intent(inout) :: d(:) type(psb_sspmat_type),intent(in), optional, target :: blck integer(psb_ipk_), intent(in), optional :: iscale + real(psb_spk_), intent(in), optional :: shft ! Local Variables integer(psb_ipk_) :: l1, l2, m, err_act, iscale_ @@ -206,7 +207,7 @@ subroutine psb_silut_fact(fill_in,thres,a,l,u,d,info,blck,iscale) ! Compute the ILU(k,t) factorization ! call psb_silut_factint(fill_in,thres,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,scale) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,scale,shft) if (info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='psb_silut_factint' @@ -316,7 +317,7 @@ contains ! Error code. ! subroutine psb_silut_factint(fill_in,thres,a,b,& - & d,lval,lja,lirp,uval,uja,uirp,l1,l2,info,scale) + & d,lval,lja,lirp,uval,uja,uirp,l1,l2,info,scale,shft) use psb_base_mod @@ -331,6 +332,7 @@ contains real(psb_spk_), allocatable, intent(inout) :: lval(:),uval(:) real(psb_spk_), intent(inout) :: d(:) real(psb_spk_), intent(in), optional :: scale + real(psb_spk_), intent(in) :: shft ! Local Variables integer(psb_ipk_) :: i, ktrw,err_act,nidx,nlw,nup,jmaxup, ma, mb, m @@ -401,10 +403,10 @@ contains d(i) = czero if (i<=ma) then call ilut_copyin(i,ma,a,i,ione,m,nlw,nup,jmaxup,nrmi,weight,& - & row,heap,ktrw,trw,info) + & row,heap,ktrw,trw,info,shft) else call ilut_copyin(i-ma,mb,b,i,ione,m,nlw,nup,jmaxup,nrmi,weight,& - & row,heap,ktrw,trw,info) + & row,heap,ktrw,trw,info,shft) endif ! @@ -540,7 +542,7 @@ contains ! every nrb calls to copyin. If A is in CSR format it is unused. ! subroutine ilut_copyin(i,m,a,jd,jmin,jmax,nlw,nup,jmaxup,& - & nrmi,weight,row,heap,ktrw,trw,info) + & nrmi,weight,row,heap,ktrw,trw,info,shft) use psb_base_mod implicit none type(psb_sspmat_type), intent(in) :: a @@ -551,6 +553,7 @@ contains real(psb_spk_), intent(inout) :: row(:) real(psb_spk_), intent(in) :: weight type(psb_i_heap), intent(inout) :: heap + real(psb_spk_), intent(in) :: shft integer(psb_ipk_) :: k,j,irb,kin,nz integer(psb_ipk_), parameter :: nrb=40 @@ -597,6 +600,7 @@ contains call heap%insert(k,info) if (info /= psb_success_) exit if (kjd) then nup = nup + 1 if (abs(row(k))>dmaxup) then @@ -648,6 +652,7 @@ contains call heap%insert(k,info) if (info /= psb_success_) exit if (kjd) then nup = nup + 1 if (abs(row(k))>dmaxup) then diff --git a/prec/impl/psb_z_ilu0_fact.f90 b/prec/impl/psb_z_ilu0_fact.f90 index 26322e95..c6c0fc55 100644 --- a/prec/impl/psb_z_ilu0_fact.f90 +++ b/prec/impl/psb_z_ilu0_fact.f90 @@ -130,7 +130,7 @@ ! greater than 0. If the overlap is 0 or the matrix has been reordered ! (see psb_fact_bld), then blck is empty. ! -subroutine psb_zilu0_fact(ialg,a,l,u,d,info,blck, upd) +subroutine psb_zilu0_fact(ialg,a,l,u,d,info,blck, upd,shft) use psb_base_mod use psb_z_ilu_fact_mod, psb_protect_name => psb_zilu0_fact @@ -145,11 +145,13 @@ subroutine psb_zilu0_fact(ialg,a,l,u,d,info,blck, upd) integer(psb_ipk_), intent(out) :: info type(psb_zspmat_type),intent(in), optional, target :: blck character, intent(in), optional :: upd + complex(psb_dpk_), intent(in), optional :: shft ! Local variables integer(psb_ipk_) :: l1, l2, m, err_act type(psb_zspmat_type), pointer :: blck_ type(psb_z_csr_sparse_mat) :: ll, uu + complex(psb_dpk_) :: shft_ character :: upd_ character(len=20) :: name, ch_err @@ -177,7 +179,12 @@ subroutine psb_zilu0_fact(ialg,a,l,u,d,info,blck, upd) else upd_ = 'F' end if - + if (present(shft)) then + shft_ = shft + else + shft_ = zzero + end if + m = a%get_nrows() + blck_%get_nrows() if ((m /= l%get_nrows()).or.(m /= u%get_nrows()).or.& & (m > size(d)) ) then @@ -193,7 +200,7 @@ subroutine psb_zilu0_fact(ialg,a,l,u,d,info,blck, upd) ! Compute the ILU(0) or the MILU(0) factorization, depending on ialg ! call psb_zilu0_factint(ialg,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,upd_,info) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,upd_,shft_,info) if(info.ne.0) then info=psb_err_from_subroutine_ ch_err='psb_zilu0_factint' @@ -314,7 +321,7 @@ contains ! Error code. ! subroutine psb_zilu0_factint(ialg,a,b,& - & d,lval,lja,lirp,uval,uja,uirp,l1,l2,upd,info) + & d,lval,lja,lirp,uval,uja,uirp,l1,l2,upd,shft,info) implicit none @@ -325,6 +332,7 @@ contains integer(psb_ipk_), intent(inout) :: lja(:),lirp(:),uja(:),uirp(:) complex(psb_dpk_), intent(inout) :: lval(:),uval(:),d(:) character, intent(in) :: upd + complex(psb_dpk_), intent(in) :: shft ! Local variables integer(psb_ipk_) :: i,j,k,l,low1,low2,kk,jj,ll, ktrw,err_act, m @@ -382,14 +390,14 @@ contains ! into lval/d(i)/uval ! call ilu_copyin(i,ma,a,i,ione,m,l1,lja,lval,& - & d(i),l2,uja,uval,ktrw,trw,upd) + & d(i),l2,uja,uval,ktrw,trw,upd,shft) else ! ! Copy the i-th local row of the matrix, stored in b ! (as (i-ma)-th row), into lval/d(i)/uval ! call ilu_copyin(i-ma,mb,b,i,ione,m,l1,lja,lval,& - & d(i),l2,uja,uval,ktrw,trw,upd) + & d(i),l2,uja,uval,ktrw,trw,upd,shft) endif lirp(i+1) = l1 + 1 @@ -583,7 +591,7 @@ contains ! every nrb calls to copyin. If A is in CSR format it is unused. ! subroutine ilu_copyin(i,m,a,jd,jmin,jmax,l1,lja,lval,& - & dia,l2,uja,uval,ktrw,trw,upd) + & dia,l2,uja,uval,ktrw,trw,upd,shft) use psb_base_mod @@ -597,6 +605,7 @@ contains integer(psb_ipk_), intent(inout) :: lja(:), uja(:) complex(psb_dpk_), intent(inout) :: lval(:), uval(:), dia character, intent(in) :: upd + complex(psb_dpk_), intent(in) :: shft ! Local variables integer(psb_ipk_) :: k,j,info,irb, nz integer(psb_ipk_), parameter :: nrb=40 @@ -625,7 +634,7 @@ contains lval(l1) = aa%val(j) lja(l1) = k else if (k == jd) then - dia = aa%val(j) + dia = aa%val(j) + shft else if ((k > jd).and.(k <= jmax)) then l2 = l2 + 1 uval(l2) = aa%val(j) @@ -665,7 +674,7 @@ contains lval(l1) = trw%val(ktrw) lja(l1) = k else if (k == jd) then - dia = trw%val(ktrw) + dia = trw%val(ktrw) + shft else if ((k > jd).and.(k <= jmax)) then l2 = l2 + 1 uval(l2) = trw%val(ktrw) diff --git a/prec/impl/psb_z_iluk_fact.f90 b/prec/impl/psb_z_iluk_fact.f90 index 1a398cda..7675226a 100644 --- a/prec/impl/psb_z_iluk_fact.f90 +++ b/prec/impl/psb_z_iluk_fact.f90 @@ -127,7 +127,7 @@ ! greater than 0. If the overlap is 0 or the matrix has been reordered ! (see psb_fact_bld), then blck does not contain any row. ! -subroutine psb_ziluk_fact(fill_in,ialg,a,l,u,d,info,blck) +subroutine psb_ziluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) use psb_base_mod use psb_z_ilu_fact_mod, psb_protect_name => psb_ziluk_fact @@ -141,6 +141,7 @@ subroutine psb_ziluk_fact(fill_in,ialg,a,l,u,d,info,blck) type(psb_zspmat_type),intent(inout) :: l,u type(psb_zspmat_type),intent(in), optional, target :: blck complex(psb_dpk_), intent(inout) :: d(:) + complex(psb_dpk_), intent(in), optional :: shft ! Local Variables integer(psb_ipk_) :: l1, l2, m, err_act @@ -184,7 +185,7 @@ subroutine psb_ziluk_fact(fill_in,ialg,a,l,u,d,info,blck) ! Compute the ILU(k) or the MILU(k) factorization, depending on ialg ! call psb_ziluk_factint(fill_in,ialg,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,shft) if (info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='psb_ziluk_factint' @@ -298,7 +299,7 @@ contains ! Error code. ! subroutine psb_ziluk_factint(fill_in,ialg,a,b,& - & d,lval,lja,lirp,uval,uja,uirp,l1,l2,info) + & d,lval,lja,lirp,uval,uja,uirp,l1,l2,info,shft) use psb_base_mod @@ -311,6 +312,7 @@ contains integer(psb_ipk_), allocatable, intent(inout) :: lja(:),lirp(:),uja(:),uirp(:) complex(psb_dpk_), allocatable, intent(inout) :: lval(:),uval(:) complex(psb_dpk_), intent(inout) :: d(:) + complex(psb_dpk_), intent(in) :: shft ! Local variables integer(psb_ipk_) :: ma,mb,i, ktrw,err_act,nidx, m @@ -400,13 +402,13 @@ contains ! ! Copy into trw the i-th local row of the matrix, stored in a ! - call iluk_copyin(i,ma,a,ione,m,row,rowlevs,heap,ktrw,trw,info) + call iluk_copyin(i,ma,a,ione,m,row,rowlevs,heap,ktrw,trw,info,shft) else ! ! Copy into trw the i-th local row of the matrix, stored in b ! (as (i-ma)-th row) ! - call iluk_copyin(i-ma,mb,b,ione,m,row,rowlevs,heap,ktrw,trw,info) + call iluk_copyin(i-ma,mb,b,ione,m,row,rowlevs,heap,ktrw,trw,info,shft) endif ! Do an elimination step on the current row. It turns out we only @@ -516,7 +518,7 @@ contains ! until we empty the buffer. Thus we will make a call to psb_sp_getblk ! every nrb calls to copyin. If A is in CSR format it is unused. ! - subroutine iluk_copyin(i,m,a,jmin,jmax,row,rowlevs,heap,ktrw,trw,info) + subroutine iluk_copyin(i,m,a,jmin,jmax,row,rowlevs,heap,ktrw,trw,info,shft) use psb_base_mod @@ -530,6 +532,8 @@ contains integer(psb_ipk_), intent(inout) :: rowlevs(:) complex(psb_dpk_), intent(inout) :: row(:) type(psb_i_heap), intent(inout) :: heap + complex(psb_dpk_), intent(in) :: shft + ! Local variables integer(psb_ipk_) :: k,j,irb,err_act,nz @@ -554,6 +558,7 @@ contains k = aa%ja(j) if ((jmin<=k).and.(k<=jmax)) then row(k) = aa%val(j) + if (k==i) row(k) = row(k) + shft rowlevs(k) = 0 call heap%insert(k,info) end if @@ -587,6 +592,7 @@ contains k = trw%ja(ktrw) if ((jmin<=k).and.(k<=jmax)) then row(k) = trw%val(ktrw) + if (k==i) row(k) = row(k) + shft rowlevs(k) = 0 call heap%insert(k,info) end if @@ -670,7 +676,8 @@ contains ! Note: this argument is intent(inout) and not only intent(out) ! to retain its allocation, done by this routine. ! - subroutine iluk_fact(fill_in,i,row,rowlevs,heap,d,uja,uirp,uval,uplevs,nidx,idxs,info) + subroutine iluk_fact(fill_in,i,row,rowlevs,heap,d,& + & uja,uirp,uval,uplevs,nidx,idxs,info) use psb_base_mod diff --git a/prec/impl/psb_z_ilut_fact.f90 b/prec/impl/psb_z_ilut_fact.f90 index 291dc778..2f004725 100644 --- a/prec/impl/psb_z_ilut_fact.f90 +++ b/prec/impl/psb_z_ilut_fact.f90 @@ -123,7 +123,7 @@ ! greater than 0. If the overlap is 0 or the matrix has been reordered ! (see psb_fact_bld), then blck does not contain any row. ! -subroutine psb_zilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale) +subroutine psb_zilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) use psb_base_mod use psb_z_ilu_fact_mod, psb_protect_name => psb_zilut_fact @@ -139,6 +139,7 @@ subroutine psb_zilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale) complex(psb_dpk_), intent(inout) :: d(:) type(psb_zspmat_type),intent(in), optional, target :: blck integer(psb_ipk_), intent(in), optional :: iscale + complex(psb_dpk_), intent(in), optional :: shft ! Local Variables integer(psb_ipk_) :: l1, l2, m, err_act, iscale_ @@ -206,7 +207,7 @@ subroutine psb_zilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale) ! Compute the ILU(k,t) factorization ! call psb_zilut_factint(fill_in,thres,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,scale) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,scale,shft) if (info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='psb_zilut_factint' @@ -316,7 +317,7 @@ contains ! Error code. ! subroutine psb_zilut_factint(fill_in,thres,a,b,& - & d,lval,lja,lirp,uval,uja,uirp,l1,l2,info,scale) + & d,lval,lja,lirp,uval,uja,uirp,l1,l2,info,scale,shft) use psb_base_mod @@ -331,6 +332,7 @@ contains complex(psb_dpk_), allocatable, intent(inout) :: lval(:),uval(:) complex(psb_dpk_), intent(inout) :: d(:) real(psb_dpk_), intent(in), optional :: scale + complex(psb_dpk_), intent(in) :: shft ! Local Variables integer(psb_ipk_) :: i, ktrw,err_act,nidx,nlw,nup,jmaxup, ma, mb, m @@ -401,10 +403,10 @@ contains d(i) = czero if (i<=ma) then call ilut_copyin(i,ma,a,i,ione,m,nlw,nup,jmaxup,nrmi,weight,& - & row,heap,ktrw,trw,info) + & row,heap,ktrw,trw,info,shft) else call ilut_copyin(i-ma,mb,b,i,ione,m,nlw,nup,jmaxup,nrmi,weight,& - & row,heap,ktrw,trw,info) + & row,heap,ktrw,trw,info,shft) endif ! @@ -540,7 +542,7 @@ contains ! every nrb calls to copyin. If A is in CSR format it is unused. ! subroutine ilut_copyin(i,m,a,jd,jmin,jmax,nlw,nup,jmaxup,& - & nrmi,weight,row,heap,ktrw,trw,info) + & nrmi,weight,row,heap,ktrw,trw,info,shft) use psb_base_mod implicit none type(psb_zspmat_type), intent(in) :: a @@ -551,6 +553,7 @@ contains complex(psb_dpk_), intent(inout) :: row(:) real(psb_dpk_), intent(in) :: weight type(psb_i_heap), intent(inout) :: heap + complex(psb_dpk_), intent(in) :: shft integer(psb_ipk_) :: k,j,irb,kin,nz integer(psb_ipk_), parameter :: nrb=40 @@ -597,6 +600,7 @@ contains call heap%insert(k,info) if (info /= psb_success_) exit if (kjd) then nup = nup + 1 if (abs(row(k))>dmaxup) then @@ -648,6 +652,7 @@ contains call heap%insert(k,info) if (info /= psb_success_) exit if (kjd) then nup = nup + 1 if (abs(row(k))>dmaxup) then diff --git a/prec/psb_c_ilu_fact_mod.f90 b/prec/psb_c_ilu_fact_mod.f90 index 45d06211..0fae1fc5 100644 --- a/prec/psb_c_ilu_fact_mod.f90 +++ b/prec/psb_c_ilu_fact_mod.f90 @@ -80,7 +80,7 @@ module psb_c_ilu_fact_mod use psb_base_mod use psb_prec_const_mod interface psb_ilu0_fact - subroutine psb_cilu0_fact(ialg,a,l,u,d,info,blck,upd) + subroutine psb_cilu0_fact(ialg,a,l,u,d,info,blck,upd,shft) import psb_cspmat_type, psb_spk_, psb_ipk_ integer(psb_ipk_), intent(in) :: ialg integer(psb_ipk_), intent(out) :: info @@ -89,11 +89,12 @@ module psb_c_ilu_fact_mod type(psb_cspmat_type),intent(in), optional, target :: blck character, intent(in), optional :: upd complex(psb_spk_), intent(inout) :: d(:) + complex(psb_spk_), intent(in), optional :: shft end subroutine psb_cilu0_fact end interface interface psb_iluk_fact - subroutine psb_ciluk_fact(fill_in,ialg,a,l,u,d,info,blck) + subroutine psb_ciluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) import psb_cspmat_type, psb_spk_, psb_ipk_ integer(psb_ipk_), intent(in) :: fill_in,ialg integer(psb_ipk_), intent(out) :: info @@ -101,11 +102,12 @@ module psb_c_ilu_fact_mod type(psb_cspmat_type),intent(inout) :: l,u type(psb_cspmat_type),intent(in), optional, target :: blck complex(psb_spk_), intent(inout) :: d(:) + complex(psb_spk_), intent(in), optional :: shft end subroutine psb_ciluk_fact end interface interface psb_ilut_fact - subroutine psb_cilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale) + subroutine psb_cilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) import psb_cspmat_type, psb_spk_, psb_ipk_ integer(psb_ipk_), intent(in) :: fill_in real(psb_spk_), intent(in) :: thres @@ -115,6 +117,7 @@ module psb_c_ilu_fact_mod complex(psb_spk_), intent(inout) :: d(:) type(psb_cspmat_type),intent(in), optional, target :: blck integer(psb_ipk_), intent(in), optional :: iscale + complex(psb_spk_), intent(in), optional :: shft end subroutine psb_cilut_fact end interface diff --git a/prec/psb_d_ilu_fact_mod.f90 b/prec/psb_d_ilu_fact_mod.f90 index 02753a4c..6354573d 100644 --- a/prec/psb_d_ilu_fact_mod.f90 +++ b/prec/psb_d_ilu_fact_mod.f90 @@ -80,7 +80,7 @@ module psb_d_ilu_fact_mod use psb_base_mod use psb_prec_const_mod interface psb_ilu0_fact - subroutine psb_dilu0_fact(ialg,a,l,u,d,info,blck,upd) + subroutine psb_dilu0_fact(ialg,a,l,u,d,info,blck,upd,shft) import psb_dspmat_type, psb_dpk_, psb_ipk_ integer(psb_ipk_), intent(in) :: ialg integer(psb_ipk_), intent(out) :: info @@ -89,11 +89,12 @@ module psb_d_ilu_fact_mod type(psb_dspmat_type),intent(in), optional, target :: blck character, intent(in), optional :: upd real(psb_dpk_), intent(inout) :: d(:) + real(psb_dpk_), intent(in), optional :: shft end subroutine psb_dilu0_fact end interface interface psb_iluk_fact - subroutine psb_diluk_fact(fill_in,ialg,a,l,u,d,info,blck) + subroutine psb_diluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) import psb_dspmat_type, psb_dpk_, psb_ipk_ integer(psb_ipk_), intent(in) :: fill_in,ialg integer(psb_ipk_), intent(out) :: info @@ -101,11 +102,12 @@ module psb_d_ilu_fact_mod type(psb_dspmat_type),intent(inout) :: l,u type(psb_dspmat_type),intent(in), optional, target :: blck real(psb_dpk_), intent(inout) :: d(:) + real(psb_dpk_), intent(in), optional :: shft end subroutine psb_diluk_fact end interface interface psb_ilut_fact - subroutine psb_dilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale) + subroutine psb_dilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) import psb_dspmat_type, psb_dpk_, psb_ipk_ integer(psb_ipk_), intent(in) :: fill_in real(psb_dpk_), intent(in) :: thres @@ -115,6 +117,7 @@ module psb_d_ilu_fact_mod real(psb_dpk_), intent(inout) :: d(:) type(psb_dspmat_type),intent(in), optional, target :: blck integer(psb_ipk_), intent(in), optional :: iscale + real(psb_dpk_), intent(in), optional :: shft end subroutine psb_dilut_fact end interface diff --git a/prec/psb_s_ilu_fact_mod.f90 b/prec/psb_s_ilu_fact_mod.f90 index 6334df15..4021adc9 100644 --- a/prec/psb_s_ilu_fact_mod.f90 +++ b/prec/psb_s_ilu_fact_mod.f90 @@ -80,7 +80,7 @@ module psb_s_ilu_fact_mod use psb_base_mod use psb_prec_const_mod interface psb_ilu0_fact - subroutine psb_silu0_fact(ialg,a,l,u,d,info,blck,upd) + subroutine psb_silu0_fact(ialg,a,l,u,d,info,blck,upd,shft) import psb_sspmat_type, psb_spk_, psb_ipk_ integer(psb_ipk_), intent(in) :: ialg integer(psb_ipk_), intent(out) :: info @@ -89,11 +89,12 @@ module psb_s_ilu_fact_mod type(psb_sspmat_type),intent(in), optional, target :: blck character, intent(in), optional :: upd real(psb_spk_), intent(inout) :: d(:) + real(psb_spk_), intent(in), optional :: shft end subroutine psb_silu0_fact end interface interface psb_iluk_fact - subroutine psb_siluk_fact(fill_in,ialg,a,l,u,d,info,blck) + subroutine psb_siluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) import psb_sspmat_type, psb_spk_, psb_ipk_ integer(psb_ipk_), intent(in) :: fill_in,ialg integer(psb_ipk_), intent(out) :: info @@ -101,11 +102,12 @@ module psb_s_ilu_fact_mod type(psb_sspmat_type),intent(inout) :: l,u type(psb_sspmat_type),intent(in), optional, target :: blck real(psb_spk_), intent(inout) :: d(:) + real(psb_spk_), intent(in), optional :: shft end subroutine psb_siluk_fact end interface interface psb_ilut_fact - subroutine psb_silut_fact(fill_in,thres,a,l,u,d,info,blck,iscale) + subroutine psb_silut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) import psb_sspmat_type, psb_spk_, psb_ipk_ integer(psb_ipk_), intent(in) :: fill_in real(psb_spk_), intent(in) :: thres @@ -115,6 +117,7 @@ module psb_s_ilu_fact_mod real(psb_spk_), intent(inout) :: d(:) type(psb_sspmat_type),intent(in), optional, target :: blck integer(psb_ipk_), intent(in), optional :: iscale + real(psb_spk_), intent(in), optional :: shft end subroutine psb_silut_fact end interface diff --git a/prec/psb_z_ilu_fact_mod.f90 b/prec/psb_z_ilu_fact_mod.f90 index 220d673f..4793b43b 100644 --- a/prec/psb_z_ilu_fact_mod.f90 +++ b/prec/psb_z_ilu_fact_mod.f90 @@ -80,7 +80,7 @@ module psb_z_ilu_fact_mod use psb_base_mod use psb_prec_const_mod interface psb_ilu0_fact - subroutine psb_zilu0_fact(ialg,a,l,u,d,info,blck,upd) + subroutine psb_zilu0_fact(ialg,a,l,u,d,info,blck,upd,shft) import psb_zspmat_type, psb_dpk_, psb_ipk_ integer(psb_ipk_), intent(in) :: ialg integer(psb_ipk_), intent(out) :: info @@ -89,11 +89,12 @@ module psb_z_ilu_fact_mod type(psb_zspmat_type),intent(in), optional, target :: blck character, intent(in), optional :: upd complex(psb_dpk_), intent(inout) :: d(:) + complex(psb_dpk_), intent(in), optional :: shft end subroutine psb_zilu0_fact end interface interface psb_iluk_fact - subroutine psb_ziluk_fact(fill_in,ialg,a,l,u,d,info,blck) + subroutine psb_ziluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) import psb_zspmat_type, psb_dpk_, psb_ipk_ integer(psb_ipk_), intent(in) :: fill_in,ialg integer(psb_ipk_), intent(out) :: info @@ -101,11 +102,12 @@ module psb_z_ilu_fact_mod type(psb_zspmat_type),intent(inout) :: l,u type(psb_zspmat_type),intent(in), optional, target :: blck complex(psb_dpk_), intent(inout) :: d(:) + complex(psb_dpk_), intent(in), optional :: shft end subroutine psb_ziluk_fact end interface interface psb_ilut_fact - subroutine psb_zilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale) + subroutine psb_zilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) import psb_zspmat_type, psb_dpk_, psb_ipk_ integer(psb_ipk_), intent(in) :: fill_in real(psb_dpk_), intent(in) :: thres @@ -115,6 +117,7 @@ module psb_z_ilu_fact_mod complex(psb_dpk_), intent(inout) :: d(:) type(psb_zspmat_type),intent(in), optional, target :: blck integer(psb_ipk_), intent(in), optional :: iscale + complex(psb_dpk_), intent(in), optional :: shft end subroutine psb_zilut_fact end interface From 250a6300bab7dc75cdf7722c998db4ca4c91173f Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Mon, 23 Oct 2023 15:14:48 +0200 Subject: [PATCH 05/48] Fix SHFT implementation --- prec/impl/psb_c_iluk_fact.f90 | 6 ++++++ prec/impl/psb_c_ilut_fact.f90 | 6 ++++++ prec/impl/psb_d_iluk_fact.f90 | 6 ++++++ prec/impl/psb_d_ilut_fact.f90 | 6 ++++++ prec/impl/psb_s_iluk_fact.f90 | 6 ++++++ prec/impl/psb_s_ilut_fact.f90 | 6 ++++++ prec/impl/psb_z_iluk_fact.f90 | 6 ++++++ prec/impl/psb_z_ilut_fact.f90 | 6 ++++++ 8 files changed, 48 insertions(+) diff --git a/prec/impl/psb_c_iluk_fact.f90 b/prec/impl/psb_c_iluk_fact.f90 index ddda6d20..4e6d7a5c 100644 --- a/prec/impl/psb_c_iluk_fact.f90 +++ b/prec/impl/psb_c_iluk_fact.f90 @@ -145,6 +145,7 @@ subroutine psb_ciluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) ! Local Variables integer(psb_ipk_) :: l1, l2, m, err_act + complex(psb_spk_) :: shft_ type(psb_cspmat_type), pointer :: blck_ type(psb_c_csr_sparse_mat) :: ll, uu character(len=20) :: name, ch_err @@ -168,6 +169,11 @@ subroutine psb_ciluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) goto 9999 end if endif + if (present(shft)) then + shft_ = shft + else + shft_ = czero + end if m = a%get_nrows() + blck_%get_nrows() if ((m /= l%get_nrows()).or.(m /= u%get_nrows()).or.& diff --git a/prec/impl/psb_c_ilut_fact.f90 b/prec/impl/psb_c_ilut_fact.f90 index 997b3b84..6aecb53b 100644 --- a/prec/impl/psb_c_ilut_fact.f90 +++ b/prec/impl/psb_c_ilut_fact.f90 @@ -143,6 +143,7 @@ subroutine psb_cilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) ! Local Variables integer(psb_ipk_) :: l1, l2, m, err_act, iscale_ + complex(psb_spk_) :: shft_ type(psb_cspmat_type), pointer :: blck_ type(psb_c_csr_sparse_mat) :: ll, uu real(psb_spk_) :: scale @@ -178,6 +179,11 @@ subroutine psb_cilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) else iscale_ = psb_ilu_scale_none_ end if + if (present(shft)) then + shft_ = shft + else + shft_ = czero + end if select case(iscale_) case(psb_ilu_scale_none_) diff --git a/prec/impl/psb_d_iluk_fact.f90 b/prec/impl/psb_d_iluk_fact.f90 index 8f1e6c7b..94c34fec 100644 --- a/prec/impl/psb_d_iluk_fact.f90 +++ b/prec/impl/psb_d_iluk_fact.f90 @@ -145,6 +145,7 @@ subroutine psb_diluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) ! Local Variables integer(psb_ipk_) :: l1, l2, m, err_act + real(psb_dpk_) :: shft_ type(psb_dspmat_type), pointer :: blck_ type(psb_d_csr_sparse_mat) :: ll, uu character(len=20) :: name, ch_err @@ -168,6 +169,11 @@ subroutine psb_diluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) goto 9999 end if endif + if (present(shft)) then + shft_ = shft + else + shft_ = dzero + end if m = a%get_nrows() + blck_%get_nrows() if ((m /= l%get_nrows()).or.(m /= u%get_nrows()).or.& diff --git a/prec/impl/psb_d_ilut_fact.f90 b/prec/impl/psb_d_ilut_fact.f90 index c6079b55..c4f05df6 100644 --- a/prec/impl/psb_d_ilut_fact.f90 +++ b/prec/impl/psb_d_ilut_fact.f90 @@ -143,6 +143,7 @@ subroutine psb_dilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) ! Local Variables integer(psb_ipk_) :: l1, l2, m, err_act, iscale_ + real(psb_dpk_) :: shft_ type(psb_dspmat_type), pointer :: blck_ type(psb_d_csr_sparse_mat) :: ll, uu real(psb_dpk_) :: scale @@ -178,6 +179,11 @@ subroutine psb_dilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) else iscale_ = psb_ilu_scale_none_ end if + if (present(shft)) then + shft_ = shft + else + shft_ = dzero + end if select case(iscale_) case(psb_ilu_scale_none_) diff --git a/prec/impl/psb_s_iluk_fact.f90 b/prec/impl/psb_s_iluk_fact.f90 index 17ef4eef..99d48880 100644 --- a/prec/impl/psb_s_iluk_fact.f90 +++ b/prec/impl/psb_s_iluk_fact.f90 @@ -145,6 +145,7 @@ subroutine psb_siluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) ! Local Variables integer(psb_ipk_) :: l1, l2, m, err_act + real(psb_spk_) :: shft_ type(psb_sspmat_type), pointer :: blck_ type(psb_s_csr_sparse_mat) :: ll, uu character(len=20) :: name, ch_err @@ -168,6 +169,11 @@ subroutine psb_siluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) goto 9999 end if endif + if (present(shft)) then + shft_ = shft + else + shft_ = szero + end if m = a%get_nrows() + blck_%get_nrows() if ((m /= l%get_nrows()).or.(m /= u%get_nrows()).or.& diff --git a/prec/impl/psb_s_ilut_fact.f90 b/prec/impl/psb_s_ilut_fact.f90 index 76b514a3..d63c5fb8 100644 --- a/prec/impl/psb_s_ilut_fact.f90 +++ b/prec/impl/psb_s_ilut_fact.f90 @@ -143,6 +143,7 @@ subroutine psb_silut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) ! Local Variables integer(psb_ipk_) :: l1, l2, m, err_act, iscale_ + real(psb_spk_) :: shft_ type(psb_sspmat_type), pointer :: blck_ type(psb_s_csr_sparse_mat) :: ll, uu real(psb_spk_) :: scale @@ -178,6 +179,11 @@ subroutine psb_silut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) else iscale_ = psb_ilu_scale_none_ end if + if (present(shft)) then + shft_ = shft + else + shft_ = szero + end if select case(iscale_) case(psb_ilu_scale_none_) diff --git a/prec/impl/psb_z_iluk_fact.f90 b/prec/impl/psb_z_iluk_fact.f90 index 7675226a..dbd9430c 100644 --- a/prec/impl/psb_z_iluk_fact.f90 +++ b/prec/impl/psb_z_iluk_fact.f90 @@ -145,6 +145,7 @@ subroutine psb_ziluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) ! Local Variables integer(psb_ipk_) :: l1, l2, m, err_act + complex(psb_dpk_) :: shft_ type(psb_zspmat_type), pointer :: blck_ type(psb_z_csr_sparse_mat) :: ll, uu character(len=20) :: name, ch_err @@ -168,6 +169,11 @@ subroutine psb_ziluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) goto 9999 end if endif + if (present(shft)) then + shft_ = shft + else + shft_ = zzero + end if m = a%get_nrows() + blck_%get_nrows() if ((m /= l%get_nrows()).or.(m /= u%get_nrows()).or.& diff --git a/prec/impl/psb_z_ilut_fact.f90 b/prec/impl/psb_z_ilut_fact.f90 index 2f004725..5e44f14a 100644 --- a/prec/impl/psb_z_ilut_fact.f90 +++ b/prec/impl/psb_z_ilut_fact.f90 @@ -143,6 +143,7 @@ subroutine psb_zilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) ! Local Variables integer(psb_ipk_) :: l1, l2, m, err_act, iscale_ + complex(psb_dpk_) :: shft_ type(psb_zspmat_type), pointer :: blck_ type(psb_z_csr_sparse_mat) :: ll, uu real(psb_dpk_) :: scale @@ -178,6 +179,11 @@ subroutine psb_zilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) else iscale_ = psb_ilu_scale_none_ end if + if (present(shft)) then + shft_ = shft + else + shft_ = zzero + end if select case(iscale_) case(psb_ilu_scale_none_) From 25e9183e5055d5050f7521e78e7b321331493f6c Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Mon, 23 Oct 2023 15:31:13 +0200 Subject: [PATCH 06/48] Fix SHFT implementation, step 2 --- prec/impl/psb_c_ilu0_fact.f90 | 4 ++-- prec/impl/psb_c_iluk_fact.f90 | 2 +- prec/impl/psb_c_ilut_fact.f90 | 2 +- prec/impl/psb_d_ilu0_fact.f90 | 4 ++-- prec/impl/psb_d_iluk_fact.f90 | 2 +- prec/impl/psb_d_ilut_fact.f90 | 2 +- prec/impl/psb_s_ilu0_fact.f90 | 4 ++-- prec/impl/psb_s_iluk_fact.f90 | 2 +- prec/impl/psb_s_ilut_fact.f90 | 2 +- prec/impl/psb_z_ilu0_fact.f90 | 4 ++-- prec/impl/psb_z_iluk_fact.f90 | 2 +- prec/impl/psb_z_ilut_fact.f90 | 2 +- 12 files changed, 16 insertions(+), 16 deletions(-) diff --git a/prec/impl/psb_c_ilu0_fact.f90 b/prec/impl/psb_c_ilu0_fact.f90 index 9e039b19..c016359f 100644 --- a/prec/impl/psb_c_ilu0_fact.f90 +++ b/prec/impl/psb_c_ilu0_fact.f90 @@ -390,14 +390,14 @@ contains ! into lval/d(i)/uval ! call ilu_copyin(i,ma,a,i,ione,m,l1,lja,lval,& - & d(i),l2,uja,uval,ktrw,trw,upd,shft) + & d(i),l2,uja,uval,ktrw,trw,upd,shft_) else ! ! Copy the i-th local row of the matrix, stored in b ! (as (i-ma)-th row), into lval/d(i)/uval ! call ilu_copyin(i-ma,mb,b,i,ione,m,l1,lja,lval,& - & d(i),l2,uja,uval,ktrw,trw,upd,shft) + & d(i),l2,uja,uval,ktrw,trw,upd,shft_) endif lirp(i+1) = l1 + 1 diff --git a/prec/impl/psb_c_iluk_fact.f90 b/prec/impl/psb_c_iluk_fact.f90 index 4e6d7a5c..6c6d8a5f 100644 --- a/prec/impl/psb_c_iluk_fact.f90 +++ b/prec/impl/psb_c_iluk_fact.f90 @@ -191,7 +191,7 @@ subroutine psb_ciluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) ! Compute the ILU(k) or the MILU(k) factorization, depending on ialg ! call psb_ciluk_factint(fill_in,ialg,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,shft) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,shft_) if (info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='psb_ciluk_factint' diff --git a/prec/impl/psb_c_ilut_fact.f90 b/prec/impl/psb_c_ilut_fact.f90 index 6aecb53b..8421ee1c 100644 --- a/prec/impl/psb_c_ilut_fact.f90 +++ b/prec/impl/psb_c_ilut_fact.f90 @@ -213,7 +213,7 @@ subroutine psb_cilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) ! Compute the ILU(k,t) factorization ! call psb_cilut_factint(fill_in,thres,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,scale,shft) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,scale,shft_) if (info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='psb_cilut_factint' diff --git a/prec/impl/psb_d_ilu0_fact.f90 b/prec/impl/psb_d_ilu0_fact.f90 index 29968e0c..dde22249 100644 --- a/prec/impl/psb_d_ilu0_fact.f90 +++ b/prec/impl/psb_d_ilu0_fact.f90 @@ -390,14 +390,14 @@ contains ! into lval/d(i)/uval ! call ilu_copyin(i,ma,a,i,ione,m,l1,lja,lval,& - & d(i),l2,uja,uval,ktrw,trw,upd,shft) + & d(i),l2,uja,uval,ktrw,trw,upd,shft_) else ! ! Copy the i-th local row of the matrix, stored in b ! (as (i-ma)-th row), into lval/d(i)/uval ! call ilu_copyin(i-ma,mb,b,i,ione,m,l1,lja,lval,& - & d(i),l2,uja,uval,ktrw,trw,upd,shft) + & d(i),l2,uja,uval,ktrw,trw,upd,shft_) endif lirp(i+1) = l1 + 1 diff --git a/prec/impl/psb_d_iluk_fact.f90 b/prec/impl/psb_d_iluk_fact.f90 index 94c34fec..dc837ba9 100644 --- a/prec/impl/psb_d_iluk_fact.f90 +++ b/prec/impl/psb_d_iluk_fact.f90 @@ -191,7 +191,7 @@ subroutine psb_diluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) ! Compute the ILU(k) or the MILU(k) factorization, depending on ialg ! call psb_diluk_factint(fill_in,ialg,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,shft) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,shft_) if (info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='psb_diluk_factint' diff --git a/prec/impl/psb_d_ilut_fact.f90 b/prec/impl/psb_d_ilut_fact.f90 index c4f05df6..cd185e80 100644 --- a/prec/impl/psb_d_ilut_fact.f90 +++ b/prec/impl/psb_d_ilut_fact.f90 @@ -213,7 +213,7 @@ subroutine psb_dilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) ! Compute the ILU(k,t) factorization ! call psb_dilut_factint(fill_in,thres,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,scale,shft) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,scale,shft_) if (info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='psb_dilut_factint' diff --git a/prec/impl/psb_s_ilu0_fact.f90 b/prec/impl/psb_s_ilu0_fact.f90 index faa21fda..d9ce1298 100644 --- a/prec/impl/psb_s_ilu0_fact.f90 +++ b/prec/impl/psb_s_ilu0_fact.f90 @@ -390,14 +390,14 @@ contains ! into lval/d(i)/uval ! call ilu_copyin(i,ma,a,i,ione,m,l1,lja,lval,& - & d(i),l2,uja,uval,ktrw,trw,upd,shft) + & d(i),l2,uja,uval,ktrw,trw,upd,shft_) else ! ! Copy the i-th local row of the matrix, stored in b ! (as (i-ma)-th row), into lval/d(i)/uval ! call ilu_copyin(i-ma,mb,b,i,ione,m,l1,lja,lval,& - & d(i),l2,uja,uval,ktrw,trw,upd,shft) + & d(i),l2,uja,uval,ktrw,trw,upd,shft_) endif lirp(i+1) = l1 + 1 diff --git a/prec/impl/psb_s_iluk_fact.f90 b/prec/impl/psb_s_iluk_fact.f90 index 99d48880..67fb8ada 100644 --- a/prec/impl/psb_s_iluk_fact.f90 +++ b/prec/impl/psb_s_iluk_fact.f90 @@ -191,7 +191,7 @@ subroutine psb_siluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) ! Compute the ILU(k) or the MILU(k) factorization, depending on ialg ! call psb_siluk_factint(fill_in,ialg,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,shft) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,shft_) if (info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='psb_siluk_factint' diff --git a/prec/impl/psb_s_ilut_fact.f90 b/prec/impl/psb_s_ilut_fact.f90 index d63c5fb8..3d111103 100644 --- a/prec/impl/psb_s_ilut_fact.f90 +++ b/prec/impl/psb_s_ilut_fact.f90 @@ -213,7 +213,7 @@ subroutine psb_silut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) ! Compute the ILU(k,t) factorization ! call psb_silut_factint(fill_in,thres,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,scale,shft) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,scale,shft_) if (info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='psb_silut_factint' diff --git a/prec/impl/psb_z_ilu0_fact.f90 b/prec/impl/psb_z_ilu0_fact.f90 index c6c0fc55..997a5e05 100644 --- a/prec/impl/psb_z_ilu0_fact.f90 +++ b/prec/impl/psb_z_ilu0_fact.f90 @@ -390,14 +390,14 @@ contains ! into lval/d(i)/uval ! call ilu_copyin(i,ma,a,i,ione,m,l1,lja,lval,& - & d(i),l2,uja,uval,ktrw,trw,upd,shft) + & d(i),l2,uja,uval,ktrw,trw,upd,shft_) else ! ! Copy the i-th local row of the matrix, stored in b ! (as (i-ma)-th row), into lval/d(i)/uval ! call ilu_copyin(i-ma,mb,b,i,ione,m,l1,lja,lval,& - & d(i),l2,uja,uval,ktrw,trw,upd,shft) + & d(i),l2,uja,uval,ktrw,trw,upd,shft_) endif lirp(i+1) = l1 + 1 diff --git a/prec/impl/psb_z_iluk_fact.f90 b/prec/impl/psb_z_iluk_fact.f90 index dbd9430c..a5540880 100644 --- a/prec/impl/psb_z_iluk_fact.f90 +++ b/prec/impl/psb_z_iluk_fact.f90 @@ -191,7 +191,7 @@ subroutine psb_ziluk_fact(fill_in,ialg,a,l,u,d,info,blck,shft) ! Compute the ILU(k) or the MILU(k) factorization, depending on ialg ! call psb_ziluk_factint(fill_in,ialg,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,shft) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,shft_) if (info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='psb_ziluk_factint' diff --git a/prec/impl/psb_z_ilut_fact.f90 b/prec/impl/psb_z_ilut_fact.f90 index 5e44f14a..0c278515 100644 --- a/prec/impl/psb_z_ilut_fact.f90 +++ b/prec/impl/psb_z_ilut_fact.f90 @@ -213,7 +213,7 @@ subroutine psb_zilut_fact(fill_in,thres,a,l,u,d,info,blck,iscale,shft) ! Compute the ILU(k,t) factorization ! call psb_zilut_factint(fill_in,thres,a,blck_,& - & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,scale,shft) + & d,ll%val,ll%ja,ll%irp,uu%val,uu%ja,uu%irp,l1,l2,info,scale,shft_) if (info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='psb_zilut_factint' From d82b0902899a8d436c005d8c11bdde9fd7efc55f Mon Sep 17 00:00:00 2001 From: sfilippone Date: Wed, 25 Oct 2023 12:47:31 +0200 Subject: [PATCH 07/48] Fix makefile for psi_acx & friends --- base/modules/Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/base/modules/Makefile b/base/modules/Makefile index f1cb5e53..9c14dab8 100644 --- a/base/modules/Makefile +++ b/base/modules/Makefile @@ -39,6 +39,7 @@ SERIAL_MODS=serial/psb_s_serial_mod.o serial/psb_d_serial_mod.o \ auxil/psi_c_serial_mod.o auxil/psi_z_serial_mod.o \ psi_mod.o psi_i_mod.o psi_l_mod.o psi_s_mod.o psi_d_mod.o psi_c_mod.o psi_z_mod.o\ auxil/psb_ip_reord_mod.o\ + auxil/psi_acx_mod.o auxil/psi_alcx_mod.o auxil/psi_lcx_mod.o \ auxil/psb_m_ip_reord_mod.o auxil/psb_e_ip_reord_mod.o \ auxil/psb_s_ip_reord_mod.o auxil/psb_d_ip_reord_mod.o \ auxil/psb_c_ip_reord_mod.o auxil/psb_z_ip_reord_mod.o \ From 5caee551e5c81a042b66a6e9ff5059e8ba9587d3 Mon Sep 17 00:00:00 2001 From: sfilippone Date: Fri, 3 Nov 2023 14:28:04 +0100 Subject: [PATCH 08/48] Fixed IN_PLACE option for collectives. --- base/comm/psb_cgather.f90 | 43 +- base/comm/psb_cgather_a.f90 | 37 +- base/comm/psb_dgather.f90 | 43 +- base/comm/psb_dgather_a.f90 | 37 +- base/comm/psb_egather_a.f90 | 37 +- base/comm/psb_i2gather_a.f90 | 37 +- base/comm/psb_igather.f90 | 43 +- base/comm/psb_lgather.f90 | 43 +- base/comm/psb_mgather_a.f90 | 37 +- base/comm/psb_sgather.f90 | 43 +- base/comm/psb_sgather_a.f90 | 37 +- base/comm/psb_zgather.f90 | 43 +- base/comm/psb_zgather_a.f90 | 37 +- base/modules/penv/psi_c_collective_mod.F90 | 469 ++++++++++++-- base/modules/penv/psi_d_collective_mod.F90 | 651 ++++++++++++++++---- base/modules/penv/psi_e_collective_mod.F90 | 607 +++++++++++++++--- base/modules/penv/psi_i2_collective_mod.F90 | 607 +++++++++++++++--- base/modules/penv/psi_m_collective_mod.F90 | 607 +++++++++++++++--- base/modules/penv/psi_s_collective_mod.F90 | 651 ++++++++++++++++---- base/modules/penv/psi_z_collective_mod.F90 | 469 ++++++++++++-- 20 files changed, 3720 insertions(+), 858 deletions(-) diff --git a/base/comm/psb_cgather.f90 b/base/comm/psb_cgather.f90 index 7893d7c3..fc7ba7fb 100644 --- a/base/comm/psb_cgather.f90 +++ b/base/comm/psb_cgather.f90 @@ -58,10 +58,11 @@ subroutine psb_cgather_vect(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows integer(psb_ipk_) :: ierr(5), err_act, jlx, ilx, lda_locx, lda_globx, i integer(psb_lpk_) :: m, n, k, ilocx, jlocx, idx, iglobx, jglobx complex(psb_spk_), allocatable :: llocx(:) + integer(psb_mpk_), allocatable :: szs(:) character(len=20) :: name, ch_err name='psb_cgatherv' @@ -125,32 +126,36 @@ subroutine psb_cgather_vect(globx, locx, desc_a, info, iroot) goto 9999 end if - call psb_realloc(m,globx,info) - if (info /= psb_success_) then - info=psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if - - globx(:) = czero - llocx = locx%get_vect() - do i=1,desc_a%get_local_rows() - call psb_loc_to_glob(i,idx,desc_a,info) - globx(idx) = llocx(i) - end do - + llocx = locx%get_vect() ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - call psb_loc_to_glob(idx,desc_a,info) - globx(idx) = czero + llocx(idx) = czero end if end do - - call psb_sum(ctxt,globx(1:m),root=root) + if ((me == root).or.(root == -1)) then + allocate(szs(np)) + end if + loc_rows = desc_a%get_local_rows() + call psb_gather(ctxt,loc_rows,szs,root=root) + if ((me == root).or.(root == -1)) then + if (sum(szs) /= m) then + info=psb_err_internal_error_ + call psb_errpush(info,name) + goto 9999 + end if + call psb_realloc(m,globx,info) + if (info /= psb_success_) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + end if + call psb_gatherv(ctxt,llocx(1:loc_rows),globx,szs,root=root) + call psb_erractionrestore(err_act) return diff --git a/base/comm/psb_cgather_a.f90 b/base/comm/psb_cgather_a.f90 index ac2e66e4..9212b328 100644 --- a/base/comm/psb_cgather_a.f90 +++ b/base/comm/psb_cgather_a.f90 @@ -60,7 +60,7 @@ subroutine psb_cgatherm(globx, locx, desc_a, info, iroot) type(psb_ctxt_type) :: ctxt integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& - & maxk, k, jlx, ilx, i, j + & maxk, k, jlx, ilx, i, j, loc_rows integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx character(len=20) :: name, ch_err @@ -232,11 +232,11 @@ subroutine psb_cgatherv(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx - + integer(psb_mpk_), allocatable :: szs(:) character(len=20) :: name, ch_err name='psb_cgatherv' @@ -307,23 +307,32 @@ subroutine psb_cgatherv(globx, locx, desc_a, info, iroot) goto 9999 end if - globx(:)=czero - - do i=1,desc_a%get_local_rows() - call psb_loc_to_glob(i,idx,desc_a,info) - globx(idx) = locx(i) - end do - ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - call psb_loc_to_glob(idx,desc_a,info) - globx(idx) = czero + locx(idx) = czero end if end do - - call psb_sum(ctxt,globx(1:m),root=root) + loc_rows = desc_a%get_local_rows() + if ((me == root).or.(root == -1)) then + allocate(szs(np)) + end if + call psb_gather(ctxt,loc_rows,szs,root=root) + if ((me == root).or.(root == -1)) then + if (sum(szs) /= m) then + info=psb_err_internal_error_ + call psb_errpush(info,name) + goto 9999 + end if + call psb_realloc(m,globx,info) + if (info /= psb_success_) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + end if + call psb_gatherv(ctxt,locx(1:loc_rows),globx,szs,root=root) call psb_erractionrestore(err_act) return diff --git a/base/comm/psb_dgather.f90 b/base/comm/psb_dgather.f90 index c1619b1b..a12be1e4 100644 --- a/base/comm/psb_dgather.f90 +++ b/base/comm/psb_dgather.f90 @@ -58,10 +58,11 @@ subroutine psb_dgather_vect(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows integer(psb_ipk_) :: ierr(5), err_act, jlx, ilx, lda_locx, lda_globx, i integer(psb_lpk_) :: m, n, k, ilocx, jlocx, idx, iglobx, jglobx real(psb_dpk_), allocatable :: llocx(:) + integer(psb_mpk_), allocatable :: szs(:) character(len=20) :: name, ch_err name='psb_dgatherv' @@ -125,32 +126,36 @@ subroutine psb_dgather_vect(globx, locx, desc_a, info, iroot) goto 9999 end if - call psb_realloc(m,globx,info) - if (info /= psb_success_) then - info=psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if - - globx(:) = dzero - llocx = locx%get_vect() - do i=1,desc_a%get_local_rows() - call psb_loc_to_glob(i,idx,desc_a,info) - globx(idx) = llocx(i) - end do - + llocx = locx%get_vect() ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - call psb_loc_to_glob(idx,desc_a,info) - globx(idx) = dzero + llocx(idx) = dzero end if end do - - call psb_sum(ctxt,globx(1:m),root=root) + if ((me == root).or.(root == -1)) then + allocate(szs(np)) + end if + loc_rows = desc_a%get_local_rows() + call psb_gather(ctxt,loc_rows,szs,root=root) + if ((me == root).or.(root == -1)) then + if (sum(szs) /= m) then + info=psb_err_internal_error_ + call psb_errpush(info,name) + goto 9999 + end if + call psb_realloc(m,globx,info) + if (info /= psb_success_) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + end if + call psb_gatherv(ctxt,llocx(1:loc_rows),globx,szs,root=root) + call psb_erractionrestore(err_act) return diff --git a/base/comm/psb_dgather_a.f90 b/base/comm/psb_dgather_a.f90 index 1e03ccfd..eec28bdc 100644 --- a/base/comm/psb_dgather_a.f90 +++ b/base/comm/psb_dgather_a.f90 @@ -60,7 +60,7 @@ subroutine psb_dgatherm(globx, locx, desc_a, info, iroot) type(psb_ctxt_type) :: ctxt integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& - & maxk, k, jlx, ilx, i, j + & maxk, k, jlx, ilx, i, j, loc_rows integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx character(len=20) :: name, ch_err @@ -232,11 +232,11 @@ subroutine psb_dgatherv(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx - + integer(psb_mpk_), allocatable :: szs(:) character(len=20) :: name, ch_err name='psb_dgatherv' @@ -307,23 +307,32 @@ subroutine psb_dgatherv(globx, locx, desc_a, info, iroot) goto 9999 end if - globx(:)=dzero - - do i=1,desc_a%get_local_rows() - call psb_loc_to_glob(i,idx,desc_a,info) - globx(idx) = locx(i) - end do - ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - call psb_loc_to_glob(idx,desc_a,info) - globx(idx) = dzero + locx(idx) = dzero end if end do - - call psb_sum(ctxt,globx(1:m),root=root) + loc_rows = desc_a%get_local_rows() + if ((me == root).or.(root == -1)) then + allocate(szs(np)) + end if + call psb_gather(ctxt,loc_rows,szs,root=root) + if ((me == root).or.(root == -1)) then + if (sum(szs) /= m) then + info=psb_err_internal_error_ + call psb_errpush(info,name) + goto 9999 + end if + call psb_realloc(m,globx,info) + if (info /= psb_success_) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + end if + call psb_gatherv(ctxt,locx(1:loc_rows),globx,szs,root=root) call psb_erractionrestore(err_act) return diff --git a/base/comm/psb_egather_a.f90 b/base/comm/psb_egather_a.f90 index b777cebd..21a41143 100644 --- a/base/comm/psb_egather_a.f90 +++ b/base/comm/psb_egather_a.f90 @@ -60,7 +60,7 @@ subroutine psb_egatherm(globx, locx, desc_a, info, iroot) type(psb_ctxt_type) :: ctxt integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& - & maxk, k, jlx, ilx, i, j + & maxk, k, jlx, ilx, i, j, loc_rows integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx character(len=20) :: name, ch_err @@ -232,11 +232,11 @@ subroutine psb_egatherv(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx - + integer(psb_mpk_), allocatable :: szs(:) character(len=20) :: name, ch_err name='psb_egatherv' @@ -307,23 +307,32 @@ subroutine psb_egatherv(globx, locx, desc_a, info, iroot) goto 9999 end if - globx(:)=ezero - - do i=1,desc_a%get_local_rows() - call psb_loc_to_glob(i,idx,desc_a,info) - globx(idx) = locx(i) - end do - ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - call psb_loc_to_glob(idx,desc_a,info) - globx(idx) = ezero + locx(idx) = ezero end if end do - - call psb_sum(ctxt,globx(1:m),root=root) + loc_rows = desc_a%get_local_rows() + if ((me == root).or.(root == -1)) then + allocate(szs(np)) + end if + call psb_gather(ctxt,loc_rows,szs,root=root) + if ((me == root).or.(root == -1)) then + if (sum(szs) /= m) then + info=psb_err_internal_error_ + call psb_errpush(info,name) + goto 9999 + end if + call psb_realloc(m,globx,info) + if (info /= psb_success_) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + end if + call psb_gatherv(ctxt,locx(1:loc_rows),globx,szs,root=root) call psb_erractionrestore(err_act) return diff --git a/base/comm/psb_i2gather_a.f90 b/base/comm/psb_i2gather_a.f90 index e0e1ed7a..f0f2a93a 100644 --- a/base/comm/psb_i2gather_a.f90 +++ b/base/comm/psb_i2gather_a.f90 @@ -60,7 +60,7 @@ subroutine psb_i2gatherm(globx, locx, desc_a, info, iroot) type(psb_ctxt_type) :: ctxt integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& - & maxk, k, jlx, ilx, i, j + & maxk, k, jlx, ilx, i, j, loc_rows integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx character(len=20) :: name, ch_err @@ -232,11 +232,11 @@ subroutine psb_i2gatherv(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx - + integer(psb_mpk_), allocatable :: szs(:) character(len=20) :: name, ch_err name='psb_i2gatherv' @@ -307,23 +307,32 @@ subroutine psb_i2gatherv(globx, locx, desc_a, info, iroot) goto 9999 end if - globx(:)=i2zero - - do i=1,desc_a%get_local_rows() - call psb_loc_to_glob(i,idx,desc_a,info) - globx(idx) = locx(i) - end do - ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - call psb_loc_to_glob(idx,desc_a,info) - globx(idx) = i2zero + locx(idx) = i2zero end if end do - - call psb_sum(ctxt,globx(1:m),root=root) + loc_rows = desc_a%get_local_rows() + if ((me == root).or.(root == -1)) then + allocate(szs(np)) + end if + call psb_gather(ctxt,loc_rows,szs,root=root) + if ((me == root).or.(root == -1)) then + if (sum(szs) /= m) then + info=psb_err_internal_error_ + call psb_errpush(info,name) + goto 9999 + end if + call psb_realloc(m,globx,info) + if (info /= psb_success_) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + end if + call psb_gatherv(ctxt,locx(1:loc_rows),globx,szs,root=root) call psb_erractionrestore(err_act) return diff --git a/base/comm/psb_igather.f90 b/base/comm/psb_igather.f90 index afa794eb..62a84173 100644 --- a/base/comm/psb_igather.f90 +++ b/base/comm/psb_igather.f90 @@ -58,10 +58,11 @@ subroutine psb_igather_vect(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows integer(psb_ipk_) :: ierr(5), err_act, jlx, ilx, lda_locx, lda_globx, i integer(psb_lpk_) :: m, n, k, ilocx, jlocx, idx, iglobx, jglobx integer(psb_ipk_), allocatable :: llocx(:) + integer(psb_mpk_), allocatable :: szs(:) character(len=20) :: name, ch_err name='psb_igatherv' @@ -125,32 +126,36 @@ subroutine psb_igather_vect(globx, locx, desc_a, info, iroot) goto 9999 end if - call psb_realloc(m,globx,info) - if (info /= psb_success_) then - info=psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if - - globx(:) = izero - llocx = locx%get_vect() - do i=1,desc_a%get_local_rows() - call psb_loc_to_glob(i,idx,desc_a,info) - globx(idx) = llocx(i) - end do - + llocx = locx%get_vect() ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - call psb_loc_to_glob(idx,desc_a,info) - globx(idx) = izero + llocx(idx) = izero end if end do - - call psb_sum(ctxt,globx(1:m),root=root) + if ((me == root).or.(root == -1)) then + allocate(szs(np)) + end if + loc_rows = desc_a%get_local_rows() + call psb_gather(ctxt,loc_rows,szs,root=root) + if ((me == root).or.(root == -1)) then + if (sum(szs) /= m) then + info=psb_err_internal_error_ + call psb_errpush(info,name) + goto 9999 + end if + call psb_realloc(m,globx,info) + if (info /= psb_success_) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + end if + call psb_gatherv(ctxt,llocx(1:loc_rows),globx,szs,root=root) + call psb_erractionrestore(err_act) return diff --git a/base/comm/psb_lgather.f90 b/base/comm/psb_lgather.f90 index 00af3cd1..7b4e7ac9 100644 --- a/base/comm/psb_lgather.f90 +++ b/base/comm/psb_lgather.f90 @@ -58,10 +58,11 @@ subroutine psb_lgather_vect(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows integer(psb_ipk_) :: ierr(5), err_act, jlx, ilx, lda_locx, lda_globx, i integer(psb_lpk_) :: m, n, k, ilocx, jlocx, idx, iglobx, jglobx integer(psb_lpk_), allocatable :: llocx(:) + integer(psb_mpk_), allocatable :: szs(:) character(len=20) :: name, ch_err name='psb_lgatherv' @@ -125,32 +126,36 @@ subroutine psb_lgather_vect(globx, locx, desc_a, info, iroot) goto 9999 end if - call psb_realloc(m,globx,info) - if (info /= psb_success_) then - info=psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if - - globx(:) = lzero - llocx = locx%get_vect() - do i=1,desc_a%get_local_rows() - call psb_loc_to_glob(i,idx,desc_a,info) - globx(idx) = llocx(i) - end do - + llocx = locx%get_vect() ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - call psb_loc_to_glob(idx,desc_a,info) - globx(idx) = lzero + llocx(idx) = lzero end if end do - - call psb_sum(ctxt,globx(1:m),root=root) + if ((me == root).or.(root == -1)) then + allocate(szs(np)) + end if + loc_rows = desc_a%get_local_rows() + call psb_gather(ctxt,loc_rows,szs,root=root) + if ((me == root).or.(root == -1)) then + if (sum(szs) /= m) then + info=psb_err_internal_error_ + call psb_errpush(info,name) + goto 9999 + end if + call psb_realloc(m,globx,info) + if (info /= psb_success_) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + end if + call psb_gatherv(ctxt,llocx(1:loc_rows),globx,szs,root=root) + call psb_erractionrestore(err_act) return diff --git a/base/comm/psb_mgather_a.f90 b/base/comm/psb_mgather_a.f90 index df574ea2..ccf2f0c0 100644 --- a/base/comm/psb_mgather_a.f90 +++ b/base/comm/psb_mgather_a.f90 @@ -60,7 +60,7 @@ subroutine psb_mgatherm(globx, locx, desc_a, info, iroot) type(psb_ctxt_type) :: ctxt integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& - & maxk, k, jlx, ilx, i, j + & maxk, k, jlx, ilx, i, j, loc_rows integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx character(len=20) :: name, ch_err @@ -232,11 +232,11 @@ subroutine psb_mgatherv(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx - + integer(psb_mpk_), allocatable :: szs(:) character(len=20) :: name, ch_err name='psb_mgatherv' @@ -307,23 +307,32 @@ subroutine psb_mgatherv(globx, locx, desc_a, info, iroot) goto 9999 end if - globx(:)=mzero - - do i=1,desc_a%get_local_rows() - call psb_loc_to_glob(i,idx,desc_a,info) - globx(idx) = locx(i) - end do - ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - call psb_loc_to_glob(idx,desc_a,info) - globx(idx) = mzero + locx(idx) = mzero end if end do - - call psb_sum(ctxt,globx(1:m),root=root) + loc_rows = desc_a%get_local_rows() + if ((me == root).or.(root == -1)) then + allocate(szs(np)) + end if + call psb_gather(ctxt,loc_rows,szs,root=root) + if ((me == root).or.(root == -1)) then + if (sum(szs) /= m) then + info=psb_err_internal_error_ + call psb_errpush(info,name) + goto 9999 + end if + call psb_realloc(m,globx,info) + if (info /= psb_success_) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + end if + call psb_gatherv(ctxt,locx(1:loc_rows),globx,szs,root=root) call psb_erractionrestore(err_act) return diff --git a/base/comm/psb_sgather.f90 b/base/comm/psb_sgather.f90 index 21ce1408..30d25440 100644 --- a/base/comm/psb_sgather.f90 +++ b/base/comm/psb_sgather.f90 @@ -58,10 +58,11 @@ subroutine psb_sgather_vect(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows integer(psb_ipk_) :: ierr(5), err_act, jlx, ilx, lda_locx, lda_globx, i integer(psb_lpk_) :: m, n, k, ilocx, jlocx, idx, iglobx, jglobx real(psb_spk_), allocatable :: llocx(:) + integer(psb_mpk_), allocatable :: szs(:) character(len=20) :: name, ch_err name='psb_sgatherv' @@ -125,32 +126,36 @@ subroutine psb_sgather_vect(globx, locx, desc_a, info, iroot) goto 9999 end if - call psb_realloc(m,globx,info) - if (info /= psb_success_) then - info=psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if - - globx(:) = szero - llocx = locx%get_vect() - do i=1,desc_a%get_local_rows() - call psb_loc_to_glob(i,idx,desc_a,info) - globx(idx) = llocx(i) - end do - + llocx = locx%get_vect() ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - call psb_loc_to_glob(idx,desc_a,info) - globx(idx) = szero + llocx(idx) = szero end if end do - - call psb_sum(ctxt,globx(1:m),root=root) + if ((me == root).or.(root == -1)) then + allocate(szs(np)) + end if + loc_rows = desc_a%get_local_rows() + call psb_gather(ctxt,loc_rows,szs,root=root) + if ((me == root).or.(root == -1)) then + if (sum(szs) /= m) then + info=psb_err_internal_error_ + call psb_errpush(info,name) + goto 9999 + end if + call psb_realloc(m,globx,info) + if (info /= psb_success_) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + end if + call psb_gatherv(ctxt,llocx(1:loc_rows),globx,szs,root=root) + call psb_erractionrestore(err_act) return diff --git a/base/comm/psb_sgather_a.f90 b/base/comm/psb_sgather_a.f90 index 28d5f5dc..27e21e78 100644 --- a/base/comm/psb_sgather_a.f90 +++ b/base/comm/psb_sgather_a.f90 @@ -60,7 +60,7 @@ subroutine psb_sgatherm(globx, locx, desc_a, info, iroot) type(psb_ctxt_type) :: ctxt integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& - & maxk, k, jlx, ilx, i, j + & maxk, k, jlx, ilx, i, j, loc_rows integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx character(len=20) :: name, ch_err @@ -232,11 +232,11 @@ subroutine psb_sgatherv(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx - + integer(psb_mpk_), allocatable :: szs(:) character(len=20) :: name, ch_err name='psb_sgatherv' @@ -307,23 +307,32 @@ subroutine psb_sgatherv(globx, locx, desc_a, info, iroot) goto 9999 end if - globx(:)=szero - - do i=1,desc_a%get_local_rows() - call psb_loc_to_glob(i,idx,desc_a,info) - globx(idx) = locx(i) - end do - ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - call psb_loc_to_glob(idx,desc_a,info) - globx(idx) = szero + locx(idx) = szero end if end do - - call psb_sum(ctxt,globx(1:m),root=root) + loc_rows = desc_a%get_local_rows() + if ((me == root).or.(root == -1)) then + allocate(szs(np)) + end if + call psb_gather(ctxt,loc_rows,szs,root=root) + if ((me == root).or.(root == -1)) then + if (sum(szs) /= m) then + info=psb_err_internal_error_ + call psb_errpush(info,name) + goto 9999 + end if + call psb_realloc(m,globx,info) + if (info /= psb_success_) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + end if + call psb_gatherv(ctxt,locx(1:loc_rows),globx,szs,root=root) call psb_erractionrestore(err_act) return diff --git a/base/comm/psb_zgather.f90 b/base/comm/psb_zgather.f90 index 53cba210..d60f15c6 100644 --- a/base/comm/psb_zgather.f90 +++ b/base/comm/psb_zgather.f90 @@ -58,10 +58,11 @@ subroutine psb_zgather_vect(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows integer(psb_ipk_) :: ierr(5), err_act, jlx, ilx, lda_locx, lda_globx, i integer(psb_lpk_) :: m, n, k, ilocx, jlocx, idx, iglobx, jglobx complex(psb_dpk_), allocatable :: llocx(:) + integer(psb_mpk_), allocatable :: szs(:) character(len=20) :: name, ch_err name='psb_zgatherv' @@ -125,32 +126,36 @@ subroutine psb_zgather_vect(globx, locx, desc_a, info, iroot) goto 9999 end if - call psb_realloc(m,globx,info) - if (info /= psb_success_) then - info=psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if - - globx(:) = zzero - llocx = locx%get_vect() - do i=1,desc_a%get_local_rows() - call psb_loc_to_glob(i,idx,desc_a,info) - globx(idx) = llocx(i) - end do - + llocx = locx%get_vect() ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - call psb_loc_to_glob(idx,desc_a,info) - globx(idx) = zzero + llocx(idx) = zzero end if end do - - call psb_sum(ctxt,globx(1:m),root=root) + if ((me == root).or.(root == -1)) then + allocate(szs(np)) + end if + loc_rows = desc_a%get_local_rows() + call psb_gather(ctxt,loc_rows,szs,root=root) + if ((me == root).or.(root == -1)) then + if (sum(szs) /= m) then + info=psb_err_internal_error_ + call psb_errpush(info,name) + goto 9999 + end if + call psb_realloc(m,globx,info) + if (info /= psb_success_) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + end if + call psb_gatherv(ctxt,llocx(1:loc_rows),globx,szs,root=root) + call psb_erractionrestore(err_act) return diff --git a/base/comm/psb_zgather_a.f90 b/base/comm/psb_zgather_a.f90 index fa5f288b..98ed8772 100644 --- a/base/comm/psb_zgather_a.f90 +++ b/base/comm/psb_zgather_a.f90 @@ -60,7 +60,7 @@ subroutine psb_zgatherm(globx, locx, desc_a, info, iroot) type(psb_ctxt_type) :: ctxt integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& - & maxk, k, jlx, ilx, i, j + & maxk, k, jlx, ilx, i, j, loc_rows integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx character(len=20) :: name, ch_err @@ -232,11 +232,11 @@ subroutine psb_zgatherv(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx - + integer(psb_mpk_), allocatable :: szs(:) character(len=20) :: name, ch_err name='psb_zgatherv' @@ -307,23 +307,32 @@ subroutine psb_zgatherv(globx, locx, desc_a, info, iroot) goto 9999 end if - globx(:)=zzero - - do i=1,desc_a%get_local_rows() - call psb_loc_to_glob(i,idx,desc_a,info) - globx(idx) = locx(i) - end do - ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - call psb_loc_to_glob(idx,desc_a,info) - globx(idx) = zzero + locx(idx) = zzero end if end do - - call psb_sum(ctxt,globx(1:m),root=root) + loc_rows = desc_a%get_local_rows() + if ((me == root).or.(root == -1)) then + allocate(szs(np)) + end if + call psb_gather(ctxt,loc_rows,szs,root=root) + if ((me == root).or.(root == -1)) then + if (sum(szs) /= m) then + info=psb_err_internal_error_ + call psb_errpush(info,name) + goto 9999 + end if + call psb_realloc(m,globx,info) + if (info /= psb_success_) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + end if + call psb_gatherv(ctxt,locx(1:loc_rows),globx,szs,root=root) call psb_erractionrestore(err_act) return diff --git a/base/modules/penv/psi_c_collective_mod.F90 b/base/modules/penv/psi_c_collective_mod.F90 index 6da00176..8da302d0 100644 --- a/base/modules/penv/psi_c_collective_mod.F90 +++ b/base/modules/penv/psi_c_collective_mod.F90 @@ -34,6 +34,14 @@ module psi_c_collective_mod use psb_desc_const_mod + interface psb_gather + module procedure psb_cgather_s, psb_cgather_v + end interface psb_gather + + interface psb_gatherv + module procedure psb_cgatherv_v + end interface + interface psb_sum module procedure psb_csums, psb_csumv, psb_csumm end interface @@ -76,6 +84,250 @@ contains + ! + ! gather + ! + subroutine psb_cgather_s(ctxt,dat,resv,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + complex(psb_spk_), intent(inout) :: dat, resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allgather(dat,1,psb_mpi_c_spk_,& + & resv,1,psb_mpi_c_spk_,icomm,info) + else + call mpi_gather(dat,1,psb_mpi_c_spk_,& + & resv,1,psb_mpi_c_spk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallgather(dat,1,psb_mpi_c_spk_,& + & resv,1,psb_mpi_c_spk_,icomm,request,info) + else + call mpi_igather(dat,1,psb_mpi_c_spk_,& + & resv,1,psb_mpi_c_spk_,root_,icomm,request,info) + endif + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_cgather_s + + subroutine psb_cgather_v(ctxt,dat,resv,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + complex(psb_spk_), intent(inout) :: dat(:), resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allgather(dat,size(dat),psb_mpi_c_spk_,& + & resv,size(dat),psb_mpi_c_spk_,icomm,info) + else + call mpi_gather(dat,size(dat),psb_mpi_c_spk_,& + & resv,size(dat),psb_mpi_c_spk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallgather(dat,size(dat),psb_mpi_c_spk_,& + & resv,size(dat),psb_mpi_c_spk_,icomm,request,info) + else + call mpi_igather(dat,size(dat),psb_mpi_c_spk_,& + & resv,size(dat),psb_mpi_c_spk_,root_,icomm,request,info) + endif + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_cgather_v + + subroutine psb_cgatherv_v(ctxt,dat,resv,szs,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + complex(psb_spk_), intent(inout) :: dat(:), resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_mpk_), intent(in), optional :: szs(:) + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info,i + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + integer(psb_mpk_), allocatable :: displs(:) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + call mpi_allgatherv(dat,size(dat),psb_mpi_c_spk_,& + & resv,szs,displs,psb_mpi_c_spk_,icomm,info) + else + if (iam == root_) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + else + allocate(displs(0)) + end if + call mpi_gatherv(dat,size(dat),psb_mpi_c_spk_,& + & resv,szs,displs,psb_mpi_c_spk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + call mpi_iallgatherv(dat,size(dat),psb_mpi_c_spk_,& + & resv,szs,displs,psb_mpi_c_spk_,icomm,request,info) + else + if (iam == root_) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + else + allocate(displs(0)) + end if + call mpi_igatherv(dat,size(dat),psb_mpi_c_spk_,& + & resv,szs,displs,psb_mpi_c_spk_,root_,icomm,request,info) + endif + + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_cgatherv_v + + + ! ! SUM ! @@ -124,20 +376,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_c_spk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_c_spk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_c_spk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_c_spk_,mpi_sum,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_c_spk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_c_spk_,mpi_sum,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_c_spk_,mpi_sum,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_c_spk_,mpi_sum,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -190,20 +452,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_spk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_c_spk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_c_spk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_c_spk_,mpi_sum,root_,icomm,info) + end if end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_spk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_spk_,mpi_sum,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_c_spk_,mpi_sum,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -258,20 +530,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_spk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_c_spk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_c_spk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_c_spk_,mpi_sum,root_,icomm,info) + end if end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_spk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_c_spk_,mpi_sum,root_, icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_c_spk_,mpi_sum,root_, icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_c_spk_,mpi_sum,root_, icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -328,20 +610,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_c_spk_,mpi_camx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_c_spk_,mpi_camx_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_c_spk_,mpi_camx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_c_spk_,mpi_camx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_c_spk_,mpi_camx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_c_spk_,mpi_camx_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_c_spk_,mpi_camx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_c_spk_,mpi_camx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -395,20 +687,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& psb_mpi_c_spk_,mpi_camx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_spk_,mpi_camx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_c_spk_,mpi_camx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_spk_,mpi_camx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_c_spk_,mpi_camx_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_c_spk_,mpi_camx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_c_spk_,mpi_camx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -463,20 +765,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_spk_,mpi_camx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_c_spk_,mpi_camx_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_c_spk_,mpi_camx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_c_spk_,mpi_camx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_spk_,mpi_camx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_c_spk_,mpi_camx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& & psb_mpi_c_spk_,mpi_camx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -532,20 +844,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_c_spk_,mpi_camn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_c_spk_,mpi_camn_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_c_spk_,mpi_camn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_c_spk_,mpi_camn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_c_spk_,mpi_camn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_c_spk_,mpi_camn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_c_spk_,mpi_camn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_c_spk_,mpi_camn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -599,20 +921,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_spk_,mpi_camn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_spk_,mpi_camn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_c_spk_,mpi_camn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_spk_,mpi_camn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_c_spk_,mpi_camn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_c_spk_,mpi_camn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_c_spk_,mpi_camn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -667,20 +999,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_spk_,mpi_camn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_c_spk_,mpi_camn_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_c_spk_,mpi_camn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_c_spk_,mpi_camn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_spk_,mpi_camn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_c_spk_,mpi_camn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_c_spk_,mpi_camn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_c_spk_,mpi_camn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -901,12 +1243,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_scan(MPI_IN_PLACE,dat,1,& + call mpi_scan(dat_,dat,1,& & psb_mpi_c_spk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iscan(MPI_IN_PLACE,dat,1,& + call mpi_iscan(dat_,dat,1,& & psb_mpi_c_spk_,mpi_sum,icomm,request,minfo) else if (collective_end) then call mpi_wait(request,status,minfo) @@ -952,12 +1295,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_exscan(MPI_IN_PLACE,dat,1,& + call mpi_exscan(dat_,dat,1,& & psb_mpi_c_spk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iexscan(MPI_IN_PLACE,dat,1,& + call mpi_iexscan(dat_,dat,1,& & psb_mpi_c_spk_,mpi_sum,icomm,request,minfo) else if (collective_end) then call mpi_wait(request,status,minfo) @@ -980,12 +1324,13 @@ contains complex(psb_spk_), intent(inout) :: dat(:) integer(psb_ipk_), intent(in), optional :: mode integer(psb_mpk_), intent(inout), optional :: request + integer(psb_ipk_) :: iam, np, info integer(psb_mpk_) :: minfo integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync - + complex(psb_spk_), allocatable :: dat_(:) #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) icomm = psb_get_mpi_comm(ctxt) @@ -1003,12 +1348,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_scan(MPI_IN_PLACE,dat,size(dat),& + call mpi_scan(dat_,dat,size(dat),& & psb_mpi_c_spk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_iscan(dat_,dat,size(dat),& & psb_mpi_c_spk_,mpi_sum,icomm,request,info) else if (collective_end) then call mpi_wait(request,status,info) @@ -1029,12 +1375,13 @@ contains complex(psb_spk_), intent(inout) :: dat(:) integer(psb_ipk_), intent(in), optional :: mode integer(psb_mpk_), intent(inout), optional :: request - complex(psb_spk_), allocatable :: dat_(:) + integer(psb_ipk_) :: iam, np, info integer(psb_mpk_) :: minfo integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + complex(psb_spk_), allocatable :: dat_(:) #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -1053,12 +1400,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_exscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_exscan(dat_,dat,size(dat),& & psb_mpi_c_spk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iexscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_iexscan(dat_,dat,size(dat),& & psb_mpi_c_spk_,mpi_sum,icomm,request,info) else if (collective_end) then call mpi_wait(request,status,info) @@ -1271,6 +1619,5 @@ contains Enddo end subroutine psb_c_e_simple_triad_a2av - end module psi_c_collective_mod diff --git a/base/modules/penv/psi_d_collective_mod.F90 b/base/modules/penv/psi_d_collective_mod.F90 index 1f91c69e..9639d650 100644 --- a/base/modules/penv/psi_d_collective_mod.F90 +++ b/base/modules/penv/psi_d_collective_mod.F90 @@ -45,6 +45,14 @@ module psi_d_collective_mod module procedure psb_d_nrm2s, psb_d_nrm2v end interface psb_nrm2 + interface psb_gather + module procedure psb_dgather_s, psb_dgather_v + end interface psb_gather + + interface psb_gatherv + module procedure psb_dgatherv_v + end interface + interface psb_sum module procedure psb_dsums, psb_dsumv, psb_dsumm end interface @@ -110,6 +118,7 @@ contains integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + real(psb_dpk_) :: dat_ #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -134,20 +143,29 @@ contains collective_start = .false. collective_end = .false. end if - if (collective_sync) then + if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,psb_mpi_r_dpk_,mpi_max,icomm,info) + call mpi_allreduce(mpi_in_place,dat,1,psb_mpi_r_dpk_,mpi_max,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,psb_mpi_r_dpk_,mpi_max,root_,icomm,info) + if (iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,psb_mpi_r_dpk_,mpi_max,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,psb_mpi_r_dpk_,mpi_max,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_r_dpk_,mpi_max,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_dpk_,mpi_max,root_,icomm,request,info) + if (iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_r_dpk_,mpi_max,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_r_dpk_,mpi_max,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -174,6 +192,7 @@ contains integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + real(psb_dpk_) :: dat_(1) ! This is a dummy #if !defined(SERIAL_MPI) @@ -200,21 +219,31 @@ contains collective_end = .false. end if if (collective_sync) then - if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + if (root_ == -1) then + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_max,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_dpk_,mpi_max,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_max,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_max,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_max,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_dpk_,mpi_max,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_max,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_max,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -242,6 +271,7 @@ contains integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + real(psb_dpk_) :: dat_(1,1) ! this is a dummy #if !defined(SERIAL_MPI) @@ -268,28 +298,37 @@ contains collective_start = .false. collective_end = .false. end if - if (collective_sync) then + if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_max,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_dpk_,mpi_max,root_,icomm,info) - endif + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_max,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_max,root_,icomm,info) + endif + end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_max,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_dpk_,mpi_max,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_max,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_max,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) end if end if - #endif end subroutine psb_dmaxm @@ -340,18 +379,27 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,psb_mpi_r_dpk_,mpi_min,icomm,info) + call mpi_allreduce(mpi_in_place,dat,1,psb_mpi_r_dpk_,mpi_min,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,psb_mpi_r_dpk_,mpi_min,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,psb_mpi_r_dpk_,mpi_min,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,psb_mpi_r_dpk_,mpi_min,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_r_dpk_,mpi_min,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_dpk_,mpi_min,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_r_dpk_,mpi_min,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_r_dpk_,mpi_min,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -405,20 +453,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_min,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_dpk_,mpi_min,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_min,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_min,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_min,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_dpk_,mpi_min,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_min,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_min,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -473,20 +531,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_min,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_dpk_,mpi_min,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_min,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_min,root_,icomm,info) + end if end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_min,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_dpk_,mpi_min,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_min,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_min,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -545,20 +613,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_r_dpk_,mpi_dnrm2_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_dpk_,mpi_dnrm2_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_r_dpk_,mpi_dnrm2_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_r_dpk_,mpi_dnrm2_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_r_dpk_,mpi_dnrm2_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_dpk_,mpi_dnrm2_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_r_dpk_,mpi_dnrm2_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_r_dpk_,mpi_dnrm2_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -612,20 +690,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_dpk_,& + call mpi_allreduce(mpi_in_place,dat,size(dat),psb_mpi_r_dpk_,& & mpi_dnrm2_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_dpk_,& - & mpi_dnrm2_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),psb_mpi_r_dpk_,& + & mpi_dnrm2_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),psb_mpi_r_dpk_,& + & mpi_dnrm2_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_dnrm2_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_dpk_,mpi_dnrm2_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_dnrm2_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_dnrm2_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -636,6 +724,250 @@ contains end subroutine psb_d_nrm2v + ! + ! gather + ! + subroutine psb_dgather_s(ctxt,dat,resv,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + real(psb_dpk_), intent(inout) :: dat, resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allgather(dat,1,psb_mpi_r_dpk_,& + & resv,1,psb_mpi_r_dpk_,icomm,info) + else + call mpi_gather(dat,1,psb_mpi_r_dpk_,& + & resv,1,psb_mpi_r_dpk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallgather(dat,1,psb_mpi_r_dpk_,& + & resv,1,psb_mpi_r_dpk_,icomm,request,info) + else + call mpi_igather(dat,1,psb_mpi_r_dpk_,& + & resv,1,psb_mpi_r_dpk_,root_,icomm,request,info) + endif + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_dgather_s + + subroutine psb_dgather_v(ctxt,dat,resv,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + real(psb_dpk_), intent(inout) :: dat(:), resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allgather(dat,size(dat),psb_mpi_r_dpk_,& + & resv,size(dat),psb_mpi_r_dpk_,icomm,info) + else + call mpi_gather(dat,size(dat),psb_mpi_r_dpk_,& + & resv,size(dat),psb_mpi_r_dpk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallgather(dat,size(dat),psb_mpi_r_dpk_,& + & resv,size(dat),psb_mpi_r_dpk_,icomm,request,info) + else + call mpi_igather(dat,size(dat),psb_mpi_r_dpk_,& + & resv,size(dat),psb_mpi_r_dpk_,root_,icomm,request,info) + endif + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_dgather_v + + subroutine psb_dgatherv_v(ctxt,dat,resv,szs,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + real(psb_dpk_), intent(inout) :: dat(:), resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_mpk_), intent(in), optional :: szs(:) + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info,i + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + integer(psb_mpk_), allocatable :: displs(:) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + call mpi_allgatherv(dat,size(dat),psb_mpi_r_dpk_,& + & resv,szs,displs,psb_mpi_r_dpk_,icomm,info) + else + if (iam == root_) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + else + allocate(displs(0)) + end if + call mpi_gatherv(dat,size(dat),psb_mpi_r_dpk_,& + & resv,szs,displs,psb_mpi_r_dpk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + call mpi_iallgatherv(dat,size(dat),psb_mpi_r_dpk_,& + & resv,szs,displs,psb_mpi_r_dpk_,icomm,request,info) + else + if (iam == root_) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + else + allocate(displs(0)) + end if + call mpi_igatherv(dat,size(dat),psb_mpi_r_dpk_,& + & resv,szs,displs,psb_mpi_r_dpk_,root_,icomm,request,info) + endif + + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_dgatherv_v + + + ! ! SUM ! @@ -684,20 +1016,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_r_dpk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_dpk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_r_dpk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_r_dpk_,mpi_sum,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_r_dpk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_dpk_,mpi_sum,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_r_dpk_,mpi_sum,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_r_dpk_,mpi_sum,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -750,20 +1092,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_dpk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_sum,root_,icomm,info) + end if end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_sum,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_sum,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -818,20 +1170,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_dpk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_sum,root_,icomm,info) + end if end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_dpk_,mpi_sum,root_, icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_sum,root_, icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_sum,root_, icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -888,20 +1250,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_r_dpk_,mpi_damx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_dpk_,mpi_damx_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_r_dpk_,mpi_damx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_r_dpk_,mpi_damx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_r_dpk_,mpi_damx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_dpk_,mpi_damx_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_r_dpk_,mpi_damx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_r_dpk_,mpi_damx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -955,20 +1327,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& psb_mpi_r_dpk_,mpi_damx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_damx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_damx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_damx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_dpk_,mpi_damx_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_damx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_damx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -1023,20 +1405,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_damx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_dpk_,mpi_damx_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_damx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_damx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_damx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_damx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& & psb_mpi_r_dpk_,mpi_damx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -1092,20 +1484,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_r_dpk_,mpi_damn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_dpk_,mpi_damn_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_r_dpk_,mpi_damn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_r_dpk_,mpi_damn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_r_dpk_,mpi_damn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_dpk_,mpi_damn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_r_dpk_,mpi_damn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_r_dpk_,mpi_damn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -1159,20 +1561,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_damn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_damn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_damn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_damn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_dpk_,mpi_damn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_damn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_damn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -1227,20 +1639,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_damn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_dpk_,mpi_damn_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_damn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_damn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_dpk_,mpi_damn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_dpk_,mpi_damn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_damn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_dpk_,mpi_damn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -1461,12 +1883,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_scan(MPI_IN_PLACE,dat,1,& + call mpi_scan(dat_,dat,1,& & psb_mpi_r_dpk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iscan(MPI_IN_PLACE,dat,1,& + call mpi_iscan(dat_,dat,1,& & psb_mpi_r_dpk_,mpi_sum,icomm,request,minfo) else if (collective_end) then call mpi_wait(request,status,minfo) @@ -1512,12 +1935,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_exscan(MPI_IN_PLACE,dat,1,& + call mpi_exscan(dat_,dat,1,& & psb_mpi_r_dpk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iexscan(MPI_IN_PLACE,dat,1,& + call mpi_iexscan(dat_,dat,1,& & psb_mpi_r_dpk_,mpi_sum,icomm,request,minfo) else if (collective_end) then call mpi_wait(request,status,minfo) @@ -1540,12 +1964,13 @@ contains real(psb_dpk_), intent(inout) :: dat(:) integer(psb_ipk_), intent(in), optional :: mode integer(psb_mpk_), intent(inout), optional :: request + integer(psb_ipk_) :: iam, np, info integer(psb_mpk_) :: minfo integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync - + real(psb_dpk_), allocatable :: dat_(:) #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) icomm = psb_get_mpi_comm(ctxt) @@ -1563,12 +1988,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_scan(MPI_IN_PLACE,dat,size(dat),& + call mpi_scan(dat_,dat,size(dat),& & psb_mpi_r_dpk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_iscan(dat_,dat,size(dat),& & psb_mpi_r_dpk_,mpi_sum,icomm,request,info) else if (collective_end) then call mpi_wait(request,status,info) @@ -1589,12 +2015,13 @@ contains real(psb_dpk_), intent(inout) :: dat(:) integer(psb_ipk_), intent(in), optional :: mode integer(psb_mpk_), intent(inout), optional :: request - real(psb_dpk_), allocatable :: dat_(:) + integer(psb_ipk_) :: iam, np, info integer(psb_mpk_) :: minfo integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + real(psb_dpk_), allocatable :: dat_(:) #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -1613,12 +2040,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_exscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_exscan(dat_,dat,size(dat),& & psb_mpi_r_dpk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iexscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_iexscan(dat_,dat,size(dat),& & psb_mpi_r_dpk_,mpi_sum,icomm,request,info) else if (collective_end) then call mpi_wait(request,status,info) @@ -1831,6 +2259,5 @@ contains Enddo end subroutine psb_d_e_simple_triad_a2av - end module psi_d_collective_mod diff --git a/base/modules/penv/psi_e_collective_mod.F90 b/base/modules/penv/psi_e_collective_mod.F90 index 7f57f2a6..b9ab089b 100644 --- a/base/modules/penv/psi_e_collective_mod.F90 +++ b/base/modules/penv/psi_e_collective_mod.F90 @@ -42,6 +42,14 @@ module psi_e_collective_mod end interface psb_min + interface psb_gather + module procedure psb_egather_s, psb_egather_v + end interface psb_gather + + interface psb_gatherv + module procedure psb_egatherv_v + end interface + interface psb_sum module procedure psb_esums, psb_esumv, psb_esumm end interface @@ -107,6 +115,7 @@ contains integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + integer(psb_epk_) :: dat_ #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -131,20 +140,29 @@ contains collective_start = .false. collective_end = .false. end if - if (collective_sync) then + if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,psb_mpi_epk_,mpi_max,icomm,info) + call mpi_allreduce(mpi_in_place,dat,1,psb_mpi_epk_,mpi_max,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,psb_mpi_epk_,mpi_max,root_,icomm,info) + if (iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,psb_mpi_epk_,mpi_max,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,psb_mpi_epk_,mpi_max,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_epk_,mpi_max,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_epk_,mpi_max,root_,icomm,request,info) + if (iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_epk_,mpi_max,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_epk_,mpi_max,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -171,6 +189,7 @@ contains integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + integer(psb_epk_) :: dat_(1) ! This is a dummy #if !defined(SERIAL_MPI) @@ -197,21 +216,31 @@ contains collective_end = .false. end if if (collective_sync) then - if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + if (root_ == -1) then + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_max,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_epk_,mpi_max,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_epk_,mpi_max,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_epk_,mpi_max,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_max,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_epk_,mpi_max,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_epk_,mpi_max,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_epk_,mpi_max,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -239,6 +268,7 @@ contains integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + integer(psb_epk_) :: dat_(1,1) ! this is a dummy #if !defined(SERIAL_MPI) @@ -265,28 +295,37 @@ contains collective_start = .false. collective_end = .false. end if - if (collective_sync) then + if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_max,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_epk_,mpi_max,root_,icomm,info) - endif + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_epk_,mpi_max,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_epk_,mpi_max,root_,icomm,info) + endif + end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_max,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_epk_,mpi_max,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_epk_,mpi_max,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_epk_,mpi_max,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) end if end if - #endif end subroutine psb_emaxm @@ -337,18 +376,27 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,psb_mpi_epk_,mpi_min,icomm,info) + call mpi_allreduce(mpi_in_place,dat,1,psb_mpi_epk_,mpi_min,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,psb_mpi_epk_,mpi_min,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,psb_mpi_epk_,mpi_min,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,psb_mpi_epk_,mpi_min,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_epk_,mpi_min,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_epk_,mpi_min,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_epk_,mpi_min,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_epk_,mpi_min,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -402,20 +450,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_min,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_epk_,mpi_min,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_epk_,mpi_min,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_epk_,mpi_min,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_min,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_epk_,mpi_min,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_epk_,mpi_min,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_epk_,mpi_min,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -470,20 +528,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_min,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_epk_,mpi_min,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_epk_,mpi_min,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_epk_,mpi_min,root_,icomm,info) + end if end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_min,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_epk_,mpi_min,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_epk_,mpi_min,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_epk_,mpi_min,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -494,6 +562,250 @@ contains + ! + ! gather + ! + subroutine psb_egather_s(ctxt,dat,resv,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + integer(psb_epk_), intent(inout) :: dat, resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allgather(dat,1,psb_mpi_epk_,& + & resv,1,psb_mpi_epk_,icomm,info) + else + call mpi_gather(dat,1,psb_mpi_epk_,& + & resv,1,psb_mpi_epk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallgather(dat,1,psb_mpi_epk_,& + & resv,1,psb_mpi_epk_,icomm,request,info) + else + call mpi_igather(dat,1,psb_mpi_epk_,& + & resv,1,psb_mpi_epk_,root_,icomm,request,info) + endif + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_egather_s + + subroutine psb_egather_v(ctxt,dat,resv,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + integer(psb_epk_), intent(inout) :: dat(:), resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allgather(dat,size(dat),psb_mpi_epk_,& + & resv,size(dat),psb_mpi_epk_,icomm,info) + else + call mpi_gather(dat,size(dat),psb_mpi_epk_,& + & resv,size(dat),psb_mpi_epk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallgather(dat,size(dat),psb_mpi_epk_,& + & resv,size(dat),psb_mpi_epk_,icomm,request,info) + else + call mpi_igather(dat,size(dat),psb_mpi_epk_,& + & resv,size(dat),psb_mpi_epk_,root_,icomm,request,info) + endif + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_egather_v + + subroutine psb_egatherv_v(ctxt,dat,resv,szs,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + integer(psb_epk_), intent(inout) :: dat(:), resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_mpk_), intent(in), optional :: szs(:) + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info,i + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + integer(psb_mpk_), allocatable :: displs(:) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + call mpi_allgatherv(dat,size(dat),psb_mpi_epk_,& + & resv,szs,displs,psb_mpi_epk_,icomm,info) + else + if (iam == root_) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + else + allocate(displs(0)) + end if + call mpi_gatherv(dat,size(dat),psb_mpi_epk_,& + & resv,szs,displs,psb_mpi_epk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + call mpi_iallgatherv(dat,size(dat),psb_mpi_epk_,& + & resv,szs,displs,psb_mpi_epk_,icomm,request,info) + else + if (iam == root_) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + else + allocate(displs(0)) + end if + call mpi_igatherv(dat,size(dat),psb_mpi_epk_,& + & resv,szs,displs,psb_mpi_epk_,root_,icomm,request,info) + endif + + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_egatherv_v + + + ! ! SUM ! @@ -542,20 +854,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_epk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_epk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_epk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_epk_,mpi_sum,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_epk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_epk_,mpi_sum,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_epk_,mpi_sum,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_epk_,mpi_sum,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -608,20 +930,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_epk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_epk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_epk_,mpi_sum,root_,icomm,info) + end if end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_sum,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_epk_,mpi_sum,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -676,20 +1008,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_epk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_epk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_epk_,mpi_sum,root_,icomm,info) + end if end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_epk_,mpi_sum,root_, icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_epk_,mpi_sum,root_, icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_epk_,mpi_sum,root_, icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -746,20 +1088,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_epk_,mpi_eamx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_epk_,mpi_eamx_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_epk_,mpi_eamx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_epk_,mpi_eamx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_epk_,mpi_eamx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_epk_,mpi_eamx_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_epk_,mpi_eamx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_epk_,mpi_eamx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -813,20 +1165,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& psb_mpi_epk_,mpi_eamx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_eamx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_epk_,mpi_eamx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_eamx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_epk_,mpi_eamx_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_epk_,mpi_eamx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_epk_,mpi_eamx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -881,20 +1243,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_eamx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_epk_,mpi_eamx_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_epk_,mpi_eamx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_epk_,mpi_eamx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_eamx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_epk_,mpi_eamx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& & psb_mpi_epk_,mpi_eamx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -950,20 +1322,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_epk_,mpi_eamn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_epk_,mpi_eamn_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_epk_,mpi_eamn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_epk_,mpi_eamn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_epk_,mpi_eamn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_epk_,mpi_eamn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_epk_,mpi_eamn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_epk_,mpi_eamn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -1017,20 +1399,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_eamn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_eamn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_epk_,mpi_eamn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_eamn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_epk_,mpi_eamn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_epk_,mpi_eamn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_epk_,mpi_eamn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -1085,20 +1477,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_eamn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_epk_,mpi_eamn_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_epk_,mpi_eamn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_epk_,mpi_eamn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_epk_,mpi_eamn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_epk_,mpi_eamn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_epk_,mpi_eamn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_epk_,mpi_eamn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -1319,12 +1721,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_scan(MPI_IN_PLACE,dat,1,& + call mpi_scan(dat_,dat,1,& & psb_mpi_epk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iscan(MPI_IN_PLACE,dat,1,& + call mpi_iscan(dat_,dat,1,& & psb_mpi_epk_,mpi_sum,icomm,request,minfo) else if (collective_end) then call mpi_wait(request,status,minfo) @@ -1370,12 +1773,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_exscan(MPI_IN_PLACE,dat,1,& + call mpi_exscan(dat_,dat,1,& & psb_mpi_epk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iexscan(MPI_IN_PLACE,dat,1,& + call mpi_iexscan(dat_,dat,1,& & psb_mpi_epk_,mpi_sum,icomm,request,minfo) else if (collective_end) then call mpi_wait(request,status,minfo) @@ -1398,12 +1802,13 @@ contains integer(psb_epk_), intent(inout) :: dat(:) integer(psb_ipk_), intent(in), optional :: mode integer(psb_mpk_), intent(inout), optional :: request + integer(psb_ipk_) :: iam, np, info integer(psb_mpk_) :: minfo integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync - + integer(psb_epk_), allocatable :: dat_(:) #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) icomm = psb_get_mpi_comm(ctxt) @@ -1421,12 +1826,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_scan(MPI_IN_PLACE,dat,size(dat),& + call mpi_scan(dat_,dat,size(dat),& & psb_mpi_epk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_iscan(dat_,dat,size(dat),& & psb_mpi_epk_,mpi_sum,icomm,request,info) else if (collective_end) then call mpi_wait(request,status,info) @@ -1447,12 +1853,13 @@ contains integer(psb_epk_), intent(inout) :: dat(:) integer(psb_ipk_), intent(in), optional :: mode integer(psb_mpk_), intent(inout), optional :: request - integer(psb_epk_), allocatable :: dat_(:) + integer(psb_ipk_) :: iam, np, info integer(psb_mpk_) :: minfo integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + integer(psb_epk_), allocatable :: dat_(:) #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -1471,12 +1878,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_exscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_exscan(dat_,dat,size(dat),& & psb_mpi_epk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iexscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_iexscan(dat_,dat,size(dat),& & psb_mpi_epk_,mpi_sum,icomm,request,info) else if (collective_end) then call mpi_wait(request,status,info) @@ -1689,6 +2097,5 @@ contains Enddo end subroutine psb_e_e_simple_triad_a2av - end module psi_e_collective_mod diff --git a/base/modules/penv/psi_i2_collective_mod.F90 b/base/modules/penv/psi_i2_collective_mod.F90 index bfe3bf35..339e4281 100644 --- a/base/modules/penv/psi_i2_collective_mod.F90 +++ b/base/modules/penv/psi_i2_collective_mod.F90 @@ -42,6 +42,14 @@ module psi_i2_collective_mod end interface psb_min + interface psb_gather + module procedure psb_i2gather_s, psb_i2gather_v + end interface psb_gather + + interface psb_gatherv + module procedure psb_i2gatherv_v + end interface + interface psb_sum module procedure psb_i2sums, psb_i2sumv, psb_i2summ end interface @@ -107,6 +115,7 @@ contains integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + integer(psb_i2pk_) :: dat_ #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -131,20 +140,29 @@ contains collective_start = .false. collective_end = .false. end if - if (collective_sync) then + if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,psb_mpi_i2pk_,mpi_max,icomm,info) + call mpi_allreduce(mpi_in_place,dat,1,psb_mpi_i2pk_,mpi_max,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,psb_mpi_i2pk_,mpi_max,root_,icomm,info) + if (iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,psb_mpi_i2pk_,mpi_max,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,psb_mpi_i2pk_,mpi_max,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_i2pk_,mpi_max,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_i2pk_,mpi_max,root_,icomm,request,info) + if (iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_i2pk_,mpi_max,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_i2pk_,mpi_max,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -171,6 +189,7 @@ contains integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + integer(psb_i2pk_) :: dat_(1) ! This is a dummy #if !defined(SERIAL_MPI) @@ -197,21 +216,31 @@ contains collective_end = .false. end if if (collective_sync) then - if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + if (root_ == -1) then + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_max,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_i2pk_,mpi_max,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_i2pk_,mpi_max,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_i2pk_,mpi_max,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_max,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_i2pk_,mpi_max,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_i2pk_,mpi_max,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_i2pk_,mpi_max,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -239,6 +268,7 @@ contains integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + integer(psb_i2pk_) :: dat_(1,1) ! this is a dummy #if !defined(SERIAL_MPI) @@ -265,28 +295,37 @@ contains collective_start = .false. collective_end = .false. end if - if (collective_sync) then + if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_max,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_i2pk_,mpi_max,root_,icomm,info) - endif + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_i2pk_,mpi_max,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_i2pk_,mpi_max,root_,icomm,info) + endif + end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_max,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_i2pk_,mpi_max,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_i2pk_,mpi_max,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_i2pk_,mpi_max,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) end if end if - #endif end subroutine psb_i2maxm @@ -337,18 +376,27 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,psb_mpi_i2pk_,mpi_min,icomm,info) + call mpi_allreduce(mpi_in_place,dat,1,psb_mpi_i2pk_,mpi_min,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,psb_mpi_i2pk_,mpi_min,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,psb_mpi_i2pk_,mpi_min,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,psb_mpi_i2pk_,mpi_min,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_i2pk_,mpi_min,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_i2pk_,mpi_min,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_i2pk_,mpi_min,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_i2pk_,mpi_min,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -402,20 +450,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_min,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_i2pk_,mpi_min,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_i2pk_,mpi_min,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_i2pk_,mpi_min,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_min,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_i2pk_,mpi_min,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_i2pk_,mpi_min,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_i2pk_,mpi_min,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -470,20 +528,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_min,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_i2pk_,mpi_min,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_i2pk_,mpi_min,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_i2pk_,mpi_min,root_,icomm,info) + end if end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_min,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_i2pk_,mpi_min,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_i2pk_,mpi_min,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_i2pk_,mpi_min,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -494,6 +562,250 @@ contains + ! + ! gather + ! + subroutine psb_i2gather_s(ctxt,dat,resv,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + integer(psb_i2pk_), intent(inout) :: dat, resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allgather(dat,1,psb_mpi_i2pk_,& + & resv,1,psb_mpi_i2pk_,icomm,info) + else + call mpi_gather(dat,1,psb_mpi_i2pk_,& + & resv,1,psb_mpi_i2pk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallgather(dat,1,psb_mpi_i2pk_,& + & resv,1,psb_mpi_i2pk_,icomm,request,info) + else + call mpi_igather(dat,1,psb_mpi_i2pk_,& + & resv,1,psb_mpi_i2pk_,root_,icomm,request,info) + endif + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_i2gather_s + + subroutine psb_i2gather_v(ctxt,dat,resv,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + integer(psb_i2pk_), intent(inout) :: dat(:), resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allgather(dat,size(dat),psb_mpi_i2pk_,& + & resv,size(dat),psb_mpi_i2pk_,icomm,info) + else + call mpi_gather(dat,size(dat),psb_mpi_i2pk_,& + & resv,size(dat),psb_mpi_i2pk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallgather(dat,size(dat),psb_mpi_i2pk_,& + & resv,size(dat),psb_mpi_i2pk_,icomm,request,info) + else + call mpi_igather(dat,size(dat),psb_mpi_i2pk_,& + & resv,size(dat),psb_mpi_i2pk_,root_,icomm,request,info) + endif + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_i2gather_v + + subroutine psb_i2gatherv_v(ctxt,dat,resv,szs,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + integer(psb_i2pk_), intent(inout) :: dat(:), resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_mpk_), intent(in), optional :: szs(:) + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info,i + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + integer(psb_mpk_), allocatable :: displs(:) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + call mpi_allgatherv(dat,size(dat),psb_mpi_i2pk_,& + & resv,szs,displs,psb_mpi_i2pk_,icomm,info) + else + if (iam == root_) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + else + allocate(displs(0)) + end if + call mpi_gatherv(dat,size(dat),psb_mpi_i2pk_,& + & resv,szs,displs,psb_mpi_i2pk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + call mpi_iallgatherv(dat,size(dat),psb_mpi_i2pk_,& + & resv,szs,displs,psb_mpi_i2pk_,icomm,request,info) + else + if (iam == root_) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + else + allocate(displs(0)) + end if + call mpi_igatherv(dat,size(dat),psb_mpi_i2pk_,& + & resv,szs,displs,psb_mpi_i2pk_,root_,icomm,request,info) + endif + + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_i2gatherv_v + + + ! ! SUM ! @@ -542,20 +854,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_i2pk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_i2pk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_i2pk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_i2pk_,mpi_sum,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_i2pk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_i2pk_,mpi_sum,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_i2pk_,mpi_sum,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_i2pk_,mpi_sum,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -608,20 +930,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_i2pk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_i2pk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_i2pk_,mpi_sum,root_,icomm,info) + end if end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_sum,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_i2pk_,mpi_sum,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -676,20 +1008,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_i2pk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_i2pk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_i2pk_,mpi_sum,root_,icomm,info) + end if end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_i2pk_,mpi_sum,root_, icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_i2pk_,mpi_sum,root_, icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_i2pk_,mpi_sum,root_, icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -746,20 +1088,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_i2pk_,mpi_i2amx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_i2pk_,mpi_i2amx_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_i2pk_,mpi_i2amx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_i2pk_,mpi_i2amx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_i2pk_,mpi_i2amx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_i2pk_,mpi_i2amx_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_i2pk_,mpi_i2amx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_i2pk_,mpi_i2amx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -813,20 +1165,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& psb_mpi_i2pk_,mpi_i2amx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_i2amx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_i2pk_,mpi_i2amx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_i2amx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_i2pk_,mpi_i2amx_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_i2pk_,mpi_i2amx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_i2pk_,mpi_i2amx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -881,20 +1243,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_i2amx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_i2pk_,mpi_i2amx_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_i2pk_,mpi_i2amx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_i2pk_,mpi_i2amx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_i2amx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_i2pk_,mpi_i2amx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& & psb_mpi_i2pk_,mpi_i2amx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -950,20 +1322,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_i2pk_,mpi_i2amn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_i2pk_,mpi_i2amn_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_i2pk_,mpi_i2amn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_i2pk_,mpi_i2amn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_i2pk_,mpi_i2amn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_i2pk_,mpi_i2amn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_i2pk_,mpi_i2amn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_i2pk_,mpi_i2amn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -1017,20 +1399,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_i2amn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_i2amn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_i2pk_,mpi_i2amn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_i2amn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_i2pk_,mpi_i2amn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_i2pk_,mpi_i2amn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_i2pk_,mpi_i2amn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -1085,20 +1477,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_i2amn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_i2pk_,mpi_i2amn_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_i2pk_,mpi_i2amn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_i2pk_,mpi_i2amn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_i2pk_,mpi_i2amn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_i2pk_,mpi_i2amn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_i2pk_,mpi_i2amn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_i2pk_,mpi_i2amn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -1319,12 +1721,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_scan(MPI_IN_PLACE,dat,1,& + call mpi_scan(dat_,dat,1,& & psb_mpi_i2pk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iscan(MPI_IN_PLACE,dat,1,& + call mpi_iscan(dat_,dat,1,& & psb_mpi_i2pk_,mpi_sum,icomm,request,minfo) else if (collective_end) then call mpi_wait(request,status,minfo) @@ -1370,12 +1773,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_exscan(MPI_IN_PLACE,dat,1,& + call mpi_exscan(dat_,dat,1,& & psb_mpi_i2pk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iexscan(MPI_IN_PLACE,dat,1,& + call mpi_iexscan(dat_,dat,1,& & psb_mpi_i2pk_,mpi_sum,icomm,request,minfo) else if (collective_end) then call mpi_wait(request,status,minfo) @@ -1398,12 +1802,13 @@ contains integer(psb_i2pk_), intent(inout) :: dat(:) integer(psb_ipk_), intent(in), optional :: mode integer(psb_mpk_), intent(inout), optional :: request + integer(psb_ipk_) :: iam, np, info integer(psb_mpk_) :: minfo integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync - + integer(psb_i2pk_), allocatable :: dat_(:) #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) icomm = psb_get_mpi_comm(ctxt) @@ -1421,12 +1826,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_scan(MPI_IN_PLACE,dat,size(dat),& + call mpi_scan(dat_,dat,size(dat),& & psb_mpi_i2pk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_iscan(dat_,dat,size(dat),& & psb_mpi_i2pk_,mpi_sum,icomm,request,info) else if (collective_end) then call mpi_wait(request,status,info) @@ -1447,12 +1853,13 @@ contains integer(psb_i2pk_), intent(inout) :: dat(:) integer(psb_ipk_), intent(in), optional :: mode integer(psb_mpk_), intent(inout), optional :: request - integer(psb_i2pk_), allocatable :: dat_(:) + integer(psb_ipk_) :: iam, np, info integer(psb_mpk_) :: minfo integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + integer(psb_i2pk_), allocatable :: dat_(:) #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -1471,12 +1878,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_exscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_exscan(dat_,dat,size(dat),& & psb_mpi_i2pk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iexscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_iexscan(dat_,dat,size(dat),& & psb_mpi_i2pk_,mpi_sum,icomm,request,info) else if (collective_end) then call mpi_wait(request,status,info) @@ -1689,6 +2097,5 @@ contains Enddo end subroutine psb_i2_e_simple_triad_a2av - end module psi_i2_collective_mod diff --git a/base/modules/penv/psi_m_collective_mod.F90 b/base/modules/penv/psi_m_collective_mod.F90 index 09995175..8f45d398 100644 --- a/base/modules/penv/psi_m_collective_mod.F90 +++ b/base/modules/penv/psi_m_collective_mod.F90 @@ -42,6 +42,14 @@ module psi_m_collective_mod end interface psb_min + interface psb_gather + module procedure psb_mgather_s, psb_mgather_v + end interface psb_gather + + interface psb_gatherv + module procedure psb_mgatherv_v + end interface + interface psb_sum module procedure psb_msums, psb_msumv, psb_msumm end interface @@ -107,6 +115,7 @@ contains integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + integer(psb_mpk_) :: dat_ #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -131,20 +140,29 @@ contains collective_start = .false. collective_end = .false. end if - if (collective_sync) then + if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,psb_mpi_mpk_,mpi_max,icomm,info) + call mpi_allreduce(mpi_in_place,dat,1,psb_mpi_mpk_,mpi_max,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,psb_mpi_mpk_,mpi_max,root_,icomm,info) + if (iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,psb_mpi_mpk_,mpi_max,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,psb_mpi_mpk_,mpi_max,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_mpk_,mpi_max,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_mpk_,mpi_max,root_,icomm,request,info) + if (iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_mpk_,mpi_max,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_mpk_,mpi_max,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -171,6 +189,7 @@ contains integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + integer(psb_mpk_) :: dat_(1) ! This is a dummy #if !defined(SERIAL_MPI) @@ -197,21 +216,31 @@ contains collective_end = .false. end if if (collective_sync) then - if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + if (root_ == -1) then + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_max,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_mpk_,mpi_max,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_mpk_,mpi_max,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_mpk_,mpi_max,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_max,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_mpk_,mpi_max,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_mpk_,mpi_max,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_mpk_,mpi_max,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -239,6 +268,7 @@ contains integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + integer(psb_mpk_) :: dat_(1,1) ! this is a dummy #if !defined(SERIAL_MPI) @@ -265,28 +295,37 @@ contains collective_start = .false. collective_end = .false. end if - if (collective_sync) then + if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_max,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_mpk_,mpi_max,root_,icomm,info) - endif + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_mpk_,mpi_max,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_mpk_,mpi_max,root_,icomm,info) + endif + end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_max,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_mpk_,mpi_max,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_mpk_,mpi_max,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_mpk_,mpi_max,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) end if end if - #endif end subroutine psb_mmaxm @@ -337,18 +376,27 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,psb_mpi_mpk_,mpi_min,icomm,info) + call mpi_allreduce(mpi_in_place,dat,1,psb_mpi_mpk_,mpi_min,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,psb_mpi_mpk_,mpi_min,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,psb_mpi_mpk_,mpi_min,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,psb_mpi_mpk_,mpi_min,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_mpk_,mpi_min,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_mpk_,mpi_min,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_mpk_,mpi_min,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_mpk_,mpi_min,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -402,20 +450,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_min,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_mpk_,mpi_min,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_mpk_,mpi_min,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_mpk_,mpi_min,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_min,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_mpk_,mpi_min,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_mpk_,mpi_min,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_mpk_,mpi_min,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -470,20 +528,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_min,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_mpk_,mpi_min,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_mpk_,mpi_min,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_mpk_,mpi_min,root_,icomm,info) + end if end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_min,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_mpk_,mpi_min,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_mpk_,mpi_min,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_mpk_,mpi_min,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -494,6 +562,250 @@ contains + ! + ! gather + ! + subroutine psb_mgather_s(ctxt,dat,resv,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + integer(psb_mpk_), intent(inout) :: dat, resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allgather(dat,1,psb_mpi_mpk_,& + & resv,1,psb_mpi_mpk_,icomm,info) + else + call mpi_gather(dat,1,psb_mpi_mpk_,& + & resv,1,psb_mpi_mpk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallgather(dat,1,psb_mpi_mpk_,& + & resv,1,psb_mpi_mpk_,icomm,request,info) + else + call mpi_igather(dat,1,psb_mpi_mpk_,& + & resv,1,psb_mpi_mpk_,root_,icomm,request,info) + endif + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_mgather_s + + subroutine psb_mgather_v(ctxt,dat,resv,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + integer(psb_mpk_), intent(inout) :: dat(:), resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allgather(dat,size(dat),psb_mpi_mpk_,& + & resv,size(dat),psb_mpi_mpk_,icomm,info) + else + call mpi_gather(dat,size(dat),psb_mpi_mpk_,& + & resv,size(dat),psb_mpi_mpk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallgather(dat,size(dat),psb_mpi_mpk_,& + & resv,size(dat),psb_mpi_mpk_,icomm,request,info) + else + call mpi_igather(dat,size(dat),psb_mpi_mpk_,& + & resv,size(dat),psb_mpi_mpk_,root_,icomm,request,info) + endif + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_mgather_v + + subroutine psb_mgatherv_v(ctxt,dat,resv,szs,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + integer(psb_mpk_), intent(inout) :: dat(:), resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_mpk_), intent(in), optional :: szs(:) + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info,i + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + integer(psb_mpk_), allocatable :: displs(:) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + call mpi_allgatherv(dat,size(dat),psb_mpi_mpk_,& + & resv,szs,displs,psb_mpi_mpk_,icomm,info) + else + if (iam == root_) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + else + allocate(displs(0)) + end if + call mpi_gatherv(dat,size(dat),psb_mpi_mpk_,& + & resv,szs,displs,psb_mpi_mpk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + call mpi_iallgatherv(dat,size(dat),psb_mpi_mpk_,& + & resv,szs,displs,psb_mpi_mpk_,icomm,request,info) + else + if (iam == root_) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + else + allocate(displs(0)) + end if + call mpi_igatherv(dat,size(dat),psb_mpi_mpk_,& + & resv,szs,displs,psb_mpi_mpk_,root_,icomm,request,info) + endif + + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_mgatherv_v + + + ! ! SUM ! @@ -542,20 +854,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_mpk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_mpk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_mpk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_mpk_,mpi_sum,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_mpk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_mpk_,mpi_sum,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_mpk_,mpi_sum,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_mpk_,mpi_sum,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -608,20 +930,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_mpk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_mpk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_mpk_,mpi_sum,root_,icomm,info) + end if end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_sum,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_mpk_,mpi_sum,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -676,20 +1008,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_mpk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_mpk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_mpk_,mpi_sum,root_,icomm,info) + end if end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_mpk_,mpi_sum,root_, icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_mpk_,mpi_sum,root_, icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_mpk_,mpi_sum,root_, icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -746,20 +1088,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_mpk_,mpi_mamx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_mpk_,mpi_mamx_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_mpk_,mpi_mamx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_mpk_,mpi_mamx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_mpk_,mpi_mamx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_mpk_,mpi_mamx_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_mpk_,mpi_mamx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_mpk_,mpi_mamx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -813,20 +1165,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& psb_mpi_mpk_,mpi_mamx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_mamx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_mpk_,mpi_mamx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_mamx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_mpk_,mpi_mamx_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_mpk_,mpi_mamx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_mpk_,mpi_mamx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -881,20 +1243,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_mamx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_mpk_,mpi_mamx_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_mpk_,mpi_mamx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_mpk_,mpi_mamx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_mamx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_mpk_,mpi_mamx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& & psb_mpi_mpk_,mpi_mamx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -950,20 +1322,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_mpk_,mpi_mamn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_mpk_,mpi_mamn_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_mpk_,mpi_mamn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_mpk_,mpi_mamn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_mpk_,mpi_mamn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_mpk_,mpi_mamn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_mpk_,mpi_mamn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_mpk_,mpi_mamn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -1017,20 +1399,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_mamn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_mamn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_mpk_,mpi_mamn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_mamn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_mpk_,mpi_mamn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_mpk_,mpi_mamn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_mpk_,mpi_mamn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -1085,20 +1477,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_mamn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_mpk_,mpi_mamn_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_mpk_,mpi_mamn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_mpk_,mpi_mamn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_mpk_,mpi_mamn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_mpk_,mpi_mamn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_mpk_,mpi_mamn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_mpk_,mpi_mamn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -1319,12 +1721,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_scan(MPI_IN_PLACE,dat,1,& + call mpi_scan(dat_,dat,1,& & psb_mpi_mpk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iscan(MPI_IN_PLACE,dat,1,& + call mpi_iscan(dat_,dat,1,& & psb_mpi_mpk_,mpi_sum,icomm,request,minfo) else if (collective_end) then call mpi_wait(request,status,minfo) @@ -1370,12 +1773,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_exscan(MPI_IN_PLACE,dat,1,& + call mpi_exscan(dat_,dat,1,& & psb_mpi_mpk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iexscan(MPI_IN_PLACE,dat,1,& + call mpi_iexscan(dat_,dat,1,& & psb_mpi_mpk_,mpi_sum,icomm,request,minfo) else if (collective_end) then call mpi_wait(request,status,minfo) @@ -1398,12 +1802,13 @@ contains integer(psb_mpk_), intent(inout) :: dat(:) integer(psb_ipk_), intent(in), optional :: mode integer(psb_mpk_), intent(inout), optional :: request + integer(psb_ipk_) :: iam, np, info integer(psb_mpk_) :: minfo integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync - + integer(psb_mpk_), allocatable :: dat_(:) #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) icomm = psb_get_mpi_comm(ctxt) @@ -1421,12 +1826,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_scan(MPI_IN_PLACE,dat,size(dat),& + call mpi_scan(dat_,dat,size(dat),& & psb_mpi_mpk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_iscan(dat_,dat,size(dat),& & psb_mpi_mpk_,mpi_sum,icomm,request,info) else if (collective_end) then call mpi_wait(request,status,info) @@ -1447,12 +1853,13 @@ contains integer(psb_mpk_), intent(inout) :: dat(:) integer(psb_ipk_), intent(in), optional :: mode integer(psb_mpk_), intent(inout), optional :: request - integer(psb_mpk_), allocatable :: dat_(:) + integer(psb_ipk_) :: iam, np, info integer(psb_mpk_) :: minfo integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + integer(psb_mpk_), allocatable :: dat_(:) #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -1471,12 +1878,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_exscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_exscan(dat_,dat,size(dat),& & psb_mpi_mpk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iexscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_iexscan(dat_,dat,size(dat),& & psb_mpi_mpk_,mpi_sum,icomm,request,info) else if (collective_end) then call mpi_wait(request,status,info) @@ -1689,6 +2097,5 @@ contains Enddo end subroutine psb_m_e_simple_triad_a2av - end module psi_m_collective_mod diff --git a/base/modules/penv/psi_s_collective_mod.F90 b/base/modules/penv/psi_s_collective_mod.F90 index d8e6ba82..6ffaae05 100644 --- a/base/modules/penv/psi_s_collective_mod.F90 +++ b/base/modules/penv/psi_s_collective_mod.F90 @@ -45,6 +45,14 @@ module psi_s_collective_mod module procedure psb_s_nrm2s, psb_s_nrm2v end interface psb_nrm2 + interface psb_gather + module procedure psb_sgather_s, psb_sgather_v + end interface psb_gather + + interface psb_gatherv + module procedure psb_sgatherv_v + end interface + interface psb_sum module procedure psb_ssums, psb_ssumv, psb_ssumm end interface @@ -110,6 +118,7 @@ contains integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + real(psb_spk_) :: dat_ #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -134,20 +143,29 @@ contains collective_start = .false. collective_end = .false. end if - if (collective_sync) then + if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,psb_mpi_r_spk_,mpi_max,icomm,info) + call mpi_allreduce(mpi_in_place,dat,1,psb_mpi_r_spk_,mpi_max,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,psb_mpi_r_spk_,mpi_max,root_,icomm,info) + if (iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,psb_mpi_r_spk_,mpi_max,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,psb_mpi_r_spk_,mpi_max,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_r_spk_,mpi_max,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_spk_,mpi_max,root_,icomm,request,info) + if (iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_r_spk_,mpi_max,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_r_spk_,mpi_max,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -174,6 +192,7 @@ contains integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + real(psb_spk_) :: dat_(1) ! This is a dummy #if !defined(SERIAL_MPI) @@ -200,21 +219,31 @@ contains collective_end = .false. end if if (collective_sync) then - if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + if (root_ == -1) then + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_max,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_spk_,mpi_max,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_spk_,mpi_max,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_max,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_max,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_spk_,mpi_max,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_spk_,mpi_max,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_max,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -242,6 +271,7 @@ contains integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + real(psb_spk_) :: dat_(1,1) ! this is a dummy #if !defined(SERIAL_MPI) @@ -268,28 +298,37 @@ contains collective_start = .false. collective_end = .false. end if - if (collective_sync) then + if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_max,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_spk_,mpi_max,root_,icomm,info) - endif + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_spk_,mpi_max,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_max,root_,icomm,info) + endif + end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_max,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_spk_,mpi_max,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_spk_,mpi_max,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_max,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) end if end if - #endif end subroutine psb_smaxm @@ -340,18 +379,27 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,psb_mpi_r_spk_,mpi_min,icomm,info) + call mpi_allreduce(mpi_in_place,dat,1,psb_mpi_r_spk_,mpi_min,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,psb_mpi_r_spk_,mpi_min,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,psb_mpi_r_spk_,mpi_min,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,psb_mpi_r_spk_,mpi_min,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_r_spk_,mpi_min,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_spk_,mpi_min,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_r_spk_,mpi_min,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_r_spk_,mpi_min,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -405,20 +453,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_min,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_spk_,mpi_min,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_spk_,mpi_min,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_min,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_min,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_spk_,mpi_min,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_spk_,mpi_min,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_min,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -473,20 +531,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_min,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_spk_,mpi_min,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_spk_,mpi_min,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_min,root_,icomm,info) + end if end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_min,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_spk_,mpi_min,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_spk_,mpi_min,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_min,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -545,20 +613,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_r_spk_,mpi_snrm2_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_spk_,mpi_snrm2_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_r_spk_,mpi_snrm2_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_r_spk_,mpi_snrm2_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_r_spk_,mpi_snrm2_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_spk_,mpi_snrm2_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_r_spk_,mpi_snrm2_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_r_spk_,mpi_snrm2_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -612,20 +690,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_spk_,& + call mpi_allreduce(mpi_in_place,dat,size(dat),psb_mpi_r_spk_,& & mpi_snrm2_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),psb_mpi_r_spk_,& - & mpi_snrm2_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),psb_mpi_r_spk_,& + & mpi_snrm2_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),psb_mpi_r_spk_,& + & mpi_snrm2_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_snrm2_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_spk_,mpi_snrm2_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_spk_,mpi_snrm2_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_snrm2_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -636,6 +724,250 @@ contains end subroutine psb_s_nrm2v + ! + ! gather + ! + subroutine psb_sgather_s(ctxt,dat,resv,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + real(psb_spk_), intent(inout) :: dat, resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allgather(dat,1,psb_mpi_r_spk_,& + & resv,1,psb_mpi_r_spk_,icomm,info) + else + call mpi_gather(dat,1,psb_mpi_r_spk_,& + & resv,1,psb_mpi_r_spk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallgather(dat,1,psb_mpi_r_spk_,& + & resv,1,psb_mpi_r_spk_,icomm,request,info) + else + call mpi_igather(dat,1,psb_mpi_r_spk_,& + & resv,1,psb_mpi_r_spk_,root_,icomm,request,info) + endif + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_sgather_s + + subroutine psb_sgather_v(ctxt,dat,resv,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + real(psb_spk_), intent(inout) :: dat(:), resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allgather(dat,size(dat),psb_mpi_r_spk_,& + & resv,size(dat),psb_mpi_r_spk_,icomm,info) + else + call mpi_gather(dat,size(dat),psb_mpi_r_spk_,& + & resv,size(dat),psb_mpi_r_spk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallgather(dat,size(dat),psb_mpi_r_spk_,& + & resv,size(dat),psb_mpi_r_spk_,icomm,request,info) + else + call mpi_igather(dat,size(dat),psb_mpi_r_spk_,& + & resv,size(dat),psb_mpi_r_spk_,root_,icomm,request,info) + endif + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_sgather_v + + subroutine psb_sgatherv_v(ctxt,dat,resv,szs,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + real(psb_spk_), intent(inout) :: dat(:), resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_mpk_), intent(in), optional :: szs(:) + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info,i + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + integer(psb_mpk_), allocatable :: displs(:) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + call mpi_allgatherv(dat,size(dat),psb_mpi_r_spk_,& + & resv,szs,displs,psb_mpi_r_spk_,icomm,info) + else + if (iam == root_) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + else + allocate(displs(0)) + end if + call mpi_gatherv(dat,size(dat),psb_mpi_r_spk_,& + & resv,szs,displs,psb_mpi_r_spk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + call mpi_iallgatherv(dat,size(dat),psb_mpi_r_spk_,& + & resv,szs,displs,psb_mpi_r_spk_,icomm,request,info) + else + if (iam == root_) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + else + allocate(displs(0)) + end if + call mpi_igatherv(dat,size(dat),psb_mpi_r_spk_,& + & resv,szs,displs,psb_mpi_r_spk_,root_,icomm,request,info) + endif + + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_sgatherv_v + + + ! ! SUM ! @@ -684,20 +1016,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_r_spk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_spk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_r_spk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_r_spk_,mpi_sum,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_r_spk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_spk_,mpi_sum,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_r_spk_,mpi_sum,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_r_spk_,mpi_sum,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -750,20 +1092,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_spk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_spk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_sum,root_,icomm,info) + end if end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_sum,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_sum,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -818,20 +1170,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_spk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_spk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_sum,root_,icomm,info) + end if end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_spk_,mpi_sum,root_, icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_spk_,mpi_sum,root_, icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_sum,root_, icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -888,20 +1250,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_r_spk_,mpi_samx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_spk_,mpi_samx_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_r_spk_,mpi_samx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_r_spk_,mpi_samx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_r_spk_,mpi_samx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_spk_,mpi_samx_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_r_spk_,mpi_samx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_r_spk_,mpi_samx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -955,20 +1327,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& psb_mpi_r_spk_,mpi_samx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_samx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_samx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_samx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_spk_,mpi_samx_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_spk_,mpi_samx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_samx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -1023,20 +1405,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_samx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_spk_,mpi_samx_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_spk_,mpi_samx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_samx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_samx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_spk_,mpi_samx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& & psb_mpi_r_spk_,mpi_samx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -1092,20 +1484,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_r_spk_,mpi_samn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_spk_,mpi_samn_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_r_spk_,mpi_samn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_r_spk_,mpi_samn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_r_spk_,mpi_samn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_r_spk_,mpi_samn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_r_spk_,mpi_samn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_r_spk_,mpi_samn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -1159,20 +1561,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_samn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_samn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_samn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_samn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_spk_,mpi_samn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_spk_,mpi_samn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_samn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -1227,20 +1639,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_samn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_spk_,mpi_samn_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_spk_,mpi_samn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_samn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_r_spk_,mpi_samn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_r_spk_,mpi_samn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_r_spk_,mpi_samn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_r_spk_,mpi_samn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -1461,12 +1883,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_scan(MPI_IN_PLACE,dat,1,& + call mpi_scan(dat_,dat,1,& & psb_mpi_r_spk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iscan(MPI_IN_PLACE,dat,1,& + call mpi_iscan(dat_,dat,1,& & psb_mpi_r_spk_,mpi_sum,icomm,request,minfo) else if (collective_end) then call mpi_wait(request,status,minfo) @@ -1512,12 +1935,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_exscan(MPI_IN_PLACE,dat,1,& + call mpi_exscan(dat_,dat,1,& & psb_mpi_r_spk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iexscan(MPI_IN_PLACE,dat,1,& + call mpi_iexscan(dat_,dat,1,& & psb_mpi_r_spk_,mpi_sum,icomm,request,minfo) else if (collective_end) then call mpi_wait(request,status,minfo) @@ -1540,12 +1964,13 @@ contains real(psb_spk_), intent(inout) :: dat(:) integer(psb_ipk_), intent(in), optional :: mode integer(psb_mpk_), intent(inout), optional :: request + integer(psb_ipk_) :: iam, np, info integer(psb_mpk_) :: minfo integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync - + real(psb_spk_), allocatable :: dat_(:) #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) icomm = psb_get_mpi_comm(ctxt) @@ -1563,12 +1988,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_scan(MPI_IN_PLACE,dat,size(dat),& + call mpi_scan(dat_,dat,size(dat),& & psb_mpi_r_spk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_iscan(dat_,dat,size(dat),& & psb_mpi_r_spk_,mpi_sum,icomm,request,info) else if (collective_end) then call mpi_wait(request,status,info) @@ -1589,12 +2015,13 @@ contains real(psb_spk_), intent(inout) :: dat(:) integer(psb_ipk_), intent(in), optional :: mode integer(psb_mpk_), intent(inout), optional :: request - real(psb_spk_), allocatable :: dat_(:) + integer(psb_ipk_) :: iam, np, info integer(psb_mpk_) :: minfo integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + real(psb_spk_), allocatable :: dat_(:) #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -1613,12 +2040,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_exscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_exscan(dat_,dat,size(dat),& & psb_mpi_r_spk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iexscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_iexscan(dat_,dat,size(dat),& & psb_mpi_r_spk_,mpi_sum,icomm,request,info) else if (collective_end) then call mpi_wait(request,status,info) @@ -1831,6 +2259,5 @@ contains Enddo end subroutine psb_s_e_simple_triad_a2av - end module psi_s_collective_mod diff --git a/base/modules/penv/psi_z_collective_mod.F90 b/base/modules/penv/psi_z_collective_mod.F90 index 6f43742f..8b3ec277 100644 --- a/base/modules/penv/psi_z_collective_mod.F90 +++ b/base/modules/penv/psi_z_collective_mod.F90 @@ -34,6 +34,14 @@ module psi_z_collective_mod use psb_desc_const_mod + interface psb_gather + module procedure psb_zgather_s, psb_zgather_v + end interface psb_gather + + interface psb_gatherv + module procedure psb_zgatherv_v + end interface + interface psb_sum module procedure psb_zsums, psb_zsumv, psb_zsumm end interface @@ -76,6 +84,250 @@ contains + ! + ! gather + ! + subroutine psb_zgather_s(ctxt,dat,resv,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + complex(psb_dpk_), intent(inout) :: dat, resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allgather(dat,1,psb_mpi_c_dpk_,& + & resv,1,psb_mpi_c_dpk_,icomm,info) + else + call mpi_gather(dat,1,psb_mpi_c_dpk_,& + & resv,1,psb_mpi_c_dpk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallgather(dat,1,psb_mpi_c_dpk_,& + & resv,1,psb_mpi_c_dpk_,icomm,request,info) + else + call mpi_igather(dat,1,psb_mpi_c_dpk_,& + & resv,1,psb_mpi_c_dpk_,root_,icomm,request,info) + endif + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_zgather_s + + subroutine psb_zgather_v(ctxt,dat,resv,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + complex(psb_dpk_), intent(inout) :: dat(:), resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + call mpi_allgather(dat,size(dat),psb_mpi_c_dpk_,& + & resv,size(dat),psb_mpi_c_dpk_,icomm,info) + else + call mpi_gather(dat,size(dat),psb_mpi_c_dpk_,& + & resv,size(dat),psb_mpi_c_dpk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + call mpi_iallgather(dat,size(dat),psb_mpi_c_dpk_,& + & resv,size(dat),psb_mpi_c_dpk_,icomm,request,info) + else + call mpi_igather(dat,size(dat),psb_mpi_c_dpk_,& + & resv,size(dat),psb_mpi_c_dpk_,root_,icomm,request,info) + endif + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_zgather_v + + subroutine psb_zgatherv_v(ctxt,dat,resv,szs,root,mode,request) +#ifdef MPI_MOD + use mpi +#endif + implicit none +#ifdef MPI_H + include 'mpif.h' +#endif + type(psb_ctxt_type), intent(in) :: ctxt + complex(psb_dpk_), intent(inout) :: dat(:), resv(:) + integer(psb_mpk_), intent(in), optional :: root + integer(psb_mpk_), intent(in), optional :: szs(:) + integer(psb_ipk_), intent(in), optional :: mode + integer(psb_mpk_), intent(inout), optional :: request + integer(psb_mpk_) :: root_ + integer(psb_mpk_) :: iam, np, info,i + integer(psb_mpk_) :: icomm + integer(psb_mpk_) :: status(mpi_status_size) + integer(psb_mpk_), allocatable :: displs(:) + logical :: collective_start, collective_end, collective_sync + +#if defined(SERIAL_MPI) + resv(0) = dat +#else + call psb_info(ctxt,iam,np) + + if (present(root)) then + root_ = root + else + root_ = -1 + endif + icomm = psb_get_mpi_comm(ctxt) + if (present(mode)) then + collective_sync = .false. + collective_start = iand(mode,psb_collective_start_) /= 0 + collective_end = iand(mode,psb_collective_end_) /= 0 + if (.not.present(request)) then + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + else + collective_sync = .true. + collective_start = .false. + collective_end = .false. + end if + if (collective_sync) then + if (root_ == -1) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + call mpi_allgatherv(dat,size(dat),psb_mpi_c_dpk_,& + & resv,szs,displs,psb_mpi_c_dpk_,icomm,info) + else + if (iam == root_) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + else + allocate(displs(0)) + end if + call mpi_gatherv(dat,size(dat),psb_mpi_c_dpk_,& + & resv,szs,displs,psb_mpi_c_dpk_,root_,icomm,info) + endif + else + if (collective_start) then + if (root_ == -1) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + call mpi_iallgatherv(dat,size(dat),psb_mpi_c_dpk_,& + & resv,szs,displs,psb_mpi_c_dpk_,icomm,request,info) + else + if (iam == root_) then + if (size(szs) < np) write(0,*) 'Error: bad input sizes' + allocate(displs(np)) + displs(1) = 0 + do i=2, np + displs(i) = displs(i-1) + szs(i-1) + end do + else + allocate(displs(0)) + end if + call mpi_igatherv(dat,size(dat),psb_mpi_c_dpk_,& + & resv,szs,displs,psb_mpi_c_dpk_,root_,icomm,request,info) + endif + + else if (collective_end) then + call mpi_wait(request,status,info) + end if + end if +#endif + end subroutine psb_zgatherv_v + + + ! ! SUM ! @@ -124,20 +376,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_c_dpk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_c_dpk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_c_dpk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_c_dpk_,mpi_sum,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_c_dpk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_c_dpk_,mpi_sum,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_c_dpk_,mpi_sum,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_c_dpk_,mpi_sum,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -190,20 +452,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_dpk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_c_dpk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_sum,root_,icomm,info) + end if end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_dpk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_dpk_,mpi_sum,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_sum,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -258,20 +530,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_dpk_,mpi_sum,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_c_dpk_,mpi_sum,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_sum,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_sum,root_,icomm,info) + end if end if else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_dpk_,mpi_sum,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_c_dpk_,mpi_sum,root_, icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_sum,root_, icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_sum,root_, icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -328,20 +610,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_c_dpk_,mpi_zamx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_c_dpk_,mpi_zamx_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_c_dpk_,mpi_zamx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_c_dpk_,mpi_zamx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_c_dpk_,mpi_zamx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_c_dpk_,mpi_zamx_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_c_dpk_,mpi_zamx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_c_dpk_,mpi_zamx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -395,20 +687,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& psb_mpi_c_dpk_,mpi_zamx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_dpk_,mpi_zamx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_zamx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_dpk_,mpi_zamx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_c_dpk_,mpi_zamx_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_zamx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_zamx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -463,20 +765,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_dpk_,mpi_zamx_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_c_dpk_,mpi_zamx_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_zamx_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_zamx_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_dpk_,mpi_zamx_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_zamx_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& & psb_mpi_c_dpk_,mpi_zamx_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -532,20 +844,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,1,& + call mpi_allreduce(mpi_in_place,dat,1,& & psb_mpi_c_dpk_,mpi_zamn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_c_dpk_,mpi_zamn_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,1,& + & psb_mpi_c_dpk_,mpi_zamn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,1,& + & psb_mpi_c_dpk_,mpi_zamn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,1,& + call mpi_iallreduce(mpi_in_place,dat,1,& & psb_mpi_c_dpk_,mpi_zamn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,1,& - & psb_mpi_c_dpk_,mpi_zamn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,1,& + & psb_mpi_c_dpk_,mpi_zamn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,1,& + & psb_mpi_c_dpk_,mpi_zamn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -599,20 +921,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_dpk_,mpi_zamn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_dpk_,mpi_zamn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_zamn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_dpk_,mpi_zamn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_c_dpk_,mpi_zamn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_zamn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_zamn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -667,20 +999,30 @@ contains end if if (collective_sync) then if (root_ == -1) then - call mpi_allreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_allreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_dpk_,mpi_zamn_op,icomm,info) else - call mpi_reduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_c_dpk_,mpi_zamn_op,root_,icomm,info) + if(iam==root_) then + call mpi_reduce(mpi_in_place,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_zamn_op,root_,icomm,info) + else + call mpi_reduce(dat,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_zamn_op,root_,icomm,info) + end if endif else if (collective_start) then if (root_ == -1) then - call mpi_iallreduce(MPI_IN_PLACE,dat,size(dat),& + call mpi_iallreduce(mpi_in_place,dat,size(dat),& & psb_mpi_c_dpk_,mpi_zamn_op,icomm,request,info) else - call mpi_ireduce(MPI_IN_PLACE,dat,size(dat),& - & psb_mpi_c_dpk_,mpi_zamn_op,root_,icomm,request,info) + if(iam==root_) then + call mpi_ireduce(mpi_in_place,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_zamn_op,root_,icomm,request,info) + else + call mpi_ireduce(dat,dat,size(dat),& + & psb_mpi_c_dpk_,mpi_zamn_op,root_,icomm,request,info) + end if end if else if (collective_end) then call mpi_wait(request,status,info) @@ -901,12 +1243,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_scan(MPI_IN_PLACE,dat,1,& + call mpi_scan(dat_,dat,1,& & psb_mpi_c_dpk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iscan(MPI_IN_PLACE,dat,1,& + call mpi_iscan(dat_,dat,1,& & psb_mpi_c_dpk_,mpi_sum,icomm,request,minfo) else if (collective_end) then call mpi_wait(request,status,minfo) @@ -952,12 +1295,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_exscan(MPI_IN_PLACE,dat,1,& + call mpi_exscan(dat_,dat,1,& & psb_mpi_c_dpk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iexscan(MPI_IN_PLACE,dat,1,& + call mpi_iexscan(dat_,dat,1,& & psb_mpi_c_dpk_,mpi_sum,icomm,request,minfo) else if (collective_end) then call mpi_wait(request,status,minfo) @@ -980,12 +1324,13 @@ contains complex(psb_dpk_), intent(inout) :: dat(:) integer(psb_ipk_), intent(in), optional :: mode integer(psb_mpk_), intent(inout), optional :: request + integer(psb_ipk_) :: iam, np, info integer(psb_mpk_) :: minfo integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync - + complex(psb_dpk_), allocatable :: dat_(:) #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) icomm = psb_get_mpi_comm(ctxt) @@ -1003,12 +1348,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_scan(MPI_IN_PLACE,dat,size(dat),& + call mpi_scan(dat_,dat,size(dat),& & psb_mpi_c_dpk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_iscan(dat_,dat,size(dat),& & psb_mpi_c_dpk_,mpi_sum,icomm,request,info) else if (collective_end) then call mpi_wait(request,status,info) @@ -1029,12 +1375,13 @@ contains complex(psb_dpk_), intent(inout) :: dat(:) integer(psb_ipk_), intent(in), optional :: mode integer(psb_mpk_), intent(inout), optional :: request - complex(psb_dpk_), allocatable :: dat_(:) + integer(psb_ipk_) :: iam, np, info integer(psb_mpk_) :: minfo integer(psb_mpk_) :: icomm integer(psb_mpk_) :: status(mpi_status_size) logical :: collective_start, collective_end, collective_sync + complex(psb_dpk_), allocatable :: dat_(:) #if !defined(SERIAL_MPI) call psb_info(ctxt,iam,np) @@ -1053,12 +1400,13 @@ contains collective_start = .false. collective_end = .false. end if + dat_ = dat if (collective_sync) then - call mpi_exscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_exscan(dat_,dat,size(dat),& & psb_mpi_c_dpk_,mpi_sum,icomm,minfo) else if (collective_start) then - call mpi_iexscan(MPI_IN_PLACE,dat,size(dat),& + call mpi_iexscan(dat_,dat,size(dat),& & psb_mpi_c_dpk_,mpi_sum,icomm,request,info) else if (collective_end) then call mpi_wait(request,status,info) @@ -1271,6 +1619,5 @@ contains Enddo end subroutine psb_z_e_simple_triad_a2av - end module psi_z_collective_mod From baf18cebd728f50b71b738affd23694934676584 Mon Sep 17 00:00:00 2001 From: sfilippone Date: Fri, 3 Nov 2023 14:39:01 +0100 Subject: [PATCH 09/48] Further fix for gather. --- base/comm/psb_cgather_a.f90 | 37 ++++++++++++++---------------------- base/comm/psb_dgather_a.f90 | 37 ++++++++++++++---------------------- base/comm/psb_egather_a.f90 | 37 ++++++++++++++---------------------- base/comm/psb_i2gather_a.f90 | 37 ++++++++++++++---------------------- base/comm/psb_mgather_a.f90 | 37 ++++++++++++++---------------------- base/comm/psb_sgather_a.f90 | 37 ++++++++++++++---------------------- base/comm/psb_zgather_a.f90 | 37 ++++++++++++++---------------------- 7 files changed, 98 insertions(+), 161 deletions(-) diff --git a/base/comm/psb_cgather_a.f90 b/base/comm/psb_cgather_a.f90 index 9212b328..ac2e66e4 100644 --- a/base/comm/psb_cgather_a.f90 +++ b/base/comm/psb_cgather_a.f90 @@ -60,7 +60,7 @@ subroutine psb_cgatherm(globx, locx, desc_a, info, iroot) type(psb_ctxt_type) :: ctxt integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& - & maxk, k, jlx, ilx, i, j, loc_rows + & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx character(len=20) :: name, ch_err @@ -232,11 +232,11 @@ subroutine psb_cgatherv(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx - integer(psb_mpk_), allocatable :: szs(:) + character(len=20) :: name, ch_err name='psb_cgatherv' @@ -307,32 +307,23 @@ subroutine psb_cgatherv(globx, locx, desc_a, info, iroot) goto 9999 end if + globx(:)=czero + + do i=1,desc_a%get_local_rows() + call psb_loc_to_glob(i,idx,desc_a,info) + globx(idx) = locx(i) + end do + ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - locx(idx) = czero + call psb_loc_to_glob(idx,desc_a,info) + globx(idx) = czero end if end do - loc_rows = desc_a%get_local_rows() - if ((me == root).or.(root == -1)) then - allocate(szs(np)) - end if - call psb_gather(ctxt,loc_rows,szs,root=root) - if ((me == root).or.(root == -1)) then - if (sum(szs) /= m) then - info=psb_err_internal_error_ - call psb_errpush(info,name) - goto 9999 - end if - call psb_realloc(m,globx,info) - if (info /= psb_success_) then - info=psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if - end if - call psb_gatherv(ctxt,locx(1:loc_rows),globx,szs,root=root) + + call psb_sum(ctxt,globx(1:m),root=root) call psb_erractionrestore(err_act) return diff --git a/base/comm/psb_dgather_a.f90 b/base/comm/psb_dgather_a.f90 index eec28bdc..1e03ccfd 100644 --- a/base/comm/psb_dgather_a.f90 +++ b/base/comm/psb_dgather_a.f90 @@ -60,7 +60,7 @@ subroutine psb_dgatherm(globx, locx, desc_a, info, iroot) type(psb_ctxt_type) :: ctxt integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& - & maxk, k, jlx, ilx, i, j, loc_rows + & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx character(len=20) :: name, ch_err @@ -232,11 +232,11 @@ subroutine psb_dgatherv(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx - integer(psb_mpk_), allocatable :: szs(:) + character(len=20) :: name, ch_err name='psb_dgatherv' @@ -307,32 +307,23 @@ subroutine psb_dgatherv(globx, locx, desc_a, info, iroot) goto 9999 end if + globx(:)=dzero + + do i=1,desc_a%get_local_rows() + call psb_loc_to_glob(i,idx,desc_a,info) + globx(idx) = locx(i) + end do + ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - locx(idx) = dzero + call psb_loc_to_glob(idx,desc_a,info) + globx(idx) = dzero end if end do - loc_rows = desc_a%get_local_rows() - if ((me == root).or.(root == -1)) then - allocate(szs(np)) - end if - call psb_gather(ctxt,loc_rows,szs,root=root) - if ((me == root).or.(root == -1)) then - if (sum(szs) /= m) then - info=psb_err_internal_error_ - call psb_errpush(info,name) - goto 9999 - end if - call psb_realloc(m,globx,info) - if (info /= psb_success_) then - info=psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if - end if - call psb_gatherv(ctxt,locx(1:loc_rows),globx,szs,root=root) + + call psb_sum(ctxt,globx(1:m),root=root) call psb_erractionrestore(err_act) return diff --git a/base/comm/psb_egather_a.f90 b/base/comm/psb_egather_a.f90 index 21a41143..b777cebd 100644 --- a/base/comm/psb_egather_a.f90 +++ b/base/comm/psb_egather_a.f90 @@ -60,7 +60,7 @@ subroutine psb_egatherm(globx, locx, desc_a, info, iroot) type(psb_ctxt_type) :: ctxt integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& - & maxk, k, jlx, ilx, i, j, loc_rows + & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx character(len=20) :: name, ch_err @@ -232,11 +232,11 @@ subroutine psb_egatherv(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx - integer(psb_mpk_), allocatable :: szs(:) + character(len=20) :: name, ch_err name='psb_egatherv' @@ -307,32 +307,23 @@ subroutine psb_egatherv(globx, locx, desc_a, info, iroot) goto 9999 end if + globx(:)=ezero + + do i=1,desc_a%get_local_rows() + call psb_loc_to_glob(i,idx,desc_a,info) + globx(idx) = locx(i) + end do + ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - locx(idx) = ezero + call psb_loc_to_glob(idx,desc_a,info) + globx(idx) = ezero end if end do - loc_rows = desc_a%get_local_rows() - if ((me == root).or.(root == -1)) then - allocate(szs(np)) - end if - call psb_gather(ctxt,loc_rows,szs,root=root) - if ((me == root).or.(root == -1)) then - if (sum(szs) /= m) then - info=psb_err_internal_error_ - call psb_errpush(info,name) - goto 9999 - end if - call psb_realloc(m,globx,info) - if (info /= psb_success_) then - info=psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if - end if - call psb_gatherv(ctxt,locx(1:loc_rows),globx,szs,root=root) + + call psb_sum(ctxt,globx(1:m),root=root) call psb_erractionrestore(err_act) return diff --git a/base/comm/psb_i2gather_a.f90 b/base/comm/psb_i2gather_a.f90 index f0f2a93a..e0e1ed7a 100644 --- a/base/comm/psb_i2gather_a.f90 +++ b/base/comm/psb_i2gather_a.f90 @@ -60,7 +60,7 @@ subroutine psb_i2gatherm(globx, locx, desc_a, info, iroot) type(psb_ctxt_type) :: ctxt integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& - & maxk, k, jlx, ilx, i, j, loc_rows + & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx character(len=20) :: name, ch_err @@ -232,11 +232,11 @@ subroutine psb_i2gatherv(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx - integer(psb_mpk_), allocatable :: szs(:) + character(len=20) :: name, ch_err name='psb_i2gatherv' @@ -307,32 +307,23 @@ subroutine psb_i2gatherv(globx, locx, desc_a, info, iroot) goto 9999 end if + globx(:)=i2zero + + do i=1,desc_a%get_local_rows() + call psb_loc_to_glob(i,idx,desc_a,info) + globx(idx) = locx(i) + end do + ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - locx(idx) = i2zero + call psb_loc_to_glob(idx,desc_a,info) + globx(idx) = i2zero end if end do - loc_rows = desc_a%get_local_rows() - if ((me == root).or.(root == -1)) then - allocate(szs(np)) - end if - call psb_gather(ctxt,loc_rows,szs,root=root) - if ((me == root).or.(root == -1)) then - if (sum(szs) /= m) then - info=psb_err_internal_error_ - call psb_errpush(info,name) - goto 9999 - end if - call psb_realloc(m,globx,info) - if (info /= psb_success_) then - info=psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if - end if - call psb_gatherv(ctxt,locx(1:loc_rows),globx,szs,root=root) + + call psb_sum(ctxt,globx(1:m),root=root) call psb_erractionrestore(err_act) return diff --git a/base/comm/psb_mgather_a.f90 b/base/comm/psb_mgather_a.f90 index ccf2f0c0..df574ea2 100644 --- a/base/comm/psb_mgather_a.f90 +++ b/base/comm/psb_mgather_a.f90 @@ -60,7 +60,7 @@ subroutine psb_mgatherm(globx, locx, desc_a, info, iroot) type(psb_ctxt_type) :: ctxt integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& - & maxk, k, jlx, ilx, i, j, loc_rows + & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx character(len=20) :: name, ch_err @@ -232,11 +232,11 @@ subroutine psb_mgatherv(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx - integer(psb_mpk_), allocatable :: szs(:) + character(len=20) :: name, ch_err name='psb_mgatherv' @@ -307,32 +307,23 @@ subroutine psb_mgatherv(globx, locx, desc_a, info, iroot) goto 9999 end if + globx(:)=mzero + + do i=1,desc_a%get_local_rows() + call psb_loc_to_glob(i,idx,desc_a,info) + globx(idx) = locx(i) + end do + ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - locx(idx) = mzero + call psb_loc_to_glob(idx,desc_a,info) + globx(idx) = mzero end if end do - loc_rows = desc_a%get_local_rows() - if ((me == root).or.(root == -1)) then - allocate(szs(np)) - end if - call psb_gather(ctxt,loc_rows,szs,root=root) - if ((me == root).or.(root == -1)) then - if (sum(szs) /= m) then - info=psb_err_internal_error_ - call psb_errpush(info,name) - goto 9999 - end if - call psb_realloc(m,globx,info) - if (info /= psb_success_) then - info=psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if - end if - call psb_gatherv(ctxt,locx(1:loc_rows),globx,szs,root=root) + + call psb_sum(ctxt,globx(1:m),root=root) call psb_erractionrestore(err_act) return diff --git a/base/comm/psb_sgather_a.f90 b/base/comm/psb_sgather_a.f90 index 27e21e78..28d5f5dc 100644 --- a/base/comm/psb_sgather_a.f90 +++ b/base/comm/psb_sgather_a.f90 @@ -60,7 +60,7 @@ subroutine psb_sgatherm(globx, locx, desc_a, info, iroot) type(psb_ctxt_type) :: ctxt integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& - & maxk, k, jlx, ilx, i, j, loc_rows + & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx character(len=20) :: name, ch_err @@ -232,11 +232,11 @@ subroutine psb_sgatherv(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx - integer(psb_mpk_), allocatable :: szs(:) + character(len=20) :: name, ch_err name='psb_sgatherv' @@ -307,32 +307,23 @@ subroutine psb_sgatherv(globx, locx, desc_a, info, iroot) goto 9999 end if + globx(:)=szero + + do i=1,desc_a%get_local_rows() + call psb_loc_to_glob(i,idx,desc_a,info) + globx(idx) = locx(i) + end do + ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - locx(idx) = szero + call psb_loc_to_glob(idx,desc_a,info) + globx(idx) = szero end if end do - loc_rows = desc_a%get_local_rows() - if ((me == root).or.(root == -1)) then - allocate(szs(np)) - end if - call psb_gather(ctxt,loc_rows,szs,root=root) - if ((me == root).or.(root == -1)) then - if (sum(szs) /= m) then - info=psb_err_internal_error_ - call psb_errpush(info,name) - goto 9999 - end if - call psb_realloc(m,globx,info) - if (info /= psb_success_) then - info=psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if - end if - call psb_gatherv(ctxt,locx(1:loc_rows),globx,szs,root=root) + + call psb_sum(ctxt,globx(1:m),root=root) call psb_erractionrestore(err_act) return diff --git a/base/comm/psb_zgather_a.f90 b/base/comm/psb_zgather_a.f90 index 98ed8772..fa5f288b 100644 --- a/base/comm/psb_zgather_a.f90 +++ b/base/comm/psb_zgather_a.f90 @@ -60,7 +60,7 @@ subroutine psb_zgatherm(globx, locx, desc_a, info, iroot) type(psb_ctxt_type) :: ctxt integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& - & maxk, k, jlx, ilx, i, j, loc_rows + & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx character(len=20) :: name, ch_err @@ -232,11 +232,11 @@ subroutine psb_zgatherv(globx, locx, desc_a, info, iroot) ! locals type(psb_ctxt_type) :: ctxt - integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank, loc_rows + integer(psb_mpk_) :: np, me, root, iiroot, icomm, myrank, rootrank integer(psb_ipk_) :: ierr(5), err_act, lda_locx, lda_globx, lock, globk,& & maxk, k, jlx, ilx, i, j integer(psb_lpk_) :: m, n, ilocx, jlocx, idx, iglobx, jglobx - integer(psb_mpk_), allocatable :: szs(:) + character(len=20) :: name, ch_err name='psb_zgatherv' @@ -307,32 +307,23 @@ subroutine psb_zgatherv(globx, locx, desc_a, info, iroot) goto 9999 end if + globx(:)=zzero + + do i=1,desc_a%get_local_rows() + call psb_loc_to_glob(i,idx,desc_a,info) + globx(idx) = locx(i) + end do + ! adjust overlapped elements do i=1, size(desc_a%ovrlap_elem,1) if (me /= desc_a%ovrlap_elem(i,3)) then idx = desc_a%ovrlap_elem(i,1) - locx(idx) = zzero + call psb_loc_to_glob(idx,desc_a,info) + globx(idx) = zzero end if end do - loc_rows = desc_a%get_local_rows() - if ((me == root).or.(root == -1)) then - allocate(szs(np)) - end if - call psb_gather(ctxt,loc_rows,szs,root=root) - if ((me == root).or.(root == -1)) then - if (sum(szs) /= m) then - info=psb_err_internal_error_ - call psb_errpush(info,name) - goto 9999 - end if - call psb_realloc(m,globx,info) - if (info /= psb_success_) then - info=psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if - end if - call psb_gatherv(ctxt,locx(1:loc_rows),globx,szs,root=root) + + call psb_sum(ctxt,globx(1:m),root=root) call psb_erractionrestore(err_act) return From d718ef1e6dcb9ef39bab4c2418da799b2120ea97 Mon Sep 17 00:00:00 2001 From: sfilippone Date: Tue, 7 Nov 2023 10:56:23 +0100 Subject: [PATCH 10/48] Always allocate szs in psb_gather --- base/comm/psb_cgather.f90 | 4 +--- base/comm/psb_dgather.f90 | 4 +--- base/comm/psb_igather.f90 | 4 +--- base/comm/psb_lgather.f90 | 4 +--- base/comm/psb_sgather.f90 | 4 +--- base/comm/psb_zgather.f90 | 4 +--- 6 files changed, 6 insertions(+), 18 deletions(-) diff --git a/base/comm/psb_cgather.f90 b/base/comm/psb_cgather.f90 index fc7ba7fb..bc5302f5 100644 --- a/base/comm/psb_cgather.f90 +++ b/base/comm/psb_cgather.f90 @@ -136,9 +136,7 @@ subroutine psb_cgather_vect(globx, locx, desc_a, info, iroot) end if end do - if ((me == root).or.(root == -1)) then - allocate(szs(np)) - end if + allocate(szs(np)) loc_rows = desc_a%get_local_rows() call psb_gather(ctxt,loc_rows,szs,root=root) if ((me == root).or.(root == -1)) then diff --git a/base/comm/psb_dgather.f90 b/base/comm/psb_dgather.f90 index a12be1e4..ed0591e8 100644 --- a/base/comm/psb_dgather.f90 +++ b/base/comm/psb_dgather.f90 @@ -136,9 +136,7 @@ subroutine psb_dgather_vect(globx, locx, desc_a, info, iroot) end if end do - if ((me == root).or.(root == -1)) then - allocate(szs(np)) - end if + allocate(szs(np)) loc_rows = desc_a%get_local_rows() call psb_gather(ctxt,loc_rows,szs,root=root) if ((me == root).or.(root == -1)) then diff --git a/base/comm/psb_igather.f90 b/base/comm/psb_igather.f90 index 62a84173..acfdf52a 100644 --- a/base/comm/psb_igather.f90 +++ b/base/comm/psb_igather.f90 @@ -136,9 +136,7 @@ subroutine psb_igather_vect(globx, locx, desc_a, info, iroot) end if end do - if ((me == root).or.(root == -1)) then - allocate(szs(np)) - end if + allocate(szs(np)) loc_rows = desc_a%get_local_rows() call psb_gather(ctxt,loc_rows,szs,root=root) if ((me == root).or.(root == -1)) then diff --git a/base/comm/psb_lgather.f90 b/base/comm/psb_lgather.f90 index 7b4e7ac9..17359bce 100644 --- a/base/comm/psb_lgather.f90 +++ b/base/comm/psb_lgather.f90 @@ -136,9 +136,7 @@ subroutine psb_lgather_vect(globx, locx, desc_a, info, iroot) end if end do - if ((me == root).or.(root == -1)) then - allocate(szs(np)) - end if + allocate(szs(np)) loc_rows = desc_a%get_local_rows() call psb_gather(ctxt,loc_rows,szs,root=root) if ((me == root).or.(root == -1)) then diff --git a/base/comm/psb_sgather.f90 b/base/comm/psb_sgather.f90 index 30d25440..59cecc17 100644 --- a/base/comm/psb_sgather.f90 +++ b/base/comm/psb_sgather.f90 @@ -136,9 +136,7 @@ subroutine psb_sgather_vect(globx, locx, desc_a, info, iroot) end if end do - if ((me == root).or.(root == -1)) then - allocate(szs(np)) - end if + allocate(szs(np)) loc_rows = desc_a%get_local_rows() call psb_gather(ctxt,loc_rows,szs,root=root) if ((me == root).or.(root == -1)) then diff --git a/base/comm/psb_zgather.f90 b/base/comm/psb_zgather.f90 index d60f15c6..5cf445a9 100644 --- a/base/comm/psb_zgather.f90 +++ b/base/comm/psb_zgather.f90 @@ -136,9 +136,7 @@ subroutine psb_zgather_vect(globx, locx, desc_a, info, iroot) end if end do - if ((me == root).or.(root == -1)) then - allocate(szs(np)) - end if + allocate(szs(np)) loc_rows = desc_a%get_local_rows() call psb_gather(ctxt,loc_rows,szs,root=root) if ((me == root).or.(root == -1)) then From a2788bdf0ba1a4b6e4bf3bb447b49da76d4db931 Mon Sep 17 00:00:00 2001 From: sfilippone Date: Tue, 7 Nov 2023 13:39:44 +0100 Subject: [PATCH 11/48] New version with ND product --- base/psblas/psb_cspmm.f90 | 23 ++++++++++++++----- base/psblas/psb_dspmm.f90 | 44 ++++++++++++++++++++++++++----------- base/psblas/psb_sspmm.f90 | 23 ++++++++++++++----- base/psblas/psb_zspmm.f90 | 23 ++++++++++++++----- base/tools/psb_cspasb.f90 | 2 +- base/tools/psb_sspasb.f90 | 2 +- base/tools/psb_zspasb.f90 | 2 +- test/pargen/psb_d_pde3d.F90 | 4 ++-- test/pargen/runs/ppde.inp | 2 +- 9 files changed, 88 insertions(+), 37 deletions(-) diff --git a/base/psblas/psb_cspmm.f90 b/base/psblas/psb_cspmm.f90 index 84d8a7d8..25a6bc56 100644 --- a/base/psblas/psb_cspmm.f90 +++ b/base/psblas/psb_cspmm.f90 @@ -180,12 +180,23 @@ subroutine psb_cspmv_vect(alpha,a,x,beta,y,desc_a,info,& ! Matrix is not transposed if (allocated(a%ad)) then - if (doswap_) call psi_swapdata(psb_swap_send_,& - & czero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - call a%ad%spmm(alpha,x%v,beta,y%v,info) - if (doswap_) call psi_swapdata(psb_swap_recv_,& - & czero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - call a%and%spmm(alpha,x%v,cone,y%v,info) + block + logical, parameter :: do_timings=.true. + real(psb_dpk_) :: t1, t2, t3, t4, t5 + if (do_timings) call psb_barrier(ctxt) + if (do_timings) t1= psb_wtime() + if (doswap_) call psi_swapdata(psb_swap_send_,& + & czero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + if (do_timings) t2= psb_wtime() + call a%ad%spmm(alpha,x%v,beta,y%v,info) + if (do_timings) t3= psb_wtime() + if (doswap_) call psi_swapdata(psb_swap_recv_,& + & czero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + if (do_timings) t4= psb_wtime() + call a%and%spmm(alpha,x%v,cone,y%v,info) + if (do_timings) t5= psb_wtime() + if (do_timings) write(0,*) me,' SPMM:',t2-t1,t3-t2,t4-t3,t5-t4 + end block else if (doswap_) then diff --git a/base/psblas/psb_dspmm.f90 b/base/psblas/psb_dspmm.f90 index d5897f82..7888188a 100644 --- a/base/psblas/psb_dspmm.f90 +++ b/base/psblas/psb_dspmm.f90 @@ -180,22 +180,40 @@ subroutine psb_dspmv_vect(alpha,a,x,beta,y,desc_a,info,& ! Matrix is not transposed if (allocated(a%ad)) then - if (doswap_) call psi_swapdata(psb_swap_send_,& - & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - call a%ad%spmm(alpha,x%v,beta,y%v,info) - if (doswap_) call psi_swapdata(psb_swap_recv_,& - & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - call a%and%spmm(alpha,x%v,done,y%v,info) + block + logical, parameter :: do_timings=.true. + real(psb_dpk_) :: t1, t2, t3, t4, t5 + if (do_timings) call psb_barrier(ctxt) + if (do_timings) t1= psb_wtime() + if (doswap_) call psi_swapdata(psb_swap_send_,& + & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + if (do_timings) t2= psb_wtime() + call a%ad%spmm(alpha,x%v,beta,y%v,info) + if (do_timings) t3= psb_wtime() + if (doswap_) call psi_swapdata(psb_swap_recv_,& + & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + if (do_timings) t4= psb_wtime() + call a%and%spmm(alpha,x%v,done,y%v,info) + if (do_timings) t5= psb_wtime() + if (do_timings) write(0,*) me,' SPMM:',t2-t1,t3-t2,t4-t3,t5-t4 + end block else - if (doswap_) then - call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),& - & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + block + logical, parameter :: do_timings=.true. + real(psb_dpk_) :: t1, t2, t3, t4, t5 + if (do_timings) call psb_barrier(ctxt) + if (do_timings) t1= psb_wtime() + if (doswap_) then + call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),& + & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + end if + if (do_timings) t2= psb_wtime() + call psb_csmm(alpha,a,x,beta,y,info) + if (do_timings) t3= psb_wtime() + if (do_timings) write(0,*) me,' SPMM:',t2-t1,t3-t2 + end block end if - - call psb_csmm(alpha,a,x,beta,y,info) - - end if if(info /= psb_success_) then info = psb_err_from_subroutine_non_ diff --git a/base/psblas/psb_sspmm.f90 b/base/psblas/psb_sspmm.f90 index 7c1e0ab3..cf8919f0 100644 --- a/base/psblas/psb_sspmm.f90 +++ b/base/psblas/psb_sspmm.f90 @@ -180,12 +180,23 @@ subroutine psb_sspmv_vect(alpha,a,x,beta,y,desc_a,info,& ! Matrix is not transposed if (allocated(a%ad)) then - if (doswap_) call psi_swapdata(psb_swap_send_,& - & szero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - call a%ad%spmm(alpha,x%v,beta,y%v,info) - if (doswap_) call psi_swapdata(psb_swap_recv_,& - & szero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - call a%and%spmm(alpha,x%v,sone,y%v,info) + block + logical, parameter :: do_timings=.true. + real(psb_dpk_) :: t1, t2, t3, t4, t5 + if (do_timings) call psb_barrier(ctxt) + if (do_timings) t1= psb_wtime() + if (doswap_) call psi_swapdata(psb_swap_send_,& + & szero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + if (do_timings) t2= psb_wtime() + call a%ad%spmm(alpha,x%v,beta,y%v,info) + if (do_timings) t3= psb_wtime() + if (doswap_) call psi_swapdata(psb_swap_recv_,& + & szero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + if (do_timings) t4= psb_wtime() + call a%and%spmm(alpha,x%v,sone,y%v,info) + if (do_timings) t5= psb_wtime() + if (do_timings) write(0,*) me,' SPMM:',t2-t1,t3-t2,t4-t3,t5-t4 + end block else if (doswap_) then diff --git a/base/psblas/psb_zspmm.f90 b/base/psblas/psb_zspmm.f90 index 4dc73f83..629fcf2b 100644 --- a/base/psblas/psb_zspmm.f90 +++ b/base/psblas/psb_zspmm.f90 @@ -180,12 +180,23 @@ subroutine psb_zspmv_vect(alpha,a,x,beta,y,desc_a,info,& ! Matrix is not transposed if (allocated(a%ad)) then - if (doswap_) call psi_swapdata(psb_swap_send_,& - & zzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - call a%ad%spmm(alpha,x%v,beta,y%v,info) - if (doswap_) call psi_swapdata(psb_swap_recv_,& - & zzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - call a%and%spmm(alpha,x%v,zone,y%v,info) + block + logical, parameter :: do_timings=.true. + real(psb_dpk_) :: t1, t2, t3, t4, t5 + if (do_timings) call psb_barrier(ctxt) + if (do_timings) t1= psb_wtime() + if (doswap_) call psi_swapdata(psb_swap_send_,& + & zzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + if (do_timings) t2= psb_wtime() + call a%ad%spmm(alpha,x%v,beta,y%v,info) + if (do_timings) t3= psb_wtime() + if (doswap_) call psi_swapdata(psb_swap_recv_,& + & zzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + if (do_timings) t4= psb_wtime() + call a%and%spmm(alpha,x%v,zone,y%v,info) + if (do_timings) t5= psb_wtime() + if (do_timings) write(0,*) me,' SPMM:',t2-t1,t3-t2,t4-t3,t5-t4 + end block else if (doswap_) then diff --git a/base/tools/psb_cspasb.f90 b/base/tools/psb_cspasb.f90 index 46258139..8263e309 100644 --- a/base/tools/psb_cspasb.f90 +++ b/base/tools/psb_cspasb.f90 @@ -183,7 +183,7 @@ subroutine psb_cspasb(a,desc_a, info, afmt, upd, mold, bld_and) type(psb_c_coo_sparse_mat) :: acoo type(psb_c_csr_sparse_mat), allocatable :: aclip type(psb_c_ecsr_sparse_mat), allocatable :: andclip - logical, parameter :: use_ecsr=.false. + logical, parameter :: use_ecsr=.true. allocate(aclip) call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) allocate(a%ad,mold=a%a) diff --git a/base/tools/psb_sspasb.f90 b/base/tools/psb_sspasb.f90 index 0edae30e..f273c7f4 100644 --- a/base/tools/psb_sspasb.f90 +++ b/base/tools/psb_sspasb.f90 @@ -183,7 +183,7 @@ subroutine psb_sspasb(a,desc_a, info, afmt, upd, mold, bld_and) type(psb_s_coo_sparse_mat) :: acoo type(psb_s_csr_sparse_mat), allocatable :: aclip type(psb_s_ecsr_sparse_mat), allocatable :: andclip - logical, parameter :: use_ecsr=.false. + logical, parameter :: use_ecsr=.true. allocate(aclip) call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) allocate(a%ad,mold=a%a) diff --git a/base/tools/psb_zspasb.f90 b/base/tools/psb_zspasb.f90 index cd77de15..1a381303 100644 --- a/base/tools/psb_zspasb.f90 +++ b/base/tools/psb_zspasb.f90 @@ -183,7 +183,7 @@ subroutine psb_zspasb(a,desc_a, info, afmt, upd, mold, bld_and) type(psb_z_coo_sparse_mat) :: acoo type(psb_z_csr_sparse_mat), allocatable :: aclip type(psb_z_ecsr_sparse_mat), allocatable :: andclip - logical, parameter :: use_ecsr=.false. + logical, parameter :: use_ecsr=.true. allocate(aclip) call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) allocate(a%ad,mold=a%a) diff --git a/test/pargen/psb_d_pde3d.F90 b/test/pargen/psb_d_pde3d.F90 index cd503d29..e802736e 100644 --- a/test/pargen/psb_d_pde3d.F90 +++ b/test/pargen/psb_d_pde3d.F90 @@ -680,9 +680,9 @@ contains t1 = psb_wtime() if (info == psb_success_) then if (present(amold)) then - call psb_spasb(a,desc_a,info,mold=amold,bld_and=.true.) + call psb_spasb(a,desc_a,info,mold=amold,bld_and=.false.) else - call psb_spasb(a,desc_a,info,afmt=afmt,bld_and=.true.) + call psb_spasb(a,desc_a,info,afmt=afmt,bld_and=.false.) end if end if call psb_barrier(ctxt) diff --git a/test/pargen/runs/ppde.inp b/test/pargen/runs/ppde.inp index c70a973f..44dac085 100644 --- a/test/pargen/runs/ppde.inp +++ b/test/pargen/runs/ppde.inp @@ -5,7 +5,7 @@ CSR Storage format for matrix A: CSR COO 200 Domain size (acutal system is this**3 (pde3d) or **2 (pde2d) ) 3 Partition: 1 BLOCK 3 3D 2 Stopping criterion 1 2 -0300 MAXIT +0008 MAXIT 10 ITRACE 002 IRST restart for RGMRES and BiCGSTABL ILU Block Solver ILU,ILUT,INVK,AINVT,AORTH From a6ec655a97d14dbfdf3003e3aecfc32955b45080 Mon Sep 17 00:00:00 2001 From: sfilippone Date: Tue, 7 Nov 2023 17:57:32 +0100 Subject: [PATCH 12/48] Prepare merge --- .../impl/{psb_c_csc_impl.f90 => psb_c_csc_impl.F90} | 0 .../impl/{psb_c_csr_impl.f90 => psb_c_csr_impl.F90} | 0 .../impl/{psb_d_csc_impl.f90 => psb_d_csc_impl.F90} | 0 .../impl/{psb_d_csr_impl.f90 => psb_d_csr_impl.F90} | 0 .../impl/{psb_s_csc_impl.f90 => psb_s_csc_impl.F90} | 0 .../impl/{psb_s_csr_impl.f90 => psb_s_csr_impl.F90} | 0 .../impl/{psb_z_csc_impl.f90 => psb_z_csc_impl.F90} | 0 .../impl/{psb_z_csr_impl.f90 => psb_z_csr_impl.F90} | 0 test/hello/Makefile | 10 ++++++++-- 9 files changed, 8 insertions(+), 2 deletions(-) rename base/serial/impl/{psb_c_csc_impl.f90 => psb_c_csc_impl.F90} (100%) rename base/serial/impl/{psb_c_csr_impl.f90 => psb_c_csr_impl.F90} (100%) rename base/serial/impl/{psb_d_csc_impl.f90 => psb_d_csc_impl.F90} (100%) rename base/serial/impl/{psb_d_csr_impl.f90 => psb_d_csr_impl.F90} (100%) rename base/serial/impl/{psb_s_csc_impl.f90 => psb_s_csc_impl.F90} (100%) rename base/serial/impl/{psb_s_csr_impl.f90 => psb_s_csr_impl.F90} (100%) rename base/serial/impl/{psb_z_csc_impl.f90 => psb_z_csc_impl.F90} (100%) rename base/serial/impl/{psb_z_csr_impl.f90 => psb_z_csr_impl.F90} (100%) diff --git a/base/serial/impl/psb_c_csc_impl.f90 b/base/serial/impl/psb_c_csc_impl.F90 similarity index 100% rename from base/serial/impl/psb_c_csc_impl.f90 rename to base/serial/impl/psb_c_csc_impl.F90 diff --git a/base/serial/impl/psb_c_csr_impl.f90 b/base/serial/impl/psb_c_csr_impl.F90 similarity index 100% rename from base/serial/impl/psb_c_csr_impl.f90 rename to base/serial/impl/psb_c_csr_impl.F90 diff --git a/base/serial/impl/psb_d_csc_impl.f90 b/base/serial/impl/psb_d_csc_impl.F90 similarity index 100% rename from base/serial/impl/psb_d_csc_impl.f90 rename to base/serial/impl/psb_d_csc_impl.F90 diff --git a/base/serial/impl/psb_d_csr_impl.f90 b/base/serial/impl/psb_d_csr_impl.F90 similarity index 100% rename from base/serial/impl/psb_d_csr_impl.f90 rename to base/serial/impl/psb_d_csr_impl.F90 diff --git a/base/serial/impl/psb_s_csc_impl.f90 b/base/serial/impl/psb_s_csc_impl.F90 similarity index 100% rename from base/serial/impl/psb_s_csc_impl.f90 rename to base/serial/impl/psb_s_csc_impl.F90 diff --git a/base/serial/impl/psb_s_csr_impl.f90 b/base/serial/impl/psb_s_csr_impl.F90 similarity index 100% rename from base/serial/impl/psb_s_csr_impl.f90 rename to base/serial/impl/psb_s_csr_impl.F90 diff --git a/base/serial/impl/psb_z_csc_impl.f90 b/base/serial/impl/psb_z_csc_impl.F90 similarity index 100% rename from base/serial/impl/psb_z_csc_impl.f90 rename to base/serial/impl/psb_z_csc_impl.F90 diff --git a/base/serial/impl/psb_z_csr_impl.f90 b/base/serial/impl/psb_z_csr_impl.F90 similarity index 100% rename from base/serial/impl/psb_z_csr_impl.f90 rename to base/serial/impl/psb_z_csr_impl.F90 diff --git a/test/hello/Makefile b/test/hello/Makefile index a6811ea7..f16ff75e 100644 --- a/test/hello/Makefile +++ b/test/hello/Makefile @@ -16,7 +16,7 @@ FINCLUDES=$(FMFLAG)$(MODDIR) $(FMFLAG). EXEDIR=./runs -all: runsd hello pingpong +all: runsd hello pingpong tsum tsum1 runsd: (if test ! -d runs ; then mkdir runs; fi) @@ -28,11 +28,17 @@ hello: hello.o pingpong: pingpong.o $(FLINK) pingpong.o -o pingpong $(PSBLAS_LIB) $(LDLIBS) /bin/mv pingpong $(EXEDIR) +tsum: tsum.o + $(FLINK) tsum.o -o tsum $(PSBLAS_LIB) $(LDLIBS) + /bin/mv tsum $(EXEDIR) +tsum1: tsum1.o + $(FLINK) tsum1.o -o tsum1 $(PSBLAS_LIB) $(LDLIBS) + /bin/mv tsum1 $(EXEDIR) clean: - /bin/rm -f hello.o pingpong.o + /bin/rm -f hello.o pingpong.o tsum.o tsum1.o $(EXEDIR)/hello verycleanlib: (cd ../..; make veryclean) From e9d1238b43c3f1c6b9689920701ead8d65ce8d0f Mon Sep 17 00:00:00 2001 From: sfilippone Date: Wed, 20 Dec 2023 13:30:09 +0100 Subject: [PATCH 13/48] Add detailed measurements. --- base/psblas/psb_dspmm.f90 | 39 ++++++++++++++++++++++++++++--------- test/kernel/pdgenspmv.f90 | 27 +++++++++++++++---------- test/pargen/psb_d_pde3d.F90 | 4 ++-- test/pargen/runs/ppde.inp | 2 +- 4 files changed, 50 insertions(+), 22 deletions(-) diff --git a/base/psblas/psb_dspmm.f90 b/base/psblas/psb_dspmm.f90 index 7888188a..780b4d24 100644 --- a/base/psblas/psb_dspmm.f90 +++ b/base/psblas/psb_dspmm.f90 @@ -83,6 +83,9 @@ subroutine psb_dspmv_vect(alpha,a,x,beta,y,desc_a,info,& character(len=20) :: name, ch_err logical :: aliw, doswap_ integer(psb_ipk_) :: debug_level, debug_unit + logical, parameter :: do_timings=.true. + integer(psb_ipk_), save :: mv_phase1=-1, mv_phase2=-1, mv_phase3=-1, mv_phase4=-1 + integer(psb_ipk_), save :: mv_phase11=-1, mv_phase12=-1 name='psb_dspmv' info=psb_success_ @@ -130,6 +133,19 @@ subroutine psb_dspmv_vect(alpha,a,x,beta,y,desc_a,info,& call psb_errpush(info,name) goto 9999 end if + if ((do_timings).and.(mv_phase1==-1)) & + & mv_phase1 = psb_get_timer_idx("SPMM: and send ") + if ((do_timings).and.(mv_phase2==-1)) & + & mv_phase2 = psb_get_timer_idx("SPMM: and cmp ad") + if ((do_timings).and.(mv_phase3==-1)) & + & mv_phase3 = psb_get_timer_idx("SPMM: and rcv") + if ((do_timings).and.(mv_phase4==-1)) & + & mv_phase4 = psb_get_timer_idx("SPMM: and cmp and") + if ((do_timings).and.(mv_phase11==-1)) & + & mv_phase11 = psb_get_timer_idx("SPMM: noand exch ") + if ((do_timings).and.(mv_phase12==-1)) & + & mv_phase12 = psb_get_timer_idx("SPMM: noand cmp") + m = desc_a%get_global_rows() n = desc_a%get_global_cols() @@ -184,18 +200,22 @@ subroutine psb_dspmv_vect(alpha,a,x,beta,y,desc_a,info,& logical, parameter :: do_timings=.true. real(psb_dpk_) :: t1, t2, t3, t4, t5 if (do_timings) call psb_barrier(ctxt) - if (do_timings) t1= psb_wtime() + if (do_timings) call psb_tic(mv_phase1) if (doswap_) call psi_swapdata(psb_swap_send_,& & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - if (do_timings) t2= psb_wtime() + if (do_timings) call psb_toc(mv_phase1) + if (do_timings) call psb_tic(mv_phase2) call a%ad%spmm(alpha,x%v,beta,y%v,info) - if (do_timings) t3= psb_wtime() + if (do_timings) call psb_toc(mv_phase2) + if (do_timings) call psb_tic(mv_phase3) if (doswap_) call psi_swapdata(psb_swap_recv_,& & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + if (do_timings) call psb_toc(mv_phase3) + if (do_timings) call psb_tic(mv_phase4) if (do_timings) t4= psb_wtime() call a%and%spmm(alpha,x%v,done,y%v,info) - if (do_timings) t5= psb_wtime() - if (do_timings) write(0,*) me,' SPMM:',t2-t1,t3-t2,t4-t3,t5-t4 + if (do_timings) call psb_toc(mv_phase4) + end block else @@ -203,15 +223,16 @@ subroutine psb_dspmv_vect(alpha,a,x,beta,y,desc_a,info,& logical, parameter :: do_timings=.true. real(psb_dpk_) :: t1, t2, t3, t4, t5 if (do_timings) call psb_barrier(ctxt) - if (do_timings) t1= psb_wtime() + + if (do_timings) call psb_tic(mv_phase11) if (doswap_) then call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),& & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) end if - if (do_timings) t2= psb_wtime() + if (do_timings) call psb_toc(mv_phase11) + if (do_timings) call psb_tic(mv_phase12) call psb_csmm(alpha,a,x,beta,y,info) - if (do_timings) t3= psb_wtime() - if (do_timings) write(0,*) me,' SPMM:',t2-t1,t3-t2 + if (do_timings) call psb_toc(mv_phase12) end block end if diff --git a/test/kernel/pdgenspmv.f90 b/test/kernel/pdgenspmv.f90 index d5fd9ba4..b7204edd 100644 --- a/test/kernel/pdgenspmv.f90 +++ b/test/kernel/pdgenspmv.f90 @@ -142,7 +142,7 @@ contains ! the rhs. ! subroutine psb_d_gen_pde3d(ctxt,idim,a,bv,xv,desc_a,afmt,info,& - & f,amold,vmold,imold,partition,nrl,iv) + & f,amold,vmold,imold,partition,nrl,iv,tnd) use psb_base_mod use psb_util_mod ! @@ -173,7 +173,7 @@ contains class(psb_d_base_vect_type), optional :: vmold class(psb_i_base_vect_type), optional :: imold integer(psb_ipk_), optional :: partition, nrl,iv(:) - + logical, optional :: tnd ! Local variables. integer(psb_ipk_), parameter :: nb=20 @@ -202,6 +202,7 @@ contains real(psb_dpk_) :: t0, t1, t2, t3, tasb, talc, ttot, tgen, tcdasb integer(psb_ipk_) :: err_act procedure(d_func_3d), pointer :: f_ + logical :: tnd_ character(len=20) :: name, ch_err,tmpfmt info = psb_success_ @@ -495,9 +496,9 @@ contains t1 = psb_wtime() if (info == psb_success_) then if (present(amold)) then - call psb_spasb(a,desc_a,info,mold=amold) + call psb_spasb(a,desc_a,info,mold=amold,bld_and=tnd) else - call psb_spasb(a,desc_a,info,afmt=afmt) + call psb_spasb(a,desc_a,info,afmt=afmt,bld_and=tnd) end if end if call psb_barrier(ctxt) @@ -549,13 +550,14 @@ program pdgenspmv use psb_base_mod use psb_util_mod use psb_d_pde3d_mod + implicit none ! input parameters character(len=20) :: kmethd, ptype character(len=5) :: afmt integer(psb_ipk_) :: idim - + logical :: tnd ! miscellaneous real(psb_dpk_), parameter :: one = done real(psb_dpk_) :: t1, t2, tprec, flops, tflops, tt1, tt2, bdwdth @@ -606,14 +608,14 @@ program pdgenspmv ! ! get parameters ! - call get_parms(ctxt,afmt,idim) - + call get_parms(ctxt,afmt,idim,tnd) + call psb_init_timers() ! ! allocate and fill in the coefficient matrix, rhs and initial guess ! call psb_barrier(ctxt) t1 = psb_wtime() - call psb_gen_pde3d(ctxt,idim,a,bv,xv,desc_a,afmt,info) + call psb_gen_pde3d(ctxt,idim,a,bv,xv,desc_a,afmt,info,tnd=tnd) call psb_barrier(ctxt) t2 = psb_wtime() - t1 if(info /= psb_success_) then @@ -694,7 +696,7 @@ program pdgenspmv write(psb_out_unit,'("Total memory occupation for DESC_A: ",i12)')descsize end if - + call psb_print_timers(ctxt) ! ! cleanup storage and exit @@ -721,10 +723,11 @@ contains ! ! get iteration parameters from standard input ! - subroutine get_parms(ctxt,afmt,idim) + subroutine get_parms(ctxt,afmt,idim,tnd) type(psb_ctxt_type) :: ctxt character(len=*) :: afmt integer(psb_ipk_) :: idim + logical :: tnd integer(psb_ipk_) :: np, iam integer(psb_ipk_) :: intbuf(10), ip @@ -733,9 +736,11 @@ contains if (iam == 0) then read(psb_inp_unit,*) afmt read(psb_inp_unit,*) idim + read(psb_inp_unit,*) tnd endif call psb_bcast(ctxt,afmt) call psb_bcast(ctxt,idim) + call psb_bcast(ctxt,tnd) if (iam == 0) then write(psb_out_unit,'("Testing matrix : ell1")') @@ -743,6 +748,8 @@ contains write(psb_out_unit,'("Number of processors : ",i0)')np write(psb_out_unit,'("Data distribution : BLOCK")') write(psb_out_unit,'(" ")') + write(psb_out_unit,'("Storage format ",a)') afmt + write(psb_out_unit,'("Testing overlap ND ",l8)') tnd end if return diff --git a/test/pargen/psb_d_pde3d.F90 b/test/pargen/psb_d_pde3d.F90 index 62bb8b40..4748569c 100644 --- a/test/pargen/psb_d_pde3d.F90 +++ b/test/pargen/psb_d_pde3d.F90 @@ -868,8 +868,8 @@ program psb_d_pde3d call psb_errpush(info,name,a_err=ch_err) goto 9999 end if - - call psb_exit(ctxt) + call psb_print_timers(ctxt) + call psb_exit(ctxt) stop 9999 call psb_error(ctxt) diff --git a/test/pargen/runs/ppde.inp b/test/pargen/runs/ppde.inp index cf7179ac..470bcf58 100644 --- a/test/pargen/runs/ppde.inp +++ b/test/pargen/runs/ppde.inp @@ -5,7 +5,7 @@ CSR Storage format for matrix A: CSR COO 200 Domain size (acutal system is this**3 (pde3d) or **2 (pde2d) ) 3 Partition: 1 BLOCK 3 3D 2 Stopping criterion 1 2 -0008 MAXIT +0200 MAXIT 10 ITRACE 002 IRST restart for RGMRES and BiCGSTABL INVK Block Solver ILU,ILUT,INVK,AINVT,AORTH From be7571f56868362159e9ad6ea459434009fae794 Mon Sep 17 00:00:00 2001 From: sfilippone Date: Wed, 20 Dec 2023 13:45:43 +0100 Subject: [PATCH 14/48] Fix missing directive --- base/serial/impl/psb_c_csr_impl.F90 | 1 + base/serial/impl/psb_d_csr_impl.F90 | 1 + base/serial/impl/psb_s_csr_impl.F90 | 1 + base/serial/impl/psb_z_csr_impl.F90 | 1 + 4 files changed, 4 insertions(+) diff --git a/base/serial/impl/psb_c_csr_impl.F90 b/base/serial/impl/psb_c_csr_impl.F90 index f6426f49..6c21f639 100644 --- a/base/serial/impl/psb_c_csr_impl.F90 +++ b/base/serial/impl/psb_c_csr_impl.F90 @@ -4310,6 +4310,7 @@ contains end subroutine csr_spspmm end subroutine psb_ccsrspspmm +#endif subroutine psb_c_ecsr_mold(a,b,info) use psb_c_csr_mat_mod, psb_protect_name => psb_c_ecsr_mold diff --git a/base/serial/impl/psb_d_csr_impl.F90 b/base/serial/impl/psb_d_csr_impl.F90 index 40d97bc0..9f1d509c 100644 --- a/base/serial/impl/psb_d_csr_impl.F90 +++ b/base/serial/impl/psb_d_csr_impl.F90 @@ -4310,6 +4310,7 @@ contains end subroutine csr_spspmm end subroutine psb_dcsrspspmm +#endif subroutine psb_d_ecsr_mold(a,b,info) use psb_d_csr_mat_mod, psb_protect_name => psb_d_ecsr_mold diff --git a/base/serial/impl/psb_s_csr_impl.F90 b/base/serial/impl/psb_s_csr_impl.F90 index abce0086..a4e1ab82 100644 --- a/base/serial/impl/psb_s_csr_impl.F90 +++ b/base/serial/impl/psb_s_csr_impl.F90 @@ -4310,6 +4310,7 @@ contains end subroutine csr_spspmm end subroutine psb_scsrspspmm +#endif subroutine psb_s_ecsr_mold(a,b,info) use psb_s_csr_mat_mod, psb_protect_name => psb_s_ecsr_mold diff --git a/base/serial/impl/psb_z_csr_impl.F90 b/base/serial/impl/psb_z_csr_impl.F90 index b550e8f1..28ac121e 100644 --- a/base/serial/impl/psb_z_csr_impl.F90 +++ b/base/serial/impl/psb_z_csr_impl.F90 @@ -4310,6 +4310,7 @@ contains end subroutine csr_spspmm end subroutine psb_zcsrspspmm +#endif subroutine psb_z_ecsr_mold(a,b,info) use psb_z_csr_mat_mod, psb_protect_name => psb_z_ecsr_mold From 49e99a3e82a49d78726bc0b82eda0e07b5eb80db Mon Sep 17 00:00:00 2001 From: sfilippone Date: Fri, 22 Dec 2023 12:01:41 +0100 Subject: [PATCH 15/48] Fix conversion and product to enable overlap with GPU --- base/psblas/psb_dspmm.f90 | 1 + base/serial/impl/psb_d_mat_impl.F90 | 313 ++++++++++++++++++++-------- 2 files changed, 231 insertions(+), 83 deletions(-) diff --git a/base/psblas/psb_dspmm.f90 b/base/psblas/psb_dspmm.f90 index 780b4d24..8e48c4c2 100644 --- a/base/psblas/psb_dspmm.f90 +++ b/base/psblas/psb_dspmm.f90 @@ -199,6 +199,7 @@ subroutine psb_dspmv_vect(alpha,a,x,beta,y,desc_a,info,& block logical, parameter :: do_timings=.true. real(psb_dpk_) :: t1, t2, t3, t4, t5 + !write(0,*) 'Going for overlap ',a%ad%get_fmt(),' ',a%and%get_fmt() if (do_timings) call psb_barrier(ctxt) if (do_timings) call psb_tic(mv_phase1) if (doswap_) call psi_swapdata(psb_swap_send_,& diff --git a/base/serial/impl/psb_d_mat_impl.F90 b/base/serial/impl/psb_d_mat_impl.F90 index 2a6fb9a5..caf725d1 100644 --- a/base/serial/impl/psb_d_mat_impl.F90 +++ b/base/serial/impl/psb_d_mat_impl.F90 @@ -1246,54 +1246,66 @@ subroutine psb_d_cscnv(a,b,info,type,mold,upd,dupl) goto 9999 end if - if (present(mold)) then - - allocate(altmp, mold=mold,stat=info) - - else if (present(type)) then - - select case (psb_toupper(type)) - case ('CSR') - allocate(psb_d_csr_sparse_mat :: altmp, stat=info) - case ('COO') - allocate(psb_d_coo_sparse_mat :: altmp, stat=info) - case ('CSC') - allocate(psb_d_csc_sparse_mat :: altmp, stat=info) - case default - info = psb_err_format_unknown_ - call psb_errpush(info,name,a_err=type) - goto 9999 - end select - else - allocate(altmp, mold=psb_get_mat_default(a),stat=info) +!!$ if (present(mold)) then +!!$ +!!$ allocate(altmp, mold=mold,stat=info) +!!$ +!!$ else if (present(type)) then +!!$ +!!$ select case (psb_toupper(type)) +!!$ case ('CSR') +!!$ allocate(psb_d_csr_sparse_mat :: altmp, stat=info) +!!$ case ('COO') +!!$ allocate(psb_d_coo_sparse_mat :: altmp, stat=info) +!!$ case ('CSC') +!!$ allocate(psb_d_csc_sparse_mat :: altmp, stat=info) +!!$ case default +!!$ info = psb_err_format_unknown_ +!!$ call psb_errpush(info,name,a_err=type) +!!$ goto 9999 +!!$ end select +!!$ else +!!$ allocate(altmp, mold=psb_get_mat_default(a),stat=info) +!!$ end if +!!$ +!!$ if (info /= psb_success_) then +!!$ info = psb_err_alloc_dealloc_ +!!$ call psb_errpush(info,name) +!!$ goto 9999 +!!$ end if +!!$ +!!$ +!!$ if (present(dupl)) then +!!$ call altmp%set_dupl(dupl) +!!$ else if (a%is_bld()) then +!!$ ! Does this make sense at all?? Who knows.. +!!$ call altmp%set_dupl(psb_dupl_def_) +!!$ end if +!!$ +!!$ if (debug) write(psb_err_unit,*) 'Converting from ',& +!!$ & a%get_fmt(),' to ',altmp%get_fmt() +!!$ +!!$ call altmp%cp_from_fmt(a%a, info) +!!$ +!!$ if (info /= psb_success_) then +!!$ info = psb_err_from_subroutine_ +!!$ call psb_errpush(info,name,a_err="mv_from") +!!$ goto 9999 +!!$ end if +!!$ +!!$ call move_alloc(altmp,b%a) + call inner_cp_alloc(a%a,b%a,info,type,mold) + if (info /= 0) goto 9999 + if (allocated(a%ad)) then + call inner_cp_alloc(a%ad,b%ad,info,type,mold) + if (info /= 0) goto 9999 end if - - if (info /= psb_success_) then - info = psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if - - - if (present(dupl)) then - call altmp%set_dupl(dupl) - else if (a%is_bld()) then - ! Does this make sense at all?? Who knows.. - call altmp%set_dupl(psb_dupl_def_) + if (allocated(a%and)) then + call inner_cp_alloc(a%and,b%and,info,type,mold) + if (info /= 0) goto 9999 end if - if (debug) write(psb_err_unit,*) 'Converting from ',& - & a%get_fmt(),' to ',altmp%get_fmt() - - call altmp%cp_from_fmt(a%a, info) - - if (info /= psb_success_) then - info = psb_err_from_subroutine_ - call psb_errpush(info,name,a_err="mv_from") - goto 9999 - end if - call move_alloc(altmp,b%a) call b%trim() call b%set_asb() call psb_erractionrestore(err_act) @@ -1303,6 +1315,69 @@ subroutine psb_d_cscnv(a,b,info,type,mold,upd,dupl) 9999 call psb_error_handler(err_act) return +contains + subroutine inner_cp_alloc(a,b,info,type,mold) + class(psb_d_base_sparse_mat), intent(in) :: a + class(psb_d_base_sparse_mat), intent(inout), allocatable :: b + integer(psb_ipk_), intent(out) :: info + character(len=*), optional, intent(in) :: type + class(psb_d_base_sparse_mat), intent(in), optional :: mold + + class(psb_d_base_sparse_mat), allocatable :: altmp + + info = psb_success_ + call psb_erractionsave(err_act) + + if (present(mold)) then + + allocate(altmp, mold=mold,stat=info) + + else if (present(type)) then + + select case (psb_toupper(type)) + case ('CSR') + allocate(psb_d_csr_sparse_mat :: altmp, stat=info) + case ('COO') + allocate(psb_d_coo_sparse_mat :: altmp, stat=info) + case ('CSC') + allocate(psb_d_csc_sparse_mat :: altmp, stat=info) + case default + info = psb_err_format_unknown_ + call psb_errpush(info,name,a_err=type) + goto 9999 + end select + else +!!$ allocate(altmp, mold=psb_get_mat_default(a),stat=info) + allocate(psb_d_csr_sparse_mat :: altmp, stat=info) + end if + + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + + if (debug) write(psb_err_unit,*) 'Converting in-place from ',& + & a%get_fmt(),' to ',altmp%get_fmt() + + call altmp%cp_from_fmt(a, info) + + if (info /= psb_success_) then + info = psb_err_from_subroutine_ + call psb_errpush(info,name,a_err="mv_from") + goto 9999 + end if + + call move_alloc(altmp,b) + + call psb_erractionrestore(err_act) + return + + +9999 call psb_error_handler(err_act) + + return + end subroutine inner_cp_alloc end subroutine psb_d_cscnv @@ -1345,46 +1420,57 @@ subroutine psb_d_cscnv_ip(a,info,type,mold,dupl) goto 9999 end if - if (present(mold)) then - - allocate(altmp, mold=mold,stat=info) - - else if (present(type)) then - - select case (psb_toupper(type)) - case ('CSR') - allocate(psb_d_csr_sparse_mat :: altmp, stat=info) - case ('COO') - allocate(psb_d_coo_sparse_mat :: altmp, stat=info) - case ('CSC') - allocate(psb_d_csc_sparse_mat :: altmp, stat=info) - case default - info = psb_err_format_unknown_ - call psb_errpush(info,name,a_err=type) - goto 9999 - end select - else - allocate(altmp, mold=psb_get_mat_default(a),stat=info) +!!$ if (present(mold)) then +!!$ +!!$ allocate(altmp, mold=mold,stat=info) +!!$ +!!$ else if (present(type)) then +!!$ +!!$ select case (psb_toupper(type)) +!!$ case ('CSR') +!!$ allocate(psb_d_csr_sparse_mat :: altmp, stat=info) +!!$ case ('COO') +!!$ allocate(psb_d_coo_sparse_mat :: altmp, stat=info) +!!$ case ('CSC') +!!$ allocate(psb_d_csc_sparse_mat :: altmp, stat=info) +!!$ case default +!!$ info = psb_err_format_unknown_ +!!$ call psb_errpush(info,name,a_err=type) +!!$ goto 9999 +!!$ end select +!!$ else +!!$ allocate(altmp, mold=psb_get_mat_default(a),stat=info) +!!$ end if +!!$ +!!$ if (info /= psb_success_) then +!!$ info = psb_err_alloc_dealloc_ +!!$ call psb_errpush(info,name) +!!$ goto 9999 +!!$ end if +!!$ +!!$ if (debug) write(psb_err_unit,*) 'Converting in-place from ',& +!!$ & a%get_fmt(),' to ',altmp%get_fmt() +!!$ +!!$ call altmp%mv_from_fmt(a%a, info) +!!$ +!!$ if (info /= psb_success_) then +!!$ info = psb_err_from_subroutine_ +!!$ call psb_errpush(info,name,a_err="mv_from") +!!$ goto 9999 +!!$ end if +!!$ +!!$ call move_alloc(altmp,a%a) + + call inner_mv_alloc(a%a,info,type,mold) + if (info /= 0) goto 9999 + if (allocated(a%ad)) then + call inner_mv_alloc(a%ad,info,type,mold) + if (info /= 0) goto 9999 end if - - if (info /= psb_success_) then - info = psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 + if (allocated(a%and)) then + call inner_mv_alloc(a%and,info,type,mold) + if (info /= 0) goto 9999 end if - - if (debug) write(psb_err_unit,*) 'Converting in-place from ',& - & a%get_fmt(),' to ',altmp%get_fmt() - - call altmp%mv_from_fmt(a%a, info) - - if (info /= psb_success_) then - info = psb_err_from_subroutine_ - call psb_errpush(info,name,a_err="mv_from") - goto 9999 - end if - - call move_alloc(altmp,a%a) call a%trim() call a%set_asb() call psb_erractionrestore(err_act) @@ -1394,7 +1480,68 @@ subroutine psb_d_cscnv_ip(a,info,type,mold,dupl) 9999 call psb_error_handler(err_act) return - +contains + subroutine inner_mv_alloc(a,info,type,mold) + class(psb_d_base_sparse_mat), intent(inout), allocatable :: a + integer(psb_ipk_), intent(out) :: info + character(len=*), optional, intent(in) :: type + class(psb_d_base_sparse_mat), intent(in), optional :: mold + + class(psb_d_base_sparse_mat), allocatable :: altmp + + info = psb_success_ + call psb_erractionsave(err_act) + + if (present(mold)) then + + allocate(altmp, mold=mold,stat=info) + + else if (present(type)) then + + select case (psb_toupper(type)) + case ('CSR') + allocate(psb_d_csr_sparse_mat :: altmp, stat=info) + case ('COO') + allocate(psb_d_coo_sparse_mat :: altmp, stat=info) + case ('CSC') + allocate(psb_d_csc_sparse_mat :: altmp, stat=info) + case default + info = psb_err_format_unknown_ + call psb_errpush(info,name,a_err=type) + goto 9999 + end select + else +!!$ allocate(altmp, mold=psb_get_mat_default(a),stat=info) + allocate(psb_d_csr_sparse_mat :: altmp, stat=info) + end if + + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + + if (debug) write(psb_err_unit,*) 'Converting in-place from ',& + & a%get_fmt(),' to ',altmp%get_fmt() + + call altmp%mv_from_fmt(a, info) + + if (info /= psb_success_) then + info = psb_err_from_subroutine_ + call psb_errpush(info,name,a_err="mv_from") + goto 9999 + end if + + call move_alloc(altmp,a) + + call psb_erractionrestore(err_act) + return + + +9999 call psb_error_handler(err_act) + + return + end subroutine inner_mv_alloc end subroutine psb_d_cscnv_ip From 4d051c777d52eea024c1e6da36483b98e650ba77 Mon Sep 17 00:00:00 2001 From: sfilippone Date: Fri, 22 Dec 2023 12:01:59 +0100 Subject: [PATCH 16/48] Fix makefile and test program --- test/cudakern/Makefile | 5 ++++- test/cudakern/dpdegenmv.F90 | 38 +++++++++++++++++++++++++++---------- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/test/cudakern/Makefile b/test/cudakern/Makefile index fdd4f588..5d938973 100755 --- a/test/cudakern/Makefile +++ b/test/cudakern/Makefile @@ -24,9 +24,12 @@ DPGOBJS=dpdegenmv.o SPGOBJS=spdegenmv.o EXEDIR=./runs -all: pgen file +all: dir pgen file pgen: dpdegenmv spdegenmv file: s_file_spmv c_file_spmv d_file_spmv z_file_spmv +dpdegenmv spdegenmv s_file_spmv c_file_spmv d_file_spmv z_file_spmv: dir +dir: + (if test ! -d $(EXEDIR); then mkdir $(EXEDIR); fi) dpdegenmv: $(DPGOBJS) $(FLINK) $(LOPT) $(DPGOBJS) -fopenmp -o dpdegenmv $(FINCLUDES) $(PSBLAS_LIB) $(LDLIBS) diff --git a/test/cudakern/dpdegenmv.F90 b/test/cudakern/dpdegenmv.F90 index db845d71..85059e81 100644 --- a/test/cudakern/dpdegenmv.F90 +++ b/test/cudakern/dpdegenmv.F90 @@ -70,6 +70,16 @@ contains ! ! functions parametrizing the differential equation ! + + ! + ! Note: b1, b2 and b3 are the coefficients of the first + ! derivative of the unknown function. The default + ! we apply here is to have them zero, so that the resulting + ! matrix is symmetric/hermitian and suitable for + ! testing with CG and FCG. + ! When testing methods for non-hermitian matrices you can + ! change the B1/B2/B3 functions to e.g. done/sqrt((3*done)) + ! function b1(x,y,z) use psb_base_mod, only : psb_dpk_, done, dzero implicit none @@ -138,7 +148,7 @@ contains ! the rhs. ! subroutine psb_d_gen_pde3d(ctxt,idim,a,bv,xv,desc_a,afmt,info,& - & f,amold,vmold,imold,partition,nrl,iv) + & f,amold,vmold,imold,partition,nrl,iv,tnd) use psb_base_mod use psb_util_mod ! @@ -169,7 +179,7 @@ contains class(psb_d_base_vect_type), optional :: vmold class(psb_i_base_vect_type), optional :: imold integer(psb_ipk_), optional :: partition, nrl,iv(:) - + logical, optional :: tnd ! Local variables. integer(psb_ipk_), parameter :: nb=20 @@ -198,6 +208,7 @@ contains real(psb_dpk_) :: t0, t1, t2, t3, tasb, talc, ttot, tgen, tcdasb integer(psb_ipk_) :: err_act procedure(d_func_3d), pointer :: f_ + logical :: tnd_ character(len=20) :: name, ch_err,tmpfmt info = psb_success_ @@ -492,9 +503,9 @@ contains t1 = psb_wtime() if (info == psb_success_) then if (present(amold)) then - call psb_spasb(a,desc_a,info,mold=amold) + call psb_spasb(a,desc_a,info,mold=amold,bld_and=tnd) else - call psb_spasb(a,desc_a,info,afmt=afmt) + call psb_spasb(a,desc_a,info,afmt=afmt,bld_and=tnd) end if end if call psb_barrier(ctxt) @@ -559,7 +570,7 @@ program pdgenmv ! input parameters character(len=5) :: acfmt, agfmt integer :: idim - + logical :: tnd ! miscellaneous real(psb_dpk_), parameter :: one = 1.d0 real(psb_dpk_) :: t1, t2, tprec, flops, tflops,& @@ -646,14 +657,14 @@ program pdgenmv ! ! get parameters ! - call get_parms(ctxt,acfmt,agfmt,idim) - + call get_parms(ctxt,acfmt,agfmt,idim,tnd) + call psb_init_timers() ! ! allocate and fill in the coefficient matrix and initial vectors ! call psb_barrier(ctxt) t1 = psb_wtime() - call psb_gen_pde3d(ctxt,idim,a,bv,xv,desc_a,'CSR ',info,partition=3) + call psb_gen_pde3d(ctxt,idim,a,bv,xv,desc_a,'CSR ',info,partition=3,tnd=tnd) call psb_barrier(ctxt) t2 = psb_wtime() - t1 if(info /= psb_success_) then @@ -935,6 +946,7 @@ program pdgenmv write(psb_out_unit,'("Total memory occupation for DESC_A: ",i12)')descsize end if + call psb_print_timers(ctxt) ! ! cleanup storage and exit @@ -962,10 +974,11 @@ contains ! ! get iteration parameters from standard input ! - subroutine get_parms(ctxt,acfmt,agfmt,idim) + subroutine get_parms(ctxt,acfmt,agfmt,idim,tnd) type(psb_ctxt_type) :: ctxt character(len=*) :: agfmt, acfmt integer :: idim + logical :: tnd integer :: np, iam integer :: intbuf(10), ip @@ -978,17 +991,22 @@ contains read(psb_inp_unit,*) agfmt write(*,*) 'Size of discretization cube?' read(psb_inp_unit,*) idim + write(*,*) 'Try comm/comp overlap?' + read(psb_inp_unit,*) tnd endif call psb_bcast(ctxt,acfmt) call psb_bcast(ctxt,agfmt) call psb_bcast(ctxt,idim) - + call psb_bcast(ctxt,tnd) + if (iam == 0) then write(psb_out_unit,'("Testing matrix : ell1")') write(psb_out_unit,'("Grid dimensions : ",i4,"x",i4,"x",i4)')idim,idim,idim write(psb_out_unit,'("Number of processors : ",i0)')np write(psb_out_unit,'("Data distribution : BLOCK")') write(psb_out_unit,'(" ")') + write(psb_out_unit,'("Storage formats ",a)') acfmt,' ',agfmt + write(psb_out_unit,'("Testing overlap ND ",l8)') tnd end if return From 3aa3c795e98fc2a8a4dc9f8ee3aa36e74f13fef8 Mon Sep 17 00:00:00 2001 From: sfilippone Date: Sat, 23 Dec 2023 13:15:01 +0100 Subject: [PATCH 17/48] Refactor assembly and cnv --- base/modules/serial/psb_c_mat_mod.F90 | 15 +- base/modules/serial/psb_d_mat_mod.F90 | 15 +- base/modules/serial/psb_s_mat_mod.F90 | 15 +- base/modules/serial/psb_z_mat_mod.F90 | 15 +- base/psblas/psb_cspmm.f90 | 74 ++++-- base/psblas/psb_dspmm.f90 | 71 +++-- base/psblas/psb_sspmm.f90 | 74 ++++-- base/psblas/psb_zspmm.f90 | 74 ++++-- base/serial/impl/psb_c_mat_impl.F90 | 356 ++++++++++++++++++++----- base/serial/impl/psb_d_mat_impl.F90 | 369 +++++++++++++++----------- base/serial/impl/psb_s_mat_impl.F90 | 356 ++++++++++++++++++++----- base/serial/impl/psb_z_mat_impl.F90 | 356 ++++++++++++++++++++----- base/tools/psb_cspasb.f90 | 73 ++--- base/tools/psb_dspasb.f90 | 73 ++--- base/tools/psb_sspasb.f90 | 73 ++--- base/tools/psb_zspasb.f90 | 73 ++--- 16 files changed, 1457 insertions(+), 625 deletions(-) diff --git a/base/modules/serial/psb_c_mat_mod.F90 b/base/modules/serial/psb_c_mat_mod.F90 index aa891381..ee819535 100644 --- a/base/modules/serial/psb_c_mat_mod.F90 +++ b/base/modules/serial/psb_c_mat_mod.F90 @@ -204,6 +204,7 @@ module psb_c_mat_mod procedure, pass(a) :: cscnv_ip => psb_c_cscnv_ip procedure, pass(a) :: cscnv_base => psb_c_cscnv_base generic, public :: cscnv => cscnv_np, cscnv_ip, cscnv_base + procedure, pass(a) :: split_nd => psb_c_split_nd procedure, pass(a) :: clone => psb_cspmat_clone procedure, pass(a) :: move_alloc => psb_cspmat_type_move ! @@ -842,6 +843,18 @@ module psb_c_mat_mod ! ! + interface + subroutine psb_c_split_nd(a,n_rows,n_cols,info) + import :: psb_ipk_, psb_lpk_, psb_cspmat_type, psb_spk_, psb_c_base_sparse_mat + class(psb_cspmat_type), intent(inout) :: a + integer(psb_ipk_), intent(in) :: n_rows, n_cols + integer(psb_ipk_), intent(out) :: info +!!$ integer(psb_ipk_),optional, intent(in) :: dupl +!!$ character(len=*), optional, intent(in) :: type +!!$ class(psb_c_base_sparse_mat), intent(in), optional :: mold + end subroutine psb_c_split_nd + end interface + ! ! CSCNV: switches to a different internal derived type. ! 3 versions: copying to target @@ -861,7 +874,6 @@ module psb_c_mat_mod end subroutine psb_c_cscnv end interface - interface subroutine psb_c_cscnv_ip(a,iinfo,type,mold,dupl) import :: psb_ipk_, psb_lpk_, psb_cspmat_type, psb_spk_, psb_c_base_sparse_mat @@ -873,7 +885,6 @@ module psb_c_mat_mod end subroutine psb_c_cscnv_ip end interface - interface subroutine psb_c_cscnv_base(a,b,info,dupl) import :: psb_ipk_, psb_lpk_, psb_cspmat_type, psb_spk_, psb_c_base_sparse_mat diff --git a/base/modules/serial/psb_d_mat_mod.F90 b/base/modules/serial/psb_d_mat_mod.F90 index c647e76b..82d2e822 100644 --- a/base/modules/serial/psb_d_mat_mod.F90 +++ b/base/modules/serial/psb_d_mat_mod.F90 @@ -204,6 +204,7 @@ module psb_d_mat_mod procedure, pass(a) :: cscnv_ip => psb_d_cscnv_ip procedure, pass(a) :: cscnv_base => psb_d_cscnv_base generic, public :: cscnv => cscnv_np, cscnv_ip, cscnv_base + procedure, pass(a) :: split_nd => psb_d_split_nd procedure, pass(a) :: clone => psb_dspmat_clone procedure, pass(a) :: move_alloc => psb_dspmat_type_move ! @@ -842,6 +843,18 @@ module psb_d_mat_mod ! ! + interface + subroutine psb_d_split_nd(a,n_rows,n_cols,info) + import :: psb_ipk_, psb_lpk_, psb_dspmat_type, psb_dpk_, psb_d_base_sparse_mat + class(psb_dspmat_type), intent(inout) :: a + integer(psb_ipk_), intent(in) :: n_rows, n_cols + integer(psb_ipk_), intent(out) :: info +!!$ integer(psb_ipk_),optional, intent(in) :: dupl +!!$ character(len=*), optional, intent(in) :: type +!!$ class(psb_d_base_sparse_mat), intent(in), optional :: mold + end subroutine psb_d_split_nd + end interface + ! ! CSCNV: switches to a different internal derived type. ! 3 versions: copying to target @@ -861,7 +874,6 @@ module psb_d_mat_mod end subroutine psb_d_cscnv end interface - interface subroutine psb_d_cscnv_ip(a,iinfo,type,mold,dupl) import :: psb_ipk_, psb_lpk_, psb_dspmat_type, psb_dpk_, psb_d_base_sparse_mat @@ -873,7 +885,6 @@ module psb_d_mat_mod end subroutine psb_d_cscnv_ip end interface - interface subroutine psb_d_cscnv_base(a,b,info,dupl) import :: psb_ipk_, psb_lpk_, psb_dspmat_type, psb_dpk_, psb_d_base_sparse_mat diff --git a/base/modules/serial/psb_s_mat_mod.F90 b/base/modules/serial/psb_s_mat_mod.F90 index 3e6b286a..d8a2e6ae 100644 --- a/base/modules/serial/psb_s_mat_mod.F90 +++ b/base/modules/serial/psb_s_mat_mod.F90 @@ -204,6 +204,7 @@ module psb_s_mat_mod procedure, pass(a) :: cscnv_ip => psb_s_cscnv_ip procedure, pass(a) :: cscnv_base => psb_s_cscnv_base generic, public :: cscnv => cscnv_np, cscnv_ip, cscnv_base + procedure, pass(a) :: split_nd => psb_s_split_nd procedure, pass(a) :: clone => psb_sspmat_clone procedure, pass(a) :: move_alloc => psb_sspmat_type_move ! @@ -842,6 +843,18 @@ module psb_s_mat_mod ! ! + interface + subroutine psb_s_split_nd(a,n_rows,n_cols,info) + import :: psb_ipk_, psb_lpk_, psb_sspmat_type, psb_spk_, psb_s_base_sparse_mat + class(psb_sspmat_type), intent(inout) :: a + integer(psb_ipk_), intent(in) :: n_rows, n_cols + integer(psb_ipk_), intent(out) :: info +!!$ integer(psb_ipk_),optional, intent(in) :: dupl +!!$ character(len=*), optional, intent(in) :: type +!!$ class(psb_s_base_sparse_mat), intent(in), optional :: mold + end subroutine psb_s_split_nd + end interface + ! ! CSCNV: switches to a different internal derived type. ! 3 versions: copying to target @@ -861,7 +874,6 @@ module psb_s_mat_mod end subroutine psb_s_cscnv end interface - interface subroutine psb_s_cscnv_ip(a,iinfo,type,mold,dupl) import :: psb_ipk_, psb_lpk_, psb_sspmat_type, psb_spk_, psb_s_base_sparse_mat @@ -873,7 +885,6 @@ module psb_s_mat_mod end subroutine psb_s_cscnv_ip end interface - interface subroutine psb_s_cscnv_base(a,b,info,dupl) import :: psb_ipk_, psb_lpk_, psb_sspmat_type, psb_spk_, psb_s_base_sparse_mat diff --git a/base/modules/serial/psb_z_mat_mod.F90 b/base/modules/serial/psb_z_mat_mod.F90 index 148e9ab9..694d4efc 100644 --- a/base/modules/serial/psb_z_mat_mod.F90 +++ b/base/modules/serial/psb_z_mat_mod.F90 @@ -204,6 +204,7 @@ module psb_z_mat_mod procedure, pass(a) :: cscnv_ip => psb_z_cscnv_ip procedure, pass(a) :: cscnv_base => psb_z_cscnv_base generic, public :: cscnv => cscnv_np, cscnv_ip, cscnv_base + procedure, pass(a) :: split_nd => psb_z_split_nd procedure, pass(a) :: clone => psb_zspmat_clone procedure, pass(a) :: move_alloc => psb_zspmat_type_move ! @@ -842,6 +843,18 @@ module psb_z_mat_mod ! ! + interface + subroutine psb_z_split_nd(a,n_rows,n_cols,info) + import :: psb_ipk_, psb_lpk_, psb_zspmat_type, psb_dpk_, psb_z_base_sparse_mat + class(psb_zspmat_type), intent(inout) :: a + integer(psb_ipk_), intent(in) :: n_rows, n_cols + integer(psb_ipk_), intent(out) :: info +!!$ integer(psb_ipk_),optional, intent(in) :: dupl +!!$ character(len=*), optional, intent(in) :: type +!!$ class(psb_z_base_sparse_mat), intent(in), optional :: mold + end subroutine psb_z_split_nd + end interface + ! ! CSCNV: switches to a different internal derived type. ! 3 versions: copying to target @@ -861,7 +874,6 @@ module psb_z_mat_mod end subroutine psb_z_cscnv end interface - interface subroutine psb_z_cscnv_ip(a,iinfo,type,mold,dupl) import :: psb_ipk_, psb_lpk_, psb_zspmat_type, psb_dpk_, psb_z_base_sparse_mat @@ -873,7 +885,6 @@ module psb_z_mat_mod end subroutine psb_z_cscnv_ip end interface - interface subroutine psb_z_cscnv_base(a,b,info,dupl) import :: psb_ipk_, psb_lpk_, psb_zspmat_type, psb_dpk_, psb_z_base_sparse_mat diff --git a/base/psblas/psb_cspmm.f90 b/base/psblas/psb_cspmm.f90 index 25a6bc56..22c6408f 100644 --- a/base/psblas/psb_cspmm.f90 +++ b/base/psblas/psb_cspmm.f90 @@ -83,6 +83,9 @@ subroutine psb_cspmv_vect(alpha,a,x,beta,y,desc_a,info,& character(len=20) :: name, ch_err logical :: aliw, doswap_ integer(psb_ipk_) :: debug_level, debug_unit + logical, parameter :: do_timings=.true. + integer(psb_ipk_), save :: mv_phase1=-1, mv_phase2=-1, mv_phase3=-1, mv_phase4=-1 + integer(psb_ipk_), save :: mv_phase11=-1, mv_phase12=-1 name='psb_cspmv' info=psb_success_ @@ -130,6 +133,19 @@ subroutine psb_cspmv_vect(alpha,a,x,beta,y,desc_a,info,& call psb_errpush(info,name) goto 9999 end if + if ((do_timings).and.(mv_phase1==-1)) & + & mv_phase1 = psb_get_timer_idx("SPMM: and send ") + if ((do_timings).and.(mv_phase2==-1)) & + & mv_phase2 = psb_get_timer_idx("SPMM: and cmp ad") + if ((do_timings).and.(mv_phase3==-1)) & + & mv_phase3 = psb_get_timer_idx("SPMM: and rcv") + if ((do_timings).and.(mv_phase4==-1)) & + & mv_phase4 = psb_get_timer_idx("SPMM: and cmp and") + if ((do_timings).and.(mv_phase11==-1)) & + & mv_phase11 = psb_get_timer_idx("SPMM: noand exch ") + if ((do_timings).and.(mv_phase12==-1)) & + & mv_phase12 = psb_get_timer_idx("SPMM: noand cmp") + m = desc_a%get_global_rows() n = desc_a%get_global_cols() @@ -178,34 +194,44 @@ subroutine psb_cspmv_vect(alpha,a,x,beta,y,desc_a,info,& if (trans_ == 'N') then ! Matrix is not transposed - + if (allocated(a%ad)) then block - logical, parameter :: do_timings=.true. - real(psb_dpk_) :: t1, t2, t3, t4, t5 - if (do_timings) call psb_barrier(ctxt) - if (do_timings) t1= psb_wtime() - if (doswap_) call psi_swapdata(psb_swap_send_,& - & czero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - if (do_timings) t2= psb_wtime() - call a%ad%spmm(alpha,x%v,beta,y%v,info) - if (do_timings) t3= psb_wtime() - if (doswap_) call psi_swapdata(psb_swap_recv_,& - & czero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - if (do_timings) t4= psb_wtime() - call a%and%spmm(alpha,x%v,cone,y%v,info) - if (do_timings) t5= psb_wtime() - if (do_timings) write(0,*) me,' SPMM:',t2-t1,t3-t2,t4-t3,t5-t4 - end block - - else - if (doswap_) then - call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),& + logical, parameter :: do_timings=.true. + real(psb_dpk_) :: t1, t2, t3, t4, t5 + !if (me==0) write(0,*) 'going for overlap ',a%ad%get_fmt(),' ',a%and%get_fmt() + if (do_timings) call psb_barrier(ctxt) + if (do_timings) call psb_tic(mv_phase1) + if (doswap_) call psi_swapdata(psb_swap_send_,& & czero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - end if - - call psb_csmm(alpha,a,x,beta,y,info) + if (do_timings) call psb_toc(mv_phase1) + if (do_timings) call psb_tic(mv_phase2) + call a%ad%spmm(alpha,x%v,beta,y%v,info) + if (do_timings) call psb_tic(mv_phase3) + if (doswap_) call psi_swapdata(psb_swap_recv_,& + & czero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + if (do_timings) call psb_toc(mv_phase3) + if (do_timings) call psb_tic(mv_phase4) + call a%and%spmm(alpha,x%v,cone,y%v,info) + if (do_timings) call psb_toc(mv_phase4) + end block + else + block + logical, parameter :: do_timings=.true. + real(psb_dpk_) :: t1, t2, t3, t4, t5 + if (do_timings) call psb_barrier(ctxt) + + if (do_timings) call psb_tic(mv_phase11) + if (doswap_) then + call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),& + & czero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + end if + if (do_timings) call psb_toc(mv_phase11) + if (do_timings) call psb_tic(mv_phase12) + call psb_csmm(alpha,a,x,beta,y,info) + if (do_timings) call psb_toc(mv_phase12) + end block end if if(info /= psb_success_) then diff --git a/base/psblas/psb_dspmm.f90 b/base/psblas/psb_dspmm.f90 index 8e48c4c2..fa256276 100644 --- a/base/psblas/psb_dspmm.f90 +++ b/base/psblas/psb_dspmm.f90 @@ -194,48 +194,45 @@ subroutine psb_dspmv_vect(alpha,a,x,beta,y,desc_a,info,& if (trans_ == 'N') then ! Matrix is not transposed - + if (allocated(a%ad)) then block - logical, parameter :: do_timings=.true. - real(psb_dpk_) :: t1, t2, t3, t4, t5 - !write(0,*) 'Going for overlap ',a%ad%get_fmt(),' ',a%and%get_fmt() - if (do_timings) call psb_barrier(ctxt) - if (do_timings) call psb_tic(mv_phase1) - if (doswap_) call psi_swapdata(psb_swap_send_,& - & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - if (do_timings) call psb_toc(mv_phase1) - if (do_timings) call psb_tic(mv_phase2) - call a%ad%spmm(alpha,x%v,beta,y%v,info) - if (do_timings) call psb_toc(mv_phase2) - if (do_timings) call psb_tic(mv_phase3) - if (doswap_) call psi_swapdata(psb_swap_recv_,& - & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - if (do_timings) call psb_toc(mv_phase3) - if (do_timings) call psb_tic(mv_phase4) - if (do_timings) t4= psb_wtime() - call a%and%spmm(alpha,x%v,done,y%v,info) - if (do_timings) call psb_toc(mv_phase4) - - end block + logical, parameter :: do_timings=.true. + real(psb_dpk_) :: t1, t2, t3, t4, t5 + !if (me==0) write(0,*) 'going for overlap ',a%ad%get_fmt(),' ',a%and%get_fmt() + if (do_timings) call psb_barrier(ctxt) + if (do_timings) call psb_tic(mv_phase1) + if (doswap_) call psi_swapdata(psb_swap_send_,& + & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + if (do_timings) call psb_toc(mv_phase1) + if (do_timings) call psb_tic(mv_phase2) + call a%ad%spmm(alpha,x%v,beta,y%v,info) + if (do_timings) call psb_tic(mv_phase3) + if (doswap_) call psi_swapdata(psb_swap_recv_,& + & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + if (do_timings) call psb_toc(mv_phase3) + if (do_timings) call psb_tic(mv_phase4) + call a%and%spmm(alpha,x%v,done,y%v,info) + if (do_timings) call psb_toc(mv_phase4) + end block else block - logical, parameter :: do_timings=.true. - real(psb_dpk_) :: t1, t2, t3, t4, t5 - if (do_timings) call psb_barrier(ctxt) - - if (do_timings) call psb_tic(mv_phase11) - if (doswap_) then - call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),& - & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - end if - if (do_timings) call psb_toc(mv_phase11) - if (do_timings) call psb_tic(mv_phase12) - call psb_csmm(alpha,a,x,beta,y,info) - if (do_timings) call psb_toc(mv_phase12) - end block - end if + logical, parameter :: do_timings=.true. + real(psb_dpk_) :: t1, t2, t3, t4, t5 + if (do_timings) call psb_barrier(ctxt) + + if (do_timings) call psb_tic(mv_phase11) + if (doswap_) then + call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),& + & dzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + end if + if (do_timings) call psb_toc(mv_phase11) + if (do_timings) call psb_tic(mv_phase12) + call psb_csmm(alpha,a,x,beta,y,info) + if (do_timings) call psb_toc(mv_phase12) + end block + end if if(info /= psb_success_) then info = psb_err_from_subroutine_non_ diff --git a/base/psblas/psb_sspmm.f90 b/base/psblas/psb_sspmm.f90 index cf8919f0..6c723831 100644 --- a/base/psblas/psb_sspmm.f90 +++ b/base/psblas/psb_sspmm.f90 @@ -83,6 +83,9 @@ subroutine psb_sspmv_vect(alpha,a,x,beta,y,desc_a,info,& character(len=20) :: name, ch_err logical :: aliw, doswap_ integer(psb_ipk_) :: debug_level, debug_unit + logical, parameter :: do_timings=.true. + integer(psb_ipk_), save :: mv_phase1=-1, mv_phase2=-1, mv_phase3=-1, mv_phase4=-1 + integer(psb_ipk_), save :: mv_phase11=-1, mv_phase12=-1 name='psb_sspmv' info=psb_success_ @@ -130,6 +133,19 @@ subroutine psb_sspmv_vect(alpha,a,x,beta,y,desc_a,info,& call psb_errpush(info,name) goto 9999 end if + if ((do_timings).and.(mv_phase1==-1)) & + & mv_phase1 = psb_get_timer_idx("SPMM: and send ") + if ((do_timings).and.(mv_phase2==-1)) & + & mv_phase2 = psb_get_timer_idx("SPMM: and cmp ad") + if ((do_timings).and.(mv_phase3==-1)) & + & mv_phase3 = psb_get_timer_idx("SPMM: and rcv") + if ((do_timings).and.(mv_phase4==-1)) & + & mv_phase4 = psb_get_timer_idx("SPMM: and cmp and") + if ((do_timings).and.(mv_phase11==-1)) & + & mv_phase11 = psb_get_timer_idx("SPMM: noand exch ") + if ((do_timings).and.(mv_phase12==-1)) & + & mv_phase12 = psb_get_timer_idx("SPMM: noand cmp") + m = desc_a%get_global_rows() n = desc_a%get_global_cols() @@ -178,34 +194,44 @@ subroutine psb_sspmv_vect(alpha,a,x,beta,y,desc_a,info,& if (trans_ == 'N') then ! Matrix is not transposed - + if (allocated(a%ad)) then block - logical, parameter :: do_timings=.true. - real(psb_dpk_) :: t1, t2, t3, t4, t5 - if (do_timings) call psb_barrier(ctxt) - if (do_timings) t1= psb_wtime() - if (doswap_) call psi_swapdata(psb_swap_send_,& - & szero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - if (do_timings) t2= psb_wtime() - call a%ad%spmm(alpha,x%v,beta,y%v,info) - if (do_timings) t3= psb_wtime() - if (doswap_) call psi_swapdata(psb_swap_recv_,& - & szero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - if (do_timings) t4= psb_wtime() - call a%and%spmm(alpha,x%v,sone,y%v,info) - if (do_timings) t5= psb_wtime() - if (do_timings) write(0,*) me,' SPMM:',t2-t1,t3-t2,t4-t3,t5-t4 - end block - - else - if (doswap_) then - call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),& + logical, parameter :: do_timings=.true. + real(psb_dpk_) :: t1, t2, t3, t4, t5 + !if (me==0) write(0,*) 'going for overlap ',a%ad%get_fmt(),' ',a%and%get_fmt() + if (do_timings) call psb_barrier(ctxt) + if (do_timings) call psb_tic(mv_phase1) + if (doswap_) call psi_swapdata(psb_swap_send_,& & szero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - end if - - call psb_csmm(alpha,a,x,beta,y,info) + if (do_timings) call psb_toc(mv_phase1) + if (do_timings) call psb_tic(mv_phase2) + call a%ad%spmm(alpha,x%v,beta,y%v,info) + if (do_timings) call psb_tic(mv_phase3) + if (doswap_) call psi_swapdata(psb_swap_recv_,& + & szero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + if (do_timings) call psb_toc(mv_phase3) + if (do_timings) call psb_tic(mv_phase4) + call a%and%spmm(alpha,x%v,sone,y%v,info) + if (do_timings) call psb_toc(mv_phase4) + end block + else + block + logical, parameter :: do_timings=.true. + real(psb_dpk_) :: t1, t2, t3, t4, t5 + if (do_timings) call psb_barrier(ctxt) + + if (do_timings) call psb_tic(mv_phase11) + if (doswap_) then + call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),& + & szero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + end if + if (do_timings) call psb_toc(mv_phase11) + if (do_timings) call psb_tic(mv_phase12) + call psb_csmm(alpha,a,x,beta,y,info) + if (do_timings) call psb_toc(mv_phase12) + end block end if if(info /= psb_success_) then diff --git a/base/psblas/psb_zspmm.f90 b/base/psblas/psb_zspmm.f90 index 629fcf2b..179e4fad 100644 --- a/base/psblas/psb_zspmm.f90 +++ b/base/psblas/psb_zspmm.f90 @@ -83,6 +83,9 @@ subroutine psb_zspmv_vect(alpha,a,x,beta,y,desc_a,info,& character(len=20) :: name, ch_err logical :: aliw, doswap_ integer(psb_ipk_) :: debug_level, debug_unit + logical, parameter :: do_timings=.true. + integer(psb_ipk_), save :: mv_phase1=-1, mv_phase2=-1, mv_phase3=-1, mv_phase4=-1 + integer(psb_ipk_), save :: mv_phase11=-1, mv_phase12=-1 name='psb_zspmv' info=psb_success_ @@ -130,6 +133,19 @@ subroutine psb_zspmv_vect(alpha,a,x,beta,y,desc_a,info,& call psb_errpush(info,name) goto 9999 end if + if ((do_timings).and.(mv_phase1==-1)) & + & mv_phase1 = psb_get_timer_idx("SPMM: and send ") + if ((do_timings).and.(mv_phase2==-1)) & + & mv_phase2 = psb_get_timer_idx("SPMM: and cmp ad") + if ((do_timings).and.(mv_phase3==-1)) & + & mv_phase3 = psb_get_timer_idx("SPMM: and rcv") + if ((do_timings).and.(mv_phase4==-1)) & + & mv_phase4 = psb_get_timer_idx("SPMM: and cmp and") + if ((do_timings).and.(mv_phase11==-1)) & + & mv_phase11 = psb_get_timer_idx("SPMM: noand exch ") + if ((do_timings).and.(mv_phase12==-1)) & + & mv_phase12 = psb_get_timer_idx("SPMM: noand cmp") + m = desc_a%get_global_rows() n = desc_a%get_global_cols() @@ -178,34 +194,44 @@ subroutine psb_zspmv_vect(alpha,a,x,beta,y,desc_a,info,& if (trans_ == 'N') then ! Matrix is not transposed - + if (allocated(a%ad)) then block - logical, parameter :: do_timings=.true. - real(psb_dpk_) :: t1, t2, t3, t4, t5 - if (do_timings) call psb_barrier(ctxt) - if (do_timings) t1= psb_wtime() - if (doswap_) call psi_swapdata(psb_swap_send_,& - & zzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - if (do_timings) t2= psb_wtime() - call a%ad%spmm(alpha,x%v,beta,y%v,info) - if (do_timings) t3= psb_wtime() - if (doswap_) call psi_swapdata(psb_swap_recv_,& - & zzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - if (do_timings) t4= psb_wtime() - call a%and%spmm(alpha,x%v,zone,y%v,info) - if (do_timings) t5= psb_wtime() - if (do_timings) write(0,*) me,' SPMM:',t2-t1,t3-t2,t4-t3,t5-t4 - end block - - else - if (doswap_) then - call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),& + logical, parameter :: do_timings=.true. + real(psb_dpk_) :: t1, t2, t3, t4, t5 + !if (me==0) write(0,*) 'going for overlap ',a%ad%get_fmt(),' ',a%and%get_fmt() + if (do_timings) call psb_barrier(ctxt) + if (do_timings) call psb_tic(mv_phase1) + if (doswap_) call psi_swapdata(psb_swap_send_,& & zzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) - end if - - call psb_csmm(alpha,a,x,beta,y,info) + if (do_timings) call psb_toc(mv_phase1) + if (do_timings) call psb_tic(mv_phase2) + call a%ad%spmm(alpha,x%v,beta,y%v,info) + if (do_timings) call psb_tic(mv_phase3) + if (doswap_) call psi_swapdata(psb_swap_recv_,& + & zzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + if (do_timings) call psb_toc(mv_phase3) + if (do_timings) call psb_tic(mv_phase4) + call a%and%spmm(alpha,x%v,zone,y%v,info) + if (do_timings) call psb_toc(mv_phase4) + end block + else + block + logical, parameter :: do_timings=.true. + real(psb_dpk_) :: t1, t2, t3, t4, t5 + if (do_timings) call psb_barrier(ctxt) + + if (do_timings) call psb_tic(mv_phase11) + if (doswap_) then + call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),& + & zzero,x%v,desc_a,iwork,info,data=psb_comm_halo_) + end if + if (do_timings) call psb_toc(mv_phase11) + if (do_timings) call psb_tic(mv_phase12) + call psb_csmm(alpha,a,x,beta,y,info) + if (do_timings) call psb_toc(mv_phase12) + end block end if if(info /= psb_success_) then diff --git a/base/serial/impl/psb_c_mat_impl.F90 b/base/serial/impl/psb_c_mat_impl.F90 index df5c4cd9..bbac0406 100644 --- a/base/serial/impl/psb_c_mat_impl.F90 +++ b/base/serial/impl/psb_c_mat_impl.F90 @@ -1213,6 +1213,56 @@ subroutine psb_c_b_csclip(a,b,info,& end subroutine psb_c_b_csclip +subroutine psb_c_split_nd(a,n_rows,n_cols,info) + use psb_error_mod + use psb_string_mod + use psb_c_mat_mod, psb_protect_name => psb_c_split_nd + implicit none + class(psb_cspmat_type), intent(inout) :: a + integer(psb_ipk_), intent(in) :: n_rows, n_cols + integer(psb_ipk_), intent(out) :: info +!!$ integer(psb_ipk_),optional, intent(in) :: dupl +!!$ character(len=*), optional, intent(in) :: type +!!$ class(psb_c_base_sparse_mat), intent(in), optional :: mold + type(psb_c_coo_sparse_mat) :: acoo + type(psb_c_csr_sparse_mat), allocatable :: aclip + type(psb_c_ecsr_sparse_mat), allocatable :: andclip + logical, parameter :: use_ecsr=.true. + character(len=20) :: name, ch_err + integer(psb_ipk_) :: err_act + + info = psb_success_ + name = 'psb_split' + call psb_erractionsave(err_act) + allocate(aclip) + call a%a%csclip(acoo,info,jmax=n_rows,rscale=.false.,cscale=.false.) + allocate(a%ad,mold=a%a) + call a%ad%mv_from_coo(acoo,info) + call a%a%csclip(acoo,info,jmin=n_rows+1,jmax=n_cols,rscale=.false.,cscale=.false.) + if (use_ecsr) then + allocate(andclip) + call andclip%mv_from_coo(acoo,info) + call move_alloc(andclip,a%and) + else + allocate(a%and,mold=a%a) + call a%and%mv_from_coo(acoo,info) + end if + + if (psb_errstatus_fatal()) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='cscnv') + goto 9999 + endif + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + + return + +end subroutine psb_c_split_nd + subroutine psb_c_cscnv(a,b,info,type,mold,upd,dupl) use psb_error_mod use psb_string_mod @@ -1246,54 +1296,65 @@ subroutine psb_c_cscnv(a,b,info,type,mold,upd,dupl) goto 9999 end if - if (present(mold)) then - - allocate(altmp, mold=mold,stat=info) - - else if (present(type)) then + if (.false.) then + if (present(mold)) then + + allocate(altmp, mold=mold,stat=info) + + else if (present(type)) then + + select case (psb_toupper(type)) + case ('CSR') + allocate(psb_c_csr_sparse_mat :: altmp, stat=info) + case ('COO') + allocate(psb_c_coo_sparse_mat :: altmp, stat=info) + case ('CSC') + allocate(psb_c_csc_sparse_mat :: altmp, stat=info) + case default + info = psb_err_format_unknown_ + call psb_errpush(info,name,a_err=type) + goto 9999 + end select + else + allocate(altmp, mold=psb_get_mat_default(a),stat=info) + end if - select case (psb_toupper(type)) - case ('CSR') - allocate(psb_c_csr_sparse_mat :: altmp, stat=info) - case ('COO') - allocate(psb_c_coo_sparse_mat :: altmp, stat=info) - case ('CSC') - allocate(psb_c_csc_sparse_mat :: altmp, stat=info) - case default - info = psb_err_format_unknown_ - call psb_errpush(info,name,a_err=type) + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info,name) goto 9999 - end select - else - allocate(altmp, mold=psb_get_mat_default(a),stat=info) - end if + end if - if (info /= psb_success_) then - info = psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if + if (present(dupl)) then + call altmp%set_dupl(dupl) + else if (a%is_bld()) then + ! Does this make sense at all?? Who knows.. + call altmp%set_dupl(psb_dupl_def_) + end if - if (present(dupl)) then - call altmp%set_dupl(dupl) - else if (a%is_bld()) then - ! Does this make sense at all?? Who knows.. - call altmp%set_dupl(psb_dupl_def_) - end if + if (debug) write(psb_err_unit,*) 'Converting from ',& + & a%get_fmt(),' to ',altmp%get_fmt() - if (debug) write(psb_err_unit,*) 'Converting from ',& - & a%get_fmt(),' to ',altmp%get_fmt() + call altmp%cp_from_fmt(a%a, info) - call altmp%cp_from_fmt(a%a, info) + if (info /= psb_success_) then + info = psb_err_from_subroutine_ + call psb_errpush(info,name,a_err="mv_from") + goto 9999 + end if - if (info /= psb_success_) then - info = psb_err_from_subroutine_ - call psb_errpush(info,name,a_err="mv_from") - goto 9999 + call move_alloc(altmp,b%a) + else + call inner_cp_fmt(a%a,b%a,info,type,mold,dupl) + if (allocated(a%ad)) then + call inner_cp_fmt(a%ad,b%ad,info,type,mold,dupl) + end if + if (allocated(a%and)) then + call inner_cp_fmt(a%and,b%and,info,type,mold,dupl) + end if end if - call move_alloc(altmp,b%a) call b%trim() call b%set_asb() call psb_erractionrestore(err_act) @@ -1303,7 +1364,79 @@ subroutine psb_c_cscnv(a,b,info,type,mold,upd,dupl) 9999 call psb_error_handler(err_act) return +contains + subroutine inner_cp_fmt(a,b,info,type,mold,dupl) + class(psb_c_base_sparse_mat), intent(in) :: a + class(psb_c_base_sparse_mat), intent(inout), allocatable :: b + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_),optional, intent(in) :: dupl + character(len=*), optional, intent(in) :: type + class(psb_c_base_sparse_mat), intent(in), optional :: mold + + class(psb_c_base_sparse_mat), allocatable :: altmp + integer(psb_ipk_) :: err_act + + info = psb_success_ + call psb_erractionsave(err_act) + + if (present(mold)) then + + allocate(altmp, mold=mold,stat=info) + + else if (present(type)) then + + select case (psb_toupper(type)) + case ('CSR') + allocate(psb_c_csr_sparse_mat :: altmp, stat=info) + case ('COO') + allocate(psb_c_coo_sparse_mat :: altmp, stat=info) + case ('CSC') + allocate(psb_c_csc_sparse_mat :: altmp, stat=info) + case default + info = psb_err_format_unknown_ + call psb_errpush(info,name,a_err=type) + goto 9999 + end select + else + allocate(psb_c_csr_sparse_mat :: altmp, stat=info) + !allocate(altmp, mold=psb_get_mat_default(a),stat=info) + end if + + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + + + if (present(dupl)) then + call altmp%set_dupl(dupl) + else if (a%is_bld()) then + ! Does this make sense at all?? Who knows.. + call altmp%set_dupl(psb_dupl_def_) + end if + + if (debug) write(psb_err_unit,*) 'Converting from ',& + & a%get_fmt(),' to ',altmp%get_fmt() + + call altmp%cp_from_fmt(a, info) + + if (info /= psb_success_) then + info = psb_err_from_subroutine_ + call psb_errpush(info,name,a_err="mv_from") + goto 9999 + end if + + call move_alloc(altmp,b) + call psb_erractionrestore(err_act) + return + + +9999 call psb_error_handler(err_act) + + return + end subroutine inner_cp_fmt end subroutine psb_c_cscnv subroutine psb_c_cscnv_ip(a,info,type,mold,dupl) @@ -1312,13 +1445,12 @@ subroutine psb_c_cscnv_ip(a,info,type,mold,dupl) use psb_c_mat_mod, psb_protect_name => psb_c_cscnv_ip implicit none - class(psb_cspmat_type), intent(inout) :: a - integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_),optional, intent(in) :: dupl - character(len=*), optional, intent(in) :: type + class(psb_cspmat_type), intent(inout) :: a + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_),optional, intent(in) :: dupl + character(len=*), optional, intent(in) :: type class(psb_c_base_sparse_mat), intent(in), optional :: mold - class(psb_c_base_sparse_mat), allocatable :: altmp integer(psb_ipk_) :: err_act character(len=20) :: name='cscnv_ip' @@ -1345,46 +1477,55 @@ subroutine psb_c_cscnv_ip(a,info,type,mold,dupl) goto 9999 end if - if (present(mold)) then + if (.false.) then + if (present(mold)) then + + allocate(altmp, mold=mold,stat=info) + + else if (present(type)) then + + select case (psb_toupper(type)) + case ('CSR') + allocate(psb_c_csr_sparse_mat :: altmp, stat=info) + case ('COO') + allocate(psb_c_coo_sparse_mat :: altmp, stat=info) + case ('CSC') + allocate(psb_c_csc_sparse_mat :: altmp, stat=info) + case default + info = psb_err_format_unknown_ + call psb_errpush(info,name,a_err=type) + goto 9999 + end select + else + allocate(altmp, mold=psb_get_mat_default(a),stat=info) + end if - allocate(altmp, mold=mold,stat=info) + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if - else if (present(type)) then + if (debug) write(psb_err_unit,*) 'Converting in-place from ',& + & a%get_fmt(),' to ',altmp%get_fmt() - select case (psb_toupper(type)) - case ('CSR') - allocate(psb_c_csr_sparse_mat :: altmp, stat=info) - case ('COO') - allocate(psb_c_coo_sparse_mat :: altmp, stat=info) - case ('CSC') - allocate(psb_c_csc_sparse_mat :: altmp, stat=info) - case default - info = psb_err_format_unknown_ - call psb_errpush(info,name,a_err=type) - goto 9999 - end select + call altmp%mv_from_fmt(a%a, info) + call move_alloc(altmp,a%a) else - allocate(altmp, mold=psb_get_mat_default(a),stat=info) + call inner_mv_fmt(a%a,info,type,mold,dupl) + if (allocated(a%ad)) then + call inner_mv_fmt(a%ad,info,type,mold,dupl) + end if + if (allocated(a%and)) then + call inner_mv_fmt(a%and,info,type,mold,dupl) + end if end if - - if (info /= psb_success_) then - info = psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if - - if (debug) write(psb_err_unit,*) 'Converting in-place from ',& - & a%get_fmt(),' to ',altmp%get_fmt() - - call altmp%mv_from_fmt(a%a, info) - if (info /= psb_success_) then info = psb_err_from_subroutine_ call psb_errpush(info,name,a_err="mv_from") goto 9999 end if - call move_alloc(altmp,a%a) call a%trim() call a%set_asb() call psb_erractionrestore(err_act) @@ -1394,6 +1535,77 @@ subroutine psb_c_cscnv_ip(a,info,type,mold,dupl) 9999 call psb_error_handler(err_act) return +contains + subroutine inner_mv_fmt(a,info,type,mold,dupl) + class(psb_c_base_sparse_mat), intent(inout), allocatable :: a + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_),optional, intent(in) :: dupl + character(len=*), optional, intent(in) :: type + class(psb_c_base_sparse_mat), intent(in), optional :: mold + class(psb_c_base_sparse_mat), allocatable :: altmp + integer(psb_ipk_) :: err_act + + info = psb_success_ + call psb_erractionsave(err_act) + + if (present(mold)) then + + allocate(altmp, mold=mold,stat=info) + + else if (present(type)) then + + select case (psb_toupper(type)) + case ('CSR') + allocate(psb_c_csr_sparse_mat :: altmp, stat=info) + case ('COO') + allocate(psb_c_coo_sparse_mat :: altmp, stat=info) + case ('CSC') + allocate(psb_c_csc_sparse_mat :: altmp, stat=info) + case default + info = psb_err_format_unknown_ + call psb_errpush(info,name,a_err=type) + goto 9999 + end select + else + allocate(psb_c_csr_sparse_mat :: altmp, stat=info) + !allocate(altmp, mold=psb_get_mat_default(a),stat=info) + end if + + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + + + if (present(dupl)) then + call altmp%set_dupl(dupl) + else if (a%is_bld()) then + ! Does this make sense at all?? Who knows.. + call altmp%set_dupl(psb_dupl_def_) + end if + + if (debug) write(psb_err_unit,*) 'Converting from ',& + & a%get_fmt(),' to ',altmp%get_fmt() + + call altmp%mv_from_fmt(a, info) + + if (info /= psb_success_) then + info = psb_err_from_subroutine_ + call psb_errpush(info,name,a_err="mv_from") + goto 9999 + end if + + call move_alloc(altmp,a) + + call psb_erractionrestore(err_act) + return + + +9999 call psb_error_handler(err_act) + + return + end subroutine inner_mv_fmt end subroutine psb_c_cscnv_ip diff --git a/base/serial/impl/psb_d_mat_impl.F90 b/base/serial/impl/psb_d_mat_impl.F90 index caf725d1..9af64b3f 100644 --- a/base/serial/impl/psb_d_mat_impl.F90 +++ b/base/serial/impl/psb_d_mat_impl.F90 @@ -1213,6 +1213,56 @@ subroutine psb_d_b_csclip(a,b,info,& end subroutine psb_d_b_csclip +subroutine psb_d_split_nd(a,n_rows,n_cols,info) + use psb_error_mod + use psb_string_mod + use psb_d_mat_mod, psb_protect_name => psb_d_split_nd + implicit none + class(psb_dspmat_type), intent(inout) :: a + integer(psb_ipk_), intent(in) :: n_rows, n_cols + integer(psb_ipk_), intent(out) :: info +!!$ integer(psb_ipk_),optional, intent(in) :: dupl +!!$ character(len=*), optional, intent(in) :: type +!!$ class(psb_d_base_sparse_mat), intent(in), optional :: mold + type(psb_d_coo_sparse_mat) :: acoo + type(psb_d_csr_sparse_mat), allocatable :: aclip + type(psb_d_ecsr_sparse_mat), allocatable :: andclip + logical, parameter :: use_ecsr=.true. + character(len=20) :: name, ch_err + integer(psb_ipk_) :: err_act + + info = psb_success_ + name = 'psb_split' + call psb_erractionsave(err_act) + allocate(aclip) + call a%a%csclip(acoo,info,jmax=n_rows,rscale=.false.,cscale=.false.) + allocate(a%ad,mold=a%a) + call a%ad%mv_from_coo(acoo,info) + call a%a%csclip(acoo,info,jmin=n_rows+1,jmax=n_cols,rscale=.false.,cscale=.false.) + if (use_ecsr) then + allocate(andclip) + call andclip%mv_from_coo(acoo,info) + call move_alloc(andclip,a%and) + else + allocate(a%and,mold=a%a) + call a%and%mv_from_coo(acoo,info) + end if + + if (psb_errstatus_fatal()) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='cscnv') + goto 9999 + endif + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + + return + +end subroutine psb_d_split_nd + subroutine psb_d_cscnv(a,b,info,type,mold,upd,dupl) use psb_error_mod use psb_string_mod @@ -1246,65 +1296,64 @@ subroutine psb_d_cscnv(a,b,info,type,mold,upd,dupl) goto 9999 end if -!!$ if (present(mold)) then -!!$ -!!$ allocate(altmp, mold=mold,stat=info) -!!$ -!!$ else if (present(type)) then -!!$ -!!$ select case (psb_toupper(type)) -!!$ case ('CSR') -!!$ allocate(psb_d_csr_sparse_mat :: altmp, stat=info) -!!$ case ('COO') -!!$ allocate(psb_d_coo_sparse_mat :: altmp, stat=info) -!!$ case ('CSC') -!!$ allocate(psb_d_csc_sparse_mat :: altmp, stat=info) -!!$ case default -!!$ info = psb_err_format_unknown_ -!!$ call psb_errpush(info,name,a_err=type) -!!$ goto 9999 -!!$ end select -!!$ else -!!$ allocate(altmp, mold=psb_get_mat_default(a),stat=info) -!!$ end if -!!$ -!!$ if (info /= psb_success_) then -!!$ info = psb_err_alloc_dealloc_ -!!$ call psb_errpush(info,name) -!!$ goto 9999 -!!$ end if -!!$ -!!$ -!!$ if (present(dupl)) then -!!$ call altmp%set_dupl(dupl) -!!$ else if (a%is_bld()) then -!!$ ! Does this make sense at all?? Who knows.. -!!$ call altmp%set_dupl(psb_dupl_def_) -!!$ end if -!!$ -!!$ if (debug) write(psb_err_unit,*) 'Converting from ',& -!!$ & a%get_fmt(),' to ',altmp%get_fmt() -!!$ -!!$ call altmp%cp_from_fmt(a%a, info) -!!$ -!!$ if (info /= psb_success_) then -!!$ info = psb_err_from_subroutine_ -!!$ call psb_errpush(info,name,a_err="mv_from") -!!$ goto 9999 -!!$ end if -!!$ -!!$ call move_alloc(altmp,b%a) - call inner_cp_alloc(a%a,b%a,info,type,mold) - if (info /= 0) goto 9999 - if (allocated(a%ad)) then - call inner_cp_alloc(a%ad,b%ad,info,type,mold) - if (info /= 0) goto 9999 - end if - if (allocated(a%and)) then - call inner_cp_alloc(a%and,b%and,info,type,mold) - if (info /= 0) goto 9999 - end if + if (.false.) then + if (present(mold)) then + allocate(altmp, mold=mold,stat=info) + + else if (present(type)) then + + select case (psb_toupper(type)) + case ('CSR') + allocate(psb_d_csr_sparse_mat :: altmp, stat=info) + case ('COO') + allocate(psb_d_coo_sparse_mat :: altmp, stat=info) + case ('CSC') + allocate(psb_d_csc_sparse_mat :: altmp, stat=info) + case default + info = psb_err_format_unknown_ + call psb_errpush(info,name,a_err=type) + goto 9999 + end select + else + allocate(altmp, mold=psb_get_mat_default(a),stat=info) + end if + + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + + + if (present(dupl)) then + call altmp%set_dupl(dupl) + else if (a%is_bld()) then + ! Does this make sense at all?? Who knows.. + call altmp%set_dupl(psb_dupl_def_) + end if + + if (debug) write(psb_err_unit,*) 'Converting from ',& + & a%get_fmt(),' to ',altmp%get_fmt() + + call altmp%cp_from_fmt(a%a, info) + + if (info /= psb_success_) then + info = psb_err_from_subroutine_ + call psb_errpush(info,name,a_err="mv_from") + goto 9999 + end if + + call move_alloc(altmp,b%a) + else + call inner_cp_fmt(a%a,b%a,info,type,mold,dupl) + if (allocated(a%ad)) then + call inner_cp_fmt(a%ad,b%ad,info,type,mold,dupl) + end if + if (allocated(a%and)) then + call inner_cp_fmt(a%and,b%and,info,type,mold,dupl) + end if + end if call b%trim() call b%set_asb() @@ -1316,24 +1365,26 @@ subroutine psb_d_cscnv(a,b,info,type,mold,upd,dupl) return contains - subroutine inner_cp_alloc(a,b,info,type,mold) + subroutine inner_cp_fmt(a,b,info,type,mold,dupl) class(psb_d_base_sparse_mat), intent(in) :: a - class(psb_d_base_sparse_mat), intent(inout), allocatable :: b - integer(psb_ipk_), intent(out) :: info + class(psb_d_base_sparse_mat), intent(inout), allocatable :: b + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_),optional, intent(in) :: dupl character(len=*), optional, intent(in) :: type class(psb_d_base_sparse_mat), intent(in), optional :: mold - + class(psb_d_base_sparse_mat), allocatable :: altmp + integer(psb_ipk_) :: err_act info = psb_success_ call psb_erractionsave(err_act) - + if (present(mold)) then - + allocate(altmp, mold=mold,stat=info) - + else if (present(type)) then - + select case (psb_toupper(type)) case ('CSR') allocate(psb_d_csr_sparse_mat :: altmp, stat=info) @@ -1347,38 +1398,45 @@ contains goto 9999 end select else -!!$ allocate(altmp, mold=psb_get_mat_default(a),stat=info) allocate(psb_d_csr_sparse_mat :: altmp, stat=info) + !allocate(altmp, mold=psb_get_mat_default(a),stat=info) end if - + if (info /= psb_success_) then info = psb_err_alloc_dealloc_ call psb_errpush(info,name) goto 9999 end if - - if (debug) write(psb_err_unit,*) 'Converting in-place from ',& + + + if (present(dupl)) then + call altmp%set_dupl(dupl) + else if (a%is_bld()) then + ! Does this make sense at all?? Who knows.. + call altmp%set_dupl(psb_dupl_def_) + end if + + if (debug) write(psb_err_unit,*) 'Converting from ',& & a%get_fmt(),' to ',altmp%get_fmt() - + call altmp%cp_from_fmt(a, info) - + if (info /= psb_success_) then info = psb_err_from_subroutine_ call psb_errpush(info,name,a_err="mv_from") goto 9999 end if - + call move_alloc(altmp,b) - + call psb_erractionrestore(err_act) return - - + + 9999 call psb_error_handler(err_act) - - return - end subroutine inner_cp_alloc + return + end subroutine inner_cp_fmt end subroutine psb_d_cscnv subroutine psb_d_cscnv_ip(a,info,type,mold,dupl) @@ -1387,13 +1445,12 @@ subroutine psb_d_cscnv_ip(a,info,type,mold,dupl) use psb_d_mat_mod, psb_protect_name => psb_d_cscnv_ip implicit none - class(psb_dspmat_type), intent(inout) :: a - integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_),optional, intent(in) :: dupl - character(len=*), optional, intent(in) :: type + class(psb_dspmat_type), intent(inout) :: a + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_),optional, intent(in) :: dupl + character(len=*), optional, intent(in) :: type class(psb_d_base_sparse_mat), intent(in), optional :: mold - class(psb_d_base_sparse_mat), allocatable :: altmp integer(psb_ipk_) :: err_act character(len=20) :: name='cscnv_ip' @@ -1420,57 +1477,55 @@ subroutine psb_d_cscnv_ip(a,info,type,mold,dupl) goto 9999 end if -!!$ if (present(mold)) then -!!$ -!!$ allocate(altmp, mold=mold,stat=info) -!!$ -!!$ else if (present(type)) then -!!$ -!!$ select case (psb_toupper(type)) -!!$ case ('CSR') -!!$ allocate(psb_d_csr_sparse_mat :: altmp, stat=info) -!!$ case ('COO') -!!$ allocate(psb_d_coo_sparse_mat :: altmp, stat=info) -!!$ case ('CSC') -!!$ allocate(psb_d_csc_sparse_mat :: altmp, stat=info) -!!$ case default -!!$ info = psb_err_format_unknown_ -!!$ call psb_errpush(info,name,a_err=type) -!!$ goto 9999 -!!$ end select -!!$ else -!!$ allocate(altmp, mold=psb_get_mat_default(a),stat=info) -!!$ end if -!!$ -!!$ if (info /= psb_success_) then -!!$ info = psb_err_alloc_dealloc_ -!!$ call psb_errpush(info,name) -!!$ goto 9999 -!!$ end if -!!$ -!!$ if (debug) write(psb_err_unit,*) 'Converting in-place from ',& -!!$ & a%get_fmt(),' to ',altmp%get_fmt() -!!$ -!!$ call altmp%mv_from_fmt(a%a, info) -!!$ -!!$ if (info /= psb_success_) then -!!$ info = psb_err_from_subroutine_ -!!$ call psb_errpush(info,name,a_err="mv_from") -!!$ goto 9999 -!!$ end if -!!$ -!!$ call move_alloc(altmp,a%a) - - call inner_mv_alloc(a%a,info,type,mold) - if (info /= 0) goto 9999 - if (allocated(a%ad)) then - call inner_mv_alloc(a%ad,info,type,mold) - if (info /= 0) goto 9999 + if (.false.) then + if (present(mold)) then + + allocate(altmp, mold=mold,stat=info) + + else if (present(type)) then + + select case (psb_toupper(type)) + case ('CSR') + allocate(psb_d_csr_sparse_mat :: altmp, stat=info) + case ('COO') + allocate(psb_d_coo_sparse_mat :: altmp, stat=info) + case ('CSC') + allocate(psb_d_csc_sparse_mat :: altmp, stat=info) + case default + info = psb_err_format_unknown_ + call psb_errpush(info,name,a_err=type) + goto 9999 + end select + else + allocate(altmp, mold=psb_get_mat_default(a),stat=info) + end if + + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + + if (debug) write(psb_err_unit,*) 'Converting in-place from ',& + & a%get_fmt(),' to ',altmp%get_fmt() + + call altmp%mv_from_fmt(a%a, info) + call move_alloc(altmp,a%a) + else + call inner_mv_fmt(a%a,info,type,mold,dupl) + if (allocated(a%ad)) then + call inner_mv_fmt(a%ad,info,type,mold,dupl) + end if + if (allocated(a%and)) then + call inner_mv_fmt(a%and,info,type,mold,dupl) + end if end if - if (allocated(a%and)) then - call inner_mv_alloc(a%and,info,type,mold) - if (info /= 0) goto 9999 + if (info /= psb_success_) then + info = psb_err_from_subroutine_ + call psb_errpush(info,name,a_err="mv_from") + goto 9999 end if + call a%trim() call a%set_asb() call psb_erractionrestore(err_act) @@ -1481,23 +1536,24 @@ subroutine psb_d_cscnv_ip(a,info,type,mold,dupl) return contains - subroutine inner_mv_alloc(a,info,type,mold) - class(psb_d_base_sparse_mat), intent(inout), allocatable :: a + subroutine inner_mv_fmt(a,info,type,mold,dupl) + class(psb_d_base_sparse_mat), intent(inout), allocatable :: a integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_),optional, intent(in) :: dupl character(len=*), optional, intent(in) :: type class(psb_d_base_sparse_mat), intent(in), optional :: mold - class(psb_d_base_sparse_mat), allocatable :: altmp + integer(psb_ipk_) :: err_act info = psb_success_ call psb_erractionsave(err_act) - + if (present(mold)) then - + allocate(altmp, mold=mold,stat=info) - + else if (present(type)) then - + select case (psb_toupper(type)) case ('CSR') allocate(psb_d_csr_sparse_mat :: altmp, stat=info) @@ -1511,37 +1567,46 @@ contains goto 9999 end select else -!!$ allocate(altmp, mold=psb_get_mat_default(a),stat=info) allocate(psb_d_csr_sparse_mat :: altmp, stat=info) + !allocate(altmp, mold=psb_get_mat_default(a),stat=info) end if - + if (info /= psb_success_) then info = psb_err_alloc_dealloc_ call psb_errpush(info,name) goto 9999 end if - - if (debug) write(psb_err_unit,*) 'Converting in-place from ',& + + + if (present(dupl)) then + call altmp%set_dupl(dupl) + else if (a%is_bld()) then + ! Does this make sense at all?? Who knows.. + call altmp%set_dupl(psb_dupl_def_) + end if + + if (debug) write(psb_err_unit,*) 'Converting from ',& & a%get_fmt(),' to ',altmp%get_fmt() - + call altmp%mv_from_fmt(a, info) - + if (info /= psb_success_) then info = psb_err_from_subroutine_ call psb_errpush(info,name,a_err="mv_from") goto 9999 end if - + call move_alloc(altmp,a) - + call psb_erractionrestore(err_act) return - - + + 9999 call psb_error_handler(err_act) - + return - end subroutine inner_mv_alloc + end subroutine inner_mv_fmt + end subroutine psb_d_cscnv_ip diff --git a/base/serial/impl/psb_s_mat_impl.F90 b/base/serial/impl/psb_s_mat_impl.F90 index ce7ce653..c0370774 100644 --- a/base/serial/impl/psb_s_mat_impl.F90 +++ b/base/serial/impl/psb_s_mat_impl.F90 @@ -1213,6 +1213,56 @@ subroutine psb_s_b_csclip(a,b,info,& end subroutine psb_s_b_csclip +subroutine psb_s_split_nd(a,n_rows,n_cols,info) + use psb_error_mod + use psb_string_mod + use psb_s_mat_mod, psb_protect_name => psb_s_split_nd + implicit none + class(psb_sspmat_type), intent(inout) :: a + integer(psb_ipk_), intent(in) :: n_rows, n_cols + integer(psb_ipk_), intent(out) :: info +!!$ integer(psb_ipk_),optional, intent(in) :: dupl +!!$ character(len=*), optional, intent(in) :: type +!!$ class(psb_s_base_sparse_mat), intent(in), optional :: mold + type(psb_s_coo_sparse_mat) :: acoo + type(psb_s_csr_sparse_mat), allocatable :: aclip + type(psb_s_ecsr_sparse_mat), allocatable :: andclip + logical, parameter :: use_ecsr=.true. + character(len=20) :: name, ch_err + integer(psb_ipk_) :: err_act + + info = psb_success_ + name = 'psb_split' + call psb_erractionsave(err_act) + allocate(aclip) + call a%a%csclip(acoo,info,jmax=n_rows,rscale=.false.,cscale=.false.) + allocate(a%ad,mold=a%a) + call a%ad%mv_from_coo(acoo,info) + call a%a%csclip(acoo,info,jmin=n_rows+1,jmax=n_cols,rscale=.false.,cscale=.false.) + if (use_ecsr) then + allocate(andclip) + call andclip%mv_from_coo(acoo,info) + call move_alloc(andclip,a%and) + else + allocate(a%and,mold=a%a) + call a%and%mv_from_coo(acoo,info) + end if + + if (psb_errstatus_fatal()) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='cscnv') + goto 9999 + endif + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + + return + +end subroutine psb_s_split_nd + subroutine psb_s_cscnv(a,b,info,type,mold,upd,dupl) use psb_error_mod use psb_string_mod @@ -1246,54 +1296,65 @@ subroutine psb_s_cscnv(a,b,info,type,mold,upd,dupl) goto 9999 end if - if (present(mold)) then - - allocate(altmp, mold=mold,stat=info) - - else if (present(type)) then + if (.false.) then + if (present(mold)) then + + allocate(altmp, mold=mold,stat=info) + + else if (present(type)) then + + select case (psb_toupper(type)) + case ('CSR') + allocate(psb_s_csr_sparse_mat :: altmp, stat=info) + case ('COO') + allocate(psb_s_coo_sparse_mat :: altmp, stat=info) + case ('CSC') + allocate(psb_s_csc_sparse_mat :: altmp, stat=info) + case default + info = psb_err_format_unknown_ + call psb_errpush(info,name,a_err=type) + goto 9999 + end select + else + allocate(altmp, mold=psb_get_mat_default(a),stat=info) + end if - select case (psb_toupper(type)) - case ('CSR') - allocate(psb_s_csr_sparse_mat :: altmp, stat=info) - case ('COO') - allocate(psb_s_coo_sparse_mat :: altmp, stat=info) - case ('CSC') - allocate(psb_s_csc_sparse_mat :: altmp, stat=info) - case default - info = psb_err_format_unknown_ - call psb_errpush(info,name,a_err=type) + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info,name) goto 9999 - end select - else - allocate(altmp, mold=psb_get_mat_default(a),stat=info) - end if + end if - if (info /= psb_success_) then - info = psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if + if (present(dupl)) then + call altmp%set_dupl(dupl) + else if (a%is_bld()) then + ! Does this make sense at all?? Who knows.. + call altmp%set_dupl(psb_dupl_def_) + end if - if (present(dupl)) then - call altmp%set_dupl(dupl) - else if (a%is_bld()) then - ! Does this make sense at all?? Who knows.. - call altmp%set_dupl(psb_dupl_def_) - end if + if (debug) write(psb_err_unit,*) 'Converting from ',& + & a%get_fmt(),' to ',altmp%get_fmt() - if (debug) write(psb_err_unit,*) 'Converting from ',& - & a%get_fmt(),' to ',altmp%get_fmt() + call altmp%cp_from_fmt(a%a, info) - call altmp%cp_from_fmt(a%a, info) + if (info /= psb_success_) then + info = psb_err_from_subroutine_ + call psb_errpush(info,name,a_err="mv_from") + goto 9999 + end if - if (info /= psb_success_) then - info = psb_err_from_subroutine_ - call psb_errpush(info,name,a_err="mv_from") - goto 9999 + call move_alloc(altmp,b%a) + else + call inner_cp_fmt(a%a,b%a,info,type,mold,dupl) + if (allocated(a%ad)) then + call inner_cp_fmt(a%ad,b%ad,info,type,mold,dupl) + end if + if (allocated(a%and)) then + call inner_cp_fmt(a%and,b%and,info,type,mold,dupl) + end if end if - call move_alloc(altmp,b%a) call b%trim() call b%set_asb() call psb_erractionrestore(err_act) @@ -1303,7 +1364,79 @@ subroutine psb_s_cscnv(a,b,info,type,mold,upd,dupl) 9999 call psb_error_handler(err_act) return +contains + subroutine inner_cp_fmt(a,b,info,type,mold,dupl) + class(psb_s_base_sparse_mat), intent(in) :: a + class(psb_s_base_sparse_mat), intent(inout), allocatable :: b + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_),optional, intent(in) :: dupl + character(len=*), optional, intent(in) :: type + class(psb_s_base_sparse_mat), intent(in), optional :: mold + + class(psb_s_base_sparse_mat), allocatable :: altmp + integer(psb_ipk_) :: err_act + + info = psb_success_ + call psb_erractionsave(err_act) + + if (present(mold)) then + + allocate(altmp, mold=mold,stat=info) + + else if (present(type)) then + + select case (psb_toupper(type)) + case ('CSR') + allocate(psb_s_csr_sparse_mat :: altmp, stat=info) + case ('COO') + allocate(psb_s_coo_sparse_mat :: altmp, stat=info) + case ('CSC') + allocate(psb_s_csc_sparse_mat :: altmp, stat=info) + case default + info = psb_err_format_unknown_ + call psb_errpush(info,name,a_err=type) + goto 9999 + end select + else + allocate(psb_s_csr_sparse_mat :: altmp, stat=info) + !allocate(altmp, mold=psb_get_mat_default(a),stat=info) + end if + + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + + + if (present(dupl)) then + call altmp%set_dupl(dupl) + else if (a%is_bld()) then + ! Does this make sense at all?? Who knows.. + call altmp%set_dupl(psb_dupl_def_) + end if + + if (debug) write(psb_err_unit,*) 'Converting from ',& + & a%get_fmt(),' to ',altmp%get_fmt() + + call altmp%cp_from_fmt(a, info) + + if (info /= psb_success_) then + info = psb_err_from_subroutine_ + call psb_errpush(info,name,a_err="mv_from") + goto 9999 + end if + + call move_alloc(altmp,b) + call psb_erractionrestore(err_act) + return + + +9999 call psb_error_handler(err_act) + + return + end subroutine inner_cp_fmt end subroutine psb_s_cscnv subroutine psb_s_cscnv_ip(a,info,type,mold,dupl) @@ -1312,13 +1445,12 @@ subroutine psb_s_cscnv_ip(a,info,type,mold,dupl) use psb_s_mat_mod, psb_protect_name => psb_s_cscnv_ip implicit none - class(psb_sspmat_type), intent(inout) :: a - integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_),optional, intent(in) :: dupl - character(len=*), optional, intent(in) :: type + class(psb_sspmat_type), intent(inout) :: a + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_),optional, intent(in) :: dupl + character(len=*), optional, intent(in) :: type class(psb_s_base_sparse_mat), intent(in), optional :: mold - class(psb_s_base_sparse_mat), allocatable :: altmp integer(psb_ipk_) :: err_act character(len=20) :: name='cscnv_ip' @@ -1345,46 +1477,55 @@ subroutine psb_s_cscnv_ip(a,info,type,mold,dupl) goto 9999 end if - if (present(mold)) then + if (.false.) then + if (present(mold)) then + + allocate(altmp, mold=mold,stat=info) + + else if (present(type)) then + + select case (psb_toupper(type)) + case ('CSR') + allocate(psb_s_csr_sparse_mat :: altmp, stat=info) + case ('COO') + allocate(psb_s_coo_sparse_mat :: altmp, stat=info) + case ('CSC') + allocate(psb_s_csc_sparse_mat :: altmp, stat=info) + case default + info = psb_err_format_unknown_ + call psb_errpush(info,name,a_err=type) + goto 9999 + end select + else + allocate(altmp, mold=psb_get_mat_default(a),stat=info) + end if - allocate(altmp, mold=mold,stat=info) + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if - else if (present(type)) then + if (debug) write(psb_err_unit,*) 'Converting in-place from ',& + & a%get_fmt(),' to ',altmp%get_fmt() - select case (psb_toupper(type)) - case ('CSR') - allocate(psb_s_csr_sparse_mat :: altmp, stat=info) - case ('COO') - allocate(psb_s_coo_sparse_mat :: altmp, stat=info) - case ('CSC') - allocate(psb_s_csc_sparse_mat :: altmp, stat=info) - case default - info = psb_err_format_unknown_ - call psb_errpush(info,name,a_err=type) - goto 9999 - end select + call altmp%mv_from_fmt(a%a, info) + call move_alloc(altmp,a%a) else - allocate(altmp, mold=psb_get_mat_default(a),stat=info) + call inner_mv_fmt(a%a,info,type,mold,dupl) + if (allocated(a%ad)) then + call inner_mv_fmt(a%ad,info,type,mold,dupl) + end if + if (allocated(a%and)) then + call inner_mv_fmt(a%and,info,type,mold,dupl) + end if end if - - if (info /= psb_success_) then - info = psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if - - if (debug) write(psb_err_unit,*) 'Converting in-place from ',& - & a%get_fmt(),' to ',altmp%get_fmt() - - call altmp%mv_from_fmt(a%a, info) - if (info /= psb_success_) then info = psb_err_from_subroutine_ call psb_errpush(info,name,a_err="mv_from") goto 9999 end if - call move_alloc(altmp,a%a) call a%trim() call a%set_asb() call psb_erractionrestore(err_act) @@ -1394,6 +1535,77 @@ subroutine psb_s_cscnv_ip(a,info,type,mold,dupl) 9999 call psb_error_handler(err_act) return +contains + subroutine inner_mv_fmt(a,info,type,mold,dupl) + class(psb_s_base_sparse_mat), intent(inout), allocatable :: a + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_),optional, intent(in) :: dupl + character(len=*), optional, intent(in) :: type + class(psb_s_base_sparse_mat), intent(in), optional :: mold + class(psb_s_base_sparse_mat), allocatable :: altmp + integer(psb_ipk_) :: err_act + + info = psb_success_ + call psb_erractionsave(err_act) + + if (present(mold)) then + + allocate(altmp, mold=mold,stat=info) + + else if (present(type)) then + + select case (psb_toupper(type)) + case ('CSR') + allocate(psb_s_csr_sparse_mat :: altmp, stat=info) + case ('COO') + allocate(psb_s_coo_sparse_mat :: altmp, stat=info) + case ('CSC') + allocate(psb_s_csc_sparse_mat :: altmp, stat=info) + case default + info = psb_err_format_unknown_ + call psb_errpush(info,name,a_err=type) + goto 9999 + end select + else + allocate(psb_s_csr_sparse_mat :: altmp, stat=info) + !allocate(altmp, mold=psb_get_mat_default(a),stat=info) + end if + + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + + + if (present(dupl)) then + call altmp%set_dupl(dupl) + else if (a%is_bld()) then + ! Does this make sense at all?? Who knows.. + call altmp%set_dupl(psb_dupl_def_) + end if + + if (debug) write(psb_err_unit,*) 'Converting from ',& + & a%get_fmt(),' to ',altmp%get_fmt() + + call altmp%mv_from_fmt(a, info) + + if (info /= psb_success_) then + info = psb_err_from_subroutine_ + call psb_errpush(info,name,a_err="mv_from") + goto 9999 + end if + + call move_alloc(altmp,a) + + call psb_erractionrestore(err_act) + return + + +9999 call psb_error_handler(err_act) + + return + end subroutine inner_mv_fmt end subroutine psb_s_cscnv_ip diff --git a/base/serial/impl/psb_z_mat_impl.F90 b/base/serial/impl/psb_z_mat_impl.F90 index 2cebf9e7..20815cb0 100644 --- a/base/serial/impl/psb_z_mat_impl.F90 +++ b/base/serial/impl/psb_z_mat_impl.F90 @@ -1213,6 +1213,56 @@ subroutine psb_z_b_csclip(a,b,info,& end subroutine psb_z_b_csclip +subroutine psb_z_split_nd(a,n_rows,n_cols,info) + use psb_error_mod + use psb_string_mod + use psb_z_mat_mod, psb_protect_name => psb_z_split_nd + implicit none + class(psb_zspmat_type), intent(inout) :: a + integer(psb_ipk_), intent(in) :: n_rows, n_cols + integer(psb_ipk_), intent(out) :: info +!!$ integer(psb_ipk_),optional, intent(in) :: dupl +!!$ character(len=*), optional, intent(in) :: type +!!$ class(psb_z_base_sparse_mat), intent(in), optional :: mold + type(psb_z_coo_sparse_mat) :: acoo + type(psb_z_csr_sparse_mat), allocatable :: aclip + type(psb_z_ecsr_sparse_mat), allocatable :: andclip + logical, parameter :: use_ecsr=.true. + character(len=20) :: name, ch_err + integer(psb_ipk_) :: err_act + + info = psb_success_ + name = 'psb_split' + call psb_erractionsave(err_act) + allocate(aclip) + call a%a%csclip(acoo,info,jmax=n_rows,rscale=.false.,cscale=.false.) + allocate(a%ad,mold=a%a) + call a%ad%mv_from_coo(acoo,info) + call a%a%csclip(acoo,info,jmin=n_rows+1,jmax=n_cols,rscale=.false.,cscale=.false.) + if (use_ecsr) then + allocate(andclip) + call andclip%mv_from_coo(acoo,info) + call move_alloc(andclip,a%and) + else + allocate(a%and,mold=a%a) + call a%and%mv_from_coo(acoo,info) + end if + + if (psb_errstatus_fatal()) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='cscnv') + goto 9999 + endif + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + + return + +end subroutine psb_z_split_nd + subroutine psb_z_cscnv(a,b,info,type,mold,upd,dupl) use psb_error_mod use psb_string_mod @@ -1246,54 +1296,65 @@ subroutine psb_z_cscnv(a,b,info,type,mold,upd,dupl) goto 9999 end if - if (present(mold)) then - - allocate(altmp, mold=mold,stat=info) - - else if (present(type)) then + if (.false.) then + if (present(mold)) then + + allocate(altmp, mold=mold,stat=info) + + else if (present(type)) then + + select case (psb_toupper(type)) + case ('CSR') + allocate(psb_z_csr_sparse_mat :: altmp, stat=info) + case ('COO') + allocate(psb_z_coo_sparse_mat :: altmp, stat=info) + case ('CSC') + allocate(psb_z_csc_sparse_mat :: altmp, stat=info) + case default + info = psb_err_format_unknown_ + call psb_errpush(info,name,a_err=type) + goto 9999 + end select + else + allocate(altmp, mold=psb_get_mat_default(a),stat=info) + end if - select case (psb_toupper(type)) - case ('CSR') - allocate(psb_z_csr_sparse_mat :: altmp, stat=info) - case ('COO') - allocate(psb_z_coo_sparse_mat :: altmp, stat=info) - case ('CSC') - allocate(psb_z_csc_sparse_mat :: altmp, stat=info) - case default - info = psb_err_format_unknown_ - call psb_errpush(info,name,a_err=type) + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info,name) goto 9999 - end select - else - allocate(altmp, mold=psb_get_mat_default(a),stat=info) - end if + end if - if (info /= psb_success_) then - info = psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if + if (present(dupl)) then + call altmp%set_dupl(dupl) + else if (a%is_bld()) then + ! Does this make sense at all?? Who knows.. + call altmp%set_dupl(psb_dupl_def_) + end if - if (present(dupl)) then - call altmp%set_dupl(dupl) - else if (a%is_bld()) then - ! Does this make sense at all?? Who knows.. - call altmp%set_dupl(psb_dupl_def_) - end if + if (debug) write(psb_err_unit,*) 'Converting from ',& + & a%get_fmt(),' to ',altmp%get_fmt() - if (debug) write(psb_err_unit,*) 'Converting from ',& - & a%get_fmt(),' to ',altmp%get_fmt() + call altmp%cp_from_fmt(a%a, info) - call altmp%cp_from_fmt(a%a, info) + if (info /= psb_success_) then + info = psb_err_from_subroutine_ + call psb_errpush(info,name,a_err="mv_from") + goto 9999 + end if - if (info /= psb_success_) then - info = psb_err_from_subroutine_ - call psb_errpush(info,name,a_err="mv_from") - goto 9999 + call move_alloc(altmp,b%a) + else + call inner_cp_fmt(a%a,b%a,info,type,mold,dupl) + if (allocated(a%ad)) then + call inner_cp_fmt(a%ad,b%ad,info,type,mold,dupl) + end if + if (allocated(a%and)) then + call inner_cp_fmt(a%and,b%and,info,type,mold,dupl) + end if end if - call move_alloc(altmp,b%a) call b%trim() call b%set_asb() call psb_erractionrestore(err_act) @@ -1303,7 +1364,79 @@ subroutine psb_z_cscnv(a,b,info,type,mold,upd,dupl) 9999 call psb_error_handler(err_act) return +contains + subroutine inner_cp_fmt(a,b,info,type,mold,dupl) + class(psb_z_base_sparse_mat), intent(in) :: a + class(psb_z_base_sparse_mat), intent(inout), allocatable :: b + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_),optional, intent(in) :: dupl + character(len=*), optional, intent(in) :: type + class(psb_z_base_sparse_mat), intent(in), optional :: mold + + class(psb_z_base_sparse_mat), allocatable :: altmp + integer(psb_ipk_) :: err_act + + info = psb_success_ + call psb_erractionsave(err_act) + + if (present(mold)) then + + allocate(altmp, mold=mold,stat=info) + + else if (present(type)) then + + select case (psb_toupper(type)) + case ('CSR') + allocate(psb_z_csr_sparse_mat :: altmp, stat=info) + case ('COO') + allocate(psb_z_coo_sparse_mat :: altmp, stat=info) + case ('CSC') + allocate(psb_z_csc_sparse_mat :: altmp, stat=info) + case default + info = psb_err_format_unknown_ + call psb_errpush(info,name,a_err=type) + goto 9999 + end select + else + allocate(psb_z_csr_sparse_mat :: altmp, stat=info) + !allocate(altmp, mold=psb_get_mat_default(a),stat=info) + end if + + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + + + if (present(dupl)) then + call altmp%set_dupl(dupl) + else if (a%is_bld()) then + ! Does this make sense at all?? Who knows.. + call altmp%set_dupl(psb_dupl_def_) + end if + + if (debug) write(psb_err_unit,*) 'Converting from ',& + & a%get_fmt(),' to ',altmp%get_fmt() + + call altmp%cp_from_fmt(a, info) + + if (info /= psb_success_) then + info = psb_err_from_subroutine_ + call psb_errpush(info,name,a_err="mv_from") + goto 9999 + end if + + call move_alloc(altmp,b) + call psb_erractionrestore(err_act) + return + + +9999 call psb_error_handler(err_act) + + return + end subroutine inner_cp_fmt end subroutine psb_z_cscnv subroutine psb_z_cscnv_ip(a,info,type,mold,dupl) @@ -1312,13 +1445,12 @@ subroutine psb_z_cscnv_ip(a,info,type,mold,dupl) use psb_z_mat_mod, psb_protect_name => psb_z_cscnv_ip implicit none - class(psb_zspmat_type), intent(inout) :: a - integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_),optional, intent(in) :: dupl - character(len=*), optional, intent(in) :: type + class(psb_zspmat_type), intent(inout) :: a + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_),optional, intent(in) :: dupl + character(len=*), optional, intent(in) :: type class(psb_z_base_sparse_mat), intent(in), optional :: mold - class(psb_z_base_sparse_mat), allocatable :: altmp integer(psb_ipk_) :: err_act character(len=20) :: name='cscnv_ip' @@ -1345,46 +1477,55 @@ subroutine psb_z_cscnv_ip(a,info,type,mold,dupl) goto 9999 end if - if (present(mold)) then + if (.false.) then + if (present(mold)) then + + allocate(altmp, mold=mold,stat=info) + + else if (present(type)) then + + select case (psb_toupper(type)) + case ('CSR') + allocate(psb_z_csr_sparse_mat :: altmp, stat=info) + case ('COO') + allocate(psb_z_coo_sparse_mat :: altmp, stat=info) + case ('CSC') + allocate(psb_z_csc_sparse_mat :: altmp, stat=info) + case default + info = psb_err_format_unknown_ + call psb_errpush(info,name,a_err=type) + goto 9999 + end select + else + allocate(altmp, mold=psb_get_mat_default(a),stat=info) + end if - allocate(altmp, mold=mold,stat=info) + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if - else if (present(type)) then + if (debug) write(psb_err_unit,*) 'Converting in-place from ',& + & a%get_fmt(),' to ',altmp%get_fmt() - select case (psb_toupper(type)) - case ('CSR') - allocate(psb_z_csr_sparse_mat :: altmp, stat=info) - case ('COO') - allocate(psb_z_coo_sparse_mat :: altmp, stat=info) - case ('CSC') - allocate(psb_z_csc_sparse_mat :: altmp, stat=info) - case default - info = psb_err_format_unknown_ - call psb_errpush(info,name,a_err=type) - goto 9999 - end select + call altmp%mv_from_fmt(a%a, info) + call move_alloc(altmp,a%a) else - allocate(altmp, mold=psb_get_mat_default(a),stat=info) + call inner_mv_fmt(a%a,info,type,mold,dupl) + if (allocated(a%ad)) then + call inner_mv_fmt(a%ad,info,type,mold,dupl) + end if + if (allocated(a%and)) then + call inner_mv_fmt(a%and,info,type,mold,dupl) + end if end if - - if (info /= psb_success_) then - info = psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 - end if - - if (debug) write(psb_err_unit,*) 'Converting in-place from ',& - & a%get_fmt(),' to ',altmp%get_fmt() - - call altmp%mv_from_fmt(a%a, info) - if (info /= psb_success_) then info = psb_err_from_subroutine_ call psb_errpush(info,name,a_err="mv_from") goto 9999 end if - call move_alloc(altmp,a%a) call a%trim() call a%set_asb() call psb_erractionrestore(err_act) @@ -1394,6 +1535,77 @@ subroutine psb_z_cscnv_ip(a,info,type,mold,dupl) 9999 call psb_error_handler(err_act) return +contains + subroutine inner_mv_fmt(a,info,type,mold,dupl) + class(psb_z_base_sparse_mat), intent(inout), allocatable :: a + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_),optional, intent(in) :: dupl + character(len=*), optional, intent(in) :: type + class(psb_z_base_sparse_mat), intent(in), optional :: mold + class(psb_z_base_sparse_mat), allocatable :: altmp + integer(psb_ipk_) :: err_act + + info = psb_success_ + call psb_erractionsave(err_act) + + if (present(mold)) then + + allocate(altmp, mold=mold,stat=info) + + else if (present(type)) then + + select case (psb_toupper(type)) + case ('CSR') + allocate(psb_z_csr_sparse_mat :: altmp, stat=info) + case ('COO') + allocate(psb_z_coo_sparse_mat :: altmp, stat=info) + case ('CSC') + allocate(psb_z_csc_sparse_mat :: altmp, stat=info) + case default + info = psb_err_format_unknown_ + call psb_errpush(info,name,a_err=type) + goto 9999 + end select + else + allocate(psb_z_csr_sparse_mat :: altmp, stat=info) + !allocate(altmp, mold=psb_get_mat_default(a),stat=info) + end if + + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + + + if (present(dupl)) then + call altmp%set_dupl(dupl) + else if (a%is_bld()) then + ! Does this make sense at all?? Who knows.. + call altmp%set_dupl(psb_dupl_def_) + end if + + if (debug) write(psb_err_unit,*) 'Converting from ',& + & a%get_fmt(),' to ',altmp%get_fmt() + + call altmp%mv_from_fmt(a, info) + + if (info /= psb_success_) then + info = psb_err_from_subroutine_ + call psb_errpush(info,name,a_err="mv_from") + goto 9999 + end if + + call move_alloc(altmp,a) + + call psb_erractionrestore(err_act) + return + + +9999 call psb_error_handler(err_act) + + return + end subroutine inner_mv_fmt end subroutine psb_z_cscnv_ip diff --git a/base/tools/psb_cspasb.f90 b/base/tools/psb_cspasb.f90 index 8263e309..db8af75a 100644 --- a/base/tools/psb_cspasb.f90 +++ b/base/tools/psb_cspasb.f90 @@ -178,41 +178,44 @@ subroutine psb_cspasb(a,desc_a, info, afmt, upd, mold, bld_and) end if if (bld_and_) then - block - character(len=1024) :: fname - type(psb_c_coo_sparse_mat) :: acoo - type(psb_c_csr_sparse_mat), allocatable :: aclip - type(psb_c_ecsr_sparse_mat), allocatable :: andclip - logical, parameter :: use_ecsr=.true. - allocate(aclip) - call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) - allocate(a%ad,mold=a%a) - call a%ad%mv_from_coo(acoo,info) - call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) - if (use_ecsr) then - allocate(andclip) - call andclip%mv_from_coo(acoo,info) - call move_alloc(andclip,a%and) - else - allocate(a%and,mold=a%a) - call a%and%mv_from_coo(acoo,info) - end if - if (.false.) then - write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' - open(25,file=fname) - call a%ad%print(25) - close(25) - write(fname,'(a,i2.2,a)') 'andclip_',me,'.mtx' - open(25,file=fname) - call a%and%print(25) - close(25) - !call andclip%set_cols(n_col) - write(*,*) me,' ',trim(name),' ad ',& - &a%ad%get_nrows(),a%ad%get_ncols(),n_row,n_col - write(*,*) me,' ',trim(name),' and ',& - &a%and%get_nrows(),a%and%get_ncols(),n_row,n_col - end if - end block +!!$ allocate(a%ad,mold=a%a) +!!$ allocate(a%and,mold=a%a)o + call a%split_nd(n_row,n_col,info) +!!$ block +!!$ character(len=1024) :: fname +!!$ type(psb_c_coo_sparse_mat) :: acoo +!!$ type(psb_c_csr_sparse_mat), allocatable :: aclip +!!$ type(psb_c_ecsr_sparse_mat), allocatable :: andclip +!!$ logical, parameter :: use_ecsr=.true. +!!$ allocate(aclip) +!!$ call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) +!!$ allocate(a%ad,mold=a%a) +!!$ call a%ad%mv_from_coo(acoo,info) +!!$ call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) +!!$ if (use_ecsr) then +!!$ allocate(andclip) +!!$ call andclip%mv_from_coo(acoo,info) +!!$ call move_alloc(andclip,a%and) +!!$ else +!!$ allocate(a%and,mold=a%a) +!!$ call a%and%mv_from_coo(acoo,info) +!!$ end if +!!$ if (.false.) then +!!$ write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' +!!$ open(25,file=fname) +!!$ call a%ad%print(25) +!!$ close(25) +!!$ write(fname,'(a,i2.2,a)') 'andclip_',me,'.mtx' +!!$ open(25,file=fname) +!!$ call a%and%print(25) +!!$ close(25) +!!$ !call andclip%set_cols(n_col) +!!$ write(*,*) me,' ',trim(name),' ad ',& +!!$ &a%ad%get_nrows(),a%ad%get_ncols(),n_row,n_col +!!$ write(*,*) me,' ',trim(name),' and ',& +!!$ &a%and%get_nrows(),a%and%get_ncols(),n_row,n_col +!!$ end if +!!$ end block else if (allocated(a%ad)) deallocate(a%ad) if (allocated(a%and)) deallocate(a%and) diff --git a/base/tools/psb_dspasb.f90 b/base/tools/psb_dspasb.f90 index 6beb0e6f..236568a1 100644 --- a/base/tools/psb_dspasb.f90 +++ b/base/tools/psb_dspasb.f90 @@ -178,41 +178,44 @@ subroutine psb_dspasb(a,desc_a, info, afmt, upd, mold, bld_and) end if if (bld_and_) then - block - character(len=1024) :: fname - type(psb_d_coo_sparse_mat) :: acoo - type(psb_d_csr_sparse_mat), allocatable :: aclip - type(psb_d_ecsr_sparse_mat), allocatable :: andclip - logical, parameter :: use_ecsr=.true. - allocate(aclip) - call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) - allocate(a%ad,mold=a%a) - call a%ad%mv_from_coo(acoo,info) - call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) - if (use_ecsr) then - allocate(andclip) - call andclip%mv_from_coo(acoo,info) - call move_alloc(andclip,a%and) - else - allocate(a%and,mold=a%a) - call a%and%mv_from_coo(acoo,info) - end if - if (.false.) then - write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' - open(25,file=fname) - call a%ad%print(25) - close(25) - write(fname,'(a,i2.2,a)') 'andclip_',me,'.mtx' - open(25,file=fname) - call a%and%print(25) - close(25) - !call andclip%set_cols(n_col) - write(*,*) me,' ',trim(name),' ad ',& - &a%ad%get_nrows(),a%ad%get_ncols(),n_row,n_col - write(*,*) me,' ',trim(name),' and ',& - &a%and%get_nrows(),a%and%get_ncols(),n_row,n_col - end if - end block +!!$ allocate(a%ad,mold=a%a) +!!$ allocate(a%and,mold=a%a)o + call a%split_nd(n_row,n_col,info) +!!$ block +!!$ character(len=1024) :: fname +!!$ type(psb_d_coo_sparse_mat) :: acoo +!!$ type(psb_d_csr_sparse_mat), allocatable :: aclip +!!$ type(psb_d_ecsr_sparse_mat), allocatable :: andclip +!!$ logical, parameter :: use_ecsr=.true. +!!$ allocate(aclip) +!!$ call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) +!!$ allocate(a%ad,mold=a%a) +!!$ call a%ad%mv_from_coo(acoo,info) +!!$ call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) +!!$ if (use_ecsr) then +!!$ allocate(andclip) +!!$ call andclip%mv_from_coo(acoo,info) +!!$ call move_alloc(andclip,a%and) +!!$ else +!!$ allocate(a%and,mold=a%a) +!!$ call a%and%mv_from_coo(acoo,info) +!!$ end if +!!$ if (.false.) then +!!$ write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' +!!$ open(25,file=fname) +!!$ call a%ad%print(25) +!!$ close(25) +!!$ write(fname,'(a,i2.2,a)') 'andclip_',me,'.mtx' +!!$ open(25,file=fname) +!!$ call a%and%print(25) +!!$ close(25) +!!$ !call andclip%set_cols(n_col) +!!$ write(*,*) me,' ',trim(name),' ad ',& +!!$ &a%ad%get_nrows(),a%ad%get_ncols(),n_row,n_col +!!$ write(*,*) me,' ',trim(name),' and ',& +!!$ &a%and%get_nrows(),a%and%get_ncols(),n_row,n_col +!!$ end if +!!$ end block else if (allocated(a%ad)) deallocate(a%ad) if (allocated(a%and)) deallocate(a%and) diff --git a/base/tools/psb_sspasb.f90 b/base/tools/psb_sspasb.f90 index f273c7f4..110097c5 100644 --- a/base/tools/psb_sspasb.f90 +++ b/base/tools/psb_sspasb.f90 @@ -178,41 +178,44 @@ subroutine psb_sspasb(a,desc_a, info, afmt, upd, mold, bld_and) end if if (bld_and_) then - block - character(len=1024) :: fname - type(psb_s_coo_sparse_mat) :: acoo - type(psb_s_csr_sparse_mat), allocatable :: aclip - type(psb_s_ecsr_sparse_mat), allocatable :: andclip - logical, parameter :: use_ecsr=.true. - allocate(aclip) - call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) - allocate(a%ad,mold=a%a) - call a%ad%mv_from_coo(acoo,info) - call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) - if (use_ecsr) then - allocate(andclip) - call andclip%mv_from_coo(acoo,info) - call move_alloc(andclip,a%and) - else - allocate(a%and,mold=a%a) - call a%and%mv_from_coo(acoo,info) - end if - if (.false.) then - write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' - open(25,file=fname) - call a%ad%print(25) - close(25) - write(fname,'(a,i2.2,a)') 'andclip_',me,'.mtx' - open(25,file=fname) - call a%and%print(25) - close(25) - !call andclip%set_cols(n_col) - write(*,*) me,' ',trim(name),' ad ',& - &a%ad%get_nrows(),a%ad%get_ncols(),n_row,n_col - write(*,*) me,' ',trim(name),' and ',& - &a%and%get_nrows(),a%and%get_ncols(),n_row,n_col - end if - end block +!!$ allocate(a%ad,mold=a%a) +!!$ allocate(a%and,mold=a%a)o + call a%split_nd(n_row,n_col,info) +!!$ block +!!$ character(len=1024) :: fname +!!$ type(psb_s_coo_sparse_mat) :: acoo +!!$ type(psb_s_csr_sparse_mat), allocatable :: aclip +!!$ type(psb_s_ecsr_sparse_mat), allocatable :: andclip +!!$ logical, parameter :: use_ecsr=.true. +!!$ allocate(aclip) +!!$ call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) +!!$ allocate(a%ad,mold=a%a) +!!$ call a%ad%mv_from_coo(acoo,info) +!!$ call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) +!!$ if (use_ecsr) then +!!$ allocate(andclip) +!!$ call andclip%mv_from_coo(acoo,info) +!!$ call move_alloc(andclip,a%and) +!!$ else +!!$ allocate(a%and,mold=a%a) +!!$ call a%and%mv_from_coo(acoo,info) +!!$ end if +!!$ if (.false.) then +!!$ write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' +!!$ open(25,file=fname) +!!$ call a%ad%print(25) +!!$ close(25) +!!$ write(fname,'(a,i2.2,a)') 'andclip_',me,'.mtx' +!!$ open(25,file=fname) +!!$ call a%and%print(25) +!!$ close(25) +!!$ !call andclip%set_cols(n_col) +!!$ write(*,*) me,' ',trim(name),' ad ',& +!!$ &a%ad%get_nrows(),a%ad%get_ncols(),n_row,n_col +!!$ write(*,*) me,' ',trim(name),' and ',& +!!$ &a%and%get_nrows(),a%and%get_ncols(),n_row,n_col +!!$ end if +!!$ end block else if (allocated(a%ad)) deallocate(a%ad) if (allocated(a%and)) deallocate(a%and) diff --git a/base/tools/psb_zspasb.f90 b/base/tools/psb_zspasb.f90 index 1a381303..2cb53368 100644 --- a/base/tools/psb_zspasb.f90 +++ b/base/tools/psb_zspasb.f90 @@ -178,41 +178,44 @@ subroutine psb_zspasb(a,desc_a, info, afmt, upd, mold, bld_and) end if if (bld_and_) then - block - character(len=1024) :: fname - type(psb_z_coo_sparse_mat) :: acoo - type(psb_z_csr_sparse_mat), allocatable :: aclip - type(psb_z_ecsr_sparse_mat), allocatable :: andclip - logical, parameter :: use_ecsr=.true. - allocate(aclip) - call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) - allocate(a%ad,mold=a%a) - call a%ad%mv_from_coo(acoo,info) - call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) - if (use_ecsr) then - allocate(andclip) - call andclip%mv_from_coo(acoo,info) - call move_alloc(andclip,a%and) - else - allocate(a%and,mold=a%a) - call a%and%mv_from_coo(acoo,info) - end if - if (.false.) then - write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' - open(25,file=fname) - call a%ad%print(25) - close(25) - write(fname,'(a,i2.2,a)') 'andclip_',me,'.mtx' - open(25,file=fname) - call a%and%print(25) - close(25) - !call andclip%set_cols(n_col) - write(*,*) me,' ',trim(name),' ad ',& - &a%ad%get_nrows(),a%ad%get_ncols(),n_row,n_col - write(*,*) me,' ',trim(name),' and ',& - &a%and%get_nrows(),a%and%get_ncols(),n_row,n_col - end if - end block +!!$ allocate(a%ad,mold=a%a) +!!$ allocate(a%and,mold=a%a)o + call a%split_nd(n_row,n_col,info) +!!$ block +!!$ character(len=1024) :: fname +!!$ type(psb_z_coo_sparse_mat) :: acoo +!!$ type(psb_z_csr_sparse_mat), allocatable :: aclip +!!$ type(psb_z_ecsr_sparse_mat), allocatable :: andclip +!!$ logical, parameter :: use_ecsr=.true. +!!$ allocate(aclip) +!!$ call a%a%csclip(acoo,info,jmax=n_row,rscale=.false.,cscale=.false.) +!!$ allocate(a%ad,mold=a%a) +!!$ call a%ad%mv_from_coo(acoo,info) +!!$ call a%a%csclip(acoo,info,jmin=n_row+1,jmax=n_col,rscale=.false.,cscale=.false.) +!!$ if (use_ecsr) then +!!$ allocate(andclip) +!!$ call andclip%mv_from_coo(acoo,info) +!!$ call move_alloc(andclip,a%and) +!!$ else +!!$ allocate(a%and,mold=a%a) +!!$ call a%and%mv_from_coo(acoo,info) +!!$ end if +!!$ if (.false.) then +!!$ write(fname,'(a,i2.2,a)') 'adclip_',me,'.mtx' +!!$ open(25,file=fname) +!!$ call a%ad%print(25) +!!$ close(25) +!!$ write(fname,'(a,i2.2,a)') 'andclip_',me,'.mtx' +!!$ open(25,file=fname) +!!$ call a%and%print(25) +!!$ close(25) +!!$ !call andclip%set_cols(n_col) +!!$ write(*,*) me,' ',trim(name),' ad ',& +!!$ &a%ad%get_nrows(),a%ad%get_ncols(),n_row,n_col +!!$ write(*,*) me,' ',trim(name),' and ',& +!!$ &a%and%get_nrows(),a%and%get_ncols(),n_row,n_col +!!$ end if +!!$ end block else if (allocated(a%ad)) deallocate(a%ad) if (allocated(a%and)) deallocate(a%and) From 14c4ff0f32ab7f6a92f964e124f1035919644e5d Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Mon, 12 Feb 2024 14:15:19 +0100 Subject: [PATCH 18/48] Added new methd for two combined axpbys --- base/modules/psblas/psb_c_psblas_mod.F90 | 14 +++++++++++++ base/modules/psblas/psb_d_psblas_mod.F90 | 14 +++++++++++++ base/modules/psblas/psb_s_psblas_mod.F90 | 14 +++++++++++++ base/modules/psblas/psb_z_psblas_mod.F90 | 14 +++++++++++++ base/modules/serial/psb_c_base_vect_mod.F90 | 19 ++++++++++++++++++ base/modules/serial/psb_c_vect_mod.F90 | 18 +++++++++++++++++ base/modules/serial/psb_d_base_vect_mod.F90 | 19 ++++++++++++++++++ base/modules/serial/psb_d_vect_mod.F90 | 18 +++++++++++++++++ base/modules/serial/psb_s_base_vect_mod.F90 | 19 ++++++++++++++++++ base/modules/serial/psb_s_vect_mod.F90 | 18 +++++++++++++++++ base/modules/serial/psb_z_base_vect_mod.F90 | 19 ++++++++++++++++++ base/modules/serial/psb_z_vect_mod.F90 | 18 +++++++++++++++++ cuda/psb_c_cuda_vect_mod.F90 | 22 +++++++++++++++++++++ cuda/psb_d_cuda_vect_mod.F90 | 22 +++++++++++++++++++++ cuda/psb_s_cuda_vect_mod.F90 | 22 +++++++++++++++++++++ cuda/psb_z_cuda_vect_mod.F90 | 22 +++++++++++++++++++++ test/pargen/psb_d_pde3d.F90 | 8 ++++---- 17 files changed, 296 insertions(+), 4 deletions(-) diff --git a/base/modules/psblas/psb_c_psblas_mod.F90 b/base/modules/psblas/psb_c_psblas_mod.F90 index 98deebd8..d660597a 100644 --- a/base/modules/psblas/psb_c_psblas_mod.F90 +++ b/base/modules/psblas/psb_c_psblas_mod.F90 @@ -143,6 +143,20 @@ module psb_c_psblas_mod end subroutine psb_caxpby end interface + interface psb_abgdxyx + subroutine psb_cabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& + & desc_a, info) + import :: psb_desc_type, psb_spk_, psb_ipk_, & + & psb_c_vect_type, psb_cspmat_type + type(psb_c_vect_type), intent (inout) :: x + type(psb_c_vect_type), intent (inout) :: y + type(psb_c_vect_type), intent (inout) :: z + complex(psb_spk_), intent (in) :: alpha, beta, gamma, delta + type(psb_desc_type), intent (in) :: desc_a + integer(psb_ipk_), intent(out) :: info + end subroutine psb_cabgdxyz_vect + end interface psb_abgdxyx + interface psb_geamax function psb_camax(x, desc_a, info, jx,global) import :: psb_desc_type, psb_spk_, psb_ipk_, & diff --git a/base/modules/psblas/psb_d_psblas_mod.F90 b/base/modules/psblas/psb_d_psblas_mod.F90 index e4988387..734ed633 100644 --- a/base/modules/psblas/psb_d_psblas_mod.F90 +++ b/base/modules/psblas/psb_d_psblas_mod.F90 @@ -143,6 +143,20 @@ module psb_d_psblas_mod end subroutine psb_daxpby end interface + interface psb_abgdxyx + subroutine psb_dabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& + & desc_a, info) + import :: psb_desc_type, psb_dpk_, psb_ipk_, & + & psb_d_vect_type, psb_dspmat_type + type(psb_d_vect_type), intent (inout) :: x + type(psb_d_vect_type), intent (inout) :: y + type(psb_d_vect_type), intent (inout) :: z + real(psb_dpk_), intent (in) :: alpha, beta, gamma, delta + type(psb_desc_type), intent (in) :: desc_a + integer(psb_ipk_), intent(out) :: info + end subroutine psb_dabgdxyz_vect + end interface psb_abgdxyx + interface psb_geamax function psb_damax(x, desc_a, info, jx,global) import :: psb_desc_type, psb_dpk_, psb_ipk_, & diff --git a/base/modules/psblas/psb_s_psblas_mod.F90 b/base/modules/psblas/psb_s_psblas_mod.F90 index 93fe74b9..0f7d29e6 100644 --- a/base/modules/psblas/psb_s_psblas_mod.F90 +++ b/base/modules/psblas/psb_s_psblas_mod.F90 @@ -143,6 +143,20 @@ module psb_s_psblas_mod end subroutine psb_saxpby end interface + interface psb_abgdxyx + subroutine psb_sabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& + & desc_a, info) + import :: psb_desc_type, psb_spk_, psb_ipk_, & + & psb_s_vect_type, psb_sspmat_type + type(psb_s_vect_type), intent (inout) :: x + type(psb_s_vect_type), intent (inout) :: y + type(psb_s_vect_type), intent (inout) :: z + real(psb_spk_), intent (in) :: alpha, beta, gamma, delta + type(psb_desc_type), intent (in) :: desc_a + integer(psb_ipk_), intent(out) :: info + end subroutine psb_sabgdxyz_vect + end interface psb_abgdxyx + interface psb_geamax function psb_samax(x, desc_a, info, jx,global) import :: psb_desc_type, psb_spk_, psb_ipk_, & diff --git a/base/modules/psblas/psb_z_psblas_mod.F90 b/base/modules/psblas/psb_z_psblas_mod.F90 index 06be1b82..17674600 100644 --- a/base/modules/psblas/psb_z_psblas_mod.F90 +++ b/base/modules/psblas/psb_z_psblas_mod.F90 @@ -143,6 +143,20 @@ module psb_z_psblas_mod end subroutine psb_zaxpby end interface + interface psb_abgdxyx + subroutine psb_zabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& + & desc_a, info) + import :: psb_desc_type, psb_dpk_, psb_ipk_, & + & psb_z_vect_type, psb_zspmat_type + type(psb_z_vect_type), intent (inout) :: x + type(psb_z_vect_type), intent (inout) :: y + type(psb_z_vect_type), intent (inout) :: z + complex(psb_dpk_), intent (in) :: alpha, beta, gamma, delta + type(psb_desc_type), intent (in) :: desc_a + integer(psb_ipk_), intent(out) :: info + end subroutine psb_zabgdxyz_vect + end interface psb_abgdxyx + interface psb_geamax function psb_zamax(x, desc_a, info, jx,global) import :: psb_desc_type, psb_dpk_, psb_ipk_, & diff --git a/base/modules/serial/psb_c_base_vect_mod.F90 b/base/modules/serial/psb_c_base_vect_mod.F90 index df15e0c9..e4f398a7 100644 --- a/base/modules/serial/psb_c_base_vect_mod.F90 +++ b/base/modules/serial/psb_c_base_vect_mod.F90 @@ -155,6 +155,8 @@ module psb_c_base_vect_mod procedure, pass(z) :: axpby_v2 => c_base_axpby_v2 procedure, pass(z) :: axpby_a2 => c_base_axpby_a2 generic, public :: axpby => axpby_v, axpby_a, axpby_v2, axpby_a2 + procedure, pass(z) :: abgdxyz => c_base_abgdxyz + ! ! Vector by vector multiplication. Need all variants ! to handle multiple requirements from preconditioners @@ -1126,6 +1128,23 @@ contains end subroutine c_base_axpby_a2 + subroutine c_base_abgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_c_base_vect_type), intent(inout) :: x + class(psb_c_base_vect_type), intent(inout) :: y + class(psb_c_base_vect_type), intent(inout) :: z + complex(psb_spk_), intent (in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + + if (x%is_dev()) call x%sync() + + call y%axpby(m,alpha,x,beta,info) + call z%axpby(m,gamma,y,delta,info) + + end subroutine c_base_abgdxyz + ! ! Multiple variants of two operations: diff --git a/base/modules/serial/psb_c_vect_mod.F90 b/base/modules/serial/psb_c_vect_mod.F90 index 1a336d11..8b2941ff 100644 --- a/base/modules/serial/psb_c_vect_mod.F90 +++ b/base/modules/serial/psb_c_vect_mod.F90 @@ -102,6 +102,8 @@ module psb_c_vect_mod procedure, pass(z) :: axpby_v2 => c_vect_axpby_v2 procedure, pass(z) :: axpby_a2 => c_vect_axpby_a2 generic, public :: axpby => axpby_v, axpby_a, axpby_v2, axpby_a2 + procedure, pass(z) :: abgdxyz => c_vect_abgdxyz + procedure, pass(y) :: mlt_v => c_vect_mlt_v procedure, pass(y) :: mlt_a => c_vect_mlt_a procedure, pass(z) :: mlt_a_2 => c_vect_mlt_a_2 @@ -771,6 +773,22 @@ contains end subroutine c_vect_axpby_a2 + subroutine c_vect_abgdxyz(m,alpha,beta,gamma,delta,x, y, z, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_c_vect_type), intent(inout) :: x + class(psb_c_vect_type), intent(inout) :: y + class(psb_c_vect_type), intent(inout) :: z + complex(psb_spk_), intent (in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + + if (allocated(z%v)) & + call z%abgdxyz(m,alpha,beta,gamma,delta,x,y,info) + + end subroutine c_vect_abgdxyz + + subroutine c_vect_mlt_v(x, y, info) use psi_serial_mod implicit none diff --git a/base/modules/serial/psb_d_base_vect_mod.F90 b/base/modules/serial/psb_d_base_vect_mod.F90 index 87f5b0e4..7ad2d6e7 100644 --- a/base/modules/serial/psb_d_base_vect_mod.F90 +++ b/base/modules/serial/psb_d_base_vect_mod.F90 @@ -155,6 +155,8 @@ module psb_d_base_vect_mod procedure, pass(z) :: axpby_v2 => d_base_axpby_v2 procedure, pass(z) :: axpby_a2 => d_base_axpby_a2 generic, public :: axpby => axpby_v, axpby_a, axpby_v2, axpby_a2 + procedure, pass(z) :: abgdxyz => d_base_abgdxyz + ! ! Vector by vector multiplication. Need all variants ! to handle multiple requirements from preconditioners @@ -1133,6 +1135,23 @@ contains end subroutine d_base_axpby_a2 + subroutine d_base_abgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_d_base_vect_type), intent(inout) :: x + class(psb_d_base_vect_type), intent(inout) :: y + class(psb_d_base_vect_type), intent(inout) :: z + real(psb_dpk_), intent (in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + + if (x%is_dev()) call x%sync() + + call y%axpby(m,alpha,x,beta,info) + call z%axpby(m,gamma,y,delta,info) + + end subroutine d_base_abgdxyz + ! ! Multiple variants of two operations: diff --git a/base/modules/serial/psb_d_vect_mod.F90 b/base/modules/serial/psb_d_vect_mod.F90 index 88fa3262..ef75be87 100644 --- a/base/modules/serial/psb_d_vect_mod.F90 +++ b/base/modules/serial/psb_d_vect_mod.F90 @@ -102,6 +102,8 @@ module psb_d_vect_mod procedure, pass(z) :: axpby_v2 => d_vect_axpby_v2 procedure, pass(z) :: axpby_a2 => d_vect_axpby_a2 generic, public :: axpby => axpby_v, axpby_a, axpby_v2, axpby_a2 + procedure, pass(z) :: abgdxyz => d_vect_abgdxyz + procedure, pass(y) :: mlt_v => d_vect_mlt_v procedure, pass(y) :: mlt_a => d_vect_mlt_a procedure, pass(z) :: mlt_a_2 => d_vect_mlt_a_2 @@ -778,6 +780,22 @@ contains end subroutine d_vect_axpby_a2 + subroutine d_vect_abgdxyz(m,alpha,beta,gamma,delta,x, y, z, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_d_vect_type), intent(inout) :: x + class(psb_d_vect_type), intent(inout) :: y + class(psb_d_vect_type), intent(inout) :: z + real(psb_dpk_), intent (in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + + if (allocated(z%v)) & + call z%abgdxyz(m,alpha,beta,gamma,delta,x,y,info) + + end subroutine d_vect_abgdxyz + + subroutine d_vect_mlt_v(x, y, info) use psi_serial_mod implicit none diff --git a/base/modules/serial/psb_s_base_vect_mod.F90 b/base/modules/serial/psb_s_base_vect_mod.F90 index fccd846b..4e9c0dd3 100644 --- a/base/modules/serial/psb_s_base_vect_mod.F90 +++ b/base/modules/serial/psb_s_base_vect_mod.F90 @@ -155,6 +155,8 @@ module psb_s_base_vect_mod procedure, pass(z) :: axpby_v2 => s_base_axpby_v2 procedure, pass(z) :: axpby_a2 => s_base_axpby_a2 generic, public :: axpby => axpby_v, axpby_a, axpby_v2, axpby_a2 + procedure, pass(z) :: abgdxyz => s_base_abgdxyz + ! ! Vector by vector multiplication. Need all variants ! to handle multiple requirements from preconditioners @@ -1133,6 +1135,23 @@ contains end subroutine s_base_axpby_a2 + subroutine s_base_abgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_s_base_vect_type), intent(inout) :: x + class(psb_s_base_vect_type), intent(inout) :: y + class(psb_s_base_vect_type), intent(inout) :: z + real(psb_spk_), intent (in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + + if (x%is_dev()) call x%sync() + + call y%axpby(m,alpha,x,beta,info) + call z%axpby(m,gamma,y,delta,info) + + end subroutine s_base_abgdxyz + ! ! Multiple variants of two operations: diff --git a/base/modules/serial/psb_s_vect_mod.F90 b/base/modules/serial/psb_s_vect_mod.F90 index 7a54ecf0..34479856 100644 --- a/base/modules/serial/psb_s_vect_mod.F90 +++ b/base/modules/serial/psb_s_vect_mod.F90 @@ -102,6 +102,8 @@ module psb_s_vect_mod procedure, pass(z) :: axpby_v2 => s_vect_axpby_v2 procedure, pass(z) :: axpby_a2 => s_vect_axpby_a2 generic, public :: axpby => axpby_v, axpby_a, axpby_v2, axpby_a2 + procedure, pass(z) :: abgdxyz => s_vect_abgdxyz + procedure, pass(y) :: mlt_v => s_vect_mlt_v procedure, pass(y) :: mlt_a => s_vect_mlt_a procedure, pass(z) :: mlt_a_2 => s_vect_mlt_a_2 @@ -778,6 +780,22 @@ contains end subroutine s_vect_axpby_a2 + subroutine s_vect_abgdxyz(m,alpha,beta,gamma,delta,x, y, z, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_s_vect_type), intent(inout) :: x + class(psb_s_vect_type), intent(inout) :: y + class(psb_s_vect_type), intent(inout) :: z + real(psb_spk_), intent (in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + + if (allocated(z%v)) & + call z%abgdxyz(m,alpha,beta,gamma,delta,x,y,info) + + end subroutine s_vect_abgdxyz + + subroutine s_vect_mlt_v(x, y, info) use psi_serial_mod implicit none diff --git a/base/modules/serial/psb_z_base_vect_mod.F90 b/base/modules/serial/psb_z_base_vect_mod.F90 index 2a14de21..60c3c854 100644 --- a/base/modules/serial/psb_z_base_vect_mod.F90 +++ b/base/modules/serial/psb_z_base_vect_mod.F90 @@ -155,6 +155,8 @@ module psb_z_base_vect_mod procedure, pass(z) :: axpby_v2 => z_base_axpby_v2 procedure, pass(z) :: axpby_a2 => z_base_axpby_a2 generic, public :: axpby => axpby_v, axpby_a, axpby_v2, axpby_a2 + procedure, pass(z) :: abgdxyz => z_base_abgdxyz + ! ! Vector by vector multiplication. Need all variants ! to handle multiple requirements from preconditioners @@ -1126,6 +1128,23 @@ contains end subroutine z_base_axpby_a2 + subroutine z_base_abgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_z_base_vect_type), intent(inout) :: x + class(psb_z_base_vect_type), intent(inout) :: y + class(psb_z_base_vect_type), intent(inout) :: z + complex(psb_dpk_), intent (in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + + if (x%is_dev()) call x%sync() + + call y%axpby(m,alpha,x,beta,info) + call z%axpby(m,gamma,y,delta,info) + + end subroutine z_base_abgdxyz + ! ! Multiple variants of two operations: diff --git a/base/modules/serial/psb_z_vect_mod.F90 b/base/modules/serial/psb_z_vect_mod.F90 index e8a34859..54ddfebe 100644 --- a/base/modules/serial/psb_z_vect_mod.F90 +++ b/base/modules/serial/psb_z_vect_mod.F90 @@ -102,6 +102,8 @@ module psb_z_vect_mod procedure, pass(z) :: axpby_v2 => z_vect_axpby_v2 procedure, pass(z) :: axpby_a2 => z_vect_axpby_a2 generic, public :: axpby => axpby_v, axpby_a, axpby_v2, axpby_a2 + procedure, pass(z) :: abgdxyz => z_vect_abgdxyz + procedure, pass(y) :: mlt_v => z_vect_mlt_v procedure, pass(y) :: mlt_a => z_vect_mlt_a procedure, pass(z) :: mlt_a_2 => z_vect_mlt_a_2 @@ -771,6 +773,22 @@ contains end subroutine z_vect_axpby_a2 + subroutine z_vect_abgdxyz(m,alpha,beta,gamma,delta,x, y, z, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_z_vect_type), intent(inout) :: x + class(psb_z_vect_type), intent(inout) :: y + class(psb_z_vect_type), intent(inout) :: z + complex(psb_dpk_), intent (in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + + if (allocated(z%v)) & + call z%abgdxyz(m,alpha,beta,gamma,delta,x,y,info) + + end subroutine z_vect_abgdxyz + + subroutine z_vect_mlt_v(x, y, info) use psi_serial_mod implicit none diff --git a/cuda/psb_c_cuda_vect_mod.F90 b/cuda/psb_c_cuda_vect_mod.F90 index c140dadb..db988e56 100644 --- a/cuda/psb_c_cuda_vect_mod.F90 +++ b/cuda/psb_c_cuda_vect_mod.F90 @@ -90,6 +90,7 @@ module psb_c_cuda_vect_mod procedure, pass(x) :: dot_a => c_cuda_dot_a procedure, pass(y) :: axpby_v => c_cuda_axpby_v procedure, pass(y) :: axpby_a => c_cuda_axpby_a + procedure, pass(z) :: abgdxyz => c_cuda_abgdxyz procedure, pass(y) :: mlt_v => c_cuda_mlt_v procedure, pass(y) :: mlt_a => c_cuda_mlt_a procedure, pass(z) :: mlt_a_2 => c_cuda_mlt_a_2 @@ -911,6 +912,27 @@ contains end subroutine c_cuda_axpby_v + + subroutine c_cuda_abgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_c_base_vect_type), intent(inout) :: x + class(psb_c_base_vect_type), intent(inout) :: y + class(psb_c_vect_cuda), intent(inout) :: z + complex(psb_spk_), intent (in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + + call z%psb_c_base_vect_type(m,alpha,beta,gamma,delta,x,y,info) +!!$ +!!$ if (x%is_dev()) call x%sync() +!!$ +!!$ call y%axpby(m,alpha,x,beta,info) +!!$ call z%axpby(m,gamma,y,delta,info) + + end subroutine c_cuda_abgdxyz + + subroutine c_cuda_axpby_a(m,alpha, x, beta, y, info) use psi_serial_mod implicit none diff --git a/cuda/psb_d_cuda_vect_mod.F90 b/cuda/psb_d_cuda_vect_mod.F90 index 44381c99..7f84807b 100644 --- a/cuda/psb_d_cuda_vect_mod.F90 +++ b/cuda/psb_d_cuda_vect_mod.F90 @@ -90,6 +90,7 @@ module psb_d_cuda_vect_mod procedure, pass(x) :: dot_a => d_cuda_dot_a procedure, pass(y) :: axpby_v => d_cuda_axpby_v procedure, pass(y) :: axpby_a => d_cuda_axpby_a + procedure, pass(z) :: abgdxyz => d_cuda_abgdxyz procedure, pass(y) :: mlt_v => d_cuda_mlt_v procedure, pass(y) :: mlt_a => d_cuda_mlt_a procedure, pass(z) :: mlt_a_2 => d_cuda_mlt_a_2 @@ -911,6 +912,27 @@ contains end subroutine d_cuda_axpby_v + + subroutine d_cuda_abgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_d_base_vect_type), intent(inout) :: x + class(psb_d_base_vect_type), intent(inout) :: y + class(psb_d_vect_cuda), intent(inout) :: z + real(psb_dpk_), intent (in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + + call z%psb_d_base_vect_type(m,alpha,beta,gamma,delta,x,y,info) +!!$ +!!$ if (x%is_dev()) call x%sync() +!!$ +!!$ call y%axpby(m,alpha,x,beta,info) +!!$ call z%axpby(m,gamma,y,delta,info) + + end subroutine d_cuda_abgdxyz + + subroutine d_cuda_axpby_a(m,alpha, x, beta, y, info) use psi_serial_mod implicit none diff --git a/cuda/psb_s_cuda_vect_mod.F90 b/cuda/psb_s_cuda_vect_mod.F90 index 7778eb50..8858c6d9 100644 --- a/cuda/psb_s_cuda_vect_mod.F90 +++ b/cuda/psb_s_cuda_vect_mod.F90 @@ -90,6 +90,7 @@ module psb_s_cuda_vect_mod procedure, pass(x) :: dot_a => s_cuda_dot_a procedure, pass(y) :: axpby_v => s_cuda_axpby_v procedure, pass(y) :: axpby_a => s_cuda_axpby_a + procedure, pass(z) :: abgdxyz => s_cuda_abgdxyz procedure, pass(y) :: mlt_v => s_cuda_mlt_v procedure, pass(y) :: mlt_a => s_cuda_mlt_a procedure, pass(z) :: mlt_a_2 => s_cuda_mlt_a_2 @@ -911,6 +912,27 @@ contains end subroutine s_cuda_axpby_v + + subroutine s_cuda_abgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_s_base_vect_type), intent(inout) :: x + class(psb_s_base_vect_type), intent(inout) :: y + class(psb_s_vect_cuda), intent(inout) :: z + real(psb_spk_), intent (in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + + call z%psb_s_base_vect_type(m,alpha,beta,gamma,delta,x,y,info) +!!$ +!!$ if (x%is_dev()) call x%sync() +!!$ +!!$ call y%axpby(m,alpha,x,beta,info) +!!$ call z%axpby(m,gamma,y,delta,info) + + end subroutine s_cuda_abgdxyz + + subroutine s_cuda_axpby_a(m,alpha, x, beta, y, info) use psi_serial_mod implicit none diff --git a/cuda/psb_z_cuda_vect_mod.F90 b/cuda/psb_z_cuda_vect_mod.F90 index 53484911..a7243ff9 100644 --- a/cuda/psb_z_cuda_vect_mod.F90 +++ b/cuda/psb_z_cuda_vect_mod.F90 @@ -90,6 +90,7 @@ module psb_z_cuda_vect_mod procedure, pass(x) :: dot_a => z_cuda_dot_a procedure, pass(y) :: axpby_v => z_cuda_axpby_v procedure, pass(y) :: axpby_a => z_cuda_axpby_a + procedure, pass(z) :: abgdxyz => z_cuda_abgdxyz procedure, pass(y) :: mlt_v => z_cuda_mlt_v procedure, pass(y) :: mlt_a => z_cuda_mlt_a procedure, pass(z) :: mlt_a_2 => z_cuda_mlt_a_2 @@ -911,6 +912,27 @@ contains end subroutine z_cuda_axpby_v + + subroutine z_cuda_abgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_z_base_vect_type), intent(inout) :: x + class(psb_z_base_vect_type), intent(inout) :: y + class(psb_z_vect_cuda), intent(inout) :: z + complex(psb_dpk_), intent (in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + + call z%psb_z_base_vect_type(m,alpha,beta,gamma,delta,x,y,info) +!!$ +!!$ if (x%is_dev()) call x%sync() +!!$ +!!$ call y%axpby(m,alpha,x,beta,info) +!!$ call z%axpby(m,gamma,y,delta,info) + + end subroutine z_cuda_abgdxyz + + subroutine z_cuda_axpby_a(m,alpha, x, beta, y, info) use psi_serial_mod implicit none diff --git a/test/pargen/psb_d_pde3d.F90 b/test/pargen/psb_d_pde3d.F90 index 4748569c..6e895c00 100644 --- a/test/pargen/psb_d_pde3d.F90 +++ b/test/pargen/psb_d_pde3d.F90 @@ -592,9 +592,9 @@ contains t1 = psb_wtime() if (info == psb_success_) then if (present(amold)) then - call psb_spasb(a,desc_a,info,mold=amold,bld_and=.false.) + call psb_spasb(a,desc_a,info,mold=amold) else - call psb_spasb(a,desc_a,info,afmt=afmt,bld_and=.false.) + call psb_spasb(a,desc_a,info,afmt=afmt) end if end if call psb_barrier(ctxt) @@ -868,8 +868,8 @@ program psb_d_pde3d call psb_errpush(info,name,a_err=ch_err) goto 9999 end if - call psb_print_timers(ctxt) - call psb_exit(ctxt) + + call psb_exit(ctxt) stop 9999 call psb_error(ctxt) From 45f00e6e1963142d6532a02007c910d5ab752e97 Mon Sep 17 00:00:00 2001 From: sfilippone Date: Mon, 12 Feb 2024 15:10:58 +0100 Subject: [PATCH 19/48] Fixed comments --- base/modules/serial/psb_c_base_vect_mod.F90 | 23 ++++++++++++++++++--- base/modules/serial/psb_d_base_vect_mod.F90 | 23 ++++++++++++++++++--- base/modules/serial/psb_s_base_vect_mod.F90 | 23 ++++++++++++++++++--- base/modules/serial/psb_z_base_vect_mod.F90 | 23 ++++++++++++++++++--- 4 files changed, 80 insertions(+), 12 deletions(-) diff --git a/base/modules/serial/psb_c_base_vect_mod.F90 b/base/modules/serial/psb_c_base_vect_mod.F90 index e4f398a7..793df3bc 100644 --- a/base/modules/serial/psb_c_base_vect_mod.F90 +++ b/base/modules/serial/psb_c_base_vect_mod.F90 @@ -1020,7 +1020,7 @@ contains !! \param m Number of entries to be considered !! \param alpha scalar alpha !! \param x The class(base_vect) to be added - !! \param beta scalar alpha + !! \param beta scalar beta !! \param info return code !! subroutine c_base_axpby_v(m,alpha, x, beta, y, info) @@ -1049,7 +1049,7 @@ contains !! \param m Number of entries to be considered !! \param alpha scalar alpha !! \param x The class(base_vect) to be added - !! \param beta scalar alpha + !! \param beta scalar beta !! \param y The class(base_vect) to be added !! \param z The class(base_vect) to be returned !! \param info return code @@ -1080,7 +1080,7 @@ contains !! \param m Number of entries to be considered !! \param alpha scalar alpha !! \param x(:) The array to be added - !! \param beta scalar alpha + !! \param beta scalar beta !! \param info return code !! subroutine c_base_axpby_a(m,alpha, x, beta, y, info) @@ -1128,6 +1128,23 @@ contains end subroutine c_base_axpby_a2 + ! + ! ABGDXYZ is invoked via Z, hence the structure below. + ! + ! + !> Function base_abgdxyz + !! \memberof psb_c_base_vect_type + !! \brief ABGDXYZ combines two AXPBYS y=alpha*x+beta*y, z=gamma*y+delta*zeta + !! \param m Number of entries to be considered + !! \param alpha scalar alpha + !! \param beta scalar beta + !! \param gamma scalar gamma + !! \param delta scalar delta + !! \param x The class(base_vect) to be added + !! \param y The class(base_vect) to be added + !! \param z The class(base_vect) to be added + !! \param info return code + !! subroutine c_base_abgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) use psi_serial_mod implicit none diff --git a/base/modules/serial/psb_d_base_vect_mod.F90 b/base/modules/serial/psb_d_base_vect_mod.F90 index 7ad2d6e7..29a2ccd8 100644 --- a/base/modules/serial/psb_d_base_vect_mod.F90 +++ b/base/modules/serial/psb_d_base_vect_mod.F90 @@ -1027,7 +1027,7 @@ contains !! \param m Number of entries to be considered !! \param alpha scalar alpha !! \param x The class(base_vect) to be added - !! \param beta scalar alpha + !! \param beta scalar beta !! \param info return code !! subroutine d_base_axpby_v(m,alpha, x, beta, y, info) @@ -1056,7 +1056,7 @@ contains !! \param m Number of entries to be considered !! \param alpha scalar alpha !! \param x The class(base_vect) to be added - !! \param beta scalar alpha + !! \param beta scalar beta !! \param y The class(base_vect) to be added !! \param z The class(base_vect) to be returned !! \param info return code @@ -1087,7 +1087,7 @@ contains !! \param m Number of entries to be considered !! \param alpha scalar alpha !! \param x(:) The array to be added - !! \param beta scalar alpha + !! \param beta scalar beta !! \param info return code !! subroutine d_base_axpby_a(m,alpha, x, beta, y, info) @@ -1135,6 +1135,23 @@ contains end subroutine d_base_axpby_a2 + ! + ! ABGDXYZ is invoked via Z, hence the structure below. + ! + ! + !> Function base_abgdxyz + !! \memberof psb_d_base_vect_type + !! \brief ABGDXYZ combines two AXPBYS y=alpha*x+beta*y, z=gamma*y+delta*zeta + !! \param m Number of entries to be considered + !! \param alpha scalar alpha + !! \param beta scalar beta + !! \param gamma scalar gamma + !! \param delta scalar delta + !! \param x The class(base_vect) to be added + !! \param y The class(base_vect) to be added + !! \param z The class(base_vect) to be added + !! \param info return code + !! subroutine d_base_abgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) use psi_serial_mod implicit none diff --git a/base/modules/serial/psb_s_base_vect_mod.F90 b/base/modules/serial/psb_s_base_vect_mod.F90 index 4e9c0dd3..61ae27d2 100644 --- a/base/modules/serial/psb_s_base_vect_mod.F90 +++ b/base/modules/serial/psb_s_base_vect_mod.F90 @@ -1027,7 +1027,7 @@ contains !! \param m Number of entries to be considered !! \param alpha scalar alpha !! \param x The class(base_vect) to be added - !! \param beta scalar alpha + !! \param beta scalar beta !! \param info return code !! subroutine s_base_axpby_v(m,alpha, x, beta, y, info) @@ -1056,7 +1056,7 @@ contains !! \param m Number of entries to be considered !! \param alpha scalar alpha !! \param x The class(base_vect) to be added - !! \param beta scalar alpha + !! \param beta scalar beta !! \param y The class(base_vect) to be added !! \param z The class(base_vect) to be returned !! \param info return code @@ -1087,7 +1087,7 @@ contains !! \param m Number of entries to be considered !! \param alpha scalar alpha !! \param x(:) The array to be added - !! \param beta scalar alpha + !! \param beta scalar beta !! \param info return code !! subroutine s_base_axpby_a(m,alpha, x, beta, y, info) @@ -1135,6 +1135,23 @@ contains end subroutine s_base_axpby_a2 + ! + ! ABGDXYZ is invoked via Z, hence the structure below. + ! + ! + !> Function base_abgdxyz + !! \memberof psb_s_base_vect_type + !! \brief ABGDXYZ combines two AXPBYS y=alpha*x+beta*y, z=gamma*y+delta*zeta + !! \param m Number of entries to be considered + !! \param alpha scalar alpha + !! \param beta scalar beta + !! \param gamma scalar gamma + !! \param delta scalar delta + !! \param x The class(base_vect) to be added + !! \param y The class(base_vect) to be added + !! \param z The class(base_vect) to be added + !! \param info return code + !! subroutine s_base_abgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) use psi_serial_mod implicit none diff --git a/base/modules/serial/psb_z_base_vect_mod.F90 b/base/modules/serial/psb_z_base_vect_mod.F90 index 60c3c854..53f3ea8e 100644 --- a/base/modules/serial/psb_z_base_vect_mod.F90 +++ b/base/modules/serial/psb_z_base_vect_mod.F90 @@ -1020,7 +1020,7 @@ contains !! \param m Number of entries to be considered !! \param alpha scalar alpha !! \param x The class(base_vect) to be added - !! \param beta scalar alpha + !! \param beta scalar beta !! \param info return code !! subroutine z_base_axpby_v(m,alpha, x, beta, y, info) @@ -1049,7 +1049,7 @@ contains !! \param m Number of entries to be considered !! \param alpha scalar alpha !! \param x The class(base_vect) to be added - !! \param beta scalar alpha + !! \param beta scalar beta !! \param y The class(base_vect) to be added !! \param z The class(base_vect) to be returned !! \param info return code @@ -1080,7 +1080,7 @@ contains !! \param m Number of entries to be considered !! \param alpha scalar alpha !! \param x(:) The array to be added - !! \param beta scalar alpha + !! \param beta scalar beta !! \param info return code !! subroutine z_base_axpby_a(m,alpha, x, beta, y, info) @@ -1128,6 +1128,23 @@ contains end subroutine z_base_axpby_a2 + ! + ! ABGDXYZ is invoked via Z, hence the structure below. + ! + ! + !> Function base_abgdxyz + !! \memberof psb_z_base_vect_type + !! \brief ABGDXYZ combines two AXPBYS y=alpha*x+beta*y, z=gamma*y+delta*zeta + !! \param m Number of entries to be considered + !! \param alpha scalar alpha + !! \param beta scalar beta + !! \param gamma scalar gamma + !! \param delta scalar delta + !! \param x The class(base_vect) to be added + !! \param y The class(base_vect) to be added + !! \param z The class(base_vect) to be added + !! \param info return code + !! subroutine z_base_abgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) use psi_serial_mod implicit none From ebc7c6b3b40fe21705db6f4d148128ae14410707 Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Mon, 12 Feb 2024 16:29:48 +0100 Subject: [PATCH 20/48] Fix call to base%abgdxyz --- cuda/psb_c_cuda_vect_mod.F90 | 2 +- cuda/psb_d_cuda_vect_mod.F90 | 2 +- cuda/psb_s_cuda_vect_mod.F90 | 2 +- cuda/psb_z_cuda_vect_mod.F90 | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cuda/psb_c_cuda_vect_mod.F90 b/cuda/psb_c_cuda_vect_mod.F90 index db988e56..56cc80e6 100644 --- a/cuda/psb_c_cuda_vect_mod.F90 +++ b/cuda/psb_c_cuda_vect_mod.F90 @@ -923,7 +923,7 @@ contains complex(psb_spk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info - call z%psb_c_base_vect_type(m,alpha,beta,gamma,delta,x,y,info) + call z%psb_c_base_vect_type%abgdxyz(m,alpha,beta,gamma,delta,x,y,info) !!$ !!$ if (x%is_dev()) call x%sync() !!$ diff --git a/cuda/psb_d_cuda_vect_mod.F90 b/cuda/psb_d_cuda_vect_mod.F90 index 7f84807b..03e65f91 100644 --- a/cuda/psb_d_cuda_vect_mod.F90 +++ b/cuda/psb_d_cuda_vect_mod.F90 @@ -923,7 +923,7 @@ contains real(psb_dpk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info - call z%psb_d_base_vect_type(m,alpha,beta,gamma,delta,x,y,info) + call z%psb_d_base_vect_type%abgdxyz(m,alpha,beta,gamma,delta,x,y,info) !!$ !!$ if (x%is_dev()) call x%sync() !!$ diff --git a/cuda/psb_s_cuda_vect_mod.F90 b/cuda/psb_s_cuda_vect_mod.F90 index 8858c6d9..9616b3a6 100644 --- a/cuda/psb_s_cuda_vect_mod.F90 +++ b/cuda/psb_s_cuda_vect_mod.F90 @@ -923,7 +923,7 @@ contains real(psb_spk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info - call z%psb_s_base_vect_type(m,alpha,beta,gamma,delta,x,y,info) + call z%psb_s_base_vect_type%abgdxyz(m,alpha,beta,gamma,delta,x,y,info) !!$ !!$ if (x%is_dev()) call x%sync() !!$ diff --git a/cuda/psb_z_cuda_vect_mod.F90 b/cuda/psb_z_cuda_vect_mod.F90 index a7243ff9..1153fc61 100644 --- a/cuda/psb_z_cuda_vect_mod.F90 +++ b/cuda/psb_z_cuda_vect_mod.F90 @@ -923,7 +923,7 @@ contains complex(psb_dpk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info - call z%psb_z_base_vect_type(m,alpha,beta,gamma,delta,x,y,info) + call z%psb_z_base_vect_type%abgdxyz(m,alpha,beta,gamma,delta,x,y,info) !!$ !!$ if (x%is_dev()) call x%sync() !!$ From 83ededd02b36c458c2c2ca23ac50b570d6623c8e Mon Sep 17 00:00:00 2001 From: sfilippone Date: Tue, 13 Feb 2024 12:54:37 +0100 Subject: [PATCH 21/48] Implementatino of abgd_xyz --- base/modules/auxil/psi_c_serial_mod.f90 | 13 ++ base/modules/auxil/psi_d_serial_mod.f90 | 13 ++ base/modules/auxil/psi_e_serial_mod.f90 | 13 ++ base/modules/auxil/psi_i2_serial_mod.f90 | 13 ++ base/modules/auxil/psi_m_serial_mod.f90 | 13 ++ base/modules/auxil/psi_s_serial_mod.f90 | 13 ++ base/modules/auxil/psi_z_serial_mod.f90 | 13 ++ base/serial/psi_c_serial_impl.F90 | 225 +++++++++++++++++++++++ base/serial/psi_d_serial_impl.F90 | 225 +++++++++++++++++++++++ base/serial/psi_e_serial_impl.F90 | 225 +++++++++++++++++++++++ base/serial/psi_i2_serial_impl.F90 | 225 +++++++++++++++++++++++ base/serial/psi_m_serial_impl.F90 | 225 +++++++++++++++++++++++ base/serial/psi_s_serial_impl.F90 | 225 +++++++++++++++++++++++ base/serial/psi_z_serial_impl.F90 | 225 +++++++++++++++++++++++ 14 files changed, 1666 insertions(+) diff --git a/base/modules/auxil/psi_c_serial_mod.f90 b/base/modules/auxil/psi_c_serial_mod.f90 index 0fdff04b..6926d6bd 100644 --- a/base/modules/auxil/psi_c_serial_mod.f90 +++ b/base/modules/auxil/psi_c_serial_mod.f90 @@ -99,6 +99,19 @@ module psi_c_serial_mod end subroutine psi_caxpbyv2 end interface psb_geaxpby + interface psi_abgdxyz + subroutine psi_cabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + import :: psb_ipk_, psb_spk_ + implicit none + integer(psb_ipk_), intent(in) :: m + complex(psb_spk_), intent (in) :: x(:) + complex(psb_spk_), intent (inout) :: y(:) + complex(psb_spk_), intent (inout) :: z(:) + complex(psb_spk_), intent (in) :: alpha, beta,gamma,delta + integer(psb_ipk_), intent(out) :: info + end subroutine psi_cabgdxyz + end interface psi_abgdxyz + interface psi_gth subroutine psi_cgthmv(n,k,idx,alpha,x,beta,y) import :: psb_ipk_, psb_spk_ diff --git a/base/modules/auxil/psi_d_serial_mod.f90 b/base/modules/auxil/psi_d_serial_mod.f90 index 0ce14dbb..42185d21 100644 --- a/base/modules/auxil/psi_d_serial_mod.f90 +++ b/base/modules/auxil/psi_d_serial_mod.f90 @@ -99,6 +99,19 @@ module psi_d_serial_mod end subroutine psi_daxpbyv2 end interface psb_geaxpby + interface psi_abgdxyz + subroutine psi_dabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + import :: psb_ipk_, psb_dpk_ + implicit none + integer(psb_ipk_), intent(in) :: m + real(psb_dpk_), intent (in) :: x(:) + real(psb_dpk_), intent (inout) :: y(:) + real(psb_dpk_), intent (inout) :: z(:) + real(psb_dpk_), intent (in) :: alpha, beta,gamma,delta + integer(psb_ipk_), intent(out) :: info + end subroutine psi_dabgdxyz + end interface psi_abgdxyz + interface psi_gth subroutine psi_dgthmv(n,k,idx,alpha,x,beta,y) import :: psb_ipk_, psb_dpk_ diff --git a/base/modules/auxil/psi_e_serial_mod.f90 b/base/modules/auxil/psi_e_serial_mod.f90 index f0372e01..ffba06fd 100644 --- a/base/modules/auxil/psi_e_serial_mod.f90 +++ b/base/modules/auxil/psi_e_serial_mod.f90 @@ -99,6 +99,19 @@ module psi_e_serial_mod end subroutine psi_eaxpbyv2 end interface psb_geaxpby + interface psi_abgdxyz + subroutine psi_eabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + import :: psb_ipk_, psb_lpk_,psb_mpk_, psb_epk_ + implicit none + integer(psb_ipk_), intent(in) :: m + integer(psb_epk_), intent (in) :: x(:) + integer(psb_epk_), intent (inout) :: y(:) + integer(psb_epk_), intent (inout) :: z(:) + integer(psb_epk_), intent (in) :: alpha, beta,gamma,delta + integer(psb_ipk_), intent(out) :: info + end subroutine psi_eabgdxyz + end interface psi_abgdxyz + interface psi_gth subroutine psi_egthmv(n,k,idx,alpha,x,beta,y) import :: psb_ipk_, psb_lpk_,psb_mpk_, psb_epk_ diff --git a/base/modules/auxil/psi_i2_serial_mod.f90 b/base/modules/auxil/psi_i2_serial_mod.f90 index 70dd95e1..d61a1146 100644 --- a/base/modules/auxil/psi_i2_serial_mod.f90 +++ b/base/modules/auxil/psi_i2_serial_mod.f90 @@ -99,6 +99,19 @@ module psi_i2_serial_mod end subroutine psi_i2axpbyv2 end interface psb_geaxpby + interface psi_abgdxyz + subroutine psi_i2abgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + import :: psb_ipk_, psb_lpk_,psb_mpk_, psb_epk_ + implicit none + integer(psb_ipk_), intent(in) :: m + integer(psb_i2pk_), intent (in) :: x(:) + integer(psb_i2pk_), intent (inout) :: y(:) + integer(psb_i2pk_), intent (inout) :: z(:) + integer(psb_i2pk_), intent (in) :: alpha, beta,gamma,delta + integer(psb_ipk_), intent(out) :: info + end subroutine psi_i2abgdxyz + end interface psi_abgdxyz + interface psi_gth subroutine psi_i2gthmv(n,k,idx,alpha,x,beta,y) import :: psb_ipk_, psb_lpk_,psb_mpk_, psb_epk_ diff --git a/base/modules/auxil/psi_m_serial_mod.f90 b/base/modules/auxil/psi_m_serial_mod.f90 index cfd1348e..76131d75 100644 --- a/base/modules/auxil/psi_m_serial_mod.f90 +++ b/base/modules/auxil/psi_m_serial_mod.f90 @@ -99,6 +99,19 @@ module psi_m_serial_mod end subroutine psi_maxpbyv2 end interface psb_geaxpby + interface psi_abgdxyz + subroutine psi_mabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + import :: psb_ipk_, psb_lpk_,psb_mpk_, psb_epk_ + implicit none + integer(psb_ipk_), intent(in) :: m + integer(psb_mpk_), intent (in) :: x(:) + integer(psb_mpk_), intent (inout) :: y(:) + integer(psb_mpk_), intent (inout) :: z(:) + integer(psb_mpk_), intent (in) :: alpha, beta,gamma,delta + integer(psb_ipk_), intent(out) :: info + end subroutine psi_mabgdxyz + end interface psi_abgdxyz + interface psi_gth subroutine psi_mgthmv(n,k,idx,alpha,x,beta,y) import :: psb_ipk_, psb_lpk_,psb_mpk_, psb_epk_ diff --git a/base/modules/auxil/psi_s_serial_mod.f90 b/base/modules/auxil/psi_s_serial_mod.f90 index 25c4a7ef..02b96311 100644 --- a/base/modules/auxil/psi_s_serial_mod.f90 +++ b/base/modules/auxil/psi_s_serial_mod.f90 @@ -99,6 +99,19 @@ module psi_s_serial_mod end subroutine psi_saxpbyv2 end interface psb_geaxpby + interface psi_abgdxyz + subroutine psi_sabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + import :: psb_ipk_, psb_spk_ + implicit none + integer(psb_ipk_), intent(in) :: m + real(psb_spk_), intent (in) :: x(:) + real(psb_spk_), intent (inout) :: y(:) + real(psb_spk_), intent (inout) :: z(:) + real(psb_spk_), intent (in) :: alpha, beta,gamma,delta + integer(psb_ipk_), intent(out) :: info + end subroutine psi_sabgdxyz + end interface psi_abgdxyz + interface psi_gth subroutine psi_sgthmv(n,k,idx,alpha,x,beta,y) import :: psb_ipk_, psb_spk_ diff --git a/base/modules/auxil/psi_z_serial_mod.f90 b/base/modules/auxil/psi_z_serial_mod.f90 index b40cf05a..a86bdd70 100644 --- a/base/modules/auxil/psi_z_serial_mod.f90 +++ b/base/modules/auxil/psi_z_serial_mod.f90 @@ -99,6 +99,19 @@ module psi_z_serial_mod end subroutine psi_zaxpbyv2 end interface psb_geaxpby + interface psi_abgdxyz + subroutine psi_zabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + import :: psb_ipk_, psb_dpk_ + implicit none + integer(psb_ipk_), intent(in) :: m + complex(psb_dpk_), intent (in) :: x(:) + complex(psb_dpk_), intent (inout) :: y(:) + complex(psb_dpk_), intent (inout) :: z(:) + complex(psb_dpk_), intent (in) :: alpha, beta,gamma,delta + integer(psb_ipk_), intent(out) :: info + end subroutine psi_zabgdxyz + end interface psi_abgdxyz + interface psi_gth subroutine psi_zgthmv(n,k,idx,alpha,x,beta,y) import :: psb_ipk_, psb_dpk_ diff --git a/base/serial/psi_c_serial_impl.F90 b/base/serial/psi_c_serial_impl.F90 index a3898349..129e8484 100644 --- a/base/serial/psi_c_serial_impl.F90 +++ b/base/serial/psi_c_serial_impl.F90 @@ -1567,3 +1567,228 @@ subroutine caxpbyv2(m, n, alpha, X, lldx, beta, Y, lldy, Z, lldz, info) return end subroutine caxpbyv2 + +subroutine psi_cabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + use psb_const_mod + use psb_error_mod + implicit none + integer(psb_ipk_), intent(in) :: m + complex(psb_spk_), intent (in) :: x(:) + complex(psb_spk_), intent (inout) :: y(:) + complex(psb_spk_), intent (inout) :: z(:) + complex(psb_spk_), intent (in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i + integer(psb_ipk_) :: int_err(5) + character name*20 + name='cabgdxyz' + + info = psb_success_ + if (m.lt.0) then + info=psb_err_iarg_neg_ + int_err(1)=1 + int_err(2)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(x).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=6 + int_err(2)=1 + int_err(3)=size(x) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(y).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=7 + int_err(2)=1 + int_err(3)=size(y) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(z).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=8 + int_err(2)=1 + int_err(3)=size(z) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + endif + + if (beta == czero) then + if (gamma == czero) then + if (alpha == czero) then + if (delta == czero) then + ! a 0 b 0 g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = czero + z(i) = czero + end do + else if (delta /= czero) then + ! a 0 b 0 g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = czero + z(i) = delta*z(i) + end do + end if + else if (alpha /= czero) then + if (delta == czero) then + ! a n b 0 g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = czero + end do + else if (delta /= czero) then + ! a n b 0 g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = delta*z(i) + end do + + end if + + end if + + else if (gamma /= czero) then + + if (alpha == czero) then + + if (delta == czero) then + ! a 0 b 0 g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = czero + z(i) = czero ! gamma*y(i) + end do + + else if (delta /= czero) then + ! a 0 b 0 g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = czero + z(i) = delta*z(i) + end do + + end if + + else if (alpha /= czero) then + + if (delta == czero) then + ! a n b 0 g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = gamma*y(i) + end do + + else if (delta /= czero) then + ! a n b 0 g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + + end if + + end if + + else if (beta /= czero) then + + if (gamma == czero) then + if (alpha == czero) then + if (delta == czero) then + ! a 0 b n g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = czero + end do + + else if (delta /= czero) then + ! a 0 b n g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = delta*z(i) + end do + + end if + + else if (alpha /= czero) then + if (delta == czero) then + ! a n b n g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = czero + end do + + else if (delta /= czero) then + ! a n b n g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = delta*z(i) + end do + + end if + + end if + else if (gamma /= czero) then + if (alpha == czero) then + if (delta == czero) then + ! a 0 b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = gamma*y(i) + end do + + else if (delta /= czero) then + ! a 0 b n g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + + else if (alpha /= czero) then + if (delta == czero) then + ! a n b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = gamma*y(i) + end do + + else if (delta /= czero) then + ! a n b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + end if + end if + end if + + return + +9999 continue + call fcpsb_serror() + return + +end subroutine psi_cabgdxyz diff --git a/base/serial/psi_d_serial_impl.F90 b/base/serial/psi_d_serial_impl.F90 index 1b5b1442..bd0c82df 100644 --- a/base/serial/psi_d_serial_impl.F90 +++ b/base/serial/psi_d_serial_impl.F90 @@ -1567,3 +1567,228 @@ subroutine daxpbyv2(m, n, alpha, X, lldx, beta, Y, lldy, Z, lldz, info) return end subroutine daxpbyv2 + +subroutine psi_dabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + use psb_const_mod + use psb_error_mod + implicit none + integer(psb_ipk_), intent(in) :: m + real(psb_dpk_), intent (in) :: x(:) + real(psb_dpk_), intent (inout) :: y(:) + real(psb_dpk_), intent (inout) :: z(:) + real(psb_dpk_), intent (in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i + integer(psb_ipk_) :: int_err(5) + character name*20 + name='dabgdxyz' + + info = psb_success_ + if (m.lt.0) then + info=psb_err_iarg_neg_ + int_err(1)=1 + int_err(2)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(x).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=6 + int_err(2)=1 + int_err(3)=size(x) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(y).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=7 + int_err(2)=1 + int_err(3)=size(y) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(z).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=8 + int_err(2)=1 + int_err(3)=size(z) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + endif + + if (beta == dzero) then + if (gamma == dzero) then + if (alpha == dzero) then + if (delta == dzero) then + ! a 0 b 0 g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = dzero + z(i) = dzero + end do + else if (delta /= dzero) then + ! a 0 b 0 g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = dzero + z(i) = delta*z(i) + end do + end if + else if (alpha /= dzero) then + if (delta == dzero) then + ! a n b 0 g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = dzero + end do + else if (delta /= dzero) then + ! a n b 0 g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = delta*z(i) + end do + + end if + + end if + + else if (gamma /= dzero) then + + if (alpha == dzero) then + + if (delta == dzero) then + ! a 0 b 0 g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = dzero + z(i) = dzero ! gamma*y(i) + end do + + else if (delta /= dzero) then + ! a 0 b 0 g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = dzero + z(i) = delta*z(i) + end do + + end if + + else if (alpha /= dzero) then + + if (delta == dzero) then + ! a n b 0 g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = gamma*y(i) + end do + + else if (delta /= dzero) then + ! a n b 0 g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + + end if + + end if + + else if (beta /= dzero) then + + if (gamma == dzero) then + if (alpha == dzero) then + if (delta == dzero) then + ! a 0 b n g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = dzero + end do + + else if (delta /= dzero) then + ! a 0 b n g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = delta*z(i) + end do + + end if + + else if (alpha /= dzero) then + if (delta == dzero) then + ! a n b n g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = dzero + end do + + else if (delta /= dzero) then + ! a n b n g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = delta*z(i) + end do + + end if + + end if + else if (gamma /= dzero) then + if (alpha == dzero) then + if (delta == dzero) then + ! a 0 b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = gamma*y(i) + end do + + else if (delta /= dzero) then + ! a 0 b n g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + + else if (alpha /= dzero) then + if (delta == dzero) then + ! a n b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = gamma*y(i) + end do + + else if (delta /= dzero) then + ! a n b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + end if + end if + end if + + return + +9999 continue + call fcpsb_serror() + return + +end subroutine psi_dabgdxyz diff --git a/base/serial/psi_e_serial_impl.F90 b/base/serial/psi_e_serial_impl.F90 index 9cdcdf0e..8b17aeb8 100644 --- a/base/serial/psi_e_serial_impl.F90 +++ b/base/serial/psi_e_serial_impl.F90 @@ -1567,3 +1567,228 @@ subroutine eaxpbyv2(m, n, alpha, X, lldx, beta, Y, lldy, Z, lldz, info) return end subroutine eaxpbyv2 + +subroutine psi_eabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + use psb_const_mod + use psb_error_mod + implicit none + integer(psb_ipk_), intent(in) :: m + integer(psb_epk_), intent (in) :: x(:) + integer(psb_epk_), intent (inout) :: y(:) + integer(psb_epk_), intent (inout) :: z(:) + integer(psb_epk_), intent (in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i + integer(psb_ipk_) :: int_err(5) + character name*20 + name='eabgdxyz' + + info = psb_success_ + if (m.lt.0) then + info=psb_err_iarg_neg_ + int_err(1)=1 + int_err(2)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(x).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=6 + int_err(2)=1 + int_err(3)=size(x) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(y).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=7 + int_err(2)=1 + int_err(3)=size(y) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(z).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=8 + int_err(2)=1 + int_err(3)=size(z) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + endif + + if (beta == ezero) then + if (gamma == ezero) then + if (alpha == ezero) then + if (delta == ezero) then + ! a 0 b 0 g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = ezero + z(i) = ezero + end do + else if (delta /= ezero) then + ! a 0 b 0 g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = ezero + z(i) = delta*z(i) + end do + end if + else if (alpha /= ezero) then + if (delta == ezero) then + ! a n b 0 g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = ezero + end do + else if (delta /= ezero) then + ! a n b 0 g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = delta*z(i) + end do + + end if + + end if + + else if (gamma /= ezero) then + + if (alpha == ezero) then + + if (delta == ezero) then + ! a 0 b 0 g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = ezero + z(i) = ezero ! gamma*y(i) + end do + + else if (delta /= ezero) then + ! a 0 b 0 g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = ezero + z(i) = delta*z(i) + end do + + end if + + else if (alpha /= ezero) then + + if (delta == ezero) then + ! a n b 0 g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = gamma*y(i) + end do + + else if (delta /= ezero) then + ! a n b 0 g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + + end if + + end if + + else if (beta /= ezero) then + + if (gamma == ezero) then + if (alpha == ezero) then + if (delta == ezero) then + ! a 0 b n g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = ezero + end do + + else if (delta /= ezero) then + ! a 0 b n g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = delta*z(i) + end do + + end if + + else if (alpha /= ezero) then + if (delta == ezero) then + ! a n b n g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = ezero + end do + + else if (delta /= ezero) then + ! a n b n g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = delta*z(i) + end do + + end if + + end if + else if (gamma /= ezero) then + if (alpha == ezero) then + if (delta == ezero) then + ! a 0 b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = gamma*y(i) + end do + + else if (delta /= ezero) then + ! a 0 b n g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + + else if (alpha /= ezero) then + if (delta == ezero) then + ! a n b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = gamma*y(i) + end do + + else if (delta /= ezero) then + ! a n b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + end if + end if + end if + + return + +9999 continue + call fcpsb_serror() + return + +end subroutine psi_eabgdxyz diff --git a/base/serial/psi_i2_serial_impl.F90 b/base/serial/psi_i2_serial_impl.F90 index d25617a9..9a2c36c6 100644 --- a/base/serial/psi_i2_serial_impl.F90 +++ b/base/serial/psi_i2_serial_impl.F90 @@ -1567,3 +1567,228 @@ subroutine i2axpbyv2(m, n, alpha, X, lldx, beta, Y, lldy, Z, lldz, info) return end subroutine i2axpbyv2 + +subroutine psi_i2abgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + use psb_const_mod + use psb_error_mod + implicit none + integer(psb_ipk_), intent(in) :: m + integer(psb_i2pk_), intent (in) :: x(:) + integer(psb_i2pk_), intent (inout) :: y(:) + integer(psb_i2pk_), intent (inout) :: z(:) + integer(psb_i2pk_), intent (in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i + integer(psb_ipk_) :: int_err(5) + character name*20 + name='i2abgdxyz' + + info = psb_success_ + if (m.lt.0) then + info=psb_err_iarg_neg_ + int_err(1)=1 + int_err(2)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(x).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=6 + int_err(2)=1 + int_err(3)=size(x) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(y).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=7 + int_err(2)=1 + int_err(3)=size(y) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(z).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=8 + int_err(2)=1 + int_err(3)=size(z) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + endif + + if (beta == i2zero) then + if (gamma == i2zero) then + if (alpha == i2zero) then + if (delta == i2zero) then + ! a 0 b 0 g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = i2zero + z(i) = i2zero + end do + else if (delta /= i2zero) then + ! a 0 b 0 g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = i2zero + z(i) = delta*z(i) + end do + end if + else if (alpha /= i2zero) then + if (delta == i2zero) then + ! a n b 0 g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = i2zero + end do + else if (delta /= i2zero) then + ! a n b 0 g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = delta*z(i) + end do + + end if + + end if + + else if (gamma /= i2zero) then + + if (alpha == i2zero) then + + if (delta == i2zero) then + ! a 0 b 0 g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = i2zero + z(i) = i2zero ! gamma*y(i) + end do + + else if (delta /= i2zero) then + ! a 0 b 0 g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = i2zero + z(i) = delta*z(i) + end do + + end if + + else if (alpha /= i2zero) then + + if (delta == i2zero) then + ! a n b 0 g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = gamma*y(i) + end do + + else if (delta /= i2zero) then + ! a n b 0 g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + + end if + + end if + + else if (beta /= i2zero) then + + if (gamma == i2zero) then + if (alpha == i2zero) then + if (delta == i2zero) then + ! a 0 b n g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = i2zero + end do + + else if (delta /= i2zero) then + ! a 0 b n g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = delta*z(i) + end do + + end if + + else if (alpha /= i2zero) then + if (delta == i2zero) then + ! a n b n g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = i2zero + end do + + else if (delta /= i2zero) then + ! a n b n g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = delta*z(i) + end do + + end if + + end if + else if (gamma /= i2zero) then + if (alpha == i2zero) then + if (delta == i2zero) then + ! a 0 b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = gamma*y(i) + end do + + else if (delta /= i2zero) then + ! a 0 b n g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + + else if (alpha /= i2zero) then + if (delta == i2zero) then + ! a n b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = gamma*y(i) + end do + + else if (delta /= i2zero) then + ! a n b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + end if + end if + end if + + return + +9999 continue + call fcpsb_serror() + return + +end subroutine psi_i2abgdxyz diff --git a/base/serial/psi_m_serial_impl.F90 b/base/serial/psi_m_serial_impl.F90 index 05c8e60f..dd114a45 100644 --- a/base/serial/psi_m_serial_impl.F90 +++ b/base/serial/psi_m_serial_impl.F90 @@ -1567,3 +1567,228 @@ subroutine maxpbyv2(m, n, alpha, X, lldx, beta, Y, lldy, Z, lldz, info) return end subroutine maxpbyv2 + +subroutine psi_mabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + use psb_const_mod + use psb_error_mod + implicit none + integer(psb_ipk_), intent(in) :: m + integer(psb_mpk_), intent (in) :: x(:) + integer(psb_mpk_), intent (inout) :: y(:) + integer(psb_mpk_), intent (inout) :: z(:) + integer(psb_mpk_), intent (in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i + integer(psb_ipk_) :: int_err(5) + character name*20 + name='mabgdxyz' + + info = psb_success_ + if (m.lt.0) then + info=psb_err_iarg_neg_ + int_err(1)=1 + int_err(2)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(x).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=6 + int_err(2)=1 + int_err(3)=size(x) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(y).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=7 + int_err(2)=1 + int_err(3)=size(y) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(z).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=8 + int_err(2)=1 + int_err(3)=size(z) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + endif + + if (beta == mzero) then + if (gamma == mzero) then + if (alpha == mzero) then + if (delta == mzero) then + ! a 0 b 0 g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = mzero + z(i) = mzero + end do + else if (delta /= mzero) then + ! a 0 b 0 g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = mzero + z(i) = delta*z(i) + end do + end if + else if (alpha /= mzero) then + if (delta == mzero) then + ! a n b 0 g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = mzero + end do + else if (delta /= mzero) then + ! a n b 0 g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = delta*z(i) + end do + + end if + + end if + + else if (gamma /= mzero) then + + if (alpha == mzero) then + + if (delta == mzero) then + ! a 0 b 0 g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = mzero + z(i) = mzero ! gamma*y(i) + end do + + else if (delta /= mzero) then + ! a 0 b 0 g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = mzero + z(i) = delta*z(i) + end do + + end if + + else if (alpha /= mzero) then + + if (delta == mzero) then + ! a n b 0 g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = gamma*y(i) + end do + + else if (delta /= mzero) then + ! a n b 0 g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + + end if + + end if + + else if (beta /= mzero) then + + if (gamma == mzero) then + if (alpha == mzero) then + if (delta == mzero) then + ! a 0 b n g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = mzero + end do + + else if (delta /= mzero) then + ! a 0 b n g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = delta*z(i) + end do + + end if + + else if (alpha /= mzero) then + if (delta == mzero) then + ! a n b n g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = mzero + end do + + else if (delta /= mzero) then + ! a n b n g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = delta*z(i) + end do + + end if + + end if + else if (gamma /= mzero) then + if (alpha == mzero) then + if (delta == mzero) then + ! a 0 b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = gamma*y(i) + end do + + else if (delta /= mzero) then + ! a 0 b n g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + + else if (alpha /= mzero) then + if (delta == mzero) then + ! a n b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = gamma*y(i) + end do + + else if (delta /= mzero) then + ! a n b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + end if + end if + end if + + return + +9999 continue + call fcpsb_serror() + return + +end subroutine psi_mabgdxyz diff --git a/base/serial/psi_s_serial_impl.F90 b/base/serial/psi_s_serial_impl.F90 index 26a57e68..8e2dda0f 100644 --- a/base/serial/psi_s_serial_impl.F90 +++ b/base/serial/psi_s_serial_impl.F90 @@ -1567,3 +1567,228 @@ subroutine saxpbyv2(m, n, alpha, X, lldx, beta, Y, lldy, Z, lldz, info) return end subroutine saxpbyv2 + +subroutine psi_sabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + use psb_const_mod + use psb_error_mod + implicit none + integer(psb_ipk_), intent(in) :: m + real(psb_spk_), intent (in) :: x(:) + real(psb_spk_), intent (inout) :: y(:) + real(psb_spk_), intent (inout) :: z(:) + real(psb_spk_), intent (in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i + integer(psb_ipk_) :: int_err(5) + character name*20 + name='sabgdxyz' + + info = psb_success_ + if (m.lt.0) then + info=psb_err_iarg_neg_ + int_err(1)=1 + int_err(2)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(x).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=6 + int_err(2)=1 + int_err(3)=size(x) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(y).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=7 + int_err(2)=1 + int_err(3)=size(y) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(z).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=8 + int_err(2)=1 + int_err(3)=size(z) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + endif + + if (beta == szero) then + if (gamma == szero) then + if (alpha == szero) then + if (delta == szero) then + ! a 0 b 0 g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = szero + z(i) = szero + end do + else if (delta /= szero) then + ! a 0 b 0 g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = szero + z(i) = delta*z(i) + end do + end if + else if (alpha /= szero) then + if (delta == szero) then + ! a n b 0 g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = szero + end do + else if (delta /= szero) then + ! a n b 0 g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = delta*z(i) + end do + + end if + + end if + + else if (gamma /= szero) then + + if (alpha == szero) then + + if (delta == szero) then + ! a 0 b 0 g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = szero + z(i) = szero ! gamma*y(i) + end do + + else if (delta /= szero) then + ! a 0 b 0 g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = szero + z(i) = delta*z(i) + end do + + end if + + else if (alpha /= szero) then + + if (delta == szero) then + ! a n b 0 g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = gamma*y(i) + end do + + else if (delta /= szero) then + ! a n b 0 g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + + end if + + end if + + else if (beta /= szero) then + + if (gamma == szero) then + if (alpha == szero) then + if (delta == szero) then + ! a 0 b n g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = szero + end do + + else if (delta /= szero) then + ! a 0 b n g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = delta*z(i) + end do + + end if + + else if (alpha /= szero) then + if (delta == szero) then + ! a n b n g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = szero + end do + + else if (delta /= szero) then + ! a n b n g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = delta*z(i) + end do + + end if + + end if + else if (gamma /= szero) then + if (alpha == szero) then + if (delta == szero) then + ! a 0 b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = gamma*y(i) + end do + + else if (delta /= szero) then + ! a 0 b n g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + + else if (alpha /= szero) then + if (delta == szero) then + ! a n b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = gamma*y(i) + end do + + else if (delta /= szero) then + ! a n b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + end if + end if + end if + + return + +9999 continue + call fcpsb_serror() + return + +end subroutine psi_sabgdxyz diff --git a/base/serial/psi_z_serial_impl.F90 b/base/serial/psi_z_serial_impl.F90 index 0b15b2d6..c6a7e01d 100644 --- a/base/serial/psi_z_serial_impl.F90 +++ b/base/serial/psi_z_serial_impl.F90 @@ -1567,3 +1567,228 @@ subroutine zaxpbyv2(m, n, alpha, X, lldx, beta, Y, lldy, Z, lldz, info) return end subroutine zaxpbyv2 + +subroutine psi_zabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) + use psb_const_mod + use psb_error_mod + implicit none + integer(psb_ipk_), intent(in) :: m + complex(psb_dpk_), intent (in) :: x(:) + complex(psb_dpk_), intent (inout) :: y(:) + complex(psb_dpk_), intent (inout) :: z(:) + complex(psb_dpk_), intent (in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i + integer(psb_ipk_) :: int_err(5) + character name*20 + name='zabgdxyz' + + info = psb_success_ + if (m.lt.0) then + info=psb_err_iarg_neg_ + int_err(1)=1 + int_err(2)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(x).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=6 + int_err(2)=1 + int_err(3)=size(x) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(y).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=7 + int_err(2)=1 + int_err(3)=size(y) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(z).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=8 + int_err(2)=1 + int_err(3)=size(z) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + endif + + if (beta == zzero) then + if (gamma == zzero) then + if (alpha == zzero) then + if (delta == zzero) then + ! a 0 b 0 g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = zzero + z(i) = zzero + end do + else if (delta /= zzero) then + ! a 0 b 0 g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = zzero + z(i) = delta*z(i) + end do + end if + else if (alpha /= zzero) then + if (delta == zzero) then + ! a n b 0 g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = zzero + end do + else if (delta /= zzero) then + ! a n b 0 g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = delta*z(i) + end do + + end if + + end if + + else if (gamma /= zzero) then + + if (alpha == zzero) then + + if (delta == zzero) then + ! a 0 b 0 g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = zzero + z(i) = zzero ! gamma*y(i) + end do + + else if (delta /= zzero) then + ! a 0 b 0 g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = zzero + z(i) = delta*z(i) + end do + + end if + + else if (alpha /= zzero) then + + if (delta == zzero) then + ! a n b 0 g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = gamma*y(i) + end do + + else if (delta /= zzero) then + ! a n b 0 g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + + end if + + end if + + else if (beta /= zzero) then + + if (gamma == zzero) then + if (alpha == zzero) then + if (delta == zzero) then + ! a 0 b n g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = zzero + end do + + else if (delta /= zzero) then + ! a 0 b n g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = delta*z(i) + end do + + end if + + else if (alpha /= zzero) then + if (delta == zzero) then + ! a n b n g 0 d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = zzero + end do + + else if (delta /= zzero) then + ! a n b n g 0 d n + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = delta*z(i) + end do + + end if + + end if + else if (gamma /= zzero) then + if (alpha == zzero) then + if (delta == zzero) then + ! a 0 b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = gamma*y(i) + end do + + else if (delta /= zzero) then + ! a 0 b n g n d n + !$omp parallel do private(i) + do i=1,m + y(i) = beta*y(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + + else if (alpha /= zzero) then + if (delta == zzero) then + ! a n b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = gamma*y(i) + end do + + else if (delta /= zzero) then + ! a n b n g n d 0 + !$omp parallel do private(i) + do i=1,m + y(i) = alpha*x(i)+beta*y(i) + z(i) = gamma*y(i)+delta*z(i) + end do + + end if + end if + end if + end if + + return + +9999 continue + call fcpsb_serror() + return + +end subroutine psi_zabgdxyz From 6c53b6ec79dce2d6d1d0c985f11f015622b9f7fd Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Tue, 13 Feb 2024 15:48:43 +0100 Subject: [PATCH 22/48] Fix typo in interface for psb_abgdxyz --- base/modules/psblas/psb_c_psblas_mod.F90 | 4 ++-- base/modules/psblas/psb_d_psblas_mod.F90 | 4 ++-- base/modules/psblas/psb_s_psblas_mod.F90 | 4 ++-- base/modules/psblas/psb_z_psblas_mod.F90 | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/base/modules/psblas/psb_c_psblas_mod.F90 b/base/modules/psblas/psb_c_psblas_mod.F90 index d660597a..7f7f937c 100644 --- a/base/modules/psblas/psb_c_psblas_mod.F90 +++ b/base/modules/psblas/psb_c_psblas_mod.F90 @@ -143,7 +143,7 @@ module psb_c_psblas_mod end subroutine psb_caxpby end interface - interface psb_abgdxyx + interface psb_abgdxyz subroutine psb_cabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& & desc_a, info) import :: psb_desc_type, psb_spk_, psb_ipk_, & @@ -155,7 +155,7 @@ module psb_c_psblas_mod type(psb_desc_type), intent (in) :: desc_a integer(psb_ipk_), intent(out) :: info end subroutine psb_cabgdxyz_vect - end interface psb_abgdxyx + end interface psb_abgdxyz interface psb_geamax function psb_camax(x, desc_a, info, jx,global) diff --git a/base/modules/psblas/psb_d_psblas_mod.F90 b/base/modules/psblas/psb_d_psblas_mod.F90 index 734ed633..12090956 100644 --- a/base/modules/psblas/psb_d_psblas_mod.F90 +++ b/base/modules/psblas/psb_d_psblas_mod.F90 @@ -143,7 +143,7 @@ module psb_d_psblas_mod end subroutine psb_daxpby end interface - interface psb_abgdxyx + interface psb_abgdxyz subroutine psb_dabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& & desc_a, info) import :: psb_desc_type, psb_dpk_, psb_ipk_, & @@ -155,7 +155,7 @@ module psb_d_psblas_mod type(psb_desc_type), intent (in) :: desc_a integer(psb_ipk_), intent(out) :: info end subroutine psb_dabgdxyz_vect - end interface psb_abgdxyx + end interface psb_abgdxyz interface psb_geamax function psb_damax(x, desc_a, info, jx,global) diff --git a/base/modules/psblas/psb_s_psblas_mod.F90 b/base/modules/psblas/psb_s_psblas_mod.F90 index 0f7d29e6..7a7ce783 100644 --- a/base/modules/psblas/psb_s_psblas_mod.F90 +++ b/base/modules/psblas/psb_s_psblas_mod.F90 @@ -143,7 +143,7 @@ module psb_s_psblas_mod end subroutine psb_saxpby end interface - interface psb_abgdxyx + interface psb_abgdxyz subroutine psb_sabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& & desc_a, info) import :: psb_desc_type, psb_spk_, psb_ipk_, & @@ -155,7 +155,7 @@ module psb_s_psblas_mod type(psb_desc_type), intent (in) :: desc_a integer(psb_ipk_), intent(out) :: info end subroutine psb_sabgdxyz_vect - end interface psb_abgdxyx + end interface psb_abgdxyz interface psb_geamax function psb_samax(x, desc_a, info, jx,global) diff --git a/base/modules/psblas/psb_z_psblas_mod.F90 b/base/modules/psblas/psb_z_psblas_mod.F90 index 17674600..bcfe9caa 100644 --- a/base/modules/psblas/psb_z_psblas_mod.F90 +++ b/base/modules/psblas/psb_z_psblas_mod.F90 @@ -143,7 +143,7 @@ module psb_z_psblas_mod end subroutine psb_zaxpby end interface - interface psb_abgdxyx + interface psb_abgdxyz subroutine psb_zabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& & desc_a, info) import :: psb_desc_type, psb_dpk_, psb_ipk_, & @@ -155,7 +155,7 @@ module psb_z_psblas_mod type(psb_desc_type), intent (in) :: desc_a integer(psb_ipk_), intent(out) :: info end subroutine psb_zabgdxyz_vect - end interface psb_abgdxyx + end interface psb_abgdxyz interface psb_geamax function psb_zamax(x, desc_a, info, jx,global) From 29669b56a24868b42c9bee2836fae05f9a57f480 Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Tue, 13 Feb 2024 16:07:06 +0100 Subject: [PATCH 23/48] Implementation of psb_abgdxyz --- base/psblas/psb_caxpby.f90 | 82 ++++++++++++++++++++++++++++++++++++++ base/psblas/psb_daxpby.f90 | 82 ++++++++++++++++++++++++++++++++++++++ base/psblas/psb_saxpby.f90 | 82 ++++++++++++++++++++++++++++++++++++++ base/psblas/psb_zaxpby.f90 | 82 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 328 insertions(+) diff --git a/base/psblas/psb_caxpby.f90 b/base/psblas/psb_caxpby.f90 index da3dd93b..a41e6ef2 100644 --- a/base/psblas/psb_caxpby.f90 +++ b/base/psblas/psb_caxpby.f90 @@ -741,3 +741,85 @@ subroutine psb_caddconst_vect(x,b,z,desc_a,info) return end subroutine psb_caddconst_vect + + +subroutine psb_cabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& + & desc_a, info) + import :: psb_desc_type, psb_spk_, psb_ipk_, & + & psb_c_vect_type, psb_cspmat_type + type(psb_c_vect_type), intent (inout) :: x + type(psb_c_vect_type), intent (inout) :: y + type(psb_c_vect_type), intent (inout) :: z + complex(psb_spk_), intent (in) :: alpha, beta, gamma, delta + type(psb_desc_type), intent (in) :: desc_a + integer(psb_ipk_), intent(out) :: info + ! locals + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me,& + & err_act, iix, jjx, iiy, jjy + integer(psb_lpk_) :: ix, ijx, iy, ijy, m + character(len=20) :: name, ch_err + + name='psb_c_addconst_vect' + if (psb_errstatus_fatal()) return + info=psb_success_ + call psb_erractionsave(err_act) + + ctxt=desc_a%get_context() + + call psb_info(ctxt, me, np) + if (np == -ione) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(y%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(z%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + + ix = ione + iy = ione + + m = desc_a%get_global_rows() + + ! check vector correctness + call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect 1' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + call psb_chkvect(m,lone,z%get_nrows(),iy,lone,desc_a,info,iiy,jjy) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect 2' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + + if(desc_a%get_local_rows() > 0) then + call z%abgdxyz(alpha,beta,gamma,delta,x,y,info) + end if + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + +end subroutine psb_cabgdxyz_vect + diff --git a/base/psblas/psb_daxpby.f90 b/base/psblas/psb_daxpby.f90 index c386f8f2..4805727e 100644 --- a/base/psblas/psb_daxpby.f90 +++ b/base/psblas/psb_daxpby.f90 @@ -741,3 +741,85 @@ subroutine psb_daddconst_vect(x,b,z,desc_a,info) return end subroutine psb_daddconst_vect + + +subroutine psb_dabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& + & desc_a, info) + import :: psb_desc_type, psb_dpk_, psb_ipk_, & + & psb_d_vect_type, psb_dspmat_type + type(psb_d_vect_type), intent (inout) :: x + type(psb_d_vect_type), intent (inout) :: y + type(psb_d_vect_type), intent (inout) :: z + real(psb_dpk_), intent (in) :: alpha, beta, gamma, delta + type(psb_desc_type), intent (in) :: desc_a + integer(psb_ipk_), intent(out) :: info + ! locals + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me,& + & err_act, iix, jjx, iiy, jjy + integer(psb_lpk_) :: ix, ijx, iy, ijy, m + character(len=20) :: name, ch_err + + name='psb_d_addconst_vect' + if (psb_errstatus_fatal()) return + info=psb_success_ + call psb_erractionsave(err_act) + + ctxt=desc_a%get_context() + + call psb_info(ctxt, me, np) + if (np == -ione) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(y%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(z%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + + ix = ione + iy = ione + + m = desc_a%get_global_rows() + + ! check vector correctness + call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect 1' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + call psb_chkvect(m,lone,z%get_nrows(),iy,lone,desc_a,info,iiy,jjy) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect 2' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + + if(desc_a%get_local_rows() > 0) then + call z%abgdxyz(alpha,beta,gamma,delta,x,y,info) + end if + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + +end subroutine psb_dabgdxyz_vect + diff --git a/base/psblas/psb_saxpby.f90 b/base/psblas/psb_saxpby.f90 index 78f4d01a..581d64cd 100644 --- a/base/psblas/psb_saxpby.f90 +++ b/base/psblas/psb_saxpby.f90 @@ -741,3 +741,85 @@ subroutine psb_saddconst_vect(x,b,z,desc_a,info) return end subroutine psb_saddconst_vect + + +subroutine psb_sabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& + & desc_a, info) + import :: psb_desc_type, psb_spk_, psb_ipk_, & + & psb_s_vect_type, psb_sspmat_type + type(psb_s_vect_type), intent (inout) :: x + type(psb_s_vect_type), intent (inout) :: y + type(psb_s_vect_type), intent (inout) :: z + real(psb_spk_), intent (in) :: alpha, beta, gamma, delta + type(psb_desc_type), intent (in) :: desc_a + integer(psb_ipk_), intent(out) :: info + ! locals + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me,& + & err_act, iix, jjx, iiy, jjy + integer(psb_lpk_) :: ix, ijx, iy, ijy, m + character(len=20) :: name, ch_err + + name='psb_s_addconst_vect' + if (psb_errstatus_fatal()) return + info=psb_success_ + call psb_erractionsave(err_act) + + ctxt=desc_a%get_context() + + call psb_info(ctxt, me, np) + if (np == -ione) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(y%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(z%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + + ix = ione + iy = ione + + m = desc_a%get_global_rows() + + ! check vector correctness + call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect 1' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + call psb_chkvect(m,lone,z%get_nrows(),iy,lone,desc_a,info,iiy,jjy) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect 2' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + + if(desc_a%get_local_rows() > 0) then + call z%abgdxyz(alpha,beta,gamma,delta,x,y,info) + end if + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + +end subroutine psb_sabgdxyz_vect + diff --git a/base/psblas/psb_zaxpby.f90 b/base/psblas/psb_zaxpby.f90 index 2258f38f..df13f242 100644 --- a/base/psblas/psb_zaxpby.f90 +++ b/base/psblas/psb_zaxpby.f90 @@ -741,3 +741,85 @@ subroutine psb_zaddconst_vect(x,b,z,desc_a,info) return end subroutine psb_zaddconst_vect + + +subroutine psb_zabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& + & desc_a, info) + import :: psb_desc_type, psb_dpk_, psb_ipk_, & + & psb_z_vect_type, psb_zspmat_type + type(psb_z_vect_type), intent (inout) :: x + type(psb_z_vect_type), intent (inout) :: y + type(psb_z_vect_type), intent (inout) :: z + complex(psb_dpk_), intent (in) :: alpha, beta, gamma, delta + type(psb_desc_type), intent (in) :: desc_a + integer(psb_ipk_), intent(out) :: info + ! locals + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me,& + & err_act, iix, jjx, iiy, jjy + integer(psb_lpk_) :: ix, ijx, iy, ijy, m + character(len=20) :: name, ch_err + + name='psb_z_addconst_vect' + if (psb_errstatus_fatal()) return + info=psb_success_ + call psb_erractionsave(err_act) + + ctxt=desc_a%get_context() + + call psb_info(ctxt, me, np) + if (np == -ione) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(y%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(z%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + + ix = ione + iy = ione + + m = desc_a%get_global_rows() + + ! check vector correctness + call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect 1' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + call psb_chkvect(m,lone,z%get_nrows(),iy,lone,desc_a,info,iiy,jjy) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect 2' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + + if(desc_a%get_local_rows() > 0) then + call z%abgdxyz(alpha,beta,gamma,delta,x,y,info) + end if + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + +end subroutine psb_zabgdxyz_vect + From 5c3d5f023582a8b4773ad6df999ee038206d6954 Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Tue, 13 Feb 2024 16:13:06 +0100 Subject: [PATCH 24/48] Silly bug in abgdxyz implementation --- base/psblas/psb_caxpby.f90 | 4 ++-- base/psblas/psb_daxpby.f90 | 4 ++-- base/psblas/psb_saxpby.f90 | 4 ++-- base/psblas/psb_zaxpby.f90 | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/base/psblas/psb_caxpby.f90 b/base/psblas/psb_caxpby.f90 index a41e6ef2..f19f2caf 100644 --- a/base/psblas/psb_caxpby.f90 +++ b/base/psblas/psb_caxpby.f90 @@ -745,8 +745,8 @@ end subroutine psb_caddconst_vect subroutine psb_cabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& & desc_a, info) - import :: psb_desc_type, psb_spk_, psb_ipk_, & - & psb_c_vect_type, psb_cspmat_type + use psb_base_mod, psb_protect_name => psb_cabgdxyz_vect + implicit none type(psb_c_vect_type), intent (inout) :: x type(psb_c_vect_type), intent (inout) :: y type(psb_c_vect_type), intent (inout) :: z diff --git a/base/psblas/psb_daxpby.f90 b/base/psblas/psb_daxpby.f90 index 4805727e..690c5080 100644 --- a/base/psblas/psb_daxpby.f90 +++ b/base/psblas/psb_daxpby.f90 @@ -745,8 +745,8 @@ end subroutine psb_daddconst_vect subroutine psb_dabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& & desc_a, info) - import :: psb_desc_type, psb_dpk_, psb_ipk_, & - & psb_d_vect_type, psb_dspmat_type + use psb_base_mod, psb_protect_name => psb_dabgdxyz_vect + implicit none type(psb_d_vect_type), intent (inout) :: x type(psb_d_vect_type), intent (inout) :: y type(psb_d_vect_type), intent (inout) :: z diff --git a/base/psblas/psb_saxpby.f90 b/base/psblas/psb_saxpby.f90 index 581d64cd..4b48f363 100644 --- a/base/psblas/psb_saxpby.f90 +++ b/base/psblas/psb_saxpby.f90 @@ -745,8 +745,8 @@ end subroutine psb_saddconst_vect subroutine psb_sabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& & desc_a, info) - import :: psb_desc_type, psb_spk_, psb_ipk_, & - & psb_s_vect_type, psb_sspmat_type + use psb_base_mod, psb_protect_name => psb_sabgdxyz_vect + implicit none type(psb_s_vect_type), intent (inout) :: x type(psb_s_vect_type), intent (inout) :: y type(psb_s_vect_type), intent (inout) :: z diff --git a/base/psblas/psb_zaxpby.f90 b/base/psblas/psb_zaxpby.f90 index df13f242..6bacacda 100644 --- a/base/psblas/psb_zaxpby.f90 +++ b/base/psblas/psb_zaxpby.f90 @@ -745,8 +745,8 @@ end subroutine psb_zaddconst_vect subroutine psb_zabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& & desc_a, info) - import :: psb_desc_type, psb_dpk_, psb_ipk_, & - & psb_z_vect_type, psb_zspmat_type + use psb_base_mod, psb_protect_name => psb_zabgdxyz_vect + implicit none type(psb_z_vect_type), intent (inout) :: x type(psb_z_vect_type), intent (inout) :: y type(psb_z_vect_type), intent (inout) :: z From 3121c435822da38de4e22dfe0a7c99a519728243 Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Tue, 13 Feb 2024 16:16:13 +0100 Subject: [PATCH 25/48] Silly bug in abgdxyz implementation --- base/psblas/psb_caxpby.f90 | 7 ++++--- base/psblas/psb_daxpby.f90 | 7 ++++--- base/psblas/psb_saxpby.f90 | 7 ++++--- base/psblas/psb_zaxpby.f90 | 7 ++++--- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/base/psblas/psb_caxpby.f90 b/base/psblas/psb_caxpby.f90 index f19f2caf..3351149b 100644 --- a/base/psblas/psb_caxpby.f90 +++ b/base/psblas/psb_caxpby.f90 @@ -757,7 +757,7 @@ subroutine psb_cabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& type(psb_ctxt_type) :: ctxt integer(psb_ipk_) :: np, me,& & err_act, iix, jjx, iiy, jjy - integer(psb_lpk_) :: ix, ijx, iy, ijy, m + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nr character(len=20) :: name, ch_err name='psb_c_addconst_vect' @@ -792,7 +792,8 @@ subroutine psb_cabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& ix = ione iy = ione - m = desc_a%get_global_rows() + m = desc_a%get_global_rows() + nr = desc_a%get_local_rows() ! check vector correctness call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx) @@ -811,7 +812,7 @@ subroutine psb_cabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& end if if(desc_a%get_local_rows() > 0) then - call z%abgdxyz(alpha,beta,gamma,delta,x,y,info) + call z%abgdxyz(nr,alpha,beta,gamma,delta,x,y,info) end if call psb_erractionrestore(err_act) diff --git a/base/psblas/psb_daxpby.f90 b/base/psblas/psb_daxpby.f90 index 690c5080..8d43b6ac 100644 --- a/base/psblas/psb_daxpby.f90 +++ b/base/psblas/psb_daxpby.f90 @@ -757,7 +757,7 @@ subroutine psb_dabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& type(psb_ctxt_type) :: ctxt integer(psb_ipk_) :: np, me,& & err_act, iix, jjx, iiy, jjy - integer(psb_lpk_) :: ix, ijx, iy, ijy, m + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nr character(len=20) :: name, ch_err name='psb_d_addconst_vect' @@ -792,7 +792,8 @@ subroutine psb_dabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& ix = ione iy = ione - m = desc_a%get_global_rows() + m = desc_a%get_global_rows() + nr = desc_a%get_local_rows() ! check vector correctness call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx) @@ -811,7 +812,7 @@ subroutine psb_dabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& end if if(desc_a%get_local_rows() > 0) then - call z%abgdxyz(alpha,beta,gamma,delta,x,y,info) + call z%abgdxyz(nr,alpha,beta,gamma,delta,x,y,info) end if call psb_erractionrestore(err_act) diff --git a/base/psblas/psb_saxpby.f90 b/base/psblas/psb_saxpby.f90 index 4b48f363..6a5441cd 100644 --- a/base/psblas/psb_saxpby.f90 +++ b/base/psblas/psb_saxpby.f90 @@ -757,7 +757,7 @@ subroutine psb_sabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& type(psb_ctxt_type) :: ctxt integer(psb_ipk_) :: np, me,& & err_act, iix, jjx, iiy, jjy - integer(psb_lpk_) :: ix, ijx, iy, ijy, m + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nr character(len=20) :: name, ch_err name='psb_s_addconst_vect' @@ -792,7 +792,8 @@ subroutine psb_sabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& ix = ione iy = ione - m = desc_a%get_global_rows() + m = desc_a%get_global_rows() + nr = desc_a%get_local_rows() ! check vector correctness call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx) @@ -811,7 +812,7 @@ subroutine psb_sabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& end if if(desc_a%get_local_rows() > 0) then - call z%abgdxyz(alpha,beta,gamma,delta,x,y,info) + call z%abgdxyz(nr,alpha,beta,gamma,delta,x,y,info) end if call psb_erractionrestore(err_act) diff --git a/base/psblas/psb_zaxpby.f90 b/base/psblas/psb_zaxpby.f90 index 6bacacda..75f16ea8 100644 --- a/base/psblas/psb_zaxpby.f90 +++ b/base/psblas/psb_zaxpby.f90 @@ -757,7 +757,7 @@ subroutine psb_zabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& type(psb_ctxt_type) :: ctxt integer(psb_ipk_) :: np, me,& & err_act, iix, jjx, iiy, jjy - integer(psb_lpk_) :: ix, ijx, iy, ijy, m + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nr character(len=20) :: name, ch_err name='psb_z_addconst_vect' @@ -792,7 +792,8 @@ subroutine psb_zabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& ix = ione iy = ione - m = desc_a%get_global_rows() + m = desc_a%get_global_rows() + nr = desc_a%get_local_rows() ! check vector correctness call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx) @@ -811,7 +812,7 @@ subroutine psb_zabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& end if if(desc_a%get_local_rows() > 0) then - call z%abgdxyz(alpha,beta,gamma,delta,x,y,info) + call z%abgdxyz(nr,alpha,beta,gamma,delta,x,y,info) end if call psb_erractionrestore(err_act) From 9ced67634dc17248602b5871bbc3e2aa6ccdbdb8 Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Wed, 14 Feb 2024 08:52:38 +0100 Subject: [PATCH 26/48] Fix KIND for NR in axpby --- base/psblas/psb_caxpby.f90 | 4 ++-- base/psblas/psb_daxpby.f90 | 4 ++-- base/psblas/psb_saxpby.f90 | 4 ++-- base/psblas/psb_zaxpby.f90 | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/base/psblas/psb_caxpby.f90 b/base/psblas/psb_caxpby.f90 index 3351149b..7c22bb06 100644 --- a/base/psblas/psb_caxpby.f90 +++ b/base/psblas/psb_caxpby.f90 @@ -756,8 +756,8 @@ subroutine psb_cabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& ! locals type(psb_ctxt_type) :: ctxt integer(psb_ipk_) :: np, me,& - & err_act, iix, jjx, iiy, jjy - integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nr + & err_act, iix, jjx, iiy, jjy, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m character(len=20) :: name, ch_err name='psb_c_addconst_vect' diff --git a/base/psblas/psb_daxpby.f90 b/base/psblas/psb_daxpby.f90 index 8d43b6ac..1de77647 100644 --- a/base/psblas/psb_daxpby.f90 +++ b/base/psblas/psb_daxpby.f90 @@ -756,8 +756,8 @@ subroutine psb_dabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& ! locals type(psb_ctxt_type) :: ctxt integer(psb_ipk_) :: np, me,& - & err_act, iix, jjx, iiy, jjy - integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nr + & err_act, iix, jjx, iiy, jjy, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m character(len=20) :: name, ch_err name='psb_d_addconst_vect' diff --git a/base/psblas/psb_saxpby.f90 b/base/psblas/psb_saxpby.f90 index 6a5441cd..1b1f24e6 100644 --- a/base/psblas/psb_saxpby.f90 +++ b/base/psblas/psb_saxpby.f90 @@ -756,8 +756,8 @@ subroutine psb_sabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& ! locals type(psb_ctxt_type) :: ctxt integer(psb_ipk_) :: np, me,& - & err_act, iix, jjx, iiy, jjy - integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nr + & err_act, iix, jjx, iiy, jjy, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m character(len=20) :: name, ch_err name='psb_s_addconst_vect' diff --git a/base/psblas/psb_zaxpby.f90 b/base/psblas/psb_zaxpby.f90 index 75f16ea8..0f37a1f4 100644 --- a/base/psblas/psb_zaxpby.f90 +++ b/base/psblas/psb_zaxpby.f90 @@ -756,8 +756,8 @@ subroutine psb_zabgdxyz_vect(alpha, beta, gamma, delta, x, y, z,& ! locals type(psb_ctxt_type) :: ctxt integer(psb_ipk_) :: np, me,& - & err_act, iix, jjx, iiy, jjy - integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nr + & err_act, iix, jjx, iiy, jjy, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m character(len=20) :: name, ch_err name='psb_z_addconst_vect' From 4e611bb078d54f1eea74e6439db722414c9c269a Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Wed, 14 Feb 2024 15:55:55 +0100 Subject: [PATCH 27/48] Enable psi_abgdxyz --- base/modules/serial/psb_c_base_vect_mod.F90 | 15 +++++++++++---- base/modules/serial/psb_c_vect_mod.F90 | 2 +- base/modules/serial/psb_d_base_vect_mod.F90 | 15 +++++++++++---- base/modules/serial/psb_d_vect_mod.F90 | 2 +- base/modules/serial/psb_s_base_vect_mod.F90 | 15 +++++++++++---- base/modules/serial/psb_s_vect_mod.F90 | 2 +- base/modules/serial/psb_z_base_vect_mod.F90 | 15 +++++++++++---- base/modules/serial/psb_z_vect_mod.F90 | 2 +- base/serial/psi_c_serial_impl.F90 | 4 ++-- base/serial/psi_d_serial_impl.F90 | 4 ++-- base/serial/psi_e_serial_impl.F90 | 4 ++-- base/serial/psi_i2_serial_impl.F90 | 4 ++-- base/serial/psi_m_serial_impl.F90 | 4 ++-- base/serial/psi_s_serial_impl.F90 | 4 ++-- base/serial/psi_z_serial_impl.F90 | 4 ++-- 15 files changed, 62 insertions(+), 34 deletions(-) diff --git a/base/modules/serial/psb_c_base_vect_mod.F90 b/base/modules/serial/psb_c_base_vect_mod.F90 index 793df3bc..a4772103 100644 --- a/base/modules/serial/psb_c_base_vect_mod.F90 +++ b/base/modules/serial/psb_c_base_vect_mod.F90 @@ -1155,10 +1155,17 @@ contains complex(psb_spk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info - if (x%is_dev()) call x%sync() - - call y%axpby(m,alpha,x,beta,info) - call z%axpby(m,gamma,y,delta,info) + if (.false.) then + if (x%is_dev()) call x%sync() + + call y%axpby(m,alpha,x,beta,info) + call z%axpby(m,gamma,y,delta,info) + else + if (x%is_dev().and.(alpha/=czero))) call x%sync() + if (y%is_dev().and.(beta/=czero)) call y%sync() + if (z%is_dev().and.(delta/=czero)) call z%sync() + call psi_cabgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) + end if end subroutine c_base_abgdxyz diff --git a/base/modules/serial/psb_c_vect_mod.F90 b/base/modules/serial/psb_c_vect_mod.F90 index 8b2941ff..2eebb0da 100644 --- a/base/modules/serial/psb_c_vect_mod.F90 +++ b/base/modules/serial/psb_c_vect_mod.F90 @@ -1152,7 +1152,7 @@ contains end if end function c_vect_nrm2_weight - + function c_vect_nrm2_weight_mask(n,x,w,id,info,aux) result(res) use psi_serial_mod implicit none diff --git a/base/modules/serial/psb_d_base_vect_mod.F90 b/base/modules/serial/psb_d_base_vect_mod.F90 index 29a2ccd8..59b43fce 100644 --- a/base/modules/serial/psb_d_base_vect_mod.F90 +++ b/base/modules/serial/psb_d_base_vect_mod.F90 @@ -1162,10 +1162,17 @@ contains real(psb_dpk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info - if (x%is_dev()) call x%sync() - - call y%axpby(m,alpha,x,beta,info) - call z%axpby(m,gamma,y,delta,info) + if (.false.) then + if (x%is_dev()) call x%sync() + + call y%axpby(m,alpha,x,beta,info) + call z%axpby(m,gamma,y,delta,info) + else + if (x%is_dev().and.(alpha/=dzero))) call x%sync() + if (y%is_dev().and.(beta/=dzero)) call y%sync() + if (z%is_dev().and.(delta/=dzero)) call z%sync() + call psi_dabgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) + end if end subroutine d_base_abgdxyz diff --git a/base/modules/serial/psb_d_vect_mod.F90 b/base/modules/serial/psb_d_vect_mod.F90 index ef75be87..bbb966ed 100644 --- a/base/modules/serial/psb_d_vect_mod.F90 +++ b/base/modules/serial/psb_d_vect_mod.F90 @@ -1159,7 +1159,7 @@ contains end if end function d_vect_nrm2_weight - + function d_vect_nrm2_weight_mask(n,x,w,id,info,aux) result(res) use psi_serial_mod implicit none diff --git a/base/modules/serial/psb_s_base_vect_mod.F90 b/base/modules/serial/psb_s_base_vect_mod.F90 index 61ae27d2..dee48ca5 100644 --- a/base/modules/serial/psb_s_base_vect_mod.F90 +++ b/base/modules/serial/psb_s_base_vect_mod.F90 @@ -1162,10 +1162,17 @@ contains real(psb_spk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info - if (x%is_dev()) call x%sync() - - call y%axpby(m,alpha,x,beta,info) - call z%axpby(m,gamma,y,delta,info) + if (.false.) then + if (x%is_dev()) call x%sync() + + call y%axpby(m,alpha,x,beta,info) + call z%axpby(m,gamma,y,delta,info) + else + if (x%is_dev().and.(alpha/=szero))) call x%sync() + if (y%is_dev().and.(beta/=szero)) call y%sync() + if (z%is_dev().and.(delta/=szero)) call z%sync() + call psi_sabgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) + end if end subroutine s_base_abgdxyz diff --git a/base/modules/serial/psb_s_vect_mod.F90 b/base/modules/serial/psb_s_vect_mod.F90 index 34479856..0ffd199f 100644 --- a/base/modules/serial/psb_s_vect_mod.F90 +++ b/base/modules/serial/psb_s_vect_mod.F90 @@ -1159,7 +1159,7 @@ contains end if end function s_vect_nrm2_weight - + function s_vect_nrm2_weight_mask(n,x,w,id,info,aux) result(res) use psi_serial_mod implicit none diff --git a/base/modules/serial/psb_z_base_vect_mod.F90 b/base/modules/serial/psb_z_base_vect_mod.F90 index 53f3ea8e..0ab2f945 100644 --- a/base/modules/serial/psb_z_base_vect_mod.F90 +++ b/base/modules/serial/psb_z_base_vect_mod.F90 @@ -1155,10 +1155,17 @@ contains complex(psb_dpk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info - if (x%is_dev()) call x%sync() - - call y%axpby(m,alpha,x,beta,info) - call z%axpby(m,gamma,y,delta,info) + if (.false.) then + if (x%is_dev()) call x%sync() + + call y%axpby(m,alpha,x,beta,info) + call z%axpby(m,gamma,y,delta,info) + else + if (x%is_dev().and.(alpha/=zzero))) call x%sync() + if (y%is_dev().and.(beta/=zzero)) call y%sync() + if (z%is_dev().and.(delta/=zzero)) call z%sync() + call psi_zabgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) + end if end subroutine z_base_abgdxyz diff --git a/base/modules/serial/psb_z_vect_mod.F90 b/base/modules/serial/psb_z_vect_mod.F90 index 54ddfebe..1ea1fd4a 100644 --- a/base/modules/serial/psb_z_vect_mod.F90 +++ b/base/modules/serial/psb_z_vect_mod.F90 @@ -1152,7 +1152,7 @@ contains end if end function z_vect_nrm2_weight - + function z_vect_nrm2_weight_mask(n,x,w,id,info,aux) result(res) use psi_serial_mod implicit none diff --git a/base/serial/psi_c_serial_impl.F90 b/base/serial/psi_c_serial_impl.F90 index 129e8484..557220e5 100644 --- a/base/serial/psi_c_serial_impl.F90 +++ b/base/serial/psi_c_serial_impl.F90 @@ -1616,7 +1616,7 @@ subroutine psi_cabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) call fcpsb_errpush(info,name,int_err) goto 9999 endif - + if (beta == czero) then if (gamma == czero) then if (alpha == czero) then @@ -1773,7 +1773,7 @@ subroutine psi_cabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) end do else if (delta /= czero) then - ! a n b n g n d 0 + ! a n b n g n d n !$omp parallel do private(i) do i=1,m y(i) = alpha*x(i)+beta*y(i) diff --git a/base/serial/psi_d_serial_impl.F90 b/base/serial/psi_d_serial_impl.F90 index bd0c82df..d423b401 100644 --- a/base/serial/psi_d_serial_impl.F90 +++ b/base/serial/psi_d_serial_impl.F90 @@ -1616,7 +1616,7 @@ subroutine psi_dabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) call fcpsb_errpush(info,name,int_err) goto 9999 endif - + if (beta == dzero) then if (gamma == dzero) then if (alpha == dzero) then @@ -1773,7 +1773,7 @@ subroutine psi_dabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) end do else if (delta /= dzero) then - ! a n b n g n d 0 + ! a n b n g n d n !$omp parallel do private(i) do i=1,m y(i) = alpha*x(i)+beta*y(i) diff --git a/base/serial/psi_e_serial_impl.F90 b/base/serial/psi_e_serial_impl.F90 index 8b17aeb8..c7977c35 100644 --- a/base/serial/psi_e_serial_impl.F90 +++ b/base/serial/psi_e_serial_impl.F90 @@ -1616,7 +1616,7 @@ subroutine psi_eabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) call fcpsb_errpush(info,name,int_err) goto 9999 endif - + if (beta == ezero) then if (gamma == ezero) then if (alpha == ezero) then @@ -1773,7 +1773,7 @@ subroutine psi_eabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) end do else if (delta /= ezero) then - ! a n b n g n d 0 + ! a n b n g n d n !$omp parallel do private(i) do i=1,m y(i) = alpha*x(i)+beta*y(i) diff --git a/base/serial/psi_i2_serial_impl.F90 b/base/serial/psi_i2_serial_impl.F90 index 9a2c36c6..ce4aff80 100644 --- a/base/serial/psi_i2_serial_impl.F90 +++ b/base/serial/psi_i2_serial_impl.F90 @@ -1616,7 +1616,7 @@ subroutine psi_i2abgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) call fcpsb_errpush(info,name,int_err) goto 9999 endif - + if (beta == i2zero) then if (gamma == i2zero) then if (alpha == i2zero) then @@ -1773,7 +1773,7 @@ subroutine psi_i2abgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) end do else if (delta /= i2zero) then - ! a n b n g n d 0 + ! a n b n g n d n !$omp parallel do private(i) do i=1,m y(i) = alpha*x(i)+beta*y(i) diff --git a/base/serial/psi_m_serial_impl.F90 b/base/serial/psi_m_serial_impl.F90 index dd114a45..8d9d19f4 100644 --- a/base/serial/psi_m_serial_impl.F90 +++ b/base/serial/psi_m_serial_impl.F90 @@ -1616,7 +1616,7 @@ subroutine psi_mabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) call fcpsb_errpush(info,name,int_err) goto 9999 endif - + if (beta == mzero) then if (gamma == mzero) then if (alpha == mzero) then @@ -1773,7 +1773,7 @@ subroutine psi_mabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) end do else if (delta /= mzero) then - ! a n b n g n d 0 + ! a n b n g n d n !$omp parallel do private(i) do i=1,m y(i) = alpha*x(i)+beta*y(i) diff --git a/base/serial/psi_s_serial_impl.F90 b/base/serial/psi_s_serial_impl.F90 index 8e2dda0f..df251b27 100644 --- a/base/serial/psi_s_serial_impl.F90 +++ b/base/serial/psi_s_serial_impl.F90 @@ -1616,7 +1616,7 @@ subroutine psi_sabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) call fcpsb_errpush(info,name,int_err) goto 9999 endif - + if (beta == szero) then if (gamma == szero) then if (alpha == szero) then @@ -1773,7 +1773,7 @@ subroutine psi_sabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) end do else if (delta /= szero) then - ! a n b n g n d 0 + ! a n b n g n d n !$omp parallel do private(i) do i=1,m y(i) = alpha*x(i)+beta*y(i) diff --git a/base/serial/psi_z_serial_impl.F90 b/base/serial/psi_z_serial_impl.F90 index c6a7e01d..44ea5ae7 100644 --- a/base/serial/psi_z_serial_impl.F90 +++ b/base/serial/psi_z_serial_impl.F90 @@ -1616,7 +1616,7 @@ subroutine psi_zabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) call fcpsb_errpush(info,name,int_err) goto 9999 endif - + if (beta == zzero) then if (gamma == zzero) then if (alpha == zzero) then @@ -1773,7 +1773,7 @@ subroutine psi_zabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) end do else if (delta /= zzero) then - ! a n b n g n d 0 + ! a n b n g n d n !$omp parallel do private(i) do i=1,m y(i) = alpha*x(i)+beta*y(i) From 2a40b82b5830d17dfaa8a731abe76ec8bae5fdba Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Wed, 14 Feb 2024 16:01:16 +0100 Subject: [PATCH 28/48] Fix typo in base_vect_mod --- base/modules/serial/psb_c_base_vect_mod.F90 | 2 +- base/modules/serial/psb_d_base_vect_mod.F90 | 2 +- base/modules/serial/psb_s_base_vect_mod.F90 | 2 +- base/modules/serial/psb_z_base_vect_mod.F90 | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/base/modules/serial/psb_c_base_vect_mod.F90 b/base/modules/serial/psb_c_base_vect_mod.F90 index a4772103..b158ac64 100644 --- a/base/modules/serial/psb_c_base_vect_mod.F90 +++ b/base/modules/serial/psb_c_base_vect_mod.F90 @@ -1161,7 +1161,7 @@ contains call y%axpby(m,alpha,x,beta,info) call z%axpby(m,gamma,y,delta,info) else - if (x%is_dev().and.(alpha/=czero))) call x%sync() + if (x%is_dev().and.(alpha/=czero)) call x%sync() if (y%is_dev().and.(beta/=czero)) call y%sync() if (z%is_dev().and.(delta/=czero)) call z%sync() call psi_cabgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) diff --git a/base/modules/serial/psb_d_base_vect_mod.F90 b/base/modules/serial/psb_d_base_vect_mod.F90 index 59b43fce..f53bc590 100644 --- a/base/modules/serial/psb_d_base_vect_mod.F90 +++ b/base/modules/serial/psb_d_base_vect_mod.F90 @@ -1168,7 +1168,7 @@ contains call y%axpby(m,alpha,x,beta,info) call z%axpby(m,gamma,y,delta,info) else - if (x%is_dev().and.(alpha/=dzero))) call x%sync() + if (x%is_dev().and.(alpha/=dzero)) call x%sync() if (y%is_dev().and.(beta/=dzero)) call y%sync() if (z%is_dev().and.(delta/=dzero)) call z%sync() call psi_dabgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) diff --git a/base/modules/serial/psb_s_base_vect_mod.F90 b/base/modules/serial/psb_s_base_vect_mod.F90 index dee48ca5..12626c72 100644 --- a/base/modules/serial/psb_s_base_vect_mod.F90 +++ b/base/modules/serial/psb_s_base_vect_mod.F90 @@ -1168,7 +1168,7 @@ contains call y%axpby(m,alpha,x,beta,info) call z%axpby(m,gamma,y,delta,info) else - if (x%is_dev().and.(alpha/=szero))) call x%sync() + if (x%is_dev().and.(alpha/=szero)) call x%sync() if (y%is_dev().and.(beta/=szero)) call y%sync() if (z%is_dev().and.(delta/=szero)) call z%sync() call psi_sabgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) diff --git a/base/modules/serial/psb_z_base_vect_mod.F90 b/base/modules/serial/psb_z_base_vect_mod.F90 index 0ab2f945..fe990a9e 100644 --- a/base/modules/serial/psb_z_base_vect_mod.F90 +++ b/base/modules/serial/psb_z_base_vect_mod.F90 @@ -1161,7 +1161,7 @@ contains call y%axpby(m,alpha,x,beta,info) call z%axpby(m,gamma,y,delta,info) else - if (x%is_dev().and.(alpha/=zzero))) call x%sync() + if (x%is_dev().and.(alpha/=zzero)) call x%sync() if (y%is_dev().and.(beta/=zzero)) call y%sync() if (z%is_dev().and.(delta/=zzero)) call z%sync() call psi_zabgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) From b8f9badf954aec07365af71b4d7fc1209fcc890a Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Wed, 14 Feb 2024 20:05:52 +0100 Subject: [PATCH 29/48] Fix interface between vect and base_vect%ABGD --- base/modules/serial/psb_c_vect_mod.F90 | 2 +- base/modules/serial/psb_d_vect_mod.F90 | 2 +- base/modules/serial/psb_s_vect_mod.F90 | 2 +- base/modules/serial/psb_z_vect_mod.F90 | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/base/modules/serial/psb_c_vect_mod.F90 b/base/modules/serial/psb_c_vect_mod.F90 index 2eebb0da..e0488def 100644 --- a/base/modules/serial/psb_c_vect_mod.F90 +++ b/base/modules/serial/psb_c_vect_mod.F90 @@ -784,7 +784,7 @@ contains integer(psb_ipk_), intent(out) :: info if (allocated(z%v)) & - call z%abgdxyz(m,alpha,beta,gamma,delta,x,y,info) + call z%v%abgdxyz(m,alpha,beta,gamma,delta,x%v,y%v,info) end subroutine c_vect_abgdxyz diff --git a/base/modules/serial/psb_d_vect_mod.F90 b/base/modules/serial/psb_d_vect_mod.F90 index bbb966ed..07007452 100644 --- a/base/modules/serial/psb_d_vect_mod.F90 +++ b/base/modules/serial/psb_d_vect_mod.F90 @@ -791,7 +791,7 @@ contains integer(psb_ipk_), intent(out) :: info if (allocated(z%v)) & - call z%abgdxyz(m,alpha,beta,gamma,delta,x,y,info) + call z%v%abgdxyz(m,alpha,beta,gamma,delta,x%v,y%v,info) end subroutine d_vect_abgdxyz diff --git a/base/modules/serial/psb_s_vect_mod.F90 b/base/modules/serial/psb_s_vect_mod.F90 index 0ffd199f..aa16a04d 100644 --- a/base/modules/serial/psb_s_vect_mod.F90 +++ b/base/modules/serial/psb_s_vect_mod.F90 @@ -791,7 +791,7 @@ contains integer(psb_ipk_), intent(out) :: info if (allocated(z%v)) & - call z%abgdxyz(m,alpha,beta,gamma,delta,x,y,info) + call z%v%abgdxyz(m,alpha,beta,gamma,delta,x%v,y%v,info) end subroutine s_vect_abgdxyz diff --git a/base/modules/serial/psb_z_vect_mod.F90 b/base/modules/serial/psb_z_vect_mod.F90 index 1ea1fd4a..58bf6b18 100644 --- a/base/modules/serial/psb_z_vect_mod.F90 +++ b/base/modules/serial/psb_z_vect_mod.F90 @@ -784,7 +784,7 @@ contains integer(psb_ipk_), intent(out) :: info if (allocated(z%v)) & - call z%abgdxyz(m,alpha,beta,gamma,delta,x,y,info) + call z%v%abgdxyz(m,alpha,beta,gamma,delta,x%v,y%v,info) end subroutine z_vect_abgdxyz From f4c7604f610fd91127bc3e68b8467a1638acb301 Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Sat, 17 Feb 2024 09:40:09 +0100 Subject: [PATCH 30/48] Fix base implementation of abgdxyz to call set_host --- base/modules/serial/psb_c_base_vect_mod.F90 | 2 ++ base/modules/serial/psb_d_base_vect_mod.F90 | 2 ++ base/modules/serial/psb_s_base_vect_mod.F90 | 2 ++ base/modules/serial/psb_z_base_vect_mod.F90 | 2 ++ 4 files changed, 8 insertions(+) diff --git a/base/modules/serial/psb_c_base_vect_mod.F90 b/base/modules/serial/psb_c_base_vect_mod.F90 index b158ac64..5a468d55 100644 --- a/base/modules/serial/psb_c_base_vect_mod.F90 +++ b/base/modules/serial/psb_c_base_vect_mod.F90 @@ -1165,6 +1165,8 @@ contains if (y%is_dev().and.(beta/=czero)) call y%sync() if (z%is_dev().and.(delta/=czero)) call z%sync() call psi_cabgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) + call y%set_host() + call z%set_host() end if end subroutine c_base_abgdxyz diff --git a/base/modules/serial/psb_d_base_vect_mod.F90 b/base/modules/serial/psb_d_base_vect_mod.F90 index f53bc590..8f583cd3 100644 --- a/base/modules/serial/psb_d_base_vect_mod.F90 +++ b/base/modules/serial/psb_d_base_vect_mod.F90 @@ -1172,6 +1172,8 @@ contains if (y%is_dev().and.(beta/=dzero)) call y%sync() if (z%is_dev().and.(delta/=dzero)) call z%sync() call psi_dabgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) + call y%set_host() + call z%set_host() end if end subroutine d_base_abgdxyz diff --git a/base/modules/serial/psb_s_base_vect_mod.F90 b/base/modules/serial/psb_s_base_vect_mod.F90 index 12626c72..85bb3bda 100644 --- a/base/modules/serial/psb_s_base_vect_mod.F90 +++ b/base/modules/serial/psb_s_base_vect_mod.F90 @@ -1172,6 +1172,8 @@ contains if (y%is_dev().and.(beta/=szero)) call y%sync() if (z%is_dev().and.(delta/=szero)) call z%sync() call psi_sabgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) + call y%set_host() + call z%set_host() end if end subroutine s_base_abgdxyz diff --git a/base/modules/serial/psb_z_base_vect_mod.F90 b/base/modules/serial/psb_z_base_vect_mod.F90 index fe990a9e..b30b1586 100644 --- a/base/modules/serial/psb_z_base_vect_mod.F90 +++ b/base/modules/serial/psb_z_base_vect_mod.F90 @@ -1165,6 +1165,8 @@ contains if (y%is_dev().and.(beta/=zzero)) call y%sync() if (z%is_dev().and.(delta/=zzero)) call z%sync() call psi_zabgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) + call y%set_host() + call z%set_host() end if end subroutine z_base_abgdxyz From a41b209144ed25837f1b8c8196c1e4b87569b02a Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Sat, 17 Feb 2024 17:18:59 +0100 Subject: [PATCH 31/48] Better AXPBY implementation in CUDA. --- cuda/spgpu/kernels/caxpby.cu | 40 ++++++++++++++++++++++++++++++----- cuda/spgpu/kernels/daxpby.cu | 41 +++++++++++++++++++++++++++++++----- cuda/spgpu/kernels/saxpby.cu | 30 +++++++++++++++++++++++--- cuda/spgpu/kernels/zaxpby.cu | 29 ++++++++++++++++++++++--- 4 files changed, 124 insertions(+), 16 deletions(-) diff --git a/cuda/spgpu/kernels/caxpby.cu b/cuda/spgpu/kernels/caxpby.cu index d3d326ef..16eb87ed 100644 --- a/cuda/spgpu/kernels/caxpby.cu +++ b/cuda/spgpu/kernels/caxpby.cu @@ -32,8 +32,9 @@ extern "C" __global__ void spgpuCaxpby_krn(cuFloatComplex *z, int n, cuFloatComplex beta, cuFloatComplex *y, cuFloatComplex alpha, cuFloatComplex* x) { int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; - - if (id < n) + unsigned int gridSize = blockDim.x * gridDim.x; + for ( ; id < n; id +=gridSize) + //if (id,n) { // Since z, x and y are accessed with the same offset by the same thread, // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). @@ -45,7 +46,30 @@ __global__ void spgpuCaxpby_krn(cuFloatComplex *z, int n, cuFloatComplex beta, c } } +#if 1 +void spgpuCaxpby(spgpuHandle_t handle, + __device cuFloatComplex *z, + int n, + cuFloatComplex beta, + __device cuFloatComplex *y, + cuFloatComplex alpha, + __device cuFloatComplex* x) +{ + int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; + int num_mp, max_threads_mp, num_blocks_mp, num_blocks; + dim3 block(BLOCK_SIZE); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + num_mp = deviceProp.multiProcessorCount; + max_threads_mp = deviceProp.maxThreadsPerMultiProcessor; + num_blocks_mp = max_threads_mp/BLOCK_SIZE; + num_blocks = num_blocks_mp*num_mp; + dim3 grid(num_blocks); + + spgpuCaxpby_krn<<currentStream>>>(z, n, beta, y, alpha, x); +} +#else void spgpuCaxpby_(spgpuHandle_t handle, __device cuFloatComplex *z, int n, @@ -55,9 +79,15 @@ void spgpuCaxpby_(spgpuHandle_t handle, __device cuFloatComplex* x) { int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; - + int num_mp, max_threads_mp, num_blocks_mp, num_blocks; dim3 block(BLOCK_SIZE); - dim3 grid(msize); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + num_mp = deviceProp.multiProcessorCount; + max_threads_mp = deviceProp.maxThreadsPerMultiProcessor; + num_blocks_mp = max_threads_mp/BLOCK_SIZE; + num_blocks = num_blocks_mp*num_mp; + dim3 grid(num_blocks); spgpuCaxpby_krn<<currentStream>>>(z, n, beta, y, alpha, x); } @@ -86,7 +116,7 @@ void spgpuCaxpby(spgpuHandle_t handle, cudaCheckError("CUDA error on saxpby"); } - +#endif void spgpuCmaxpby(spgpuHandle_t handle, __device cuFloatComplex *z, int n, diff --git a/cuda/spgpu/kernels/daxpby.cu b/cuda/spgpu/kernels/daxpby.cu index 83724ce2..a0a163a2 100644 --- a/cuda/spgpu/kernels/daxpby.cu +++ b/cuda/spgpu/kernels/daxpby.cu @@ -16,6 +16,7 @@ #include "cudadebug.h" #include "cudalang.h" +#include extern "C" { @@ -31,8 +32,9 @@ extern "C" __global__ void spgpuDaxpby_krn(double *z, int n, double beta, double *y, double alpha, double* x) { int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; - - if (id < n) + unsigned int gridSize = blockDim.x * gridDim.x; + for ( ; id < n; id +=gridSize) + //if (id,n) { // Since z, x and y are accessed with the same offset by the same thread, // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). @@ -44,8 +46,9 @@ __global__ void spgpuDaxpby_krn(double *z, int n, double beta, double *y, double } } +#if 1 -void spgpuDaxpby_(spgpuHandle_t handle, +void spgpuDaxpby(spgpuHandle_t handle, __device double *z, int n, double beta, @@ -54,9 +57,37 @@ void spgpuDaxpby_(spgpuHandle_t handle, __device double* x) { int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; + int num_mp, max_threads_mp, num_blocks_mp, num_blocks; + dim3 block(BLOCK_SIZE); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + num_mp = deviceProp.multiProcessorCount; + max_threads_mp = deviceProp.maxThreadsPerMultiProcessor; + num_blocks_mp = max_threads_mp/BLOCK_SIZE; + num_blocks = num_blocks_mp*num_mp; + dim3 grid(num_blocks); + spgpuDaxpby_krn<<currentStream>>>(z, n, beta, y, alpha, x); +} +#else +void spgpuDaxpby_(spgpuHandle_t handle, + __device double *z, + int n, + double beta, + __device double *y, + double alpha, + __device double* x) +{ + int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; + int num_mp, max_threads_mp, num_blocks_mp, num_blocks; dim3 block(BLOCK_SIZE); - dim3 grid(msize); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + num_mp = deviceProp.multiProcessorCount; + max_threads_mp = deviceProp.maxThreadsPerMultiProcessor; + num_blocks_mp = max_threads_mp/BLOCK_SIZE; + num_blocks = num_blocks_mp*num_mp; + dim3 grid(num_blocks); spgpuDaxpby_krn<<currentStream>>>(z, n, beta, y, alpha, x); } @@ -84,7 +115,7 @@ void spgpuDaxpby(spgpuHandle_t handle, cudaCheckError("CUDA error on daxpby"); } - +#endif void spgpuDmaxpby(spgpuHandle_t handle, __device double *z, int n, diff --git a/cuda/spgpu/kernels/saxpby.cu b/cuda/spgpu/kernels/saxpby.cu index 2c46f19e..42e2a7a7 100644 --- a/cuda/spgpu/kernels/saxpby.cu +++ b/cuda/spgpu/kernels/saxpby.cu @@ -30,8 +30,9 @@ extern "C" __global__ void spgpuSaxpby_krn(float *z, int n, float beta, float *y, float alpha, float* x) { int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; - - if (id < n) + unsigned int gridSize = blockDim.x * gridDim.x; + for ( ; id < n; id +=gridSize) + //if (id,n) { // Since z, x and y are accessed with the same offset by the same thread, // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). @@ -44,6 +45,29 @@ __global__ void spgpuSaxpby_krn(float *z, int n, float beta, float *y, float alp } +#if 1 +void spgpuSaxpby(spgpuHandle_t handle, + __device float *z, + int n, + float beta, + __device float *y, + float alpha, + __device float* x) +{ + int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; + int num_mp, max_threads_mp, num_blocks_mp, num_blocks; + dim3 block(BLOCK_SIZE); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + num_mp = deviceProp.multiProcessorCount; + max_threads_mp = deviceProp.maxThreadsPerMultiProcessor; + num_blocks_mp = max_threads_mp/BLOCK_SIZE; + num_blocks = num_blocks_mp*num_mp; + dim3 grid(num_blocks); + + spgpuSaxpby_krn<<currentStream>>>(z, n, beta, y, alpha, x); +} +#else void spgpuSaxpby_(spgpuHandle_t handle, __device float *z, int n, @@ -83,7 +107,7 @@ void spgpuSaxpby(spgpuHandle_t handle, cudaCheckError("CUDA error on saxpby"); } - +#endif void spgpuSmaxpby(spgpuHandle_t handle, __device float *z, int n, diff --git a/cuda/spgpu/kernels/zaxpby.cu b/cuda/spgpu/kernels/zaxpby.cu index 7f9d5797..da438fc2 100644 --- a/cuda/spgpu/kernels/zaxpby.cu +++ b/cuda/spgpu/kernels/zaxpby.cu @@ -33,8 +33,9 @@ extern "C" __global__ void spgpuZaxpby_krn(cuDoubleComplex *z, int n, cuDoubleComplex beta, cuDoubleComplex *y, cuDoubleComplex alpha, cuDoubleComplex* x) { int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; - - if (id < n) + unsigned int gridSize = blockDim.x * gridDim.x; + for ( ; id < n; id +=gridSize) + //if (id,n) { // Since z, x and y are accessed with the same offset by the same thread, // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). @@ -46,7 +47,29 @@ __global__ void spgpuZaxpby_krn(cuDoubleComplex *z, int n, cuDoubleComplex beta, } } +#if 1 +void spgpuZaxpby(spgpuHandle_t handle, + __device cuDoubleComplex *z, + int n, + cuDoubleComplex beta, + __device cuDoubleComplex *y, + cuDoubleComplex alpha, + __device cuDoubleComplex* x) +{ + int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; + int num_mp, max_threads_mp, num_blocks_mp, num_blocks; + dim3 block(BLOCK_SIZE); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + num_mp = deviceProp.multiProcessorCount; + max_threads_mp = deviceProp.maxThreadsPerMultiProcessor; + num_blocks_mp = max_threads_mp/BLOCK_SIZE; + num_blocks = num_blocks_mp*num_mp; + dim3 grid(num_blocks); + spgpuZaxpby_krn<<currentStream>>>(z, n, beta, y, alpha, x); +} +#else void spgpuZaxpby_(spgpuHandle_t handle, __device cuDoubleComplex *z, int n, @@ -86,7 +109,7 @@ void spgpuZaxpby(spgpuHandle_t handle, cudaCheckError("CUDA error on daxpby"); } - +#endif void spgpuZmaxpby(spgpuHandle_t handle, __device cuDoubleComplex *z, int n, From 864872ecacff43d40eed244a1044c0eb293db117 Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Sat, 17 Feb 2024 17:28:32 +0100 Subject: [PATCH 32/48] Intermediate implementation of abgdxyz on cuda --- cuda/psb_d_cuda_vect_mod.F90 | 39 ++++++++++++++++++++++++++++++------ cuda/spgpu/vector.h | 15 +++++++++++++- 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/cuda/psb_d_cuda_vect_mod.F90 b/cuda/psb_d_cuda_vect_mod.F90 index 03e65f91..fe5d3a38 100644 --- a/cuda/psb_d_cuda_vect_mod.F90 +++ b/cuda/psb_d_cuda_vect_mod.F90 @@ -923,13 +923,40 @@ contains real(psb_dpk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info - call z%psb_d_base_vect_type%abgdxyz(m,alpha,beta,gamma,delta,x,y,info) -!!$ -!!$ if (x%is_dev()) call x%sync() -!!$ -!!$ call y%axpby(m,alpha,x,beta,info) -!!$ call z%axpby(m,gamma,y,delta,info) + + info = psb_success_ + if (.false.) then + + select type(xx => x) + type is (psb_d_vect_cuda) + ! Do something different here + if ((beta /= dzero).and.y%is_host())& + & call y%sync() + if (xx%is_host()) call xx%sync() + nx = getMultiVecDeviceSize(xx%deviceVect) + ny = getMultiVecDeviceSize(y%deviceVect) + if ((nx Date: Sat, 17 Feb 2024 17:46:09 +0100 Subject: [PATCH 33/48] Intermediate impl of ABGDXYZ --- cuda/psb_d_cuda_vect_mod.F90 | 54 +++++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/cuda/psb_d_cuda_vect_mod.F90 b/cuda/psb_d_cuda_vect_mod.F90 index fe5d3a38..36fac14e 100644 --- a/cuda/psb_d_cuda_vect_mod.F90 +++ b/cuda/psb_d_cuda_vect_mod.F90 @@ -922,33 +922,47 @@ contains class(psb_d_vect_cuda), intent(inout) :: z real(psb_dpk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info - + integer(psb_ipk_) :: nx, ny, nz + logical :: gpu_done info = psb_success_ if (.false.) then - + gpu_done = .false. select type(xx => x) - type is (psb_d_vect_cuda) - ! Do something different here - if ((beta /= dzero).and.y%is_host())& - & call y%sync() - if (xx%is_host()) call xx%sync() - nx = getMultiVecDeviceSize(xx%deviceVect) - ny = getMultiVecDeviceSize(y%deviceVect) - if ((nx y) + class is (psb_d_vect_cuda) + select type(zz => z) + class is (psb_d_vect_cuda) + ! Do something different here + if ((beta /= dzero).and.yy%is_host())& + & call yy%sync() + if ((delta /= dzero).and.zz%is_host())& + & call zz%sync() + if (xx%is_host()) call xx%sync() + nx = getMultiVecDeviceSize(xx%deviceVect) + ny = getMultiVecDeviceSize(yy%deviceVect) + nz = getMultiVecDeviceSize(zz%deviceVect) + if ((nx Date: Sat, 17 Feb 2024 18:20:12 +0100 Subject: [PATCH 34/48] New implementation for ABGDXYZ in CUDA --- cuda/dvectordev.c | 24 +++++++++++ cuda/psb_d_cuda_vect_mod.F90 | 3 +- cuda/psb_d_vectordev_mod.F90 | 13 ++++++ cuda/spgpu/kernels/Makefile | 2 +- cuda/spgpu/kernels/dabgdxyz.cu | 79 ++++++++++++++++++++++++++++++++++ 5 files changed, 119 insertions(+), 2 deletions(-) create mode 100644 cuda/spgpu/kernels/dabgdxyz.cu diff --git a/cuda/dvectordev.c b/cuda/dvectordev.c index 39aa5b2a..785753dd 100644 --- a/cuda/dvectordev.c +++ b/cuda/dvectordev.c @@ -241,6 +241,30 @@ int axpbyMultiVecDeviceDouble(int n,double alpha, void* devMultiVecX, return(i); } + +int abgdxyzMultiVecDeviceDouble(int n,double alpha,double beta, double gamma, double delta, + void* devMultiVecX, void* devMultiVecY, void* devMultiVecZ) +{ int j=0, i=0; + int pitch = 0; + struct MultiVectDevice *devVecX = (struct MultiVectDevice *) devMultiVecX; + struct MultiVectDevice *devVecY = (struct MultiVectDevice *) devMultiVecY; + struct MultiVectDevice *devVecZ = (struct MultiVectDevice *) devMultiVecZ; + spgpuHandle_t handle=psb_cudaGetHandle(); + pitch = devVecY->pitch_; + if ((n > devVecY->size_) || (n>devVecX->size_ )) + return SPGPU_UNSUPPORTED; + +#if 1 + spgpuDabgdxyz(handle,n, alpha,beta,gamma,delta, + (double*)devVecX->v_,(double*) devVecY->v_,(double*) devVecZ->v_); +#else + for(j=0;jcount_;j++) + spgpuDaxpby(handle,(double*)devVecY->v_+pitch*j, n, beta, + (double*)devVecY->v_+pitch*j, alpha,(double*) devVecX->v_+pitch*j); +#endif + return(i); +} + int axyMultiVecDeviceDouble(int n, double alpha, void *deviceVecA, void *deviceVecB) { int i = 0; struct MultiVectDevice *devVecA = (struct MultiVectDevice *) deviceVecA; diff --git a/cuda/psb_d_cuda_vect_mod.F90 b/cuda/psb_d_cuda_vect_mod.F90 index 36fac14e..8256eaa0 100644 --- a/cuda/psb_d_cuda_vect_mod.F90 +++ b/cuda/psb_d_cuda_vect_mod.F90 @@ -947,7 +947,8 @@ contains if ((nx + +extern "C" +{ +#include "core.h" +#include "vector.h" +} + + +#include "debug.h" + +#define BLOCK_SIZE 512 + +__global__ void spgpuDabgdxyz_krn(int n, double alpha, double beta, double gamma, double delta, + double* x, double *y, double *z) +{ + int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; + unsigned int gridSize = blockDim.x * gridDim.x; + double t; + for ( ; id < n; id +=gridSize) + //if (id,n) + { + + if (beta == 0.0) + t = PREC_DMUL(alpha,x[id]); + else + t = PREC_DADD(PREC_DMUL(alpha, x[id]), PREC_DMUL(beta,y[id])); + if (delta == 0.0) + z[id] = gamma * t; + else + z[id] = PREC_DADD(PREC_DMUL(gamma, t), PREC_DMUL(delta,z[id])); + y[id] = t; + } +} + + +void spgpuDabgdxyz(spgpuHandle_t handle, + int n, + double alpha, + double beta, + double gamma, + double delta, + __device double* x, + __device double* y, + __device double *z) +{ + int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; + int num_mp, max_threads_mp, num_blocks_mp, num_blocks; + dim3 block(BLOCK_SIZE); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + num_mp = deviceProp.multiProcessorCount; + max_threads_mp = deviceProp.maxThreadsPerMultiProcessor; + num_blocks_mp = max_threads_mp/BLOCK_SIZE; + num_blocks = num_blocks_mp*num_mp; + dim3 grid(num_blocks); + + spgpuDabgdxyz_krn<<currentStream>>>(n, alpha, beta, gamma, delta, + x, y, z); +} + From f9677bc8920a187bdb0184e2879da80932fe6a55 Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Sat, 17 Feb 2024 18:42:56 +0100 Subject: [PATCH 35/48] Enabled new CUDA version of ABGDXYZ --- cuda/psb_d_cuda_vect_mod.F90 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuda/psb_d_cuda_vect_mod.F90 b/cuda/psb_d_cuda_vect_mod.F90 index 8256eaa0..f2ef2be3 100644 --- a/cuda/psb_d_cuda_vect_mod.F90 +++ b/cuda/psb_d_cuda_vect_mod.F90 @@ -927,7 +927,7 @@ contains info = psb_success_ - if (.false.) then + if (.true.) then gpu_done = .false. select type(xx => x) class is (psb_d_vect_cuda) From 1ba8dfc7b7079eed0e86e90539e9b7c7daffe701 Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Sun, 18 Feb 2024 10:31:32 +0100 Subject: [PATCH 36/48] Switch FOR and IF in AXPBY --- cuda/spgpu/kernels/caxpby.cu | 25 +++++++++++++++---------- cuda/spgpu/kernels/daxpby.cu | 25 +++++++++++++++---------- cuda/spgpu/kernels/saxpby.cu | 25 +++++++++++++++---------- cuda/spgpu/kernels/zaxpby.cu | 25 +++++++++++++++---------- 4 files changed, 60 insertions(+), 40 deletions(-) diff --git a/cuda/spgpu/kernels/caxpby.cu b/cuda/spgpu/kernels/caxpby.cu index 16eb87ed..33deecbc 100644 --- a/cuda/spgpu/kernels/caxpby.cu +++ b/cuda/spgpu/kernels/caxpby.cu @@ -33,16 +33,21 @@ __global__ void spgpuCaxpby_krn(cuFloatComplex *z, int n, cuFloatComplex beta, c { int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; unsigned int gridSize = blockDim.x * gridDim.x; - for ( ; id < n; id +=gridSize) - //if (id,n) - { - // Since z, x and y are accessed with the same offset by the same thread, - // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). - - if (cuFloatComplex_isZero(beta)) - z[id] = cuCmulf(alpha,x[id]); - else - z[id] = cuCfmaf(beta, y[id], cuCmulf(alpha, x[id])); + if (cuFloatComplex_isZero(beta)) { + for ( ; id < n; id +=gridSize) + //if (id,n) + { + // Since z, x and y are accessed with the same offset by the same thread, + // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). + + z[id] = cuCmulf(alpha,x[id]); + } + } else { + for ( ; id < n; id +=gridSize) + //if (id,n) + { + z[id] = cuCfmaf(beta, y[id], cuCmulf(alpha, x[id])); + } } } diff --git a/cuda/spgpu/kernels/daxpby.cu b/cuda/spgpu/kernels/daxpby.cu index a0a163a2..ce7c0dd4 100644 --- a/cuda/spgpu/kernels/daxpby.cu +++ b/cuda/spgpu/kernels/daxpby.cu @@ -33,16 +33,21 @@ __global__ void spgpuDaxpby_krn(double *z, int n, double beta, double *y, double { int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; unsigned int gridSize = blockDim.x * gridDim.x; - for ( ; id < n; id +=gridSize) - //if (id,n) - { - // Since z, x and y are accessed with the same offset by the same thread, - // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). - - if (beta == 0.0) - z[id] = PREC_DMUL(alpha,x[id]); - else - z[id] = PREC_DADD(PREC_DMUL(alpha, x[id]), PREC_DMUL(beta,y[id])); + if (beta == 0.0) { + for ( ; id < n; id +=gridSize) + { + // Since z, x and y are accessed with the same offset by the same thread, + // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). + + z[id] = PREC_DMUL(alpha,x[id]); + } + } else { + for ( ; id < n; id +=gridSize) + { + // Since z, x and y are accessed with the same offset by the same thread, + // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). + z[id] = PREC_DADD(PREC_DMUL(alpha, x[id]), PREC_DMUL(beta,y[id])); + } } } diff --git a/cuda/spgpu/kernels/saxpby.cu b/cuda/spgpu/kernels/saxpby.cu index 42e2a7a7..36c3cdbe 100644 --- a/cuda/spgpu/kernels/saxpby.cu +++ b/cuda/spgpu/kernels/saxpby.cu @@ -31,16 +31,21 @@ __global__ void spgpuSaxpby_krn(float *z, int n, float beta, float *y, float alp { int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; unsigned int gridSize = blockDim.x * gridDim.x; - for ( ; id < n; id +=gridSize) - //if (id,n) - { - // Since z, x and y are accessed with the same offset by the same thread, - // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). - - if (beta == 0.0f) - z[id] = PREC_FMUL(alpha,x[id]); - else - z[id] = PREC_FADD(PREC_FMUL(alpha, x[id]), PREC_FMUL(beta,y[id])); + if (beta == 0.0f) { + for ( ; id < n; id +=gridSize) + { + // Since z, x and y are accessed with the same offset by the same thread, + // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). + + z[id] = PREC_FMUL(alpha,x[id]); + } + } else { + for ( ; id < n; id +=gridSize) + { + // Since z, x and y are accessed with the same offset by the same thread, + // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). + z[id] = PREC_FADD(PREC_FMUL(alpha, x[id]), PREC_FMUL(beta,y[id])); + } } } diff --git a/cuda/spgpu/kernels/zaxpby.cu b/cuda/spgpu/kernels/zaxpby.cu index da438fc2..8aec3e17 100644 --- a/cuda/spgpu/kernels/zaxpby.cu +++ b/cuda/spgpu/kernels/zaxpby.cu @@ -34,16 +34,21 @@ __global__ void spgpuZaxpby_krn(cuDoubleComplex *z, int n, cuDoubleComplex beta, { int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; unsigned int gridSize = blockDim.x * gridDim.x; - for ( ; id < n; id +=gridSize) - //if (id,n) - { - // Since z, x and y are accessed with the same offset by the same thread, - // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). - - if (cuDoubleComplex_isZero(beta)) - z[id] = cuCmul(alpha,x[id]); - else - z[id] = cuCfma(alpha, x[id], cuCmul(beta,y[id])); + if (cuDoubleComplex_isZero(beta)) { + for ( ; id < n; id +=gridSize) + //if (id,n) + { + // Since z, x and y are accessed with the same offset by the same thread, + // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). + + z[id] = cuCmul(alpha,x[id]); + } + } else { + for ( ; id < n; id +=gridSize) + //if (id,n) + { + z[id] = cuCfma(beta, y[id], cuCmul(alpha, x[id])); + } } } From 35d68aa4e326dc0b11a18cdd4f15ec62ec3802f4 Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Sun, 18 Feb 2024 16:44:37 +0100 Subject: [PATCH 37/48] Reuse calls to getDeviceProperties done at init time --- cuda/spgpu/kernels/caxpby.cu | 29 ++++++++++++++++++++++++----- cuda/spgpu/kernels/dabgdxyz.cu | 8 ++++---- cuda/spgpu/kernels/daxpby.cu | 34 ++++++++++++++++++++++++---------- cuda/spgpu/kernels/saxpby.cu | 33 +++++++++++++++++++++++++++------ cuda/spgpu/kernels/zaxpby.cu | 28 +++++++++++++++++++++++----- 5 files changed, 102 insertions(+), 30 deletions(-) diff --git a/cuda/spgpu/kernels/caxpby.cu b/cuda/spgpu/kernels/caxpby.cu index 33deecbc..3e97f75f 100644 --- a/cuda/spgpu/kernels/caxpby.cu +++ b/cuda/spgpu/kernels/caxpby.cu @@ -22,6 +22,9 @@ extern "C" { #include "core.h" #include "vector.h" + int getGPUMultiProcessors(); + int getGPUMaxThreadsPerMP(); + //#include "cuda_util.h" } @@ -29,6 +32,8 @@ extern "C" #define BLOCK_SIZE 512 +#if 1 + __global__ void spgpuCaxpby_krn(cuFloatComplex *z, int n, cuFloatComplex beta, cuFloatComplex *y, cuFloatComplex alpha, cuFloatComplex* x) { int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; @@ -51,7 +56,6 @@ __global__ void spgpuCaxpby_krn(cuFloatComplex *z, int n, cuFloatComplex beta, c } } -#if 1 void spgpuCaxpby(spgpuHandle_t handle, __device cuFloatComplex *z, int n, @@ -63,10 +67,8 @@ void spgpuCaxpby(spgpuHandle_t handle, int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; int num_mp, max_threads_mp, num_blocks_mp, num_blocks; dim3 block(BLOCK_SIZE); - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, 0); - num_mp = deviceProp.multiProcessorCount; - max_threads_mp = deviceProp.maxThreadsPerMultiProcessor; + num_mp = getGPUMultiProcessors(); + max_threads_mp = getGPUMaxThreadsPerMP(); num_blocks_mp = max_threads_mp/BLOCK_SIZE; num_blocks = num_blocks_mp*num_mp; dim3 grid(num_blocks); @@ -75,6 +77,23 @@ void spgpuCaxpby(spgpuHandle_t handle, } #else + +__global__ void spgpuCaxpby_krn(cuFloatComplex *z, int n, cuFloatComplex beta, cuFloatComplex *y, cuFloatComplex alpha, cuFloatComplex* x) +{ + int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; + + if (id < n) + { + // Since z, x and y are accessed with the same offset by the same thread, + // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). + + if (cuFloatComplex_isZero(beta)) + z[id] = cuCmulf(alpha,x[id]); + else + z[id] = cuCfmaf(beta, y[id], cuCmulf(alpha, x[id])); + } +} + void spgpuCaxpby_(spgpuHandle_t handle, __device cuFloatComplex *z, int n, diff --git a/cuda/spgpu/kernels/dabgdxyz.cu b/cuda/spgpu/kernels/dabgdxyz.cu index 525371d3..f2b18e02 100644 --- a/cuda/spgpu/kernels/dabgdxyz.cu +++ b/cuda/spgpu/kernels/dabgdxyz.cu @@ -22,6 +22,8 @@ extern "C" { #include "core.h" #include "vector.h" + int getGPUMultiProcessors(); + int getGPUMaxThreadsPerMP(); } @@ -65,10 +67,8 @@ void spgpuDabgdxyz(spgpuHandle_t handle, int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; int num_mp, max_threads_mp, num_blocks_mp, num_blocks; dim3 block(BLOCK_SIZE); - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, 0); - num_mp = deviceProp.multiProcessorCount; - max_threads_mp = deviceProp.maxThreadsPerMultiProcessor; + num_mp = getGPUMultiProcessors(); + max_threads_mp = getGPUMaxThreadsPerMP(); num_blocks_mp = max_threads_mp/BLOCK_SIZE; num_blocks = num_blocks_mp*num_mp; dim3 grid(num_blocks); diff --git a/cuda/spgpu/kernels/daxpby.cu b/cuda/spgpu/kernels/daxpby.cu index ce7c0dd4..fa87d996 100644 --- a/cuda/spgpu/kernels/daxpby.cu +++ b/cuda/spgpu/kernels/daxpby.cu @@ -22,6 +22,9 @@ extern "C" { #include "core.h" #include "vector.h" + int getGPUMultiProcessors(); + int getGPUMaxThreadsPerMP(); + //#include "cuda_util.h" } @@ -29,6 +32,8 @@ extern "C" #define BLOCK_SIZE 512 + +#if 1 __global__ void spgpuDaxpby_krn(double *z, int n, double beta, double *y, double alpha, double* x) { int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; @@ -36,23 +41,17 @@ __global__ void spgpuDaxpby_krn(double *z, int n, double beta, double *y, double if (beta == 0.0) { for ( ; id < n; id +=gridSize) { - // Since z, x and y are accessed with the same offset by the same thread, - // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). z[id] = PREC_DMUL(alpha,x[id]); } } else { for ( ; id < n; id +=gridSize) { - // Since z, x and y are accessed with the same offset by the same thread, - // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). z[id] = PREC_DADD(PREC_DMUL(alpha, x[id]), PREC_DMUL(beta,y[id])); } } } -#if 1 - void spgpuDaxpby(spgpuHandle_t handle, __device double *z, int n, @@ -64,10 +63,8 @@ void spgpuDaxpby(spgpuHandle_t handle, int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; int num_mp, max_threads_mp, num_blocks_mp, num_blocks; dim3 block(BLOCK_SIZE); - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, 0); - num_mp = deviceProp.multiProcessorCount; - max_threads_mp = deviceProp.maxThreadsPerMultiProcessor; + num_mp = getGPUMultiProcessors(); + max_threads_mp = getGPUMaxThreadsPerMP(); num_blocks_mp = max_threads_mp/BLOCK_SIZE; num_blocks = num_blocks_mp*num_mp; dim3 grid(num_blocks); @@ -75,6 +72,23 @@ void spgpuDaxpby(spgpuHandle_t handle, spgpuDaxpby_krn<<currentStream>>>(z, n, beta, y, alpha, x); } #else + +__global__ void spgpuDaxpby_krn(double *z, int n, double beta, double *y, double alpha, double* x) +{ + int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; + + if (id < n) + { + // Since z, x and y are accessed with the same offset by the same thread, + // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). + + if (beta == 0.0) + z[id] = PREC_DMUL(alpha,x[id]); + else + z[id] = PREC_DADD(PREC_DMUL(alpha, x[id]), PREC_DMUL(beta,y[id])); + } +} + void spgpuDaxpby_(spgpuHandle_t handle, __device double *z, int n, diff --git a/cuda/spgpu/kernels/saxpby.cu b/cuda/spgpu/kernels/saxpby.cu index 36c3cdbe..2f06e39c 100644 --- a/cuda/spgpu/kernels/saxpby.cu +++ b/cuda/spgpu/kernels/saxpby.cu @@ -20,6 +20,9 @@ extern "C" { #include "core.h" #include "vector.h" + int getGPUMultiProcessors(); + int getGPUMaxThreadsPerMP(); + //#include "cuda_util.h" } @@ -27,6 +30,8 @@ extern "C" #define BLOCK_SIZE 512 + +#if 1 __global__ void spgpuSaxpby_krn(float *z, int n, float beta, float *y, float alpha, float* x) { int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; @@ -49,8 +54,6 @@ __global__ void spgpuSaxpby_krn(float *z, int n, float beta, float *y, float alp } } - -#if 1 void spgpuSaxpby(spgpuHandle_t handle, __device float *z, int n, @@ -62,17 +65,35 @@ void spgpuSaxpby(spgpuHandle_t handle, int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; int num_mp, max_threads_mp, num_blocks_mp, num_blocks; dim3 block(BLOCK_SIZE); - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, 0); - num_mp = deviceProp.multiProcessorCount; - max_threads_mp = deviceProp.maxThreadsPerMultiProcessor; + num_mp = getGPUMultiProcessors(); + max_threads_mp = getGPUMaxThreadsPerMP(); num_blocks_mp = max_threads_mp/BLOCK_SIZE; num_blocks = num_blocks_mp*num_mp; dim3 grid(num_blocks); spgpuSaxpby_krn<<currentStream>>>(z, n, beta, y, alpha, x); } + #else + +__global__ void spgpuSaxpby_krn(float *z, int n, float beta, float *y, float alpha, float* x) +{ + int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; + + if (id < n) + { + // Since z, x and y are accessed with the same offset by the same thread, + // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). + + if (beta == 0.0f) + z[id] = PREC_FMUL(alpha,x[id]); + else + z[id] = PREC_FADD(PREC_FMUL(alpha, x[id]), PREC_FMUL(beta,y[id])); + } +} + + + void spgpuSaxpby_(spgpuHandle_t handle, __device float *z, int n, diff --git a/cuda/spgpu/kernels/zaxpby.cu b/cuda/spgpu/kernels/zaxpby.cu index 8aec3e17..8efc40d2 100644 --- a/cuda/spgpu/kernels/zaxpby.cu +++ b/cuda/spgpu/kernels/zaxpby.cu @@ -23,6 +23,9 @@ extern "C" { #include "core.h" #include "vector.h" + int getGPUMultiProcessors(); + int getGPUMaxThreadsPerMP(); + //#include "cuda_util.h" } @@ -30,6 +33,7 @@ extern "C" #define BLOCK_SIZE 512 +#if 1 __global__ void spgpuZaxpby_krn(cuDoubleComplex *z, int n, cuDoubleComplex beta, cuDoubleComplex *y, cuDoubleComplex alpha, cuDoubleComplex* x) { int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; @@ -52,7 +56,6 @@ __global__ void spgpuZaxpby_krn(cuDoubleComplex *z, int n, cuDoubleComplex beta, } } -#if 1 void spgpuZaxpby(spgpuHandle_t handle, __device cuDoubleComplex *z, int n, @@ -64,10 +67,8 @@ void spgpuZaxpby(spgpuHandle_t handle, int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; int num_mp, max_threads_mp, num_blocks_mp, num_blocks; dim3 block(BLOCK_SIZE); - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, 0); - num_mp = deviceProp.multiProcessorCount; - max_threads_mp = deviceProp.maxThreadsPerMultiProcessor; + num_mp = getGPUMultiProcessors(); + max_threads_mp = getGPUMaxThreadsPerMP(); num_blocks_mp = max_threads_mp/BLOCK_SIZE; num_blocks = num_blocks_mp*num_mp; dim3 grid(num_blocks); @@ -75,6 +76,23 @@ void spgpuZaxpby(spgpuHandle_t handle, spgpuZaxpby_krn<<currentStream>>>(z, n, beta, y, alpha, x); } #else +__global__ void spgpuZaxpby_krn(cuDoubleComplex *z, int n, cuDoubleComplex beta, cuDoubleComplex *y, cuDoubleComplex alpha, cuDoubleComplex* x) +{ + int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; + + if (id < n) + { + // Since z, x and y are accessed with the same offset by the same thread, + // and the write to z follows the x and y read, x, y and z can share the same base address (in-place computing). + + if (cuDoubleComplex_isZero(beta)) + z[id] = cuCmul(alpha,x[id]); + else + z[id] = cuCfma(alpha, x[id], cuCmul(beta,y[id])); + } +} + + void spgpuZaxpby_(spgpuHandle_t handle, __device cuDoubleComplex *z, int n, From 0568a83734e7148c6c8676669009103d0f6c6455 Mon Sep 17 00:00:00 2001 From: sfilippone Date: Mon, 19 Feb 2024 10:53:23 +0100 Subject: [PATCH 38/48] Fix ifdef and old code --- cuda/spgpu/kernels/caxpby.cu | 12 ++++-------- cuda/spgpu/kernels/daxpby.cu | 12 ++++-------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/cuda/spgpu/kernels/caxpby.cu b/cuda/spgpu/kernels/caxpby.cu index 3e97f75f..817fdf53 100644 --- a/cuda/spgpu/kernels/caxpby.cu +++ b/cuda/spgpu/kernels/caxpby.cu @@ -78,6 +78,7 @@ void spgpuCaxpby(spgpuHandle_t handle, #else + __global__ void spgpuCaxpby_krn(cuFloatComplex *z, int n, cuFloatComplex beta, cuFloatComplex *y, cuFloatComplex alpha, cuFloatComplex* x) { int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; @@ -94,6 +95,7 @@ __global__ void spgpuCaxpby_krn(cuFloatComplex *z, int n, cuFloatComplex beta, c } } + void spgpuCaxpby_(spgpuHandle_t handle, __device cuFloatComplex *z, int n, @@ -103,15 +105,9 @@ void spgpuCaxpby_(spgpuHandle_t handle, __device cuFloatComplex* x) { int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; - int num_mp, max_threads_mp, num_blocks_mp, num_blocks; + dim3 block(BLOCK_SIZE); - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, 0); - num_mp = deviceProp.multiProcessorCount; - max_threads_mp = deviceProp.maxThreadsPerMultiProcessor; - num_blocks_mp = max_threads_mp/BLOCK_SIZE; - num_blocks = num_blocks_mp*num_mp; - dim3 grid(num_blocks); + dim3 grid(msize); spgpuCaxpby_krn<<currentStream>>>(z, n, beta, y, alpha, x); } diff --git a/cuda/spgpu/kernels/daxpby.cu b/cuda/spgpu/kernels/daxpby.cu index fa87d996..e4823b34 100644 --- a/cuda/spgpu/kernels/daxpby.cu +++ b/cuda/spgpu/kernels/daxpby.cu @@ -89,6 +89,7 @@ __global__ void spgpuDaxpby_krn(double *z, int n, double beta, double *y, double } } + void spgpuDaxpby_(spgpuHandle_t handle, __device double *z, int n, @@ -98,15 +99,9 @@ void spgpuDaxpby_(spgpuHandle_t handle, __device double* x) { int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; - int num_mp, max_threads_mp, num_blocks_mp, num_blocks; + dim3 block(BLOCK_SIZE); - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, 0); - num_mp = deviceProp.multiProcessorCount; - max_threads_mp = deviceProp.maxThreadsPerMultiProcessor; - num_blocks_mp = max_threads_mp/BLOCK_SIZE; - num_blocks = num_blocks_mp*num_mp; - dim3 grid(num_blocks); + dim3 grid(msize); spgpuDaxpby_krn<<currentStream>>>(z, n, beta, y, alpha, x); } @@ -134,6 +129,7 @@ void spgpuDaxpby(spgpuHandle_t handle, cudaCheckError("CUDA error on daxpby"); } + #endif void spgpuDmaxpby(spgpuHandle_t handle, __device double *z, From 93c71c43162fb6663cca9c2f0fff0e8f2ff4c47e Mon Sep 17 00:00:00 2001 From: sfilippone Date: Tue, 20 Feb 2024 10:25:31 +0100 Subject: [PATCH 39/48] Fix %ZERO() on cuda --- cuda/psb_c_cuda_vect_mod.F90 | 6 ++-- cuda/psb_d_cuda_vect_mod.F90 | 60 ++++++------------------------------ cuda/psb_d_vectordev_mod.F90 | 13 -------- cuda/psb_i_cuda_vect_mod.F90 | 6 ++-- cuda/psb_s_cuda_vect_mod.F90 | 6 ++-- cuda/psb_z_cuda_vect_mod.F90 | 6 ++-- 6 files changed, 21 insertions(+), 76 deletions(-) diff --git a/cuda/psb_c_cuda_vect_mod.F90 b/cuda/psb_c_cuda_vect_mod.F90 index 56cc80e6..fca1c616 100644 --- a/cuda/psb_c_cuda_vect_mod.F90 +++ b/cuda/psb_c_cuda_vect_mod.F90 @@ -668,9 +668,9 @@ contains use psi_serial_mod implicit none class(psb_c_vect_cuda), intent(inout) :: x - - if (allocated(x%v)) x%v=czero - call x%set_host() + + call x%set_scal(czero) + end subroutine c_cuda_zero subroutine c_cuda_asb_m(n, x, info) diff --git a/cuda/psb_d_cuda_vect_mod.F90 b/cuda/psb_d_cuda_vect_mod.F90 index f2ef2be3..2220b26c 100644 --- a/cuda/psb_d_cuda_vect_mod.F90 +++ b/cuda/psb_d_cuda_vect_mod.F90 @@ -668,9 +668,9 @@ contains use psi_serial_mod implicit none class(psb_d_vect_cuda), intent(inout) :: x - - if (allocated(x%v)) x%v=dzero - call x%set_host() + + call x%set_scal(dzero) + end subroutine d_cuda_zero subroutine d_cuda_asb_m(n, x, info) @@ -922,56 +922,14 @@ contains class(psb_d_vect_cuda), intent(inout) :: z real(psb_dpk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_) :: nx, ny, nz - logical :: gpu_done - info = psb_success_ + call z%psb_d_base_vect_type%abgdxyz(m,alpha,beta,gamma,delta,x,y,info) +!!$ +!!$ if (x%is_dev()) call x%sync() +!!$ +!!$ call y%axpby(m,alpha,x,beta,info) +!!$ call z%axpby(m,gamma,y,delta,info) - if (.true.) then - gpu_done = .false. - select type(xx => x) - class is (psb_d_vect_cuda) - select type(yy => y) - class is (psb_d_vect_cuda) - select type(zz => z) - class is (psb_d_vect_cuda) - ! Do something different here - if ((beta /= dzero).and.yy%is_host())& - & call yy%sync() - if ((delta /= dzero).and.zz%is_host())& - & call zz%sync() - if (xx%is_host()) call xx%sync() - nx = getMultiVecDeviceSize(xx%deviceVect) - ny = getMultiVecDeviceSize(yy%deviceVect) - nz = getMultiVecDeviceSize(zz%deviceVect) - if ((nx Date: Tue, 20 Feb 2024 12:30:07 +0100 Subject: [PATCH 40/48] X_cuda_vect%abgdxyz --- cuda/psb_c_cuda_vect_mod.F90 | 64 ++++++++++++++++++++++++++++++------ cuda/psb_d_cuda_vect_mod.F90 | 56 +++++++++++++++++++++++++++---- cuda/psb_i_cuda_vect_mod.F90 | 8 ++--- cuda/psb_s_cuda_vect_mod.F90 | 64 ++++++++++++++++++++++++++++++------ cuda/psb_z_cuda_vect_mod.F90 | 64 ++++++++++++++++++++++++++++++------ 5 files changed, 216 insertions(+), 40 deletions(-) diff --git a/cuda/psb_c_cuda_vect_mod.F90 b/cuda/psb_c_cuda_vect_mod.F90 index fca1c616..9b3b6fb1 100644 --- a/cuda/psb_c_cuda_vect_mod.F90 +++ b/cuda/psb_c_cuda_vect_mod.F90 @@ -922,13 +922,57 @@ contains class(psb_c_vect_cuda), intent(inout) :: z complex(psb_spk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: nx, ny, nz + logical :: gpu_done + + info = psb_success_ + + if (.true.) then + gpu_done = .false. + select type(xx => x) + class is (psb_c_vect_cuda) + select type(yy => y) + class is (psb_c_vect_cuda) + select type(zz => z) + class is (psb_c_vect_cuda) + ! Do something different here + if ((beta /= czero).and.yy%is_host())& + & call yy%sync() + if ((delta /= czero).and.zz%is_host())& + & call zz%sync() + if (xx%is_host()) call xx%sync() + nx = getMultiVecDeviceSize(xx%deviceVect) + ny = getMultiVecDeviceSize(yy%deviceVect) + nz = getMultiVecDeviceSize(zz%deviceVect) + if ((nx x) !!$ type is (psb_c_base_multivect_type) -!!$ if ((beta /= dzero).and.(y%is_dev()))& +!!$ if ((beta /= czero).and.(y%is_dev()))& !!$ & call y%sync() !!$ call psb_geaxpby(m,alpha,xx%v,beta,y%v,info) !!$ call y%set_host() !!$ type is (psb_c_multivect_cuda) !!$ ! Do something different here -!!$ if ((beta /= dzero).and.y%is_host())& +!!$ if ((beta /= czero).and.y%is_host())& !!$ & call y%sync() !!$ if (xx%is_host()) call xx%sync() !!$ nx = getMultiVecDeviceSize(xx%deviceVect) @@ -1817,7 +1861,7 @@ contains implicit none class(psb_c_multivect_cuda), intent(inout) :: x - if (allocated(x%v)) x%v=dzero + if (allocated(x%v)) x%v=czero call x%set_host() end subroutine c_cuda_multi_zero diff --git a/cuda/psb_d_cuda_vect_mod.F90 b/cuda/psb_d_cuda_vect_mod.F90 index 2220b26c..c98d66f6 100644 --- a/cuda/psb_d_cuda_vect_mod.F90 +++ b/cuda/psb_d_cuda_vect_mod.F90 @@ -922,13 +922,57 @@ contains class(psb_d_vect_cuda), intent(inout) :: z real(psb_dpk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: nx, ny, nz + logical :: gpu_done + + info = psb_success_ + + if (.true.) then + gpu_done = .false. + select type(xx => x) + class is (psb_d_vect_cuda) + select type(yy => y) + class is (psb_d_vect_cuda) + select type(zz => z) + class is (psb_d_vect_cuda) + ! Do something different here + if ((beta /= dzero).and.yy%is_host())& + & call yy%sync() + if ((delta /= dzero).and.zz%is_host())& + & call zz%sync() + if (xx%is_host()) call xx%sync() + nx = getMultiVecDeviceSize(xx%deviceVect) + ny = getMultiVecDeviceSize(yy%deviceVect) + nz = getMultiVecDeviceSize(zz%deviceVect) + if ((nx x) !!$ type is (psb_i_base_multivect_type) -!!$ if ((beta /= dzero).and.(y%is_dev()))& +!!$ if ((beta /= izero).and.(y%is_dev()))& !!$ & call y%sync() !!$ call psb_geaxpby(m,alpha,xx%v,beta,y%v,info) !!$ call y%set_host() !!$ type is (psb_i_multivect_cuda) !!$ ! Do something different here -!!$ if ((beta /= dzero).and.y%is_host())& +!!$ if ((beta /= izero).and.y%is_host())& !!$ & call y%sync() !!$ if (xx%is_host()) call xx%sync() !!$ nx = getMultiVecDeviceSize(xx%deviceVect) @@ -1477,7 +1477,7 @@ contains implicit none class(psb_i_multivect_cuda), intent(inout) :: x - if (allocated(x%v)) x%v=dzero + if (allocated(x%v)) x%v=izero call x%set_host() end subroutine i_cuda_multi_zero diff --git a/cuda/psb_s_cuda_vect_mod.F90 b/cuda/psb_s_cuda_vect_mod.F90 index 80c60bc3..55ed4a7d 100644 --- a/cuda/psb_s_cuda_vect_mod.F90 +++ b/cuda/psb_s_cuda_vect_mod.F90 @@ -922,13 +922,57 @@ contains class(psb_s_vect_cuda), intent(inout) :: z real(psb_spk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: nx, ny, nz + logical :: gpu_done + + info = psb_success_ + + if (.true.) then + gpu_done = .false. + select type(xx => x) + class is (psb_s_vect_cuda) + select type(yy => y) + class is (psb_s_vect_cuda) + select type(zz => z) + class is (psb_s_vect_cuda) + ! Do something different here + if ((beta /= szero).and.yy%is_host())& + & call yy%sync() + if ((delta /= szero).and.zz%is_host())& + & call zz%sync() + if (xx%is_host()) call xx%sync() + nx = getMultiVecDeviceSize(xx%deviceVect) + ny = getMultiVecDeviceSize(yy%deviceVect) + nz = getMultiVecDeviceSize(zz%deviceVect) + if ((nx x) !!$ type is (psb_s_base_multivect_type) -!!$ if ((beta /= dzero).and.(y%is_dev()))& +!!$ if ((beta /= szero).and.(y%is_dev()))& !!$ & call y%sync() !!$ call psb_geaxpby(m,alpha,xx%v,beta,y%v,info) !!$ call y%set_host() !!$ type is (psb_s_multivect_cuda) !!$ ! Do something different here -!!$ if ((beta /= dzero).and.y%is_host())& +!!$ if ((beta /= szero).and.y%is_host())& !!$ & call y%sync() !!$ if (xx%is_host()) call xx%sync() !!$ nx = getMultiVecDeviceSize(xx%deviceVect) @@ -1817,7 +1861,7 @@ contains implicit none class(psb_s_multivect_cuda), intent(inout) :: x - if (allocated(x%v)) x%v=dzero + if (allocated(x%v)) x%v=szero call x%set_host() end subroutine s_cuda_multi_zero diff --git a/cuda/psb_z_cuda_vect_mod.F90 b/cuda/psb_z_cuda_vect_mod.F90 index 9f801742..2114723b 100644 --- a/cuda/psb_z_cuda_vect_mod.F90 +++ b/cuda/psb_z_cuda_vect_mod.F90 @@ -922,13 +922,57 @@ contains class(psb_z_vect_cuda), intent(inout) :: z complex(psb_dpk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: nx, ny, nz + logical :: gpu_done + + info = psb_success_ + + if (.true.) then + gpu_done = .false. + select type(xx => x) + class is (psb_z_vect_cuda) + select type(yy => y) + class is (psb_z_vect_cuda) + select type(zz => z) + class is (psb_z_vect_cuda) + ! Do something different here + if ((beta /= zzero).and.yy%is_host())& + & call yy%sync() + if ((delta /= zzero).and.zz%is_host())& + & call zz%sync() + if (xx%is_host()) call xx%sync() + nx = getMultiVecDeviceSize(xx%deviceVect) + ny = getMultiVecDeviceSize(yy%deviceVect) + nz = getMultiVecDeviceSize(zz%deviceVect) + if ((nx x) !!$ type is (psb_z_base_multivect_type) -!!$ if ((beta /= dzero).and.(y%is_dev()))& +!!$ if ((beta /= zzero).and.(y%is_dev()))& !!$ & call y%sync() !!$ call psb_geaxpby(m,alpha,xx%v,beta,y%v,info) !!$ call y%set_host() !!$ type is (psb_z_multivect_cuda) !!$ ! Do something different here -!!$ if ((beta /= dzero).and.y%is_host())& +!!$ if ((beta /= zzero).and.y%is_host())& !!$ & call y%sync() !!$ if (xx%is_host()) call xx%sync() !!$ nx = getMultiVecDeviceSize(xx%deviceVect) @@ -1817,7 +1861,7 @@ contains implicit none class(psb_z_multivect_cuda), intent(inout) :: x - if (allocated(x%v)) x%v=dzero + if (allocated(x%v)) x%v=zzero call x%set_host() end subroutine z_cuda_multi_zero From 2a75d677d05da6616103c86e3ba769de16159d34 Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Tue, 20 Feb 2024 13:04:40 +0100 Subject: [PATCH 41/48] ABGDXYZ in vectordev_mod --- cuda/psb_c_vectordev_mod.F90 | 13 ++++++++++++- cuda/psb_d_vectordev_mod.F90 | 13 ++++++++++++- cuda/psb_s_vectordev_mod.F90 | 13 ++++++++++++- cuda/psb_z_vectordev_mod.F90 | 13 ++++++++++++- 4 files changed, 48 insertions(+), 4 deletions(-) diff --git a/cuda/psb_c_vectordev_mod.F90 b/cuda/psb_c_vectordev_mod.F90 index b15b2371..88888f61 100644 --- a/cuda/psb_c_vectordev_mod.F90 +++ b/cuda/psb_c_vectordev_mod.F90 @@ -304,7 +304,6 @@ module psb_c_vectordev_mod end function asumMultiVecDeviceFloatComplex end interface - interface axpbyMultiVecDevice function axpbyMultiVecDeviceFloatComplex(n,alpha,deviceVecA,beta,deviceVecB) & & result(res) bind(c,name='axpbyMultiVecDeviceFloatComplex') @@ -316,6 +315,18 @@ module psb_c_vectordev_mod end function axpbyMultiVecDeviceFloatComplex end interface + interface abgdxyzMultiVecDevice + function abgdxyzMultiVecDeviceFloatComplex(n,alpha,beta,gamma,delta,deviceVecX,& + & deviceVecY,deviceVecZ) & + & result(res) bind(c,name='abgdxyzMultiVecDeviceFloatComplex') + use iso_c_binding + integer(c_int) :: res + integer(c_int), value :: n + type(c_float_complex), value :: alpha, beta,gamma,delta + type(c_ptr), value :: deviceVecX, deviceVecY, deviceVecZ + end function abgdxyzMultiVecDeviceFloatComplex + end interface + interface axyMultiVecDevice function axyMultiVecDeviceFloatComplex(n,alpha,deviceVecA,deviceVecB) & & result(res) bind(c,name='axyMultiVecDeviceFloatComplex') diff --git a/cuda/psb_d_vectordev_mod.F90 b/cuda/psb_d_vectordev_mod.F90 index 802add96..176e8a6e 100644 --- a/cuda/psb_d_vectordev_mod.F90 +++ b/cuda/psb_d_vectordev_mod.F90 @@ -304,7 +304,6 @@ module psb_d_vectordev_mod end function asumMultiVecDeviceDouble end interface - interface axpbyMultiVecDevice function axpbyMultiVecDeviceDouble(n,alpha,deviceVecA,beta,deviceVecB) & & result(res) bind(c,name='axpbyMultiVecDeviceDouble') @@ -316,6 +315,18 @@ module psb_d_vectordev_mod end function axpbyMultiVecDeviceDouble end interface + interface abgdxyzMultiVecDevice + function abgdxyzMultiVecDeviceDouble(n,alpha,beta,gamma,delta,deviceVecX,& + & deviceVecY,deviceVecZ) & + & result(res) bind(c,name='abgdxyzMultiVecDeviceDouble') + use iso_c_binding + integer(c_int) :: res + integer(c_int), value :: n + type(c_double), value :: alpha, beta,gamma,delta + type(c_ptr), value :: deviceVecX, deviceVecY, deviceVecZ + end function abgdxyzMultiVecDeviceDouble + end interface + interface axyMultiVecDevice function axyMultiVecDeviceDouble(n,alpha,deviceVecA,deviceVecB) & & result(res) bind(c,name='axyMultiVecDeviceDouble') diff --git a/cuda/psb_s_vectordev_mod.F90 b/cuda/psb_s_vectordev_mod.F90 index 3ecabe70..73bb7445 100644 --- a/cuda/psb_s_vectordev_mod.F90 +++ b/cuda/psb_s_vectordev_mod.F90 @@ -304,7 +304,6 @@ module psb_s_vectordev_mod end function asumMultiVecDeviceFloat end interface - interface axpbyMultiVecDevice function axpbyMultiVecDeviceFloat(n,alpha,deviceVecA,beta,deviceVecB) & & result(res) bind(c,name='axpbyMultiVecDeviceFloat') @@ -316,6 +315,18 @@ module psb_s_vectordev_mod end function axpbyMultiVecDeviceFloat end interface + interface abgdxyzMultiVecDevice + function abgdxyzMultiVecDeviceFloat(n,alpha,beta,gamma,delta,deviceVecX,& + & deviceVecY,deviceVecZ) & + & result(res) bind(c,name='abgdxyzMultiVecDeviceFloat') + use iso_c_binding + integer(c_int) :: res + integer(c_int), value :: n + type(c_float), value :: alpha, beta,gamma,delta + type(c_ptr), value :: deviceVecX, deviceVecY, deviceVecZ + end function abgdxyzMultiVecDeviceFloat + end interface + interface axyMultiVecDevice function axyMultiVecDeviceFloat(n,alpha,deviceVecA,deviceVecB) & & result(res) bind(c,name='axyMultiVecDeviceFloat') diff --git a/cuda/psb_z_vectordev_mod.F90 b/cuda/psb_z_vectordev_mod.F90 index 8f07cd56..fa858acc 100644 --- a/cuda/psb_z_vectordev_mod.F90 +++ b/cuda/psb_z_vectordev_mod.F90 @@ -304,7 +304,6 @@ module psb_z_vectordev_mod end function asumMultiVecDeviceDoubleComplex end interface - interface axpbyMultiVecDevice function axpbyMultiVecDeviceDoubleComplex(n,alpha,deviceVecA,beta,deviceVecB) & & result(res) bind(c,name='axpbyMultiVecDeviceDoubleComplex') @@ -316,6 +315,18 @@ module psb_z_vectordev_mod end function axpbyMultiVecDeviceDoubleComplex end interface + interface abgdxyzMultiVecDevice + function abgdxyzMultiVecDeviceDoubleComplex(n,alpha,beta,gamma,delta,deviceVecX,& + & deviceVecY,deviceVecZ) & + & result(res) bind(c,name='abgdxyzMultiVecDeviceDoubleComplex') + use iso_c_binding + integer(c_int) :: res + integer(c_int), value :: n + type(c_double_complex), value :: alpha, beta,gamma,delta + type(c_ptr), value :: deviceVecX, deviceVecY, deviceVecZ + end function abgdxyzMultiVecDeviceDoubleComplex + end interface + interface axyMultiVecDevice function axyMultiVecDeviceDoubleComplex(n,alpha,deviceVecA,deviceVecB) & & result(res) bind(c,name='axyMultiVecDeviceDoubleComplex') From 2d3773df9887885d4758d4263871c69d0a53c6e4 Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Tue, 20 Feb 2024 13:05:16 +0100 Subject: [PATCH 42/48] CUDA kernels for ABGDXYZ --- cuda/cvectordev.c | 18 ++++++++ cuda/cvectordev.h | 3 ++ cuda/dvectordev.c | 7 --- cuda/dvectordev.h | 2 + cuda/spgpu/kernels/Makefile | 6 +-- cuda/spgpu/kernels/cabgdxyz.cu | 80 ++++++++++++++++++++++++++++++++++ cuda/spgpu/kernels/sabgdxyz.cu | 79 +++++++++++++++++++++++++++++++++ cuda/spgpu/kernels/zabgdxyz.cu | 80 ++++++++++++++++++++++++++++++++++ cuda/spgpu/vector.h | 36 +++++++++++++++ cuda/svectordev.c | 17 ++++++++ cuda/svectordev.h | 2 + cuda/zvectordev.c | 18 ++++++++ cuda/zvectordev.h | 3 ++ 13 files changed, 341 insertions(+), 10 deletions(-) create mode 100644 cuda/spgpu/kernels/cabgdxyz.cu create mode 100644 cuda/spgpu/kernels/sabgdxyz.cu create mode 100644 cuda/spgpu/kernels/zabgdxyz.cu diff --git a/cuda/cvectordev.c b/cuda/cvectordev.c index 518154d5..9db5202e 100644 --- a/cuda/cvectordev.c +++ b/cuda/cvectordev.c @@ -255,6 +255,24 @@ int axpbyMultiVecDeviceFloatComplex(int n,cuFloatComplex alpha, void* devMultiVe return(i); } +int abgdxyzMultiVecDeviceFloatComplex(int n,cuFloatComplex alpha,cuFloatComplex beta, + cuFloatComplex gamma, cuFloatComplex delta, + void* devMultiVecX, void* devMultiVecY, void* devMultiVecZ) +{ int j=0, i=0; + int pitch = 0; + struct MultiVectDevice *devVecX = (struct MultiVectDevice *) devMultiVecX; + struct MultiVectDevice *devVecY = (struct MultiVectDevice *) devMultiVecY; + struct MultiVectDevice *devVecZ = (struct MultiVectDevice *) devMultiVecZ; + spgpuHandle_t handle=psb_cudaGetHandle(); + pitch = devVecY->pitch_; + if ((n > devVecY->size_) || (n>devVecX->size_ )) + return SPGPU_UNSUPPORTED; + + spgpuCabgdxyz(handle,n, alpha,beta,gamma,delta, + (cuFloatComplex *)devVecX->v_,(cuFloatComplex *) devVecY->v_,(cuFloatComplex *) devVecZ->v_); + return(i); +} + int axyMultiVecDeviceFloatComplex(int n, cuFloatComplex alpha, void *deviceVecA, void *deviceVecB) { int i = 0; diff --git a/cuda/cvectordev.h b/cuda/cvectordev.h index 27c8984a..fc18e328 100644 --- a/cuda/cvectordev.h +++ b/cuda/cvectordev.h @@ -69,6 +69,9 @@ int asumMultiVecDeviceFloatComplex(cuFloatComplex* y_res, int n, void* devVecA); int dotMultiVecDeviceFloatComplex(cuFloatComplex* y_res, int n, void* devVecA, void* devVecB); int axpbyMultiVecDeviceFloatComplex(int n, cuFloatComplex alpha, void* devVecX, cuFloatComplex beta, void* devVecY); +int abgdxyzMultiVecDeviceFloatComplex(int n,cuFloatComplex alpha,cuFloatComplex beta, + cuFloatComplex gamma, cuFloatComplex delta, + void* devMultiVecX, void* devMultiVecY, void* devMultiVecZ); int axyMultiVecDeviceFloatComplex(int n, cuFloatComplex alpha, void *deviceVecA, void *deviceVecB); int axybzMultiVecDeviceFloatComplex(int n, cuFloatComplex alpha, void *deviceVecA, void *deviceVecB, cuFloatComplex beta, void *deviceVecZ); diff --git a/cuda/dvectordev.c b/cuda/dvectordev.c index 785753dd..b4ca95f4 100644 --- a/cuda/dvectordev.c +++ b/cuda/dvectordev.c @@ -241,7 +241,6 @@ int axpbyMultiVecDeviceDouble(int n,double alpha, void* devMultiVecX, return(i); } - int abgdxyzMultiVecDeviceDouble(int n,double alpha,double beta, double gamma, double delta, void* devMultiVecX, void* devMultiVecY, void* devMultiVecZ) { int j=0, i=0; @@ -254,14 +253,8 @@ int abgdxyzMultiVecDeviceDouble(int n,double alpha,double beta, double gamma, do if ((n > devVecY->size_) || (n>devVecX->size_ )) return SPGPU_UNSUPPORTED; -#if 1 spgpuDabgdxyz(handle,n, alpha,beta,gamma,delta, (double*)devVecX->v_,(double*) devVecY->v_,(double*) devVecZ->v_); -#else - for(j=0;jcount_;j++) - spgpuDaxpby(handle,(double*)devVecY->v_+pitch*j, n, beta, - (double*)devVecY->v_+pitch*j, alpha,(double*) devVecX->v_+pitch*j); -#endif return(i); } diff --git a/cuda/dvectordev.h b/cuda/dvectordev.h index 25905c60..81a2e8f6 100644 --- a/cuda/dvectordev.h +++ b/cuda/dvectordev.h @@ -67,6 +67,8 @@ int asumMultiVecDeviceDouble(double* y_res, int n, void* devVecA); int dotMultiVecDeviceDouble(double* y_res, int n, void* devVecA, void* devVecB); int axpbyMultiVecDeviceDouble(int n, double alpha, void* devVecX, double beta, void* devVecY); +int abgdxyzMultiVecDeviceDouble(int n,double alpha,double beta, double gamma, double delta, + void* devMultiVecX, void* devMultiVecY, void* devMultiVecZ); int axyMultiVecDeviceDouble(int n, double alpha, void *deviceVecA, void *deviceVecB); int axybzMultiVecDeviceDouble(int n, double alpha, void *deviceVecA, void *deviceVecB, double beta, void *deviceVecZ); diff --git a/cuda/spgpu/kernels/Makefile b/cuda/spgpu/kernels/Makefile index 5ada698a..3e668b4e 100644 --- a/cuda/spgpu/kernels/Makefile +++ b/cuda/spgpu/kernels/Makefile @@ -11,15 +11,15 @@ LIBNAME=$(UP)/libspgpu.a CINCLUDES=-I$(INCDIR) OBJS=cabs.o camax.o casum.o caxpby.o caxy.o cdot.o cgath.o \ - cnrm2.o cscal.o cscat.o csetscal.o \ + cnrm2.o cscal.o cscat.o csetscal.o cabgdxyz.o\ dabs.o damax.o dasum.o daxpby.o daxy.o ddot.o dgath.o dabgdxyz.o\ dia_cspmv.o dia_dspmv.o dia_sspmv.o dia_zspmv.o dnrm2.o \ dscal.o dscat.o dsetscal.o ell_ccsput.o ell_cspmv.o \ ell_dcsput.o ell_dspmv.o ell_scsput.o ell_sspmv.o ell_zcsput.o ell_zspmv.o \ hdia_cspmv.o hdia_dspmv.o hdia_sspmv.o hdia_zspmv.o hell_cspmv.o hell_dspmv.o \ hell_sspmv.o hell_zspmv.o igath.o iscat.o isetscal.o sabs.o samax.o sasum.o \ - saxpby.o saxy.o sdot.o sgath.o snrm2.o sscal.o sscat.o ssetscal.o zabs.o zamax.o \ - zasum.o zaxpby.o zaxy.o zdot.o zgath.o znrm2.o zscal.o zscat.o zsetscal.o + saxpby.o saxy.o sdot.o sgath.o snrm2.o sscal.o sscat.o ssetscal.o zabs.o zamax.o sabgdxyz.o\ + zasum.o zaxpby.o zaxy.o zdot.o zgath.o znrm2.o zscal.o zscat.o zsetscal.o zabgdxyz.o objs: $(OBJS) lib: objs diff --git a/cuda/spgpu/kernels/cabgdxyz.cu b/cuda/spgpu/kernels/cabgdxyz.cu new file mode 100644 index 00000000..00dc6ab4 --- /dev/null +++ b/cuda/spgpu/kernels/cabgdxyz.cu @@ -0,0 +1,80 @@ +/* + * spGPU - Sparse matrices on GPU library. + * + * Copyright (C) 2010 - 2012 + * Davide Barbieri - University of Rome Tor Vergata + * + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * version 3 as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + */ + +#include "cudadebug.h" +#include "cudalang.h" +#include + +extern "C" +{ +#include "core.h" +#include "vector.h" + int getGPUMultiProcessors(); + int getGPUMaxThreadsPerMP(); +} + + +#include "debug.h" + +#define BLOCK_SIZE 512 + +__global__ void spgpuCabgdxyz_krn(int n, cuFloatComplex alpha, cuFloatComplex beta, + cuFloatComplex gamma, cuFloatComplex delta, + cuFloatComplex * x, cuFloatComplex *y, cuFloatComplex *z) +{ + int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; + unsigned int gridSize = blockDim.x * gridDim.x; + cuFloatComplex t; + for ( ; id < n; id +=gridSize) + //if (id,n) + { + + if (cuFloatComplex_isZero(beta)) + t = cuCmulf(alpha,x[id]); + else + t = cuCfmaf(alpha, x[id], cuCmulf(beta,y[id])); + if (cuFloatComplex_isZero(delta)) + z[id] = cuCmulf(gamma, t); + else + z[id] = cuCfmafmulf(gamma, t, cuCmulf(delta,z[id])); + y[id] = t; + } +} + + +void spgpuCabgdxyz(spgpuHandle_t handle, + int n, + cuFloatComplex alpha, + cuFloatComplex beta, + cuFloatComplex gamma, + cuFloatComplex delta, + __device cuFloatComplex * x, + __device cuFloatComplex * y, + __device cuFloatComplex *z) +{ + int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; + int num_mp, max_threads_mp, num_blocks_mp, num_blocks; + dim3 block(BLOCK_SIZE); + num_mp = getGPUMultiProcessors(); + max_threads_mp = getGPUMaxThreadsPerMP(); + num_blocks_mp = max_threads_mp/BLOCK_SIZE; + num_blocks = num_blocks_mp*num_mp; + dim3 grid(num_blocks); + + spgpuCabgdxyz_krn<<currentStream>>>(n, alpha, beta, gamma, delta, + x, y, z); +} + diff --git a/cuda/spgpu/kernels/sabgdxyz.cu b/cuda/spgpu/kernels/sabgdxyz.cu new file mode 100644 index 00000000..8c137ed3 --- /dev/null +++ b/cuda/spgpu/kernels/sabgdxyz.cu @@ -0,0 +1,79 @@ +/* + * spGPU - Sparse matrices on GPU library. + * + * Copyright (C) 2010 - 2012 + * Davide Barbieri - University of Rome Tor Vergata + * + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * version 3 as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + */ + +#include "cudadebug.h" +#include "cudalang.h" +#include + +extern "C" +{ +#include "core.h" +#include "vector.h" + int getGPUMultiProcessors(); + int getGPUMaxThreadsPerMP(); +} + + +#include "debug.h" + +#define BLOCK_SIZE 512 + +__global__ void spgpuSabgdxyz_krn(int n, float alpha, float beta, float gamma, float delta, + float* x, float *y, float *z) +{ + int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; + unsigned int gridSize = blockDim.x * gridDim.x; + float t; + for ( ; id < n; id +=gridSize) + //if (id,n) + { + + if (beta == 0.0) + t = PREC_FMUL(alpha,x[id]); + else + t = PREC_FADD(PREC_FMUL(alpha, x[id]), PREC_FMUL(beta,y[id])); + if (delta == 0.0) + z[id] = gamma * t; + else + z[id] = PREC_FADD(PREC_FMUL(gamma, t), PREC_FMUL(delta,z[id])); + y[id] = t; + } +} + + +void spgpuSabgdxyz(spgpuHandle_t handle, + int n, + float alpha, + float beta, + float gamma, + float delta, + __device float* x, + __device float* y, + __device float *z) +{ + int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; + int num_mp, max_threads_mp, num_blocks_mp, num_blocks; + dim3 block(BLOCK_SIZE); + num_mp = getGPUMultiProcessors(); + max_threads_mp = getGPUMaxThreadsPerMP(); + num_blocks_mp = max_threads_mp/BLOCK_SIZE; + num_blocks = num_blocks_mp*num_mp; + dim3 grid(num_blocks); + + spgpuSabgdxyz_krn<<currentStream>>>(n, alpha, beta, gamma, delta, + x, y, z); +} + diff --git a/cuda/spgpu/kernels/zabgdxyz.cu b/cuda/spgpu/kernels/zabgdxyz.cu new file mode 100644 index 00000000..48def937 --- /dev/null +++ b/cuda/spgpu/kernels/zabgdxyz.cu @@ -0,0 +1,80 @@ +/* + * spGPU - Sparse matrices on GPU library. + * + * Copyright (C) 2010 - 2012 + * Davide Barbieri - University of Rome Tor Vergata + * + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * version 3 as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + */ + +#include "cudadebug.h" +#include "cudalang.h" +#include + +extern "C" +{ +#include "core.h" +#include "vector.h" + int getGPUMultiProcessors(); + int getGPUMaxThreadsPerMP(); +} + + +#include "debug.h" + +#define BLOCK_SIZE 512 + +__global__ void spgpuZabgdxyz_krn(int n, cuDoubleComplex alpha, cuDoubleComplex beta, + cuDoubleComplex gamma, cuDoubleComplex delta, + cuDoubleComplex * x, cuDoubleComplex *y, cuDoubleComplex *z) +{ + int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; + unsigned int gridSize = blockDim.x * gridDim.x; + cuDoubleComplex t; + for ( ; id < n; id +=gridSize) + //if (id,n) + { + + if (cuDoubleComplex_isZero(beta)) + t = cuCmul(alpha,x[id]); + else + t = cuCfma(alpha, x[id], cuCmul(beta,y[id])); + if (cuDoubleComplex_isZero(delta)) + z[id] = cuCmul(gamma, t); + else + z[id] = cuCfma(gamma, t, cuCmul(delta,z[id])); + y[id] = t; + } +} + + +void spgpuZabgdxyz(spgpuHandle_t handle, + int n, + cuDoubleComplex alpha, + cuDoubleComplex beta, + cuDoubleComplex gamma, + cuDoubleComplex delta, + __device cuDoubleComplex * x, + __device cuDoubleComplex * y, + __device cuDoubleComplex *z) +{ + int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; + int num_mp, max_threads_mp, num_blocks_mp, num_blocks; + dim3 block(BLOCK_SIZE); + num_mp = getGPUMultiProcessors(); + max_threads_mp = getGPUMaxThreadsPerMP(); + num_blocks_mp = max_threads_mp/BLOCK_SIZE; + num_blocks = num_blocks_mp*num_mp; + dim3 grid(num_blocks); + + spgpuZabgdxyz_krn<<currentStream>>>(n, alpha, beta, gamma, delta, + x, y, z); +} + diff --git a/cuda/spgpu/vector.h b/cuda/spgpu/vector.h index 69ffedf0..9fc3e658 100644 --- a/cuda/spgpu/vector.h +++ b/cuda/spgpu/vector.h @@ -181,6 +181,18 @@ void spgpuSaxpby(spgpuHandle_t handle, float alpha, __device float* x); + +void spgpuSabgdxyz(spgpuHandle_t handle, + int n, + float alpha, + float beta, + float gamma, + float delta, + __device float* x, + __device float *y, + __device float *z) +; + /** * \fn void spgpuSmaxpby(spgpuHandle_t handle, __device float *z, int n, float beta, __device float *y, float alpha, __device float* x, int count, int pitch) * Computes the single precision z = beta * y + alpha * x of x and y multivectors. z could be exactly x or y (without offset) or another vector. @@ -755,6 +767,18 @@ void spgpuCaxpby(spgpuHandle_t handle, cuFloatComplex alpha, __device cuFloatComplex* x); + +void spgpuCabgdxyz(spgpuHandle_t handle, + int n, + cuFloatComplex alpha, + cuFloatComplex beta, + cuFloatComplex gamma, + cuFloatComplex delta, + __device cuFloatComplex* x, + __device cuFloatComplex *y, + __device cuFloatComplex *z) +; + /** * \fn void spgpuCmaxpby(spgpuHandle_t handle, __device cuFloatComplex *z, int n, cuFloatComplex beta, __device cuFloatComplex *y, cuFloatComplex alpha, __device cuFloatComplex* x, int count, int pitch) * Computes the single precision complex z = beta * y + alpha * x of x and y multivectors. z could be exactly x or y (without offset) or another vector. @@ -1034,6 +1058,18 @@ void spgpuZaxpby(spgpuHandle_t handle, cuDoubleComplex alpha, __device cuDoubleComplex* x); + +void spgpuZabgdxyz(spgpuHandle_t handle, + int n, + cuDoubleComplex alpha, + cuDoubleComplex beta, + cuDoubleComplex gamma, + cuDoubleComplex delta, + __device cuDoubleComplex* x, + __device cuDoubleComplex *y, + __device cuDoubleComplex *z) +; + /** * \fn void spgpuZmaxpby(spgpuHandle_t handle, __device cuDoubleComplex *z, int n, cuDoubleComplex beta, __device cuDoubleComplex *y, cuDoubleComplex alpha, __device cuDoubleComplex* x, int count, int pitch) * Computes the double precision complex z = beta * y + alpha * x of x and y multivectors. z could be exactly x or y (without offset) or another vector. diff --git a/cuda/svectordev.c b/cuda/svectordev.c index bfa4061a..b84718f5 100644 --- a/cuda/svectordev.c +++ b/cuda/svectordev.c @@ -241,6 +241,23 @@ int axpbyMultiVecDeviceFloat(int n,float alpha, void* devMultiVecX, return(i); } +int abgdxyzMultiVecDeviceFloat(int n,float alpha,float beta, float gamma, float delta, + void* devMultiVecX, void* devMultiVecY, void* devMultiVecZ) +{ int j=0, i=0; + int pitch = 0; + struct MultiVectDevice *devVecX = (struct MultiVectDevice *) devMultiVecX; + struct MultiVectDevice *devVecY = (struct MultiVectDevice *) devMultiVecY; + struct MultiVectDevice *devVecZ = (struct MultiVectDevice *) devMultiVecZ; + spgpuHandle_t handle=psb_cudaGetHandle(); + pitch = devVecY->pitch_; + if ((n > devVecY->size_) || (n>devVecX->size_ )) + return SPGPU_UNSUPPORTED; + + spgpuSabgdxyz(handle,n, alpha,beta,gamma,delta, + (float*)devVecX->v_,(float*) devVecY->v_,(float*) devVecZ->v_); + return(i); +} + int axyMultiVecDeviceFloat(int n, float alpha, void *deviceVecA, void *deviceVecB) { int i = 0; struct MultiVectDevice *devVecA = (struct MultiVectDevice *) deviceVecA; diff --git a/cuda/svectordev.h b/cuda/svectordev.h index bf25fcb1..730f929a 100644 --- a/cuda/svectordev.h +++ b/cuda/svectordev.h @@ -67,6 +67,8 @@ int asumMultiVecDeviceFloat(float* y_res, int n, void* devVecA); int dotMultiVecDeviceFloat(float* y_res, int n, void* devVecA, void* devVecB); int axpbyMultiVecDeviceFloat(int n, float alpha, void* devVecX, float beta, void* devVecY); +int abgdxyzMultiVecDeviceFloat(int n,float alpha,float beta, float gamma, float delta, + void* devMultiVecX, void* devMultiVecY, void* devMultiVecZ); int axyMultiVecDeviceFloat(int n, float alpha, void *deviceVecA, void *deviceVecB); int axybzMultiVecDeviceFloat(int n, float alpha, void *deviceVecA, void *deviceVecB, float beta, void *deviceVecZ); diff --git a/cuda/zvectordev.c b/cuda/zvectordev.c index 0fb1c67e..d1f23f2a 100644 --- a/cuda/zvectordev.c +++ b/cuda/zvectordev.c @@ -234,6 +234,24 @@ int dotMultiVecDeviceDoubleComplex(cuDoubleComplex* y_res, int n, void* devMulti return(i); } +int abgdxyzMultiVecDeviceDoubleComplex(int n,cuDoubleComplex alpha, + cuDoubleComplex beta, cuDoubleComplex gamma, cuDoubleComplex delta, + void* devMultiVecX, void* devMultiVecY, void* devMultiVecZ) +{ int j=0, i=0; + int pitch = 0; + struct MultiVectDevice *devVecX = (struct MultiVectDevice *) devMultiVecX; + struct MultiVectDevice *devVecY = (struct MultiVectDevice *) devMultiVecY; + struct MultiVectDevice *devVecZ = (struct MultiVectDevice *) devMultiVecZ; + spgpuHandle_t handle=psb_cudaGetHandle(); + pitch = devVecY->pitch_; + if ((n > devVecY->size_) || (n>devVecX->size_ )) + return SPGPU_UNSUPPORTED; + + spgpuZabgdxyz(handle,n, alpha,beta,gamma,delta, + (cuDoubleComplex *)devVecX->v_,(cuDoubleComplex *) devVecY->v_,(cuDoubleComplex *) devVecZ->v_); + return(i); +} + int axpbyMultiVecDeviceDoubleComplex(int n,cuDoubleComplex alpha, void* devMultiVecX, cuDoubleComplex beta, void* devMultiVecY) { int j=0, i=0; diff --git a/cuda/zvectordev.h b/cuda/zvectordev.h index 96330a7a..4c32f11c 100644 --- a/cuda/zvectordev.h +++ b/cuda/zvectordev.h @@ -77,6 +77,9 @@ int dotMultiVecDeviceDoubleComplex(cuDoubleComplex* y_res, int n, int axpbyMultiVecDeviceDoubleComplex(int n, cuDoubleComplex alpha, void* devVecX, cuDoubleComplex beta, void* devVecY); +int abgdxyzMultiVecDeviceDoubleComplex(int n,cuDoubleComplex alpha, + cuDoubleComplex beta, cuDoubleComplex gamma, cuDoubleComplex delta, + void* devMultiVecX, void* devMultiVecY, void* devMultiVecZ); int axyMultiVecDeviceDoubleComplex(int n, cuDoubleComplex alpha, void *deviceVecA, void *deviceVecB); int axybzMultiVecDeviceDoubleComplex(int n, cuDoubleComplex alpha, void *deviceVecA, From d95077ffd66389d11e5782ca76f26432c05cc270 Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Tue, 20 Feb 2024 13:46:58 +0100 Subject: [PATCH 43/48] Fix typo in vectordev_mod --- cuda/psb_c_vectordev_mod.F90 | 4 +--- cuda/psb_d_vectordev_mod.F90 | 4 +--- cuda/psb_i_vectordev_mod.F90 | 2 -- cuda/psb_s_vectordev_mod.F90 | 4 +--- cuda/psb_z_vectordev_mod.F90 | 4 +--- 5 files changed, 4 insertions(+), 14 deletions(-) diff --git a/cuda/psb_c_vectordev_mod.F90 b/cuda/psb_c_vectordev_mod.F90 index 88888f61..20a4ac3f 100644 --- a/cuda/psb_c_vectordev_mod.F90 +++ b/cuda/psb_c_vectordev_mod.F90 @@ -28,8 +28,6 @@ ! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE ! POSSIBILITY OF SUCH DAMAGE. ! - - module psb_c_vectordev_mod use psb_base_vectordev_mod @@ -322,7 +320,7 @@ module psb_c_vectordev_mod use iso_c_binding integer(c_int) :: res integer(c_int), value :: n - type(c_float_complex), value :: alpha, beta,gamma,delta + complex(c_float_complex), value :: alpha, beta,gamma,delta type(c_ptr), value :: deviceVecX, deviceVecY, deviceVecZ end function abgdxyzMultiVecDeviceFloatComplex end interface diff --git a/cuda/psb_d_vectordev_mod.F90 b/cuda/psb_d_vectordev_mod.F90 index 176e8a6e..080b27fe 100644 --- a/cuda/psb_d_vectordev_mod.F90 +++ b/cuda/psb_d_vectordev_mod.F90 @@ -28,8 +28,6 @@ ! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE ! POSSIBILITY OF SUCH DAMAGE. ! - - module psb_d_vectordev_mod use psb_base_vectordev_mod @@ -322,7 +320,7 @@ module psb_d_vectordev_mod use iso_c_binding integer(c_int) :: res integer(c_int), value :: n - type(c_double), value :: alpha, beta,gamma,delta + real(c_double), value :: alpha, beta,gamma,delta type(c_ptr), value :: deviceVecX, deviceVecY, deviceVecZ end function abgdxyzMultiVecDeviceDouble end interface diff --git a/cuda/psb_i_vectordev_mod.F90 b/cuda/psb_i_vectordev_mod.F90 index 84037aaf..ce3dc5e1 100644 --- a/cuda/psb_i_vectordev_mod.F90 +++ b/cuda/psb_i_vectordev_mod.F90 @@ -28,8 +28,6 @@ ! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE ! POSSIBILITY OF SUCH DAMAGE. ! - - module psb_i_vectordev_mod use psb_base_vectordev_mod diff --git a/cuda/psb_s_vectordev_mod.F90 b/cuda/psb_s_vectordev_mod.F90 index 73bb7445..19776cbc 100644 --- a/cuda/psb_s_vectordev_mod.F90 +++ b/cuda/psb_s_vectordev_mod.F90 @@ -28,8 +28,6 @@ ! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE ! POSSIBILITY OF SUCH DAMAGE. ! - - module psb_s_vectordev_mod use psb_base_vectordev_mod @@ -322,7 +320,7 @@ module psb_s_vectordev_mod use iso_c_binding integer(c_int) :: res integer(c_int), value :: n - type(c_float), value :: alpha, beta,gamma,delta + real(c_float), value :: alpha, beta,gamma,delta type(c_ptr), value :: deviceVecX, deviceVecY, deviceVecZ end function abgdxyzMultiVecDeviceFloat end interface diff --git a/cuda/psb_z_vectordev_mod.F90 b/cuda/psb_z_vectordev_mod.F90 index fa858acc..07e4ba37 100644 --- a/cuda/psb_z_vectordev_mod.F90 +++ b/cuda/psb_z_vectordev_mod.F90 @@ -28,8 +28,6 @@ ! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE ! POSSIBILITY OF SUCH DAMAGE. ! - - module psb_z_vectordev_mod use psb_base_vectordev_mod @@ -322,7 +320,7 @@ module psb_z_vectordev_mod use iso_c_binding integer(c_int) :: res integer(c_int), value :: n - type(c_double_complex), value :: alpha, beta,gamma,delta + complex(c_double_complex), value :: alpha, beta,gamma,delta type(c_ptr), value :: deviceVecX, deviceVecY, deviceVecZ end function abgdxyzMultiVecDeviceDoubleComplex end interface From 0e269ed6418dd4e953ce57b6b04ce74b585f5a3e Mon Sep 17 00:00:00 2001 From: sfilippone Date: Tue, 20 Feb 2024 13:57:32 +0100 Subject: [PATCH 44/48] typo in Cabgdxyz --- cuda/spgpu/kernels/cabgdxyz.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuda/spgpu/kernels/cabgdxyz.cu b/cuda/spgpu/kernels/cabgdxyz.cu index 00dc6ab4..a85b3873 100644 --- a/cuda/spgpu/kernels/cabgdxyz.cu +++ b/cuda/spgpu/kernels/cabgdxyz.cu @@ -49,7 +49,7 @@ __global__ void spgpuCabgdxyz_krn(int n, cuFloatComplex alpha, cuFloatComplex if (cuFloatComplex_isZero(delta)) z[id] = cuCmulf(gamma, t); else - z[id] = cuCfmafmulf(gamma, t, cuCmulf(delta,z[id])); + z[id] = cuCfmaf(gamma, t, cuCmulf(delta,z[id])); y[id] = t; } } From b5d5f9766107f1d2b6c5a61885fab6df388b4cda Mon Sep 17 00:00:00 2001 From: sfilippone Date: Tue, 20 Feb 2024 14:04:23 +0100 Subject: [PATCH 45/48] Improve cuda%zero() --- cuda/psb_c_cuda_vect_mod.F90 | 4 +++- cuda/psb_d_cuda_vect_mod.F90 | 4 +++- cuda/psb_i_cuda_vect_mod.F90 | 4 +++- cuda/psb_s_cuda_vect_mod.F90 | 4 +++- cuda/psb_z_cuda_vect_mod.F90 | 4 +++- 5 files changed, 15 insertions(+), 5 deletions(-) diff --git a/cuda/psb_c_cuda_vect_mod.F90 b/cuda/psb_c_cuda_vect_mod.F90 index 9b3b6fb1..7eee128f 100644 --- a/cuda/psb_c_cuda_vect_mod.F90 +++ b/cuda/psb_c_cuda_vect_mod.F90 @@ -668,7 +668,9 @@ contains use psi_serial_mod implicit none class(psb_c_vect_cuda), intent(inout) :: x - + ! Since we are overwriting, make sure to do it + ! on the GPU side + call x%set_dev() call x%set_scal(czero) end subroutine c_cuda_zero diff --git a/cuda/psb_d_cuda_vect_mod.F90 b/cuda/psb_d_cuda_vect_mod.F90 index c98d66f6..1e6e9f2a 100644 --- a/cuda/psb_d_cuda_vect_mod.F90 +++ b/cuda/psb_d_cuda_vect_mod.F90 @@ -668,7 +668,9 @@ contains use psi_serial_mod implicit none class(psb_d_vect_cuda), intent(inout) :: x - + ! Since we are overwriting, make sure to do it + ! on the GPU side + call x%set_dev() call x%set_scal(dzero) end subroutine d_cuda_zero diff --git a/cuda/psb_i_cuda_vect_mod.F90 b/cuda/psb_i_cuda_vect_mod.F90 index a018713e..903c4a08 100644 --- a/cuda/psb_i_cuda_vect_mod.F90 +++ b/cuda/psb_i_cuda_vect_mod.F90 @@ -650,7 +650,9 @@ contains use psi_serial_mod implicit none class(psb_i_vect_cuda), intent(inout) :: x - + ! Since we are overwriting, make sure to do it + ! on the GPU side + call x%set_dev() call x%set_scal(izero) end subroutine i_cuda_zero diff --git a/cuda/psb_s_cuda_vect_mod.F90 b/cuda/psb_s_cuda_vect_mod.F90 index 55ed4a7d..3497c33e 100644 --- a/cuda/psb_s_cuda_vect_mod.F90 +++ b/cuda/psb_s_cuda_vect_mod.F90 @@ -668,7 +668,9 @@ contains use psi_serial_mod implicit none class(psb_s_vect_cuda), intent(inout) :: x - + ! Since we are overwriting, make sure to do it + ! on the GPU side + call x%set_dev() call x%set_scal(szero) end subroutine s_cuda_zero diff --git a/cuda/psb_z_cuda_vect_mod.F90 b/cuda/psb_z_cuda_vect_mod.F90 index 2114723b..8483544c 100644 --- a/cuda/psb_z_cuda_vect_mod.F90 +++ b/cuda/psb_z_cuda_vect_mod.F90 @@ -668,7 +668,9 @@ contains use psi_serial_mod implicit none class(psb_z_vect_cuda), intent(inout) :: x - + ! Since we are overwriting, make sure to do it + ! on the GPU side + call x%set_dev() call x%set_scal(zzero) end subroutine z_cuda_zero From 86be8ebcd0452b108be6580d8af58fa55692a08a Mon Sep 17 00:00:00 2001 From: sfilippone Date: Mon, 4 Mar 2024 16:29:47 +0100 Subject: [PATCH 46/48] New method W%XYZW() --- base/modules/auxil/psi_c_serial_mod.f90 | 14 ++++ base/modules/auxil/psi_d_serial_mod.f90 | 14 ++++ base/modules/auxil/psi_e_serial_mod.f90 | 14 ++++ base/modules/auxil/psi_i2_serial_mod.f90 | 14 ++++ base/modules/auxil/psi_m_serial_mod.f90 | 14 ++++ base/modules/auxil/psi_s_serial_mod.f90 | 14 ++++ base/modules/auxil/psi_z_serial_mod.f90 | 14 ++++ base/modules/serial/psb_c_base_vect_mod.F90 | 44 +++++++++---- base/modules/serial/psb_c_vect_mod.F90 | 17 +++++ base/modules/serial/psb_d_base_vect_mod.F90 | 44 +++++++++---- base/modules/serial/psb_d_vect_mod.F90 | 17 +++++ base/modules/serial/psb_s_base_vect_mod.F90 | 44 +++++++++---- base/modules/serial/psb_s_vect_mod.F90 | 17 +++++ base/modules/serial/psb_z_base_vect_mod.F90 | 44 +++++++++---- base/modules/serial/psb_z_vect_mod.F90 | 17 +++++ base/serial/psi_c_serial_impl.F90 | 72 +++++++++++++++++++++ base/serial/psi_d_serial_impl.F90 | 72 +++++++++++++++++++++ base/serial/psi_e_serial_impl.F90 | 72 +++++++++++++++++++++ base/serial/psi_i2_serial_impl.F90 | 72 +++++++++++++++++++++ base/serial/psi_m_serial_impl.F90 | 72 +++++++++++++++++++++ base/serial/psi_s_serial_impl.F90 | 72 +++++++++++++++++++++ base/serial/psi_z_serial_impl.F90 | 72 +++++++++++++++++++++ 22 files changed, 790 insertions(+), 56 deletions(-) diff --git a/base/modules/auxil/psi_c_serial_mod.f90 b/base/modules/auxil/psi_c_serial_mod.f90 index 6926d6bd..38b740a7 100644 --- a/base/modules/auxil/psi_c_serial_mod.f90 +++ b/base/modules/auxil/psi_c_serial_mod.f90 @@ -112,6 +112,20 @@ module psi_c_serial_mod end subroutine psi_cabgdxyz end interface psi_abgdxyz + interface psi_xyzw + subroutine psi_cxyzw(m,a,b,c,d,e,f,x, y, z,w, info) + import :: psb_ipk_, psb_spk_ + implicit none + integer(psb_ipk_), intent(in) :: m + complex(psb_spk_), intent (in) :: x(:) + complex(psb_spk_), intent (inout) :: y(:) + complex(psb_spk_), intent (inout) :: z(:) + complex(psb_spk_), intent (inout) :: w(:) + complex(psb_spk_), intent (in) :: a,b,c,d,e,f + integer(psb_ipk_), intent(out) :: info + end subroutine psi_cxyzw + end interface psi_xyzw + interface psi_gth subroutine psi_cgthmv(n,k,idx,alpha,x,beta,y) import :: psb_ipk_, psb_spk_ diff --git a/base/modules/auxil/psi_d_serial_mod.f90 b/base/modules/auxil/psi_d_serial_mod.f90 index 42185d21..1d65c5f6 100644 --- a/base/modules/auxil/psi_d_serial_mod.f90 +++ b/base/modules/auxil/psi_d_serial_mod.f90 @@ -112,6 +112,20 @@ module psi_d_serial_mod end subroutine psi_dabgdxyz end interface psi_abgdxyz + interface psi_xyzw + subroutine psi_dxyzw(m,a,b,c,d,e,f,x, y, z,w, info) + import :: psb_ipk_, psb_dpk_ + implicit none + integer(psb_ipk_), intent(in) :: m + real(psb_dpk_), intent (in) :: x(:) + real(psb_dpk_), intent (inout) :: y(:) + real(psb_dpk_), intent (inout) :: z(:) + real(psb_dpk_), intent (inout) :: w(:) + real(psb_dpk_), intent (in) :: a,b,c,d,e,f + integer(psb_ipk_), intent(out) :: info + end subroutine psi_dxyzw + end interface psi_xyzw + interface psi_gth subroutine psi_dgthmv(n,k,idx,alpha,x,beta,y) import :: psb_ipk_, psb_dpk_ diff --git a/base/modules/auxil/psi_e_serial_mod.f90 b/base/modules/auxil/psi_e_serial_mod.f90 index ffba06fd..6f4e8c06 100644 --- a/base/modules/auxil/psi_e_serial_mod.f90 +++ b/base/modules/auxil/psi_e_serial_mod.f90 @@ -112,6 +112,20 @@ module psi_e_serial_mod end subroutine psi_eabgdxyz end interface psi_abgdxyz + interface psi_xyzw + subroutine psi_exyzw(m,a,b,c,d,e,f,x, y, z,w, info) + import :: psb_ipk_, psb_lpk_,psb_mpk_, psb_epk_ + implicit none + integer(psb_ipk_), intent(in) :: m + integer(psb_epk_), intent (in) :: x(:) + integer(psb_epk_), intent (inout) :: y(:) + integer(psb_epk_), intent (inout) :: z(:) + integer(psb_epk_), intent (inout) :: w(:) + integer(psb_epk_), intent (in) :: a,b,c,d,e,f + integer(psb_ipk_), intent(out) :: info + end subroutine psi_exyzw + end interface psi_xyzw + interface psi_gth subroutine psi_egthmv(n,k,idx,alpha,x,beta,y) import :: psb_ipk_, psb_lpk_,psb_mpk_, psb_epk_ diff --git a/base/modules/auxil/psi_i2_serial_mod.f90 b/base/modules/auxil/psi_i2_serial_mod.f90 index d61a1146..ffa14059 100644 --- a/base/modules/auxil/psi_i2_serial_mod.f90 +++ b/base/modules/auxil/psi_i2_serial_mod.f90 @@ -112,6 +112,20 @@ module psi_i2_serial_mod end subroutine psi_i2abgdxyz end interface psi_abgdxyz + interface psi_xyzw + subroutine psi_i2xyzw(m,a,b,c,d,e,f,x, y, z,w, info) + import :: psb_ipk_, psb_lpk_,psb_mpk_, psb_epk_ + implicit none + integer(psb_ipk_), intent(in) :: m + integer(psb_i2pk_), intent (in) :: x(:) + integer(psb_i2pk_), intent (inout) :: y(:) + integer(psb_i2pk_), intent (inout) :: z(:) + integer(psb_i2pk_), intent (inout) :: w(:) + integer(psb_i2pk_), intent (in) :: a,b,c,d,e,f + integer(psb_ipk_), intent(out) :: info + end subroutine psi_i2xyzw + end interface psi_xyzw + interface psi_gth subroutine psi_i2gthmv(n,k,idx,alpha,x,beta,y) import :: psb_ipk_, psb_lpk_,psb_mpk_, psb_epk_ diff --git a/base/modules/auxil/psi_m_serial_mod.f90 b/base/modules/auxil/psi_m_serial_mod.f90 index 76131d75..5661fdbf 100644 --- a/base/modules/auxil/psi_m_serial_mod.f90 +++ b/base/modules/auxil/psi_m_serial_mod.f90 @@ -112,6 +112,20 @@ module psi_m_serial_mod end subroutine psi_mabgdxyz end interface psi_abgdxyz + interface psi_xyzw + subroutine psi_mxyzw(m,a,b,c,d,e,f,x, y, z,w, info) + import :: psb_ipk_, psb_lpk_,psb_mpk_, psb_epk_ + implicit none + integer(psb_ipk_), intent(in) :: m + integer(psb_mpk_), intent (in) :: x(:) + integer(psb_mpk_), intent (inout) :: y(:) + integer(psb_mpk_), intent (inout) :: z(:) + integer(psb_mpk_), intent (inout) :: w(:) + integer(psb_mpk_), intent (in) :: a,b,c,d,e,f + integer(psb_ipk_), intent(out) :: info + end subroutine psi_mxyzw + end interface psi_xyzw + interface psi_gth subroutine psi_mgthmv(n,k,idx,alpha,x,beta,y) import :: psb_ipk_, psb_lpk_,psb_mpk_, psb_epk_ diff --git a/base/modules/auxil/psi_s_serial_mod.f90 b/base/modules/auxil/psi_s_serial_mod.f90 index 02b96311..5cc17d58 100644 --- a/base/modules/auxil/psi_s_serial_mod.f90 +++ b/base/modules/auxil/psi_s_serial_mod.f90 @@ -112,6 +112,20 @@ module psi_s_serial_mod end subroutine psi_sabgdxyz end interface psi_abgdxyz + interface psi_xyzw + subroutine psi_sxyzw(m,a,b,c,d,e,f,x, y, z,w, info) + import :: psb_ipk_, psb_spk_ + implicit none + integer(psb_ipk_), intent(in) :: m + real(psb_spk_), intent (in) :: x(:) + real(psb_spk_), intent (inout) :: y(:) + real(psb_spk_), intent (inout) :: z(:) + real(psb_spk_), intent (inout) :: w(:) + real(psb_spk_), intent (in) :: a,b,c,d,e,f + integer(psb_ipk_), intent(out) :: info + end subroutine psi_sxyzw + end interface psi_xyzw + interface psi_gth subroutine psi_sgthmv(n,k,idx,alpha,x,beta,y) import :: psb_ipk_, psb_spk_ diff --git a/base/modules/auxil/psi_z_serial_mod.f90 b/base/modules/auxil/psi_z_serial_mod.f90 index a86bdd70..8a3f053d 100644 --- a/base/modules/auxil/psi_z_serial_mod.f90 +++ b/base/modules/auxil/psi_z_serial_mod.f90 @@ -112,6 +112,20 @@ module psi_z_serial_mod end subroutine psi_zabgdxyz end interface psi_abgdxyz + interface psi_xyzw + subroutine psi_zxyzw(m,a,b,c,d,e,f,x, y, z,w, info) + import :: psb_ipk_, psb_dpk_ + implicit none + integer(psb_ipk_), intent(in) :: m + complex(psb_dpk_), intent (in) :: x(:) + complex(psb_dpk_), intent (inout) :: y(:) + complex(psb_dpk_), intent (inout) :: z(:) + complex(psb_dpk_), intent (inout) :: w(:) + complex(psb_dpk_), intent (in) :: a,b,c,d,e,f + integer(psb_ipk_), intent(out) :: info + end subroutine psi_zxyzw + end interface psi_xyzw + interface psi_gth subroutine psi_zgthmv(n,k,idx,alpha,x,beta,y) import :: psb_ipk_, psb_dpk_ diff --git a/base/modules/serial/psb_c_base_vect_mod.F90 b/base/modules/serial/psb_c_base_vect_mod.F90 index 5a468d55..41bab5ab 100644 --- a/base/modules/serial/psb_c_base_vect_mod.F90 +++ b/base/modules/serial/psb_c_base_vect_mod.F90 @@ -156,6 +156,7 @@ module psb_c_base_vect_mod procedure, pass(z) :: axpby_a2 => c_base_axpby_a2 generic, public :: axpby => axpby_v, axpby_a, axpby_v2, axpby_a2 procedure, pass(z) :: abgdxyz => c_base_abgdxyz + procedure, pass(w) :: xyzw => c_base_xyzw ! ! Vector by vector multiplication. Need all variants @@ -1155,22 +1156,37 @@ contains complex(psb_spk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info - if (.false.) then - if (x%is_dev()) call x%sync() - - call y%axpby(m,alpha,x,beta,info) - call z%axpby(m,gamma,y,delta,info) - else - if (x%is_dev().and.(alpha/=czero)) call x%sync() - if (y%is_dev().and.(beta/=czero)) call y%sync() - if (z%is_dev().and.(delta/=czero)) call z%sync() - call psi_cabgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) - call y%set_host() - call z%set_host() - end if - + if (x%is_dev().and.(alpha/=czero)) call x%sync() + if (y%is_dev().and.(beta/=czero)) call y%sync() + if (z%is_dev().and.(delta/=czero)) call z%sync() + call psi_abgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) + call y%set_host() + call z%set_host() + end subroutine c_base_abgdxyz + subroutine c_base_xyzw(m,a,b,c,d,e,f,x, y, z, w,info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_c_base_vect_type), intent(inout) :: x + class(psb_c_base_vect_type), intent(inout) :: y + class(psb_c_base_vect_type), intent(inout) :: z + class(psb_c_base_vect_type), intent(inout) :: w + complex(psb_spk_), intent (in) :: a,b,c,d,e,f + integer(psb_ipk_), intent(out) :: info + + if (x%is_dev().and.(a/=czero)) call x%sync() + if (y%is_dev().and.(b/=czero)) call y%sync() + if (z%is_dev().and.(d/=czero)) call z%sync() + if (w%is_dev().and.(f/=czero)) call w%sync() + call psi_xyzw(m,a,b,c,d,e,f,x%v, y%v, z%v, w%v, info) + call y%set_host() + call z%set_host() + call w%set_host() + + end subroutine c_base_xyzw + ! ! Multiple variants of two operations: diff --git a/base/modules/serial/psb_c_vect_mod.F90 b/base/modules/serial/psb_c_vect_mod.F90 index e0488def..865f9456 100644 --- a/base/modules/serial/psb_c_vect_mod.F90 +++ b/base/modules/serial/psb_c_vect_mod.F90 @@ -103,6 +103,7 @@ module psb_c_vect_mod procedure, pass(z) :: axpby_a2 => c_vect_axpby_a2 generic, public :: axpby => axpby_v, axpby_a, axpby_v2, axpby_a2 procedure, pass(z) :: abgdxyz => c_vect_abgdxyz + procedure, pass(z) :: xyzw => c_vect_xyzw procedure, pass(y) :: mlt_v => c_vect_mlt_v procedure, pass(y) :: mlt_a => c_vect_mlt_a @@ -788,6 +789,22 @@ contains end subroutine c_vect_abgdxyz + subroutine c_vect_xyzw(m,a,b,c,d,e,f,x, y, z, w, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_c_vect_type), intent(inout) :: x + class(psb_c_vect_type), intent(inout) :: y + class(psb_c_vect_type), intent(inout) :: z + class(psb_c_vect_type), intent(inout) :: w + complex(psb_spk_), intent (in) :: a, b, c, d, e, f + integer(psb_ipk_), intent(out) :: info + + if (allocated(w%v)) & + call w%v%xyzw(m,a,b,c,d,e,f,x%v,y%v,z%v,info) + + end subroutine c_vect_xyzw + subroutine c_vect_mlt_v(x, y, info) use psi_serial_mod diff --git a/base/modules/serial/psb_d_base_vect_mod.F90 b/base/modules/serial/psb_d_base_vect_mod.F90 index 8f583cd3..1ad1ffa5 100644 --- a/base/modules/serial/psb_d_base_vect_mod.F90 +++ b/base/modules/serial/psb_d_base_vect_mod.F90 @@ -156,6 +156,7 @@ module psb_d_base_vect_mod procedure, pass(z) :: axpby_a2 => d_base_axpby_a2 generic, public :: axpby => axpby_v, axpby_a, axpby_v2, axpby_a2 procedure, pass(z) :: abgdxyz => d_base_abgdxyz + procedure, pass(w) :: xyzw => d_base_xyzw ! ! Vector by vector multiplication. Need all variants @@ -1162,22 +1163,37 @@ contains real(psb_dpk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info - if (.false.) then - if (x%is_dev()) call x%sync() - - call y%axpby(m,alpha,x,beta,info) - call z%axpby(m,gamma,y,delta,info) - else - if (x%is_dev().and.(alpha/=dzero)) call x%sync() - if (y%is_dev().and.(beta/=dzero)) call y%sync() - if (z%is_dev().and.(delta/=dzero)) call z%sync() - call psi_dabgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) - call y%set_host() - call z%set_host() - end if - + if (x%is_dev().and.(alpha/=dzero)) call x%sync() + if (y%is_dev().and.(beta/=dzero)) call y%sync() + if (z%is_dev().and.(delta/=dzero)) call z%sync() + call psi_abgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) + call y%set_host() + call z%set_host() + end subroutine d_base_abgdxyz + subroutine d_base_xyzw(m,a,b,c,d,e,f,x, y, z, w,info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_d_base_vect_type), intent(inout) :: x + class(psb_d_base_vect_type), intent(inout) :: y + class(psb_d_base_vect_type), intent(inout) :: z + class(psb_d_base_vect_type), intent(inout) :: w + real(psb_dpk_), intent (in) :: a,b,c,d,e,f + integer(psb_ipk_), intent(out) :: info + + if (x%is_dev().and.(a/=dzero)) call x%sync() + if (y%is_dev().and.(b/=dzero)) call y%sync() + if (z%is_dev().and.(d/=dzero)) call z%sync() + if (w%is_dev().and.(f/=dzero)) call w%sync() + call psi_xyzw(m,a,b,c,d,e,f,x%v, y%v, z%v, w%v, info) + call y%set_host() + call z%set_host() + call w%set_host() + + end subroutine d_base_xyzw + ! ! Multiple variants of two operations: diff --git a/base/modules/serial/psb_d_vect_mod.F90 b/base/modules/serial/psb_d_vect_mod.F90 index 07007452..55dd8230 100644 --- a/base/modules/serial/psb_d_vect_mod.F90 +++ b/base/modules/serial/psb_d_vect_mod.F90 @@ -103,6 +103,7 @@ module psb_d_vect_mod procedure, pass(z) :: axpby_a2 => d_vect_axpby_a2 generic, public :: axpby => axpby_v, axpby_a, axpby_v2, axpby_a2 procedure, pass(z) :: abgdxyz => d_vect_abgdxyz + procedure, pass(z) :: xyzw => d_vect_xyzw procedure, pass(y) :: mlt_v => d_vect_mlt_v procedure, pass(y) :: mlt_a => d_vect_mlt_a @@ -795,6 +796,22 @@ contains end subroutine d_vect_abgdxyz + subroutine d_vect_xyzw(m,a,b,c,d,e,f,x, y, z, w, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_d_vect_type), intent(inout) :: x + class(psb_d_vect_type), intent(inout) :: y + class(psb_d_vect_type), intent(inout) :: z + class(psb_d_vect_type), intent(inout) :: w + real(psb_dpk_), intent (in) :: a, b, c, d, e, f + integer(psb_ipk_), intent(out) :: info + + if (allocated(w%v)) & + call w%v%xyzw(m,a,b,c,d,e,f,x%v,y%v,z%v,info) + + end subroutine d_vect_xyzw + subroutine d_vect_mlt_v(x, y, info) use psi_serial_mod diff --git a/base/modules/serial/psb_s_base_vect_mod.F90 b/base/modules/serial/psb_s_base_vect_mod.F90 index 85bb3bda..26b82c31 100644 --- a/base/modules/serial/psb_s_base_vect_mod.F90 +++ b/base/modules/serial/psb_s_base_vect_mod.F90 @@ -156,6 +156,7 @@ module psb_s_base_vect_mod procedure, pass(z) :: axpby_a2 => s_base_axpby_a2 generic, public :: axpby => axpby_v, axpby_a, axpby_v2, axpby_a2 procedure, pass(z) :: abgdxyz => s_base_abgdxyz + procedure, pass(w) :: xyzw => s_base_xyzw ! ! Vector by vector multiplication. Need all variants @@ -1162,22 +1163,37 @@ contains real(psb_spk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info - if (.false.) then - if (x%is_dev()) call x%sync() - - call y%axpby(m,alpha,x,beta,info) - call z%axpby(m,gamma,y,delta,info) - else - if (x%is_dev().and.(alpha/=szero)) call x%sync() - if (y%is_dev().and.(beta/=szero)) call y%sync() - if (z%is_dev().and.(delta/=szero)) call z%sync() - call psi_sabgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) - call y%set_host() - call z%set_host() - end if - + if (x%is_dev().and.(alpha/=szero)) call x%sync() + if (y%is_dev().and.(beta/=szero)) call y%sync() + if (z%is_dev().and.(delta/=szero)) call z%sync() + call psi_abgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) + call y%set_host() + call z%set_host() + end subroutine s_base_abgdxyz + subroutine s_base_xyzw(m,a,b,c,d,e,f,x, y, z, w,info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_s_base_vect_type), intent(inout) :: x + class(psb_s_base_vect_type), intent(inout) :: y + class(psb_s_base_vect_type), intent(inout) :: z + class(psb_s_base_vect_type), intent(inout) :: w + real(psb_spk_), intent (in) :: a,b,c,d,e,f + integer(psb_ipk_), intent(out) :: info + + if (x%is_dev().and.(a/=szero)) call x%sync() + if (y%is_dev().and.(b/=szero)) call y%sync() + if (z%is_dev().and.(d/=szero)) call z%sync() + if (w%is_dev().and.(f/=szero)) call w%sync() + call psi_xyzw(m,a,b,c,d,e,f,x%v, y%v, z%v, w%v, info) + call y%set_host() + call z%set_host() + call w%set_host() + + end subroutine s_base_xyzw + ! ! Multiple variants of two operations: diff --git a/base/modules/serial/psb_s_vect_mod.F90 b/base/modules/serial/psb_s_vect_mod.F90 index aa16a04d..a50b2a0a 100644 --- a/base/modules/serial/psb_s_vect_mod.F90 +++ b/base/modules/serial/psb_s_vect_mod.F90 @@ -103,6 +103,7 @@ module psb_s_vect_mod procedure, pass(z) :: axpby_a2 => s_vect_axpby_a2 generic, public :: axpby => axpby_v, axpby_a, axpby_v2, axpby_a2 procedure, pass(z) :: abgdxyz => s_vect_abgdxyz + procedure, pass(z) :: xyzw => s_vect_xyzw procedure, pass(y) :: mlt_v => s_vect_mlt_v procedure, pass(y) :: mlt_a => s_vect_mlt_a @@ -795,6 +796,22 @@ contains end subroutine s_vect_abgdxyz + subroutine s_vect_xyzw(m,a,b,c,d,e,f,x, y, z, w, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_s_vect_type), intent(inout) :: x + class(psb_s_vect_type), intent(inout) :: y + class(psb_s_vect_type), intent(inout) :: z + class(psb_s_vect_type), intent(inout) :: w + real(psb_spk_), intent (in) :: a, b, c, d, e, f + integer(psb_ipk_), intent(out) :: info + + if (allocated(w%v)) & + call w%v%xyzw(m,a,b,c,d,e,f,x%v,y%v,z%v,info) + + end subroutine s_vect_xyzw + subroutine s_vect_mlt_v(x, y, info) use psi_serial_mod diff --git a/base/modules/serial/psb_z_base_vect_mod.F90 b/base/modules/serial/psb_z_base_vect_mod.F90 index b30b1586..a3afc9c1 100644 --- a/base/modules/serial/psb_z_base_vect_mod.F90 +++ b/base/modules/serial/psb_z_base_vect_mod.F90 @@ -156,6 +156,7 @@ module psb_z_base_vect_mod procedure, pass(z) :: axpby_a2 => z_base_axpby_a2 generic, public :: axpby => axpby_v, axpby_a, axpby_v2, axpby_a2 procedure, pass(z) :: abgdxyz => z_base_abgdxyz + procedure, pass(w) :: xyzw => z_base_xyzw ! ! Vector by vector multiplication. Need all variants @@ -1155,22 +1156,37 @@ contains complex(psb_dpk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info - if (.false.) then - if (x%is_dev()) call x%sync() - - call y%axpby(m,alpha,x,beta,info) - call z%axpby(m,gamma,y,delta,info) - else - if (x%is_dev().and.(alpha/=zzero)) call x%sync() - if (y%is_dev().and.(beta/=zzero)) call y%sync() - if (z%is_dev().and.(delta/=zzero)) call z%sync() - call psi_zabgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) - call y%set_host() - call z%set_host() - end if - + if (x%is_dev().and.(alpha/=zzero)) call x%sync() + if (y%is_dev().and.(beta/=zzero)) call y%sync() + if (z%is_dev().and.(delta/=zzero)) call z%sync() + call psi_abgdxyz(m,alpha, beta, gamma,delta,x%v, y%v, z%v, info) + call y%set_host() + call z%set_host() + end subroutine z_base_abgdxyz + subroutine z_base_xyzw(m,a,b,c,d,e,f,x, y, z, w,info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_z_base_vect_type), intent(inout) :: x + class(psb_z_base_vect_type), intent(inout) :: y + class(psb_z_base_vect_type), intent(inout) :: z + class(psb_z_base_vect_type), intent(inout) :: w + complex(psb_dpk_), intent (in) :: a,b,c,d,e,f + integer(psb_ipk_), intent(out) :: info + + if (x%is_dev().and.(a/=zzero)) call x%sync() + if (y%is_dev().and.(b/=zzero)) call y%sync() + if (z%is_dev().and.(d/=zzero)) call z%sync() + if (w%is_dev().and.(f/=zzero)) call w%sync() + call psi_xyzw(m,a,b,c,d,e,f,x%v, y%v, z%v, w%v, info) + call y%set_host() + call z%set_host() + call w%set_host() + + end subroutine z_base_xyzw + ! ! Multiple variants of two operations: diff --git a/base/modules/serial/psb_z_vect_mod.F90 b/base/modules/serial/psb_z_vect_mod.F90 index 58bf6b18..21e0c546 100644 --- a/base/modules/serial/psb_z_vect_mod.F90 +++ b/base/modules/serial/psb_z_vect_mod.F90 @@ -103,6 +103,7 @@ module psb_z_vect_mod procedure, pass(z) :: axpby_a2 => z_vect_axpby_a2 generic, public :: axpby => axpby_v, axpby_a, axpby_v2, axpby_a2 procedure, pass(z) :: abgdxyz => z_vect_abgdxyz + procedure, pass(z) :: xyzw => z_vect_xyzw procedure, pass(y) :: mlt_v => z_vect_mlt_v procedure, pass(y) :: mlt_a => z_vect_mlt_a @@ -788,6 +789,22 @@ contains end subroutine z_vect_abgdxyz + subroutine z_vect_xyzw(m,a,b,c,d,e,f,x, y, z, w, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_z_vect_type), intent(inout) :: x + class(psb_z_vect_type), intent(inout) :: y + class(psb_z_vect_type), intent(inout) :: z + class(psb_z_vect_type), intent(inout) :: w + complex(psb_dpk_), intent (in) :: a, b, c, d, e, f + integer(psb_ipk_), intent(out) :: info + + if (allocated(w%v)) & + call w%v%xyzw(m,a,b,c,d,e,f,x%v,y%v,z%v,info) + + end subroutine z_vect_xyzw + subroutine z_vect_mlt_v(x, y, info) use psi_serial_mod diff --git a/base/serial/psi_c_serial_impl.F90 b/base/serial/psi_c_serial_impl.F90 index 557220e5..e230a1e0 100644 --- a/base/serial/psi_c_serial_impl.F90 +++ b/base/serial/psi_c_serial_impl.F90 @@ -1792,3 +1792,75 @@ subroutine psi_cabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) return end subroutine psi_cabgdxyz + +subroutine psi_cxyzw(m,a,b,c,d,e,f,x, y, z,w, info) + use psb_const_mod + use psb_error_mod + implicit none + integer(psb_ipk_), intent(in) :: m + complex(psb_spk_), intent (in) :: x(:) + complex(psb_spk_), intent (inout) :: y(:) + complex(psb_spk_), intent (inout) :: z(:) + complex(psb_spk_), intent (inout) :: w(:) + complex(psb_spk_), intent (in) :: a,b,c,d,e,f + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i + integer(psb_ipk_) :: int_err(5) + character name*20 + name='cabgdxyz' + + info = psb_success_ + if (m.lt.0) then + info=psb_err_iarg_neg_ + int_err(1)=1 + int_err(2)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(x).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=6 + int_err(2)=1 + int_err(3)=size(x) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(y).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=7 + int_err(2)=1 + int_err(3)=size(y) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(z).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=8 + int_err(2)=1 + int_err(3)=size(z) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + endif + + if ((a==czero).or.(b==czero).or. & + & (c==czero).or.(d==czero).or.& + & (e==czero).or.(f==czero)) then + write(0,*) 'XYZW assumes a,b,c,d,e,f are all nonzero' + else + !$omp parallel do private(i) + do i=1,m + y(i) = a*x(i)+b*y(i) + z(i) = c*y(i)+d*z(i) + w(i) = e*z(i)+f*w(i) + end do + + end if + + return + +9999 continue + call fcpsb_serror() + return + +end subroutine psi_cxyzw diff --git a/base/serial/psi_d_serial_impl.F90 b/base/serial/psi_d_serial_impl.F90 index d423b401..bf1b2917 100644 --- a/base/serial/psi_d_serial_impl.F90 +++ b/base/serial/psi_d_serial_impl.F90 @@ -1792,3 +1792,75 @@ subroutine psi_dabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) return end subroutine psi_dabgdxyz + +subroutine psi_dxyzw(m,a,b,c,d,e,f,x, y, z,w, info) + use psb_const_mod + use psb_error_mod + implicit none + integer(psb_ipk_), intent(in) :: m + real(psb_dpk_), intent (in) :: x(:) + real(psb_dpk_), intent (inout) :: y(:) + real(psb_dpk_), intent (inout) :: z(:) + real(psb_dpk_), intent (inout) :: w(:) + real(psb_dpk_), intent (in) :: a,b,c,d,e,f + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i + integer(psb_ipk_) :: int_err(5) + character name*20 + name='dabgdxyz' + + info = psb_success_ + if (m.lt.0) then + info=psb_err_iarg_neg_ + int_err(1)=1 + int_err(2)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(x).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=6 + int_err(2)=1 + int_err(3)=size(x) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(y).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=7 + int_err(2)=1 + int_err(3)=size(y) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(z).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=8 + int_err(2)=1 + int_err(3)=size(z) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + endif + + if ((a==dzero).or.(b==dzero).or. & + & (c==dzero).or.(d==dzero).or.& + & (e==dzero).or.(f==dzero)) then + write(0,*) 'XYZW assumes a,b,c,d,e,f are all nonzero' + else + !$omp parallel do private(i) + do i=1,m + y(i) = a*x(i)+b*y(i) + z(i) = c*y(i)+d*z(i) + w(i) = e*z(i)+f*w(i) + end do + + end if + + return + +9999 continue + call fcpsb_serror() + return + +end subroutine psi_dxyzw diff --git a/base/serial/psi_e_serial_impl.F90 b/base/serial/psi_e_serial_impl.F90 index c7977c35..911ab4ec 100644 --- a/base/serial/psi_e_serial_impl.F90 +++ b/base/serial/psi_e_serial_impl.F90 @@ -1792,3 +1792,75 @@ subroutine psi_eabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) return end subroutine psi_eabgdxyz + +subroutine psi_exyzw(m,a,b,c,d,e,f,x, y, z,w, info) + use psb_const_mod + use psb_error_mod + implicit none + integer(psb_ipk_), intent(in) :: m + integer(psb_epk_), intent (in) :: x(:) + integer(psb_epk_), intent (inout) :: y(:) + integer(psb_epk_), intent (inout) :: z(:) + integer(psb_epk_), intent (inout) :: w(:) + integer(psb_epk_), intent (in) :: a,b,c,d,e,f + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i + integer(psb_ipk_) :: int_err(5) + character name*20 + name='eabgdxyz' + + info = psb_success_ + if (m.lt.0) then + info=psb_err_iarg_neg_ + int_err(1)=1 + int_err(2)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(x).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=6 + int_err(2)=1 + int_err(3)=size(x) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(y).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=7 + int_err(2)=1 + int_err(3)=size(y) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(z).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=8 + int_err(2)=1 + int_err(3)=size(z) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + endif + + if ((a==ezero).or.(b==ezero).or. & + & (c==ezero).or.(d==ezero).or.& + & (e==ezero).or.(f==ezero)) then + write(0,*) 'XYZW assumes a,b,c,d,e,f are all nonzero' + else + !$omp parallel do private(i) + do i=1,m + y(i) = a*x(i)+b*y(i) + z(i) = c*y(i)+d*z(i) + w(i) = e*z(i)+f*w(i) + end do + + end if + + return + +9999 continue + call fcpsb_serror() + return + +end subroutine psi_exyzw diff --git a/base/serial/psi_i2_serial_impl.F90 b/base/serial/psi_i2_serial_impl.F90 index ce4aff80..fb42dfcd 100644 --- a/base/serial/psi_i2_serial_impl.F90 +++ b/base/serial/psi_i2_serial_impl.F90 @@ -1792,3 +1792,75 @@ subroutine psi_i2abgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) return end subroutine psi_i2abgdxyz + +subroutine psi_i2xyzw(m,a,b,c,d,e,f,x, y, z,w, info) + use psb_const_mod + use psb_error_mod + implicit none + integer(psb_ipk_), intent(in) :: m + integer(psb_i2pk_), intent (in) :: x(:) + integer(psb_i2pk_), intent (inout) :: y(:) + integer(psb_i2pk_), intent (inout) :: z(:) + integer(psb_i2pk_), intent (inout) :: w(:) + integer(psb_i2pk_), intent (in) :: a,b,c,d,e,f + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i + integer(psb_ipk_) :: int_err(5) + character name*20 + name='i2abgdxyz' + + info = psb_success_ + if (m.lt.0) then + info=psb_err_iarg_neg_ + int_err(1)=1 + int_err(2)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(x).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=6 + int_err(2)=1 + int_err(3)=size(x) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(y).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=7 + int_err(2)=1 + int_err(3)=size(y) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(z).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=8 + int_err(2)=1 + int_err(3)=size(z) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + endif + + if ((a==i2zero).or.(b==i2zero).or. & + & (c==i2zero).or.(d==i2zero).or.& + & (e==i2zero).or.(f==i2zero)) then + write(0,*) 'XYZW assumes a,b,c,d,e,f are all nonzero' + else + !$omp parallel do private(i) + do i=1,m + y(i) = a*x(i)+b*y(i) + z(i) = c*y(i)+d*z(i) + w(i) = e*z(i)+f*w(i) + end do + + end if + + return + +9999 continue + call fcpsb_serror() + return + +end subroutine psi_i2xyzw diff --git a/base/serial/psi_m_serial_impl.F90 b/base/serial/psi_m_serial_impl.F90 index 8d9d19f4..346fd897 100644 --- a/base/serial/psi_m_serial_impl.F90 +++ b/base/serial/psi_m_serial_impl.F90 @@ -1792,3 +1792,75 @@ subroutine psi_mabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) return end subroutine psi_mabgdxyz + +subroutine psi_mxyzw(m,a,b,c,d,e,f,x, y, z,w, info) + use psb_const_mod + use psb_error_mod + implicit none + integer(psb_ipk_), intent(in) :: m + integer(psb_mpk_), intent (in) :: x(:) + integer(psb_mpk_), intent (inout) :: y(:) + integer(psb_mpk_), intent (inout) :: z(:) + integer(psb_mpk_), intent (inout) :: w(:) + integer(psb_mpk_), intent (in) :: a,b,c,d,e,f + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i + integer(psb_ipk_) :: int_err(5) + character name*20 + name='mabgdxyz' + + info = psb_success_ + if (m.lt.0) then + info=psb_err_iarg_neg_ + int_err(1)=1 + int_err(2)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(x).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=6 + int_err(2)=1 + int_err(3)=size(x) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(y).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=7 + int_err(2)=1 + int_err(3)=size(y) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(z).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=8 + int_err(2)=1 + int_err(3)=size(z) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + endif + + if ((a==mzero).or.(b==mzero).or. & + & (c==mzero).or.(d==mzero).or.& + & (e==mzero).or.(f==mzero)) then + write(0,*) 'XYZW assumes a,b,c,d,e,f are all nonzero' + else + !$omp parallel do private(i) + do i=1,m + y(i) = a*x(i)+b*y(i) + z(i) = c*y(i)+d*z(i) + w(i) = e*z(i)+f*w(i) + end do + + end if + + return + +9999 continue + call fcpsb_serror() + return + +end subroutine psi_mxyzw diff --git a/base/serial/psi_s_serial_impl.F90 b/base/serial/psi_s_serial_impl.F90 index df251b27..52f86bcd 100644 --- a/base/serial/psi_s_serial_impl.F90 +++ b/base/serial/psi_s_serial_impl.F90 @@ -1792,3 +1792,75 @@ subroutine psi_sabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) return end subroutine psi_sabgdxyz + +subroutine psi_sxyzw(m,a,b,c,d,e,f,x, y, z,w, info) + use psb_const_mod + use psb_error_mod + implicit none + integer(psb_ipk_), intent(in) :: m + real(psb_spk_), intent (in) :: x(:) + real(psb_spk_), intent (inout) :: y(:) + real(psb_spk_), intent (inout) :: z(:) + real(psb_spk_), intent (inout) :: w(:) + real(psb_spk_), intent (in) :: a,b,c,d,e,f + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i + integer(psb_ipk_) :: int_err(5) + character name*20 + name='sabgdxyz' + + info = psb_success_ + if (m.lt.0) then + info=psb_err_iarg_neg_ + int_err(1)=1 + int_err(2)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(x).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=6 + int_err(2)=1 + int_err(3)=size(x) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(y).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=7 + int_err(2)=1 + int_err(3)=size(y) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(z).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=8 + int_err(2)=1 + int_err(3)=size(z) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + endif + + if ((a==szero).or.(b==szero).or. & + & (c==szero).or.(d==szero).or.& + & (e==szero).or.(f==szero)) then + write(0,*) 'XYZW assumes a,b,c,d,e,f are all nonzero' + else + !$omp parallel do private(i) + do i=1,m + y(i) = a*x(i)+b*y(i) + z(i) = c*y(i)+d*z(i) + w(i) = e*z(i)+f*w(i) + end do + + end if + + return + +9999 continue + call fcpsb_serror() + return + +end subroutine psi_sxyzw diff --git a/base/serial/psi_z_serial_impl.F90 b/base/serial/psi_z_serial_impl.F90 index 44ea5ae7..7e680273 100644 --- a/base/serial/psi_z_serial_impl.F90 +++ b/base/serial/psi_z_serial_impl.F90 @@ -1792,3 +1792,75 @@ subroutine psi_zabgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) return end subroutine psi_zabgdxyz + +subroutine psi_zxyzw(m,a,b,c,d,e,f,x, y, z,w, info) + use psb_const_mod + use psb_error_mod + implicit none + integer(psb_ipk_), intent(in) :: m + complex(psb_dpk_), intent (in) :: x(:) + complex(psb_dpk_), intent (inout) :: y(:) + complex(psb_dpk_), intent (inout) :: z(:) + complex(psb_dpk_), intent (inout) :: w(:) + complex(psb_dpk_), intent (in) :: a,b,c,d,e,f + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i + integer(psb_ipk_) :: int_err(5) + character name*20 + name='zabgdxyz' + + info = psb_success_ + if (m.lt.0) then + info=psb_err_iarg_neg_ + int_err(1)=1 + int_err(2)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(x).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=6 + int_err(2)=1 + int_err(3)=size(x) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(y).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=7 + int_err(2)=1 + int_err(3)=size(y) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + else if (size(z).lt.max(1,m)) then + info=psb_err_iarg_not_gtia_ii_ + int_err(1)=8 + int_err(2)=1 + int_err(3)=size(z) + int_err(4)=m + call fcpsb_errpush(info,name,int_err) + goto 9999 + endif + + if ((a==zzero).or.(b==zzero).or. & + & (c==zzero).or.(d==zzero).or.& + & (e==zzero).or.(f==zzero)) then + write(0,*) 'XYZW assumes a,b,c,d,e,f are all nonzero' + else + !$omp parallel do private(i) + do i=1,m + y(i) = a*x(i)+b*y(i) + z(i) = c*y(i)+d*z(i) + w(i) = e*z(i)+f*w(i) + end do + + end if + + return + +9999 continue + call fcpsb_serror() + return + +end subroutine psi_zxyzw From a11f328e62785f3e1684fb667b4a512ce7f4e77e Mon Sep 17 00:00:00 2001 From: sfilippone Date: Tue, 5 Mar 2024 12:42:21 +0100 Subject: [PATCH 47/48] Added CUDA version of XYZW --- cuda/cvectordev.c | 22 +++++++++++++ cuda/cvectordev.h | 5 +++ cuda/dvectordev.c | 19 +++++++++++ cuda/dvectordev.h | 3 ++ cuda/psb_c_cuda_vect_mod.F90 | 64 ++++++++++++++++++++++++++++++++++-- cuda/psb_c_vectordev_mod.F90 | 12 +++++++ cuda/psb_d_cuda_vect_mod.F90 | 64 ++++++++++++++++++++++++++++++++++-- cuda/psb_d_vectordev_mod.F90 | 12 +++++++ cuda/psb_s_cuda_vect_mod.F90 | 64 ++++++++++++++++++++++++++++++++++-- cuda/psb_s_vectordev_mod.F90 | 12 +++++++ cuda/psb_z_cuda_vect_mod.F90 | 64 ++++++++++++++++++++++++++++++++++-- cuda/psb_z_vectordev_mod.F90 | 12 +++++++ cuda/spgpu/kernels/Makefile | 3 +- cuda/spgpu/vector.h | 46 ++++++++++++++++++++++++++ cuda/svectordev.c | 21 ++++++++++++ cuda/svectordev.h | 3 ++ cuda/zvectordev.c | 24 +++++++++++++- cuda/zvectordev.h | 5 +++ 18 files changed, 445 insertions(+), 10 deletions(-) diff --git a/cuda/cvectordev.c b/cuda/cvectordev.c index 9db5202e..cdfda481 100644 --- a/cuda/cvectordev.c +++ b/cuda/cvectordev.c @@ -273,6 +273,28 @@ int abgdxyzMultiVecDeviceFloatComplex(int n,cuFloatComplex alpha,cuFloatComplex return(i); } +int xyzwMultiVecDeviceFloatComplex(int n,cuFloatComplex a,cuFloatComplex b, + cuFloatComplex c, cuFloatComplex d, + cuFloatComplex e, cuFloatComplex f, + void* devMultiVecX, void* devMultiVecY, + void* devMultiVecZ, void* devMultiVecW) +{ int j=0, i=0; + int pitch = 0; + struct MultiVectDevice *devVecX = (struct MultiVectDevice *) devMultiVecX; + struct MultiVectDevice *devVecY = (struct MultiVectDevice *) devMultiVecY; + struct MultiVectDevice *devVecZ = (struct MultiVectDevice *) devMultiVecZ; + struct MultiVectDevice *devVecW = (struct MultiVectDevice *) devMultiVecW; + spgpuHandle_t handle=psb_cudaGetHandle(); + pitch = devVecY->pitch_; + if ((n > devVecY->size_) || (n>devVecX->size_ )) + return SPGPU_UNSUPPORTED; + + spgpuCxyzw(handle,n, a,b,c,d,e,f, + (cuFloatComplex *)devVecX->v_,(cuFloatComplex *) devVecY->v_, + (cuFloatComplex *) devVecZ->v_,(cuFloatComplex *) devVecW->v_); + return(i); +} + int axyMultiVecDeviceFloatComplex(int n, cuFloatComplex alpha, void *deviceVecA, void *deviceVecB) { int i = 0; diff --git a/cuda/cvectordev.h b/cuda/cvectordev.h index fc18e328..62693e27 100644 --- a/cuda/cvectordev.h +++ b/cuda/cvectordev.h @@ -72,6 +72,11 @@ int axpbyMultiVecDeviceFloatComplex(int n, cuFloatComplex alpha, void* devVecX, int abgdxyzMultiVecDeviceFloatComplex(int n,cuFloatComplex alpha,cuFloatComplex beta, cuFloatComplex gamma, cuFloatComplex delta, void* devMultiVecX, void* devMultiVecY, void* devMultiVecZ); +int xyzwMultiVecDeviceFloatComplex(int n,cuFloatComplex a,cuFloatComplex b, + cuFloatComplex c, cuFloatComplex d, + cuFloatComplex e, cuFloatComplex f, + void* devMultiVecX, void* devMultiVecY, + void* devMultiVecZ, void* devMultiVecW); int axyMultiVecDeviceFloatComplex(int n, cuFloatComplex alpha, void *deviceVecA, void *deviceVecB); int axybzMultiVecDeviceFloatComplex(int n, cuFloatComplex alpha, void *deviceVecA, void *deviceVecB, cuFloatComplex beta, void *deviceVecZ); diff --git a/cuda/dvectordev.c b/cuda/dvectordev.c index b4ca95f4..723f48d8 100644 --- a/cuda/dvectordev.c +++ b/cuda/dvectordev.c @@ -258,6 +258,25 @@ int abgdxyzMultiVecDeviceDouble(int n,double alpha,double beta, double gamma, do return(i); } +int xyzwMultiVecDeviceDouble(int n,double a, double b, double c, double d, double e, double f, + void* devMultiVecX, void* devMultiVecY, + void* devMultiVecZ, void* devMultiVecW) +{ int j=0, i=0; + int pitch = 0; + struct MultiVectDevice *devVecX = (struct MultiVectDevice *) devMultiVecX; + struct MultiVectDevice *devVecY = (struct MultiVectDevice *) devMultiVecY; + struct MultiVectDevice *devVecZ = (struct MultiVectDevice *) devMultiVecZ; + struct MultiVectDevice *devVecW = (struct MultiVectDevice *) devMultiVecW; + spgpuHandle_t handle=psb_cudaGetHandle(); + pitch = devVecY->pitch_; + if ((n > devVecY->size_) || (n>devVecX->size_ )) + return SPGPU_UNSUPPORTED; + + spgpuDxyzw(handle,n, a,b,c,d,e,f, + (double*)devVecX->v_,(double*) devVecY->v_,(double*) devVecZ->v_,(double*) devVecW->v_); + return(i); +} + int axyMultiVecDeviceDouble(int n, double alpha, void *deviceVecA, void *deviceVecB) { int i = 0; struct MultiVectDevice *devVecA = (struct MultiVectDevice *) deviceVecA; diff --git a/cuda/dvectordev.h b/cuda/dvectordev.h index 81a2e8f6..c2bfa1b5 100644 --- a/cuda/dvectordev.h +++ b/cuda/dvectordev.h @@ -69,6 +69,9 @@ int dotMultiVecDeviceDouble(double* y_res, int n, void* devVecA, void* devVecB); int axpbyMultiVecDeviceDouble(int n, double alpha, void* devVecX, double beta, void* devVecY); int abgdxyzMultiVecDeviceDouble(int n,double alpha,double beta, double gamma, double delta, void* devMultiVecX, void* devMultiVecY, void* devMultiVecZ); +int xyzwMultiVecDeviceDouble(int n,double a, double b, double c, double d, double e, double f, + void* devMultiVecX, void* devMultiVecY, + void* devMultiVecZ, void* devMultiVecW); int axyMultiVecDeviceDouble(int n, double alpha, void *deviceVecA, void *deviceVecB); int axybzMultiVecDeviceDouble(int n, double alpha, void *deviceVecA, void *deviceVecB, double beta, void *deviceVecZ); diff --git a/cuda/psb_c_cuda_vect_mod.F90 b/cuda/psb_c_cuda_vect_mod.F90 index 7eee128f..727249df 100644 --- a/cuda/psb_c_cuda_vect_mod.F90 +++ b/cuda/psb_c_cuda_vect_mod.F90 @@ -914,7 +914,6 @@ contains end subroutine c_cuda_axpby_v - subroutine c_cuda_abgdxyz(m,alpha, beta, gamma,delta,x, y, z, info) use psi_serial_mod implicit none @@ -975,9 +974,70 @@ contains call z%axpby(m,gamma,y,delta,info) end if - end subroutine c_cuda_abgdxyz + subroutine c_cuda_xyzw(m,a,b,c,d,e,f,x, y, z,w, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_c_base_vect_type), intent(inout) :: x + class(psb_c_base_vect_type), intent(inout) :: y + class(psb_c_base_vect_type), intent(inout) :: z + class(psb_c_vect_cuda), intent(inout) :: w + complex(psb_spk_), intent (in) :: a,b,c,d,e,f + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: nx, ny, nz, nw + logical :: gpu_done + + info = psb_success_ + + gpu_done = .false. + if ((a==czero).or.(b==czero).or. & + & (c==czero).or.(d==czero).or.& + & (e==czero).or.(f==czero)) then + write(0,*) 'XYZW assumes a,b,c,d,e,f are all nonzero' + else + select type(xx => x) + class is (psb_c_vect_cuda) + select type(yy => y) + class is (psb_c_vect_cuda) + select type(zz => z) + class is (psb_c_vect_cuda) + ! Do something different here + if (xx%is_host()) call xx%sync() + if (yy%is_host()) call yy%sync() + if (zz%is_host()) call zz%sync() + if (w%is_host()) call w%sync() + nx = getMultiVecDeviceSize(xx%deviceVect) + ny = getMultiVecDeviceSize(yy%deviceVect) + nz = getMultiVecDeviceSize(zz%deviceVect) + nw = getMultiVecDeviceSize(w%deviceVect) + if ((nx x) + class is (psb_d_vect_cuda) + select type(yy => y) + class is (psb_d_vect_cuda) + select type(zz => z) + class is (psb_d_vect_cuda) + ! Do something different here + if (xx%is_host()) call xx%sync() + if (yy%is_host()) call yy%sync() + if (zz%is_host()) call zz%sync() + if (w%is_host()) call w%sync() + nx = getMultiVecDeviceSize(xx%deviceVect) + ny = getMultiVecDeviceSize(yy%deviceVect) + nz = getMultiVecDeviceSize(zz%deviceVect) + nw = getMultiVecDeviceSize(w%deviceVect) + if ((nx x) + class is (psb_s_vect_cuda) + select type(yy => y) + class is (psb_s_vect_cuda) + select type(zz => z) + class is (psb_s_vect_cuda) + ! Do something different here + if (xx%is_host()) call xx%sync() + if (yy%is_host()) call yy%sync() + if (zz%is_host()) call zz%sync() + if (w%is_host()) call w%sync() + nx = getMultiVecDeviceSize(xx%deviceVect) + ny = getMultiVecDeviceSize(yy%deviceVect) + nz = getMultiVecDeviceSize(zz%deviceVect) + nw = getMultiVecDeviceSize(w%deviceVect) + if ((nx x) + class is (psb_z_vect_cuda) + select type(yy => y) + class is (psb_z_vect_cuda) + select type(zz => z) + class is (psb_z_vect_cuda) + ! Do something different here + if (xx%is_host()) call xx%sync() + if (yy%is_host()) call yy%sync() + if (zz%is_host()) call zz%sync() + if (w%is_host()) call w%sync() + nx = getMultiVecDeviceSize(xx%deviceVect) + ny = getMultiVecDeviceSize(yy%deviceVect) + nz = getMultiVecDeviceSize(zz%deviceVect) + nw = getMultiVecDeviceSize(w%deviceVect) + if ((nxpitch_; + if ((n > devVecY->size_) || (n>devVecX->size_ )) + return SPGPU_UNSUPPORTED; + + spgpuSxyzw(handle,n, a,b,c,d,e,f, + (float*)devVecX->v_,(float*) devVecY->v_, + (float*) devVecZ->v_,(float*) devVecW->v_); + return(i); +} + int axyMultiVecDeviceFloat(int n, float alpha, void *deviceVecA, void *deviceVecB) { int i = 0; struct MultiVectDevice *devVecA = (struct MultiVectDevice *) deviceVecA; diff --git a/cuda/svectordev.h b/cuda/svectordev.h index 730f929a..363c0108 100644 --- a/cuda/svectordev.h +++ b/cuda/svectordev.h @@ -69,6 +69,9 @@ int dotMultiVecDeviceFloat(float* y_res, int n, void* devVecA, void* devVecB); int axpbyMultiVecDeviceFloat(int n, float alpha, void* devVecX, float beta, void* devVecY); int abgdxyzMultiVecDeviceFloat(int n,float alpha,float beta, float gamma, float delta, void* devMultiVecX, void* devMultiVecY, void* devMultiVecZ); +int xyzwMultiVecDeviceFloat(int n,float a,float b, float c, float d, float e, float f, + void* devMultiVecX, void* devMultiVecY, + void* devMultiVecZ, void* devMultiVecW); int axyMultiVecDeviceFloat(int n, float alpha, void *deviceVecA, void *deviceVecB); int axybzMultiVecDeviceFloat(int n, float alpha, void *deviceVecA, void *deviceVecB, float beta, void *deviceVecZ); diff --git a/cuda/zvectordev.c b/cuda/zvectordev.c index d1f23f2a..e9f0cec7 100644 --- a/cuda/zvectordev.c +++ b/cuda/zvectordev.c @@ -251,7 +251,29 @@ int abgdxyzMultiVecDeviceDoubleComplex(int n,cuDoubleComplex alpha, (cuDoubleComplex *)devVecX->v_,(cuDoubleComplex *) devVecY->v_,(cuDoubleComplex *) devVecZ->v_); return(i); } - + +int xyzwMultiVecDeviceDoubleComplex(int n,cuDoubleComplex a, cuDoubleComplex b, + cuDoubleComplex c, cuDoubleComplex d, + cuDoubleComplex e, cuDoubleComplex f, + void* devMultiVecX, void* devMultiVecY, + void* devMultiVecZ, void* devMultiVecW) +{ int j=0, i=0; + int pitch = 0; + struct MultiVectDevice *devVecX = (struct MultiVectDevice *) devMultiVecX; + struct MultiVectDevice *devVecY = (struct MultiVectDevice *) devMultiVecY; + struct MultiVectDevice *devVecZ = (struct MultiVectDevice *) devMultiVecZ; + struct MultiVectDevice *devVecW = (struct MultiVectDevice *) devMultiVecW; + spgpuHandle_t handle=psb_cudaGetHandle(); + pitch = devVecY->pitch_; + if ((n > devVecY->size_) || (n>devVecX->size_ )) + return SPGPU_UNSUPPORTED; + + spgpuZxyzw(handle,n, a,b,c,d,e,f, + (cuDoubleComplex *)devVecX->v_,(cuDoubleComplex *) devVecY->v_, + (cuDoubleComplex *) devVecZ->v_,(cuDoubleComplex *) devVecW->v_); + return(i); +} + int axpbyMultiVecDeviceDoubleComplex(int n,cuDoubleComplex alpha, void* devMultiVecX, cuDoubleComplex beta, void* devMultiVecY) { int j=0, i=0; diff --git a/cuda/zvectordev.h b/cuda/zvectordev.h index 4c32f11c..ae623bdb 100644 --- a/cuda/zvectordev.h +++ b/cuda/zvectordev.h @@ -80,6 +80,11 @@ int axpbyMultiVecDeviceDoubleComplex(int n, cuDoubleComplex alpha, void* devVecX int abgdxyzMultiVecDeviceDoubleComplex(int n,cuDoubleComplex alpha, cuDoubleComplex beta, cuDoubleComplex gamma, cuDoubleComplex delta, void* devMultiVecX, void* devMultiVecY, void* devMultiVecZ); +int xyzwMultiVecDeviceDoubleComplex(int n,cuDoubleComplex a, cuDoubleComplex b, + cuDoubleComplex c, cuDoubleComplex d, + cuDoubleComplex e, cuDoubleComplex f, + void* devMultiVecX, void* devMultiVecY, + void* devMultiVecZ, void* devMultiVecW); int axyMultiVecDeviceDoubleComplex(int n, cuDoubleComplex alpha, void *deviceVecA, void *deviceVecB); int axybzMultiVecDeviceDoubleComplex(int n, cuDoubleComplex alpha, void *deviceVecA, From 48455190ecd4e78a43be5a5d6f9c9749cce606a2 Mon Sep 17 00:00:00 2001 From: sfilippone Date: Tue, 5 Mar 2024 13:57:03 +0100 Subject: [PATCH 48/48] Add GPU version of XYZW --- cuda/spgpu/kernels/cxyzw.cu | 78 +++++++++++++++++++++++++++++++++++++ cuda/spgpu/kernels/dxyzw.cu | 78 +++++++++++++++++++++++++++++++++++++ cuda/spgpu/kernels/sxyzw.cu | 78 +++++++++++++++++++++++++++++++++++++ cuda/spgpu/kernels/zxyzw.cu | 78 +++++++++++++++++++++++++++++++++++++ 4 files changed, 312 insertions(+) create mode 100644 cuda/spgpu/kernels/cxyzw.cu create mode 100644 cuda/spgpu/kernels/dxyzw.cu create mode 100644 cuda/spgpu/kernels/sxyzw.cu create mode 100644 cuda/spgpu/kernels/zxyzw.cu diff --git a/cuda/spgpu/kernels/cxyzw.cu b/cuda/spgpu/kernels/cxyzw.cu new file mode 100644 index 00000000..d2b332b1 --- /dev/null +++ b/cuda/spgpu/kernels/cxyzw.cu @@ -0,0 +1,78 @@ +/* + * spGPU - Sparse matrices on GPU library. + * + * Copyright (C) 2010 - 2012 + * Davide Barbieri - University of Rome Tor Vergata + * + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * version 3 as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + */ + +#include "cudadebug.h" +#include "cudalang.h" +#include + +extern "C" +{ +#include "core.h" +#include "vector.h" + int getGPUMultiProcessors(); + int getGPUMaxThreadsPerMP(); +} + + +#include "debug.h" + +#define BLOCK_SIZE 512 + +__global__ void spgpuCxyzw_krn(int n, cuFloatComplex a, cuFloatComplex b, + cuFloatComplex c, cuFloatComplex d, + cuFloatComplex e, cuFloatComplex f, + cuFloatComplex * x, cuFloatComplex *y, + cuFloatComplex *z, cuFloatComplex *w) +{ + int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; + unsigned int gridSize = blockDim.x * gridDim.x; + cuFloatComplex ty, tz; + for ( ; id < n; id +=gridSize) + //if (id,n) + { + + ty = cuCfmaf(a, x[id], cuCmulf(b,y[id])); + tz = cuCfmaf(c, ty, cuCmulf(d,z[id])); + w[id] = cuCfmaf(e, tz, cuCmulf(f,w[id])); + y[id] = ty; + z[id] = tz; + } +} + + +void spgpuCxyzw(spgpuHandle_t handle, + int n, + cuFloatComplex a, cuFloatComplex b, + cuFloatComplex c, cuFloatComplex d, + cuFloatComplex e, cuFloatComplex f, + __device cuFloatComplex * x, + __device cuFloatComplex * y, + __device cuFloatComplex * z, + __device cuFloatComplex *w) +{ + int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; + int num_mp, max_threads_mp, num_blocks_mp, num_blocks; + dim3 block(BLOCK_SIZE); + num_mp = getGPUMultiProcessors(); + max_threads_mp = getGPUMaxThreadsPerMP(); + num_blocks_mp = max_threads_mp/BLOCK_SIZE; + num_blocks = num_blocks_mp*num_mp; + dim3 grid(num_blocks); + + spgpuCxyzw_krn<<currentStream>>>(n, a,b,c,d,e,f, + x, y, z,w); +} + diff --git a/cuda/spgpu/kernels/dxyzw.cu b/cuda/spgpu/kernels/dxyzw.cu new file mode 100644 index 00000000..afd36651 --- /dev/null +++ b/cuda/spgpu/kernels/dxyzw.cu @@ -0,0 +1,78 @@ +/* + * spGPU - Sparse matrices on GPU library. + * + * Copyright (C) 2010 - 2012 + * Davide Barbieri - University of Rome Tor Vergata + * + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * version 3 as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + */ + +#include "cudadebug.h" +#include "cudalang.h" +#include + +extern "C" +{ +#include "core.h" +#include "vector.h" + int getGPUMultiProcessors(); + int getGPUMaxThreadsPerMP(); +} + + +#include "debug.h" + +#define BLOCK_SIZE 512 + +__global__ void spgpuDxyzw_krn(int n, double a, double b, + double c, double d, + double e, double f, + double * x, double *y, + double *z, double *w) +{ + int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; + unsigned int gridSize = blockDim.x * gridDim.x; + double ty, tz; + for ( ; id < n; id +=gridSize) + //if (id,n) + { + + ty = PREC_DADD(PREC_DADD(a, x[id]), PREC_DMUL(b,y[id])); + tz = PREC_DADD(PREC_DADD(c, ty), PREC_DMUL(d,z[id])); + w[id] = PREC_DADD(PREC_DADD(e, tz), PREC_DMUL(f,w[id])); + y[id] = ty; + z[id] = tz; + } +} + + +void spgpuDxyzw(spgpuHandle_t handle, + int n, + double a, double b, + double c, double d, + double e, double f, + __device double * x, + __device double * y, + __device double * z, + __device double *w) +{ + int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; + int num_mp, max_threads_mp, num_blocks_mp, num_blocks; + dim3 block(BLOCK_SIZE); + num_mp = getGPUMultiProcessors(); + max_threads_mp = getGPUMaxThreadsPerMP(); + num_blocks_mp = max_threads_mp/BLOCK_SIZE; + num_blocks = num_blocks_mp*num_mp; + dim3 grid(num_blocks); + + spgpuDxyzw_krn<<currentStream>>>(n, a,b,c,d,e,f, + x, y, z,w); +} + diff --git a/cuda/spgpu/kernels/sxyzw.cu b/cuda/spgpu/kernels/sxyzw.cu new file mode 100644 index 00000000..9cedd02f --- /dev/null +++ b/cuda/spgpu/kernels/sxyzw.cu @@ -0,0 +1,78 @@ +/* + * spGPU - Sparse matrices on GPU library. + * + * Copyright (C) 2010 - 2012 + * Davide Barbieri - University of Rome Tor Vergata + * + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * version 3 as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + */ + +#include "cudadebug.h" +#include "cudalang.h" +#include + +extern "C" +{ +#include "core.h" +#include "vector.h" + int getGPUMultiProcessors(); + int getGPUMaxThreadsPerMP(); +} + + +#include "debug.h" + +#define BLOCK_SIZE 512 + +__global__ void spgpuSxyzw_krn(int n, float a, float b, + float c, float d, + float e, float f, + float * x, float *y, + float *z, float *w) +{ + int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; + unsigned int gridSize = blockDim.x * gridDim.x; + float ty, tz; + for ( ; id < n; id +=gridSize) + //if (id,n) + { + + ty = PREC_FADD(PREC_FMUL(a, x[id]), PREC_FMUL(b,y[id])); + tz = PREC_FADD(PREC_FMUL(c, ty), PREC_FMUL(d,z[id])); + w[id] = PREC_FADD(PREC_FMUL(e, tz), PREC_FMUL(f,w[id])); + y[id] = ty; + z[id] = tz; + } +} + + +void spgpuSxyzw(spgpuHandle_t handle, + int n, + float a, float b, + float c, float d, + float e, float f, + __device float * x, + __device float * y, + __device float * z, + __device float *w) +{ + int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; + int num_mp, max_threads_mp, num_blocks_mp, num_blocks; + dim3 block(BLOCK_SIZE); + num_mp = getGPUMultiProcessors(); + max_threads_mp = getGPUMaxThreadsPerMP(); + num_blocks_mp = max_threads_mp/BLOCK_SIZE; + num_blocks = num_blocks_mp*num_mp; + dim3 grid(num_blocks); + + spgpuSxyzw_krn<<currentStream>>>(n, a,b,c,d,e,f, + x, y, z,w); +} + diff --git a/cuda/spgpu/kernels/zxyzw.cu b/cuda/spgpu/kernels/zxyzw.cu new file mode 100644 index 00000000..7a61edee --- /dev/null +++ b/cuda/spgpu/kernels/zxyzw.cu @@ -0,0 +1,78 @@ +/* + * spGPU - Sparse matrices on GPU library. + * + * Copyright (C) 2010 - 2012 + * Davide Barbieri - University of Rome Tor Vergata + * + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * version 3 as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + */ + +#include "cudadebug.h" +#include "cudalang.h" +#include + +extern "C" +{ +#include "core.h" +#include "vector.h" + int getGPUMultiProcessors(); + int getGPUMaxThreadsPerMP(); +} + + +#include "debug.h" + +#define BLOCK_SIZE 512 + +__global__ void spgpuZxyzw_krn(int n, cuDoubleComplex a, cuDoubleComplex b, + cuDoubleComplex c, cuDoubleComplex d, + cuDoubleComplex e, cuDoubleComplex f, + cuDoubleComplex * x, cuDoubleComplex *y, + cuDoubleComplex *z, cuDoubleComplex *w) +{ + int id = threadIdx.x + BLOCK_SIZE*blockIdx.x; + unsigned int gridSize = blockDim.x * gridDim.x; + cuDoubleComplex ty, tz; + for ( ; id < n; id +=gridSize) + //if (id,n) + { + + ty = cuCfma(a, x[id], cuCmul(b,y[id])); + tz = cuCfma(c, ty, cuCmul(d,z[id])); + w[id] = cuCfma(e, tz, cuCmul(f,w[id])); + y[id] = ty; + z[id] = tz; + } +} + + +void spgpuZxyzw(spgpuHandle_t handle, + int n, + cuDoubleComplex a, cuDoubleComplex b, + cuDoubleComplex c, cuDoubleComplex d, + cuDoubleComplex e, cuDoubleComplex f, + __device cuDoubleComplex * x, + __device cuDoubleComplex * y, + __device cuDoubleComplex * z, + __device cuDoubleComplex *w) +{ + int msize = (n+BLOCK_SIZE-1)/BLOCK_SIZE; + int num_mp, max_threads_mp, num_blocks_mp, num_blocks; + dim3 block(BLOCK_SIZE); + num_mp = getGPUMultiProcessors(); + max_threads_mp = getGPUMaxThreadsPerMP(); + num_blocks_mp = max_threads_mp/BLOCK_SIZE; + num_blocks = num_blocks_mp*num_mp; + dim3 grid(num_blocks); + + spgpuZxyzw_krn<<currentStream>>>(n, a,b,c,d,e,f, + x, y, z,w); +} +