Set defaults for SPSPMM depending on OpenMP compilation.

omp-walther
sfilippone 1 year ago
parent 0d8a5d3dc2
commit 41be1357c3

@ -80,8 +80,11 @@ module psb_base_mat_mod
integer(psb_ipk_), parameter :: spspmm_serial_rb_tree = 3
integer(psb_ipk_), parameter :: spspmm_omp_rb_tree = 4
integer(psb_ipk_), parameter :: spspmm_omp_two_pass = 5
#if defined(OPENMP)
integer(psb_ipk_), save :: spspmm_impl = spspmm_omp_gustavson
#else
integer(psb_ipk_), save :: spspmm_impl = spspmm_serial
#endif
!
!> \namespace psb_base_mod \class psb_base_sparse_mat

@ -3654,130 +3654,7 @@ subroutine psb_c_csr_clean_zeros(a, info)
call a%set_host()
end subroutine psb_c_csr_clean_zeros
#if 0
subroutine psb_ccsrspspmm(a,b,c,info)
use psb_c_mat_mod
use psb_serial_mod, psb_protect_name => psb_ccsrspspmm
implicit none
class(psb_c_csr_sparse_mat), intent(in) :: a,b
type(psb_c_csr_sparse_mat), intent(out) :: c
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: ma,na,mb,nb, nzc, nza, nzb
character(len=20) :: name
integer(psb_ipk_) :: err_act
name='psb_csrspspmm'
call psb_erractionsave(err_act)
info = psb_success_
if (a%is_dev()) call a%sync()
if (b%is_dev()) call b%sync()
ma = a%get_nrows()
na = a%get_ncols()
mb = b%get_nrows()
nb = b%get_ncols()
if ( mb /= na ) then
write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb
info = psb_err_invalid_matrix_sizes_
call psb_errpush(info,name)
goto 9999
endif
! Estimate number of nonzeros on output.
nza = a%get_nzeros()
nzb = b%get_nzeros()
nzc = 2*(nza+nzb)
call c%allocate(ma,nb,nzc)
call csr_spspmm(a,b,c,info)
call c%set_asb()
call c%set_host()
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(err_act)
return
contains
subroutine csr_spspmm(a,b,c,info)
implicit none
type(psb_c_csr_sparse_mat), intent(in) :: a,b
type(psb_c_csr_sparse_mat), intent(inout) :: c
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: ma,na,mb,nb
integer(psb_ipk_), allocatable :: irow(:), idxs(:)
complex(psb_spk_), allocatable :: row(:)
integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, &
& nzc,nnzre, isz, ipb, irwsz, nrc, nze
complex(psb_spk_) :: cfb
info = psb_success_
ma = a%get_nrows()
na = a%get_ncols()
mb = b%get_nrows()
nb = b%get_ncols()
nze = min(size(c%val),size(c%ja))
isz = max(ma,na,mb,nb)
call psb_realloc(isz,row,info)
if (info == 0) call psb_realloc(isz,idxs,info)
if (info == 0) call psb_realloc(isz,irow,info)
if (info /= 0) return
row = dzero
irow = 0
nzc = 1
do j = 1,ma
c%irp(j) = nzc
nrc = 0
do k = a%irp(j), a%irp(j+1)-1
irw = a%ja(k)
cfb = a%val(k)
irwsz = b%irp(irw+1)-b%irp(irw)
do i = b%irp(irw),b%irp(irw+1)-1
icl = b%ja(i)
if (irow(icl)<j) then
nrc = nrc + 1
idxs(nrc) = icl
irow(icl) = j
end if
row(icl) = row(icl) + cfb*b%val(i)
end do
end do
if (nrc > 0 ) then
if ((nzc+nrc)>nze) then
nze = max(ma*((nzc+j-1)/j),nzc+2*nrc)
call psb_realloc(nze,c%val,info)
if (info == 0) call psb_realloc(nze,c%ja,info)
if (info /= 0) return
end if
call psb_qsort(idxs(1:nrc))
do i=1, nrc
irw = idxs(i)
c%ja(nzc) = irw
c%val(nzc) = row(irw)
row(irw) = dzero
nzc = nzc + 1
end do
end if
end do
c%irp(ma+1) = nzc
end subroutine csr_spspmm
end subroutine psb_ccsrspspmm
#else
#if defined(OPENMP)
subroutine psb_ccsrspspmm(a,b,c,info)
use psb_c_mat_mod
use psb_serial_mod, psb_protect_name => psb_ccsrspspmm
@ -4307,6 +4184,131 @@ contains
!$omp end parallel do
end subroutine spmm_omp_two_pass
end subroutine psb_ccsrspspmm
#else
subroutine psb_ccsrspspmm(a,b,c,info)
use psb_c_mat_mod
use psb_serial_mod, psb_protect_name => psb_ccsrspspmm
implicit none
class(psb_c_csr_sparse_mat), intent(in) :: a,b
type(psb_c_csr_sparse_mat), intent(out) :: c
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: ma,na,mb,nb, nzc, nza, nzb
character(len=20) :: name
integer(psb_ipk_) :: err_act
name='psb_csrspspmm'
call psb_erractionsave(err_act)
info = psb_success_
if (a%is_dev()) call a%sync()
if (b%is_dev()) call b%sync()
ma = a%get_nrows()
na = a%get_ncols()
mb = b%get_nrows()
nb = b%get_ncols()
if ( mb /= na ) then
write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb
info = psb_err_invalid_matrix_sizes_
call psb_errpush(info,name)
goto 9999
endif
! Estimate number of nonzeros on output.
nza = a%get_nzeros()
nzb = b%get_nzeros()
nzc = 2*(nza+nzb)
call c%allocate(ma,nb,nzc)
call csr_spspmm(a,b,c,info)
call c%set_asb()
call c%set_host()
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(err_act)
return
contains
subroutine csr_spspmm(a,b,c,info)
implicit none
type(psb_c_csr_sparse_mat), intent(in) :: a,b
type(psb_c_csr_sparse_mat), intent(inout) :: c
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: ma,na,mb,nb
integer(psb_ipk_), allocatable :: irow(:), idxs(:)
complex(psb_spk_), allocatable :: row(:)
integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, &
& nzc,nnzre, isz, ipb, irwsz, nrc, nze
complex(psb_spk_) :: cfb
info = psb_success_
ma = a%get_nrows()
na = a%get_ncols()
mb = b%get_nrows()
nb = b%get_ncols()
nze = min(size(c%val),size(c%ja))
isz = max(ma,na,mb,nb)
call psb_realloc(isz,row,info)
if (info == 0) call psb_realloc(isz,idxs,info)
if (info == 0) call psb_realloc(isz,irow,info)
if (info /= 0) return
row = dzero
irow = 0
nzc = 1
do j = 1,ma
c%irp(j) = nzc
nrc = 0
do k = a%irp(j), a%irp(j+1)-1
irw = a%ja(k)
cfb = a%val(k)
irwsz = b%irp(irw+1)-b%irp(irw)
do i = b%irp(irw),b%irp(irw+1)-1
icl = b%ja(i)
if (irow(icl)<j) then
nrc = nrc + 1
idxs(nrc) = icl
irow(icl) = j
end if
row(icl) = row(icl) + cfb*b%val(i)
end do
end do
if (nrc > 0 ) then
if ((nzc+nrc)>nze) then
nze = max(ma*((nzc+j-1)/j),nzc+2*nrc)
call psb_realloc(nze,c%val,info)
if (info == 0) call psb_realloc(nze,c%ja,info)
if (info /= 0) return
end if
call psb_qsort(idxs(1:nrc))
do i=1, nrc
irw = idxs(i)
c%ja(nzc) = irw
c%val(nzc) = row(irw)
row(irw) = dzero
nzc = nzc + 1
end do
end if
end do
c%irp(ma+1) = nzc
end subroutine csr_spspmm
end subroutine psb_ccsrspspmm
#endif

