From 3d9fee2dd7f65b0bcf72c0dce46b9908a5502bd1 Mon Sep 17 00:00:00 2001 From: sfilippone Date: Tue, 8 Oct 2024 17:07:10 +0200 Subject: [PATCH] Fix DOT on CUDA vectors. --- cuda/dvectordev.c | 3 ++- cuda/psb_c_cuda_vect_mod.F90 | 22 +++------------------- cuda/psb_d_cuda_vect_mod.F90 | 18 +----------------- cuda/psb_i_cuda_vect_mod.F90 | 12 ------------ cuda/psb_s_cuda_vect_mod.F90 | 22 +++------------------- cuda/psb_z_cuda_vect_mod.F90 | 22 +++------------------- cuda/svectordev.c | 3 ++- cuda/zvectordev.c | 3 ++- 8 files changed, 16 insertions(+), 89 deletions(-) diff --git a/cuda/dvectordev.c b/cuda/dvectordev.c index a69e1b71..d4f5513b 100644 --- a/cuda/dvectordev.c +++ b/cuda/dvectordev.c @@ -220,7 +220,8 @@ int dotMultiVecDeviceDouble(double* y_res, int n, void* devMultiVecA, void* devM struct MultiVectDevice *devVecB = (struct MultiVectDevice *) devMultiVecB; spgpuHandle_t handle=psb_cudaGetHandle(); - spgpuDmdot(handle, y_res, n, (double*)devVecA->v_, (double*)devVecB->v_,devVecA->count_,devVecB->pitch_); + spgpuDmdot(handle, y_res, n, (double*)devVecA->v_, (double*)devVecB->v_, + devVecA->count_,devVecB->pitch_); return(i); } diff --git a/cuda/psb_c_cuda_vect_mod.F90 b/cuda/psb_c_cuda_vect_mod.F90 index 45fafe0a..9755b386 100644 --- a/cuda/psb_c_cuda_vect_mod.F90 +++ b/cuda/psb_c_cuda_vect_mod.F90 @@ -813,18 +813,6 @@ contains call x%set_dev() end subroutine c_cuda_set_scal -!!$ -!!$ subroutine c_cuda_set_vect(x,val) -!!$ class(psb_c_vect_cuda), intent(inout) :: x -!!$ complex(psb_spk_), intent(in) :: val(:) -!!$ integer(psb_ipk_) :: nr -!!$ integer(psb_ipk_) :: info -!!$ -!!$ if (x%is_dev()) call x%sync() -!!$ call x%psb_c_base_vect_type%set_vect(val) -!!$ call x%set_host() -!!$ -!!$ end subroutine c_cuda_set_vect @@ -834,7 +822,6 @@ contains class(psb_c_base_vect_type), intent(inout) :: y integer(psb_ipk_), intent(in) :: n complex(psb_spk_) :: res - complex(psb_spk_), external :: ddot integer(psb_ipk_) :: info res = czero @@ -844,9 +831,6 @@ contains ! TYPE psb_c_vect ! select type(yy => y) - type is (psb_c_base_vect_type) - if (x%is_dev()) call x%sync() - res = ddot(n,x%v,1,yy%v,1) type is (psb_c_vect_cuda) if (x%is_host()) call x%sync() if (yy%is_host()) call yy%sync() @@ -858,7 +842,7 @@ contains class default ! y%sync is done in dot_a - call x%sync() + if (x%is_dev()) call x%sync() res = y%dot(n,x%v) end select @@ -870,10 +854,10 @@ contains complex(psb_spk_), intent(in) :: y(:) integer(psb_ipk_), intent(in) :: n complex(psb_spk_) :: res - complex(psb_spk_), external :: ddot + complex(psb_spk_), external :: cdot if (x%is_dev()) call x%sync() - res = ddot(n,y,1,x%v,1) + res = cdot(n,y,1,x%v,1) end function c_cuda_dot_a diff --git a/cuda/psb_d_cuda_vect_mod.F90 b/cuda/psb_d_cuda_vect_mod.F90 index e7e563ff..dfa83c60 100644 --- a/cuda/psb_d_cuda_vect_mod.F90 +++ b/cuda/psb_d_cuda_vect_mod.F90 @@ -813,18 +813,6 @@ contains call x%set_dev() end subroutine d_cuda_set_scal -!!$ -!!$ subroutine d_cuda_set_vect(x,val) -!!$ class(psb_d_vect_cuda), intent(inout) :: x -!!$ real(psb_dpk_), intent(in) :: val(:) -!!$ integer(psb_ipk_) :: nr -!!$ integer(psb_ipk_) :: info -!!$ -!!$ if (x%is_dev()) call x%sync() -!!$ call x%psb_d_base_vect_type%set_vect(val) -!!$ call x%set_host() -!!$ -!!$ end subroutine d_cuda_set_vect @@ -834,7 +822,6 @@ contains class(psb_d_base_vect_type), intent(inout) :: y integer(psb_ipk_), intent(in) :: n real(psb_dpk_) :: res - real(psb_dpk_), external :: ddot integer(psb_ipk_) :: info res = dzero @@ -844,9 +831,6 @@ contains ! TYPE psb_d_vect ! select type(yy => y) - type is (psb_d_base_vect_type) - if (x%is_dev()) call x%sync() - res = ddot(n,x%v,1,yy%v,1) type is (psb_d_vect_cuda) if (x%is_host()) call x%sync() if (yy%is_host()) call yy%sync() @@ -858,7 +842,7 @@ contains class default ! y%sync is done in dot_a - call x%sync() + if (x%is_dev()) call x%sync() res = y%dot(n,x%v) end select diff --git a/cuda/psb_i_cuda_vect_mod.F90 b/cuda/psb_i_cuda_vect_mod.F90 index 461d84d1..4be4679c 100644 --- a/cuda/psb_i_cuda_vect_mod.F90 +++ b/cuda/psb_i_cuda_vect_mod.F90 @@ -795,18 +795,6 @@ contains call x%set_dev() end subroutine i_cuda_set_scal -!!$ -!!$ subroutine i_cuda_set_vect(x,val) -!!$ class(psb_i_vect_cuda), intent(inout) :: x -!!$ integer(psb_ipk_), intent(in) :: val(:) -!!$ integer(psb_ipk_) :: nr -!!$ integer(psb_ipk_) :: info -!!$ -!!$ if (x%is_dev()) call x%sync() -!!$ call x%psb_i_base_vect_type%set_vect(val) -!!$ call x%set_host() -!!$ -!!$ end subroutine i_cuda_set_vect diff --git a/cuda/psb_s_cuda_vect_mod.F90 b/cuda/psb_s_cuda_vect_mod.F90 index a2c69934..39a108ab 100644 --- a/cuda/psb_s_cuda_vect_mod.F90 +++ b/cuda/psb_s_cuda_vect_mod.F90 @@ -813,18 +813,6 @@ contains call x%set_dev() end subroutine s_cuda_set_scal -!!$ -!!$ subroutine s_cuda_set_vect(x,val) -!!$ class(psb_s_vect_cuda), intent(inout) :: x -!!$ real(psb_spk_), intent(in) :: val(:) -!!$ integer(psb_ipk_) :: nr -!!$ integer(psb_ipk_) :: info -!!$ -!!$ if (x%is_dev()) call x%sync() -!!$ call x%psb_s_base_vect_type%set_vect(val) -!!$ call x%set_host() -!!$ -!!$ end subroutine s_cuda_set_vect @@ -834,7 +822,6 @@ contains class(psb_s_base_vect_type), intent(inout) :: y integer(psb_ipk_), intent(in) :: n real(psb_spk_) :: res - real(psb_spk_), external :: ddot integer(psb_ipk_) :: info res = szero @@ -844,9 +831,6 @@ contains ! TYPE psb_s_vect ! select type(yy => y) - type is (psb_s_base_vect_type) - if (x%is_dev()) call x%sync() - res = ddot(n,x%v,1,yy%v,1) type is (psb_s_vect_cuda) if (x%is_host()) call x%sync() if (yy%is_host()) call yy%sync() @@ -858,7 +842,7 @@ contains class default ! y%sync is done in dot_a - call x%sync() + if (x%is_dev()) call x%sync() res = y%dot(n,x%v) end select @@ -870,10 +854,10 @@ contains real(psb_spk_), intent(in) :: y(:) integer(psb_ipk_), intent(in) :: n real(psb_spk_) :: res - real(psb_spk_), external :: ddot + real(psb_spk_), external :: sdot if (x%is_dev()) call x%sync() - res = ddot(n,y,1,x%v,1) + res = sdot(n,y,1,x%v,1) end function s_cuda_dot_a diff --git a/cuda/psb_z_cuda_vect_mod.F90 b/cuda/psb_z_cuda_vect_mod.F90 index dfeafa6e..d4318bea 100644 --- a/cuda/psb_z_cuda_vect_mod.F90 +++ b/cuda/psb_z_cuda_vect_mod.F90 @@ -813,18 +813,6 @@ contains call x%set_dev() end subroutine z_cuda_set_scal -!!$ -!!$ subroutine z_cuda_set_vect(x,val) -!!$ class(psb_z_vect_cuda), intent(inout) :: x -!!$ complex(psb_dpk_), intent(in) :: val(:) -!!$ integer(psb_ipk_) :: nr -!!$ integer(psb_ipk_) :: info -!!$ -!!$ if (x%is_dev()) call x%sync() -!!$ call x%psb_z_base_vect_type%set_vect(val) -!!$ call x%set_host() -!!$ -!!$ end subroutine z_cuda_set_vect @@ -834,7 +822,6 @@ contains class(psb_z_base_vect_type), intent(inout) :: y integer(psb_ipk_), intent(in) :: n complex(psb_dpk_) :: res - complex(psb_dpk_), external :: ddot integer(psb_ipk_) :: info res = zzero @@ -844,9 +831,6 @@ contains ! TYPE psb_z_vect ! select type(yy => y) - type is (psb_z_base_vect_type) - if (x%is_dev()) call x%sync() - res = ddot(n,x%v,1,yy%v,1) type is (psb_z_vect_cuda) if (x%is_host()) call x%sync() if (yy%is_host()) call yy%sync() @@ -858,7 +842,7 @@ contains class default ! y%sync is done in dot_a - call x%sync() + if (x%is_dev()) call x%sync() res = y%dot(n,x%v) end select @@ -870,10 +854,10 @@ contains complex(psb_dpk_), intent(in) :: y(:) integer(psb_ipk_), intent(in) :: n complex(psb_dpk_) :: res - complex(psb_dpk_), external :: ddot + complex(psb_dpk_), external :: zdot if (x%is_dev()) call x%sync() - res = ddot(n,y,1,x%v,1) + res = zdot(n,y,1,x%v,1) end function z_cuda_dot_a diff --git a/cuda/svectordev.c b/cuda/svectordev.c index cfaef5ce..ab4dd01b 100644 --- a/cuda/svectordev.c +++ b/cuda/svectordev.c @@ -220,7 +220,8 @@ int dotMultiVecDeviceFloat(float* y_res, int n, void* devMultiVecA, void* devMul struct MultiVectDevice *devVecB = (struct MultiVectDevice *) devMultiVecB; spgpuHandle_t handle=psb_cudaGetHandle(); - spgpuSmdot(handle, y_res, n, (float*)devVecA->v_, (float*)devVecB->v_,devVecA->count_,devVecB->pitch_); + spgpuSmdot(handle, y_res, n, (float*)devVecA->v_, (float*)devVecB->v_, + devVecA->count_,devVecB->pitch_); return(i); } diff --git a/cuda/zvectordev.c b/cuda/zvectordev.c index d7d88f1b..3a5b0738 100644 --- a/cuda/zvectordev.c +++ b/cuda/zvectordev.c @@ -223,7 +223,8 @@ int scalMultiVecDeviceDoubleComplex(cuDoubleComplex alpha, void* devMultiVecA) return(i); } -int dotMultiVecDeviceDoubleComplex(cuDoubleComplex* y_res, int n, void* devMultiVecA, void* devMultiVecB) +int dotMultiVecDeviceDoubleComplex(cuDoubleComplex* y_res, int n, + void* devMultiVecA, void* devMultiVecB) {int i=0; struct MultiVectDevice *devVecA = (struct MultiVectDevice *) devMultiVecA; struct MultiVectDevice *devVecB = (struct MultiVectDevice *) devMultiVecB;