psb_dnrm2_weight working also for GPUs

newG2L
Cirdans-Home 4 years ago
parent faf493b861
commit 9a2ea73d81

@ -1020,10 +1020,12 @@ contains
end function c_vect_nrm2
function c_vect_nrm2_weight(n,x,w) result(res)
function c_vect_nrm2_weight(n,x,w,aux) result(res)
use psi_serial_mod
implicit none
class(psb_c_vect_type), intent(inout) :: x
class(psb_c_vect_type), intent(inout) :: w
class(psb_c_vect_type), intent(inout), optional :: aux
integer(psb_ipk_), intent(in) :: n
real(psb_spk_) :: res
integer(psb_ipk_) :: info
@ -1032,8 +1034,12 @@ contains
type(psb_c_vect_type) :: wtemp
if( allocated(w%v) ) then
! FIXME for GPU
allocate(wtemp%v, source=w%v, stat = info)
if (.not.present(aux)) then
allocate(wtemp%v, mold=w%v)
call wtemp%v%bld(w%get_vect())
else
call psb_geaxpby(n,cone,w%v%v,czero,aux%v%v,info)
end if
else
info = -1
end if
@ -1043,29 +1049,44 @@ contains
end if
if (allocated(x%v)) then
call wtemp%v%mlt(x%v,info)
res = wtemp%v%nrm2(n)
if (.not.present(aux)) then
call wtemp%v%mlt(x%v,info)
res = wtemp%v%nrm2(n)
else
call aux%v%mlt(x%v,info)
res = aux%v%nrm2(n)
end if
else
res = szero
end if
if (.not.present(aux)) then
call wtemp%free(info)
end if
end function c_vect_nrm2_weight
function c_vect_nrm2_weight_mask(n,x,w,id) result(res)
function c_vect_nrm2_weight_mask(n,x,w,id,info,aux) result(res)
use psi_serial_mod
implicit none
class(psb_c_vect_type), intent(inout) :: x
class(psb_c_vect_type), intent(inout) :: w
class(psb_c_vect_type), intent(inout) :: id
integer(psb_ipk_), intent(in) :: n
real(psb_spk_) :: res
integer(psb_ipk_) :: info
integer(psb_ipk_), intent(out) :: info
class(psb_c_vect_type), intent(inout), optional :: aux
! Temp vectors
type(psb_c_vect_type) :: wtemp
if( allocated(w%v) ) then
! FIXME for GPU
allocate(wtemp%v, source=w%v, stat = info)
if (.not.present(aux)) then
allocate(wtemp%v, mold=w%v)
call wtemp%v%bld(w%get_vect())
else
call psb_geaxpby(n,cone,w%v%v,czero,aux%v%v,info)
end if
else
info = -1
end if
@ -1076,15 +1097,25 @@ contains
if (allocated(x%v).and.allocated(id%v)) then
call wtemp%sync() ! FIXME for GPU
where( abs(id%v%v) <= szero) wtemp%v%v = szero
call wtemp%set_host() ! FIXME for GPU
call wtemp%v%mlt(x%v,info)
res = wtemp%v%nrm2(n)
if (.not.present(aux)) then
where( abs(id%v%v) <= szero) wtemp%v%v = szero
call wtemp%set_host()
call wtemp%v%mlt(x%v,info)
res = wtemp%v%nrm2(n)
else
where( abs(id%v%v) <= szero) aux%v%v = szero
call aux%set_host()
call aux%v%mlt(x%v,info)
res = aux%v%nrm2(n)
end if
else
res = szero
end if
if (.not.present(aux)) then
call wtemp%free(info)
end if
end function c_vect_nrm2_weight_mask
function c_vect_amax(n,x) result(res)

