diff --git a/base/serial/impl/psb_c_csr_impl.f90 b/base/serial/impl/psb_c_csr_impl.f90 index b3125d8f..f5d28a39 100644 --- a/base/serial/impl/psb_c_csr_impl.f90 +++ b/base/serial/impl/psb_c_csr_impl.f90 @@ -3360,7 +3360,7 @@ subroutine psb_ccsrspspmm(a,b,c,info) class(psb_c_csr_sparse_mat), intent(in) :: a,b type(psb_c_csr_sparse_mat), intent(out) :: c integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_) :: nze, ma,na,mb,nb, nzc, nza, nzb,nzeb + integer(psb_ipk_) :: nze, ma,na,mb,nb, nzc, nza, nzb, nze, nzeb character(len=20) :: name integer(psb_ipk_) :: err_act name='psb_csrspspmm' @@ -3383,12 +3383,17 @@ subroutine psb_ccsrspspmm(a,b,c,info) goto 9999 endif + ! Estimate number of nonzeros on output. nza = a%get_nzeros() nzb = b%get_nzeros() nzc = 2*(nza+nzb) - - ! Estimate number of nonzeros on output. - ! Turns out this is often a large overestimate. + ma = max(ma,1) + ma = max(na,1) + mb = max(mb,1) + nb = max(nb,1) + nze = ma*(((nza+ma-1)/ma)*((nzb+mb-1)/mb) ) + nzeb = (((nza+na-1)/na)*((nzb+nb-1)/nb))*nb + nzc = min(nzc, nze+nzeb) call c%allocate(ma,nb,nzc) call csr_spspmm(a,b,c,info) diff --git a/base/serial/impl/psb_d_csr_impl.f90 b/base/serial/impl/psb_d_csr_impl.f90 index af7f7336..39d9df8e 100644 --- a/base/serial/impl/psb_d_csr_impl.f90 +++ b/base/serial/impl/psb_d_csr_impl.f90 @@ -3360,7 +3360,7 @@ subroutine psb_dcsrspspmm(a,b,c,info) class(psb_d_csr_sparse_mat), intent(in) :: a,b type(psb_d_csr_sparse_mat), intent(out) :: c integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_) :: nze, ma,na,mb,nb, nzc, nza, nzb,nzeb + integer(psb_ipk_) :: nze, ma,na,mb,nb, nzc, nza, nzb, nze, nzeb character(len=20) :: name integer(psb_ipk_) :: err_act name='psb_csrspspmm' @@ -3383,12 +3383,17 @@ subroutine psb_dcsrspspmm(a,b,c,info) goto 9999 endif + ! Estimate number of nonzeros on output. nza = a%get_nzeros() nzb = b%get_nzeros() nzc = 2*(nza+nzb) - - ! Estimate number of nonzeros on output. - ! Turns out this is often a large overestimate. + ma = max(ma,1) + ma = max(na,1) + mb = max(mb,1) + nb = max(nb,1) + nze = ma*(((nza+ma-1)/ma)*((nzb+mb-1)/mb) ) + nzeb = (((nza+na-1)/na)*((nzb+nb-1)/nb))*nb + nzc = min(nzc, nze+nzeb) call c%allocate(ma,nb,nzc) call csr_spspmm(a,b,c,info) diff --git a/base/serial/impl/psb_s_csr_impl.f90 b/base/serial/impl/psb_s_csr_impl.f90 index 7ee2a4ea..07a6b57e 100644 --- a/base/serial/impl/psb_s_csr_impl.f90 +++ b/base/serial/impl/psb_s_csr_impl.f90 @@ -3360,7 +3360,7 @@ subroutine psb_scsrspspmm(a,b,c,info) class(psb_s_csr_sparse_mat), intent(in) :: a,b type(psb_s_csr_sparse_mat), intent(out) :: c integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_) :: nze, ma,na,mb,nb, nzc, nza, nzb,nzeb + integer(psb_ipk_) :: nze, ma,na,mb,nb, nzc, nza, nzb, nze, nzeb character(len=20) :: name integer(psb_ipk_) :: err_act name='psb_csrspspmm' @@ -3383,12 +3383,17 @@ subroutine psb_scsrspspmm(a,b,c,info) goto 9999 endif + ! Estimate number of nonzeros on output. nza = a%get_nzeros() nzb = b%get_nzeros() nzc = 2*(nza+nzb) - - ! Estimate number of nonzeros on output. - ! Turns out this is often a large overestimate. + ma = max(ma,1) + ma = max(na,1) + mb = max(mb,1) + nb = max(nb,1) + nze = ma*(((nza+ma-1)/ma)*((nzb+mb-1)/mb) ) + nzeb = (((nza+na-1)/na)*((nzb+nb-1)/nb))*nb + nzc = min(nzc, nze+nzeb) call c%allocate(ma,nb,nzc) call csr_spspmm(a,b,c,info) diff --git a/base/serial/impl/psb_z_csr_impl.f90 b/base/serial/impl/psb_z_csr_impl.f90 index 42c8c5be..7e58e1a3 100644 --- a/base/serial/impl/psb_z_csr_impl.f90 +++ b/base/serial/impl/psb_z_csr_impl.f90 @@ -3360,7 +3360,7 @@ subroutine psb_zcsrspspmm(a,b,c,info) class(psb_z_csr_sparse_mat), intent(in) :: a,b type(psb_z_csr_sparse_mat), intent(out) :: c integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_) :: nze, ma,na,mb,nb, nzc, nza, nzb,nzeb + integer(psb_ipk_) :: nze, ma,na,mb,nb, nzc, nza, nzb, nze, nzeb character(len=20) :: name integer(psb_ipk_) :: err_act name='psb_csrspspmm' @@ -3383,12 +3383,17 @@ subroutine psb_zcsrspspmm(a,b,c,info) goto 9999 endif + ! Estimate number of nonzeros on output. nza = a%get_nzeros() nzb = b%get_nzeros() nzc = 2*(nza+nzb) - - ! Estimate number of nonzeros on output. - ! Turns out this is often a large overestimate. + ma = max(ma,1) + ma = max(na,1) + mb = max(mb,1) + nb = max(nb,1) + nze = ma*(((nza+ma-1)/ma)*((nzb+mb-1)/mb) ) + nzeb = (((nza+na-1)/na)*((nzb+nb-1)/nb))*nb + nzc = min(nzc, nze+nzeb) call c%allocate(ma,nb,nzc) call csr_spspmm(a,b,c,info)