From 99dc3f5d939322c1a1b521410c4af93091d20ae1 Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Wed, 26 Feb 2020 18:56:06 +0000 Subject: [PATCH] New par_spmm version with 4-byte matrices (with new sphalo variant) --- base/modules/tools/psb_c_tools_mod.F90 | 26 +- base/modules/tools/psb_d_tools_mod.F90 | 26 +- base/modules/tools/psb_s_tools_mod.F90 | 26 +- base/modules/tools/psb_z_tools_mod.F90 | 26 +- base/tools/psb_c_par_csr_spspmm.f90 | 186 +++++++------ base/tools/psb_csphalo.F90 | 372 +++++++++++++++++++++++++ base/tools/psb_d_par_csr_spspmm.f90 | 186 +++++++------ base/tools/psb_dsphalo.F90 | 372 +++++++++++++++++++++++++ base/tools/psb_s_par_csr_spspmm.f90 | 186 +++++++------ base/tools/psb_ssphalo.F90 | 372 +++++++++++++++++++++++++ base/tools/psb_z_par_csr_spspmm.f90 | 186 +++++++------ base/tools/psb_zsphalo.F90 | 372 +++++++++++++++++++++++++ 12 files changed, 1964 insertions(+), 372 deletions(-) diff --git a/base/modules/tools/psb_c_tools_mod.F90 b/base/modules/tools/psb_c_tools_mod.F90 index 9f417d63..fdcc5e56 100644 --- a/base/modules/tools/psb_c_tools_mod.F90 +++ b/base/modules/tools/psb_c_tools_mod.F90 @@ -33,7 +33,8 @@ Module psb_c_tools_mod use psb_desc_mod, only : psb_desc_type, psb_spk_, psb_ipk_, psb_lpk_ use psb_c_vect_mod, only : psb_c_base_vect_type, psb_c_vect_type use psb_c_mat_mod, only : psb_cspmat_type, psb_lcspmat_type, psb_c_base_sparse_mat, & - & psb_lc_csr_sparse_mat, psb_lc_coo_sparse_mat, psb_c_coo_sparse_mat + & psb_lc_csr_sparse_mat, psb_lc_coo_sparse_mat, & + & psb_c_csr_sparse_mat, psb_c_coo_sparse_mat use psb_l_vect_mod, only : psb_l_vect_type use psb_c_multivect_mod, only : psb_c_base_multivect_type, psb_c_multivect_type @@ -221,6 +222,18 @@ Module psb_c_tools_mod integer(psb_ipk_), intent(in), optional :: data type(psb_desc_type),Intent(in), optional, target :: col_desc end Subroutine psb_lc_csr_halo + Subroutine psb_c_lc_csr_halo(a,desc_a,blk,info,rowcnv,colcnv,& + & rowscale,colscale,data,outcol_glob,col_desc) + import + implicit none + type(psb_c_csr_sparse_mat),Intent(in) :: a + type(psb_lc_csr_sparse_mat),Intent(inout) :: blk + type(psb_desc_type),intent(in), target :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, optional, intent(in) :: rowcnv,colcnv,rowscale,colscale,outcol_glob + integer(psb_ipk_), intent(in), optional :: data + type(psb_desc_type),Intent(in), optional, target :: col_desc + end Subroutine psb_c_lc_csr_halo end interface @@ -335,6 +348,17 @@ Module psb_c_tools_mod end interface interface psb_par_spspmm + subroutine psb_c_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) + import :: psb_c_csr_sparse_mat, psb_desc_type, psb_ipk_ + Implicit None + type(psb_c_csr_sparse_mat),intent(in) :: acsr + type(psb_c_csr_sparse_mat),intent(inout) :: bcsr + type(psb_c_csr_sparse_mat),intent(out) :: ccsr + type(psb_desc_type),intent(in) :: desc_a + type(psb_desc_type),intent(inout) :: desc_c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_), intent(in), optional :: data + End Subroutine psb_c_par_csr_spspmm subroutine psb_lc_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) import :: psb_lc_csr_sparse_mat, psb_desc_type, psb_ipk_ Implicit None diff --git a/base/modules/tools/psb_d_tools_mod.F90 b/base/modules/tools/psb_d_tools_mod.F90 index e49a4d7f..aa127872 100644 --- a/base/modules/tools/psb_d_tools_mod.F90 +++ b/base/modules/tools/psb_d_tools_mod.F90 @@ -33,7 +33,8 @@ Module psb_d_tools_mod use psb_desc_mod, only : psb_desc_type, psb_dpk_, psb_ipk_, psb_lpk_ use psb_d_vect_mod, only : psb_d_base_vect_type, psb_d_vect_type use psb_d_mat_mod, only : psb_dspmat_type, psb_ldspmat_type, psb_d_base_sparse_mat, & - & psb_ld_csr_sparse_mat, psb_ld_coo_sparse_mat, psb_d_coo_sparse_mat + & psb_ld_csr_sparse_mat, psb_ld_coo_sparse_mat, & + & psb_d_csr_sparse_mat, psb_d_coo_sparse_mat use psb_l_vect_mod, only : psb_l_vect_type use psb_d_multivect_mod, only : psb_d_base_multivect_type, psb_d_multivect_type @@ -221,6 +222,18 @@ Module psb_d_tools_mod integer(psb_ipk_), intent(in), optional :: data type(psb_desc_type),Intent(in), optional, target :: col_desc end Subroutine psb_ld_csr_halo + Subroutine psb_d_ld_csr_halo(a,desc_a,blk,info,rowcnv,colcnv,& + & rowscale,colscale,data,outcol_glob,col_desc) + import + implicit none + type(psb_d_csr_sparse_mat),Intent(in) :: a + type(psb_ld_csr_sparse_mat),Intent(inout) :: blk + type(psb_desc_type),intent(in), target :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, optional, intent(in) :: rowcnv,colcnv,rowscale,colscale,outcol_glob + integer(psb_ipk_), intent(in), optional :: data + type(psb_desc_type),Intent(in), optional, target :: col_desc + end Subroutine psb_d_ld_csr_halo end interface @@ -335,6 +348,17 @@ Module psb_d_tools_mod end interface interface psb_par_spspmm + subroutine psb_d_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) + import :: psb_d_csr_sparse_mat, psb_desc_type, psb_ipk_ + Implicit None + type(psb_d_csr_sparse_mat),intent(in) :: acsr + type(psb_d_csr_sparse_mat),intent(inout) :: bcsr + type(psb_d_csr_sparse_mat),intent(out) :: ccsr + type(psb_desc_type),intent(in) :: desc_a + type(psb_desc_type),intent(inout) :: desc_c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_), intent(in), optional :: data + End Subroutine psb_d_par_csr_spspmm subroutine psb_ld_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) import :: psb_ld_csr_sparse_mat, psb_desc_type, psb_ipk_ Implicit None diff --git a/base/modules/tools/psb_s_tools_mod.F90 b/base/modules/tools/psb_s_tools_mod.F90 index 0383f62b..24453728 100644 --- a/base/modules/tools/psb_s_tools_mod.F90 +++ b/base/modules/tools/psb_s_tools_mod.F90 @@ -33,7 +33,8 @@ Module psb_s_tools_mod use psb_desc_mod, only : psb_desc_type, psb_spk_, psb_ipk_, psb_lpk_ use psb_s_vect_mod, only : psb_s_base_vect_type, psb_s_vect_type use psb_s_mat_mod, only : psb_sspmat_type, psb_lsspmat_type, psb_s_base_sparse_mat, & - & psb_ls_csr_sparse_mat, psb_ls_coo_sparse_mat, psb_s_coo_sparse_mat + & psb_ls_csr_sparse_mat, psb_ls_coo_sparse_mat, & + & psb_s_csr_sparse_mat, psb_s_coo_sparse_mat use psb_l_vect_mod, only : psb_l_vect_type use psb_s_multivect_mod, only : psb_s_base_multivect_type, psb_s_multivect_type @@ -221,6 +222,18 @@ Module psb_s_tools_mod integer(psb_ipk_), intent(in), optional :: data type(psb_desc_type),Intent(in), optional, target :: col_desc end Subroutine psb_ls_csr_halo + Subroutine psb_s_ls_csr_halo(a,desc_a,blk,info,rowcnv,colcnv,& + & rowscale,colscale,data,outcol_glob,col_desc) + import + implicit none + type(psb_s_csr_sparse_mat),Intent(in) :: a + type(psb_ls_csr_sparse_mat),Intent(inout) :: blk + type(psb_desc_type),intent(in), target :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, optional, intent(in) :: rowcnv,colcnv,rowscale,colscale,outcol_glob + integer(psb_ipk_), intent(in), optional :: data + type(psb_desc_type),Intent(in), optional, target :: col_desc + end Subroutine psb_s_ls_csr_halo end interface @@ -335,6 +348,17 @@ Module psb_s_tools_mod end interface interface psb_par_spspmm + subroutine psb_s_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) + import :: psb_s_csr_sparse_mat, psb_desc_type, psb_ipk_ + Implicit None + type(psb_s_csr_sparse_mat),intent(in) :: acsr + type(psb_s_csr_sparse_mat),intent(inout) :: bcsr + type(psb_s_csr_sparse_mat),intent(out) :: ccsr + type(psb_desc_type),intent(in) :: desc_a + type(psb_desc_type),intent(inout) :: desc_c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_), intent(in), optional :: data + End Subroutine psb_s_par_csr_spspmm subroutine psb_ls_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) import :: psb_ls_csr_sparse_mat, psb_desc_type, psb_ipk_ Implicit None diff --git a/base/modules/tools/psb_z_tools_mod.F90 b/base/modules/tools/psb_z_tools_mod.F90 index 55455072..2b639fdc 100644 --- a/base/modules/tools/psb_z_tools_mod.F90 +++ b/base/modules/tools/psb_z_tools_mod.F90 @@ -33,7 +33,8 @@ Module psb_z_tools_mod use psb_desc_mod, only : psb_desc_type, psb_dpk_, psb_ipk_, psb_lpk_ use psb_z_vect_mod, only : psb_z_base_vect_type, psb_z_vect_type use psb_z_mat_mod, only : psb_zspmat_type, psb_lzspmat_type, psb_z_base_sparse_mat, & - & psb_lz_csr_sparse_mat, psb_lz_coo_sparse_mat, psb_z_coo_sparse_mat + & psb_lz_csr_sparse_mat, psb_lz_coo_sparse_mat, & + & psb_z_csr_sparse_mat, psb_z_coo_sparse_mat use psb_l_vect_mod, only : psb_l_vect_type use psb_z_multivect_mod, only : psb_z_base_multivect_type, psb_z_multivect_type @@ -221,6 +222,18 @@ Module psb_z_tools_mod integer(psb_ipk_), intent(in), optional :: data type(psb_desc_type),Intent(in), optional, target :: col_desc end Subroutine psb_lz_csr_halo + Subroutine psb_z_lz_csr_halo(a,desc_a,blk,info,rowcnv,colcnv,& + & rowscale,colscale,data,outcol_glob,col_desc) + import + implicit none + type(psb_z_csr_sparse_mat),Intent(in) :: a + type(psb_lz_csr_sparse_mat),Intent(inout) :: blk + type(psb_desc_type),intent(in), target :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, optional, intent(in) :: rowcnv,colcnv,rowscale,colscale,outcol_glob + integer(psb_ipk_), intent(in), optional :: data + type(psb_desc_type),Intent(in), optional, target :: col_desc + end Subroutine psb_z_lz_csr_halo end interface @@ -335,6 +348,17 @@ Module psb_z_tools_mod end interface interface psb_par_spspmm + subroutine psb_z_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) + import :: psb_z_csr_sparse_mat, psb_desc_type, psb_ipk_ + Implicit None + type(psb_z_csr_sparse_mat),intent(in) :: acsr + type(psb_z_csr_sparse_mat),intent(inout) :: bcsr + type(psb_z_csr_sparse_mat),intent(out) :: ccsr + type(psb_desc_type),intent(in) :: desc_a + type(psb_desc_type),intent(inout) :: desc_c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_), intent(in), optional :: data + End Subroutine psb_z_par_csr_spspmm subroutine psb_lz_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) import :: psb_lz_csr_sparse_mat, psb_desc_type, psb_ipk_ Implicit None diff --git a/base/tools/psb_c_par_csr_spspmm.f90 b/base/tools/psb_c_par_csr_spspmm.f90 index 51980ad6..058d1a62 100644 --- a/base/tools/psb_c_par_csr_spspmm.f90 +++ b/base/tools/psb_c_par_csr_spspmm.f90 @@ -61,98 +61,100 @@ ! info - integer, output. ! Error code. ! -!!$Subroutine psb_c_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) -!!$ use psb_base_mod, psb_protect_name => psb_c_par_csr_spspmm -!!$ Implicit None -!!$ -!!$ type(psb_c_csr_sparse_mat),intent(in) :: acsr -!!$ type(psb_c_csr_sparse_mat),intent(inout) :: bcsr -!!$ type(psb_c_csr_sparse_mat),intent(out) :: ccsr -!!$ type(psb_desc_type),intent(in) :: desc_a -!!$ type(psb_desc_type),intent(inout) :: desc_c -!!$ integer(psb_ipk_), intent(out) :: info -!!$ integer(psb_ipk_), intent(in), optional :: data -!!$ ! ...local scalars.... -!!$ integer(psb_ipk_) :: ictxt, np,me -!!$ integer(psb_ipk_) :: ncol, nnz -!!$ type(psb_c_csr_sparse_mat) :: tcsr1 -!!$ logical :: update_desc_c -!!$ integer(psb_ipk_) :: debug_level, debug_unit, err_act -!!$ character(len=20) :: name, ch_err -!!$ -!!$ if(psb_get_errstatus() /= 0) return -!!$ info=psb_success_ -!!$ name='psb_c_p_csr_spspmm' -!!$ call psb_erractionsave(err_act) -!!$ if (psb_errstatus_fatal()) then -!!$ info = psb_err_internal_error_ ; goto 9999 -!!$ end if -!!$ debug_unit = psb_get_debug_unit() -!!$ debug_level = psb_get_debug_level() -!!$ -!!$ ictxt = desc_a%get_context() -!!$ -!!$ call psb_info(ictxt, me, np) -!!$ -!!$ if (debug_level >= psb_debug_outer_) & -!!$ & write(debug_unit,*) me,' ',trim(name),': Start' -!!$ -!!$ update_desc_c = desc_c%is_bld() -!!$ -!!$ ! -!!$ ! This is a bit tricky. -!!$ ! DESC_A is the descriptor of (the columns of) A, and therefore -!!$ ! of the rows of B; the columns of B, in the intended usage, span -!!$ ! a different space for which we have DESC_C. -!!$ ! We are gathering the halo rows of B to multiply by A; -!!$ ! now, the columns of B would ideally be kept in -!!$ ! global numbering, so that we can call this repeatedly to accumulate -!!$ ! the product of multiple operators, and convert to local numbering -!!$ ! at the last possible moment. However, this would imply calling -!!$ ! the serial SPSPMM with a matrix B with the GLOBAL number of columns -!!$ ! and this could be very expensive in memory. The solution is to keep B -!!$ ! in local numbering, so that only columns really appearing count, but to -!!$ ! expand the descriptor when gathering the halo, because by performing -!!$ ! the products we are extending the support of the operator; hence -!!$ ! this routine is intended to be called with a temporary descriptor -!!$ ! DESC_C which is in the BUILD state, to allow for such expansion -!!$ ! across multiple products. -!!$ ! The caller will at some later point finalize the descriptor DESC_C. -!!$ ! -!!$ -!!$ ncol = desc_a%get_local_cols() -!!$ call psb_sphalo(bcsr,desc_a,tcsr1,info,& -!!$ & colcnv=.true.,rowscale=.true.,outcol_glob=.true.,col_desc=desc_c,data=data) -!!$ nnz = tcsr1%get_nzeros() -!!$ if (update_desc_c) then -!!$ call desc_c%indxmap%g2lip_ins(tcsr1%ja(1:nnz),info) -!!$ else -!!$ call desc_c%indxmap%g2lip(tcsr1%ja(1:nnz),info) -!!$ end if -!!$ if (info == psb_success_) call psb_rwextd(ncol,bcsr,info,b=tcsr1) -!!$ if (info == psb_success_) call tcsr1%free() -!!$ if(info /= psb_success_) then -!!$ call psb_errpush(psb_err_internal_error_,name,a_err='Extend am3') -!!$ goto 9999 -!!$ end if -!!$ call bcsr%set_ncols(desc_c%get_local_cols()) -!!$ -!!$ -!!$ if (debug_level >= psb_debug_outer_) & -!!$ & write(debug_unit,*) me,' ',trim(name),& -!!$ & 'starting spspmm 3' -!!$ if (debug_level >= psb_debug_outer_) write(debug_unit,*) me,' ',trim(name),& -!!$ & 'starting spspmm ',acsr%get_nrows(),acsr%get_ncols(),bcsr%get_nrows(),bcsr%get_ncols() -!!$ call psb_spspmm(acsr,bcsr,ccsr,info) -!!$ -!!$ call psb_erractionrestore(err_act) -!!$ return -!!$ -!!$9999 call psb_error_handler(ictxt,err_act) -!!$ -!!$ return -!!$ -!!$End Subroutine psb_c_par_csr_spspmm +Subroutine psb_c_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) + use psb_base_mod, psb_protect_name => psb_c_par_csr_spspmm + Implicit None + + type(psb_c_csr_sparse_mat),intent(in) :: acsr + type(psb_c_csr_sparse_mat),intent(inout) :: bcsr + type(psb_c_csr_sparse_mat),intent(out) :: ccsr + type(psb_desc_type),intent(in) :: desc_a + type(psb_desc_type),intent(inout) :: desc_c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_), intent(in), optional :: data + ! ...local scalars.... + integer(psb_ipk_) :: ictxt, np,me + integer(psb_ipk_) :: ncol, nnz + type(psb_lc_csr_sparse_mat) :: ltcsr + type(psb_c_csr_sparse_mat) :: tcsr + logical :: update_desc_c + integer(psb_ipk_) :: debug_level, debug_unit, err_act + character(len=20) :: name, ch_err + + if(psb_get_errstatus() /= 0) return + info=psb_success_ + name='psb_c_p_csr_spspmm' + call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if + debug_unit = psb_get_debug_unit() + debug_level = psb_get_debug_level() + + ictxt = desc_a%get_context() + + call psb_info(ictxt, me, np) + + if (debug_level >= psb_debug_outer_) & + & write(debug_unit,*) me,' ',trim(name),': Start' + + update_desc_c = desc_c%is_bld() + + ! + ! This is a bit tricky. + ! DESC_A is the descriptor of (the columns of) A, and therefore + ! of the rows of B; the columns of B, in the intended usage, span + ! a different space for which we have DESC_C. + ! We are gathering the halo rows of B to multiply by A; + ! now, the columns of B would ideally be kept in + ! global numbering, so that we can call this repeatedly to accumulate + ! the product of multiple operators, and convert to local numbering + ! at the last possible moment. However, this would imply calling + ! the serial SPSPMM with a matrix B with the GLOBAL number of columns + ! and this could be very expensive in memory. The solution is to keep B + ! in local numbering, so that only columns really appearing count, but to + ! expand the descriptor when gathering the halo, because by performing + ! the products we are extending the support of the operator; hence + ! this routine is intended to be called with a temporary descriptor + ! DESC_C which is in the BUILD state, to allow for such expansion + ! across multiple products. + ! The caller will at some later point finalize the descriptor DESC_C. + ! + + ncol = desc_a%get_local_cols() + call psb_sphalo(bcsr,desc_a,ltcsr,info,& + & colcnv=.true.,rowscale=.true.,outcol_glob=.true.,col_desc=desc_c,data=data) + nnz = ltcsr%get_nzeros() + if (update_desc_c) then + call desc_c%indxmap%g2lip_ins(ltcsr%ja(1:nnz),info) + else + call desc_c%indxmap%g2lip(ltcsr%ja(1:nnz),info) + end if + call ltcsr%mv_to_ifmt(tcsr,info) + if (info == psb_success_) call psb_rwextd(ncol,bcsr,info,b=tcsr) + if (info == psb_success_) call tcsr%free() + if(info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,a_err='Extend am3') + goto 9999 + end if + call bcsr%set_ncols(desc_c%get_local_cols()) + + + if (debug_level >= psb_debug_outer_) & + & write(debug_unit,*) me,' ',trim(name),& + & 'starting spspmm 3' + if (debug_level >= psb_debug_outer_) write(debug_unit,*) me,' ',trim(name),& + & 'starting spspmm ',acsr%get_nrows(),acsr%get_ncols(),bcsr%get_nrows(),bcsr%get_ncols() + call psb_spspmm(acsr,bcsr,ccsr,info) + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ictxt,err_act) + + return + +End Subroutine psb_c_par_csr_spspmm Subroutine psb_lc_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) use psb_base_mod, psb_protect_name => psb_lc_par_csr_spspmm diff --git a/base/tools/psb_csphalo.F90 b/base/tools/psb_csphalo.F90 index 24ebb4f6..668f7d52 100644 --- a/base/tools/psb_csphalo.F90 +++ b/base/tools/psb_csphalo.F90 @@ -1239,3 +1239,375 @@ Subroutine psb_lc_csr_halo(a,desc_a,blk,info,rowcnv,colcnv,& return End Subroutine psb_lc_csr_halo + +Subroutine psb_c_lc_csr_halo(a,desc_a,blk,info,rowcnv,colcnv,& + & rowscale,colscale,data,outcol_glob,col_desc) + use psb_base_mod, psb_protect_name => psb_c_lc_csr_halo + +#ifdef MPI_MOD + use mpi +#endif + Implicit None +#ifdef MPI_H + include 'mpif.h' +#endif + + type(psb_c_csr_sparse_mat),Intent(in) :: a + type(psb_lc_csr_sparse_mat),Intent(inout) :: blk + type(psb_desc_type),intent(in), target :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, optional, intent(in) :: rowcnv,colcnv,rowscale,colscale,outcol_glob + integer(psb_ipk_), intent(in), optional :: data + type(psb_desc_type),Intent(in), optional, target :: col_desc + ! ...local scalars.... + integer(psb_ipk_) :: ictxt, np,me + integer(psb_ipk_) :: counter,proc,i, n_el_send,n_el_recv,& + & n_elem, j,ipx,mat_recv, iszs, iszr,idxs,idxr,nz,& + & data_,totxch,ngtz, idx, nxs, nxr, err_act, & + & nsnds, nrcvs, ncg, jpx, tot_elem + integer(psb_lpk_) :: irmax,icmax,irmin,icmin,l1, lnr, lnc, lnnz, & + & r, k + integer(psb_mpk_) :: icomm, minfo + integer(psb_mpk_), allocatable :: brvindx(:), & + & rvsz(:), bsdindx(:),sdsz(:) + integer(psb_ipk_), allocatable :: iasnd(:), jasnd(:) + integer(psb_lpk_), allocatable :: liasnd(:), ljasnd(:) + complex(psb_spk_), allocatable :: valsnd(:) + type(psb_lc_coo_sparse_mat), allocatable :: acoo + class(psb_i_base_vect_type), pointer :: pdxv + integer(psb_ipk_), allocatable :: ipdxv(:) + logical :: rowcnv_,colcnv_,rowscale_,colscale_,outcol_glob_ + Type(psb_desc_type), pointer :: col_desc_ + character(len=5) :: outfmt_ + integer(psb_ipk_) :: debug_level, debug_unit + character(len=20) :: name, ch_err + + if(psb_get_errstatus() /= 0) return + info=psb_success_ + name='psb_lc_csr_sphalo' + call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if + debug_unit = psb_get_debug_unit() + debug_level = psb_get_debug_level() + + ictxt = desc_a%get_context() + icomm = desc_a%get_mpic() + + Call psb_info(ictxt, me, np) + + if (debug_level >= psb_debug_outer_) & + & write(debug_unit,*) me,' ',trim(name),': Start' + + if (present(rowcnv)) then + rowcnv_ = rowcnv + else + rowcnv_ = .true. + endif + if (present(colcnv)) then + colcnv_ = colcnv + else + colcnv_ = .true. + endif + if (present(rowscale)) then + rowscale_ = rowscale + else + rowscale_ = .false. + endif + if (present(colscale)) then + colscale_ = colscale + else + colscale_ = .false. + endif + if (present(data)) then + data_ = data + else + data_ = psb_comm_halo_ + endif + if (present(outcol_glob)) then + outcol_glob_ = outcol_glob + else + outcol_glob_ = .false. + endif + if (present(col_desc)) then + col_desc_ => col_desc + else + col_desc_ => desc_a + end if + + Allocate(brvindx(np+1),& + & rvsz(np),sdsz(np),bsdindx(np+1), acoo,stat=info) + + if (info /= psb_success_) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + + If (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Data selector',data_ + select case(data_) + case(psb_comm_halo_,psb_comm_ext_ ) + ! Do not accept OVRLAP_INDEX any longer. + case default + call psb_errpush(psb_err_from_subroutine_,name,a_err='wrong Data selector') + goto 9999 + end select + + + sdsz(:)=0 + rvsz(:)=0 + l1 = 0 + brvindx(1) = 0 + bsdindx(1) = 0 + counter=1 + idx = 0 + idxs = 0 + idxr = 0 + + call desc_a%get_list(data_,pdxv,totxch,nxr,nxs,info) + ipdxv = pdxv%get_vect() + ! For all rows in the halo descriptor, extract the row size + lnr = 0 + Do + proc=ipdxv(counter) + if (proc == -1) exit + n_el_recv = ipdxv(counter+psb_n_elem_recv_) + counter = counter+n_el_recv + n_el_send = ipdxv(counter+psb_n_elem_send_) + tot_elem = 0 + Do j=0,n_el_send-1 + idx = ipdxv(counter+psb_elem_send_+j) + n_elem = a%get_nz_row(idx) + tot_elem = tot_elem+n_elem + Enddo + sdsz(proc+1) = tot_elem + lnr = lnr + n_el_recv + counter = counter+n_el_send+3 + Enddo + + ! + ! Exchange row sizes, so as to know sends/receives. + ! This is different from the halo exchange because the + ! size of the rows may vary, as opposed to fixed + ! (multi) vector row size. + ! + call mpi_alltoall(sdsz,1,psb_mpi_mpk_,& + & rvsz,1,psb_mpi_mpk_,icomm,minfo) + + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='mpi_alltoall') + goto 9999 + end if + nsnds = count(sdsz /= 0) + nrcvs = count(rvsz /= 0) + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Done initial alltoall',nsnds,nrcvs + + idxs = 0 + idxr = 0 + counter = 1 + Do + proc=ipdxv(counter) + if (proc == -1) exit + n_el_recv = ipdxv(counter+psb_n_elem_recv_) + counter = counter+n_el_recv + n_el_send = ipdxv(counter+psb_n_elem_send_) + + bsdindx(proc+1) = idxs + idxs = idxs + sdsz(proc+1) + brvindx(proc+1) = idxr + idxr = idxr + rvsz(proc+1) + counter = counter+n_el_send+3 + Enddo + + iszr = sum(rvsz) + mat_recv = iszr + iszs = sum(sdsz) + + lnnz = max(iszr,iszs,ione) + lnc = a%get_ncols() + call acoo%allocate(lnr,lnc,lnnz) + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Sizes:',acoo%get_size(),& + & ' Send:',sdsz(:),' Receive:',rvsz(:) + + call psb_ensure_size(max(iszs,1),iasnd,info) + if (info == psb_success_) call psb_ensure_size(max(iszs,1),jasnd,info) + if (info == psb_success_) call psb_ensure_size(max(iszs,1),liasnd,info) + if (info == psb_success_) call psb_ensure_size(max(iszs,1),ljasnd,info) + if (info == psb_success_) call psb_ensure_size(max(iszs,1),valsnd,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='ensure_size') + goto 9999 + end if + + l1 = 0 + ipx = 1 + counter=1 + idx = 0 + ! + ! Make sure to get all columns in csget. + ! This is necessary when sphalo is used to compute a transpose, + ! as opposed to just gathering halo for spspmm purposes. + ! + ncg = huge(ncg) + tot_elem = 0 + Do + proc = ipdxv(counter) + if (proc == -1) exit + n_el_recv = ipdxv(counter+psb_n_elem_recv_) + counter = counter+n_el_recv + n_el_send = ipdxv(counter+psb_n_elem_send_) + + Do j=0,n_el_send-1 + idx = ipdxv(counter+psb_elem_send_+j) + n_elem = a%get_nz_row(idx) + call a%csget(idx,idx,ngtz,iasnd,jasnd,valsnd,info,& + & append=.true.,nzin=tot_elem,jmax=ncg) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='psb_sp_getrow') + goto 9999 + end if + tot_elem = tot_elem+ngtz + Enddo + counter = counter+n_el_send+3 + Enddo + nz = tot_elem + + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Going for alltoallv',iszs,iszr + if (rowcnv_) then + call psb_loc_to_glob(iasnd(1:nz),liasnd(1:nz),desc_a,info,iact='I') + else + liasnd(1:nz) = iasnd(1:nz) + end if + if (colcnv_) then + call psb_loc_to_glob(jasnd(1:nz),ljasnd(1:nz),col_desc_,info,iact='I') + else + ljasnd(1:nz) = jasnd(1:nz) + end if + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='psb_loc_to_glob') + goto 9999 + end if + + select case(psb_get_sp_a2av_alg()) + case(psb_sp_a2av_smpl_triad_) + call psb_simple_triad_a2av(valsnd,liasnd,ljasnd,sdsz,bsdindx,& + & acoo%val,acoo%ia,acoo%ja,rvsz,brvindx,ictxt,info) + case(psb_sp_a2av_smpl_v_) + call psb_simple_a2av(valsnd,sdsz,bsdindx,& + & acoo%val,rvsz,brvindx,ictxt,info) + if (info == psb_success_) call psb_simple_a2av(liasnd,sdsz,bsdindx,& + & acoo%ia,rvsz,brvindx,ictxt,info) + if (info == psb_success_) call psb_simple_a2av(ljasnd,sdsz,bsdindx,& + & acoo%ja,rvsz,brvindx,ictxt,info) + case(psb_sp_a2av_mpi_) + call mpi_alltoallv(valsnd,sdsz,bsdindx,psb_mpi_c_spk_,& + & acoo%val,rvsz,brvindx,psb_mpi_c_spk_,icomm,minfo) + if (minfo == mpi_success) & + & call mpi_alltoallv(liasnd,sdsz,bsdindx,psb_mpi_lpk_,& + & acoo%ia,rvsz,brvindx,psb_mpi_lpk_,icomm,minfo) + if (minfo == mpi_success) & + & call mpi_alltoallv(ljasnd,sdsz,bsdindx,psb_mpi_lpk_,& + & acoo%ja,rvsz,brvindx,psb_mpi_lpk_,icomm,minfo) + if (minfo /= mpi_success) info = minfo + case default + info = psb_err_internal_error_ + call psb_errpush(info,name,a_err='wrong A2AV alg selector') + goto 9999 + end select + + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='alltoallv') + goto 9999 + end if + + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Done alltoallv' + ! + ! Convert into local numbering + ! + if (rowcnv_) call psb_glob_to_loc(acoo%ia(1:iszr),desc_a,info,iact='I') + ! + ! This seems to be the correct output condition + ! + if (colcnv_.and.(.not.outcol_glob_)) & + & call psb_glob_to_loc(acoo%ja(1:iszr),col_desc_,info,iact='I') + + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='psbglob_to_loc') + goto 9999 + end if + + l1 = 0 + call acoo%set_nrows(lzero) + ! + irmin = huge(irmin) + icmin = huge(icmin) + irmax = 0 + icmax = 0 + Do i=1,iszr + r=(acoo%ia(i)) + k=(acoo%ja(i)) + ! Just in case some of the conversions were out-of-range + If ((r>0).and.(k>0)) Then + l1=l1+1 + acoo%val(l1) = acoo%val(i) + acoo%ia(l1) = r + acoo%ja(l1) = k + irmin = min(irmin,r) + irmax = max(irmax,r) + icmin = min(icmin,k) + icmax = max(icmax,k) + End If + Enddo + if (rowscale_) then + call acoo%set_nrows(max(irmax-irmin+1,0)) + acoo%ia(1:l1) = acoo%ia(1:l1) - irmin + 1 + else + call acoo%set_nrows(irmax) + end if + if (colscale_) then + call acoo%set_ncols(max(icmax-icmin+1,0)) + acoo%ja(1:l1) = acoo%ja(1:l1) - icmin + 1 + else + call acoo%set_ncols(icmax) + end if + + call acoo%set_nzeros(l1) + call acoo%set_sorted(.false.) + + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),& + & ': End data exchange',counter,l1 + + call acoo%fix(info) + if (info == psb_success_) call acoo%mv_to_fmt(blk,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='psb_spcnv') + goto 9999 + end if + + Deallocate(brvindx,bsdindx,rvsz,sdsz,& + & iasnd,jasnd,valsnd,stat=info) + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Done' + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ictxt,err_act) + + return + +End Subroutine psb_c_lc_csr_halo diff --git a/base/tools/psb_d_par_csr_spspmm.f90 b/base/tools/psb_d_par_csr_spspmm.f90 index b589351e..2e34a32c 100644 --- a/base/tools/psb_d_par_csr_spspmm.f90 +++ b/base/tools/psb_d_par_csr_spspmm.f90 @@ -61,98 +61,100 @@ ! info - integer, output. ! Error code. ! -!!$Subroutine psb_d_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) -!!$ use psb_base_mod, psb_protect_name => psb_d_par_csr_spspmm -!!$ Implicit None -!!$ -!!$ type(psb_d_csr_sparse_mat),intent(in) :: acsr -!!$ type(psb_d_csr_sparse_mat),intent(inout) :: bcsr -!!$ type(psb_d_csr_sparse_mat),intent(out) :: ccsr -!!$ type(psb_desc_type),intent(in) :: desc_a -!!$ type(psb_desc_type),intent(inout) :: desc_c -!!$ integer(psb_ipk_), intent(out) :: info -!!$ integer(psb_ipk_), intent(in), optional :: data -!!$ ! ...local scalars.... -!!$ integer(psb_ipk_) :: ictxt, np,me -!!$ integer(psb_ipk_) :: ncol, nnz -!!$ type(psb_d_csr_sparse_mat) :: tcsr1 -!!$ logical :: update_desc_c -!!$ integer(psb_ipk_) :: debug_level, debug_unit, err_act -!!$ character(len=20) :: name, ch_err -!!$ -!!$ if(psb_get_errstatus() /= 0) return -!!$ info=psb_success_ -!!$ name='psb_d_p_csr_spspmm' -!!$ call psb_erractionsave(err_act) -!!$ if (psb_errstatus_fatal()) then -!!$ info = psb_err_internal_error_ ; goto 9999 -!!$ end if -!!$ debug_unit = psb_get_debug_unit() -!!$ debug_level = psb_get_debug_level() -!!$ -!!$ ictxt = desc_a%get_context() -!!$ -!!$ call psb_info(ictxt, me, np) -!!$ -!!$ if (debug_level >= psb_debug_outer_) & -!!$ & write(debug_unit,*) me,' ',trim(name),': Start' -!!$ -!!$ update_desc_c = desc_c%is_bld() -!!$ -!!$ ! -!!$ ! This is a bit tricky. -!!$ ! DESC_A is the descriptor of (the columns of) A, and therefore -!!$ ! of the rows of B; the columns of B, in the intended usage, span -!!$ ! a different space for which we have DESC_C. -!!$ ! We are gathering the halo rows of B to multiply by A; -!!$ ! now, the columns of B would ideally be kept in -!!$ ! global numbering, so that we can call this repeatedly to accumulate -!!$ ! the product of multiple operators, and convert to local numbering -!!$ ! at the last possible moment. However, this would imply calling -!!$ ! the serial SPSPMM with a matrix B with the GLOBAL number of columns -!!$ ! and this could be very expensive in memory. The solution is to keep B -!!$ ! in local numbering, so that only columns really appearing count, but to -!!$ ! expand the descriptor when gathering the halo, because by performing -!!$ ! the products we are extending the support of the operator; hence -!!$ ! this routine is intended to be called with a temporary descriptor -!!$ ! DESC_C which is in the BUILD state, to allow for such expansion -!!$ ! across multiple products. -!!$ ! The caller will at some later point finalize the descriptor DESC_C. -!!$ ! -!!$ -!!$ ncol = desc_a%get_local_cols() -!!$ call psb_sphalo(bcsr,desc_a,tcsr1,info,& -!!$ & colcnv=.true.,rowscale=.true.,outcol_glob=.true.,col_desc=desc_c,data=data) -!!$ nnz = tcsr1%get_nzeros() -!!$ if (update_desc_c) then -!!$ call desc_c%indxmap%g2lip_ins(tcsr1%ja(1:nnz),info) -!!$ else -!!$ call desc_c%indxmap%g2lip(tcsr1%ja(1:nnz),info) -!!$ end if -!!$ if (info == psb_success_) call psb_rwextd(ncol,bcsr,info,b=tcsr1) -!!$ if (info == psb_success_) call tcsr1%free() -!!$ if(info /= psb_success_) then -!!$ call psb_errpush(psb_err_internal_error_,name,a_err='Extend am3') -!!$ goto 9999 -!!$ end if -!!$ call bcsr%set_ncols(desc_c%get_local_cols()) -!!$ -!!$ -!!$ if (debug_level >= psb_debug_outer_) & -!!$ & write(debug_unit,*) me,' ',trim(name),& -!!$ & 'starting spspmm 3' -!!$ if (debug_level >= psb_debug_outer_) write(debug_unit,*) me,' ',trim(name),& -!!$ & 'starting spspmm ',acsr%get_nrows(),acsr%get_ncols(),bcsr%get_nrows(),bcsr%get_ncols() -!!$ call psb_spspmm(acsr,bcsr,ccsr,info) -!!$ -!!$ call psb_erractionrestore(err_act) -!!$ return -!!$ -!!$9999 call psb_error_handler(ictxt,err_act) -!!$ -!!$ return -!!$ -!!$End Subroutine psb_d_par_csr_spspmm +Subroutine psb_d_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) + use psb_base_mod, psb_protect_name => psb_d_par_csr_spspmm + Implicit None + + type(psb_d_csr_sparse_mat),intent(in) :: acsr + type(psb_d_csr_sparse_mat),intent(inout) :: bcsr + type(psb_d_csr_sparse_mat),intent(out) :: ccsr + type(psb_desc_type),intent(in) :: desc_a + type(psb_desc_type),intent(inout) :: desc_c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_), intent(in), optional :: data + ! ...local scalars.... + integer(psb_ipk_) :: ictxt, np,me + integer(psb_ipk_) :: ncol, nnz + type(psb_ld_csr_sparse_mat) :: ltcsr + type(psb_d_csr_sparse_mat) :: tcsr + logical :: update_desc_c + integer(psb_ipk_) :: debug_level, debug_unit, err_act + character(len=20) :: name, ch_err + + if(psb_get_errstatus() /= 0) return + info=psb_success_ + name='psb_d_p_csr_spspmm' + call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if + debug_unit = psb_get_debug_unit() + debug_level = psb_get_debug_level() + + ictxt = desc_a%get_context() + + call psb_info(ictxt, me, np) + + if (debug_level >= psb_debug_outer_) & + & write(debug_unit,*) me,' ',trim(name),': Start' + + update_desc_c = desc_c%is_bld() + + ! + ! This is a bit tricky. + ! DESC_A is the descriptor of (the columns of) A, and therefore + ! of the rows of B; the columns of B, in the intended usage, span + ! a different space for which we have DESC_C. + ! We are gathering the halo rows of B to multiply by A; + ! now, the columns of B would ideally be kept in + ! global numbering, so that we can call this repeatedly to accumulate + ! the product of multiple operators, and convert to local numbering + ! at the last possible moment. However, this would imply calling + ! the serial SPSPMM with a matrix B with the GLOBAL number of columns + ! and this could be very expensive in memory. The solution is to keep B + ! in local numbering, so that only columns really appearing count, but to + ! expand the descriptor when gathering the halo, because by performing + ! the products we are extending the support of the operator; hence + ! this routine is intended to be called with a temporary descriptor + ! DESC_C which is in the BUILD state, to allow for such expansion + ! across multiple products. + ! The caller will at some later point finalize the descriptor DESC_C. + ! + + ncol = desc_a%get_local_cols() + call psb_sphalo(bcsr,desc_a,ltcsr,info,& + & colcnv=.true.,rowscale=.true.,outcol_glob=.true.,col_desc=desc_c,data=data) + nnz = ltcsr%get_nzeros() + if (update_desc_c) then + call desc_c%indxmap%g2lip_ins(ltcsr%ja(1:nnz),info) + else + call desc_c%indxmap%g2lip(ltcsr%ja(1:nnz),info) + end if + call ltcsr%mv_to_ifmt(tcsr,info) + if (info == psb_success_) call psb_rwextd(ncol,bcsr,info,b=tcsr) + if (info == psb_success_) call tcsr%free() + if(info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,a_err='Extend am3') + goto 9999 + end if + call bcsr%set_ncols(desc_c%get_local_cols()) + + + if (debug_level >= psb_debug_outer_) & + & write(debug_unit,*) me,' ',trim(name),& + & 'starting spspmm 3' + if (debug_level >= psb_debug_outer_) write(debug_unit,*) me,' ',trim(name),& + & 'starting spspmm ',acsr%get_nrows(),acsr%get_ncols(),bcsr%get_nrows(),bcsr%get_ncols() + call psb_spspmm(acsr,bcsr,ccsr,info) + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ictxt,err_act) + + return + +End Subroutine psb_d_par_csr_spspmm Subroutine psb_ld_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) use psb_base_mod, psb_protect_name => psb_ld_par_csr_spspmm diff --git a/base/tools/psb_dsphalo.F90 b/base/tools/psb_dsphalo.F90 index 51c4c9ef..8d800f6d 100644 --- a/base/tools/psb_dsphalo.F90 +++ b/base/tools/psb_dsphalo.F90 @@ -1239,3 +1239,375 @@ Subroutine psb_ld_csr_halo(a,desc_a,blk,info,rowcnv,colcnv,& return End Subroutine psb_ld_csr_halo + +Subroutine psb_d_ld_csr_halo(a,desc_a,blk,info,rowcnv,colcnv,& + & rowscale,colscale,data,outcol_glob,col_desc) + use psb_base_mod, psb_protect_name => psb_d_ld_csr_halo + +#ifdef MPI_MOD + use mpi +#endif + Implicit None +#ifdef MPI_H + include 'mpif.h' +#endif + + type(psb_d_csr_sparse_mat),Intent(in) :: a + type(psb_ld_csr_sparse_mat),Intent(inout) :: blk + type(psb_desc_type),intent(in), target :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, optional, intent(in) :: rowcnv,colcnv,rowscale,colscale,outcol_glob + integer(psb_ipk_), intent(in), optional :: data + type(psb_desc_type),Intent(in), optional, target :: col_desc + ! ...local scalars.... + integer(psb_ipk_) :: ictxt, np,me + integer(psb_ipk_) :: counter,proc,i, n_el_send,n_el_recv,& + & n_elem, j,ipx,mat_recv, iszs, iszr,idxs,idxr,nz,& + & data_,totxch,ngtz, idx, nxs, nxr, err_act, & + & nsnds, nrcvs, ncg, jpx, tot_elem + integer(psb_lpk_) :: irmax,icmax,irmin,icmin,l1, lnr, lnc, lnnz, & + & r, k + integer(psb_mpk_) :: icomm, minfo + integer(psb_mpk_), allocatable :: brvindx(:), & + & rvsz(:), bsdindx(:),sdsz(:) + integer(psb_ipk_), allocatable :: iasnd(:), jasnd(:) + integer(psb_lpk_), allocatable :: liasnd(:), ljasnd(:) + real(psb_dpk_), allocatable :: valsnd(:) + type(psb_ld_coo_sparse_mat), allocatable :: acoo + class(psb_i_base_vect_type), pointer :: pdxv + integer(psb_ipk_), allocatable :: ipdxv(:) + logical :: rowcnv_,colcnv_,rowscale_,colscale_,outcol_glob_ + Type(psb_desc_type), pointer :: col_desc_ + character(len=5) :: outfmt_ + integer(psb_ipk_) :: debug_level, debug_unit + character(len=20) :: name, ch_err + + if(psb_get_errstatus() /= 0) return + info=psb_success_ + name='psb_ld_csr_sphalo' + call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if + debug_unit = psb_get_debug_unit() + debug_level = psb_get_debug_level() + + ictxt = desc_a%get_context() + icomm = desc_a%get_mpic() + + Call psb_info(ictxt, me, np) + + if (debug_level >= psb_debug_outer_) & + & write(debug_unit,*) me,' ',trim(name),': Start' + + if (present(rowcnv)) then + rowcnv_ = rowcnv + else + rowcnv_ = .true. + endif + if (present(colcnv)) then + colcnv_ = colcnv + else + colcnv_ = .true. + endif + if (present(rowscale)) then + rowscale_ = rowscale + else + rowscale_ = .false. + endif + if (present(colscale)) then + colscale_ = colscale + else + colscale_ = .false. + endif + if (present(data)) then + data_ = data + else + data_ = psb_comm_halo_ + endif + if (present(outcol_glob)) then + outcol_glob_ = outcol_glob + else + outcol_glob_ = .false. + endif + if (present(col_desc)) then + col_desc_ => col_desc + else + col_desc_ => desc_a + end if + + Allocate(brvindx(np+1),& + & rvsz(np),sdsz(np),bsdindx(np+1), acoo,stat=info) + + if (info /= psb_success_) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + + If (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Data selector',data_ + select case(data_) + case(psb_comm_halo_,psb_comm_ext_ ) + ! Do not accept OVRLAP_INDEX any longer. + case default + call psb_errpush(psb_err_from_subroutine_,name,a_err='wrong Data selector') + goto 9999 + end select + + + sdsz(:)=0 + rvsz(:)=0 + l1 = 0 + brvindx(1) = 0 + bsdindx(1) = 0 + counter=1 + idx = 0 + idxs = 0 + idxr = 0 + + call desc_a%get_list(data_,pdxv,totxch,nxr,nxs,info) + ipdxv = pdxv%get_vect() + ! For all rows in the halo descriptor, extract the row size + lnr = 0 + Do + proc=ipdxv(counter) + if (proc == -1) exit + n_el_recv = ipdxv(counter+psb_n_elem_recv_) + counter = counter+n_el_recv + n_el_send = ipdxv(counter+psb_n_elem_send_) + tot_elem = 0 + Do j=0,n_el_send-1 + idx = ipdxv(counter+psb_elem_send_+j) + n_elem = a%get_nz_row(idx) + tot_elem = tot_elem+n_elem + Enddo + sdsz(proc+1) = tot_elem + lnr = lnr + n_el_recv + counter = counter+n_el_send+3 + Enddo + + ! + ! Exchange row sizes, so as to know sends/receives. + ! This is different from the halo exchange because the + ! size of the rows may vary, as opposed to fixed + ! (multi) vector row size. + ! + call mpi_alltoall(sdsz,1,psb_mpi_mpk_,& + & rvsz,1,psb_mpi_mpk_,icomm,minfo) + + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='mpi_alltoall') + goto 9999 + end if + nsnds = count(sdsz /= 0) + nrcvs = count(rvsz /= 0) + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Done initial alltoall',nsnds,nrcvs + + idxs = 0 + idxr = 0 + counter = 1 + Do + proc=ipdxv(counter) + if (proc == -1) exit + n_el_recv = ipdxv(counter+psb_n_elem_recv_) + counter = counter+n_el_recv + n_el_send = ipdxv(counter+psb_n_elem_send_) + + bsdindx(proc+1) = idxs + idxs = idxs + sdsz(proc+1) + brvindx(proc+1) = idxr + idxr = idxr + rvsz(proc+1) + counter = counter+n_el_send+3 + Enddo + + iszr = sum(rvsz) + mat_recv = iszr + iszs = sum(sdsz) + + lnnz = max(iszr,iszs,ione) + lnc = a%get_ncols() + call acoo%allocate(lnr,lnc,lnnz) + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Sizes:',acoo%get_size(),& + & ' Send:',sdsz(:),' Receive:',rvsz(:) + + call psb_ensure_size(max(iszs,1),iasnd,info) + if (info == psb_success_) call psb_ensure_size(max(iszs,1),jasnd,info) + if (info == psb_success_) call psb_ensure_size(max(iszs,1),liasnd,info) + if (info == psb_success_) call psb_ensure_size(max(iszs,1),ljasnd,info) + if (info == psb_success_) call psb_ensure_size(max(iszs,1),valsnd,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='ensure_size') + goto 9999 + end if + + l1 = 0 + ipx = 1 + counter=1 + idx = 0 + ! + ! Make sure to get all columns in csget. + ! This is necessary when sphalo is used to compute a transpose, + ! as opposed to just gathering halo for spspmm purposes. + ! + ncg = huge(ncg) + tot_elem = 0 + Do + proc = ipdxv(counter) + if (proc == -1) exit + n_el_recv = ipdxv(counter+psb_n_elem_recv_) + counter = counter+n_el_recv + n_el_send = ipdxv(counter+psb_n_elem_send_) + + Do j=0,n_el_send-1 + idx = ipdxv(counter+psb_elem_send_+j) + n_elem = a%get_nz_row(idx) + call a%csget(idx,idx,ngtz,iasnd,jasnd,valsnd,info,& + & append=.true.,nzin=tot_elem,jmax=ncg) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='psb_sp_getrow') + goto 9999 + end if + tot_elem = tot_elem+ngtz + Enddo + counter = counter+n_el_send+3 + Enddo + nz = tot_elem + + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Going for alltoallv',iszs,iszr + if (rowcnv_) then + call psb_loc_to_glob(iasnd(1:nz),liasnd(1:nz),desc_a,info,iact='I') + else + liasnd(1:nz) = iasnd(1:nz) + end if + if (colcnv_) then + call psb_loc_to_glob(jasnd(1:nz),ljasnd(1:nz),col_desc_,info,iact='I') + else + ljasnd(1:nz) = jasnd(1:nz) + end if + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='psb_loc_to_glob') + goto 9999 + end if + + select case(psb_get_sp_a2av_alg()) + case(psb_sp_a2av_smpl_triad_) + call psb_simple_triad_a2av(valsnd,liasnd,ljasnd,sdsz,bsdindx,& + & acoo%val,acoo%ia,acoo%ja,rvsz,brvindx,ictxt,info) + case(psb_sp_a2av_smpl_v_) + call psb_simple_a2av(valsnd,sdsz,bsdindx,& + & acoo%val,rvsz,brvindx,ictxt,info) + if (info == psb_success_) call psb_simple_a2av(liasnd,sdsz,bsdindx,& + & acoo%ia,rvsz,brvindx,ictxt,info) + if (info == psb_success_) call psb_simple_a2av(ljasnd,sdsz,bsdindx,& + & acoo%ja,rvsz,brvindx,ictxt,info) + case(psb_sp_a2av_mpi_) + call mpi_alltoallv(valsnd,sdsz,bsdindx,psb_mpi_r_dpk_,& + & acoo%val,rvsz,brvindx,psb_mpi_r_dpk_,icomm,minfo) + if (minfo == mpi_success) & + & call mpi_alltoallv(liasnd,sdsz,bsdindx,psb_mpi_lpk_,& + & acoo%ia,rvsz,brvindx,psb_mpi_lpk_,icomm,minfo) + if (minfo == mpi_success) & + & call mpi_alltoallv(ljasnd,sdsz,bsdindx,psb_mpi_lpk_,& + & acoo%ja,rvsz,brvindx,psb_mpi_lpk_,icomm,minfo) + if (minfo /= mpi_success) info = minfo + case default + info = psb_err_internal_error_ + call psb_errpush(info,name,a_err='wrong A2AV alg selector') + goto 9999 + end select + + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='alltoallv') + goto 9999 + end if + + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Done alltoallv' + ! + ! Convert into local numbering + ! + if (rowcnv_) call psb_glob_to_loc(acoo%ia(1:iszr),desc_a,info,iact='I') + ! + ! This seems to be the correct output condition + ! + if (colcnv_.and.(.not.outcol_glob_)) & + & call psb_glob_to_loc(acoo%ja(1:iszr),col_desc_,info,iact='I') + + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='psbglob_to_loc') + goto 9999 + end if + + l1 = 0 + call acoo%set_nrows(lzero) + ! + irmin = huge(irmin) + icmin = huge(icmin) + irmax = 0 + icmax = 0 + Do i=1,iszr + r=(acoo%ia(i)) + k=(acoo%ja(i)) + ! Just in case some of the conversions were out-of-range + If ((r>0).and.(k>0)) Then + l1=l1+1 + acoo%val(l1) = acoo%val(i) + acoo%ia(l1) = r + acoo%ja(l1) = k + irmin = min(irmin,r) + irmax = max(irmax,r) + icmin = min(icmin,k) + icmax = max(icmax,k) + End If + Enddo + if (rowscale_) then + call acoo%set_nrows(max(irmax-irmin+1,0)) + acoo%ia(1:l1) = acoo%ia(1:l1) - irmin + 1 + else + call acoo%set_nrows(irmax) + end if + if (colscale_) then + call acoo%set_ncols(max(icmax-icmin+1,0)) + acoo%ja(1:l1) = acoo%ja(1:l1) - icmin + 1 + else + call acoo%set_ncols(icmax) + end if + + call acoo%set_nzeros(l1) + call acoo%set_sorted(.false.) + + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),& + & ': End data exchange',counter,l1 + + call acoo%fix(info) + if (info == psb_success_) call acoo%mv_to_fmt(blk,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='psb_spcnv') + goto 9999 + end if + + Deallocate(brvindx,bsdindx,rvsz,sdsz,& + & iasnd,jasnd,valsnd,stat=info) + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Done' + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ictxt,err_act) + + return + +End Subroutine psb_d_ld_csr_halo diff --git a/base/tools/psb_s_par_csr_spspmm.f90 b/base/tools/psb_s_par_csr_spspmm.f90 index 5460b43e..fba99f27 100644 --- a/base/tools/psb_s_par_csr_spspmm.f90 +++ b/base/tools/psb_s_par_csr_spspmm.f90 @@ -61,98 +61,100 @@ ! info - integer, output. ! Error code. ! -!!$Subroutine psb_s_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) -!!$ use psb_base_mod, psb_protect_name => psb_s_par_csr_spspmm -!!$ Implicit None -!!$ -!!$ type(psb_s_csr_sparse_mat),intent(in) :: acsr -!!$ type(psb_s_csr_sparse_mat),intent(inout) :: bcsr -!!$ type(psb_s_csr_sparse_mat),intent(out) :: ccsr -!!$ type(psb_desc_type),intent(in) :: desc_a -!!$ type(psb_desc_type),intent(inout) :: desc_c -!!$ integer(psb_ipk_), intent(out) :: info -!!$ integer(psb_ipk_), intent(in), optional :: data -!!$ ! ...local scalars.... -!!$ integer(psb_ipk_) :: ictxt, np,me -!!$ integer(psb_ipk_) :: ncol, nnz -!!$ type(psb_s_csr_sparse_mat) :: tcsr1 -!!$ logical :: update_desc_c -!!$ integer(psb_ipk_) :: debug_level, debug_unit, err_act -!!$ character(len=20) :: name, ch_err -!!$ -!!$ if(psb_get_errstatus() /= 0) return -!!$ info=psb_success_ -!!$ name='psb_s_p_csr_spspmm' -!!$ call psb_erractionsave(err_act) -!!$ if (psb_errstatus_fatal()) then -!!$ info = psb_err_internal_error_ ; goto 9999 -!!$ end if -!!$ debug_unit = psb_get_debug_unit() -!!$ debug_level = psb_get_debug_level() -!!$ -!!$ ictxt = desc_a%get_context() -!!$ -!!$ call psb_info(ictxt, me, np) -!!$ -!!$ if (debug_level >= psb_debug_outer_) & -!!$ & write(debug_unit,*) me,' ',trim(name),': Start' -!!$ -!!$ update_desc_c = desc_c%is_bld() -!!$ -!!$ ! -!!$ ! This is a bit tricky. -!!$ ! DESC_A is the descriptor of (the columns of) A, and therefore -!!$ ! of the rows of B; the columns of B, in the intended usage, span -!!$ ! a different space for which we have DESC_C. -!!$ ! We are gathering the halo rows of B to multiply by A; -!!$ ! now, the columns of B would ideally be kept in -!!$ ! global numbering, so that we can call this repeatedly to accumulate -!!$ ! the product of multiple operators, and convert to local numbering -!!$ ! at the last possible moment. However, this would imply calling -!!$ ! the serial SPSPMM with a matrix B with the GLOBAL number of columns -!!$ ! and this could be very expensive in memory. The solution is to keep B -!!$ ! in local numbering, so that only columns really appearing count, but to -!!$ ! expand the descriptor when gathering the halo, because by performing -!!$ ! the products we are extending the support of the operator; hence -!!$ ! this routine is intended to be called with a temporary descriptor -!!$ ! DESC_C which is in the BUILD state, to allow for such expansion -!!$ ! across multiple products. -!!$ ! The caller will at some later point finalize the descriptor DESC_C. -!!$ ! -!!$ -!!$ ncol = desc_a%get_local_cols() -!!$ call psb_sphalo(bcsr,desc_a,tcsr1,info,& -!!$ & colcnv=.true.,rowscale=.true.,outcol_glob=.true.,col_desc=desc_c,data=data) -!!$ nnz = tcsr1%get_nzeros() -!!$ if (update_desc_c) then -!!$ call desc_c%indxmap%g2lip_ins(tcsr1%ja(1:nnz),info) -!!$ else -!!$ call desc_c%indxmap%g2lip(tcsr1%ja(1:nnz),info) -!!$ end if -!!$ if (info == psb_success_) call psb_rwextd(ncol,bcsr,info,b=tcsr1) -!!$ if (info == psb_success_) call tcsr1%free() -!!$ if(info /= psb_success_) then -!!$ call psb_errpush(psb_err_internal_error_,name,a_err='Extend am3') -!!$ goto 9999 -!!$ end if -!!$ call bcsr%set_ncols(desc_c%get_local_cols()) -!!$ -!!$ -!!$ if (debug_level >= psb_debug_outer_) & -!!$ & write(debug_unit,*) me,' ',trim(name),& -!!$ & 'starting spspmm 3' -!!$ if (debug_level >= psb_debug_outer_) write(debug_unit,*) me,' ',trim(name),& -!!$ & 'starting spspmm ',acsr%get_nrows(),acsr%get_ncols(),bcsr%get_nrows(),bcsr%get_ncols() -!!$ call psb_spspmm(acsr,bcsr,ccsr,info) -!!$ -!!$ call psb_erractionrestore(err_act) -!!$ return -!!$ -!!$9999 call psb_error_handler(ictxt,err_act) -!!$ -!!$ return -!!$ -!!$End Subroutine psb_s_par_csr_spspmm +Subroutine psb_s_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) + use psb_base_mod, psb_protect_name => psb_s_par_csr_spspmm + Implicit None + + type(psb_s_csr_sparse_mat),intent(in) :: acsr + type(psb_s_csr_sparse_mat),intent(inout) :: bcsr + type(psb_s_csr_sparse_mat),intent(out) :: ccsr + type(psb_desc_type),intent(in) :: desc_a + type(psb_desc_type),intent(inout) :: desc_c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_), intent(in), optional :: data + ! ...local scalars.... + integer(psb_ipk_) :: ictxt, np,me + integer(psb_ipk_) :: ncol, nnz + type(psb_ls_csr_sparse_mat) :: ltcsr + type(psb_s_csr_sparse_mat) :: tcsr + logical :: update_desc_c + integer(psb_ipk_) :: debug_level, debug_unit, err_act + character(len=20) :: name, ch_err + + if(psb_get_errstatus() /= 0) return + info=psb_success_ + name='psb_s_p_csr_spspmm' + call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if + debug_unit = psb_get_debug_unit() + debug_level = psb_get_debug_level() + + ictxt = desc_a%get_context() + + call psb_info(ictxt, me, np) + + if (debug_level >= psb_debug_outer_) & + & write(debug_unit,*) me,' ',trim(name),': Start' + + update_desc_c = desc_c%is_bld() + + ! + ! This is a bit tricky. + ! DESC_A is the descriptor of (the columns of) A, and therefore + ! of the rows of B; the columns of B, in the intended usage, span + ! a different space for which we have DESC_C. + ! We are gathering the halo rows of B to multiply by A; + ! now, the columns of B would ideally be kept in + ! global numbering, so that we can call this repeatedly to accumulate + ! the product of multiple operators, and convert to local numbering + ! at the last possible moment. However, this would imply calling + ! the serial SPSPMM with a matrix B with the GLOBAL number of columns + ! and this could be very expensive in memory. The solution is to keep B + ! in local numbering, so that only columns really appearing count, but to + ! expand the descriptor when gathering the halo, because by performing + ! the products we are extending the support of the operator; hence + ! this routine is intended to be called with a temporary descriptor + ! DESC_C which is in the BUILD state, to allow for such expansion + ! across multiple products. + ! The caller will at some later point finalize the descriptor DESC_C. + ! + + ncol = desc_a%get_local_cols() + call psb_sphalo(bcsr,desc_a,ltcsr,info,& + & colcnv=.true.,rowscale=.true.,outcol_glob=.true.,col_desc=desc_c,data=data) + nnz = ltcsr%get_nzeros() + if (update_desc_c) then + call desc_c%indxmap%g2lip_ins(ltcsr%ja(1:nnz),info) + else + call desc_c%indxmap%g2lip(ltcsr%ja(1:nnz),info) + end if + call ltcsr%mv_to_ifmt(tcsr,info) + if (info == psb_success_) call psb_rwextd(ncol,bcsr,info,b=tcsr) + if (info == psb_success_) call tcsr%free() + if(info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,a_err='Extend am3') + goto 9999 + end if + call bcsr%set_ncols(desc_c%get_local_cols()) + + + if (debug_level >= psb_debug_outer_) & + & write(debug_unit,*) me,' ',trim(name),& + & 'starting spspmm 3' + if (debug_level >= psb_debug_outer_) write(debug_unit,*) me,' ',trim(name),& + & 'starting spspmm ',acsr%get_nrows(),acsr%get_ncols(),bcsr%get_nrows(),bcsr%get_ncols() + call psb_spspmm(acsr,bcsr,ccsr,info) + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ictxt,err_act) + + return + +End Subroutine psb_s_par_csr_spspmm Subroutine psb_ls_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) use psb_base_mod, psb_protect_name => psb_ls_par_csr_spspmm diff --git a/base/tools/psb_ssphalo.F90 b/base/tools/psb_ssphalo.F90 index a7c914c0..038e72a5 100644 --- a/base/tools/psb_ssphalo.F90 +++ b/base/tools/psb_ssphalo.F90 @@ -1239,3 +1239,375 @@ Subroutine psb_ls_csr_halo(a,desc_a,blk,info,rowcnv,colcnv,& return End Subroutine psb_ls_csr_halo + +Subroutine psb_s_ls_csr_halo(a,desc_a,blk,info,rowcnv,colcnv,& + & rowscale,colscale,data,outcol_glob,col_desc) + use psb_base_mod, psb_protect_name => psb_s_ls_csr_halo + +#ifdef MPI_MOD + use mpi +#endif + Implicit None +#ifdef MPI_H + include 'mpif.h' +#endif + + type(psb_s_csr_sparse_mat),Intent(in) :: a + type(psb_ls_csr_sparse_mat),Intent(inout) :: blk + type(psb_desc_type),intent(in), target :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, optional, intent(in) :: rowcnv,colcnv,rowscale,colscale,outcol_glob + integer(psb_ipk_), intent(in), optional :: data + type(psb_desc_type),Intent(in), optional, target :: col_desc + ! ...local scalars.... + integer(psb_ipk_) :: ictxt, np,me + integer(psb_ipk_) :: counter,proc,i, n_el_send,n_el_recv,& + & n_elem, j,ipx,mat_recv, iszs, iszr,idxs,idxr,nz,& + & data_,totxch,ngtz, idx, nxs, nxr, err_act, & + & nsnds, nrcvs, ncg, jpx, tot_elem + integer(psb_lpk_) :: irmax,icmax,irmin,icmin,l1, lnr, lnc, lnnz, & + & r, k + integer(psb_mpk_) :: icomm, minfo + integer(psb_mpk_), allocatable :: brvindx(:), & + & rvsz(:), bsdindx(:),sdsz(:) + integer(psb_ipk_), allocatable :: iasnd(:), jasnd(:) + integer(psb_lpk_), allocatable :: liasnd(:), ljasnd(:) + real(psb_spk_), allocatable :: valsnd(:) + type(psb_ls_coo_sparse_mat), allocatable :: acoo + class(psb_i_base_vect_type), pointer :: pdxv + integer(psb_ipk_), allocatable :: ipdxv(:) + logical :: rowcnv_,colcnv_,rowscale_,colscale_,outcol_glob_ + Type(psb_desc_type), pointer :: col_desc_ + character(len=5) :: outfmt_ + integer(psb_ipk_) :: debug_level, debug_unit + character(len=20) :: name, ch_err + + if(psb_get_errstatus() /= 0) return + info=psb_success_ + name='psb_ls_csr_sphalo' + call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if + debug_unit = psb_get_debug_unit() + debug_level = psb_get_debug_level() + + ictxt = desc_a%get_context() + icomm = desc_a%get_mpic() + + Call psb_info(ictxt, me, np) + + if (debug_level >= psb_debug_outer_) & + & write(debug_unit,*) me,' ',trim(name),': Start' + + if (present(rowcnv)) then + rowcnv_ = rowcnv + else + rowcnv_ = .true. + endif + if (present(colcnv)) then + colcnv_ = colcnv + else + colcnv_ = .true. + endif + if (present(rowscale)) then + rowscale_ = rowscale + else + rowscale_ = .false. + endif + if (present(colscale)) then + colscale_ = colscale + else + colscale_ = .false. + endif + if (present(data)) then + data_ = data + else + data_ = psb_comm_halo_ + endif + if (present(outcol_glob)) then + outcol_glob_ = outcol_glob + else + outcol_glob_ = .false. + endif + if (present(col_desc)) then + col_desc_ => col_desc + else + col_desc_ => desc_a + end if + + Allocate(brvindx(np+1),& + & rvsz(np),sdsz(np),bsdindx(np+1), acoo,stat=info) + + if (info /= psb_success_) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + + If (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Data selector',data_ + select case(data_) + case(psb_comm_halo_,psb_comm_ext_ ) + ! Do not accept OVRLAP_INDEX any longer. + case default + call psb_errpush(psb_err_from_subroutine_,name,a_err='wrong Data selector') + goto 9999 + end select + + + sdsz(:)=0 + rvsz(:)=0 + l1 = 0 + brvindx(1) = 0 + bsdindx(1) = 0 + counter=1 + idx = 0 + idxs = 0 + idxr = 0 + + call desc_a%get_list(data_,pdxv,totxch,nxr,nxs,info) + ipdxv = pdxv%get_vect() + ! For all rows in the halo descriptor, extract the row size + lnr = 0 + Do + proc=ipdxv(counter) + if (proc == -1) exit + n_el_recv = ipdxv(counter+psb_n_elem_recv_) + counter = counter+n_el_recv + n_el_send = ipdxv(counter+psb_n_elem_send_) + tot_elem = 0 + Do j=0,n_el_send-1 + idx = ipdxv(counter+psb_elem_send_+j) + n_elem = a%get_nz_row(idx) + tot_elem = tot_elem+n_elem + Enddo + sdsz(proc+1) = tot_elem + lnr = lnr + n_el_recv + counter = counter+n_el_send+3 + Enddo + + ! + ! Exchange row sizes, so as to know sends/receives. + ! This is different from the halo exchange because the + ! size of the rows may vary, as opposed to fixed + ! (multi) vector row size. + ! + call mpi_alltoall(sdsz,1,psb_mpi_mpk_,& + & rvsz,1,psb_mpi_mpk_,icomm,minfo) + + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='mpi_alltoall') + goto 9999 + end if + nsnds = count(sdsz /= 0) + nrcvs = count(rvsz /= 0) + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Done initial alltoall',nsnds,nrcvs + + idxs = 0 + idxr = 0 + counter = 1 + Do + proc=ipdxv(counter) + if (proc == -1) exit + n_el_recv = ipdxv(counter+psb_n_elem_recv_) + counter = counter+n_el_recv + n_el_send = ipdxv(counter+psb_n_elem_send_) + + bsdindx(proc+1) = idxs + idxs = idxs + sdsz(proc+1) + brvindx(proc+1) = idxr + idxr = idxr + rvsz(proc+1) + counter = counter+n_el_send+3 + Enddo + + iszr = sum(rvsz) + mat_recv = iszr + iszs = sum(sdsz) + + lnnz = max(iszr,iszs,ione) + lnc = a%get_ncols() + call acoo%allocate(lnr,lnc,lnnz) + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Sizes:',acoo%get_size(),& + & ' Send:',sdsz(:),' Receive:',rvsz(:) + + call psb_ensure_size(max(iszs,1),iasnd,info) + if (info == psb_success_) call psb_ensure_size(max(iszs,1),jasnd,info) + if (info == psb_success_) call psb_ensure_size(max(iszs,1),liasnd,info) + if (info == psb_success_) call psb_ensure_size(max(iszs,1),ljasnd,info) + if (info == psb_success_) call psb_ensure_size(max(iszs,1),valsnd,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='ensure_size') + goto 9999 + end if + + l1 = 0 + ipx = 1 + counter=1 + idx = 0 + ! + ! Make sure to get all columns in csget. + ! This is necessary when sphalo is used to compute a transpose, + ! as opposed to just gathering halo for spspmm purposes. + ! + ncg = huge(ncg) + tot_elem = 0 + Do + proc = ipdxv(counter) + if (proc == -1) exit + n_el_recv = ipdxv(counter+psb_n_elem_recv_) + counter = counter+n_el_recv + n_el_send = ipdxv(counter+psb_n_elem_send_) + + Do j=0,n_el_send-1 + idx = ipdxv(counter+psb_elem_send_+j) + n_elem = a%get_nz_row(idx) + call a%csget(idx,idx,ngtz,iasnd,jasnd,valsnd,info,& + & append=.true.,nzin=tot_elem,jmax=ncg) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='psb_sp_getrow') + goto 9999 + end if + tot_elem = tot_elem+ngtz + Enddo + counter = counter+n_el_send+3 + Enddo + nz = tot_elem + + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Going for alltoallv',iszs,iszr + if (rowcnv_) then + call psb_loc_to_glob(iasnd(1:nz),liasnd(1:nz),desc_a,info,iact='I') + else + liasnd(1:nz) = iasnd(1:nz) + end if + if (colcnv_) then + call psb_loc_to_glob(jasnd(1:nz),ljasnd(1:nz),col_desc_,info,iact='I') + else + ljasnd(1:nz) = jasnd(1:nz) + end if + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='psb_loc_to_glob') + goto 9999 + end if + + select case(psb_get_sp_a2av_alg()) + case(psb_sp_a2av_smpl_triad_) + call psb_simple_triad_a2av(valsnd,liasnd,ljasnd,sdsz,bsdindx,& + & acoo%val,acoo%ia,acoo%ja,rvsz,brvindx,ictxt,info) + case(psb_sp_a2av_smpl_v_) + call psb_simple_a2av(valsnd,sdsz,bsdindx,& + & acoo%val,rvsz,brvindx,ictxt,info) + if (info == psb_success_) call psb_simple_a2av(liasnd,sdsz,bsdindx,& + & acoo%ia,rvsz,brvindx,ictxt,info) + if (info == psb_success_) call psb_simple_a2av(ljasnd,sdsz,bsdindx,& + & acoo%ja,rvsz,brvindx,ictxt,info) + case(psb_sp_a2av_mpi_) + call mpi_alltoallv(valsnd,sdsz,bsdindx,psb_mpi_r_spk_,& + & acoo%val,rvsz,brvindx,psb_mpi_r_spk_,icomm,minfo) + if (minfo == mpi_success) & + & call mpi_alltoallv(liasnd,sdsz,bsdindx,psb_mpi_lpk_,& + & acoo%ia,rvsz,brvindx,psb_mpi_lpk_,icomm,minfo) + if (minfo == mpi_success) & + & call mpi_alltoallv(ljasnd,sdsz,bsdindx,psb_mpi_lpk_,& + & acoo%ja,rvsz,brvindx,psb_mpi_lpk_,icomm,minfo) + if (minfo /= mpi_success) info = minfo + case default + info = psb_err_internal_error_ + call psb_errpush(info,name,a_err='wrong A2AV alg selector') + goto 9999 + end select + + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='alltoallv') + goto 9999 + end if + + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Done alltoallv' + ! + ! Convert into local numbering + ! + if (rowcnv_) call psb_glob_to_loc(acoo%ia(1:iszr),desc_a,info,iact='I') + ! + ! This seems to be the correct output condition + ! + if (colcnv_.and.(.not.outcol_glob_)) & + & call psb_glob_to_loc(acoo%ja(1:iszr),col_desc_,info,iact='I') + + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='psbglob_to_loc') + goto 9999 + end if + + l1 = 0 + call acoo%set_nrows(lzero) + ! + irmin = huge(irmin) + icmin = huge(icmin) + irmax = 0 + icmax = 0 + Do i=1,iszr + r=(acoo%ia(i)) + k=(acoo%ja(i)) + ! Just in case some of the conversions were out-of-range + If ((r>0).and.(k>0)) Then + l1=l1+1 + acoo%val(l1) = acoo%val(i) + acoo%ia(l1) = r + acoo%ja(l1) = k + irmin = min(irmin,r) + irmax = max(irmax,r) + icmin = min(icmin,k) + icmax = max(icmax,k) + End If + Enddo + if (rowscale_) then + call acoo%set_nrows(max(irmax-irmin+1,0)) + acoo%ia(1:l1) = acoo%ia(1:l1) - irmin + 1 + else + call acoo%set_nrows(irmax) + end if + if (colscale_) then + call acoo%set_ncols(max(icmax-icmin+1,0)) + acoo%ja(1:l1) = acoo%ja(1:l1) - icmin + 1 + else + call acoo%set_ncols(icmax) + end if + + call acoo%set_nzeros(l1) + call acoo%set_sorted(.false.) + + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),& + & ': End data exchange',counter,l1 + + call acoo%fix(info) + if (info == psb_success_) call acoo%mv_to_fmt(blk,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='psb_spcnv') + goto 9999 + end if + + Deallocate(brvindx,bsdindx,rvsz,sdsz,& + & iasnd,jasnd,valsnd,stat=info) + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Done' + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ictxt,err_act) + + return + +End Subroutine psb_s_ls_csr_halo diff --git a/base/tools/psb_z_par_csr_spspmm.f90 b/base/tools/psb_z_par_csr_spspmm.f90 index 4499b053..3be9d0e8 100644 --- a/base/tools/psb_z_par_csr_spspmm.f90 +++ b/base/tools/psb_z_par_csr_spspmm.f90 @@ -61,98 +61,100 @@ ! info - integer, output. ! Error code. ! -!!$Subroutine psb_z_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) -!!$ use psb_base_mod, psb_protect_name => psb_z_par_csr_spspmm -!!$ Implicit None -!!$ -!!$ type(psb_z_csr_sparse_mat),intent(in) :: acsr -!!$ type(psb_z_csr_sparse_mat),intent(inout) :: bcsr -!!$ type(psb_z_csr_sparse_mat),intent(out) :: ccsr -!!$ type(psb_desc_type),intent(in) :: desc_a -!!$ type(psb_desc_type),intent(inout) :: desc_c -!!$ integer(psb_ipk_), intent(out) :: info -!!$ integer(psb_ipk_), intent(in), optional :: data -!!$ ! ...local scalars.... -!!$ integer(psb_ipk_) :: ictxt, np,me -!!$ integer(psb_ipk_) :: ncol, nnz -!!$ type(psb_z_csr_sparse_mat) :: tcsr1 -!!$ logical :: update_desc_c -!!$ integer(psb_ipk_) :: debug_level, debug_unit, err_act -!!$ character(len=20) :: name, ch_err -!!$ -!!$ if(psb_get_errstatus() /= 0) return -!!$ info=psb_success_ -!!$ name='psb_z_p_csr_spspmm' -!!$ call psb_erractionsave(err_act) -!!$ if (psb_errstatus_fatal()) then -!!$ info = psb_err_internal_error_ ; goto 9999 -!!$ end if -!!$ debug_unit = psb_get_debug_unit() -!!$ debug_level = psb_get_debug_level() -!!$ -!!$ ictxt = desc_a%get_context() -!!$ -!!$ call psb_info(ictxt, me, np) -!!$ -!!$ if (debug_level >= psb_debug_outer_) & -!!$ & write(debug_unit,*) me,' ',trim(name),': Start' -!!$ -!!$ update_desc_c = desc_c%is_bld() -!!$ -!!$ ! -!!$ ! This is a bit tricky. -!!$ ! DESC_A is the descriptor of (the columns of) A, and therefore -!!$ ! of the rows of B; the columns of B, in the intended usage, span -!!$ ! a different space for which we have DESC_C. -!!$ ! We are gathering the halo rows of B to multiply by A; -!!$ ! now, the columns of B would ideally be kept in -!!$ ! global numbering, so that we can call this repeatedly to accumulate -!!$ ! the product of multiple operators, and convert to local numbering -!!$ ! at the last possible moment. However, this would imply calling -!!$ ! the serial SPSPMM with a matrix B with the GLOBAL number of columns -!!$ ! and this could be very expensive in memory. The solution is to keep B -!!$ ! in local numbering, so that only columns really appearing count, but to -!!$ ! expand the descriptor when gathering the halo, because by performing -!!$ ! the products we are extending the support of the operator; hence -!!$ ! this routine is intended to be called with a temporary descriptor -!!$ ! DESC_C which is in the BUILD state, to allow for such expansion -!!$ ! across multiple products. -!!$ ! The caller will at some later point finalize the descriptor DESC_C. -!!$ ! -!!$ -!!$ ncol = desc_a%get_local_cols() -!!$ call psb_sphalo(bcsr,desc_a,tcsr1,info,& -!!$ & colcnv=.true.,rowscale=.true.,outcol_glob=.true.,col_desc=desc_c,data=data) -!!$ nnz = tcsr1%get_nzeros() -!!$ if (update_desc_c) then -!!$ call desc_c%indxmap%g2lip_ins(tcsr1%ja(1:nnz),info) -!!$ else -!!$ call desc_c%indxmap%g2lip(tcsr1%ja(1:nnz),info) -!!$ end if -!!$ if (info == psb_success_) call psb_rwextd(ncol,bcsr,info,b=tcsr1) -!!$ if (info == psb_success_) call tcsr1%free() -!!$ if(info /= psb_success_) then -!!$ call psb_errpush(psb_err_internal_error_,name,a_err='Extend am3') -!!$ goto 9999 -!!$ end if -!!$ call bcsr%set_ncols(desc_c%get_local_cols()) -!!$ -!!$ -!!$ if (debug_level >= psb_debug_outer_) & -!!$ & write(debug_unit,*) me,' ',trim(name),& -!!$ & 'starting spspmm 3' -!!$ if (debug_level >= psb_debug_outer_) write(debug_unit,*) me,' ',trim(name),& -!!$ & 'starting spspmm ',acsr%get_nrows(),acsr%get_ncols(),bcsr%get_nrows(),bcsr%get_ncols() -!!$ call psb_spspmm(acsr,bcsr,ccsr,info) -!!$ -!!$ call psb_erractionrestore(err_act) -!!$ return -!!$ -!!$9999 call psb_error_handler(ictxt,err_act) -!!$ -!!$ return -!!$ -!!$End Subroutine psb_z_par_csr_spspmm +Subroutine psb_z_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) + use psb_base_mod, psb_protect_name => psb_z_par_csr_spspmm + Implicit None + + type(psb_z_csr_sparse_mat),intent(in) :: acsr + type(psb_z_csr_sparse_mat),intent(inout) :: bcsr + type(psb_z_csr_sparse_mat),intent(out) :: ccsr + type(psb_desc_type),intent(in) :: desc_a + type(psb_desc_type),intent(inout) :: desc_c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_), intent(in), optional :: data + ! ...local scalars.... + integer(psb_ipk_) :: ictxt, np,me + integer(psb_ipk_) :: ncol, nnz + type(psb_lz_csr_sparse_mat) :: ltcsr + type(psb_z_csr_sparse_mat) :: tcsr + logical :: update_desc_c + integer(psb_ipk_) :: debug_level, debug_unit, err_act + character(len=20) :: name, ch_err + + if(psb_get_errstatus() /= 0) return + info=psb_success_ + name='psb_z_p_csr_spspmm' + call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if + debug_unit = psb_get_debug_unit() + debug_level = psb_get_debug_level() + + ictxt = desc_a%get_context() + + call psb_info(ictxt, me, np) + + if (debug_level >= psb_debug_outer_) & + & write(debug_unit,*) me,' ',trim(name),': Start' + + update_desc_c = desc_c%is_bld() + + ! + ! This is a bit tricky. + ! DESC_A is the descriptor of (the columns of) A, and therefore + ! of the rows of B; the columns of B, in the intended usage, span + ! a different space for which we have DESC_C. + ! We are gathering the halo rows of B to multiply by A; + ! now, the columns of B would ideally be kept in + ! global numbering, so that we can call this repeatedly to accumulate + ! the product of multiple operators, and convert to local numbering + ! at the last possible moment. However, this would imply calling + ! the serial SPSPMM with a matrix B with the GLOBAL number of columns + ! and this could be very expensive in memory. The solution is to keep B + ! in local numbering, so that only columns really appearing count, but to + ! expand the descriptor when gathering the halo, because by performing + ! the products we are extending the support of the operator; hence + ! this routine is intended to be called with a temporary descriptor + ! DESC_C which is in the BUILD state, to allow for such expansion + ! across multiple products. + ! The caller will at some later point finalize the descriptor DESC_C. + ! + + ncol = desc_a%get_local_cols() + call psb_sphalo(bcsr,desc_a,ltcsr,info,& + & colcnv=.true.,rowscale=.true.,outcol_glob=.true.,col_desc=desc_c,data=data) + nnz = ltcsr%get_nzeros() + if (update_desc_c) then + call desc_c%indxmap%g2lip_ins(ltcsr%ja(1:nnz),info) + else + call desc_c%indxmap%g2lip(ltcsr%ja(1:nnz),info) + end if + call ltcsr%mv_to_ifmt(tcsr,info) + if (info == psb_success_) call psb_rwextd(ncol,bcsr,info,b=tcsr) + if (info == psb_success_) call tcsr%free() + if(info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,a_err='Extend am3') + goto 9999 + end if + call bcsr%set_ncols(desc_c%get_local_cols()) + + + if (debug_level >= psb_debug_outer_) & + & write(debug_unit,*) me,' ',trim(name),& + & 'starting spspmm 3' + if (debug_level >= psb_debug_outer_) write(debug_unit,*) me,' ',trim(name),& + & 'starting spspmm ',acsr%get_nrows(),acsr%get_ncols(),bcsr%get_nrows(),bcsr%get_ncols() + call psb_spspmm(acsr,bcsr,ccsr,info) + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ictxt,err_act) + + return + +End Subroutine psb_z_par_csr_spspmm Subroutine psb_lz_par_csr_spspmm(acsr,desc_a,bcsr,ccsr,desc_c,info,data) use psb_base_mod, psb_protect_name => psb_lz_par_csr_spspmm diff --git a/base/tools/psb_zsphalo.F90 b/base/tools/psb_zsphalo.F90 index 434680c4..0e32938b 100644 --- a/base/tools/psb_zsphalo.F90 +++ b/base/tools/psb_zsphalo.F90 @@ -1239,3 +1239,375 @@ Subroutine psb_lz_csr_halo(a,desc_a,blk,info,rowcnv,colcnv,& return End Subroutine psb_lz_csr_halo + +Subroutine psb_z_lz_csr_halo(a,desc_a,blk,info,rowcnv,colcnv,& + & rowscale,colscale,data,outcol_glob,col_desc) + use psb_base_mod, psb_protect_name => psb_z_lz_csr_halo + +#ifdef MPI_MOD + use mpi +#endif + Implicit None +#ifdef MPI_H + include 'mpif.h' +#endif + + type(psb_z_csr_sparse_mat),Intent(in) :: a + type(psb_lz_csr_sparse_mat),Intent(inout) :: blk + type(psb_desc_type),intent(in), target :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, optional, intent(in) :: rowcnv,colcnv,rowscale,colscale,outcol_glob + integer(psb_ipk_), intent(in), optional :: data + type(psb_desc_type),Intent(in), optional, target :: col_desc + ! ...local scalars.... + integer(psb_ipk_) :: ictxt, np,me + integer(psb_ipk_) :: counter,proc,i, n_el_send,n_el_recv,& + & n_elem, j,ipx,mat_recv, iszs, iszr,idxs,idxr,nz,& + & data_,totxch,ngtz, idx, nxs, nxr, err_act, & + & nsnds, nrcvs, ncg, jpx, tot_elem + integer(psb_lpk_) :: irmax,icmax,irmin,icmin,l1, lnr, lnc, lnnz, & + & r, k + integer(psb_mpk_) :: icomm, minfo + integer(psb_mpk_), allocatable :: brvindx(:), & + & rvsz(:), bsdindx(:),sdsz(:) + integer(psb_ipk_), allocatable :: iasnd(:), jasnd(:) + integer(psb_lpk_), allocatable :: liasnd(:), ljasnd(:) + complex(psb_dpk_), allocatable :: valsnd(:) + type(psb_lz_coo_sparse_mat), allocatable :: acoo + class(psb_i_base_vect_type), pointer :: pdxv + integer(psb_ipk_), allocatable :: ipdxv(:) + logical :: rowcnv_,colcnv_,rowscale_,colscale_,outcol_glob_ + Type(psb_desc_type), pointer :: col_desc_ + character(len=5) :: outfmt_ + integer(psb_ipk_) :: debug_level, debug_unit + character(len=20) :: name, ch_err + + if(psb_get_errstatus() /= 0) return + info=psb_success_ + name='psb_lz_csr_sphalo' + call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if + debug_unit = psb_get_debug_unit() + debug_level = psb_get_debug_level() + + ictxt = desc_a%get_context() + icomm = desc_a%get_mpic() + + Call psb_info(ictxt, me, np) + + if (debug_level >= psb_debug_outer_) & + & write(debug_unit,*) me,' ',trim(name),': Start' + + if (present(rowcnv)) then + rowcnv_ = rowcnv + else + rowcnv_ = .true. + endif + if (present(colcnv)) then + colcnv_ = colcnv + else + colcnv_ = .true. + endif + if (present(rowscale)) then + rowscale_ = rowscale + else + rowscale_ = .false. + endif + if (present(colscale)) then + colscale_ = colscale + else + colscale_ = .false. + endif + if (present(data)) then + data_ = data + else + data_ = psb_comm_halo_ + endif + if (present(outcol_glob)) then + outcol_glob_ = outcol_glob + else + outcol_glob_ = .false. + endif + if (present(col_desc)) then + col_desc_ => col_desc + else + col_desc_ => desc_a + end if + + Allocate(brvindx(np+1),& + & rvsz(np),sdsz(np),bsdindx(np+1), acoo,stat=info) + + if (info /= psb_success_) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + + If (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Data selector',data_ + select case(data_) + case(psb_comm_halo_,psb_comm_ext_ ) + ! Do not accept OVRLAP_INDEX any longer. + case default + call psb_errpush(psb_err_from_subroutine_,name,a_err='wrong Data selector') + goto 9999 + end select + + + sdsz(:)=0 + rvsz(:)=0 + l1 = 0 + brvindx(1) = 0 + bsdindx(1) = 0 + counter=1 + idx = 0 + idxs = 0 + idxr = 0 + + call desc_a%get_list(data_,pdxv,totxch,nxr,nxs,info) + ipdxv = pdxv%get_vect() + ! For all rows in the halo descriptor, extract the row size + lnr = 0 + Do + proc=ipdxv(counter) + if (proc == -1) exit + n_el_recv = ipdxv(counter+psb_n_elem_recv_) + counter = counter+n_el_recv + n_el_send = ipdxv(counter+psb_n_elem_send_) + tot_elem = 0 + Do j=0,n_el_send-1 + idx = ipdxv(counter+psb_elem_send_+j) + n_elem = a%get_nz_row(idx) + tot_elem = tot_elem+n_elem + Enddo + sdsz(proc+1) = tot_elem + lnr = lnr + n_el_recv + counter = counter+n_el_send+3 + Enddo + + ! + ! Exchange row sizes, so as to know sends/receives. + ! This is different from the halo exchange because the + ! size of the rows may vary, as opposed to fixed + ! (multi) vector row size. + ! + call mpi_alltoall(sdsz,1,psb_mpi_mpk_,& + & rvsz,1,psb_mpi_mpk_,icomm,minfo) + + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='mpi_alltoall') + goto 9999 + end if + nsnds = count(sdsz /= 0) + nrcvs = count(rvsz /= 0) + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Done initial alltoall',nsnds,nrcvs + + idxs = 0 + idxr = 0 + counter = 1 + Do + proc=ipdxv(counter) + if (proc == -1) exit + n_el_recv = ipdxv(counter+psb_n_elem_recv_) + counter = counter+n_el_recv + n_el_send = ipdxv(counter+psb_n_elem_send_) + + bsdindx(proc+1) = idxs + idxs = idxs + sdsz(proc+1) + brvindx(proc+1) = idxr + idxr = idxr + rvsz(proc+1) + counter = counter+n_el_send+3 + Enddo + + iszr = sum(rvsz) + mat_recv = iszr + iszs = sum(sdsz) + + lnnz = max(iszr,iszs,ione) + lnc = a%get_ncols() + call acoo%allocate(lnr,lnc,lnnz) + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Sizes:',acoo%get_size(),& + & ' Send:',sdsz(:),' Receive:',rvsz(:) + + call psb_ensure_size(max(iszs,1),iasnd,info) + if (info == psb_success_) call psb_ensure_size(max(iszs,1),jasnd,info) + if (info == psb_success_) call psb_ensure_size(max(iszs,1),liasnd,info) + if (info == psb_success_) call psb_ensure_size(max(iszs,1),ljasnd,info) + if (info == psb_success_) call psb_ensure_size(max(iszs,1),valsnd,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='ensure_size') + goto 9999 + end if + + l1 = 0 + ipx = 1 + counter=1 + idx = 0 + ! + ! Make sure to get all columns in csget. + ! This is necessary when sphalo is used to compute a transpose, + ! as opposed to just gathering halo for spspmm purposes. + ! + ncg = huge(ncg) + tot_elem = 0 + Do + proc = ipdxv(counter) + if (proc == -1) exit + n_el_recv = ipdxv(counter+psb_n_elem_recv_) + counter = counter+n_el_recv + n_el_send = ipdxv(counter+psb_n_elem_send_) + + Do j=0,n_el_send-1 + idx = ipdxv(counter+psb_elem_send_+j) + n_elem = a%get_nz_row(idx) + call a%csget(idx,idx,ngtz,iasnd,jasnd,valsnd,info,& + & append=.true.,nzin=tot_elem,jmax=ncg) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='psb_sp_getrow') + goto 9999 + end if + tot_elem = tot_elem+ngtz + Enddo + counter = counter+n_el_send+3 + Enddo + nz = tot_elem + + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Going for alltoallv',iszs,iszr + if (rowcnv_) then + call psb_loc_to_glob(iasnd(1:nz),liasnd(1:nz),desc_a,info,iact='I') + else + liasnd(1:nz) = iasnd(1:nz) + end if + if (colcnv_) then + call psb_loc_to_glob(jasnd(1:nz),ljasnd(1:nz),col_desc_,info,iact='I') + else + ljasnd(1:nz) = jasnd(1:nz) + end if + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='psb_loc_to_glob') + goto 9999 + end if + + select case(psb_get_sp_a2av_alg()) + case(psb_sp_a2av_smpl_triad_) + call psb_simple_triad_a2av(valsnd,liasnd,ljasnd,sdsz,bsdindx,& + & acoo%val,acoo%ia,acoo%ja,rvsz,brvindx,ictxt,info) + case(psb_sp_a2av_smpl_v_) + call psb_simple_a2av(valsnd,sdsz,bsdindx,& + & acoo%val,rvsz,brvindx,ictxt,info) + if (info == psb_success_) call psb_simple_a2av(liasnd,sdsz,bsdindx,& + & acoo%ia,rvsz,brvindx,ictxt,info) + if (info == psb_success_) call psb_simple_a2av(ljasnd,sdsz,bsdindx,& + & acoo%ja,rvsz,brvindx,ictxt,info) + case(psb_sp_a2av_mpi_) + call mpi_alltoallv(valsnd,sdsz,bsdindx,psb_mpi_c_dpk_,& + & acoo%val,rvsz,brvindx,psb_mpi_c_dpk_,icomm,minfo) + if (minfo == mpi_success) & + & call mpi_alltoallv(liasnd,sdsz,bsdindx,psb_mpi_lpk_,& + & acoo%ia,rvsz,brvindx,psb_mpi_lpk_,icomm,minfo) + if (minfo == mpi_success) & + & call mpi_alltoallv(ljasnd,sdsz,bsdindx,psb_mpi_lpk_,& + & acoo%ja,rvsz,brvindx,psb_mpi_lpk_,icomm,minfo) + if (minfo /= mpi_success) info = minfo + case default + info = psb_err_internal_error_ + call psb_errpush(info,name,a_err='wrong A2AV alg selector') + goto 9999 + end select + + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='alltoallv') + goto 9999 + end if + + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Done alltoallv' + ! + ! Convert into local numbering + ! + if (rowcnv_) call psb_glob_to_loc(acoo%ia(1:iszr),desc_a,info,iact='I') + ! + ! This seems to be the correct output condition + ! + if (colcnv_.and.(.not.outcol_glob_)) & + & call psb_glob_to_loc(acoo%ja(1:iszr),col_desc_,info,iact='I') + + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='psbglob_to_loc') + goto 9999 + end if + + l1 = 0 + call acoo%set_nrows(lzero) + ! + irmin = huge(irmin) + icmin = huge(icmin) + irmax = 0 + icmax = 0 + Do i=1,iszr + r=(acoo%ia(i)) + k=(acoo%ja(i)) + ! Just in case some of the conversions were out-of-range + If ((r>0).and.(k>0)) Then + l1=l1+1 + acoo%val(l1) = acoo%val(i) + acoo%ia(l1) = r + acoo%ja(l1) = k + irmin = min(irmin,r) + irmax = max(irmax,r) + icmin = min(icmin,k) + icmax = max(icmax,k) + End If + Enddo + if (rowscale_) then + call acoo%set_nrows(max(irmax-irmin+1,0)) + acoo%ia(1:l1) = acoo%ia(1:l1) - irmin + 1 + else + call acoo%set_nrows(irmax) + end if + if (colscale_) then + call acoo%set_ncols(max(icmax-icmin+1,0)) + acoo%ja(1:l1) = acoo%ja(1:l1) - icmin + 1 + else + call acoo%set_ncols(icmax) + end if + + call acoo%set_nzeros(l1) + call acoo%set_sorted(.false.) + + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),& + & ': End data exchange',counter,l1 + + call acoo%fix(info) + if (info == psb_success_) call acoo%mv_to_fmt(blk,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='psb_spcnv') + goto 9999 + end if + + Deallocate(brvindx,bsdindx,rvsz,sdsz,& + & iasnd,jasnd,valsnd,stat=info) + if (debug_level >= psb_debug_outer_)& + & write(debug_unit,*) me,' ',trim(name),': Done' + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ictxt,err_act) + + return + +End Subroutine psb_z_lz_csr_halo