Merged CSRE experiment.

psblas-caf-ext
Salvatore Filippone 8 years ago
parent 89235a0dea
commit ff0cfac9cf

@ -101,6 +101,15 @@ module psb_d_csr_mat_mod
end type psb_d_csr_sparse_mat
type, extends(psb_d_csr_sparse_mat) :: psb_d_csre_sparse_mat
contains
procedure, pass(a) :: csmv => psb_d_csre_csmv
procedure, nopass :: get_fmt => d_csre_get_fmt
end type psb_d_csre_sparse_mat
private :: d_csr_get_nzeros, d_csr_free, d_csr_get_fmt, &
& d_csr_get_size, d_csr_sizeof, d_csr_get_nz_row, &
& d_csr_is_by_rows
@ -494,6 +503,19 @@ module psb_d_csr_mat_mod
end subroutine psb_d_csr_scals
end interface
!> \memberof psb_d_csre_sparse_mat
!! \see psb_d_base_mat_mod::psb_d_base_csmv
interface
subroutine psb_d_csre_csmv(alpha,a,x,beta,y,info,trans)
import :: psb_ipk_, psb_d_csre_sparse_mat, psb_dpk_
class(psb_d_csre_sparse_mat), intent(in) :: a
real(psb_dpk_), intent(in) :: alpha, beta, x(:)
real(psb_dpk_), intent(inout) :: y(:)
integer(psb_ipk_), intent(out) :: info
character, optional, intent(in) :: trans
end subroutine psb_d_csre_csmv
end interface
contains
@ -538,6 +560,12 @@ contains
res = 'CSR'
end function d_csr_get_fmt
function d_csre_get_fmt() result(res)
implicit none
character(len=5) :: res
res = 'CSRe'
end function d_csre_get_fmt
function d_csr_get_nzeros(a) result(res)
implicit none
class(psb_d_csr_sparse_mat), intent(in) :: a

@ -72,7 +72,7 @@
module psb_d_mat_mod
use psb_d_base_mat_mod
use psb_d_csr_mat_mod, only : psb_d_csr_sparse_mat
use psb_d_csr_mat_mod, only : psb_d_csr_sparse_mat, psb_d_csre_sparse_mat
use psb_d_csc_mat_mod, only : psb_d_csc_sparse_mat
type :: psb_dspmat_type