@ -3654,130 +3654,7 @@ subroutine psb_d_csr_clean_zeros(a, info)
call a%set_host()
end subroutine psb_d_csr_clean_zeros
#if 0
subroutine psb_dcsrspspmm(a,b,c,info)
use psb_d_mat_mod
use psb_serial_mod, psb_protect_name => psb_dcsrspspmm
implicit none
class(psb_d_csr_sparse_mat), intent(in) :: a,b
type(psb_d_csr_sparse_mat), intent(out) :: c
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: ma,na,mb,nb, nzc, nza, nzb
character(len=20) :: name
integer(psb_ipk_) :: err_act
name='psb_csrspspmm'
call psb_erractionsave(err_act)
info = psb_success_
if (a%is_dev()) call a%sync()
if (b%is_dev()) call b%sync()
ma = a%get_nrows()
na = a%get_ncols()
mb = b%get_nrows()
nb = b%get_ncols()
if ( mb /= na ) then
write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb
info = psb_err_invalid_matrix_sizes_
call psb_errpush(info,name)
goto 9999
endif
! Estimate number of nonzeros on output.
nza = a%get_nzeros()
nzb = b%get_nzeros()
nzc = 2*(nza+nzb)
call c%allocate(ma,nb,nzc)
call csr_spspmm(a,b,c,info)
call c%set_asb()
call c%set_host()
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(err_act)
return
contains
subroutine csr_spspmm(a,b,c,info)
implicit none
type(psb_d_csr_sparse_mat), intent(in) :: a,b
type(psb_d_csr_sparse_mat), intent(inout) :: c
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: ma,na,mb,nb
integer(psb_ipk_), allocatable :: irow(:), idxs(:)
real(psb_dpk_), allocatable :: row(:)
integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, &
& nzc,nnzre, isz, ipb, irwsz, nrc, nze
real(psb_dpk_) :: cfb
info = psb_success_
ma = a%get_nrows()
na = a%get_ncols()
mb = b%get_nrows()
nb = b%get_ncols()
nze = min(size(c%val),size(c%ja))
isz = max(ma,na,mb,nb)
call psb_realloc(isz,row,info)
if (info == 0) call psb_realloc(isz,idxs,info)
if (info == 0) call psb_realloc(isz,irow,info)
if (info /= 0) return
row = dzero
irow = 0
nzc = 1
do j = 1,ma
c%irp(j) = nzc
nrc = 0
do k = a%irp(j), a%irp(j+1)-1
irw = a%ja(k)
cfb = a%val(k)
irwsz = b%irp(irw+1)-b%irp(irw)
do i = b%irp(irw),b%irp(irw+1)-1
icl = b%ja(i)
if (irow(icl)<j) then
nrc = nrc + 1
idxs(nrc) = icl
irow(icl) = j
end if
row(icl) = row(icl) + cfb*b%val(i)
end do
end do
if (nrc > 0 ) then
if ((nzc+nrc)>nze) then
nze = max(ma*((nzc+j-1)/j),nzc+2*nrc)
call psb_realloc(nze,c%val,info)
if (info == 0) call psb_realloc(nze,c%ja,info)
if (info /= 0) return
end if
call psb_qsort(idxs(1:nrc))
do i=1, nrc
irw = idxs(i)
c%ja(nzc) = irw
c%val(nzc) = row(irw)
row(irw) = dzero
nzc = nzc + 1
end do
end if
end do
c%irp(ma+1) = nzc
end subroutine csr_spspmm
end subroutine psb_dcsrspspmm
#else
#if defined(OPENMP)
subroutine psb_dcsrspspmm(a,b,c,info)
use psb_d_mat_mod
use psb_serial_mod, psb_protect_name => psb_dcsrspspmm
@ -4307,6 +4184,131 @@ contains
!$omp end parallel do
end subroutine spmm_omp_two_pass
end subroutine psb_dcsrspspmm
#else
subroutine psb_dcsrspspmm(a,b,c,info)
use psb_d_mat_mod
use psb_serial_mod, psb_protect_name => psb_dcsrspspmm
implicit none
class(psb_d_csr_sparse_mat), intent(in) :: a,b
type(psb_d_csr_sparse_mat), intent(out) :: c
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: ma,na,mb,nb, nzc, nza, nzb
character(len=20) :: name
integer(psb_ipk_) :: err_act
name='psb_csrspspmm'
call psb_erractionsave(err_act)
info = psb_success_
if (a%is_dev()) call a%sync()
if (b%is_dev()) call b%sync()
ma = a%get_nrows()
na = a%get_ncols()
mb = b%get_nrows()
nb = b%get_ncols()
if ( mb /= na ) then
write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb
info = psb_err_invalid_matrix_sizes_
call psb_errpush(info,name)
goto 9999
endif
! Estimate number of nonzeros on output.
nza = a%get_nzeros()
nzb = b%get_nzeros()
nzc = 2*(nza+nzb)
call c%allocate(ma,nb,nzc)
call csr_spspmm(a,b,c,info)
call c%set_asb()
call c%set_host()
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(err_act)
return
contains
subroutine csr_spspmm(a,b,c,info)
implicit none
type(psb_d_csr_sparse_mat), intent(in) :: a,b
type(psb_d_csr_sparse_mat), intent(inout) :: c
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: ma,na,mb,nb
integer(psb_ipk_), allocatable :: irow(:), idxs(:)
real(psb_dpk_), allocatable :: row(:)
integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, &
& nzc,nnzre, isz, ipb, irwsz, nrc, nze
real(psb_dpk_) :: cfb
info = psb_success_
ma = a%get_nrows()
na = a%get_ncols()
mb = b%get_nrows()
nb = b%get_ncols()
nze = min(size(c%val),size(c%ja))
isz = max(ma,na,mb,nb)
call psb_realloc(isz,row,info)
if (info == 0) call psb_realloc(isz,idxs,info)
if (info == 0) call psb_realloc(isz,irow,info)
if (info /= 0) return
row = dzero
irow = 0
nzc = 1
do j = 1,ma
c%irp(j) = nzc
nrc = 0
do k = a%irp(j), a%irp(j+1)-1
irw = a%ja(k)
cfb = a%val(k)
irwsz = b%irp(irw+1)-b%irp(irw)
do i = b%irp(irw),b%irp(irw+1)-1
icl = b%ja(i)
if (irow(icl)<j) then
nrc = nrc + 1
idxs(nrc) = icl
irow(icl) = j
end if
row(icl) = row(icl) + cfb*b%val(i)
end do
end do
if (nrc > 0 ) then
if ((nzc+nrc)>nze) then
nze = max(ma*((nzc+j-1)/j),nzc+2*nrc)
call psb_realloc(nze,c%val,info)
if (info == 0) call psb_realloc(nze,c%ja,info)
if (info /= 0) return
end if
call psb_qsort(idxs(1:nrc))
do i=1, nrc
irw = idxs(i)
c%ja(nzc) = irw
c%val(nzc) = row(irw)
row(irw) = dzero
nzc = nzc + 1
end do
end if
end do
c%irp(ma+1) = nzc
end subroutine csr_spspmm
end subroutine psb_dcsrspspmm
#endif