@ -1027,10 +1027,12 @@ contains
end function d_vect_nrm2
function d_vect_nrm2_weight(n,x,w) result(res)
function d_vect_nrm2_weight(n,x,w,aux) result(res)
use psi_serial_mod
implicit none
class(psb_d_vect_type), intent(inout) :: x
class(psb_d_vect_type), intent(inout) :: w
class(psb_d_vect_type), intent(inout), optional :: aux
integer(psb_ipk_), intent(in) :: n
real(psb_dpk_) :: res
integer(psb_ipk_) :: info
@ -1039,8 +1041,12 @@ contains
type(psb_d_vect_type) :: wtemp
if( allocated(w%v) ) then
! FIXME for GPU
allocate(wtemp%v, source=w%v, stat = info)
if (.not.present(aux)) then
allocate(wtemp%v, mold=w%v)
call wtemp%v%bld(w%get_vect())
else
call psb_geaxpby(n,done,w%v%v,dzero,aux%v%v,info)
end if
else
info = -1
end if
@ -1050,29 +1056,44 @@ contains
end if
if (allocated(x%v)) then
call wtemp%v%mlt(x%v,info)
res = wtemp%v%nrm2(n)
if (.not.present(aux)) then
call wtemp%v%mlt(x%v,info)
res = wtemp%v%nrm2(n)
else
call aux%v%mlt(x%v,info)
res = aux%v%nrm2(n)
end if
else
res = dzero
end if
if (.not.present(aux)) then
call wtemp%free(info)
end if
end function d_vect_nrm2_weight
function d_vect_nrm2_weight_mask(n,x,w,id) result(res)
function d_vect_nrm2_weight_mask(n,x,w,id,info,aux) result(res)
use psi_serial_mod
implicit none
class(psb_d_vect_type), intent(inout) :: x
class(psb_d_vect_type), intent(inout) :: w
class(psb_d_vect_type), intent(inout) :: id
integer(psb_ipk_), intent(in) :: n
real(psb_dpk_) :: res
integer(psb_ipk_) :: info
integer(psb_ipk_), intent(out) :: info
class(psb_d_vect_type), intent(inout), optional :: aux
! Temp vectors
type(psb_d_vect_type) :: wtemp
if( allocated(w%v) ) then
! FIXME for GPU
allocate(wtemp%v, source=w%v, stat = info)
if (.not.present(aux)) then
allocate(wtemp%v, mold=w%v)
call wtemp%v%bld(w%get_vect())
else
call psb_geaxpby(n,done,w%v%v,dzero,aux%v%v,info)
end if
else
info = -1
end if
@ -1083,15 +1104,25 @@ contains
if (allocated(x%v).and.allocated(id%v)) then
call wtemp%sync() ! FIXME for GPU
where( abs(id%v%v) <= dzero) wtemp%v%v = dzero
call wtemp%set_host() ! FIXME for GPU
call wtemp%v%mlt(x%v,info)
res = wtemp%v%nrm2(n)
if (.not.present(aux)) then
where( abs(id%v%v) <= dzero) wtemp%v%v = dzero
call wtemp%set_host()
call wtemp%v%mlt(x%v,info)
res = wtemp%v%nrm2(n)
else
where( abs(id%v%v) <= dzero) aux%v%v = dzero
call aux%set_host()
call aux%v%mlt(x%v,info)
res = aux%v%nrm2(n)
end if
else
res = dzero
end if
if (.not.present(aux)) then
call wtemp%free(info)
end if
end function d_vect_nrm2_weight_mask
function d_vect_amax(n,x) result(res)

@ -1027,10 +1027,12 @@ contains
end function s_vect_nrm2
function s_vect_nrm2_weight(n,x,w) result(res)
function s_vect_nrm2_weight(n,x,w,aux) result(res)
use psi_serial_mod
implicit none
class(psb_s_vect_type), intent(inout) :: x
class(psb_s_vect_type), intent(inout) :: w
class(psb_s_vect_type), intent(inout), optional :: aux
integer(psb_ipk_), intent(in) :: n
real(psb_spk_) :: res
integer(psb_ipk_) :: info
@ -1039,8 +1041,12 @@ contains
type(psb_s_vect_type) :: wtemp
if( allocated(w%v) ) then
! FIXME for GPU
allocate(wtemp%v, source=w%v, stat = info)
if (.not.present(aux)) then
allocate(wtemp%v, mold=w%v)
call wtemp%v%bld(w%get_vect())
else
call psb_geaxpby(n,sone,w%v%v,szero,aux%v%v,info)
end if
else
info = -1
end if
@ -1050,29 +1056,44 @@ contains
end if
if (allocated(x%v)) then
call wtemp%v%mlt(x%v,info)
res = wtemp%v%nrm2(n)
if (.not.present(aux)) then
call wtemp%v%mlt(x%v,info)
res = wtemp%v%nrm2(n)
else
call aux%v%mlt(x%v,info)
res = aux%v%nrm2(n)
end if
else
res = szero
end if
if (.not.present(aux)) then
call wtemp%free(info)
end if
end function s_vect_nrm2_weight
function s_vect_nrm2_weight_mask(n,x,w,id) result(res)
function s_vect_nrm2_weight_mask(n,x,w,id,info,aux) result(res)
use psi_serial_mod
implicit none
class(psb_s_vect_type), intent(inout) :: x
class(psb_s_vect_type), intent(inout) :: w
class(psb_s_vect_type), intent(inout) :: id
integer(psb_ipk_), intent(in) :: n
real(psb_spk_) :: res
integer(psb_ipk_) :: info
integer(psb_ipk_), intent(out) :: info
class(psb_s_vect_type), intent(inout), optional :: aux
! Temp vectors
type(psb_s_vect_type) :: wtemp
if( allocated(w%v) ) then
! FIXME for GPU
allocate(wtemp%v, source=w%v, stat = info)
if (.not.present(aux)) then
allocate(wtemp%v, mold=w%v)
call wtemp%v%bld(w%get_vect())
else
call psb_geaxpby(n,sone,w%v%v,szero,aux%v%v,info)
end if
else
info = -1
end if
@ -1083,15 +1104,25 @@ contains
if (allocated(x%v).and.allocated(id%v)) then
call wtemp%sync() ! FIXME for GPU
where( abs(id%v%v) <= szero) wtemp%v%v = szero
call wtemp%set_host() ! FIXME for GPU
call wtemp%v%mlt(x%v,info)
res = wtemp%v%nrm2(n)
if (.not.present(aux)) then
where( abs(id%v%v) <= szero) wtemp%v%v = szero
call wtemp%set_host()
call wtemp%v%mlt(x%v,info)
res = wtemp%v%nrm2(n)
else
where( abs(id%v%v) <= szero) aux%v%v = szero
call aux%set_host()
call aux%v%mlt(x%v,info)
res = aux%v%nrm2(n)
end if
else
res = szero
end if
if (.not.present(aux)) then
call wtemp%free(info)
end if
end function s_vect_nrm2_weight_mask
function s_vect_amax(n,x) result(res)