@ -3180,3 +3180,365 @@ contains
end subroutine csr_spspmm
end subroutine psb_dcsrspspmm
! == ===================================
!
!
!
! Computational routines
!
!
!
!
!
!
! == ===================================
subroutine psb_d_csre_csmv(alpha,a,x,beta,y,info,trans)
use psb_error_mod
use psb_string_mod
use psb_d_csr_mat_mod, psb_protect_name => psb_d_csre_csmv
implicit none
class(psb_d_csre_sparse_mat), intent(in) :: a
real(psb_dpk_), intent(in) :: alpha, beta, x(:)
real(psb_dpk_), intent(inout) :: y(:)
integer(psb_ipk_), intent(out) :: info
character, optional, intent(in) :: trans
character :: trans_
integer(psb_ipk_) :: i,j,k,m,n, nnz, ir, jc
real(psb_dpk_) :: acc
logical :: tra, ctra
integer(psb_ipk_) :: err_act
integer(psb_ipk_) :: ierr(5)
character(len=20) :: name='d_csr_csmv'
logical, parameter :: debug=.false.
call psb_erractionsave(err_act)
info = psb_success_
if (a%is_dev()) call a%sync()
if (present(trans)) then
trans_ = trans
else
trans_ = 'N'
end if
if (.not.a%is_asb()) then
info = psb_err_invalid_mat_state_
call psb_errpush(info,name)
goto 9999
endif
tra = (psb_toupper(trans_) == 'T')
ctra = (psb_toupper(trans_) == 'C')
if (tra.or.ctra) then
m = a%get_ncols()
n = a%get_nrows()
else
n = a%get_ncols()
m = a%get_nrows()
end if
if (size(x,1)<n) then
info = psb_err_input_asize_small_i_
ierr(1) = 3; ierr(2) = n;
call psb_errpush(info,name,i_err=ierr)
goto 9999
end if
if (size(y,1)<m) then
info = psb_err_input_asize_small_i_
ierr(1) = 5; ierr(2) = m;
call psb_errpush(info,name,i_err=ierr)
goto 9999
end if
call psb_d_csre_csmv_inner(m,n,alpha,a%irp,a%ja,a%val,&
& a%is_triangle(),a%is_unit(),&
& x,beta,y,tra,ctra)
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(err_act)
return
contains
subroutine psb_d_csre_csmv_inner(m,n,alpha,irp,ja,val,is_triangle,is_unit,&
& x,beta,y,tra,ctra)
integer(psb_ipk_), intent(in) :: m,n,irp(*),ja(*)
real(psb_dpk_), intent(in) :: alpha, beta, x(*),val(*)
real(psb_dpk_), intent(inout) :: y(*)
logical, intent(in) :: is_triangle,is_unit,tra, ctra
integer(psb_ipk_) :: i,j,k, ir, jc
real(psb_dpk_) :: acc
if (alpha == dzero) then
if (beta == dzero) then
do i = 1, m
y(i) = dzero
enddo
else
do i = 1, m
y(i) = beta*y(i)
end do
endif
return
end if
if ((.not.tra).and.(.not.ctra)) then
if (beta == dzero) then
if (alpha == done) then
do i=1,m
acc = dzero
do j=irp(i), irp(i+1)-1
acc = acc + val(j) * x(ja(j))
enddo
y(i) = acc
end do
else if (alpha == -done) then
do i=1,m
acc = dzero
do j=irp(i), irp(i+1)-1
acc = acc + val(j) * x(ja(j))
enddo
y(i) = -acc
end do
else
do i=1,m
acc = dzero
do j=irp(i), irp(i+1)-1
acc = acc + val(j) * x(ja(j))
enddo
y(i) = alpha*acc
end do
end if
else if (beta == done) then
if (alpha == done) then
do i=1,m
acc = dzero
do j=irp(i), irp(i+1)-1
acc = acc + val(j) * x(ja(j))
enddo
if (acc /= dzero) y(i) = y(i) + acc
end do
else if (alpha == -done) then
do i=1,m
acc = dzero
do j=irp(i), irp(i+1)-1
acc = acc + val(j) * x(ja(j))
enddo
if (acc /= dzero) y(i) = y(i) -acc
end do
else
do i=1,m
acc = dzero
do j=irp(i), irp(i+1)-1
acc = acc + val(j) * x(ja(j))
enddo
if (acc /= dzero) y(i) = y(i) + alpha*acc
end do
end if
else if (beta == -done) then
if (alpha == done) then
do i=1,m
acc = dzero
do j=irp(i), irp(i+1)-1
acc = acc + val(j) * x(ja(j))
enddo
y(i) = -y(i) + acc
end do
else if (alpha == -done) then
do i=1,m
acc = dzero
do j=irp(i), irp(i+1)-1
acc = acc + val(j) * x(ja(j))
enddo
y(i) = -y(i) -acc
end do
else
do i=1,m
acc = dzero
do j=irp(i), irp(i+1)-1
acc = acc + val(j) * x(ja(j))
enddo
y(i) = -y(i) + alpha*acc
end do
end if
else
if (alpha == done) then
do i=1,m
acc = dzero
do j=irp(i), irp(i+1)-1
acc = acc + val(j) * x(ja(j))
enddo
y(i) = beta*y(i) + acc
end do
else if (alpha == -done) then
do i=1,m
acc = dzero
do j=irp(i), irp(i+1)-1
acc = acc + val(j) * x(ja(j))
enddo
y(i) = beta*y(i) - acc
end do
else
do i=1,m
acc = dzero
do j=irp(i), irp(i+1)-1
acc = acc + val(j) * x(ja(j))
enddo
y(i) = beta*y(i) + alpha*acc
end do
end if
end if
else if (tra) then
if (beta == dzero) then
do i=1, m
y(i) = dzero
end do
else if (beta == done) then
! Do nothing
else if (beta == -done) then
do i=1, m
y(i) = -y(i)
end do
else
do i=1, m
y(i) = beta*y(i)
end do
end if
if (alpha == done) then
do i=1,n
do j=irp(i), irp(i+1)-1
ir = ja(j)
y(ir) = y(ir) + val(j)*x(i)
end do
enddo
else if (alpha == -done) then
do i=1,n
do j=irp(i), irp(i+1)-1
ir = ja(j)
y(ir) = y(ir) - val(j)*x(i)
end do
enddo
else
do i=1,n
do j=irp(i), irp(i+1)-1
ir = ja(j)
y(ir) = y(ir) + alpha*val(j)*x(i)
end do
enddo
end if
else if (ctra) then
if (beta == dzero) then
do i=1, m
y(i) = dzero
end do
else if (beta == done) then
! Do nothing
else if (beta == -done) then
do i=1, m
y(i) = -y(i)
end do
else
do i=1, m
y(i) = beta*y(i)
end do
end if
if (alpha == done) then
do i=1,n
do j=irp(i), irp(i+1)-1
ir = ja(j)
y(ir) = y(ir) + (val(j))*x(i)
end do
enddo
else if (alpha == -done) then
do i=1,n
do j=irp(i), irp(i+1)-1
ir = ja(j)
y(ir) = y(ir) - (val(j))*x(i)
end do
enddo
else
do i=1,n
do j=irp(i), irp(i+1)-1
ir = ja(j)
y(ir) = y(ir) + alpha*(val(j))*x(i)
end do
enddo
end if
endif
if (is_unit) then
do i=1, min(m,n)
y(i) = y(i) + alpha*x(i)
end do
end if
end subroutine psb_d_csre_csmv_inner
end subroutine psb_d_csre_csmv