@ -3654,130 +3654,7 @@ subroutine psb_s_csr_clean_zeros(a, info)
call a%set_host()
end subroutine psb_s_csr_clean_zeros
#if 0
subroutine psb_scsrspspmm(a,b,c,info)
use psb_s_mat_mod
use psb_serial_mod, psb_protect_name => psb_scsrspspmm
implicit none
class(psb_s_csr_sparse_mat), intent(in) :: a,b
type(psb_s_csr_sparse_mat), intent(out) :: c
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: ma,na,mb,nb, nzc, nza, nzb
character(len=20) :: name
integer(psb_ipk_) :: err_act
name='psb_csrspspmm'
call psb_erractionsave(err_act)
info = psb_success_
if (a%is_dev()) call a%sync()
if (b%is_dev()) call b%sync()
ma = a%get_nrows()
na = a%get_ncols()
mb = b%get_nrows()
nb = b%get_ncols()
if ( mb /= na ) then
write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb
info = psb_err_invalid_matrix_sizes_
call psb_errpush(info,name)
goto 9999
endif
! Estimate number of nonzeros on output.
nza = a%get_nzeros()
nzb = b%get_nzeros()
nzc = 2*(nza+nzb)
call c%allocate(ma,nb,nzc)
call csr_spspmm(a,b,c,info)
call c%set_asb()
call c%set_host()
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(err_act)
return
contains
subroutine csr_spspmm(a,b,c,info)
implicit none
type(psb_s_csr_sparse_mat), intent(in) :: a,b
type(psb_s_csr_sparse_mat), intent(inout) :: c
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: ma,na,mb,nb
integer(psb_ipk_), allocatable :: irow(:), idxs(:)
real(psb_spk_), allocatable :: row(:)
integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, &
& nzc,nnzre, isz, ipb, irwsz, nrc, nze
real(psb_spk_) :: cfb
info = psb_success_
ma = a%get_nrows()
na = a%get_ncols()
mb = b%get_nrows()
nb = b%get_ncols()
nze = min(size(c%val),size(c%ja))
isz = max(ma,na,mb,nb)
call psb_realloc(isz,row,info)
if (info == 0) call psb_realloc(isz,idxs,info)
if (info == 0) call psb_realloc(isz,irow,info)
if (info /= 0) return
row = dzero
irow = 0
nzc = 1
do j = 1,ma
c%irp(j) = nzc
nrc = 0
do k = a%irp(j), a%irp(j+1)-1
irw = a%ja(k)
cfb = a%val(k)
irwsz = b%irp(irw+1)-b%irp(irw)
do i = b%irp(irw),b%irp(irw+1)-1
icl = b%ja(i)
if (irow(icl)<j) then
nrc = nrc + 1
idxs(nrc) = icl
irow(icl) = j
end if
row(icl) = row(icl) + cfb*b%val(i)
end do
end do
if (nrc > 0 ) then
if ((nzc+nrc)>nze) then
nze = max(ma*((nzc+j-1)/j),nzc+2*nrc)
call psb_realloc(nze,c%val,info)
if (info == 0) call psb_realloc(nze,c%ja,info)
if (info /= 0) return
end if
call psb_qsort(idxs(1:nrc))
do i=1, nrc
irw = idxs(i)
c%ja(nzc) = irw
c%val(nzc) = row(irw)
row(irw) = dzero
nzc = nzc + 1
end do
end if
end do
c%irp(ma+1) = nzc
end subroutine csr_spspmm
end subroutine psb_scsrspspmm
#else
#if defined(OPENMP)
subroutine psb_scsrspspmm(a,b,c,info)
use psb_s_mat_mod
use psb_serial_mod, psb_protect_name => psb_scsrspspmm
@ -4307,6 +4184,131 @@ contains
!$omp end parallel do
end subroutine spmm_omp_two_pass
end subroutine psb_scsrspspmm
#else
subroutine psb_scsrspspmm(a,b,c,info)
use psb_s_mat_mod
use psb_serial_mod, psb_protect_name => psb_scsrspspmm
implicit none
class(psb_s_csr_sparse_mat), intent(in) :: a,b
type(psb_s_csr_sparse_mat), intent(out) :: c
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: ma,na,mb,nb, nzc, nza, nzb
character(len=20) :: name
integer(psb_ipk_) :: err_act
name='psb_csrspspmm'
call psb_erractionsave(err_act)
info = psb_success_
if (a%is_dev()) call a%sync()
if (b%is_dev()) call b%sync()
ma = a%get_nrows()
na = a%get_ncols()
mb = b%get_nrows()
nb = b%get_ncols()
if ( mb /= na ) then
write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb
info = psb_err_invalid_matrix_sizes_
call psb_errpush(info,name)
goto 9999
endif
! Estimate number of nonzeros on output.
nza = a%get_nzeros()
nzb = b%get_nzeros()
nzc = 2*(nza+nzb)
call c%allocate(ma,nb,nzc)
call csr_spspmm(a,b,c,info)
call c%set_asb()
call c%set_host()
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(err_act)
return
contains
subroutine csr_spspmm(a,b,c,info)
implicit none
type(psb_s_csr_sparse_mat), intent(in) :: a,b
type(psb_s_csr_sparse_mat), intent(inout) :: c
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: ma,na,mb,nb
integer(psb_ipk_), allocatable :: irow(:), idxs(:)
real(psb_spk_), allocatable :: row(:)
integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, &
& nzc,nnzre, isz, ipb, irwsz, nrc, nze
real(psb_spk_) :: cfb
info = psb_success_
ma = a%get_nrows()
na = a%get_ncols()
mb = b%get_nrows()
nb = b%get_ncols()
nze = min(size(c%val),size(c%ja))
isz = max(ma,na,mb,nb)
call psb_realloc(isz,row,info)
if (info == 0) call psb_realloc(isz,idxs,info)
if (info == 0) call psb_realloc(isz,irow,info)
if (info /= 0) return
row = dzero
irow = 0
nzc = 1
do j = 1,ma
c%irp(j) = nzc
nrc = 0
do k = a%irp(j), a%irp(j+1)-1
irw = a%ja(k)
cfb = a%val(k)
irwsz = b%irp(irw+1)-b%irp(irw)
do i = b%irp(irw),b%irp(irw+1)-1
icl = b%ja(i)
if (irow(icl)<j) then
nrc = nrc + 1
idxs(nrc) = icl
irow(icl) = j
end if
row(icl) = row(icl) + cfb*b%val(i)
end do
end do
if (nrc > 0 ) then
if ((nzc+nrc)>nze) then
nze = max(ma*((nzc+j-1)/j),nzc+2*nrc)
call psb_realloc(nze,c%val,info)
if (info == 0) call psb_realloc(nze,c%ja,info)
if (info /= 0) return
end if
call psb_qsort(idxs(1:nrc))
do i=1, nrc
irw = idxs(i)
c%ja(nzc) = irw
c%val(nzc) = row(irw)
row(irw) = dzero
nzc = nzc + 1
end do
end if
end do
c%irp(ma+1) = nzc
end subroutine csr_spspmm
end subroutine psb_scsrspspmm
#endif