@ -1020,10 +1020,12 @@ contains
end function z_vect_nrm2
function z_vect_nrm2_weight(n,x,w) result(res)
function z_vect_nrm2_weight(n,x,w,aux) result(res)
use psi_serial_mod
implicit none
class(psb_z_vect_type), intent(inout) :: x
class(psb_z_vect_type), intent(inout) :: w
class(psb_z_vect_type), intent(inout), optional :: aux
integer(psb_ipk_), intent(in) :: n
real(psb_dpk_) :: res
integer(psb_ipk_) :: info
@ -1032,8 +1034,12 @@ contains
type(psb_z_vect_type) :: wtemp
if( allocated(w%v) ) then
! FIXME for GPU
allocate(wtemp%v, source=w%v, stat = info)
if (.not.present(aux)) then
allocate(wtemp%v, mold=w%v)
call wtemp%v%bld(w%get_vect())
else
call psb_geaxpby(n,zone,w%v%v,zzero,aux%v%v,info)
end if
else
info = -1
end if
@ -1043,29 +1049,44 @@ contains
end if
if (allocated(x%v)) then
call wtemp%v%mlt(x%v,info)
res = wtemp%v%nrm2(n)
if (.not.present(aux)) then
call wtemp%v%mlt(x%v,info)
res = wtemp%v%nrm2(n)
else
call aux%v%mlt(x%v,info)
res = aux%v%nrm2(n)
end if
else
res = dzero
end if
if (.not.present(aux)) then
call wtemp%free(info)
end if
end function z_vect_nrm2_weight
function z_vect_nrm2_weight_mask(n,x,w,id) result(res)
function z_vect_nrm2_weight_mask(n,x,w,id,info,aux) result(res)
use psi_serial_mod
implicit none
class(psb_z_vect_type), intent(inout) :: x
class(psb_z_vect_type), intent(inout) :: w
class(psb_z_vect_type), intent(inout) :: id
integer(psb_ipk_), intent(in) :: n
real(psb_dpk_) :: res
integer(psb_ipk_) :: info
integer(psb_ipk_), intent(out) :: info
class(psb_z_vect_type), intent(inout), optional :: aux
! Temp vectors
type(psb_z_vect_type) :: wtemp
if( allocated(w%v) ) then
! FIXME for GPU
allocate(wtemp%v, source=w%v, stat = info)
if (.not.present(aux)) then
allocate(wtemp%v, mold=w%v)
call wtemp%v%bld(w%get_vect())
else
call psb_geaxpby(n,zone,w%v%v,zzero,aux%v%v,info)
end if
else
info = -1
end if
@ -1076,15 +1097,25 @@ contains
if (allocated(x%v).and.allocated(id%v)) then
call wtemp%sync() ! FIXME for GPU
where( abs(id%v%v) <= dzero) wtemp%v%v = dzero
call wtemp%set_host() ! FIXME for GPU
call wtemp%v%mlt(x%v,info)
res = wtemp%v%nrm2(n)
if (.not.present(aux)) then
where( abs(id%v%v) <= dzero) wtemp%v%v = dzero
call wtemp%set_host()
call wtemp%v%mlt(x%v,info)
res = wtemp%v%nrm2(n)
else
where( abs(id%v%v) <= dzero) aux%v%v = dzero
call aux%set_host()
call aux%v%mlt(x%v,info)
res = aux%v%nrm2(n)
end if
else
res = dzero
end if
if (.not.present(aux)) then
call wtemp%free(info)
end if
end function z_vect_nrm2_weight_mask
function z_vect_amax(n,x) result(res)

