Fix DOT on CUDA vectors.

repack-newsolve
sfilippone 4 months ago
parent 949499265e
commit 3d9fee2dd7

@ -220,7 +220,8 @@ int dotMultiVecDeviceDouble(double* y_res, int n, void* devMultiVecA, void* devM
struct MultiVectDevice *devVecB = (struct MultiVectDevice *) devMultiVecB; struct MultiVectDevice *devVecB = (struct MultiVectDevice *) devMultiVecB;
spgpuHandle_t handle=psb_cudaGetHandle(); spgpuHandle_t handle=psb_cudaGetHandle();
spgpuDmdot(handle, y_res, n, (double*)devVecA->v_, (double*)devVecB->v_,devVecA->count_,devVecB->pitch_); spgpuDmdot(handle, y_res, n, (double*)devVecA->v_, (double*)devVecB->v_,
devVecA->count_,devVecB->pitch_);
return(i); return(i);
} }

@ -813,18 +813,6 @@ contains
call x%set_dev() call x%set_dev()
end subroutine c_cuda_set_scal end subroutine c_cuda_set_scal
!!$
!!$ subroutine c_cuda_set_vect(x,val)
!!$ class(psb_c_vect_cuda), intent(inout) :: x
!!$ complex(psb_spk_), intent(in) :: val(:)
!!$ integer(psb_ipk_) :: nr
!!$ integer(psb_ipk_) :: info
!!$
!!$ if (x%is_dev()) call x%sync()
!!$ call x%psb_c_base_vect_type%set_vect(val)
!!$ call x%set_host()
!!$
!!$ end subroutine c_cuda_set_vect
@ -834,7 +822,6 @@ contains
class(psb_c_base_vect_type), intent(inout) :: y class(psb_c_base_vect_type), intent(inout) :: y
integer(psb_ipk_), intent(in) :: n integer(psb_ipk_), intent(in) :: n
complex(psb_spk_) :: res complex(psb_spk_) :: res
complex(psb_spk_), external :: ddot
integer(psb_ipk_) :: info integer(psb_ipk_) :: info
res = czero res = czero
@ -844,9 +831,6 @@ contains
! TYPE psb_c_vect ! TYPE psb_c_vect
! !
select type(yy => y) select type(yy => y)
type is (psb_c_base_vect_type)
if (x%is_dev()) call x%sync()
res = ddot(n,x%v,1,yy%v,1)
type is (psb_c_vect_cuda) type is (psb_c_vect_cuda)
if (x%is_host()) call x%sync() if (x%is_host()) call x%sync()
if (yy%is_host()) call yy%sync() if (yy%is_host()) call yy%sync()
@ -858,7 +842,7 @@ contains
class default class default
! y%sync is done in dot_a ! y%sync is done in dot_a
call x%sync() if (x%is_dev()) call x%sync()
res = y%dot(n,x%v) res = y%dot(n,x%v)
end select end select
@ -870,10 +854,10 @@ contains
complex(psb_spk_), intent(in) :: y(:) complex(psb_spk_), intent(in) :: y(:)
integer(psb_ipk_), intent(in) :: n integer(psb_ipk_), intent(in) :: n
complex(psb_spk_) :: res complex(psb_spk_) :: res
complex(psb_spk_), external :: ddot complex(psb_spk_), external :: cdot
if (x%is_dev()) call x%sync() if (x%is_dev()) call x%sync()
res = ddot(n,y,1,x%v,1) res = cdot(n,y,1,x%v,1)
end function c_cuda_dot_a end function c_cuda_dot_a

