From 93c71c43162fb6663cca9c2f0fff0e8f2ff4c47e Mon Sep 17 00:00:00 2001 From: sfilippone Date: Tue, 20 Feb 2024 10:25:31 +0100 Subject: [PATCH] Fix %ZERO() on cuda --- cuda/psb_c_cuda_vect_mod.F90 | 6 ++-- cuda/psb_d_cuda_vect_mod.F90 | 60 ++++++------------------------------ cuda/psb_d_vectordev_mod.F90 | 13 -------- cuda/psb_i_cuda_vect_mod.F90 | 6 ++-- cuda/psb_s_cuda_vect_mod.F90 | 6 ++-- cuda/psb_z_cuda_vect_mod.F90 | 6 ++-- 6 files changed, 21 insertions(+), 76 deletions(-) diff --git a/cuda/psb_c_cuda_vect_mod.F90 b/cuda/psb_c_cuda_vect_mod.F90 index 56cc80e6..fca1c616 100644 --- a/cuda/psb_c_cuda_vect_mod.F90 +++ b/cuda/psb_c_cuda_vect_mod.F90 @@ -668,9 +668,9 @@ contains use psi_serial_mod implicit none class(psb_c_vect_cuda), intent(inout) :: x - - if (allocated(x%v)) x%v=czero - call x%set_host() + + call x%set_scal(czero) + end subroutine c_cuda_zero subroutine c_cuda_asb_m(n, x, info) diff --git a/cuda/psb_d_cuda_vect_mod.F90 b/cuda/psb_d_cuda_vect_mod.F90 index f2ef2be3..2220b26c 100644 --- a/cuda/psb_d_cuda_vect_mod.F90 +++ b/cuda/psb_d_cuda_vect_mod.F90 @@ -668,9 +668,9 @@ contains use psi_serial_mod implicit none class(psb_d_vect_cuda), intent(inout) :: x - - if (allocated(x%v)) x%v=dzero - call x%set_host() + + call x%set_scal(dzero) + end subroutine d_cuda_zero subroutine d_cuda_asb_m(n, x, info) @@ -922,56 +922,14 @@ contains class(psb_d_vect_cuda), intent(inout) :: z real(psb_dpk_), intent (in) :: alpha, beta, gamma, delta integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_) :: nx, ny, nz - logical :: gpu_done - info = psb_success_ + call z%psb_d_base_vect_type%abgdxyz(m,alpha,beta,gamma,delta,x,y,info) +!!$ +!!$ if (x%is_dev()) call x%sync() +!!$ +!!$ call y%axpby(m,alpha,x,beta,info) +!!$ call z%axpby(m,gamma,y,delta,info) - if (.true.) then - gpu_done = .false. - select type(xx => x) - class is (psb_d_vect_cuda) - select type(yy => y) - class is (psb_d_vect_cuda) - select type(zz => z) - class is (psb_d_vect_cuda) - ! Do something different here - if ((beta /= dzero).and.yy%is_host())& - & call yy%sync() - if ((delta /= dzero).and.zz%is_host())& - & call zz%sync() - if (xx%is_host()) call xx%sync() - nx = getMultiVecDeviceSize(xx%deviceVect) - ny = getMultiVecDeviceSize(yy%deviceVect) - nz = getMultiVecDeviceSize(zz%deviceVect) - if ((nx