From 22cfaaccab427cebfa529a67f01502e83e4a4d86 Mon Sep 17 00:00:00 2001 From: Salvatore Filippone Date: Wed, 15 Apr 2015 15:37:06 +0000 Subject: [PATCH] psblas3: base/modules/psb_c_base_vect_mod.f90 base/modules/psb_d_base_vect_mod.f90 base/modules/psb_s_base_vect_mod.f90 base/modules/psb_s_tools_mod.f90 base/modules/psb_z_base_vect_mod.f90 Start merge of mods from matasb. Inner buffer methods for base vectors. --- base/modules/psb_c_base_vect_mod.f90 | 176 +++++++++++++++++++++++---- base/modules/psb_d_base_vect_mod.f90 | 176 +++++++++++++++++++++++---- base/modules/psb_s_base_vect_mod.f90 | 176 +++++++++++++++++++++++---- base/modules/psb_s_tools_mod.f90 | 2 +- base/modules/psb_z_base_vect_mod.f90 | 176 +++++++++++++++++++++++---- 5 files changed, 597 insertions(+), 109 deletions(-) diff --git a/base/modules/psb_c_base_vect_mod.f90 b/base/modules/psb_c_base_vect_mod.f90 index de09c4f9..52f17321 100644 --- a/base/modules/psb_c_base_vect_mod.f90 +++ b/base/modules/psb_c_base_vect_mod.f90 @@ -62,6 +62,8 @@ module psb_c_base_vect_mod type psb_c_base_vect_type !> Values. complex(psb_spk_), allocatable :: v(:) + complex(psb_spk_), allocatable :: combuf(:) + integer(psb_ipk_), allocatable :: comid(:,:) contains ! ! Constructors/allocators @@ -97,6 +99,17 @@ module psb_c_base_vect_mod procedure, pass(x) :: set_dev => c_base_set_dev procedure, pass(x) :: set_sync => c_base_set_sync + ! + ! These are for handling gather/scatter in new + ! comm internals implementation. + ! + procedure, nopass :: use_buffer => c_base_use_buffer + procedure, pass(x) :: new_buffer => c_base_new_buffer + procedure, nopass :: device_wait => c_base_device_wait + procedure, pass(x) :: free_buffer => c_base_free_buffer + procedure, pass(x) :: new_comid => c_base_new_comid + procedure, pass(x) :: free_comid => c_base_free_comid + ! ! Basic info procedure, pass(x) :: get_nrows => c_base_get_nrows @@ -148,10 +161,12 @@ module psb_c_base_vect_mod procedure, pass(x) :: gthab => c_base_gthab procedure, pass(x) :: gthzv => c_base_gthzv procedure, pass(x) :: gthzv_x => c_base_gthzv_x - generic, public :: gth => gthab, gthzv, gthzv_x + procedure, pass(x) :: gthzbuf => c_base_gthzbuf + generic, public :: gth => gthab, gthzv, gthzv_x, gthzbuf procedure, pass(y) :: sctb => c_base_sctb procedure, pass(y) :: sctb_x => c_base_sctb_x - generic, public :: sct => sctb, sctb_x + procedure, pass(y) :: sctb_buf => c_base_sctb_buf + generic, public :: sct => sctb, sctb_x, sctb_buf end type psb_c_base_vect_type public :: psb_c_base_vect @@ -348,8 +363,8 @@ contains case default info = 321 -! !$ call psb_errpush(info,name) -! !$ goto 9999 + ! !$ call psb_errpush(info,name) + ! !$ goto 9999 end select end if call x%set_host() @@ -360,7 +375,6 @@ contains end subroutine c_base_ins_a - subroutine c_base_ins_v(n,irl,val,dupl,x,info) use psi_serial_mod implicit none @@ -452,6 +466,8 @@ contains info = 0 if (allocated(x%v)) deallocate(x%v, stat=info) + if (info == 0) call x%free_buffer(info) + if (info == 0) call x%free_comid(info) if (info /= 0) call & & psb_errpush(psb_err_alloc_dealloc_,'vect_free') @@ -778,12 +794,9 @@ contains complex(psb_spk_), intent (in) :: alpha, beta integer(psb_ipk_), intent(out) :: info - select type(xx => x) - type is (psb_c_base_vect_type) - call psb_geaxpby(m,alpha,x%v,beta,y%v,info) - class default - call y%axpby(m,alpha,x%v,beta,info) - end select + if (x%is_dev()) call x%sync() + + call y%axpby(m,alpha,x%v,beta,info) end subroutine c_base_axpby_v @@ -809,7 +822,9 @@ contains complex(psb_spk_), intent (in) :: alpha, beta integer(psb_ipk_), intent(out) :: info + if (y%is_dev()) call y%sync() call psb_geaxpby(m,alpha,x,beta,y%v,info) + call y%set_host() end subroutine c_base_axpby_a @@ -838,15 +853,8 @@ contains integer(psb_ipk_) :: i, n info = 0 - select type(xx => x) - type is (psb_c_base_vect_type) - n = min(size(y%v), size(xx%v)) - do i=1, n - y%v(i) = y%v(i)*xx%v(i) - end do - class default - call y%mlt(x%v,info) - end select + if (x%is_dev()) call x%sync() + call y%mlt(x%v,info) end subroutine c_base_mlt_v @@ -866,11 +874,13 @@ contains integer(psb_ipk_) :: i, n info = 0 + if (y%is_dev()) call y%sync() n = min(size(y%v), size(x)) do i=1, n y%v(i) = y%v(i)*x(i) end do - + call y%set_host() + end subroutine c_base_mlt_a @@ -896,6 +906,8 @@ contains integer(psb_ipk_) :: i, n info = 0 + if (z%is_dev()) call z%sync() + n = min(size(z%v), size(x), size(y)) !!$ write(0,*) 'Mlt_a_2: ',n if (alpha == czero) then @@ -951,6 +963,8 @@ contains end if end if end if + call z%set_host() + end subroutine c_base_mlt_a_2 ! @@ -978,6 +992,8 @@ contains logical :: conjgx_, conjgy_ info = 0 + if (y%is_dev()) call y%sync() + if (x%is_dev()) call x%sync() if (.not.psb_c_is_complex_) then call z%mlt(alpha,x%v,y%v,beta,info) else @@ -1004,7 +1020,7 @@ contains integer(psb_ipk_) :: i, n info = 0 - + if (y%is_dev()) call y%sync() call z%mlt(alpha,x,y%v,beta,info) end subroutine c_base_mlt_av @@ -1020,7 +1036,7 @@ contains integer(psb_ipk_) :: i, n info = 0 - + if (x%is_dev()) call x%sync() call z%mlt(alpha,y,x,beta,info) end subroutine c_base_mlt_va @@ -1040,7 +1056,10 @@ contains class(psb_c_base_vect_type), intent(inout) :: x complex(psb_spk_), intent (in) :: alpha - if (allocated(x%v)) x%v = alpha*x%v + if (allocated(x%v)) then + x%v = alpha*x%v + call x%set_host() + end if end subroutine c_base_scal @@ -1058,6 +1077,7 @@ contains real(psb_spk_) :: res real(psb_spk_), external :: scnrm2 + if (x%is_dev()) call x%sync() res = scnrm2(n,x%v,1) end function c_base_nrm2 @@ -1073,6 +1093,7 @@ contains integer(psb_ipk_), intent(in) :: n real(psb_spk_) :: res + if (x%is_dev()) call x%sync() res = maxval(abs(x%v(1:n))) end function c_base_amax @@ -1088,6 +1109,7 @@ contains integer(psb_ipk_), intent(in) :: n real(psb_spk_) :: res + if (x%is_dev()) call x%sync() res = sum(abs(x%v(1:n))) end function c_base_asum @@ -1111,7 +1133,7 @@ contains complex(psb_spk_) :: alpha, beta, y(:) class(psb_c_base_vect_type) :: x - call x%sync() + if (x%is_dev()) call x%sync() call psi_gth(n,idx,alpha,x%v,beta,y) end subroutine c_base_gthab @@ -1131,10 +1153,108 @@ contains complex(psb_spk_) :: y(:) class(psb_c_base_vect_type) :: x + if (idx%is_dev()) call idx%sync() call x%gth(n,idx%v(i:),y) end subroutine c_base_gthzv_x + ! + ! New comm internals impl. + ! + subroutine c_base_gthzbuf(i,n,idx,x) + use psi_serial_mod + integer(psb_ipk_) :: i,n + class(psb_i_base_vect_type) :: idx + class(psb_c_base_vect_type) :: x + + if (.not.allocated(x%combuf)) then + call psb_errpush(psb_err_alloc_dealloc_,'gthzbuf') + return + end if + if (idx%is_dev()) call idx%sync() + if (x%is_dev()) call x%sync() + call x%gth(n,idx%v(i:),x%combuf(i:)) + + end subroutine c_base_gthzbuf + + subroutine c_base_sctb_buf(i,n,idx,beta,y) + use psi_serial_mod + integer(psb_ipk_) :: i, n + class(psb_i_base_vect_type) :: idx + complex(psb_spk_) :: beta + class(psb_c_base_vect_type) :: y + + + if (.not.allocated(y%combuf)) then + call psb_errpush(psb_err_alloc_dealloc_,'sctb_buf') + return + end if + if (y%is_dev()) call y%sync() + if (idx%is_dev()) call idx%sync() + call y%sct(n,idx%v(i:),y%combuf(i:),beta) + call y%set_host() + + end subroutine c_base_sctb_buf + + ! + !> Function base_device_wait: + !! \memberof psb_c_base_vect_type + !! \brief device_wait: base version is a no-op. + !! + ! + subroutine c_base_device_wait() + implicit none + + end subroutine c_base_device_wait + + function c_base_use_buffer() result(res) + logical :: res + + res = .true. + end function c_base_use_buffer + + subroutine c_base_new_buffer(n,x,info) + use psb_realloc_mod + implicit none + class(psb_c_base_vect_type), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + integer(psb_ipk_), intent(out) :: info + + call psb_realloc(n,x%combuf,info) + end subroutine c_base_new_buffer + + subroutine c_base_new_comid(n,x,info) + use psb_realloc_mod + implicit none + class(psb_c_base_vect_type), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + integer(psb_ipk_), intent(out) :: info + + call psb_realloc(n,2,x%comid,info) + end subroutine c_base_new_comid + + + subroutine c_base_free_buffer(x,info) + use psb_realloc_mod + implicit none + class(psb_c_base_vect_type), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + if (allocated(x%combuf)) & + & deallocate(x%combuf,stat=info) + end subroutine c_base_free_buffer + + subroutine c_base_free_comid(x,info) + use psb_realloc_mod + implicit none + class(psb_c_base_vect_type), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + if (allocated(x%comid)) & + & deallocate(x%comid,stat=info) + end subroutine c_base_free_comid + + ! ! shortcut alpha=1 beta=0 ! @@ -1150,7 +1270,7 @@ contains complex(psb_spk_) :: y(:) class(psb_c_base_vect_type) :: x - call x%sync() + if (x%is_dev()) call x%sync() call psi_gth(n,idx,x%v,y) end subroutine c_base_gthzv @@ -1174,7 +1294,7 @@ contains complex(psb_spk_) :: beta, x(:) class(psb_c_base_vect_type) :: y - call y%sync() + if (y%is_dev()) call y%sync() call psi_sct(n,idx,x,beta,y%v) call y%set_host() @@ -1187,7 +1307,9 @@ contains complex(psb_spk_) :: beta, x(:) class(psb_c_base_vect_type) :: y + if (idx%is_dev()) call idx%sync() call y%sct(n,idx%v(i:),x,beta) + call y%set_host() end subroutine c_base_sctb_x diff --git a/base/modules/psb_d_base_vect_mod.f90 b/base/modules/psb_d_base_vect_mod.f90 index 4f3224cc..99662ac9 100644 --- a/base/modules/psb_d_base_vect_mod.f90 +++ b/base/modules/psb_d_base_vect_mod.f90 @@ -62,6 +62,8 @@ module psb_d_base_vect_mod type psb_d_base_vect_type !> Values. real(psb_dpk_), allocatable :: v(:) + real(psb_dpk_), allocatable :: combuf(:) + integer(psb_ipk_), allocatable :: comid(:,:) contains ! ! Constructors/allocators @@ -97,6 +99,17 @@ module psb_d_base_vect_mod procedure, pass(x) :: set_dev => d_base_set_dev procedure, pass(x) :: set_sync => d_base_set_sync + ! + ! These are for handling gather/scatter in new + ! comm internals implementation. + ! + procedure, nopass :: use_buffer => d_base_use_buffer + procedure, pass(x) :: new_buffer => d_base_new_buffer + procedure, nopass :: device_wait => d_base_device_wait + procedure, pass(x) :: free_buffer => d_base_free_buffer + procedure, pass(x) :: new_comid => d_base_new_comid + procedure, pass(x) :: free_comid => d_base_free_comid + ! ! Basic info procedure, pass(x) :: get_nrows => d_base_get_nrows @@ -148,10 +161,12 @@ module psb_d_base_vect_mod procedure, pass(x) :: gthab => d_base_gthab procedure, pass(x) :: gthzv => d_base_gthzv procedure, pass(x) :: gthzv_x => d_base_gthzv_x - generic, public :: gth => gthab, gthzv, gthzv_x + procedure, pass(x) :: gthzbuf => d_base_gthzbuf + generic, public :: gth => gthab, gthzv, gthzv_x, gthzbuf procedure, pass(y) :: sctb => d_base_sctb procedure, pass(y) :: sctb_x => d_base_sctb_x - generic, public :: sct => sctb, sctb_x + procedure, pass(y) :: sctb_buf => d_base_sctb_buf + generic, public :: sct => sctb, sctb_x, sctb_buf end type psb_d_base_vect_type public :: psb_d_base_vect @@ -348,8 +363,8 @@ contains case default info = 321 -! !$ call psb_errpush(info,name) -! !$ goto 9999 + ! !$ call psb_errpush(info,name) + ! !$ goto 9999 end select end if call x%set_host() @@ -360,7 +375,6 @@ contains end subroutine d_base_ins_a - subroutine d_base_ins_v(n,irl,val,dupl,x,info) use psi_serial_mod implicit none @@ -452,6 +466,8 @@ contains info = 0 if (allocated(x%v)) deallocate(x%v, stat=info) + if (info == 0) call x%free_buffer(info) + if (info == 0) call x%free_comid(info) if (info /= 0) call & & psb_errpush(psb_err_alloc_dealloc_,'vect_free') @@ -778,12 +794,9 @@ contains real(psb_dpk_), intent (in) :: alpha, beta integer(psb_ipk_), intent(out) :: info - select type(xx => x) - type is (psb_d_base_vect_type) - call psb_geaxpby(m,alpha,x%v,beta,y%v,info) - class default - call y%axpby(m,alpha,x%v,beta,info) - end select + if (x%is_dev()) call x%sync() + + call y%axpby(m,alpha,x%v,beta,info) end subroutine d_base_axpby_v @@ -809,7 +822,9 @@ contains real(psb_dpk_), intent (in) :: alpha, beta integer(psb_ipk_), intent(out) :: info + if (y%is_dev()) call y%sync() call psb_geaxpby(m,alpha,x,beta,y%v,info) + call y%set_host() end subroutine d_base_axpby_a @@ -838,15 +853,8 @@ contains integer(psb_ipk_) :: i, n info = 0 - select type(xx => x) - type is (psb_d_base_vect_type) - n = min(size(y%v), size(xx%v)) - do i=1, n - y%v(i) = y%v(i)*xx%v(i) - end do - class default - call y%mlt(x%v,info) - end select + if (x%is_dev()) call x%sync() + call y%mlt(x%v,info) end subroutine d_base_mlt_v @@ -866,11 +874,13 @@ contains integer(psb_ipk_) :: i, n info = 0 + if (y%is_dev()) call y%sync() n = min(size(y%v), size(x)) do i=1, n y%v(i) = y%v(i)*x(i) end do - + call y%set_host() + end subroutine d_base_mlt_a @@ -896,6 +906,8 @@ contains integer(psb_ipk_) :: i, n info = 0 + if (z%is_dev()) call z%sync() + n = min(size(z%v), size(x), size(y)) !!$ write(0,*) 'Mlt_a_2: ',n if (alpha == dzero) then @@ -951,6 +963,8 @@ contains end if end if end if + call z%set_host() + end subroutine d_base_mlt_a_2 ! @@ -978,6 +992,8 @@ contains logical :: conjgx_, conjgy_ info = 0 + if (y%is_dev()) call y%sync() + if (x%is_dev()) call x%sync() if (.not.psb_d_is_complex_) then call z%mlt(alpha,x%v,y%v,beta,info) else @@ -1004,7 +1020,7 @@ contains integer(psb_ipk_) :: i, n info = 0 - + if (y%is_dev()) call y%sync() call z%mlt(alpha,x,y%v,beta,info) end subroutine d_base_mlt_av @@ -1020,7 +1036,7 @@ contains integer(psb_ipk_) :: i, n info = 0 - + if (x%is_dev()) call x%sync() call z%mlt(alpha,y,x,beta,info) end subroutine d_base_mlt_va @@ -1040,7 +1056,10 @@ contains class(psb_d_base_vect_type), intent(inout) :: x real(psb_dpk_), intent (in) :: alpha - if (allocated(x%v)) x%v = alpha*x%v + if (allocated(x%v)) then + x%v = alpha*x%v + call x%set_host() + end if end subroutine d_base_scal @@ -1058,6 +1077,7 @@ contains real(psb_dpk_) :: res real(psb_dpk_), external :: dnrm2 + if (x%is_dev()) call x%sync() res = dnrm2(n,x%v,1) end function d_base_nrm2 @@ -1073,6 +1093,7 @@ contains integer(psb_ipk_), intent(in) :: n real(psb_dpk_) :: res + if (x%is_dev()) call x%sync() res = maxval(abs(x%v(1:n))) end function d_base_amax @@ -1088,6 +1109,7 @@ contains integer(psb_ipk_), intent(in) :: n real(psb_dpk_) :: res + if (x%is_dev()) call x%sync() res = sum(abs(x%v(1:n))) end function d_base_asum @@ -1111,7 +1133,7 @@ contains real(psb_dpk_) :: alpha, beta, y(:) class(psb_d_base_vect_type) :: x - call x%sync() + if (x%is_dev()) call x%sync() call psi_gth(n,idx,alpha,x%v,beta,y) end subroutine d_base_gthab @@ -1131,10 +1153,108 @@ contains real(psb_dpk_) :: y(:) class(psb_d_base_vect_type) :: x + if (idx%is_dev()) call idx%sync() call x%gth(n,idx%v(i:),y) end subroutine d_base_gthzv_x + ! + ! New comm internals impl. + ! + subroutine d_base_gthzbuf(i,n,idx,x) + use psi_serial_mod + integer(psb_ipk_) :: i,n + class(psb_i_base_vect_type) :: idx + class(psb_d_base_vect_type) :: x + + if (.not.allocated(x%combuf)) then + call psb_errpush(psb_err_alloc_dealloc_,'gthzbuf') + return + end if + if (idx%is_dev()) call idx%sync() + if (x%is_dev()) call x%sync() + call x%gth(n,idx%v(i:),x%combuf(i:)) + + end subroutine d_base_gthzbuf + + subroutine d_base_sctb_buf(i,n,idx,beta,y) + use psi_serial_mod + integer(psb_ipk_) :: i, n + class(psb_i_base_vect_type) :: idx + real(psb_dpk_) :: beta + class(psb_d_base_vect_type) :: y + + + if (.not.allocated(y%combuf)) then + call psb_errpush(psb_err_alloc_dealloc_,'sctb_buf') + return + end if + if (y%is_dev()) call y%sync() + if (idx%is_dev()) call idx%sync() + call y%sct(n,idx%v(i:),y%combuf(i:),beta) + call y%set_host() + + end subroutine d_base_sctb_buf + + ! + !> Function base_device_wait: + !! \memberof psb_d_base_vect_type + !! \brief device_wait: base version is a no-op. + !! + ! + subroutine d_base_device_wait() + implicit none + + end subroutine d_base_device_wait + + function d_base_use_buffer() result(res) + logical :: res + + res = .true. + end function d_base_use_buffer + + subroutine d_base_new_buffer(n,x,info) + use psb_realloc_mod + implicit none + class(psb_d_base_vect_type), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + integer(psb_ipk_), intent(out) :: info + + call psb_realloc(n,x%combuf,info) + end subroutine d_base_new_buffer + + subroutine d_base_new_comid(n,x,info) + use psb_realloc_mod + implicit none + class(psb_d_base_vect_type), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + integer(psb_ipk_), intent(out) :: info + + call psb_realloc(n,2,x%comid,info) + end subroutine d_base_new_comid + + + subroutine d_base_free_buffer(x,info) + use psb_realloc_mod + implicit none + class(psb_d_base_vect_type), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + if (allocated(x%combuf)) & + & deallocate(x%combuf,stat=info) + end subroutine d_base_free_buffer + + subroutine d_base_free_comid(x,info) + use psb_realloc_mod + implicit none + class(psb_d_base_vect_type), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + if (allocated(x%comid)) & + & deallocate(x%comid,stat=info) + end subroutine d_base_free_comid + + ! ! shortcut alpha=1 beta=0 ! @@ -1150,7 +1270,7 @@ contains real(psb_dpk_) :: y(:) class(psb_d_base_vect_type) :: x - call x%sync() + if (x%is_dev()) call x%sync() call psi_gth(n,idx,x%v,y) end subroutine d_base_gthzv @@ -1174,7 +1294,7 @@ contains real(psb_dpk_) :: beta, x(:) class(psb_d_base_vect_type) :: y - call y%sync() + if (y%is_dev()) call y%sync() call psi_sct(n,idx,x,beta,y%v) call y%set_host() @@ -1187,7 +1307,9 @@ contains real(psb_dpk_) :: beta, x(:) class(psb_d_base_vect_type) :: y + if (idx%is_dev()) call idx%sync() call y%sct(n,idx%v(i:),x,beta) + call y%set_host() end subroutine d_base_sctb_x diff --git a/base/modules/psb_s_base_vect_mod.f90 b/base/modules/psb_s_base_vect_mod.f90 index 22d52a2b..8a4af041 100644 --- a/base/modules/psb_s_base_vect_mod.f90 +++ b/base/modules/psb_s_base_vect_mod.f90 @@ -62,6 +62,8 @@ module psb_s_base_vect_mod type psb_s_base_vect_type !> Values. real(psb_spk_), allocatable :: v(:) + real(psb_spk_), allocatable :: combuf(:) + integer(psb_ipk_), allocatable :: comid(:,:) contains ! ! Constructors/allocators @@ -97,6 +99,17 @@ module psb_s_base_vect_mod procedure, pass(x) :: set_dev => s_base_set_dev procedure, pass(x) :: set_sync => s_base_set_sync + ! + ! These are for handling gather/scatter in new + ! comm internals implementation. + ! + procedure, nopass :: use_buffer => s_base_use_buffer + procedure, pass(x) :: new_buffer => s_base_new_buffer + procedure, nopass :: device_wait => s_base_device_wait + procedure, pass(x) :: free_buffer => s_base_free_buffer + procedure, pass(x) :: new_comid => s_base_new_comid + procedure, pass(x) :: free_comid => s_base_free_comid + ! ! Basic info procedure, pass(x) :: get_nrows => s_base_get_nrows @@ -148,10 +161,12 @@ module psb_s_base_vect_mod procedure, pass(x) :: gthab => s_base_gthab procedure, pass(x) :: gthzv => s_base_gthzv procedure, pass(x) :: gthzv_x => s_base_gthzv_x - generic, public :: gth => gthab, gthzv, gthzv_x + procedure, pass(x) :: gthzbuf => s_base_gthzbuf + generic, public :: gth => gthab, gthzv, gthzv_x, gthzbuf procedure, pass(y) :: sctb => s_base_sctb procedure, pass(y) :: sctb_x => s_base_sctb_x - generic, public :: sct => sctb, sctb_x + procedure, pass(y) :: sctb_buf => s_base_sctb_buf + generic, public :: sct => sctb, sctb_x, sctb_buf end type psb_s_base_vect_type public :: psb_s_base_vect @@ -348,8 +363,8 @@ contains case default info = 321 -! !$ call psb_errpush(info,name) -! !$ goto 9999 + ! !$ call psb_errpush(info,name) + ! !$ goto 9999 end select end if call x%set_host() @@ -360,7 +375,6 @@ contains end subroutine s_base_ins_a - subroutine s_base_ins_v(n,irl,val,dupl,x,info) use psi_serial_mod implicit none @@ -452,6 +466,8 @@ contains info = 0 if (allocated(x%v)) deallocate(x%v, stat=info) + if (info == 0) call x%free_buffer(info) + if (info == 0) call x%free_comid(info) if (info /= 0) call & & psb_errpush(psb_err_alloc_dealloc_,'vect_free') @@ -778,12 +794,9 @@ contains real(psb_spk_), intent (in) :: alpha, beta integer(psb_ipk_), intent(out) :: info - select type(xx => x) - type is (psb_s_base_vect_type) - call psb_geaxpby(m,alpha,x%v,beta,y%v,info) - class default - call y%axpby(m,alpha,x%v,beta,info) - end select + if (x%is_dev()) call x%sync() + + call y%axpby(m,alpha,x%v,beta,info) end subroutine s_base_axpby_v @@ -809,7 +822,9 @@ contains real(psb_spk_), intent (in) :: alpha, beta integer(psb_ipk_), intent(out) :: info + if (y%is_dev()) call y%sync() call psb_geaxpby(m,alpha,x,beta,y%v,info) + call y%set_host() end subroutine s_base_axpby_a @@ -838,15 +853,8 @@ contains integer(psb_ipk_) :: i, n info = 0 - select type(xx => x) - type is (psb_s_base_vect_type) - n = min(size(y%v), size(xx%v)) - do i=1, n - y%v(i) = y%v(i)*xx%v(i) - end do - class default - call y%mlt(x%v,info) - end select + if (x%is_dev()) call x%sync() + call y%mlt(x%v,info) end subroutine s_base_mlt_v @@ -866,11 +874,13 @@ contains integer(psb_ipk_) :: i, n info = 0 + if (y%is_dev()) call y%sync() n = min(size(y%v), size(x)) do i=1, n y%v(i) = y%v(i)*x(i) end do - + call y%set_host() + end subroutine s_base_mlt_a @@ -896,6 +906,8 @@ contains integer(psb_ipk_) :: i, n info = 0 + if (z%is_dev()) call z%sync() + n = min(size(z%v), size(x), size(y)) !!$ write(0,*) 'Mlt_a_2: ',n if (alpha == szero) then @@ -951,6 +963,8 @@ contains end if end if end if + call z%set_host() + end subroutine s_base_mlt_a_2 ! @@ -978,6 +992,8 @@ contains logical :: conjgx_, conjgy_ info = 0 + if (y%is_dev()) call y%sync() + if (x%is_dev()) call x%sync() if (.not.psb_s_is_complex_) then call z%mlt(alpha,x%v,y%v,beta,info) else @@ -1004,7 +1020,7 @@ contains integer(psb_ipk_) :: i, n info = 0 - + if (y%is_dev()) call y%sync() call z%mlt(alpha,x,y%v,beta,info) end subroutine s_base_mlt_av @@ -1020,7 +1036,7 @@ contains integer(psb_ipk_) :: i, n info = 0 - + if (x%is_dev()) call x%sync() call z%mlt(alpha,y,x,beta,info) end subroutine s_base_mlt_va @@ -1040,7 +1056,10 @@ contains class(psb_s_base_vect_type), intent(inout) :: x real(psb_spk_), intent (in) :: alpha - if (allocated(x%v)) x%v = alpha*x%v + if (allocated(x%v)) then + x%v = alpha*x%v + call x%set_host() + end if end subroutine s_base_scal @@ -1058,6 +1077,7 @@ contains real(psb_spk_) :: res real(psb_spk_), external :: snrm2 + if (x%is_dev()) call x%sync() res = snrm2(n,x%v,1) end function s_base_nrm2 @@ -1073,6 +1093,7 @@ contains integer(psb_ipk_), intent(in) :: n real(psb_spk_) :: res + if (x%is_dev()) call x%sync() res = maxval(abs(x%v(1:n))) end function s_base_amax @@ -1088,6 +1109,7 @@ contains integer(psb_ipk_), intent(in) :: n real(psb_spk_) :: res + if (x%is_dev()) call x%sync() res = sum(abs(x%v(1:n))) end function s_base_asum @@ -1111,7 +1133,7 @@ contains real(psb_spk_) :: alpha, beta, y(:) class(psb_s_base_vect_type) :: x - call x%sync() + if (x%is_dev()) call x%sync() call psi_gth(n,idx,alpha,x%v,beta,y) end subroutine s_base_gthab @@ -1131,10 +1153,108 @@ contains real(psb_spk_) :: y(:) class(psb_s_base_vect_type) :: x + if (idx%is_dev()) call idx%sync() call x%gth(n,idx%v(i:),y) end subroutine s_base_gthzv_x + ! + ! New comm internals impl. + ! + subroutine s_base_gthzbuf(i,n,idx,x) + use psi_serial_mod + integer(psb_ipk_) :: i,n + class(psb_i_base_vect_type) :: idx + class(psb_s_base_vect_type) :: x + + if (.not.allocated(x%combuf)) then + call psb_errpush(psb_err_alloc_dealloc_,'gthzbuf') + return + end if + if (idx%is_dev()) call idx%sync() + if (x%is_dev()) call x%sync() + call x%gth(n,idx%v(i:),x%combuf(i:)) + + end subroutine s_base_gthzbuf + + subroutine s_base_sctb_buf(i,n,idx,beta,y) + use psi_serial_mod + integer(psb_ipk_) :: i, n + class(psb_i_base_vect_type) :: idx + real(psb_spk_) :: beta + class(psb_s_base_vect_type) :: y + + + if (.not.allocated(y%combuf)) then + call psb_errpush(psb_err_alloc_dealloc_,'sctb_buf') + return + end if + if (y%is_dev()) call y%sync() + if (idx%is_dev()) call idx%sync() + call y%sct(n,idx%v(i:),y%combuf(i:),beta) + call y%set_host() + + end subroutine s_base_sctb_buf + + ! + !> Function base_device_wait: + !! \memberof psb_s_base_vect_type + !! \brief device_wait: base version is a no-op. + !! + ! + subroutine s_base_device_wait() + implicit none + + end subroutine s_base_device_wait + + function s_base_use_buffer() result(res) + logical :: res + + res = .true. + end function s_base_use_buffer + + subroutine s_base_new_buffer(n,x,info) + use psb_realloc_mod + implicit none + class(psb_s_base_vect_type), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + integer(psb_ipk_), intent(out) :: info + + call psb_realloc(n,x%combuf,info) + end subroutine s_base_new_buffer + + subroutine s_base_new_comid(n,x,info) + use psb_realloc_mod + implicit none + class(psb_s_base_vect_type), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + integer(psb_ipk_), intent(out) :: info + + call psb_realloc(n,2,x%comid,info) + end subroutine s_base_new_comid + + + subroutine s_base_free_buffer(x,info) + use psb_realloc_mod + implicit none + class(psb_s_base_vect_type), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + if (allocated(x%combuf)) & + & deallocate(x%combuf,stat=info) + end subroutine s_base_free_buffer + + subroutine s_base_free_comid(x,info) + use psb_realloc_mod + implicit none + class(psb_s_base_vect_type), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + if (allocated(x%comid)) & + & deallocate(x%comid,stat=info) + end subroutine s_base_free_comid + + ! ! shortcut alpha=1 beta=0 ! @@ -1150,7 +1270,7 @@ contains real(psb_spk_) :: y(:) class(psb_s_base_vect_type) :: x - call x%sync() + if (x%is_dev()) call x%sync() call psi_gth(n,idx,x%v,y) end subroutine s_base_gthzv @@ -1174,7 +1294,7 @@ contains real(psb_spk_) :: beta, x(:) class(psb_s_base_vect_type) :: y - call y%sync() + if (y%is_dev()) call y%sync() call psi_sct(n,idx,x,beta,y%v) call y%set_host() @@ -1187,7 +1307,9 @@ contains real(psb_spk_) :: beta, x(:) class(psb_s_base_vect_type) :: y + if (idx%is_dev()) call idx%sync() call y%sct(n,idx%v(i:),x,beta) + call y%set_host() end subroutine s_base_sctb_x diff --git a/base/modules/psb_s_tools_mod.f90 b/base/modules/psb_s_tools_mod.f90 index fe7f8c0c..3f2bb1e0 100644 --- a/base/modules/psb_s_tools_mod.f90 +++ b/base/modules/psb_s_tools_mod.f90 @@ -22,7 +22,7 @@ !!$ PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE PSBLAS GROUP OR ITS CONTRIBUTORS !!$ BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR !!$ CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -!!$ SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSIESS +!!$ SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS !!$ INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN !!$ CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) !!$ ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE diff --git a/base/modules/psb_z_base_vect_mod.f90 b/base/modules/psb_z_base_vect_mod.f90 index 3c907177..b10aa9ad 100644 --- a/base/modules/psb_z_base_vect_mod.f90 +++ b/base/modules/psb_z_base_vect_mod.f90 @@ -62,6 +62,8 @@ module psb_z_base_vect_mod type psb_z_base_vect_type !> Values. complex(psb_dpk_), allocatable :: v(:) + complex(psb_dpk_), allocatable :: combuf(:) + integer(psb_ipk_), allocatable :: comid(:,:) contains ! ! Constructors/allocators @@ -97,6 +99,17 @@ module psb_z_base_vect_mod procedure, pass(x) :: set_dev => z_base_set_dev procedure, pass(x) :: set_sync => z_base_set_sync + ! + ! These are for handling gather/scatter in new + ! comm internals implementation. + ! + procedure, nopass :: use_buffer => z_base_use_buffer + procedure, pass(x) :: new_buffer => z_base_new_buffer + procedure, nopass :: device_wait => z_base_device_wait + procedure, pass(x) :: free_buffer => z_base_free_buffer + procedure, pass(x) :: new_comid => z_base_new_comid + procedure, pass(x) :: free_comid => z_base_free_comid + ! ! Basic info procedure, pass(x) :: get_nrows => z_base_get_nrows @@ -148,10 +161,12 @@ module psb_z_base_vect_mod procedure, pass(x) :: gthab => z_base_gthab procedure, pass(x) :: gthzv => z_base_gthzv procedure, pass(x) :: gthzv_x => z_base_gthzv_x - generic, public :: gth => gthab, gthzv, gthzv_x + procedure, pass(x) :: gthzbuf => z_base_gthzbuf + generic, public :: gth => gthab, gthzv, gthzv_x, gthzbuf procedure, pass(y) :: sctb => z_base_sctb procedure, pass(y) :: sctb_x => z_base_sctb_x - generic, public :: sct => sctb, sctb_x + procedure, pass(y) :: sctb_buf => z_base_sctb_buf + generic, public :: sct => sctb, sctb_x, sctb_buf end type psb_z_base_vect_type public :: psb_z_base_vect @@ -348,8 +363,8 @@ contains case default info = 321 -! !$ call psb_errpush(info,name) -! !$ goto 9999 + ! !$ call psb_errpush(info,name) + ! !$ goto 9999 end select end if call x%set_host() @@ -360,7 +375,6 @@ contains end subroutine z_base_ins_a - subroutine z_base_ins_v(n,irl,val,dupl,x,info) use psi_serial_mod implicit none @@ -452,6 +466,8 @@ contains info = 0 if (allocated(x%v)) deallocate(x%v, stat=info) + if (info == 0) call x%free_buffer(info) + if (info == 0) call x%free_comid(info) if (info /= 0) call & & psb_errpush(psb_err_alloc_dealloc_,'vect_free') @@ -778,12 +794,9 @@ contains complex(psb_dpk_), intent (in) :: alpha, beta integer(psb_ipk_), intent(out) :: info - select type(xx => x) - type is (psb_z_base_vect_type) - call psb_geaxpby(m,alpha,x%v,beta,y%v,info) - class default - call y%axpby(m,alpha,x%v,beta,info) - end select + if (x%is_dev()) call x%sync() + + call y%axpby(m,alpha,x%v,beta,info) end subroutine z_base_axpby_v @@ -809,7 +822,9 @@ contains complex(psb_dpk_), intent (in) :: alpha, beta integer(psb_ipk_), intent(out) :: info + if (y%is_dev()) call y%sync() call psb_geaxpby(m,alpha,x,beta,y%v,info) + call y%set_host() end subroutine z_base_axpby_a @@ -838,15 +853,8 @@ contains integer(psb_ipk_) :: i, n info = 0 - select type(xx => x) - type is (psb_z_base_vect_type) - n = min(size(y%v), size(xx%v)) - do i=1, n - y%v(i) = y%v(i)*xx%v(i) - end do - class default - call y%mlt(x%v,info) - end select + if (x%is_dev()) call x%sync() + call y%mlt(x%v,info) end subroutine z_base_mlt_v @@ -866,11 +874,13 @@ contains integer(psb_ipk_) :: i, n info = 0 + if (y%is_dev()) call y%sync() n = min(size(y%v), size(x)) do i=1, n y%v(i) = y%v(i)*x(i) end do - + call y%set_host() + end subroutine z_base_mlt_a @@ -896,6 +906,8 @@ contains integer(psb_ipk_) :: i, n info = 0 + if (z%is_dev()) call z%sync() + n = min(size(z%v), size(x), size(y)) !!$ write(0,*) 'Mlt_a_2: ',n if (alpha == zzero) then @@ -951,6 +963,8 @@ contains end if end if end if + call z%set_host() + end subroutine z_base_mlt_a_2 ! @@ -978,6 +992,8 @@ contains logical :: conjgx_, conjgy_ info = 0 + if (y%is_dev()) call y%sync() + if (x%is_dev()) call x%sync() if (.not.psb_z_is_complex_) then call z%mlt(alpha,x%v,y%v,beta,info) else @@ -1004,7 +1020,7 @@ contains integer(psb_ipk_) :: i, n info = 0 - + if (y%is_dev()) call y%sync() call z%mlt(alpha,x,y%v,beta,info) end subroutine z_base_mlt_av @@ -1020,7 +1036,7 @@ contains integer(psb_ipk_) :: i, n info = 0 - + if (x%is_dev()) call x%sync() call z%mlt(alpha,y,x,beta,info) end subroutine z_base_mlt_va @@ -1040,7 +1056,10 @@ contains class(psb_z_base_vect_type), intent(inout) :: x complex(psb_dpk_), intent (in) :: alpha - if (allocated(x%v)) x%v = alpha*x%v + if (allocated(x%v)) then + x%v = alpha*x%v + call x%set_host() + end if end subroutine z_base_scal @@ -1058,6 +1077,7 @@ contains real(psb_dpk_) :: res real(psb_dpk_), external :: dznrm2 + if (x%is_dev()) call x%sync() res = dznrm2(n,x%v,1) end function z_base_nrm2 @@ -1073,6 +1093,7 @@ contains integer(psb_ipk_), intent(in) :: n real(psb_dpk_) :: res + if (x%is_dev()) call x%sync() res = maxval(abs(x%v(1:n))) end function z_base_amax @@ -1088,6 +1109,7 @@ contains integer(psb_ipk_), intent(in) :: n real(psb_dpk_) :: res + if (x%is_dev()) call x%sync() res = sum(abs(x%v(1:n))) end function z_base_asum @@ -1111,7 +1133,7 @@ contains complex(psb_dpk_) :: alpha, beta, y(:) class(psb_z_base_vect_type) :: x - call x%sync() + if (x%is_dev()) call x%sync() call psi_gth(n,idx,alpha,x%v,beta,y) end subroutine z_base_gthab @@ -1131,10 +1153,108 @@ contains complex(psb_dpk_) :: y(:) class(psb_z_base_vect_type) :: x + if (idx%is_dev()) call idx%sync() call x%gth(n,idx%v(i:),y) end subroutine z_base_gthzv_x + ! + ! New comm internals impl. + ! + subroutine z_base_gthzbuf(i,n,idx,x) + use psi_serial_mod + integer(psb_ipk_) :: i,n + class(psb_i_base_vect_type) :: idx + class(psb_z_base_vect_type) :: x + + if (.not.allocated(x%combuf)) then + call psb_errpush(psb_err_alloc_dealloc_,'gthzbuf') + return + end if + if (idx%is_dev()) call idx%sync() + if (x%is_dev()) call x%sync() + call x%gth(n,idx%v(i:),x%combuf(i:)) + + end subroutine z_base_gthzbuf + + subroutine z_base_sctb_buf(i,n,idx,beta,y) + use psi_serial_mod + integer(psb_ipk_) :: i, n + class(psb_i_base_vect_type) :: idx + complex(psb_dpk_) :: beta + class(psb_z_base_vect_type) :: y + + + if (.not.allocated(y%combuf)) then + call psb_errpush(psb_err_alloc_dealloc_,'sctb_buf') + return + end if + if (y%is_dev()) call y%sync() + if (idx%is_dev()) call idx%sync() + call y%sct(n,idx%v(i:),y%combuf(i:),beta) + call y%set_host() + + end subroutine z_base_sctb_buf + + ! + !> Function base_device_wait: + !! \memberof psb_z_base_vect_type + !! \brief device_wait: base version is a no-op. + !! + ! + subroutine z_base_device_wait() + implicit none + + end subroutine z_base_device_wait + + function z_base_use_buffer() result(res) + logical :: res + + res = .true. + end function z_base_use_buffer + + subroutine z_base_new_buffer(n,x,info) + use psb_realloc_mod + implicit none + class(psb_z_base_vect_type), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + integer(psb_ipk_), intent(out) :: info + + call psb_realloc(n,x%combuf,info) + end subroutine z_base_new_buffer + + subroutine z_base_new_comid(n,x,info) + use psb_realloc_mod + implicit none + class(psb_z_base_vect_type), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + integer(psb_ipk_), intent(out) :: info + + call psb_realloc(n,2,x%comid,info) + end subroutine z_base_new_comid + + + subroutine z_base_free_buffer(x,info) + use psb_realloc_mod + implicit none + class(psb_z_base_vect_type), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + if (allocated(x%combuf)) & + & deallocate(x%combuf,stat=info) + end subroutine z_base_free_buffer + + subroutine z_base_free_comid(x,info) + use psb_realloc_mod + implicit none + class(psb_z_base_vect_type), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + if (allocated(x%comid)) & + & deallocate(x%comid,stat=info) + end subroutine z_base_free_comid + + ! ! shortcut alpha=1 beta=0 ! @@ -1150,7 +1270,7 @@ contains complex(psb_dpk_) :: y(:) class(psb_z_base_vect_type) :: x - call x%sync() + if (x%is_dev()) call x%sync() call psi_gth(n,idx,x%v,y) end subroutine z_base_gthzv @@ -1174,7 +1294,7 @@ contains complex(psb_dpk_) :: beta, x(:) class(psb_z_base_vect_type) :: y - call y%sync() + if (y%is_dev()) call y%sync() call psi_sct(n,idx,x,beta,y%v) call y%set_host() @@ -1187,7 +1307,9 @@ contains complex(psb_dpk_) :: beta, x(:) class(psb_z_base_vect_type) :: y + if (idx%is_dev()) call idx%sync() call y%sct(n,idx%v(i:),x,beta) + call y%set_host() end subroutine z_base_sctb_x