@ -813,18 +813,6 @@ contains
call x%set_dev() call x%set_dev()
end subroutine d_cuda_set_scal end subroutine d_cuda_set_scal
!!$
!!$ subroutine d_cuda_set_vect(x,val)
!!$ class(psb_d_vect_cuda), intent(inout) :: x
!!$ real(psb_dpk_), intent(in) :: val(:)
!!$ integer(psb_ipk_) :: nr
!!$ integer(psb_ipk_) :: info
!!$
!!$ if (x%is_dev()) call x%sync()
!!$ call x%psb_d_base_vect_type%set_vect(val)
!!$ call x%set_host()
!!$
!!$ end subroutine d_cuda_set_vect
@ -834,7 +822,6 @@ contains
class(psb_d_base_vect_type), intent(inout) :: y class(psb_d_base_vect_type), intent(inout) :: y
integer(psb_ipk_), intent(in) :: n integer(psb_ipk_), intent(in) :: n
real(psb_dpk_) :: res real(psb_dpk_) :: res
real(psb_dpk_), external :: ddot
integer(psb_ipk_) :: info integer(psb_ipk_) :: info
res = dzero res = dzero
@ -844,9 +831,6 @@ contains
! TYPE psb_d_vect ! TYPE psb_d_vect
! !
select type(yy => y) select type(yy => y)
type is (psb_d_base_vect_type)
if (x%is_dev()) call x%sync()
res = ddot(n,x%v,1,yy%v,1)
type is (psb_d_vect_cuda) type is (psb_d_vect_cuda)
if (x%is_host()) call x%sync() if (x%is_host()) call x%sync()
if (yy%is_host()) call yy%sync() if (yy%is_host()) call yy%sync()
@ -858,7 +842,7 @@ contains
class default class default
! y%sync is done in dot_a ! y%sync is done in dot_a
call x%sync() if (x%is_dev()) call x%sync()
res = y%dot(n,x%v) res = y%dot(n,x%v)
end select end select

@ -795,18 +795,6 @@ contains
call x%set_dev() call x%set_dev()
end subroutine i_cuda_set_scal end subroutine i_cuda_set_scal
!!$
!!$ subroutine i_cuda_set_vect(x,val)
!!$ class(psb_i_vect_cuda), intent(inout) :: x
!!$ integer(psb_ipk_), intent(in) :: val(:)
!!$ integer(psb_ipk_) :: nr
!!$ integer(psb_ipk_) :: info
!!$
!!$ if (x%is_dev()) call x%sync()
!!$ call x%psb_i_base_vect_type%set_vect(val)
!!$ call x%set_host()
!!$
!!$ end subroutine i_cuda_set_vect

@ -813,18 +813,6 @@ contains
call x%set_dev() call x%set_dev()
end subroutine s_cuda_set_scal end subroutine s_cuda_set_scal
!!$
!!$ subroutine s_cuda_set_vect(x,val)
!!$ class(psb_s_vect_cuda), intent(inout) :: x
!!$ real(psb_spk_), intent(in) :: val(:)
!!$ integer(psb_ipk_) :: nr
!!$ integer(psb_ipk_) :: info
!!$
!!$ if (x%is_dev()) call x%sync()
!!$ call x%psb_s_base_vect_type%set_vect(val)
!!$ call x%set_host()
!!$
!!$ end subroutine s_cuda_set_vect
@ -834,7 +822,6 @@ contains
class(psb_s_base_vect_type), intent(inout) :: y class(psb_s_base_vect_type), intent(inout) :: y
integer(psb_ipk_), intent(in) :: n integer(psb_ipk_), intent(in) :: n
real(psb_spk_) :: res real(psb_spk_) :: res
real(psb_spk_), external :: ddot
integer(psb_ipk_) :: info integer(psb_ipk_) :: info
res = szero res = szero
@ -844,9 +831,6 @@ contains
! TYPE psb_s_vect ! TYPE psb_s_vect
! !
select type(yy => y) select type(yy => y)
type is (psb_s_base_vect_type)
if (x%is_dev()) call x%sync()
res = ddot(n,x%v,1,yy%v,1)
type is (psb_s_vect_cuda) type is (psb_s_vect_cuda)
if (x%is_host()) call x%sync() if (x%is_host()) call x%sync()
if (yy%is_host()) call yy%sync() if (yy%is_host()) call yy%sync()
@ -858,7 +842,7 @@ contains
class default class default
! y%sync is done in dot_a ! y%sync is done in dot_a
call x%sync() if (x%is_dev()) call x%sync()
res = y%dot(n,x%v) res = y%dot(n,x%v)
end select end select
@ -870,10 +854,10 @@ contains
real(psb_spk_), intent(in) :: y(:) real(psb_spk_), intent(in) :: y(:)
integer(psb_ipk_), intent(in) :: n integer(psb_ipk_), intent(in) :: n
real(psb_spk_) :: res real(psb_spk_) :: res
real(psb_spk_), external :: ddot real(psb_spk_), external :: sdot
if (x%is_dev()) call x%sync() if (x%is_dev()) call x%sync()
res = ddot(n,y,1,x%v,1) res = sdot(n,y,1,x%v,1)
end function s_cuda_dot_a end function s_cuda_dot_a