@ -385,7 +385,7 @@ end function psb_cnrm2_vect
! info - integer. Return code
! global - logical(optional) Whether to perform the global reduction, default: .true.
!
function psb_cnrm2_weight_vect(x,w, desc_a, info,global) result(res)
function psb_cnrm2_weight_vect(x,w, desc_a, info,global,aux) result(res)
use psb_desc_mod
use psb_check_mod
use psb_error_mod
@ -399,6 +399,7 @@ function psb_cnrm2_weight_vect(x,w, desc_a, info,global) result(res)
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
type(psb_c_vect_type), intent(inout), optional :: aux
! locals
type(psb_ctxt_type) :: ctxt
@ -456,7 +457,7 @@ function psb_cnrm2_weight_vect(x,w, desc_a, info,global) result(res)
if (desc_a%get_local_rows() > 0) then
ndim = desc_a%get_local_rows()
res = x%nrm2(ndim,w)
res = x%nrm2(ndim,w,aux)
! adjust because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%is_dev()) call x%sync()
@ -495,7 +496,7 @@ end function psb_cnrm2_weight_vect
! info - integer. Return code
! global - logical(optional) Whether to perform the global reduction, default: .true.
!
function psb_cnrm2_weightmask_vect(x,w,idv, desc_a, info,global) result(res)
function psb_cnrm2_weightmask_vect(x,w,idv, desc_a, info,global, aux) result(res)
use psb_desc_mod
use psb_check_mod
use psb_error_mod
@ -510,6 +511,7 @@ function psb_cnrm2_weightmask_vect(x,w,idv, desc_a, info,global) result(res)
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
type(psb_c_vect_type), intent(inout), optional :: aux
! locals
type(psb_ctxt_type) :: ctxt
@ -567,7 +569,7 @@ function psb_cnrm2_weightmask_vect(x,w,idv, desc_a, info,global) result(res)
if (desc_a%get_local_rows() > 0) then
ndim = desc_a%get_local_rows()
res = x%nrm2(ndim,w,idv)
res = x%nrm2(ndim,w,idv,info,aux)
! adjust because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%is_dev()) call x%sync()

@ -385,7 +385,7 @@ end function psb_dnrm2_vect
! info - integer. Return code
! global - logical(optional) Whether to perform the global reduction, default: .true.
!
function psb_dnrm2_weight_vect(x,w, desc_a, info,global) result(res)
function psb_dnrm2_weight_vect(x,w, desc_a, info,global,aux) result(res)
use psb_desc_mod
use psb_check_mod
use psb_error_mod
@ -399,6 +399,7 @@ function psb_dnrm2_weight_vect(x,w, desc_a, info,global) result(res)
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
type(psb_d_vect_type), intent(inout), optional :: aux
! locals
type(psb_ctxt_type) :: ctxt
@ -456,7 +457,7 @@ function psb_dnrm2_weight_vect(x,w, desc_a, info,global) result(res)
if (desc_a%get_local_rows() > 0) then
ndim = desc_a%get_local_rows()
res = x%nrm2(ndim,w)
res = x%nrm2(ndim,w,aux)
! adjust because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%is_dev()) call x%sync()
@ -495,7 +496,7 @@ end function psb_dnrm2_weight_vect
! info - integer. Return code
! global - logical(optional) Whether to perform the global reduction, default: .true.
!
function psb_dnrm2_weightmask_vect(x,w,idv, desc_a, info,global) result(res)
function psb_dnrm2_weightmask_vect(x,w,idv, desc_a, info,global, aux) result(res)
use psb_desc_mod
use psb_check_mod
use psb_error_mod
@ -510,6 +511,7 @@ function psb_dnrm2_weightmask_vect(x,w,idv, desc_a, info,global) result(res)
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
type(psb_d_vect_type), intent(inout), optional :: aux
! locals
type(psb_ctxt_type) :: ctxt
@ -567,7 +569,7 @@ function psb_dnrm2_weightmask_vect(x,w,idv, desc_a, info,global) result(res)
if (desc_a%get_local_rows() > 0) then
ndim = desc_a%get_local_rows()
res = x%nrm2(ndim,w,idv)
res = x%nrm2(ndim,w,idv,info,aux)
! adjust because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%is_dev()) call x%sync()

