Fix %ZERO() on cuda

nond-rep
sfilippone 10 months ago
parent 0568a83734
commit 93c71c4316

@ -669,8 +669,8 @@ contains
implicit none implicit none
class(psb_c_vect_cuda), intent(inout) :: x class(psb_c_vect_cuda), intent(inout) :: x
if (allocated(x%v)) x%v=czero call x%set_scal(czero)
call x%set_host()
end subroutine c_cuda_zero end subroutine c_cuda_zero
subroutine c_cuda_asb_m(n, x, info) subroutine c_cuda_asb_m(n, x, info)

@ -669,8 +669,8 @@ contains
implicit none implicit none
class(psb_d_vect_cuda), intent(inout) :: x class(psb_d_vect_cuda), intent(inout) :: x
if (allocated(x%v)) x%v=dzero call x%set_scal(dzero)
call x%set_host()
end subroutine d_cuda_zero end subroutine d_cuda_zero
subroutine d_cuda_asb_m(n, x, info) subroutine d_cuda_asb_m(n, x, info)
@ -922,56 +922,14 @@ contains
class(psb_d_vect_cuda), intent(inout) :: z class(psb_d_vect_cuda), intent(inout) :: z
real(psb_dpk_), intent (in) :: alpha, beta, gamma, delta real(psb_dpk_), intent (in) :: alpha, beta, gamma, delta
integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: nx, ny, nz
logical :: gpu_done
info = psb_success_
if (.true.) then
gpu_done = .false.
select type(xx => x)
class is (psb_d_vect_cuda)
select type(yy => y)
class is (psb_d_vect_cuda)
select type(zz => z)
class is (psb_d_vect_cuda)
! Do something different here
if ((beta /= dzero).and.yy%is_host())&
& call yy%sync()
if ((delta /= dzero).and.zz%is_host())&
& call zz%sync()
if (xx%is_host()) call xx%sync()
nx = getMultiVecDeviceSize(xx%deviceVect)
ny = getMultiVecDeviceSize(yy%deviceVect)
nz = getMultiVecDeviceSize(zz%deviceVect)
if ((nx<m).or.(ny<m).or.(nz<m)) then
info = psb_err_internal_error_
else
info = abgdxyzMultiVecDevice(m,alpha,beta,gamma,delta,&
& xx%deviceVect,yy%deviceVect,zz%deviceVect)
end if
call yy%set_dev()
call zz%set_dev()
gpu_done = .true.
end select
end select
end select
if (.not.gpu_done) then call z%psb_d_base_vect_type%abgdxyz(m,alpha,beta,gamma,delta,x,y,info)
if (x%is_host()) call x%sync() !!$
if (y%is_host()) call y%sync() !!$ if (x%is_dev()) call x%sync()
if (z%is_host()) call z%sync() !!$
call y%axpby(m,alpha,x,beta,info) !!$ call y%axpby(m,alpha,x,beta,info)
call z%axpby(m,gamma,y,delta,info) !!$ call z%axpby(m,gamma,y,delta,info)
end if
else
if (x%is_host()) call x%sync()
if (y%is_host()) call y%sync()
if (z%is_host()) call z%sync()
call y%axpby(m,alpha,x,beta,info)
call z%axpby(m,gamma,y,delta,info)
end if
end subroutine d_cuda_abgdxyz end subroutine d_cuda_abgdxyz

@ -316,19 +316,6 @@ module psb_d_vectordev_mod
end function axpbyMultiVecDeviceDouble end function axpbyMultiVecDeviceDouble
end interface end interface
interface abgdxyzMultiVecDevice
function abgdxyzMultiVecDeviceDouble(n,alpha,beta,gamma,delta,deviceVecX,&
& deviceVecY,deviceVecZ) &
& result(res) bind(c,name='abgdxyzMultiVecDeviceDouble')
use iso_c_binding
integer(c_int) :: res
integer(c_int), value :: n
real(c_double), value :: alpha, beta,gamma,delta
type(c_ptr), value :: deviceVecX, deviceVecY, deviceVecZ
end function abgdxyzMultiVecDeviceDouble
end interface abgdxyzMultiVecDevice
interface axyMultiVecDevice interface axyMultiVecDevice
function axyMultiVecDeviceDouble(n,alpha,deviceVecA,deviceVecB) & function axyMultiVecDeviceDouble(n,alpha,deviceVecA,deviceVecB) &
& result(res) bind(c,name='axyMultiVecDeviceDouble') & result(res) bind(c,name='axyMultiVecDeviceDouble')

@ -651,8 +651,8 @@ contains
implicit none implicit none
class(psb_i_vect_cuda), intent(inout) :: x class(psb_i_vect_cuda), intent(inout) :: x
if (allocated(x%v)) x%v=izero call x%set_scal(izero)
call x%set_host()
end subroutine i_cuda_zero end subroutine i_cuda_zero
subroutine i_cuda_asb_m(n, x, info) subroutine i_cuda_asb_m(n, x, info)

@ -669,8 +669,8 @@ contains
implicit none implicit none
class(psb_s_vect_cuda), intent(inout) :: x class(psb_s_vect_cuda), intent(inout) :: x
if (allocated(x%v)) x%v=szero call x%set_scal(szero)
call x%set_host()
end subroutine s_cuda_zero end subroutine s_cuda_zero
subroutine s_cuda_asb_m(n, x, info) subroutine s_cuda_asb_m(n, x, info)

@ -669,8 +669,8 @@ contains
implicit none implicit none
class(psb_z_vect_cuda), intent(inout) :: x class(psb_z_vect_cuda), intent(inout) :: x
if (allocated(x%v)) x%v=zzero call x%set_scal(zzero)
call x%set_host()
end subroutine z_cuda_zero end subroutine z_cuda_zero
subroutine z_cuda_asb_m(n, x, info) subroutine z_cuda_asb_m(n, x, info)

Loading…
Cancel
Save