@ -49,10 +49,11 @@ program pdgenspmv
! sparse matrix and preconditioner
type(psb_dspmat_type) :: a, ad, ah
type(psb_d_csr_sparse_mat) :: acsr
type(psb_d_csre_sparse_mat) :: acsre
! descriptor
type(psb_desc_type) :: desc_a
! dense matrices
type(psb_d_vect_type) :: xv,bv, vtst
type(psb_d_vect_type) :: xv,bv, vtst,bvh
real(psb_dpk_), allocatable :: tst(:), work(:)
! blacs parameters
integer(psb_ipk_) :: ictxt, iam, np
@ -109,12 +110,14 @@ program pdgenspmv
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
call psb_geasb(bvh,desc_a,info,scratch=.true.)
if (iam == psb_root_) write(psb_out_unit,'("Overall matrix creation time : ",es12.5)')t2
if (iam == psb_root_) write(psb_out_unit,'(" ")')
nrl = desc_a%get_local_rows()
ncl = desc_a%get_local_cols()
call a%csclip(ad,info,jmax=nrl)
call a%csclip(ah,info,jmin=nrl+1,jmax=ncl,cscale=.true.)
call ah%cscnv(info,mold=acsre)
lwork = 2*ncl
allocate(work(lwork), stat=info)
@ -135,16 +138,18 @@ program pdgenspmv
do i=1,times
call psi_swapdata(psb_swap_send_,&
& dzero,xv%v,desc_a,work,info,data=psb_comm_halo_)
call psb_csmm(done,ad,xv,dzero,bv,info)
call psb_csmm(done,ad,xv,dzero,bvh,info)
call psi_swapdata(psb_swap_recv_,&
& dzero,xv%v,desc_a,work,info,data=psb_comm_halo_)
call ah%a%csmv(done,xv%v%v(nrl+1:),done,bv%v%v,info)
call ah%a%csmv(done,xv%v%v(nrl+1:),done,bvh%v%v,info)
end do
call psb_barrier(ictxt)
tt2 = psb_wtime() - tt1
call psb_amx(ictxt,tt2)
call psb_amx(ictxt,t2)
call psb_geaxpby(-done,bv,done,bvh,desc_a,info)
err = psb_genrm2(bvh,desc_a,info)
nr = desc_a%get_global_rows()
annz = a%get_nzeros()
amatsize = a%sizeof()
@ -171,7 +176,7 @@ program pdgenspmv
write(psb_out_unit,'("Time for ",i0," products (s) (trans.): ",F20.3)') times,tt2
write(psb_out_unit,'("Time per product (ms) (trans.): ",F20.3)') tt2*1.d3/(1.d0*times)
write(psb_out_unit,'("MFLOPS (trans.): ",F20.3)') tflops/1.d6
write(psb_out_unit,'("Difference : ",E20.12)') err
!
! This computation is valid for CSR
!

Loading…
Cancel
Save