@ -813,18 +813,6 @@ contains
call x%set_dev() call x%set_dev()
end subroutine z_cuda_set_scal end subroutine z_cuda_set_scal
!!$
!!$ subroutine z_cuda_set_vect(x,val)
!!$ class(psb_z_vect_cuda), intent(inout) :: x
!!$ complex(psb_dpk_), intent(in) :: val(:)
!!$ integer(psb_ipk_) :: nr
!!$ integer(psb_ipk_) :: info
!!$
!!$ if (x%is_dev()) call x%sync()
!!$ call x%psb_z_base_vect_type%set_vect(val)
!!$ call x%set_host()
!!$
!!$ end subroutine z_cuda_set_vect
@ -834,7 +822,6 @@ contains
class(psb_z_base_vect_type), intent(inout) :: y class(psb_z_base_vect_type), intent(inout) :: y
integer(psb_ipk_), intent(in) :: n integer(psb_ipk_), intent(in) :: n
complex(psb_dpk_) :: res complex(psb_dpk_) :: res
complex(psb_dpk_), external :: ddot
integer(psb_ipk_) :: info integer(psb_ipk_) :: info
res = zzero res = zzero
@ -844,9 +831,6 @@ contains
! TYPE psb_z_vect ! TYPE psb_z_vect
! !
select type(yy => y) select type(yy => y)
type is (psb_z_base_vect_type)
if (x%is_dev()) call x%sync()
res = ddot(n,x%v,1,yy%v,1)
type is (psb_z_vect_cuda) type is (psb_z_vect_cuda)
if (x%is_host()) call x%sync() if (x%is_host()) call x%sync()
if (yy%is_host()) call yy%sync() if (yy%is_host()) call yy%sync()
@ -858,7 +842,7 @@ contains
class default class default
! y%sync is done in dot_a ! y%sync is done in dot_a
call x%sync() if (x%is_dev()) call x%sync()
res = y%dot(n,x%v) res = y%dot(n,x%v)
end select end select
@ -870,10 +854,10 @@ contains
complex(psb_dpk_), intent(in) :: y(:) complex(psb_dpk_), intent(in) :: y(:)
integer(psb_ipk_), intent(in) :: n integer(psb_ipk_), intent(in) :: n
complex(psb_dpk_) :: res complex(psb_dpk_) :: res
complex(psb_dpk_), external :: ddot complex(psb_dpk_), external :: zdot
if (x%is_dev()) call x%sync() if (x%is_dev()) call x%sync()
res = ddot(n,y,1,x%v,1) res = zdot(n,y,1,x%v,1)
end function z_cuda_dot_a end function z_cuda_dot_a

@ -220,7 +220,8 @@ int dotMultiVecDeviceFloat(float* y_res, int n, void* devMultiVecA, void* devMul
struct MultiVectDevice *devVecB = (struct MultiVectDevice *) devMultiVecB; struct MultiVectDevice *devVecB = (struct MultiVectDevice *) devMultiVecB;
spgpuHandle_t handle=psb_cudaGetHandle(); spgpuHandle_t handle=psb_cudaGetHandle();
spgpuSmdot(handle, y_res, n, (float*)devVecA->v_, (float*)devVecB->v_,devVecA->count_,devVecB->pitch_); spgpuSmdot(handle, y_res, n, (float*)devVecA->v_, (float*)devVecB->v_,
devVecA->count_,devVecB->pitch_);
return(i); return(i);
} }

@ -223,7 +223,8 @@ int scalMultiVecDeviceDoubleComplex(cuDoubleComplex alpha, void* devMultiVecA)
return(i); return(i);
} }
int dotMultiVecDeviceDoubleComplex(cuDoubleComplex* y_res, int n, void* devMultiVecA, void* devMultiVecB) int dotMultiVecDeviceDoubleComplex(cuDoubleComplex* y_res, int n,
void* devMultiVecA, void* devMultiVecB)
{int i=0; {int i=0;
struct MultiVectDevice *devVecA = (struct MultiVectDevice *) devMultiVecA; struct MultiVectDevice *devVecA = (struct MultiVectDevice *) devMultiVecA;
struct MultiVectDevice *devVecB = (struct MultiVectDevice *) devMultiVecB; struct MultiVectDevice *devVecB = (struct MultiVectDevice *) devMultiVecB;

Loading…
Cancel
Save