Fix DOT in OpenACC

oacc_loloum
sfilippone 2 years ago
parent 4822861b73
commit 5903c0b272

@ -798,22 +798,18 @@ contains
class(psb_c_base_vect_type), intent(inout) :: y class(psb_c_base_vect_type), intent(inout) :: y
integer(psb_ipk_), intent(in) :: n integer(psb_ipk_), intent(in) :: n
complex(psb_spk_) :: res complex(psb_spk_) :: res
complex(psb_spk_), external :: ddot
integer(psb_ipk_) :: info integer(psb_ipk_) :: info
res = czero res = czero
!!$ write(0,*) 'oacc_dot_v' !!$ write(0,*) 'oacc_dot_v'
select type(yy => y) select type(yy => y)
type is (psb_c_base_vect_type)
if (x%is_dev()) call x%sync()
res = ddot(n, x%v, 1, yy%v, 1)
type is (psb_c_vect_oacc) type is (psb_c_vect_oacc)
if (x%is_host()) call x%sync() if (x%is_host()) call x%sync()
if (yy%is_host()) call yy%sync() if (yy%is_host()) call yy%sync()
res = c_inner_oacc_dot(n, x%v, yy%v) res = c_inner_oacc_dot(n, x%v, yy%v)
class default class default
call x%sync() if (x%is_dev()) call x%sync()
res = y%dot(n, x%v) res = y%dot(n, x%v)
end select end select
contains contains
function c_inner_oacc_dot(n, x, y) result(res) function c_inner_oacc_dot(n, x, y) result(res)
@ -838,10 +834,10 @@ contains
complex(psb_spk_), intent(in) :: y(:) complex(psb_spk_), intent(in) :: y(:)
integer(psb_ipk_), intent(in) :: n integer(psb_ipk_), intent(in) :: n
complex(psb_spk_) :: res complex(psb_spk_) :: res
complex(psb_spk_), external :: ddot complex(psb_spk_), external :: cdot
if (x%is_dev()) call x%sync() if (x%is_dev()) call x%sync()
res = ddot(n, y, 1, x%v, 1) res = cdot(n, y, 1, x%v, 1)
end function c_oacc_dot_a end function c_oacc_dot_a

@ -798,22 +798,18 @@ contains
class(psb_d_base_vect_type), intent(inout) :: y class(psb_d_base_vect_type), intent(inout) :: y
integer(psb_ipk_), intent(in) :: n integer(psb_ipk_), intent(in) :: n
real(psb_dpk_) :: res real(psb_dpk_) :: res
real(psb_dpk_), external :: ddot
integer(psb_ipk_) :: info integer(psb_ipk_) :: info
res = dzero res = dzero
!!$ write(0,*) 'oacc_dot_v' !!$ write(0,*) 'oacc_dot_v'
select type(yy => y) select type(yy => y)
type is (psb_d_base_vect_type)
if (x%is_dev()) call x%sync()
res = ddot(n, x%v, 1, yy%v, 1)
type is (psb_d_vect_oacc) type is (psb_d_vect_oacc)
if (x%is_host()) call x%sync() if (x%is_host()) call x%sync()
if (yy%is_host()) call yy%sync() if (yy%is_host()) call yy%sync()
res = d_inner_oacc_dot(n, x%v, yy%v) res = d_inner_oacc_dot(n, x%v, yy%v)
class default class default
call x%sync() if (x%is_dev()) call x%sync()
res = y%dot(n, x%v) res = y%dot(n, x%v)
end select end select
contains contains
function d_inner_oacc_dot(n, x, y) result(res) function d_inner_oacc_dot(n, x, y) result(res)