@ -3654,130 +3654,7 @@ subroutine psb_z_csr_clean_zeros(a, info)
call a%set_host()
end subroutine psb_z_csr_clean_zeros
#if 0
subroutine psb_zcsrspspmm(a,b,c,info)
use psb_z_mat_mod
use psb_serial_mod, psb_protect_name => psb_zcsrspspmm
implicit none
class(psb_z_csr_sparse_mat), intent(in) :: a,b
type(psb_z_csr_sparse_mat), intent(out) :: c
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: ma,na,mb,nb, nzc, nza, nzb
character(len=20) :: name
integer(psb_ipk_) :: err_act
name='psb_csrspspmm'
call psb_erractionsave(err_act)
info = psb_success_
if (a%is_dev()) call a%sync()
if (b%is_dev()) call b%sync()
ma = a%get_nrows()
na = a%get_ncols()
mb = b%get_nrows()
nb = b%get_ncols()
if ( mb /= na ) then
write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb
info = psb_err_invalid_matrix_sizes_
call psb_errpush(info,name)
goto 9999
endif
! Estimate number of nonzeros on output.
nza = a%get_nzeros()
nzb = b%get_nzeros()
nzc = 2*(nza+nzb)
call c%allocate(ma,nb,nzc)
call csr_spspmm(a,b,c,info)
call c%set_asb()
call c%set_host()
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(err_act)
return
contains
subroutine csr_spspmm(a,b,c,info)
implicit none
type(psb_z_csr_sparse_mat), intent(in) :: a,b
type(psb_z_csr_sparse_mat), intent(inout) :: c
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: ma,na,mb,nb
integer(psb_ipk_), allocatable :: irow(:), idxs(:)
complex(psb_dpk_), allocatable :: row(:)
integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, &
& nzc,nnzre, isz, ipb, irwsz, nrc, nze
complex(psb_dpk_) :: cfb
info = psb_success_
ma = a%get_nrows()
na = a%get_ncols()
mb = b%get_nrows()
nb = b%get_ncols()
nze = min(size(c%val),size(c%ja))
isz = max(ma,na,mb,nb)
call psb_realloc(isz,row,info)
if (info == 0) call psb_realloc(isz,idxs,info)
if (info == 0) call psb_realloc(isz,irow,info)
if (info /= 0) return
row = dzero
irow = 0
nzc = 1
do j = 1,ma
c%irp(j) = nzc
nrc = 0
do k = a%irp(j), a%irp(j+1)-1
irw = a%ja(k)
cfb = a%val(k)
irwsz = b%irp(irw+1)-b%irp(irw)
do i = b%irp(irw),b%irp(irw+1)-1
icl = b%ja(i)
if (irow(icl)<j) then
nrc = nrc + 1
idxs(nrc) = icl
irow(icl) = j
end if
row(icl) = row(icl) + cfb*b%val(i)
end do
end do
if (nrc > 0 ) then
if ((nzc+nrc)>nze) then
nze = max(ma*((nzc+j-1)/j),nzc+2*nrc)
call psb_realloc(nze,c%val,info)
if (info == 0) call psb_realloc(nze,c%ja,info)
if (info /= 0) return
end if
call psb_qsort(idxs(1:nrc))
do i=1, nrc
irw = idxs(i)
c%ja(nzc) = irw
c%val(nzc) = row(irw)
row(irw) = dzero
nzc = nzc + 1
end do
end if
end do
c%irp(ma+1) = nzc
end subroutine csr_spspmm
end subroutine psb_zcsrspspmm
#else
#if defined(OPENMP)
subroutine psb_zcsrspspmm(a,b,c,info)
use psb_z_mat_mod
use psb_serial_mod, psb_protect_name => psb_zcsrspspmm
@ -4307,6 +4184,131 @@ contains
!$omp end parallel do
end subroutine spmm_omp_two_pass
end subroutine psb_zcsrspspmm
#else
subroutine psb_zcsrspspmm(a,b,c,info)
use psb_z_mat_mod
use psb_serial_mod, psb_protect_name => psb_zcsrspspmm
implicit none
class(psb_z_csr_sparse_mat), intent(in) :: a,b
type(psb_z_csr_sparse_mat), intent(out) :: c
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: ma,na,mb,nb, nzc, nza, nzb
character(len=20) :: name
integer(psb_ipk_) :: err_act
name='psb_csrspspmm'
call psb_erractionsave(err_act)
info = psb_success_
if (a%is_dev()) call a%sync()
if (b%is_dev()) call b%sync()
ma = a%get_nrows()
na = a%get_ncols()
mb = b%get_nrows()
nb = b%get_ncols()
if ( mb /= na ) then
write(psb_err_unit,*) 'Mismatch in SPSPMM: ',ma,na,mb,nb
info = psb_err_invalid_matrix_sizes_
call psb_errpush(info,name)
goto 9999
endif
! Estimate number of nonzeros on output.
nza = a%get_nzeros()
nzb = b%get_nzeros()
nzc = 2*(nza+nzb)
call c%allocate(ma,nb,nzc)
call csr_spspmm(a,b,c,info)
call c%set_asb()
call c%set_host()
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(err_act)
return
contains
subroutine csr_spspmm(a,b,c,info)
implicit none
type(psb_z_csr_sparse_mat), intent(in) :: a,b
type(psb_z_csr_sparse_mat), intent(inout) :: c
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: ma,na,mb,nb
integer(psb_ipk_), allocatable :: irow(:), idxs(:)
complex(psb_dpk_), allocatable :: row(:)
integer(psb_ipk_) :: i,j,k,irw,icl,icf, iret, &
& nzc,nnzre, isz, ipb, irwsz, nrc, nze
complex(psb_dpk_) :: cfb
info = psb_success_
ma = a%get_nrows()
na = a%get_ncols()
mb = b%get_nrows()
nb = b%get_ncols()
nze = min(size(c%val),size(c%ja))
isz = max(ma,na,mb,nb)
call psb_realloc(isz,row,info)
if (info == 0) call psb_realloc(isz,idxs,info)
if (info == 0) call psb_realloc(isz,irow,info)
if (info /= 0) return
row = dzero
irow = 0
nzc = 1
do j = 1,ma
c%irp(j) = nzc
nrc = 0
do k = a%irp(j), a%irp(j+1)-1
irw = a%ja(k)
cfb = a%val(k)
irwsz = b%irp(irw+1)-b%irp(irw)
do i = b%irp(irw),b%irp(irw+1)-1
icl = b%ja(i)
if (irow(icl)<j) then
nrc = nrc + 1
idxs(nrc) = icl
irow(icl) = j
end if
row(icl) = row(icl) + cfb*b%val(i)
end do
end do
if (nrc > 0 ) then
if ((nzc+nrc)>nze) then
nze = max(ma*((nzc+j-1)/j),nzc+2*nrc)
call psb_realloc(nze,c%val,info)
if (info == 0) call psb_realloc(nze,c%ja,info)
if (info /= 0) return
end if
call psb_qsort(idxs(1:nrc))
do i=1, nrc
irw = idxs(i)
c%ja(nzc) = irw
c%val(nzc) = row(irw)
row(irw) = dzero
nzc = nzc + 1
end do
end if
end do
c%irp(ma+1) = nzc
end subroutine csr_spspmm
end subroutine psb_zcsrspspmm
#endif

Loading…
Cancel
Save