Fix OpenACC version of ELL vect_mv

oacc_loloum
sfilippone 5 months ago
parent fa5e7ff945
commit 95c546aadd

@ -10,13 +10,13 @@ contains
integer(psb_ipk_), intent(out) :: info
character, optional, intent(in) :: trans
integer(psb_ipk_) :: m, n, nzt
integer(psb_ipk_) :: m, n, nzt, nc
info = psb_success_
m = a%get_nrows()
n = a%get_ncols()
nzt = a%nzt
nc = size(a%ja,2)
if ((n /= size(x%v)) .or. (m /= size(y%v))) then
write(0,*) 'Size error ', m, n, size(x%v), size(y%v)
info = psb_err_invalid_mat_state_
@ -27,14 +27,15 @@ contains
if (x%is_host()) call x%sync()
if (y%is_host()) call y%sync()
call inner_spmv(m, n, nzt, alpha, a%val, a%ja, x%v, beta, y%v, info)
call inner_spmv(m, n, nc, alpha, a%val, a%ja, x%v, beta, y%v, info)
call y%set_dev()
contains
subroutine inner_spmv(m, n, nzt, alpha, val, ja, x, beta, y, info)
subroutine inner_spmv(m, n, nc, alpha, val, ja, x, beta, y, info)
implicit none
integer(psb_ipk_) :: m, n, nzt
integer(psb_ipk_) :: m, n, nc
complex(psb_spk_), intent(in) :: alpha, beta
complex(psb_spk_) :: val(:,:), x(:), y(:)
integer(psb_ipk_) :: ja(:,:)
@ -52,7 +53,7 @@ contains
do i = ii, ii + isz - 1
tmp = 0.0_psb_dpk_
!$acc loop seq
do j = 1, nzt
do j = 1, nc
if (ja(i,j) > 0) then
tmp = tmp + val(i,j) * x(ja(i,j))
end if

@ -10,13 +10,13 @@ contains
integer(psb_ipk_), intent(out) :: info
character, optional, intent(in) :: trans
integer(psb_ipk_) :: m, n, nzt
integer(psb_ipk_) :: m, n, nzt, nc
info = psb_success_
m = a%get_nrows()
n = a%get_ncols()
nzt = a%nzt
nc = size(a%ja,2)
if ((n /= size(x%v)) .or. (m /= size(y%v))) then
write(0,*) 'Size error ', m, n, size(x%v), size(y%v)
info = psb_err_invalid_mat_state_
@ -27,14 +27,15 @@ contains
if (x%is_host()) call x%sync()
if (y%is_host()) call y%sync()
call inner_spmv(m, n, nzt, alpha, a%val, a%ja, x%v, beta, y%v, info)
call inner_spmv(m, n, nc, alpha, a%val, a%ja, x%v, beta, y%v, info)
call y%set_dev()
contains
subroutine inner_spmv(m, n, nzt, alpha, val, ja, x, beta, y, info)
subroutine inner_spmv(m, n, nc, alpha, val, ja, x, beta, y, info)
implicit none
integer(psb_ipk_) :: m, n, nzt
integer(psb_ipk_) :: m, n, nc
real(psb_dpk_), intent(in) :: alpha, beta
real(psb_dpk_) :: val(:,:), x(:), y(:)
integer(psb_ipk_) :: ja(:,:)
@ -52,7 +53,7 @@ contains
do i = ii, ii + isz - 1
tmp = 0.0_psb_dpk_
!$acc loop seq
do j = 1, nzt
do j = 1, nc
if (ja(i,j) > 0) then
tmp = tmp + val(i,j) * x(ja(i,j))
end if

@ -10,13 +10,13 @@ contains
integer(psb_ipk_), intent(out) :: info
character, optional, intent(in) :: trans
integer(psb_ipk_) :: m, n, nzt
integer(psb_ipk_) :: m, n, nzt, nc
info = psb_success_
m = a%get_nrows()
n = a%get_ncols()
nzt = a%nzt
nc = size(a%ja,2)
if ((n /= size(x%v)) .or. (m /= size(y%v))) then
write(0,*) 'Size error ', m, n, size(x%v), size(y%v)
info = psb_err_invalid_mat_state_
@ -27,14 +27,15 @@ contains
if (x%is_host()) call x%sync()
if (y%is_host()) call y%sync()
call inner_spmv(m, n, nzt, alpha, a%val, a%ja, x%v, beta, y%v, info)
call inner_spmv(m, n, nc, alpha, a%val, a%ja, x%v, beta, y%v, info)
call y%set_dev()
contains
subroutine inner_spmv(m, n, nzt, alpha, val, ja, x, beta, y, info)
subroutine inner_spmv(m, n, nc, alpha, val, ja, x, beta, y, info)
implicit none
integer(psb_ipk_) :: m, n, nzt
integer(psb_ipk_) :: m, n, nc
real(psb_spk_), intent(in) :: alpha, beta
real(psb_spk_) :: val(:,:), x(:), y(:)
integer(psb_ipk_) :: ja(:,:)
@ -52,7 +53,7 @@ contains
do i = ii, ii + isz - 1
tmp = 0.0_psb_dpk_
!$acc loop seq
do j = 1, nzt
do j = 1, nc
if (ja(i,j) > 0) then
tmp = tmp + val(i,j) * x(ja(i,j))
end if

@ -10,13 +10,13 @@ contains
integer(psb_ipk_), intent(out) :: info
character, optional, intent(in) :: trans
integer(psb_ipk_) :: m, n, nzt
integer(psb_ipk_) :: m, n, nzt, nc
info = psb_success_
m = a%get_nrows()
n = a%get_ncols()
nzt = a%nzt
nc = size(a%ja,2)
if ((n /= size(x%v)) .or. (m /= size(y%v))) then
write(0,*) 'Size error ', m, n, size(x%v), size(y%v)
info = psb_err_invalid_mat_state_
@ -27,14 +27,15 @@ contains
if (x%is_host()) call x%sync()
if (y%is_host()) call y%sync()
call inner_spmv(m, n, nzt, alpha, a%val, a%ja, x%v, beta, y%v, info)
call inner_spmv(m, n, nc, alpha, a%val, a%ja, x%v, beta, y%v, info)
call y%set_dev()
contains
subroutine inner_spmv(m, n, nzt, alpha, val, ja, x, beta, y, info)
subroutine inner_spmv(m, n, nc, alpha, val, ja, x, beta, y, info)
implicit none
integer(psb_ipk_) :: m, n, nzt
integer(psb_ipk_) :: m, n, nc
complex(psb_dpk_), intent(in) :: alpha, beta
complex(psb_dpk_) :: val(:,:), x(:), y(:)
integer(psb_ipk_) :: ja(:,:)
@ -52,7 +53,7 @@ contains
do i = ii, ii + isz - 1
tmp = 0.0_psb_dpk_
!$acc loop seq
do j = 1, nzt
do j = 1, nc
if (ja(i,j) > 0) then
tmp = tmp + val(i,j) * x(ja(i,j))
end if

Loading…
Cancel
Save