@ -385,7 +385,7 @@ end function psb_snrm2_vect
! info - integer. Return code
! global - logical(optional) Whether to perform the global reduction, default: .true.
!
function psb_snrm2_weight_vect(x,w, desc_a, info,global) result(res)
function psb_snrm2_weight_vect(x,w, desc_a, info,global,aux) result(res)
use psb_desc_mod
use psb_check_mod
use psb_error_mod
@ -399,6 +399,7 @@ function psb_snrm2_weight_vect(x,w, desc_a, info,global) result(res)
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
type(psb_s_vect_type), intent(inout), optional :: aux
! locals
type(psb_ctxt_type) :: ctxt
@ -456,7 +457,7 @@ function psb_snrm2_weight_vect(x,w, desc_a, info,global) result(res)
if (desc_a%get_local_rows() > 0) then
ndim = desc_a%get_local_rows()
res = x%nrm2(ndim,w)
res = x%nrm2(ndim,w,aux)
! adjust because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%is_dev()) call x%sync()
@ -495,7 +496,7 @@ end function psb_snrm2_weight_vect
! info - integer. Return code
! global - logical(optional) Whether to perform the global reduction, default: .true.
!
function psb_snrm2_weightmask_vect(x,w,idv, desc_a, info,global) result(res)
function psb_snrm2_weightmask_vect(x,w,idv, desc_a, info,global, aux) result(res)
use psb_desc_mod
use psb_check_mod
use psb_error_mod
@ -510,6 +511,7 @@ function psb_snrm2_weightmask_vect(x,w,idv, desc_a, info,global) result(res)
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
type(psb_s_vect_type), intent(inout), optional :: aux
! locals
type(psb_ctxt_type) :: ctxt
@ -567,7 +569,7 @@ function psb_snrm2_weightmask_vect(x,w,idv, desc_a, info,global) result(res)
if (desc_a%get_local_rows() > 0) then
ndim = desc_a%get_local_rows()
res = x%nrm2(ndim,w,idv)
res = x%nrm2(ndim,w,idv,info,aux)
! adjust because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%is_dev()) call x%sync()

@ -385,7 +385,7 @@ end function psb_znrm2_vect
! info - integer. Return code
! global - logical(optional) Whether to perform the global reduction, default: .true.
!
function psb_znrm2_weight_vect(x,w, desc_a, info,global) result(res)
function psb_znrm2_weight_vect(x,w, desc_a, info,global,aux) result(res)
use psb_desc_mod
use psb_check_mod
use psb_error_mod
@ -399,6 +399,7 @@ function psb_znrm2_weight_vect(x,w, desc_a, info,global) result(res)
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
type(psb_z_vect_type), intent(inout), optional :: aux
! locals
type(psb_ctxt_type) :: ctxt
@ -456,7 +457,7 @@ function psb_znrm2_weight_vect(x,w, desc_a, info,global) result(res)
if (desc_a%get_local_rows() > 0) then
ndim = desc_a%get_local_rows()
res = x%nrm2(ndim,w)
res = x%nrm2(ndim,w,aux)
! adjust because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%is_dev()) call x%sync()
@ -495,7 +496,7 @@ end function psb_znrm2_weight_vect
! info - integer. Return code
! global - logical(optional) Whether to perform the global reduction, default: .true.
!
function psb_znrm2_weightmask_vect(x,w,idv, desc_a, info,global) result(res)
function psb_znrm2_weightmask_vect(x,w,idv, desc_a, info,global, aux) result(res)
use psb_desc_mod
use psb_check_mod
use psb_error_mod
@ -510,6 +511,7 @@ function psb_znrm2_weightmask_vect(x,w,idv, desc_a, info,global) result(res)
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
type(psb_z_vect_type), intent(inout), optional :: aux
! locals
type(psb_ctxt_type) :: ctxt
@ -567,7 +569,7 @@ function psb_znrm2_weightmask_vect(x,w,idv, desc_a, info,global) result(res)
if (desc_a%get_local_rows() > 0) then
ndim = desc_a%get_local_rows()
res = x%nrm2(ndim,w,idv)
res = x%nrm2(ndim,w,idv,info,aux)
! adjust because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%is_dev()) call x%sync()

Binary file not shown.
Loading…
Cancel
Save