@ -798,22 +798,18 @@ contains
class(psb_s_base_vect_type), intent(inout) :: y class(psb_s_base_vect_type), intent(inout) :: y
integer(psb_ipk_), intent(in) :: n integer(psb_ipk_), intent(in) :: n
real(psb_spk_) :: res real(psb_spk_) :: res
real(psb_spk_), external :: ddot
integer(psb_ipk_) :: info integer(psb_ipk_) :: info
res = szero res = szero
!!$ write(0,*) 'oacc_dot_v' !!$ write(0,*) 'oacc_dot_v'
select type(yy => y) select type(yy => y)
type is (psb_s_base_vect_type)
if (x%is_dev()) call x%sync()
res = ddot(n, x%v, 1, yy%v, 1)
type is (psb_s_vect_oacc) type is (psb_s_vect_oacc)
if (x%is_host()) call x%sync() if (x%is_host()) call x%sync()
if (yy%is_host()) call yy%sync() if (yy%is_host()) call yy%sync()
res = s_inner_oacc_dot(n, x%v, yy%v) res = s_inner_oacc_dot(n, x%v, yy%v)
class default class default
call x%sync() if (x%is_dev()) call x%sync()
res = y%dot(n, x%v) res = y%dot(n, x%v)
end select end select
contains contains
function s_inner_oacc_dot(n, x, y) result(res) function s_inner_oacc_dot(n, x, y) result(res)
@ -838,10 +834,10 @@ contains
real(psb_spk_), intent(in) :: y(:) real(psb_spk_), intent(in) :: y(:)
integer(psb_ipk_), intent(in) :: n integer(psb_ipk_), intent(in) :: n
real(psb_spk_) :: res real(psb_spk_) :: res
real(psb_spk_), external :: ddot real(psb_spk_), external :: sdot
if (x%is_dev()) call x%sync() if (x%is_dev()) call x%sync()
res = ddot(n, y, 1, x%v, 1) res = sdot(n, y, 1, x%v, 1)
end function s_oacc_dot_a end function s_oacc_dot_a

@ -798,22 +798,18 @@ contains
class(psb_z_base_vect_type), intent(inout) :: y class(psb_z_base_vect_type), intent(inout) :: y
integer(psb_ipk_), intent(in) :: n integer(psb_ipk_), intent(in) :: n
complex(psb_dpk_) :: res complex(psb_dpk_) :: res
complex(psb_dpk_), external :: ddot
integer(psb_ipk_) :: info integer(psb_ipk_) :: info
res = zzero res = zzero
!!$ write(0,*) 'oacc_dot_v' !!$ write(0,*) 'oacc_dot_v'
select type(yy => y) select type(yy => y)
type is (psb_z_base_vect_type)
if (x%is_dev()) call x%sync()
res = ddot(n, x%v, 1, yy%v, 1)
type is (psb_z_vect_oacc) type is (psb_z_vect_oacc)
if (x%is_host()) call x%sync() if (x%is_host()) call x%sync()
if (yy%is_host()) call yy%sync() if (yy%is_host()) call yy%sync()
res = z_inner_oacc_dot(n, x%v, yy%v) res = z_inner_oacc_dot(n, x%v, yy%v)
class default class default
call x%sync() if (x%is_dev()) call x%sync()
res = y%dot(n, x%v) res = y%dot(n, x%v)
end select end select
contains contains
function z_inner_oacc_dot(n, x, y) result(res) function z_inner_oacc_dot(n, x, y) result(res)
@ -838,10 +834,10 @@ contains
complex(psb_dpk_), intent(in) :: y(:) complex(psb_dpk_), intent(in) :: y(:)
integer(psb_ipk_), intent(in) :: n integer(psb_ipk_), intent(in) :: n
complex(psb_dpk_) :: res complex(psb_dpk_) :: res
complex(psb_dpk_), external :: ddot complex(psb_dpk_), external :: zdot
if (x%is_dev()) call x%sync() if (x%is_dev()) call x%sync()
res = ddot(n, y, 1, x%v, 1) res = zdot(n, y, 1, x%v, 1)
end function z_oacc_dot_a end function z_oacc_dot_a

Loading…
Cancel
Save