diff --git a/base/modules/psb_c_serial_mod.f90 b/base/modules/psb_c_serial_mod.f90 index 35e6f761..f5532aa5 100644 --- a/base/modules/psb_c_serial_mod.f90 +++ b/base/modules/psb_c_serial_mod.f90 @@ -51,6 +51,33 @@ module psb_c_serial_mod end function psb_casum_s end interface psb_asum + interface psb_spspmm + subroutine psb_cspspmm(a,b,c,info) + use psb_c_mat_mod, only : psb_cspmat_type + import :: psb_ipk_ + implicit none + type(psb_cspmat_type), intent(in) :: a,b + type(psb_cspmat_type), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + end subroutine psb_cspspmm + subroutine psb_ccsrspspmm(a,b,c,info) + use psb_c_mat_mod, only : psb_c_csr_sparse_mat + import :: psb_ipk_ + implicit none + class(psb_c_csr_sparse_mat), intent(in) :: a,b + type(psb_c_csr_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + end subroutine psb_ccsrspspmm + subroutine psb_ccscspspmm(a,b,c,info) + use psb_c_mat_mod, only : psb_c_csc_sparse_mat + import :: psb_ipk_ + implicit none + class(psb_c_csc_sparse_mat), intent(in) :: a,b + type(psb_c_csc_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + end subroutine psb_ccscspspmm + end interface + interface psb_symbmm subroutine psb_csymbmm(a,b,c,info) use psb_c_mat_mod, only : psb_cspmat_type diff --git a/base/modules/psb_d_serial_mod.f90 b/base/modules/psb_d_serial_mod.f90 index 5fc42d12..97a77ee6 100644 --- a/base/modules/psb_d_serial_mod.f90 +++ b/base/modules/psb_d_serial_mod.f90 @@ -51,6 +51,33 @@ module psb_d_serial_mod end function psb_dasum_s end interface psb_asum + interface psb_spspmm + subroutine psb_dspspmm(a,b,c,info) + use psb_d_mat_mod, only : psb_dspmat_type + import :: psb_ipk_ + implicit none + type(psb_dspmat_type), intent(in) :: a,b + type(psb_dspmat_type), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + end subroutine psb_dspspmm + subroutine psb_dcsrspspmm(a,b,c,info) + use psb_d_mat_mod, only : psb_d_csr_sparse_mat + import :: psb_ipk_ + implicit none + class(psb_d_csr_sparse_mat), intent(in) :: a,b + type(psb_d_csr_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + end subroutine psb_dcsrspspmm + subroutine psb_dcscspspmm(a,b,c,info) + use psb_d_mat_mod, only : psb_d_csc_sparse_mat + import :: psb_ipk_ + implicit none + class(psb_d_csc_sparse_mat), intent(in) :: a,b + type(psb_d_csc_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + end subroutine psb_dcscspspmm + end interface + interface psb_symbmm subroutine psb_dsymbmm(a,b,c,info) use psb_d_mat_mod, only : psb_dspmat_type diff --git a/base/modules/psb_s_serial_mod.f90 b/base/modules/psb_s_serial_mod.f90 index 2de84bc3..75e39d7f 100644 --- a/base/modules/psb_s_serial_mod.f90 +++ b/base/modules/psb_s_serial_mod.f90 @@ -51,6 +51,33 @@ module psb_s_serial_mod end function psb_sasum_s end interface psb_asum + interface psb_spspmm + subroutine psb_sspspmm(a,b,c,info) + use psb_s_mat_mod, only : psb_sspmat_type + import :: psb_ipk_ + implicit none + type(psb_sspmat_type), intent(in) :: a,b + type(psb_sspmat_type), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + end subroutine psb_sspspmm + subroutine psb_scsrspspmm(a,b,c,info) + use psb_s_mat_mod, only : psb_s_csr_sparse_mat + import :: psb_ipk_ + implicit none + class(psb_s_csr_sparse_mat), intent(in) :: a,b + type(psb_s_csr_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + end subroutine psb_scsrspspmm + subroutine psb_scscspspmm(a,b,c,info) + use psb_s_mat_mod, only : psb_s_csc_sparse_mat + import :: psb_ipk_ + implicit none + class(psb_s_csc_sparse_mat), intent(in) :: a,b + type(psb_s_csc_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + end subroutine psb_scscspspmm + end interface + interface psb_symbmm subroutine psb_ssymbmm(a,b,c,info) use psb_s_mat_mod, only : psb_sspmat_type diff --git a/base/modules/psb_z_serial_mod.f90 b/base/modules/psb_z_serial_mod.f90 index b2f372af..8b28b9c6 100644 --- a/base/modules/psb_z_serial_mod.f90 +++ b/base/modules/psb_z_serial_mod.f90 @@ -51,6 +51,33 @@ module psb_z_serial_mod end function psb_zasum_s end interface psb_asum + interface psb_spspmm + subroutine psb_zspspmm(a,b,c,info) + use psb_z_mat_mod, only : psb_zspmat_type + import :: psb_ipk_ + implicit none + type(psb_zspmat_type), intent(in) :: a,b + type(psb_zspmat_type), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + end subroutine psb_zspspmm + subroutine psb_zcsrspspmm(a,b,c,info) + use psb_z_mat_mod, only : psb_z_csr_sparse_mat + import :: psb_ipk_ + implicit none + class(psb_z_csr_sparse_mat), intent(in) :: a,b + type(psb_z_csr_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + end subroutine psb_zcsrspspmm + subroutine psb_zcscspspmm(a,b,c,info) + use psb_z_mat_mod, only : psb_z_csc_sparse_mat + import :: psb_ipk_ + implicit none + class(psb_z_csc_sparse_mat), intent(in) :: a,b + type(psb_z_csc_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + end subroutine psb_zcscspspmm + end interface + interface psb_symbmm subroutine psb_zsymbmm(a,b,c,info) use psb_z_mat_mod, only : psb_zspmat_type diff --git a/base/serial/Makefile b/base/serial/Makefile index e8572326..7184dc4a 100644 --- a/base/serial/Makefile +++ b/base/serial/Makefile @@ -3,6 +3,7 @@ include ../../Make.inc FOBJS = psb_lsame.o psi_serial_impl.o psb_sort_impl.o \ psb_srwextd.o psb_drwextd.o psb_crwextd.o psb_zrwextd.o \ + psb_sspspmm.o psb_dspspmm.o psb_cspspmm.o psb_zspspmm.o \ psb_ssymbmm.o psb_dsymbmm.o psb_csymbmm.o psb_zsymbmm.o \ psb_snumbmm.o psb_dnumbmm.o psb_cnumbmm.o psb_znumbmm.o \ psb_sgeprt.o psb_dgeprt.o psb_cgeprt.o psb_zgeprt.o\ diff --git a/base/serial/impl/psb_c_csc_impl.f90 b/base/serial/impl/psb_c_csc_impl.f90 index d5314261..63875efd 100644 --- a/base/serial/impl/psb_c_csc_impl.f90 +++ b/base/serial/impl/psb_c_csc_impl.f90 @@ -2521,7 +2521,7 @@ subroutine psb_c_cp_csc_to_fmt(a,b,info) !locals type(psb_c_coo_sparse_mat) :: tmp logical :: rwshr_ - integer(psb_ipk_) :: nza, nr, i,j,irw, err_act, nc + integer(psb_ipk_) :: nz, nr, i,j,irw, err_act, nc integer(psb_ipk_), Parameter :: maxtry=8 integer(psb_ipk_) :: debug_level, debug_unit character(len=20) :: name @@ -2535,9 +2535,11 @@ subroutine psb_c_cp_csc_to_fmt(a,b,info) type is (psb_c_csc_sparse_mat) b%psb_c_base_sparse_mat = a%psb_c_base_sparse_mat - if (info == 0) call psb_safe_cpy( a%icp, b%icp , info) - if (info == 0) call psb_safe_cpy( a%ia , b%ia , info) - if (info == 0) call psb_safe_cpy( a%val, b%val , info) + nc = a%get_ncols() + nz = a%get_nzeros() + if (info == 0) call psb_safe_cpy( a%icp(1:nc+1), b%icp , info) + if (info == 0) call psb_safe_cpy( a%ia(1:nz), b%ia , info) + if (info == 0) call psb_safe_cpy( a%val(1:nz), b%val , info) class default call a%cp_to_coo(tmp,info) @@ -2602,7 +2604,7 @@ subroutine psb_c_cp_csc_from_fmt(a,b,info) !locals type(psb_c_coo_sparse_mat) :: tmp logical :: rwshr_ - integer(psb_ipk_) :: nza, nr, i,j,irw, err_act, nc + integer(psb_ipk_) :: nz, nr, i,j,irw, err_act, nc integer(psb_ipk_), Parameter :: maxtry=8 integer(psb_ipk_) :: debug_level, debug_unit character(len=20) :: name @@ -2615,9 +2617,11 @@ subroutine psb_c_cp_csc_from_fmt(a,b,info) type is (psb_c_csc_sparse_mat) a%psb_c_base_sparse_mat = b%psb_c_base_sparse_mat - if (info == 0) call psb_safe_cpy( b%icp, a%icp , info) - if (info == 0) call psb_safe_cpy( b%ia , a%ia , info) - if (info == 0) call psb_safe_cpy( b%val, a%val , info) + nc = b%get_ncols() + nz = b%get_nzeros() + if (info == 0) call psb_safe_cpy( b%icp(1:nc+1), a%icp , info) + if (info == 0) call psb_safe_cpy( b%ia(1:nz), a%ia , info) + if (info == 0) call psb_safe_cpy( b%val(1:nz), a%val , info) class default call b%cp_to_coo(tmp,info) @@ -2985,3 +2989,124 @@ subroutine psb_c_csc_print(iout,a,iv,head,ivr,ivc) end subroutine psb_c_csc_print +subroutine psb_ccscspspmm(a,b,c,info) + use psb_c_mat_mod + use psb_serial_mod, psb_protect_name => psb_ccscspspmm + + implicit none + + class(psb_c_csc_sparse_mat), intent(in) :: a,b + type(psb_c_csc_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: nze, ma,na,mb,nb, nzc, nza, nzb,nzeb + character(len=20) :: name + integer(psb_ipk_) :: err_act + name='psb_cscspspmm' + call psb_erractionsave(err_act) + info = psb_success_ + + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + + if ( mb /= na ) then + write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb + endif + nza = a%get_nzeros() + nzb = b%get_nzeros() + nzc = 2*(nza+nzb) + nze = ma*(((nza+ma-1)/ma)*((nzb+mb-1)/mb) ) + nzeb = (((nza+na-1)/na)*((nzb+nb-1)/nb))*nb + ! Estimate number of nonzeros on output. + ! Turns out this is often a large overestimate. + call c%allocate(ma,nb,min(nzc,nze,nzeb)) + + + call csc_spspmm(a,b,c,info) + + call c%set_asb() + + call psb_erractionrestore(err_act) + return + +9999 continue + call psb_erractionrestore(err_act) + if (err_act == psb_act_abort_) then + call psb_error() + return + end if + return + +contains + + subroutine csc_spspmm(a,b,c,info) + implicit none + type(psb_c_csc_sparse_mat), intent(in) :: a,b + type(psb_c_csc_sparse_mat), intent(inout) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: ma,na,mb,nb + integer(psb_ipk_), allocatable :: icol(:), idxs(:), iaux(:) + complex(psb_spk_), allocatable :: col(:) + type(psb_int_heap) :: heap + integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, & + & nzc,nnzre, isz, ipb, irwsz, nrc, nze + complex(psb_spk_) :: cfb + + + info = psb_success_ + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + nze = min(size(c%val),size(c%ia)) + isz = max(ma,na,mb,nb) + call psb_realloc(isz,col,info) + if (info == 0) call psb_realloc(isz,idxs,info) + if (info == 0) call psb_realloc(isz,icol,info) + if (info /= 0) return + col = dzero + icol = 0 + nzc = 1 + do j = 1,nb + c%icp(j) = nzc + nrc = 0 + do k = b%icp(j), b%icp(j+1)-1 + icl = b%ia(k) + cfb = b%val(k) + irwsz = a%icp(icl+1)-a%icp(icl) + do i = a%icp(icl),a%icp(icl+1)-1 + irw = a%ia(i) + if (icol(irw) 0 ) then + if ((nzc+nrc)>nze) then + nze = max(nb*((nzc+j-1)/j),nzc+2*nrc) + call psb_realloc(nze,c%val,info) + if (info == 0) call psb_realloc(nze,c%ia,info) + if (info /= 0) return + end if + call psb_msort(idxs(1:nrc)) + do i=1, nrc + irw = idxs(i) + c%ia(nzc) = irw + c%val(nzc) = col(irw) + col(irw) = dzero + nzc = nzc + 1 + end do + end if + end do + + c%icp(nb+1) = nzc + + end subroutine csc_spspmm + +end subroutine psb_ccscspspmm diff --git a/base/serial/impl/psb_c_csr_impl.f90 b/base/serial/impl/psb_c_csr_impl.f90 index 6d14e924..7bd2fc64 100644 --- a/base/serial/impl/psb_c_csr_impl.f90 +++ b/base/serial/impl/psb_c_csr_impl.f90 @@ -3128,7 +3128,7 @@ subroutine psb_c_cp_csr_to_fmt(a,b,info) !locals type(psb_c_coo_sparse_mat) :: tmp logical :: rwshr_ - integer(psb_ipk_) :: nza, nr, i,j,irw, err_act, nc + integer(psb_ipk_) :: nz, nr, i,j,irw, err_act, nc integer(psb_ipk_), Parameter :: maxtry=8 integer(psb_ipk_) :: debug_level, debug_unit character(len=20) :: name @@ -3142,9 +3142,11 @@ subroutine psb_c_cp_csr_to_fmt(a,b,info) type is (psb_c_csr_sparse_mat) b%psb_c_base_sparse_mat = a%psb_c_base_sparse_mat - if (info == 0) call psb_safe_cpy( a%irp, b%irp , info) - if (info == 0) call psb_safe_cpy( a%ja , b%ja , info) - if (info == 0) call psb_safe_cpy( a%val, b%val , info) + nr = a%get_nrows() + nz = a%get_nzeros() + if (info == 0) call psb_safe_cpy( a%irp(1:nr+1), b%irp , info) + if (info == 0) call psb_safe_cpy( a%ja(1:nz), b%ja , info) + if (info == 0) call psb_safe_cpy( a%val(1:nz), b%val , info) class default call a%cp_to_coo(tmp,info) @@ -3221,12 +3223,137 @@ subroutine psb_c_cp_csr_from_fmt(a,b,info) type is (psb_c_csr_sparse_mat) a%psb_c_base_sparse_mat = b%psb_c_base_sparse_mat - if (info == 0) call psb_safe_cpy( b%irp, a%irp , info) - if (info == 0) call psb_safe_cpy( b%ja , a%ja , info) - if (info == 0) call psb_safe_cpy( b%val, a%val , info) + nr = b%get_nrows() + nz = b%get_nzeros() + if (info == 0) call psb_safe_cpy( b%irp(1:nr+1), a%irp , info) + if (info == 0) call psb_safe_cpy( b%ja(1:nz) , a%ja , info) + if (info == 0) call psb_safe_cpy( b%val(1:nz) , a%val , info) class default call b%cp_to_coo(tmp,info) if (info == psb_success_) call a%mv_from_coo(tmp,info) end select end subroutine psb_c_cp_csr_from_fmt + +subroutine psb_ccsrspspmm(a,b,c,info) + use psb_c_mat_mod + use psb_serial_mod, psb_protect_name => psb_ccsrspspmm + + implicit none + + class(psb_c_csr_sparse_mat), intent(in) :: a,b + type(psb_c_csr_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: nze, ma,na,mb,nb, nzc, nza, nzb,nzeb + character(len=20) :: name + integer(psb_ipk_) :: err_act + name='psb_csrspspmm' + call psb_erractionsave(err_act) + info = psb_success_ + + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + + if ( mb /= na ) then + write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb + endif + nza = a%get_nzeros() + nzb = b%get_nzeros() + nzc = 2*(nza+nzb) + nze = ma*(((nza+ma-1)/ma)*((nzb+mb-1)/mb) ) + nzeb = (((nza+na-1)/na)*((nzb+nb-1)/nb))*nb + ! Estimate number of nonzeros on output. + ! Turns out this is often a large overestimate. + call c%allocate(ma,nb,min(nzc,nze,nzeb)) + + call csr_spspmm(a,b,c,info) + + call c%set_asb() + + call psb_erractionrestore(err_act) + return + +9999 continue + call psb_erractionrestore(err_act) + if (err_act == psb_act_abort_) then + call psb_error() + return + end if + return + +contains + + subroutine csr_spspmm(a,b,c,info) + implicit none + type(psb_c_csr_sparse_mat), intent(in) :: a,b + type(psb_c_csr_sparse_mat), intent(inout) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: ma,na,mb,nb + integer(psb_ipk_), allocatable :: irow(:), idxs(:) + complex(psb_spk_), allocatable :: row(:) + type(psb_int_heap) :: heap + integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, & + & nzc,nnzre, isz, ipb, irwsz, nrc, nze + complex(psb_spk_) :: cfb + + + info = psb_success_ + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + nze = min(size(c%val),size(c%ja)) + isz = max(ma,na,mb,nb) + call psb_realloc(isz,row,info) + if (info == 0) call psb_realloc(isz,idxs,info) + if (info == 0) call psb_realloc(isz,irow,info) + if (info /= 0) return + row = dzero + irow = 0 + nzc = 1 + do j = 1,ma + c%irp(j) = nzc + nrc = 0 + do k = a%irp(j), a%irp(j+1)-1 + irw = a%ja(k) + cfb = a%val(k) + irwsz = b%irp(irw+1)-b%irp(irw) + do i = b%irp(irw),b%irp(irw+1)-1 + icl = b%ja(i) + if (irow(icl) 0 ) then + if ((nzc+nrc)>nze) then + nze = max(ma*((nzc+j-1)/j),nzc+2*nrc) + call psb_realloc(nze,c%val,info) + if (info == 0) call psb_realloc(nze,c%ja,info) + if (info /= 0) return + end if + + call psb_msort(idxs(1:nrc)) + do i=1, nrc + irw = idxs(i) + c%ja(nzc) = irw + c%val(nzc) = row(irw) + row(irw) = dzero + nzc = nzc + 1 + end do + end if + end do + + c%irp(ma+1) = nzc + + + end subroutine csr_spspmm + +end subroutine psb_ccsrspspmm diff --git a/base/serial/impl/psb_d_csc_impl.f90 b/base/serial/impl/psb_d_csc_impl.f90 index f94bd8dd..18f829a0 100644 --- a/base/serial/impl/psb_d_csc_impl.f90 +++ b/base/serial/impl/psb_d_csc_impl.f90 @@ -2521,7 +2521,7 @@ subroutine psb_d_cp_csc_to_fmt(a,b,info) !locals type(psb_d_coo_sparse_mat) :: tmp logical :: rwshr_ - integer(psb_ipk_) :: nza, nr, i,j,irw, err_act, nc + integer(psb_ipk_) :: nz, nr, i,j,irw, err_act, nc integer(psb_ipk_), Parameter :: maxtry=8 integer(psb_ipk_) :: debug_level, debug_unit character(len=20) :: name @@ -2535,9 +2535,11 @@ subroutine psb_d_cp_csc_to_fmt(a,b,info) type is (psb_d_csc_sparse_mat) b%psb_d_base_sparse_mat = a%psb_d_base_sparse_mat - if (info == 0) call psb_safe_cpy( a%icp, b%icp , info) - if (info == 0) call psb_safe_cpy( a%ia , b%ia , info) - if (info == 0) call psb_safe_cpy( a%val, b%val , info) + nc = a%get_ncols() + nz = a%get_nzeros() + if (info == 0) call psb_safe_cpy( a%icp(1:nc+1), b%icp , info) + if (info == 0) call psb_safe_cpy( a%ia(1:nz), b%ia , info) + if (info == 0) call psb_safe_cpy( a%val(1:nz), b%val , info) class default call a%cp_to_coo(tmp,info) @@ -2602,7 +2604,7 @@ subroutine psb_d_cp_csc_from_fmt(a,b,info) !locals type(psb_d_coo_sparse_mat) :: tmp logical :: rwshr_ - integer(psb_ipk_) :: nza, nr, i,j,irw, err_act, nc + integer(psb_ipk_) :: nz, nr, i,j,irw, err_act, nc integer(psb_ipk_), Parameter :: maxtry=8 integer(psb_ipk_) :: debug_level, debug_unit character(len=20) :: name @@ -2615,9 +2617,11 @@ subroutine psb_d_cp_csc_from_fmt(a,b,info) type is (psb_d_csc_sparse_mat) a%psb_d_base_sparse_mat = b%psb_d_base_sparse_mat - if (info == 0) call psb_safe_cpy( b%icp, a%icp , info) - if (info == 0) call psb_safe_cpy( b%ia , a%ia , info) - if (info == 0) call psb_safe_cpy( b%val, a%val , info) + nc = b%get_ncols() + nz = b%get_nzeros() + if (info == 0) call psb_safe_cpy( b%icp(1:nc+1), a%icp , info) + if (info == 0) call psb_safe_cpy( b%ia(1:nz), a%ia , info) + if (info == 0) call psb_safe_cpy( b%val(1:nz), a%val , info) class default call b%cp_to_coo(tmp,info) @@ -2985,3 +2989,124 @@ subroutine psb_d_csc_print(iout,a,iv,head,ivr,ivc) end subroutine psb_d_csc_print +subroutine psb_dcscspspmm(a,b,c,info) + use psb_d_mat_mod + use psb_serial_mod, psb_protect_name => psb_dcscspspmm + + implicit none + + class(psb_d_csc_sparse_mat), intent(in) :: a,b + type(psb_d_csc_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: nze, ma,na,mb,nb, nzc, nza, nzb,nzeb + character(len=20) :: name + integer(psb_ipk_) :: err_act + name='psb_cscspspmm' + call psb_erractionsave(err_act) + info = psb_success_ + + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + + if ( mb /= na ) then + write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb + endif + nza = a%get_nzeros() + nzb = b%get_nzeros() + nzc = 2*(nza+nzb) + nze = ma*(((nza+ma-1)/ma)*((nzb+mb-1)/mb) ) + nzeb = (((nza+na-1)/na)*((nzb+nb-1)/nb))*nb + ! Estimate number of nonzeros on output. + ! Turns out this is often a large overestimate. + call c%allocate(ma,nb,min(nzc,nze,nzeb)) + + + call csc_spspmm(a,b,c,info) + + call c%set_asb() + + call psb_erractionrestore(err_act) + return + +9999 continue + call psb_erractionrestore(err_act) + if (err_act == psb_act_abort_) then + call psb_error() + return + end if + return + +contains + + subroutine csc_spspmm(a,b,c,info) + implicit none + type(psb_d_csc_sparse_mat), intent(in) :: a,b + type(psb_d_csc_sparse_mat), intent(inout) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: ma,na,mb,nb + integer(psb_ipk_), allocatable :: icol(:), idxs(:), iaux(:) + real(psb_dpk_), allocatable :: col(:) + type(psb_int_heap) :: heap + integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, & + & nzc,nnzre, isz, ipb, irwsz, nrc, nze + real(psb_dpk_) :: cfb + + + info = psb_success_ + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + nze = min(size(c%val),size(c%ia)) + isz = max(ma,na,mb,nb) + call psb_realloc(isz,col,info) + if (info == 0) call psb_realloc(isz,idxs,info) + if (info == 0) call psb_realloc(isz,icol,info) + if (info /= 0) return + col = dzero + icol = 0 + nzc = 1 + do j = 1,nb + c%icp(j) = nzc + nrc = 0 + do k = b%icp(j), b%icp(j+1)-1 + icl = b%ia(k) + cfb = b%val(k) + irwsz = a%icp(icl+1)-a%icp(icl) + do i = a%icp(icl),a%icp(icl+1)-1 + irw = a%ia(i) + if (icol(irw) 0 ) then + if ((nzc+nrc)>nze) then + nze = max(nb*((nzc+j-1)/j),nzc+2*nrc) + call psb_realloc(nze,c%val,info) + if (info == 0) call psb_realloc(nze,c%ia,info) + if (info /= 0) return + end if + call psb_msort(idxs(1:nrc)) + do i=1, nrc + irw = idxs(i) + c%ia(nzc) = irw + c%val(nzc) = col(irw) + col(irw) = dzero + nzc = nzc + 1 + end do + end if + end do + + c%icp(nb+1) = nzc + + end subroutine csc_spspmm + +end subroutine psb_dcscspspmm diff --git a/base/serial/impl/psb_d_csr_impl.f90 b/base/serial/impl/psb_d_csr_impl.f90 index d081e016..fa229715 100644 --- a/base/serial/impl/psb_d_csr_impl.f90 +++ b/base/serial/impl/psb_d_csr_impl.f90 @@ -3128,7 +3128,7 @@ subroutine psb_d_cp_csr_to_fmt(a,b,info) !locals type(psb_d_coo_sparse_mat) :: tmp logical :: rwshr_ - integer(psb_ipk_) :: nza, nr, i,j,irw, err_act, nc + integer(psb_ipk_) :: nz, nr, i,j,irw, err_act, nc integer(psb_ipk_), Parameter :: maxtry=8 integer(psb_ipk_) :: debug_level, debug_unit character(len=20) :: name @@ -3142,9 +3142,11 @@ subroutine psb_d_cp_csr_to_fmt(a,b,info) type is (psb_d_csr_sparse_mat) b%psb_d_base_sparse_mat = a%psb_d_base_sparse_mat - if (info == 0) call psb_safe_cpy( a%irp, b%irp , info) - if (info == 0) call psb_safe_cpy( a%ja , b%ja , info) - if (info == 0) call psb_safe_cpy( a%val, b%val , info) + nr = a%get_nrows() + nz = a%get_nzeros() + if (info == 0) call psb_safe_cpy( a%irp(1:nr+1), b%irp , info) + if (info == 0) call psb_safe_cpy( a%ja(1:nz), b%ja , info) + if (info == 0) call psb_safe_cpy( a%val(1:nz), b%val , info) class default call a%cp_to_coo(tmp,info) @@ -3221,12 +3223,137 @@ subroutine psb_d_cp_csr_from_fmt(a,b,info) type is (psb_d_csr_sparse_mat) a%psb_d_base_sparse_mat = b%psb_d_base_sparse_mat - if (info == 0) call psb_safe_cpy( b%irp, a%irp , info) - if (info == 0) call psb_safe_cpy( b%ja , a%ja , info) - if (info == 0) call psb_safe_cpy( b%val, a%val , info) + nr = b%get_nrows() + nz = b%get_nzeros() + if (info == 0) call psb_safe_cpy( b%irp(1:nr+1), a%irp , info) + if (info == 0) call psb_safe_cpy( b%ja(1:nz) , a%ja , info) + if (info == 0) call psb_safe_cpy( b%val(1:nz) , a%val , info) class default call b%cp_to_coo(tmp,info) if (info == psb_success_) call a%mv_from_coo(tmp,info) end select end subroutine psb_d_cp_csr_from_fmt + +subroutine psb_dcsrspspmm(a,b,c,info) + use psb_d_mat_mod + use psb_serial_mod, psb_protect_name => psb_dcsrspspmm + + implicit none + + class(psb_d_csr_sparse_mat), intent(in) :: a,b + type(psb_d_csr_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: nze, ma,na,mb,nb, nzc, nza, nzb,nzeb + character(len=20) :: name + integer(psb_ipk_) :: err_act + name='psb_csrspspmm' + call psb_erractionsave(err_act) + info = psb_success_ + + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + + if ( mb /= na ) then + write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb + endif + nza = a%get_nzeros() + nzb = b%get_nzeros() + nzc = 2*(nza+nzb) + nze = ma*(((nza+ma-1)/ma)*((nzb+mb-1)/mb) ) + nzeb = (((nza+na-1)/na)*((nzb+nb-1)/nb))*nb + ! Estimate number of nonzeros on output. + ! Turns out this is often a large overestimate. + call c%allocate(ma,nb,min(nzc,nze,nzeb)) + + call csr_spspmm(a,b,c,info) + + call c%set_asb() + + call psb_erractionrestore(err_act) + return + +9999 continue + call psb_erractionrestore(err_act) + if (err_act == psb_act_abort_) then + call psb_error() + return + end if + return + +contains + + subroutine csr_spspmm(a,b,c,info) + implicit none + type(psb_d_csr_sparse_mat), intent(in) :: a,b + type(psb_d_csr_sparse_mat), intent(inout) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: ma,na,mb,nb + integer(psb_ipk_), allocatable :: irow(:), idxs(:) + real(psb_dpk_), allocatable :: row(:) + type(psb_int_heap) :: heap + integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, & + & nzc,nnzre, isz, ipb, irwsz, nrc, nze + real(psb_dpk_) :: cfb + + + info = psb_success_ + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + nze = min(size(c%val),size(c%ja)) + isz = max(ma,na,mb,nb) + call psb_realloc(isz,row,info) + if (info == 0) call psb_realloc(isz,idxs,info) + if (info == 0) call psb_realloc(isz,irow,info) + if (info /= 0) return + row = dzero + irow = 0 + nzc = 1 + do j = 1,ma + c%irp(j) = nzc + nrc = 0 + do k = a%irp(j), a%irp(j+1)-1 + irw = a%ja(k) + cfb = a%val(k) + irwsz = b%irp(irw+1)-b%irp(irw) + do i = b%irp(irw),b%irp(irw+1)-1 + icl = b%ja(i) + if (irow(icl) 0 ) then + if ((nzc+nrc)>nze) then + nze = max(ma*((nzc+j-1)/j),nzc+2*nrc) + call psb_realloc(nze,c%val,info) + if (info == 0) call psb_realloc(nze,c%ja,info) + if (info /= 0) return + end if + + call psb_msort(idxs(1:nrc)) + do i=1, nrc + irw = idxs(i) + c%ja(nzc) = irw + c%val(nzc) = row(irw) + row(irw) = dzero + nzc = nzc + 1 + end do + end if + end do + + c%irp(ma+1) = nzc + + + end subroutine csr_spspmm + +end subroutine psb_dcsrspspmm diff --git a/base/serial/impl/psb_s_csc_impl.f90 b/base/serial/impl/psb_s_csc_impl.f90 index d174908a..017c8513 100644 --- a/base/serial/impl/psb_s_csc_impl.f90 +++ b/base/serial/impl/psb_s_csc_impl.f90 @@ -2521,7 +2521,7 @@ subroutine psb_s_cp_csc_to_fmt(a,b,info) !locals type(psb_s_coo_sparse_mat) :: tmp logical :: rwshr_ - integer(psb_ipk_) :: nza, nr, i,j,irw, err_act, nc + integer(psb_ipk_) :: nz, nr, i,j,irw, err_act, nc integer(psb_ipk_), Parameter :: maxtry=8 integer(psb_ipk_) :: debug_level, debug_unit character(len=20) :: name @@ -2535,9 +2535,11 @@ subroutine psb_s_cp_csc_to_fmt(a,b,info) type is (psb_s_csc_sparse_mat) b%psb_s_base_sparse_mat = a%psb_s_base_sparse_mat - if (info == 0) call psb_safe_cpy( a%icp, b%icp , info) - if (info == 0) call psb_safe_cpy( a%ia , b%ia , info) - if (info == 0) call psb_safe_cpy( a%val, b%val , info) + nc = a%get_ncols() + nz = a%get_nzeros() + if (info == 0) call psb_safe_cpy( a%icp(1:nc+1), b%icp , info) + if (info == 0) call psb_safe_cpy( a%ia(1:nz), b%ia , info) + if (info == 0) call psb_safe_cpy( a%val(1:nz), b%val , info) class default call a%cp_to_coo(tmp,info) @@ -2602,7 +2604,7 @@ subroutine psb_s_cp_csc_from_fmt(a,b,info) !locals type(psb_s_coo_sparse_mat) :: tmp logical :: rwshr_ - integer(psb_ipk_) :: nza, nr, i,j,irw, err_act, nc + integer(psb_ipk_) :: nz, nr, i,j,irw, err_act, nc integer(psb_ipk_), Parameter :: maxtry=8 integer(psb_ipk_) :: debug_level, debug_unit character(len=20) :: name @@ -2615,9 +2617,11 @@ subroutine psb_s_cp_csc_from_fmt(a,b,info) type is (psb_s_csc_sparse_mat) a%psb_s_base_sparse_mat = b%psb_s_base_sparse_mat - if (info == 0) call psb_safe_cpy( b%icp, a%icp , info) - if (info == 0) call psb_safe_cpy( b%ia , a%ia , info) - if (info == 0) call psb_safe_cpy( b%val, a%val , info) + nc = b%get_ncols() + nz = b%get_nzeros() + if (info == 0) call psb_safe_cpy( b%icp(1:nc+1), a%icp , info) + if (info == 0) call psb_safe_cpy( b%ia(1:nz), a%ia , info) + if (info == 0) call psb_safe_cpy( b%val(1:nz), a%val , info) class default call b%cp_to_coo(tmp,info) @@ -2985,3 +2989,124 @@ subroutine psb_s_csc_print(iout,a,iv,head,ivr,ivc) end subroutine psb_s_csc_print +subroutine psb_scscspspmm(a,b,c,info) + use psb_s_mat_mod + use psb_serial_mod, psb_protect_name => psb_scscspspmm + + implicit none + + class(psb_s_csc_sparse_mat), intent(in) :: a,b + type(psb_s_csc_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: nze, ma,na,mb,nb, nzc, nza, nzb,nzeb + character(len=20) :: name + integer(psb_ipk_) :: err_act + name='psb_cscspspmm' + call psb_erractionsave(err_act) + info = psb_success_ + + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + + if ( mb /= na ) then + write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb + endif + nza = a%get_nzeros() + nzb = b%get_nzeros() + nzc = 2*(nza+nzb) + nze = ma*(((nza+ma-1)/ma)*((nzb+mb-1)/mb) ) + nzeb = (((nza+na-1)/na)*((nzb+nb-1)/nb))*nb + ! Estimate number of nonzeros on output. + ! Turns out this is often a large overestimate. + call c%allocate(ma,nb,min(nzc,nze,nzeb)) + + + call csc_spspmm(a,b,c,info) + + call c%set_asb() + + call psb_erractionrestore(err_act) + return + +9999 continue + call psb_erractionrestore(err_act) + if (err_act == psb_act_abort_) then + call psb_error() + return + end if + return + +contains + + subroutine csc_spspmm(a,b,c,info) + implicit none + type(psb_s_csc_sparse_mat), intent(in) :: a,b + type(psb_s_csc_sparse_mat), intent(inout) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: ma,na,mb,nb + integer(psb_ipk_), allocatable :: icol(:), idxs(:), iaux(:) + real(psb_spk_), allocatable :: col(:) + type(psb_int_heap) :: heap + integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, & + & nzc,nnzre, isz, ipb, irwsz, nrc, nze + real(psb_spk_) :: cfb + + + info = psb_success_ + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + nze = min(size(c%val),size(c%ia)) + isz = max(ma,na,mb,nb) + call psb_realloc(isz,col,info) + if (info == 0) call psb_realloc(isz,idxs,info) + if (info == 0) call psb_realloc(isz,icol,info) + if (info /= 0) return + col = dzero + icol = 0 + nzc = 1 + do j = 1,nb + c%icp(j) = nzc + nrc = 0 + do k = b%icp(j), b%icp(j+1)-1 + icl = b%ia(k) + cfb = b%val(k) + irwsz = a%icp(icl+1)-a%icp(icl) + do i = a%icp(icl),a%icp(icl+1)-1 + irw = a%ia(i) + if (icol(irw) 0 ) then + if ((nzc+nrc)>nze) then + nze = max(nb*((nzc+j-1)/j),nzc+2*nrc) + call psb_realloc(nze,c%val,info) + if (info == 0) call psb_realloc(nze,c%ia,info) + if (info /= 0) return + end if + call psb_msort(idxs(1:nrc)) + do i=1, nrc + irw = idxs(i) + c%ia(nzc) = irw + c%val(nzc) = col(irw) + col(irw) = dzero + nzc = nzc + 1 + end do + end if + end do + + c%icp(nb+1) = nzc + + end subroutine csc_spspmm + +end subroutine psb_scscspspmm diff --git a/base/serial/impl/psb_s_csr_impl.f90 b/base/serial/impl/psb_s_csr_impl.f90 index d83e2c43..ce41f7c1 100644 --- a/base/serial/impl/psb_s_csr_impl.f90 +++ b/base/serial/impl/psb_s_csr_impl.f90 @@ -3128,7 +3128,7 @@ subroutine psb_s_cp_csr_to_fmt(a,b,info) !locals type(psb_s_coo_sparse_mat) :: tmp logical :: rwshr_ - integer(psb_ipk_) :: nza, nr, i,j,irw, err_act, nc + integer(psb_ipk_) :: nz, nr, i,j,irw, err_act, nc integer(psb_ipk_), Parameter :: maxtry=8 integer(psb_ipk_) :: debug_level, debug_unit character(len=20) :: name @@ -3142,9 +3142,11 @@ subroutine psb_s_cp_csr_to_fmt(a,b,info) type is (psb_s_csr_sparse_mat) b%psb_s_base_sparse_mat = a%psb_s_base_sparse_mat - if (info == 0) call psb_safe_cpy( a%irp, b%irp , info) - if (info == 0) call psb_safe_cpy( a%ja , b%ja , info) - if (info == 0) call psb_safe_cpy( a%val, b%val , info) + nr = a%get_nrows() + nz = a%get_nzeros() + if (info == 0) call psb_safe_cpy( a%irp(1:nr+1), b%irp , info) + if (info == 0) call psb_safe_cpy( a%ja(1:nz), b%ja , info) + if (info == 0) call psb_safe_cpy( a%val(1:nz), b%val , info) class default call a%cp_to_coo(tmp,info) @@ -3221,12 +3223,137 @@ subroutine psb_s_cp_csr_from_fmt(a,b,info) type is (psb_s_csr_sparse_mat) a%psb_s_base_sparse_mat = b%psb_s_base_sparse_mat - if (info == 0) call psb_safe_cpy( b%irp, a%irp , info) - if (info == 0) call psb_safe_cpy( b%ja , a%ja , info) - if (info == 0) call psb_safe_cpy( b%val, a%val , info) + nr = b%get_nrows() + nz = b%get_nzeros() + if (info == 0) call psb_safe_cpy( b%irp(1:nr+1), a%irp , info) + if (info == 0) call psb_safe_cpy( b%ja(1:nz) , a%ja , info) + if (info == 0) call psb_safe_cpy( b%val(1:nz) , a%val , info) class default call b%cp_to_coo(tmp,info) if (info == psb_success_) call a%mv_from_coo(tmp,info) end select end subroutine psb_s_cp_csr_from_fmt + +subroutine psb_scsrspspmm(a,b,c,info) + use psb_s_mat_mod + use psb_serial_mod, psb_protect_name => psb_scsrspspmm + + implicit none + + class(psb_s_csr_sparse_mat), intent(in) :: a,b + type(psb_s_csr_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: nze, ma,na,mb,nb, nzc, nza, nzb,nzeb + character(len=20) :: name + integer(psb_ipk_) :: err_act + name='psb_csrspspmm' + call psb_erractionsave(err_act) + info = psb_success_ + + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + + if ( mb /= na ) then + write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb + endif + nza = a%get_nzeros() + nzb = b%get_nzeros() + nzc = 2*(nza+nzb) + nze = ma*(((nza+ma-1)/ma)*((nzb+mb-1)/mb) ) + nzeb = (((nza+na-1)/na)*((nzb+nb-1)/nb))*nb + ! Estimate number of nonzeros on output. + ! Turns out this is often a large overestimate. + call c%allocate(ma,nb,min(nzc,nze,nzeb)) + + call csr_spspmm(a,b,c,info) + + call c%set_asb() + + call psb_erractionrestore(err_act) + return + +9999 continue + call psb_erractionrestore(err_act) + if (err_act == psb_act_abort_) then + call psb_error() + return + end if + return + +contains + + subroutine csr_spspmm(a,b,c,info) + implicit none + type(psb_s_csr_sparse_mat), intent(in) :: a,b + type(psb_s_csr_sparse_mat), intent(inout) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: ma,na,mb,nb + integer(psb_ipk_), allocatable :: irow(:), idxs(:) + real(psb_spk_), allocatable :: row(:) + type(psb_int_heap) :: heap + integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, & + & nzc,nnzre, isz, ipb, irwsz, nrc, nze + real(psb_spk_) :: cfb + + + info = psb_success_ + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + nze = min(size(c%val),size(c%ja)) + isz = max(ma,na,mb,nb) + call psb_realloc(isz,row,info) + if (info == 0) call psb_realloc(isz,idxs,info) + if (info == 0) call psb_realloc(isz,irow,info) + if (info /= 0) return + row = dzero + irow = 0 + nzc = 1 + do j = 1,ma + c%irp(j) = nzc + nrc = 0 + do k = a%irp(j), a%irp(j+1)-1 + irw = a%ja(k) + cfb = a%val(k) + irwsz = b%irp(irw+1)-b%irp(irw) + do i = b%irp(irw),b%irp(irw+1)-1 + icl = b%ja(i) + if (irow(icl) 0 ) then + if ((nzc+nrc)>nze) then + nze = max(ma*((nzc+j-1)/j),nzc+2*nrc) + call psb_realloc(nze,c%val,info) + if (info == 0) call psb_realloc(nze,c%ja,info) + if (info /= 0) return + end if + + call psb_msort(idxs(1:nrc)) + do i=1, nrc + irw = idxs(i) + c%ja(nzc) = irw + c%val(nzc) = row(irw) + row(irw) = dzero + nzc = nzc + 1 + end do + end if + end do + + c%irp(ma+1) = nzc + + + end subroutine csr_spspmm + +end subroutine psb_scsrspspmm diff --git a/base/serial/impl/psb_z_csc_impl.f90 b/base/serial/impl/psb_z_csc_impl.f90 index ab0766e8..7b575a41 100644 --- a/base/serial/impl/psb_z_csc_impl.f90 +++ b/base/serial/impl/psb_z_csc_impl.f90 @@ -2521,7 +2521,7 @@ subroutine psb_z_cp_csc_to_fmt(a,b,info) !locals type(psb_z_coo_sparse_mat) :: tmp logical :: rwshr_ - integer(psb_ipk_) :: nza, nr, i,j,irw, err_act, nc + integer(psb_ipk_) :: nz, nr, i,j,irw, err_act, nc integer(psb_ipk_), Parameter :: maxtry=8 integer(psb_ipk_) :: debug_level, debug_unit character(len=20) :: name @@ -2535,9 +2535,11 @@ subroutine psb_z_cp_csc_to_fmt(a,b,info) type is (psb_z_csc_sparse_mat) b%psb_z_base_sparse_mat = a%psb_z_base_sparse_mat - if (info == 0) call psb_safe_cpy( a%icp, b%icp , info) - if (info == 0) call psb_safe_cpy( a%ia , b%ia , info) - if (info == 0) call psb_safe_cpy( a%val, b%val , info) + nc = a%get_ncols() + nz = a%get_nzeros() + if (info == 0) call psb_safe_cpy( a%icp(1:nc+1), b%icp , info) + if (info == 0) call psb_safe_cpy( a%ia(1:nz), b%ia , info) + if (info == 0) call psb_safe_cpy( a%val(1:nz), b%val , info) class default call a%cp_to_coo(tmp,info) @@ -2602,7 +2604,7 @@ subroutine psb_z_cp_csc_from_fmt(a,b,info) !locals type(psb_z_coo_sparse_mat) :: tmp logical :: rwshr_ - integer(psb_ipk_) :: nza, nr, i,j,irw, err_act, nc + integer(psb_ipk_) :: nz, nr, i,j,irw, err_act, nc integer(psb_ipk_), Parameter :: maxtry=8 integer(psb_ipk_) :: debug_level, debug_unit character(len=20) :: name @@ -2615,9 +2617,11 @@ subroutine psb_z_cp_csc_from_fmt(a,b,info) type is (psb_z_csc_sparse_mat) a%psb_z_base_sparse_mat = b%psb_z_base_sparse_mat - if (info == 0) call psb_safe_cpy( b%icp, a%icp , info) - if (info == 0) call psb_safe_cpy( b%ia , a%ia , info) - if (info == 0) call psb_safe_cpy( b%val, a%val , info) + nc = b%get_ncols() + nz = b%get_nzeros() + if (info == 0) call psb_safe_cpy( b%icp(1:nc+1), a%icp , info) + if (info == 0) call psb_safe_cpy( b%ia(1:nz), a%ia , info) + if (info == 0) call psb_safe_cpy( b%val(1:nz), a%val , info) class default call b%cp_to_coo(tmp,info) @@ -2985,3 +2989,124 @@ subroutine psb_z_csc_print(iout,a,iv,head,ivr,ivc) end subroutine psb_z_csc_print +subroutine psb_zcscspspmm(a,b,c,info) + use psb_z_mat_mod + use psb_serial_mod, psb_protect_name => psb_zcscspspmm + + implicit none + + class(psb_z_csc_sparse_mat), intent(in) :: a,b + type(psb_z_csc_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: nze, ma,na,mb,nb, nzc, nza, nzb,nzeb + character(len=20) :: name + integer(psb_ipk_) :: err_act + name='psb_cscspspmm' + call psb_erractionsave(err_act) + info = psb_success_ + + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + + if ( mb /= na ) then + write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb + endif + nza = a%get_nzeros() + nzb = b%get_nzeros() + nzc = 2*(nza+nzb) + nze = ma*(((nza+ma-1)/ma)*((nzb+mb-1)/mb) ) + nzeb = (((nza+na-1)/na)*((nzb+nb-1)/nb))*nb + ! Estimate number of nonzeros on output. + ! Turns out this is often a large overestimate. + call c%allocate(ma,nb,min(nzc,nze,nzeb)) + + + call csc_spspmm(a,b,c,info) + + call c%set_asb() + + call psb_erractionrestore(err_act) + return + +9999 continue + call psb_erractionrestore(err_act) + if (err_act == psb_act_abort_) then + call psb_error() + return + end if + return + +contains + + subroutine csc_spspmm(a,b,c,info) + implicit none + type(psb_z_csc_sparse_mat), intent(in) :: a,b + type(psb_z_csc_sparse_mat), intent(inout) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: ma,na,mb,nb + integer(psb_ipk_), allocatable :: icol(:), idxs(:), iaux(:) + complex(psb_dpk_), allocatable :: col(:) + type(psb_int_heap) :: heap + integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, & + & nzc,nnzre, isz, ipb, irwsz, nrc, nze + complex(psb_dpk_) :: cfb + + + info = psb_success_ + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + nze = min(size(c%val),size(c%ia)) + isz = max(ma,na,mb,nb) + call psb_realloc(isz,col,info) + if (info == 0) call psb_realloc(isz,idxs,info) + if (info == 0) call psb_realloc(isz,icol,info) + if (info /= 0) return + col = dzero + icol = 0 + nzc = 1 + do j = 1,nb + c%icp(j) = nzc + nrc = 0 + do k = b%icp(j), b%icp(j+1)-1 + icl = b%ia(k) + cfb = b%val(k) + irwsz = a%icp(icl+1)-a%icp(icl) + do i = a%icp(icl),a%icp(icl+1)-1 + irw = a%ia(i) + if (icol(irw) 0 ) then + if ((nzc+nrc)>nze) then + nze = max(nb*((nzc+j-1)/j),nzc+2*nrc) + call psb_realloc(nze,c%val,info) + if (info == 0) call psb_realloc(nze,c%ia,info) + if (info /= 0) return + end if + call psb_msort(idxs(1:nrc)) + do i=1, nrc + irw = idxs(i) + c%ia(nzc) = irw + c%val(nzc) = col(irw) + col(irw) = dzero + nzc = nzc + 1 + end do + end if + end do + + c%icp(nb+1) = nzc + + end subroutine csc_spspmm + +end subroutine psb_zcscspspmm diff --git a/base/serial/impl/psb_z_csr_impl.f90 b/base/serial/impl/psb_z_csr_impl.f90 index b3e3419a..b8ec9a17 100644 --- a/base/serial/impl/psb_z_csr_impl.f90 +++ b/base/serial/impl/psb_z_csr_impl.f90 @@ -3128,7 +3128,7 @@ subroutine psb_z_cp_csr_to_fmt(a,b,info) !locals type(psb_z_coo_sparse_mat) :: tmp logical :: rwshr_ - integer(psb_ipk_) :: nza, nr, i,j,irw, err_act, nc + integer(psb_ipk_) :: nz, nr, i,j,irw, err_act, nc integer(psb_ipk_), Parameter :: maxtry=8 integer(psb_ipk_) :: debug_level, debug_unit character(len=20) :: name @@ -3142,9 +3142,11 @@ subroutine psb_z_cp_csr_to_fmt(a,b,info) type is (psb_z_csr_sparse_mat) b%psb_z_base_sparse_mat = a%psb_z_base_sparse_mat - if (info == 0) call psb_safe_cpy( a%irp, b%irp , info) - if (info == 0) call psb_safe_cpy( a%ja , b%ja , info) - if (info == 0) call psb_safe_cpy( a%val, b%val , info) + nr = a%get_nrows() + nz = a%get_nzeros() + if (info == 0) call psb_safe_cpy( a%irp(1:nr+1), b%irp , info) + if (info == 0) call psb_safe_cpy( a%ja(1:nz), b%ja , info) + if (info == 0) call psb_safe_cpy( a%val(1:nz), b%val , info) class default call a%cp_to_coo(tmp,info) @@ -3221,12 +3223,137 @@ subroutine psb_z_cp_csr_from_fmt(a,b,info) type is (psb_z_csr_sparse_mat) a%psb_z_base_sparse_mat = b%psb_z_base_sparse_mat - if (info == 0) call psb_safe_cpy( b%irp, a%irp , info) - if (info == 0) call psb_safe_cpy( b%ja , a%ja , info) - if (info == 0) call psb_safe_cpy( b%val, a%val , info) + nr = b%get_nrows() + nz = b%get_nzeros() + if (info == 0) call psb_safe_cpy( b%irp(1:nr+1), a%irp , info) + if (info == 0) call psb_safe_cpy( b%ja(1:nz) , a%ja , info) + if (info == 0) call psb_safe_cpy( b%val(1:nz) , a%val , info) class default call b%cp_to_coo(tmp,info) if (info == psb_success_) call a%mv_from_coo(tmp,info) end select end subroutine psb_z_cp_csr_from_fmt + +subroutine psb_zcsrspspmm(a,b,c,info) + use psb_z_mat_mod + use psb_serial_mod, psb_protect_name => psb_zcsrspspmm + + implicit none + + class(psb_z_csr_sparse_mat), intent(in) :: a,b + type(psb_z_csr_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: nze, ma,na,mb,nb, nzc, nza, nzb,nzeb + character(len=20) :: name + integer(psb_ipk_) :: err_act + name='psb_csrspspmm' + call psb_erractionsave(err_act) + info = psb_success_ + + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + + if ( mb /= na ) then + write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb + endif + nza = a%get_nzeros() + nzb = b%get_nzeros() + nzc = 2*(nza+nzb) + nze = ma*(((nza+ma-1)/ma)*((nzb+mb-1)/mb) ) + nzeb = (((nza+na-1)/na)*((nzb+nb-1)/nb))*nb + ! Estimate number of nonzeros on output. + ! Turns out this is often a large overestimate. + call c%allocate(ma,nb,min(nzc,nze,nzeb)) + + call csr_spspmm(a,b,c,info) + + call c%set_asb() + + call psb_erractionrestore(err_act) + return + +9999 continue + call psb_erractionrestore(err_act) + if (err_act == psb_act_abort_) then + call psb_error() + return + end if + return + +contains + + subroutine csr_spspmm(a,b,c,info) + implicit none + type(psb_z_csr_sparse_mat), intent(in) :: a,b + type(psb_z_csr_sparse_mat), intent(inout) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: ma,na,mb,nb + integer(psb_ipk_), allocatable :: irow(:), idxs(:) + complex(psb_dpk_), allocatable :: row(:) + type(psb_int_heap) :: heap + integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, & + & nzc,nnzre, isz, ipb, irwsz, nrc, nze + complex(psb_dpk_) :: cfb + + + info = psb_success_ + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + nze = min(size(c%val),size(c%ja)) + isz = max(ma,na,mb,nb) + call psb_realloc(isz,row,info) + if (info == 0) call psb_realloc(isz,idxs,info) + if (info == 0) call psb_realloc(isz,irow,info) + if (info /= 0) return + row = dzero + irow = 0 + nzc = 1 + do j = 1,ma + c%irp(j) = nzc + nrc = 0 + do k = a%irp(j), a%irp(j+1)-1 + irw = a%ja(k) + cfb = a%val(k) + irwsz = b%irp(irw+1)-b%irp(irw) + do i = b%irp(irw),b%irp(irw+1)-1 + icl = b%ja(i) + if (irow(icl) 0 ) then + if ((nzc+nrc)>nze) then + nze = max(ma*((nzc+j-1)/j),nzc+2*nrc) + call psb_realloc(nze,c%val,info) + if (info == 0) call psb_realloc(nze,c%ja,info) + if (info /= 0) return + end if + + call psb_msort(idxs(1:nrc)) + do i=1, nrc + irw = idxs(i) + c%ja(nzc) = irw + c%val(nzc) = row(irw) + row(irw) = dzero + nzc = nzc + 1 + end do + end if + end do + + c%irp(ma+1) = nzc + + + end subroutine csr_spspmm + +end subroutine psb_zcsrspspmm