diff --git a/base/modules/serial/psb_base_mat_mod.F90 b/base/modules/serial/psb_base_mat_mod.F90 index 07bfbd50..2380fcb2 100644 --- a/base/modules/serial/psb_base_mat_mod.F90 +++ b/base/modules/serial/psb_base_mat_mod.F90 @@ -80,8 +80,11 @@ module psb_base_mat_mod integer(psb_ipk_), parameter :: spspmm_serial_rb_tree = 3 integer(psb_ipk_), parameter :: spspmm_omp_rb_tree = 4 integer(psb_ipk_), parameter :: spspmm_omp_two_pass = 5 +#if defined(OPENMP) + integer(psb_ipk_), save :: spspmm_impl = spspmm_omp_gustavson +#else integer(psb_ipk_), save :: spspmm_impl = spspmm_serial - +#endif ! !> \namespace psb_base_mod \class psb_base_sparse_mat diff --git a/base/serial/impl/psb_c_csr_impl.F90 b/base/serial/impl/psb_c_csr_impl.F90 index 3709ad21..2028300d 100644 --- a/base/serial/impl/psb_c_csr_impl.F90 +++ b/base/serial/impl/psb_c_csr_impl.F90 @@ -3654,130 +3654,7 @@ subroutine psb_c_csr_clean_zeros(a, info) call a%set_host() end subroutine psb_c_csr_clean_zeros -#if 0 -subroutine psb_ccsrspspmm(a,b,c,info) - use psb_c_mat_mod - use psb_serial_mod, psb_protect_name => psb_ccsrspspmm - - implicit none - - class(psb_c_csr_sparse_mat), intent(in) :: a,b - type(psb_c_csr_sparse_mat), intent(out) :: c - integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_) :: ma,na,mb,nb, nzc, nza, nzb - character(len=20) :: name - integer(psb_ipk_) :: err_act - name='psb_csrspspmm' - call psb_erractionsave(err_act) - info = psb_success_ - - if (a%is_dev()) call a%sync() - if (b%is_dev()) call b%sync() - - ma = a%get_nrows() - na = a%get_ncols() - mb = b%get_nrows() - nb = b%get_ncols() - - - if ( mb /= na ) then - write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb - info = psb_err_invalid_matrix_sizes_ - call psb_errpush(info,name) - goto 9999 - endif - - ! Estimate number of nonzeros on output. - nza = a%get_nzeros() - nzb = b%get_nzeros() - nzc = 2*(nza+nzb) - call c%allocate(ma,nb,nzc) - - call csr_spspmm(a,b,c,info) - - call c%set_asb() - call c%set_host() - - call psb_erractionrestore(err_act) - return - -9999 call psb_error_handler(err_act) - - return - -contains - - subroutine csr_spspmm(a,b,c,info) - implicit none - type(psb_c_csr_sparse_mat), intent(in) :: a,b - type(psb_c_csr_sparse_mat), intent(inout) :: c - integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_) :: ma,na,mb,nb - integer(psb_ipk_), allocatable :: irow(:), idxs(:) - complex(psb_spk_), allocatable :: row(:) - integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, & - & nzc,nnzre, isz, ipb, irwsz, nrc, nze - complex(psb_spk_) :: cfb - - - info = psb_success_ - ma = a%get_nrows() - na = a%get_ncols() - mb = b%get_nrows() - nb = b%get_ncols() - - nze = min(size(c%val),size(c%ja)) - isz = max(ma,na,mb,nb) - call psb_realloc(isz,row,info) - if (info == 0) call psb_realloc(isz,idxs,info) - if (info == 0) call psb_realloc(isz,irow,info) - if (info /= 0) return - row = dzero - irow = 0 - nzc = 1 - do j = 1,ma - c%irp(j) = nzc - nrc = 0 - do k = a%irp(j), a%irp(j+1)-1 - irw = a%ja(k) - cfb = a%val(k) - irwsz = b%irp(irw+1)-b%irp(irw) - do i = b%irp(irw),b%irp(irw+1)-1 - icl = b%ja(i) - if (irow(icl) 0 ) then - if ((nzc+nrc)>nze) then - nze = max(ma*((nzc+j-1)/j),nzc+2*nrc) - call psb_realloc(nze,c%val,info) - if (info == 0) call psb_realloc(nze,c%ja,info) - if (info /= 0) return - end if - - call psb_qsort(idxs(1:nrc)) - do i=1, nrc - irw = idxs(i) - c%ja(nzc) = irw - c%val(nzc) = row(irw) - row(irw) = dzero - nzc = nzc + 1 - end do - end if - end do - - c%irp(ma+1) = nzc - - - end subroutine csr_spspmm - -end subroutine psb_ccsrspspmm -#else +#if defined(OPENMP) subroutine psb_ccsrspspmm(a,b,c,info) use psb_c_mat_mod use psb_serial_mod, psb_protect_name => psb_ccsrspspmm @@ -4307,6 +4184,131 @@ contains !$omp end parallel do end subroutine spmm_omp_two_pass +end subroutine psb_ccsrspspmm + +#else + +subroutine psb_ccsrspspmm(a,b,c,info) + use psb_c_mat_mod + use psb_serial_mod, psb_protect_name => psb_ccsrspspmm + + implicit none + + class(psb_c_csr_sparse_mat), intent(in) :: a,b + type(psb_c_csr_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: ma,na,mb,nb, nzc, nza, nzb + character(len=20) :: name + integer(psb_ipk_) :: err_act + name='psb_csrspspmm' + call psb_erractionsave(err_act) + info = psb_success_ + + if (a%is_dev()) call a%sync() + if (b%is_dev()) call b%sync() + + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + + if ( mb /= na ) then + write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb + info = psb_err_invalid_matrix_sizes_ + call psb_errpush(info,name) + goto 9999 + endif + + ! Estimate number of nonzeros on output. + nza = a%get_nzeros() + nzb = b%get_nzeros() + nzc = 2*(nza+nzb) + call c%allocate(ma,nb,nzc) + + call csr_spspmm(a,b,c,info) + + call c%set_asb() + call c%set_host() + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + + return + +contains + + subroutine csr_spspmm(a,b,c,info) + implicit none + type(psb_c_csr_sparse_mat), intent(in) :: a,b + type(psb_c_csr_sparse_mat), intent(inout) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: ma,na,mb,nb + integer(psb_ipk_), allocatable :: irow(:), idxs(:) + complex(psb_spk_), allocatable :: row(:) + integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, & + & nzc,nnzre, isz, ipb, irwsz, nrc, nze + complex(psb_spk_) :: cfb + + + info = psb_success_ + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + nze = min(size(c%val),size(c%ja)) + isz = max(ma,na,mb,nb) + call psb_realloc(isz,row,info) + if (info == 0) call psb_realloc(isz,idxs,info) + if (info == 0) call psb_realloc(isz,irow,info) + if (info /= 0) return + row = dzero + irow = 0 + nzc = 1 + do j = 1,ma + c%irp(j) = nzc + nrc = 0 + do k = a%irp(j), a%irp(j+1)-1 + irw = a%ja(k) + cfb = a%val(k) + irwsz = b%irp(irw+1)-b%irp(irw) + do i = b%irp(irw),b%irp(irw+1)-1 + icl = b%ja(i) + if (irow(icl) 0 ) then + if ((nzc+nrc)>nze) then + nze = max(ma*((nzc+j-1)/j),nzc+2*nrc) + call psb_realloc(nze,c%val,info) + if (info == 0) call psb_realloc(nze,c%ja,info) + if (info /= 0) return + end if + + call psb_qsort(idxs(1:nrc)) + do i=1, nrc + irw = idxs(i) + c%ja(nzc) = irw + c%val(nzc) = row(irw) + row(irw) = dzero + nzc = nzc + 1 + end do + end if + end do + + c%irp(ma+1) = nzc + + + end subroutine csr_spspmm + end subroutine psb_ccsrspspmm #endif diff --git a/base/serial/impl/psb_d_csr_impl.F90 b/base/serial/impl/psb_d_csr_impl.F90 index b8f692fa..10f99bc9 100644 --- a/base/serial/impl/psb_d_csr_impl.F90 +++ b/base/serial/impl/psb_d_csr_impl.F90 @@ -3654,130 +3654,7 @@ subroutine psb_d_csr_clean_zeros(a, info) call a%set_host() end subroutine psb_d_csr_clean_zeros -#if 0 -subroutine psb_dcsrspspmm(a,b,c,info) - use psb_d_mat_mod - use psb_serial_mod, psb_protect_name => psb_dcsrspspmm - - implicit none - - class(psb_d_csr_sparse_mat), intent(in) :: a,b - type(psb_d_csr_sparse_mat), intent(out) :: c - integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_) :: ma,na,mb,nb, nzc, nza, nzb - character(len=20) :: name - integer(psb_ipk_) :: err_act - name='psb_csrspspmm' - call psb_erractionsave(err_act) - info = psb_success_ - - if (a%is_dev()) call a%sync() - if (b%is_dev()) call b%sync() - - ma = a%get_nrows() - na = a%get_ncols() - mb = b%get_nrows() - nb = b%get_ncols() - - - if ( mb /= na ) then - write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb - info = psb_err_invalid_matrix_sizes_ - call psb_errpush(info,name) - goto 9999 - endif - - ! Estimate number of nonzeros on output. - nza = a%get_nzeros() - nzb = b%get_nzeros() - nzc = 2*(nza+nzb) - call c%allocate(ma,nb,nzc) - - call csr_spspmm(a,b,c,info) - - call c%set_asb() - call c%set_host() - - call psb_erractionrestore(err_act) - return - -9999 call psb_error_handler(err_act) - - return - -contains - - subroutine csr_spspmm(a,b,c,info) - implicit none - type(psb_d_csr_sparse_mat), intent(in) :: a,b - type(psb_d_csr_sparse_mat), intent(inout) :: c - integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_) :: ma,na,mb,nb - integer(psb_ipk_), allocatable :: irow(:), idxs(:) - real(psb_dpk_), allocatable :: row(:) - integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, & - & nzc,nnzre, isz, ipb, irwsz, nrc, nze - real(psb_dpk_) :: cfb - - - info = psb_success_ - ma = a%get_nrows() - na = a%get_ncols() - mb = b%get_nrows() - nb = b%get_ncols() - - nze = min(size(c%val),size(c%ja)) - isz = max(ma,na,mb,nb) - call psb_realloc(isz,row,info) - if (info == 0) call psb_realloc(isz,idxs,info) - if (info == 0) call psb_realloc(isz,irow,info) - if (info /= 0) return - row = dzero - irow = 0 - nzc = 1 - do j = 1,ma - c%irp(j) = nzc - nrc = 0 - do k = a%irp(j), a%irp(j+1)-1 - irw = a%ja(k) - cfb = a%val(k) - irwsz = b%irp(irw+1)-b%irp(irw) - do i = b%irp(irw),b%irp(irw+1)-1 - icl = b%ja(i) - if (irow(icl) 0 ) then - if ((nzc+nrc)>nze) then - nze = max(ma*((nzc+j-1)/j),nzc+2*nrc) - call psb_realloc(nze,c%val,info) - if (info == 0) call psb_realloc(nze,c%ja,info) - if (info /= 0) return - end if - - call psb_qsort(idxs(1:nrc)) - do i=1, nrc - irw = idxs(i) - c%ja(nzc) = irw - c%val(nzc) = row(irw) - row(irw) = dzero - nzc = nzc + 1 - end do - end if - end do - - c%irp(ma+1) = nzc - - - end subroutine csr_spspmm - -end subroutine psb_dcsrspspmm -#else +#if defined(OPENMP) subroutine psb_dcsrspspmm(a,b,c,info) use psb_d_mat_mod use psb_serial_mod, psb_protect_name => psb_dcsrspspmm @@ -4307,6 +4184,131 @@ contains !$omp end parallel do end subroutine spmm_omp_two_pass +end subroutine psb_dcsrspspmm + +#else + +subroutine psb_dcsrspspmm(a,b,c,info) + use psb_d_mat_mod + use psb_serial_mod, psb_protect_name => psb_dcsrspspmm + + implicit none + + class(psb_d_csr_sparse_mat), intent(in) :: a,b + type(psb_d_csr_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: ma,na,mb,nb, nzc, nza, nzb + character(len=20) :: name + integer(psb_ipk_) :: err_act + name='psb_csrspspmm' + call psb_erractionsave(err_act) + info = psb_success_ + + if (a%is_dev()) call a%sync() + if (b%is_dev()) call b%sync() + + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + + if ( mb /= na ) then + write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb + info = psb_err_invalid_matrix_sizes_ + call psb_errpush(info,name) + goto 9999 + endif + + ! Estimate number of nonzeros on output. + nza = a%get_nzeros() + nzb = b%get_nzeros() + nzc = 2*(nza+nzb) + call c%allocate(ma,nb,nzc) + + call csr_spspmm(a,b,c,info) + + call c%set_asb() + call c%set_host() + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + + return + +contains + + subroutine csr_spspmm(a,b,c,info) + implicit none + type(psb_d_csr_sparse_mat), intent(in) :: a,b + type(psb_d_csr_sparse_mat), intent(inout) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: ma,na,mb,nb + integer(psb_ipk_), allocatable :: irow(:), idxs(:) + real(psb_dpk_), allocatable :: row(:) + integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, & + & nzc,nnzre, isz, ipb, irwsz, nrc, nze + real(psb_dpk_) :: cfb + + + info = psb_success_ + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + nze = min(size(c%val),size(c%ja)) + isz = max(ma,na,mb,nb) + call psb_realloc(isz,row,info) + if (info == 0) call psb_realloc(isz,idxs,info) + if (info == 0) call psb_realloc(isz,irow,info) + if (info /= 0) return + row = dzero + irow = 0 + nzc = 1 + do j = 1,ma + c%irp(j) = nzc + nrc = 0 + do k = a%irp(j), a%irp(j+1)-1 + irw = a%ja(k) + cfb = a%val(k) + irwsz = b%irp(irw+1)-b%irp(irw) + do i = b%irp(irw),b%irp(irw+1)-1 + icl = b%ja(i) + if (irow(icl) 0 ) then + if ((nzc+nrc)>nze) then + nze = max(ma*((nzc+j-1)/j),nzc+2*nrc) + call psb_realloc(nze,c%val,info) + if (info == 0) call psb_realloc(nze,c%ja,info) + if (info /= 0) return + end if + + call psb_qsort(idxs(1:nrc)) + do i=1, nrc + irw = idxs(i) + c%ja(nzc) = irw + c%val(nzc) = row(irw) + row(irw) = dzero + nzc = nzc + 1 + end do + end if + end do + + c%irp(ma+1) = nzc + + + end subroutine csr_spspmm + end subroutine psb_dcsrspspmm #endif diff --git a/base/serial/impl/psb_s_csr_impl.F90 b/base/serial/impl/psb_s_csr_impl.F90 index 59d73892..f3d5c669 100644 --- a/base/serial/impl/psb_s_csr_impl.F90 +++ b/base/serial/impl/psb_s_csr_impl.F90 @@ -3654,130 +3654,7 @@ subroutine psb_s_csr_clean_zeros(a, info) call a%set_host() end subroutine psb_s_csr_clean_zeros -#if 0 -subroutine psb_scsrspspmm(a,b,c,info) - use psb_s_mat_mod - use psb_serial_mod, psb_protect_name => psb_scsrspspmm - - implicit none - - class(psb_s_csr_sparse_mat), intent(in) :: a,b - type(psb_s_csr_sparse_mat), intent(out) :: c - integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_) :: ma,na,mb,nb, nzc, nza, nzb - character(len=20) :: name - integer(psb_ipk_) :: err_act - name='psb_csrspspmm' - call psb_erractionsave(err_act) - info = psb_success_ - - if (a%is_dev()) call a%sync() - if (b%is_dev()) call b%sync() - - ma = a%get_nrows() - na = a%get_ncols() - mb = b%get_nrows() - nb = b%get_ncols() - - - if ( mb /= na ) then - write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb - info = psb_err_invalid_matrix_sizes_ - call psb_errpush(info,name) - goto 9999 - endif - - ! Estimate number of nonzeros on output. - nza = a%get_nzeros() - nzb = b%get_nzeros() - nzc = 2*(nza+nzb) - call c%allocate(ma,nb,nzc) - - call csr_spspmm(a,b,c,info) - - call c%set_asb() - call c%set_host() - - call psb_erractionrestore(err_act) - return - -9999 call psb_error_handler(err_act) - - return - -contains - - subroutine csr_spspmm(a,b,c,info) - implicit none - type(psb_s_csr_sparse_mat), intent(in) :: a,b - type(psb_s_csr_sparse_mat), intent(inout) :: c - integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_) :: ma,na,mb,nb - integer(psb_ipk_), allocatable :: irow(:), idxs(:) - real(psb_spk_), allocatable :: row(:) - integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, & - & nzc,nnzre, isz, ipb, irwsz, nrc, nze - real(psb_spk_) :: cfb - - - info = psb_success_ - ma = a%get_nrows() - na = a%get_ncols() - mb = b%get_nrows() - nb = b%get_ncols() - - nze = min(size(c%val),size(c%ja)) - isz = max(ma,na,mb,nb) - call psb_realloc(isz,row,info) - if (info == 0) call psb_realloc(isz,idxs,info) - if (info == 0) call psb_realloc(isz,irow,info) - if (info /= 0) return - row = dzero - irow = 0 - nzc = 1 - do j = 1,ma - c%irp(j) = nzc - nrc = 0 - do k = a%irp(j), a%irp(j+1)-1 - irw = a%ja(k) - cfb = a%val(k) - irwsz = b%irp(irw+1)-b%irp(irw) - do i = b%irp(irw),b%irp(irw+1)-1 - icl = b%ja(i) - if (irow(icl) 0 ) then - if ((nzc+nrc)>nze) then - nze = max(ma*((nzc+j-1)/j),nzc+2*nrc) - call psb_realloc(nze,c%val,info) - if (info == 0) call psb_realloc(nze,c%ja,info) - if (info /= 0) return - end if - - call psb_qsort(idxs(1:nrc)) - do i=1, nrc - irw = idxs(i) - c%ja(nzc) = irw - c%val(nzc) = row(irw) - row(irw) = dzero - nzc = nzc + 1 - end do - end if - end do - - c%irp(ma+1) = nzc - - - end subroutine csr_spspmm - -end subroutine psb_scsrspspmm -#else +#if defined(OPENMP) subroutine psb_scsrspspmm(a,b,c,info) use psb_s_mat_mod use psb_serial_mod, psb_protect_name => psb_scsrspspmm @@ -4307,6 +4184,131 @@ contains !$omp end parallel do end subroutine spmm_omp_two_pass +end subroutine psb_scsrspspmm + +#else + +subroutine psb_scsrspspmm(a,b,c,info) + use psb_s_mat_mod + use psb_serial_mod, psb_protect_name => psb_scsrspspmm + + implicit none + + class(psb_s_csr_sparse_mat), intent(in) :: a,b + type(psb_s_csr_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: ma,na,mb,nb, nzc, nza, nzb + character(len=20) :: name + integer(psb_ipk_) :: err_act + name='psb_csrspspmm' + call psb_erractionsave(err_act) + info = psb_success_ + + if (a%is_dev()) call a%sync() + if (b%is_dev()) call b%sync() + + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + + if ( mb /= na ) then + write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb + info = psb_err_invalid_matrix_sizes_ + call psb_errpush(info,name) + goto 9999 + endif + + ! Estimate number of nonzeros on output. + nza = a%get_nzeros() + nzb = b%get_nzeros() + nzc = 2*(nza+nzb) + call c%allocate(ma,nb,nzc) + + call csr_spspmm(a,b,c,info) + + call c%set_asb() + call c%set_host() + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + + return + +contains + + subroutine csr_spspmm(a,b,c,info) + implicit none + type(psb_s_csr_sparse_mat), intent(in) :: a,b + type(psb_s_csr_sparse_mat), intent(inout) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: ma,na,mb,nb + integer(psb_ipk_), allocatable :: irow(:), idxs(:) + real(psb_spk_), allocatable :: row(:) + integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, & + & nzc,nnzre, isz, ipb, irwsz, nrc, nze + real(psb_spk_) :: cfb + + + info = psb_success_ + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + nze = min(size(c%val),size(c%ja)) + isz = max(ma,na,mb,nb) + call psb_realloc(isz,row,info) + if (info == 0) call psb_realloc(isz,idxs,info) + if (info == 0) call psb_realloc(isz,irow,info) + if (info /= 0) return + row = dzero + irow = 0 + nzc = 1 + do j = 1,ma + c%irp(j) = nzc + nrc = 0 + do k = a%irp(j), a%irp(j+1)-1 + irw = a%ja(k) + cfb = a%val(k) + irwsz = b%irp(irw+1)-b%irp(irw) + do i = b%irp(irw),b%irp(irw+1)-1 + icl = b%ja(i) + if (irow(icl) 0 ) then + if ((nzc+nrc)>nze) then + nze = max(ma*((nzc+j-1)/j),nzc+2*nrc) + call psb_realloc(nze,c%val,info) + if (info == 0) call psb_realloc(nze,c%ja,info) + if (info /= 0) return + end if + + call psb_qsort(idxs(1:nrc)) + do i=1, nrc + irw = idxs(i) + c%ja(nzc) = irw + c%val(nzc) = row(irw) + row(irw) = dzero + nzc = nzc + 1 + end do + end if + end do + + c%irp(ma+1) = nzc + + + end subroutine csr_spspmm + end subroutine psb_scsrspspmm #endif diff --git a/base/serial/impl/psb_z_csr_impl.F90 b/base/serial/impl/psb_z_csr_impl.F90 index 34cd6fbc..5cf1c72d 100644 --- a/base/serial/impl/psb_z_csr_impl.F90 +++ b/base/serial/impl/psb_z_csr_impl.F90 @@ -3654,130 +3654,7 @@ subroutine psb_z_csr_clean_zeros(a, info) call a%set_host() end subroutine psb_z_csr_clean_zeros -#if 0 -subroutine psb_zcsrspspmm(a,b,c,info) - use psb_z_mat_mod - use psb_serial_mod, psb_protect_name => psb_zcsrspspmm - - implicit none - - class(psb_z_csr_sparse_mat), intent(in) :: a,b - type(psb_z_csr_sparse_mat), intent(out) :: c - integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_) :: ma,na,mb,nb, nzc, nza, nzb - character(len=20) :: name - integer(psb_ipk_) :: err_act - name='psb_csrspspmm' - call psb_erractionsave(err_act) - info = psb_success_ - - if (a%is_dev()) call a%sync() - if (b%is_dev()) call b%sync() - - ma = a%get_nrows() - na = a%get_ncols() - mb = b%get_nrows() - nb = b%get_ncols() - - - if ( mb /= na ) then - write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb - info = psb_err_invalid_matrix_sizes_ - call psb_errpush(info,name) - goto 9999 - endif - - ! Estimate number of nonzeros on output. - nza = a%get_nzeros() - nzb = b%get_nzeros() - nzc = 2*(nza+nzb) - call c%allocate(ma,nb,nzc) - - call csr_spspmm(a,b,c,info) - - call c%set_asb() - call c%set_host() - - call psb_erractionrestore(err_act) - return - -9999 call psb_error_handler(err_act) - - return - -contains - - subroutine csr_spspmm(a,b,c,info) - implicit none - type(psb_z_csr_sparse_mat), intent(in) :: a,b - type(psb_z_csr_sparse_mat), intent(inout) :: c - integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_) :: ma,na,mb,nb - integer(psb_ipk_), allocatable :: irow(:), idxs(:) - complex(psb_dpk_), allocatable :: row(:) - integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, & - & nzc,nnzre, isz, ipb, irwsz, nrc, nze - complex(psb_dpk_) :: cfb - - - info = psb_success_ - ma = a%get_nrows() - na = a%get_ncols() - mb = b%get_nrows() - nb = b%get_ncols() - - nze = min(size(c%val),size(c%ja)) - isz = max(ma,na,mb,nb) - call psb_realloc(isz,row,info) - if (info == 0) call psb_realloc(isz,idxs,info) - if (info == 0) call psb_realloc(isz,irow,info) - if (info /= 0) return - row = dzero - irow = 0 - nzc = 1 - do j = 1,ma - c%irp(j) = nzc - nrc = 0 - do k = a%irp(j), a%irp(j+1)-1 - irw = a%ja(k) - cfb = a%val(k) - irwsz = b%irp(irw+1)-b%irp(irw) - do i = b%irp(irw),b%irp(irw+1)-1 - icl = b%ja(i) - if (irow(icl) 0 ) then - if ((nzc+nrc)>nze) then - nze = max(ma*((nzc+j-1)/j),nzc+2*nrc) - call psb_realloc(nze,c%val,info) - if (info == 0) call psb_realloc(nze,c%ja,info) - if (info /= 0) return - end if - - call psb_qsort(idxs(1:nrc)) - do i=1, nrc - irw = idxs(i) - c%ja(nzc) = irw - c%val(nzc) = row(irw) - row(irw) = dzero - nzc = nzc + 1 - end do - end if - end do - - c%irp(ma+1) = nzc - - - end subroutine csr_spspmm - -end subroutine psb_zcsrspspmm -#else +#if defined(OPENMP) subroutine psb_zcsrspspmm(a,b,c,info) use psb_z_mat_mod use psb_serial_mod, psb_protect_name => psb_zcsrspspmm @@ -4307,6 +4184,131 @@ contains !$omp end parallel do end subroutine spmm_omp_two_pass +end subroutine psb_zcsrspspmm + +#else + +subroutine psb_zcsrspspmm(a,b,c,info) + use psb_z_mat_mod + use psb_serial_mod, psb_protect_name => psb_zcsrspspmm + + implicit none + + class(psb_z_csr_sparse_mat), intent(in) :: a,b + type(psb_z_csr_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: ma,na,mb,nb, nzc, nza, nzb + character(len=20) :: name + integer(psb_ipk_) :: err_act + name='psb_csrspspmm' + call psb_erractionsave(err_act) + info = psb_success_ + + if (a%is_dev()) call a%sync() + if (b%is_dev()) call b%sync() + + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + + if ( mb /= na ) then + write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb + info = psb_err_invalid_matrix_sizes_ + call psb_errpush(info,name) + goto 9999 + endif + + ! Estimate number of nonzeros on output. + nza = a%get_nzeros() + nzb = b%get_nzeros() + nzc = 2*(nza+nzb) + call c%allocate(ma,nb,nzc) + + call csr_spspmm(a,b,c,info) + + call c%set_asb() + call c%set_host() + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + + return + +contains + + subroutine csr_spspmm(a,b,c,info) + implicit none + type(psb_z_csr_sparse_mat), intent(in) :: a,b + type(psb_z_csr_sparse_mat), intent(inout) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: ma,na,mb,nb + integer(psb_ipk_), allocatable :: irow(:), idxs(:) + complex(psb_dpk_), allocatable :: row(:) + integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, & + & nzc,nnzre, isz, ipb, irwsz, nrc, nze + complex(psb_dpk_) :: cfb + + + info = psb_success_ + ma = a%get_nrows() + na = a%get_ncols() + mb = b%get_nrows() + nb = b%get_ncols() + + nze = min(size(c%val),size(c%ja)) + isz = max(ma,na,mb,nb) + call psb_realloc(isz,row,info) + if (info == 0) call psb_realloc(isz,idxs,info) + if (info == 0) call psb_realloc(isz,irow,info) + if (info /= 0) return + row = dzero + irow = 0 + nzc = 1 + do j = 1,ma + c%irp(j) = nzc + nrc = 0 + do k = a%irp(j), a%irp(j+1)-1 + irw = a%ja(k) + cfb = a%val(k) + irwsz = b%irp(irw+1)-b%irp(irw) + do i = b%irp(irw),b%irp(irw+1)-1 + icl = b%ja(i) + if (irow(icl) 0 ) then + if ((nzc+nrc)>nze) then + nze = max(ma*((nzc+j-1)/j),nzc+2*nrc) + call psb_realloc(nze,c%val,info) + if (info == 0) call psb_realloc(nze,c%ja,info) + if (info /= 0) return + end if + + call psb_qsort(idxs(1:nrc)) + do i=1, nrc + irw = idxs(i) + c%ja(nzc) = irw + c%val(nzc) = row(irw) + row(irw) = dzero + nzc = nzc + 1 + end do + end if + end do + + c%irp(ma+1) = nzc + + + end subroutine csr_spspmm + end subroutine psb_zcsrspspmm #endif