From 01c6c2e9f7b8fbe1f495b5b36c9181ab03f3ab23 Mon Sep 17 00:00:00 2001 From: Fabio Durastante Date: Wed, 2 Apr 2025 09:56:24 +0200 Subject: [PATCH] Implemented dot routines for multivec (all kinds) --- base/modules/psblas/psb_c_psblas_mod.F90 | 21 ++ base/modules/psblas/psb_d_psblas_mod.F90 | 2 +- base/modules/psblas/psb_s_psblas_mod.F90 | 21 ++ base/modules/psblas/psb_z_psblas_mod.F90 | 21 ++ base/modules/serial/psb_c_base_vect_mod.F90 | 64 +++- base/modules/serial/psb_c_vect_mod.F90 | 150 +++++++-- base/modules/serial/psb_d_base_vect_mod.F90 | 69 ++-- base/modules/serial/psb_d_vect_mod.F90 | 33 +- base/modules/serial/psb_i_vect_mod.F90 | 71 +++- base/modules/serial/psb_l_vect_mod.F90 | 71 +++- base/modules/serial/psb_s_base_vect_mod.F90 | 64 +++- base/modules/serial/psb_s_vect_mod.F90 | 150 +++++++-- base/modules/serial/psb_z_base_vect_mod.F90 | 64 +++- base/modules/serial/psb_z_vect_mod.F90 | 150 +++++++-- base/psblas/psb_cdot.f90 | 344 ++++++++++++++++++++ base/psblas/psb_ddot.f90 | 23 +- base/psblas/psb_sdot.f90 | 344 ++++++++++++++++++++ base/psblas/psb_zdot.f90 | 344 ++++++++++++++++++++ 18 files changed, 1845 insertions(+), 161 deletions(-) diff --git a/base/modules/psblas/psb_c_psblas_mod.F90 b/base/modules/psblas/psb_c_psblas_mod.F90 index 130159bc0..4a606f7e2 100644 --- a/base/modules/psblas/psb_c_psblas_mod.F90 +++ b/base/modules/psblas/psb_c_psblas_mod.F90 @@ -31,6 +31,7 @@ ! module psb_c_psblas_mod use psb_desc_mod, only : psb_desc_type, psb_spk_, psb_ipk_, psb_lpk_ + use psb_c_multivect_mod, only : psb_c_multivect_type use psb_c_vect_mod, only : psb_c_vect_type use psb_c_mat_mod, only : psb_cspmat_type @@ -63,6 +64,26 @@ module psb_c_psblas_mod integer(psb_ipk_), intent(out) :: info logical, intent(in), optional :: global end function psb_cdot + function psb_cdot_multivect(x, y, desc_a,info,global) result(res) + import :: psb_desc_type, psb_spk_, psb_ipk_, & + & psb_c_multivect_type, psb_cspmat_type + complex(psb_spk_), dimension(:), allocatable :: res + type(psb_c_multivect_type), intent(inout) :: x, y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global + end function psb_cdot_multivect + function psb_cdot_mvect_vect(x, y, desc_a,info,global) result(res) + import :: psb_desc_type, psb_spk_, psb_ipk_, & + & psb_c_multivect_type, psb_cspmat_type, & + & psb_c_vect_type + complex(psb_spk_), dimension(:), allocatable :: res + type(psb_c_multivect_type), intent(inout) :: x + type(psb_c_vect_type), intent(inout) :: y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global + end function psb_cdot_mvect_vect end interface diff --git a/base/modules/psblas/psb_d_psblas_mod.F90 b/base/modules/psblas/psb_d_psblas_mod.F90 index 45a2a25c4..ca61729c3 100644 --- a/base/modules/psblas/psb_d_psblas_mod.F90 +++ b/base/modules/psblas/psb_d_psblas_mod.F90 @@ -31,8 +31,8 @@ ! module psb_d_psblas_mod use psb_desc_mod, only : psb_desc_type, psb_dpk_, psb_ipk_, psb_lpk_ - use psb_d_vect_mod, only : psb_d_vect_type use psb_d_multivect_mod, only : psb_d_multivect_type + use psb_d_vect_mod, only : psb_d_vect_type use psb_d_mat_mod, only : psb_dspmat_type interface psb_gedot diff --git a/base/modules/psblas/psb_s_psblas_mod.F90 b/base/modules/psblas/psb_s_psblas_mod.F90 index 6048d023e..f8da1c7c9 100644 --- a/base/modules/psblas/psb_s_psblas_mod.F90 +++ b/base/modules/psblas/psb_s_psblas_mod.F90 @@ -31,6 +31,7 @@ ! module psb_s_psblas_mod use psb_desc_mod, only : psb_desc_type, psb_spk_, psb_ipk_, psb_lpk_ + use psb_s_multivect_mod, only : psb_s_multivect_type use psb_s_vect_mod, only : psb_s_vect_type use psb_s_mat_mod, only : psb_sspmat_type @@ -63,6 +64,26 @@ module psb_s_psblas_mod integer(psb_ipk_), intent(out) :: info logical, intent(in), optional :: global end function psb_sdot + function psb_sdot_multivect(x, y, desc_a,info,global) result(res) + import :: psb_desc_type, psb_spk_, psb_ipk_, & + & psb_s_multivect_type, psb_sspmat_type + real(psb_spk_), dimension(:), allocatable :: res + type(psb_s_multivect_type), intent(inout) :: x, y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global + end function psb_sdot_multivect + function psb_sdot_mvect_vect(x, y, desc_a,info,global) result(res) + import :: psb_desc_type, psb_spk_, psb_ipk_, & + & psb_s_multivect_type, psb_sspmat_type, & + & psb_s_vect_type + real(psb_spk_), dimension(:), allocatable :: res + type(psb_s_multivect_type), intent(inout) :: x + type(psb_s_vect_type), intent(inout) :: y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global + end function psb_sdot_mvect_vect end interface diff --git a/base/modules/psblas/psb_z_psblas_mod.F90 b/base/modules/psblas/psb_z_psblas_mod.F90 index fd0cc300d..c45f02af9 100644 --- a/base/modules/psblas/psb_z_psblas_mod.F90 +++ b/base/modules/psblas/psb_z_psblas_mod.F90 @@ -31,6 +31,7 @@ ! module psb_z_psblas_mod use psb_desc_mod, only : psb_desc_type, psb_dpk_, psb_ipk_, psb_lpk_ + use psb_z_multivect_mod, only : psb_z_multivect_type use psb_z_vect_mod, only : psb_z_vect_type use psb_z_mat_mod, only : psb_zspmat_type @@ -63,6 +64,26 @@ module psb_z_psblas_mod integer(psb_ipk_), intent(out) :: info logical, intent(in), optional :: global end function psb_zdot + function psb_zdot_multivect(x, y, desc_a,info,global) result(res) + import :: psb_desc_type, psb_dpk_, psb_ipk_, & + & psb_z_multivect_type, psb_zspmat_type + complex(psb_dpk_), dimension(:), allocatable :: res + type(psb_z_multivect_type), intent(inout) :: x, y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global + end function psb_zdot_multivect + function psb_zdot_mvect_vect(x, y, desc_a,info,global) result(res) + import :: psb_desc_type, psb_dpk_, psb_ipk_, & + & psb_z_multivect_type, psb_zspmat_type, & + & psb_z_vect_type + complex(psb_dpk_), dimension(:), allocatable :: res + type(psb_z_multivect_type), intent(inout) :: x + type(psb_z_vect_type), intent(inout) :: y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global + end function psb_zdot_mvect_vect end interface diff --git a/base/modules/serial/psb_c_base_vect_mod.F90 b/base/modules/serial/psb_c_base_vect_mod.F90 index eac61c75d..1b285556a 100644 --- a/base/modules/serial/psb_c_base_vect_mod.F90 +++ b/base/modules/serial/psb_c_base_vect_mod.F90 @@ -149,7 +149,8 @@ module psb_c_base_vect_mod ! procedure, pass(x) :: dot_v => c_base_dot_v procedure, pass(x) :: dot_a => c_base_dot_a - generic, public :: dot => dot_v, dot_a + procedure, pass(x) :: dot_a2 => c_base_dot_a2 + generic, public :: dot => dot_v, dot_a, dot_a2 procedure, pass(y) :: axpby_v => c_base_axpby_v procedure, pass(y) :: axpby_a => c_base_axpby_a procedure, pass(z) :: axpby_v2 => c_base_axpby_v2 @@ -1010,6 +1011,34 @@ contains end function c_base_dot_a + ! + ! Base workhorse is good old BLAS2 + ! + ! + !> Function base_dot_a + !! \memberof psb_d_base_vect_type + !! \brief Dot product by a normal array + !! \param n Number of entries to be considered + !! \param y(:,:) The matrix to be multiplied by + !! + function c_base_dot_a2(n,x,y) result(res) + implicit none + class(psb_c_base_vect_type), intent(inout) :: x + complex(psb_spk_), intent(in) :: y(:,:) + integer(psb_ipk_), intent(in) :: n + complex(psb_spk_), allocatable, dimension(:) :: res + + ! local + integer(psb_ipk_) :: ncol + + ncol = size(y,2) + allocate(res(ncol)) + ! On the real cases the 'C' acts as a transpose, + ! on the complex cases it is a conjugate transpose + call cgemv('C',n,ncol,cone,y,n,x%v,1,czero,res,1) + + end function c_base_dot_a2 + ! ! AXPBY is invoked via Y, hence the structure below. ! @@ -2133,7 +2162,8 @@ module psb_c_base_multivect_mod ! procedure, pass(x) :: dot_v => c_base_mlv_dot_v procedure, pass(x) :: dot_a => c_base_mlv_dot_a - generic, public :: dot => dot_v, dot_a + procedure, pass(x) :: dot_vect => c_base_mlv_dot_vect + generic, public :: dot => dot_v, dot_a, dot_vect procedure, pass(y) :: axpby_v => c_base_mlv_axpby_v procedure, pass(y) :: axpby_a => c_base_mlv_axpby_a generic, public :: axpby => axpby_v, axpby_a @@ -2754,6 +2784,36 @@ contains end function c_base_mlv_dot_a + !> Function c_base_mlv_dot_vect + !! \memberof psb_c_base_multivect_type + !! \brief Dot product by a base_mlv_vector + !! \param n Number of entries to be considered + !! \param y The other (base_vect) to be multiplied by + !! + function c_base_mlv_dot_vect(n,x,y) result(res) + implicit none + class(psb_c_base_multivect_type), intent(inout) :: x + class(psb_c_base_vect_type), intent(inout) :: y + integer(psb_ipk_), intent(in) :: n + complex(psb_spk_), allocatable :: res(:) + complex(psb_spk_), external :: cdot + integer(psb_ipk_) :: j,nc + + if (x%is_dev()) call x%sync() + + select type(yy => y) + type is (psb_c_base_vect_type) + nc = psb_size(x%v,2_psb_ipk_) + allocate(res(nc)) + do j=1,nc + res(j) = cdot(n,x%v(:,j),1,y%v,1) + end do + class default + res = y%dot(n,x%v) + end select + + end function c_base_mlv_dot_vect + ! ! AXPBY is invoked via Y, hence the structure below. ! diff --git a/base/modules/serial/psb_c_vect_mod.F90 b/base/modules/serial/psb_c_vect_mod.F90 index 737ba26da..0371b4a73 100644 --- a/base/modules/serial/psb_c_vect_mod.F90 +++ b/base/modules/serial/psb_c_vect_mod.F90 @@ -1287,6 +1287,7 @@ end module psb_c_vect_mod module psb_c_multivect_mod use psb_c_base_multivect_mod + use psb_c_vect_mod use psb_const_mod use psb_i_vect_mod @@ -1309,11 +1310,18 @@ module psb_c_multivect_mod procedure, pass(x) :: get_dupl => c_mvect_get_dupl procedure, pass(x) :: set_dupl => c_mvect_set_dupl + procedure, pass(x) :: sync => c_mvect_sync + procedure, pass(x) :: is_host => c_mvect_is_host + procedure, pass(x) :: is_dev => c_mvect_is_dev + procedure, pass(x) :: is_sync => c_mvect_is_sync + procedure, pass(x) :: set_host => c_mvect_set_host + procedure, pass(x) :: set_dev => c_mvect_set_dev + procedure, pass(x) :: set_sync => c_mvect_set_sync + procedure, pass(x) :: all => c_mvect_all procedure, pass(x) :: reall => c_mvect_reall procedure, pass(x) :: zero => c_mvect_zero procedure, pass(x) :: asb => c_mvect_asb - procedure, pass(x) :: sync => c_mvect_sync procedure, pass(x) :: free => c_mvect_free procedure, pass(x) :: ins => c_mvect_ins procedure, pass(x) :: bld_x => c_mvect_bld_x @@ -1332,9 +1340,10 @@ module psb_c_multivect_mod procedure, pass(y) :: sctb => c_mvect_sctb procedure, pass(y) :: sctb_x => c_mvect_sctb_x generic, public :: sct => sctb, sctb_x -!!$ procedure, pass(x) :: dot_v => c_mvect_dot_v -!!$ procedure, pass(x) :: dot_a => c_mvect_dot_a -!!$ generic, public :: dot => dot_v, dot_a + procedure, pass(x) :: dot_v => c_mvect_dot_v + procedure, pass(x) :: dot_a => c_mvect_dot_a + procedure, pass(x) :: dot_a_vect => c_mvect_dot_vect + generic, public :: dot => dot_v, dot_a, dot_a_vect !!$ procedure, pass(y) :: axpby_v => c_mvect_axpby_v !!$ procedure, pass(y) :: axpby_a => c_mvect_axpby_a !!$ generic, public :: axpby => axpby_v, axpby_a @@ -1394,7 +1403,66 @@ contains x%dupl = psb_dupl_def_ end if end subroutine c_mvect_set_dupl - + + subroutine c_mvect_set_sync(x) + implicit none + class(psb_c_multivect_type), intent(inout) :: x + + if (allocated(x%v)) & + & call x%v%set_sync() + + end subroutine c_mvect_set_sync + + subroutine c_mvect_set_host(x) + implicit none + class(psb_c_multivect_type), intent(inout) :: x + + if (allocated(x%v)) & + & call x%v%set_host() + + end subroutine c_mvect_set_host + + subroutine c_mvect_set_dev(x) + implicit none + class(psb_c_multivect_type), intent(inout) :: x + + if (allocated(x%v)) & + & call x%v%set_dev() + + end subroutine c_mvect_set_dev + + function c_mvect_is_sync(x) result(res) + implicit none + logical :: res + class(psb_c_multivect_type), intent(inout) :: x + + res = .true. + if (allocated(x%v)) & + & res = x%v%is_sync() + + end function c_mvect_is_sync + + function c_mvect_is_host(x) result(res) + implicit none + logical :: res + class(psb_c_multivect_type), intent(inout) :: x + + res = .true. + if (allocated(x%v)) & + & res = x%v%is_host() + + end function c_mvect_is_host + + function c_mvect_is_dev(x) result(res) + implicit none + logical :: res + class(psb_c_multivect_type), intent(inout) :: x + + res = .false. + if (allocated(x%v)) & + & res = x%v%is_dev() + + end function c_mvect_is_dev function c_mvect_is_remote_build(x) result(res) implicit none @@ -1761,31 +1829,53 @@ contains end subroutine c_mvect_cnv -!!$ function c_mvect_dot_v(n,x,y) result(res) -!!$ implicit none -!!$ class(psb_c_multivect_type), intent(inout) :: x, y -!!$ integer(psb_ipk_), intent(in) :: n -!!$ complex(psb_spk_) :: res -!!$ -!!$ res = czero -!!$ if (allocated(x%v).and.allocated(y%v)) & -!!$ & res = x%v%dot(n,y%v) -!!$ -!!$ end function c_mvect_dot_v -!!$ -!!$ function c_mvect_dot_a(n,x,y) result(res) -!!$ implicit none -!!$ class(psb_c_multivect_type), intent(inout) :: x -!!$ complex(psb_spk_), intent(in) :: y(:) -!!$ integer(psb_ipk_), intent(in) :: n -!!$ complex(psb_spk_) :: res -!!$ -!!$ res = czero -!!$ if (allocated(x%v)) & -!!$ & res = x%v%dot(n,y) -!!$ -!!$ end function c_mvect_dot_a -!!$ + function c_mvect_dot_v(n,x,y) result(res) + implicit none + class(psb_c_multivect_type), intent(inout) :: x, y + integer(psb_ipk_), intent(in) :: n + complex(psb_spk_), dimension(:), allocatable :: res + + if (allocated(x%v).and.allocated(y%v)) then + res = x%v%dot(n,y%v) + else + allocate(res(1)) + res(1) = psb_err_invalid_vect_state_ + end if + + end function c_mvect_dot_v + + function c_mvect_dot_vect(n,x,y) result(res) + implicit none + class(psb_c_multivect_type), intent(inout) :: x + class(psb_c_vect_type), intent(inout) :: y + integer(psb_ipk_), intent(in) :: n + complex(psb_spk_), dimension(:), allocatable :: res + + if (allocated(x%v).and.allocated(y%v)) then + res = x%v%dot(n,y%v) + else + allocate(res(1)) + res(1) = psb_err_invalid_vect_state_ + end if + + end function c_mvect_dot_vect + + function c_mvect_dot_a(n,x,y) result(res) + implicit none + class(psb_c_multivect_type), intent(inout) :: x + complex(psb_spk_), intent(in) :: y(:,:) + integer(psb_ipk_), intent(in) :: n + complex(psb_spk_), dimension(:), allocatable :: res + + if (allocated(x%v)) then + res = x%v%dot(n,y) + else + allocate(res(1)) + res(1) = psb_err_invalid_vect_state_ + end if + + end function c_mvect_dot_a + !!$ subroutine c_mvect_axpby_v(m,alpha, x, beta, y, info) !!$ use psi_serial_mod !!$ implicit none diff --git a/base/modules/serial/psb_d_base_vect_mod.F90 b/base/modules/serial/psb_d_base_vect_mod.F90 index 99c40e30c..c626c70e2 100644 --- a/base/modules/serial/psb_d_base_vect_mod.F90 +++ b/base/modules/serial/psb_d_base_vect_mod.F90 @@ -1018,8 +1018,8 @@ contains end function d_base_dot_a - ! - ! Base workhorse is good old BLAS1 + ! + ! Base workhorse is good old BLAS2 ! ! !> Function base_dot_a @@ -1034,15 +1034,15 @@ contains real(psb_dpk_), intent(in) :: y(:,:) integer(psb_ipk_), intent(in) :: n real(psb_dpk_), allocatable, dimension(:) :: res - real(psb_dpk_), external :: ddot ! local integer(psb_ipk_) :: ncol ncol = size(y,2) allocate(res(ncol)) - - call dgemv('T',n,ncol,done,y,n,x%v,1,dzero,res,1) + ! On the real cases the 'C' acts as a transpose, + ! on the complex cases it is a conjugate transpose + call dgemv('C',n,ncol,done,y,n,x%v,1,dzero,res,1) end function d_base_dot_a2 @@ -2913,6 +2913,7 @@ contains integer(psb_ipk_) :: j,nc if (x%is_dev()) call x%sync() + res = dzero ! ! Note: this is the base implementation. ! When we get here, we are sure that X is of @@ -2934,9 +2935,37 @@ contains end function d_base_mlv_dot_v + ! + ! Base workhorse is good old BLAS1 + ! + ! + !> Function base_mlv_dot_a + !! \memberof psb_d_base_multivect_type + !! \brief Dot product by a normal array + !! \param n Number of entries to be considered + !! \param y(:) The array to be multiplied by + !! + function d_base_mlv_dot_a(n,x,y) result(res) + implicit none + class(psb_d_base_multivect_type), intent(inout) :: x + real(psb_dpk_), intent(in) :: y(:,:) + integer(psb_ipk_), intent(in) :: n + real(psb_dpk_), allocatable :: res(:) + real(psb_dpk_), external :: ddot + integer(psb_ipk_) :: j,nc + + if (x%is_dev()) call x%sync() + nc = min(psb_size(x%v,2_psb_ipk_),size(y,2_psb_ipk_)) + allocate(res(nc)) + do j=1,nc + res(j) = ddot(n,x%v(:,j),1,y(:,j),1) + end do + + end function d_base_mlv_dot_a + !> Function d_base_mlv_dot_vect !! \memberof psb_d_base_multivect_type - !! \brief Dot product by another base_mlv_vector + !! \brief Dot product by a base_mlv_vector !! \param n Number of entries to be considered !! \param y The other (base_vect) to be multiplied by !! @@ -2964,34 +2993,6 @@ contains end function d_base_mlv_dot_vect - ! - ! Base workhorse is good old BLAS1 - ! - ! - !> Function base_mlv_dot_a - !! \memberof psb_d_base_multivect_type - !! \brief Dot product by a normal array - !! \param n Number of entries to be considered - !! \param y(:) The array to be multiplied by - !! - function d_base_mlv_dot_a(n,x,y) result(res) - implicit none - class(psb_d_base_multivect_type), intent(inout) :: x - real(psb_dpk_), intent(in) :: y(:,:) - integer(psb_ipk_), intent(in) :: n - real(psb_dpk_), allocatable :: res(:) - real(psb_dpk_), external :: ddot - integer(psb_ipk_) :: j,nc - - if (x%is_dev()) call x%sync() - nc = min(psb_size(x%v,2_psb_ipk_),size(y,2_psb_ipk_)) - allocate(res(nc)) - do j=1,nc - res(j) = ddot(n,x%v(:,j),1,y(:,j),1) - end do - - end function d_base_mlv_dot_a - ! ! AXPBY is invoked via Y, hence the structure below. ! diff --git a/base/modules/serial/psb_d_vect_mod.F90 b/base/modules/serial/psb_d_vect_mod.F90 index 9406d2363..abdac604f 100644 --- a/base/modules/serial/psb_d_vect_mod.F90 +++ b/base/modules/serial/psb_d_vect_mod.F90 @@ -1421,8 +1421,8 @@ module psb_d_multivect_mod generic, public :: sct => sctb, sctb_x procedure, pass(x) :: dot_v => d_mvect_dot_v procedure, pass(x) :: dot_a => d_mvect_dot_a - procedure, pass(x) :: dot_vect => d_mvect_dot_vect - generic, public :: dot => dot_v, dot_a, dot_vect + procedure, pass(x) :: dot_a_vect => d_mvect_dot_vect + generic, public :: dot => dot_v, dot_a, dot_a_vect !!$ procedure, pass(y) :: axpby_v => d_mvect_axpby_v !!$ procedure, pass(y) :: axpby_a => d_mvect_axpby_a !!$ generic, public :: axpby => axpby_v, axpby_a @@ -1483,15 +1483,6 @@ contains end if end subroutine d_mvect_set_dupl - subroutine d_mvect_sync(x) - implicit none - class(psb_d_multivect_type), intent(inout) :: x - - if (allocated(x%v)) & - & call x%v%sync() - - end subroutine d_mvect_sync - subroutine d_mvect_set_sync(x) implicit none class(psb_d_multivect_type), intent(inout) :: x @@ -1551,7 +1542,6 @@ contains & res = x%v%is_dev() end function d_mvect_is_dev - function d_mvect_is_remote_build(x) result(res) implicit none @@ -1795,6 +1785,15 @@ contains end subroutine d_mvect_asb + subroutine d_mvect_sync(x) + implicit none + class(psb_d_multivect_type), intent(inout) :: x + + if (allocated(x%v)) & + & call x%v%sync() + + end subroutine d_mvect_sync + subroutine d_mvect_gthab(n,idx,alpha,x,beta,y) use psi_serial_mod integer(psb_ipk_) :: n, idx(:) @@ -1912,7 +1911,7 @@ contains function d_mvect_dot_v(n,x,y) result(res) implicit none class(psb_d_multivect_type), intent(inout) :: x, y - integer(psb_ipk_), intent(in) :: n + integer(psb_ipk_), intent(in) :: n real(psb_dpk_), dimension(:), allocatable :: res if (allocated(x%v).and.allocated(y%v)) then @@ -1947,13 +1946,15 @@ contains integer(psb_ipk_), intent(in) :: n real(psb_dpk_), dimension(:), allocatable :: res - res = dzero if (allocated(x%v)) then - res = x%v%dot(n,y) + res = x%v%dot(n,y) + else + allocate(res(1)) + res(1) = psb_err_invalid_vect_state_ end if end function d_mvect_dot_a -!!$ + !!$ subroutine d_mvect_axpby_v(m,alpha, x, beta, y, info) !!$ use psi_serial_mod !!$ implicit none diff --git a/base/modules/serial/psb_i_vect_mod.F90 b/base/modules/serial/psb_i_vect_mod.F90 index 0ff16c54b..13f90ada5 100644 --- a/base/modules/serial/psb_i_vect_mod.F90 +++ b/base/modules/serial/psb_i_vect_mod.F90 @@ -628,6 +628,7 @@ end module psb_i_vect_mod module psb_i_multivect_mod use psb_i_base_multivect_mod + use psb_i_vect_mod use psb_const_mod use psb_i_vect_mod @@ -650,11 +651,18 @@ module psb_i_multivect_mod procedure, pass(x) :: get_dupl => i_mvect_get_dupl procedure, pass(x) :: set_dupl => i_mvect_set_dupl + procedure, pass(x) :: sync => i_mvect_sync + procedure, pass(x) :: is_host => i_mvect_is_host + procedure, pass(x) :: is_dev => i_mvect_is_dev + procedure, pass(x) :: is_sync => i_mvect_is_sync + procedure, pass(x) :: set_host => i_mvect_set_host + procedure, pass(x) :: set_dev => i_mvect_set_dev + procedure, pass(x) :: set_sync => i_mvect_set_sync + procedure, pass(x) :: all => i_mvect_all procedure, pass(x) :: reall => i_mvect_reall procedure, pass(x) :: zero => i_mvect_zero procedure, pass(x) :: asb => i_mvect_asb - procedure, pass(x) :: sync => i_mvect_sync procedure, pass(x) :: free => i_mvect_free procedure, pass(x) :: ins => i_mvect_ins procedure, pass(x) :: bld_x => i_mvect_bld_x @@ -717,7 +725,66 @@ contains x%dupl = psb_dupl_def_ end if end subroutine i_mvect_set_dupl - + + subroutine i_mvect_set_sync(x) + implicit none + class(psb_i_multivect_type), intent(inout) :: x + + if (allocated(x%v)) & + & call x%v%set_sync() + + end subroutine i_mvect_set_sync + + subroutine i_mvect_set_host(x) + implicit none + class(psb_i_multivect_type), intent(inout) :: x + + if (allocated(x%v)) & + & call x%v%set_host() + + end subroutine i_mvect_set_host + + subroutine i_mvect_set_dev(x) + implicit none + class(psb_i_multivect_type), intent(inout) :: x + + if (allocated(x%v)) & + & call x%v%set_dev() + + end subroutine i_mvect_set_dev + + function i_mvect_is_sync(x) result(res) + implicit none + logical :: res + class(psb_i_multivect_type), intent(inout) :: x + + res = .true. + if (allocated(x%v)) & + & res = x%v%is_sync() + + end function i_mvect_is_sync + + function i_mvect_is_host(x) result(res) + implicit none + logical :: res + class(psb_i_multivect_type), intent(inout) :: x + + res = .true. + if (allocated(x%v)) & + & res = x%v%is_host() + + end function i_mvect_is_host + + function i_mvect_is_dev(x) result(res) + implicit none + logical :: res + class(psb_i_multivect_type), intent(inout) :: x + + res = .false. + if (allocated(x%v)) & + & res = x%v%is_dev() + + end function i_mvect_is_dev function i_mvect_is_remote_build(x) result(res) implicit none diff --git a/base/modules/serial/psb_l_vect_mod.F90 b/base/modules/serial/psb_l_vect_mod.F90 index 2490d7a2a..1fcb9d5be 100644 --- a/base/modules/serial/psb_l_vect_mod.F90 +++ b/base/modules/serial/psb_l_vect_mod.F90 @@ -629,6 +629,7 @@ end module psb_l_vect_mod module psb_l_multivect_mod use psb_l_base_multivect_mod + use psb_l_vect_mod use psb_const_mod use psb_i_vect_mod @@ -651,11 +652,18 @@ module psb_l_multivect_mod procedure, pass(x) :: get_dupl => l_mvect_get_dupl procedure, pass(x) :: set_dupl => l_mvect_set_dupl + procedure, pass(x) :: sync => l_mvect_sync + procedure, pass(x) :: is_host => l_mvect_is_host + procedure, pass(x) :: is_dev => l_mvect_is_dev + procedure, pass(x) :: is_sync => l_mvect_is_sync + procedure, pass(x) :: set_host => l_mvect_set_host + procedure, pass(x) :: set_dev => l_mvect_set_dev + procedure, pass(x) :: set_sync => l_mvect_set_sync + procedure, pass(x) :: all => l_mvect_all procedure, pass(x) :: reall => l_mvect_reall procedure, pass(x) :: zero => l_mvect_zero procedure, pass(x) :: asb => l_mvect_asb - procedure, pass(x) :: sync => l_mvect_sync procedure, pass(x) :: free => l_mvect_free procedure, pass(x) :: ins => l_mvect_ins procedure, pass(x) :: bld_x => l_mvect_bld_x @@ -718,7 +726,66 @@ contains x%dupl = psb_dupl_def_ end if end subroutine l_mvect_set_dupl - + + subroutine l_mvect_set_sync(x) + implicit none + class(psb_l_multivect_type), intent(inout) :: x + + if (allocated(x%v)) & + & call x%v%set_sync() + + end subroutine l_mvect_set_sync + + subroutine l_mvect_set_host(x) + implicit none + class(psb_l_multivect_type), intent(inout) :: x + + if (allocated(x%v)) & + & call x%v%set_host() + + end subroutine l_mvect_set_host + + subroutine l_mvect_set_dev(x) + implicit none + class(psb_l_multivect_type), intent(inout) :: x + + if (allocated(x%v)) & + & call x%v%set_dev() + + end subroutine l_mvect_set_dev + + function l_mvect_is_sync(x) result(res) + implicit none + logical :: res + class(psb_l_multivect_type), intent(inout) :: x + + res = .true. + if (allocated(x%v)) & + & res = x%v%is_sync() + + end function l_mvect_is_sync + + function l_mvect_is_host(x) result(res) + implicit none + logical :: res + class(psb_l_multivect_type), intent(inout) :: x + + res = .true. + if (allocated(x%v)) & + & res = x%v%is_host() + + end function l_mvect_is_host + + function l_mvect_is_dev(x) result(res) + implicit none + logical :: res + class(psb_l_multivect_type), intent(inout) :: x + + res = .false. + if (allocated(x%v)) & + & res = x%v%is_dev() + + end function l_mvect_is_dev function l_mvect_is_remote_build(x) result(res) implicit none diff --git a/base/modules/serial/psb_s_base_vect_mod.F90 b/base/modules/serial/psb_s_base_vect_mod.F90 index 2d1b03b73..d4a8cdf32 100644 --- a/base/modules/serial/psb_s_base_vect_mod.F90 +++ b/base/modules/serial/psb_s_base_vect_mod.F90 @@ -149,7 +149,8 @@ module psb_s_base_vect_mod ! procedure, pass(x) :: dot_v => s_base_dot_v procedure, pass(x) :: dot_a => s_base_dot_a - generic, public :: dot => dot_v, dot_a + procedure, pass(x) :: dot_a2 => s_base_dot_a2 + generic, public :: dot => dot_v, dot_a, dot_a2 procedure, pass(y) :: axpby_v => s_base_axpby_v procedure, pass(y) :: axpby_a => s_base_axpby_a procedure, pass(z) :: axpby_v2 => s_base_axpby_v2 @@ -1017,6 +1018,34 @@ contains end function s_base_dot_a + ! + ! Base workhorse is good old BLAS2 + ! + ! + !> Function base_dot_a + !! \memberof psb_d_base_vect_type + !! \brief Dot product by a normal array + !! \param n Number of entries to be considered + !! \param y(:,:) The matrix to be multiplied by + !! + function s_base_dot_a2(n,x,y) result(res) + implicit none + class(psb_s_base_vect_type), intent(inout) :: x + real(psb_spk_), intent(in) :: y(:,:) + integer(psb_ipk_), intent(in) :: n + real(psb_spk_), allocatable, dimension(:) :: res + + ! local + integer(psb_ipk_) :: ncol + + ncol = size(y,2) + allocate(res(ncol)) + ! On the real cases the 'C' acts as a transpose, + ! on the complex cases it is a conjugate transpose + call sgemv('C',n,ncol,sone,y,n,x%v,1,szero,res,1) + + end function s_base_dot_a2 + ! ! AXPBY is invoked via Y, hence the structure below. ! @@ -2312,7 +2341,8 @@ module psb_s_base_multivect_mod ! procedure, pass(x) :: dot_v => s_base_mlv_dot_v procedure, pass(x) :: dot_a => s_base_mlv_dot_a - generic, public :: dot => dot_v, dot_a + procedure, pass(x) :: dot_vect => s_base_mlv_dot_vect + generic, public :: dot => dot_v, dot_a, dot_vect procedure, pass(y) :: axpby_v => s_base_mlv_axpby_v procedure, pass(y) :: axpby_a => s_base_mlv_axpby_a generic, public :: axpby => axpby_v, axpby_a @@ -2933,6 +2963,36 @@ contains end function s_base_mlv_dot_a + !> Function s_base_mlv_dot_vect + !! \memberof psb_s_base_multivect_type + !! \brief Dot product by a base_mlv_vector + !! \param n Number of entries to be considered + !! \param y The other (base_vect) to be multiplied by + !! + function s_base_mlv_dot_vect(n,x,y) result(res) + implicit none + class(psb_s_base_multivect_type), intent(inout) :: x + class(psb_s_base_vect_type), intent(inout) :: y + integer(psb_ipk_), intent(in) :: n + real(psb_spk_), allocatable :: res(:) + real(psb_spk_), external :: sdot + integer(psb_ipk_) :: j,nc + + if (x%is_dev()) call x%sync() + + select type(yy => y) + type is (psb_s_base_vect_type) + nc = psb_size(x%v,2_psb_ipk_) + allocate(res(nc)) + do j=1,nc + res(j) = sdot(n,x%v(:,j),1,y%v,1) + end do + class default + res = y%dot(n,x%v) + end select + + end function s_base_mlv_dot_vect + ! ! AXPBY is invoked via Y, hence the structure below. ! diff --git a/base/modules/serial/psb_s_vect_mod.F90 b/base/modules/serial/psb_s_vect_mod.F90 index 259081a64..95a7ab02c 100644 --- a/base/modules/serial/psb_s_vect_mod.F90 +++ b/base/modules/serial/psb_s_vect_mod.F90 @@ -1366,6 +1366,7 @@ end module psb_s_vect_mod module psb_s_multivect_mod use psb_s_base_multivect_mod + use psb_s_vect_mod use psb_const_mod use psb_i_vect_mod @@ -1388,11 +1389,18 @@ module psb_s_multivect_mod procedure, pass(x) :: get_dupl => s_mvect_get_dupl procedure, pass(x) :: set_dupl => s_mvect_set_dupl + procedure, pass(x) :: sync => s_mvect_sync + procedure, pass(x) :: is_host => s_mvect_is_host + procedure, pass(x) :: is_dev => s_mvect_is_dev + procedure, pass(x) :: is_sync => s_mvect_is_sync + procedure, pass(x) :: set_host => s_mvect_set_host + procedure, pass(x) :: set_dev => s_mvect_set_dev + procedure, pass(x) :: set_sync => s_mvect_set_sync + procedure, pass(x) :: all => s_mvect_all procedure, pass(x) :: reall => s_mvect_reall procedure, pass(x) :: zero => s_mvect_zero procedure, pass(x) :: asb => s_mvect_asb - procedure, pass(x) :: sync => s_mvect_sync procedure, pass(x) :: free => s_mvect_free procedure, pass(x) :: ins => s_mvect_ins procedure, pass(x) :: bld_x => s_mvect_bld_x @@ -1411,9 +1419,10 @@ module psb_s_multivect_mod procedure, pass(y) :: sctb => s_mvect_sctb procedure, pass(y) :: sctb_x => s_mvect_sctb_x generic, public :: sct => sctb, sctb_x -!!$ procedure, pass(x) :: dot_v => s_mvect_dot_v -!!$ procedure, pass(x) :: dot_a => s_mvect_dot_a -!!$ generic, public :: dot => dot_v, dot_a + procedure, pass(x) :: dot_v => s_mvect_dot_v + procedure, pass(x) :: dot_a => s_mvect_dot_a + procedure, pass(x) :: dot_a_vect => s_mvect_dot_vect + generic, public :: dot => dot_v, dot_a, dot_a_vect !!$ procedure, pass(y) :: axpby_v => s_mvect_axpby_v !!$ procedure, pass(y) :: axpby_a => s_mvect_axpby_a !!$ generic, public :: axpby => axpby_v, axpby_a @@ -1473,7 +1482,66 @@ contains x%dupl = psb_dupl_def_ end if end subroutine s_mvect_set_dupl - + + subroutine s_mvect_set_sync(x) + implicit none + class(psb_s_multivect_type), intent(inout) :: x + + if (allocated(x%v)) & + & call x%v%set_sync() + + end subroutine s_mvect_set_sync + + subroutine s_mvect_set_host(x) + implicit none + class(psb_s_multivect_type), intent(inout) :: x + + if (allocated(x%v)) & + & call x%v%set_host() + + end subroutine s_mvect_set_host + + subroutine s_mvect_set_dev(x) + implicit none + class(psb_s_multivect_type), intent(inout) :: x + + if (allocated(x%v)) & + & call x%v%set_dev() + + end subroutine s_mvect_set_dev + + function s_mvect_is_sync(x) result(res) + implicit none + logical :: res + class(psb_s_multivect_type), intent(inout) :: x + + res = .true. + if (allocated(x%v)) & + & res = x%v%is_sync() + + end function s_mvect_is_sync + + function s_mvect_is_host(x) result(res) + implicit none + logical :: res + class(psb_s_multivect_type), intent(inout) :: x + + res = .true. + if (allocated(x%v)) & + & res = x%v%is_host() + + end function s_mvect_is_host + + function s_mvect_is_dev(x) result(res) + implicit none + logical :: res + class(psb_s_multivect_type), intent(inout) :: x + + res = .false. + if (allocated(x%v)) & + & res = x%v%is_dev() + + end function s_mvect_is_dev function s_mvect_is_remote_build(x) result(res) implicit none @@ -1840,31 +1908,53 @@ contains end subroutine s_mvect_cnv -!!$ function s_mvect_dot_v(n,x,y) result(res) -!!$ implicit none -!!$ class(psb_s_multivect_type), intent(inout) :: x, y -!!$ integer(psb_ipk_), intent(in) :: n -!!$ real(psb_spk_) :: res -!!$ -!!$ res = szero -!!$ if (allocated(x%v).and.allocated(y%v)) & -!!$ & res = x%v%dot(n,y%v) -!!$ -!!$ end function s_mvect_dot_v -!!$ -!!$ function s_mvect_dot_a(n,x,y) result(res) -!!$ implicit none -!!$ class(psb_s_multivect_type), intent(inout) :: x -!!$ real(psb_spk_), intent(in) :: y(:) -!!$ integer(psb_ipk_), intent(in) :: n -!!$ real(psb_spk_) :: res -!!$ -!!$ res = szero -!!$ if (allocated(x%v)) & -!!$ & res = x%v%dot(n,y) -!!$ -!!$ end function s_mvect_dot_a -!!$ + function s_mvect_dot_v(n,x,y) result(res) + implicit none + class(psb_s_multivect_type), intent(inout) :: x, y + integer(psb_ipk_), intent(in) :: n + real(psb_spk_), dimension(:), allocatable :: res + + if (allocated(x%v).and.allocated(y%v)) then + res = x%v%dot(n,y%v) + else + allocate(res(1)) + res(1) = psb_err_invalid_vect_state_ + end if + + end function s_mvect_dot_v + + function s_mvect_dot_vect(n,x,y) result(res) + implicit none + class(psb_s_multivect_type), intent(inout) :: x + class(psb_s_vect_type), intent(inout) :: y + integer(psb_ipk_), intent(in) :: n + real(psb_spk_), dimension(:), allocatable :: res + + if (allocated(x%v).and.allocated(y%v)) then + res = x%v%dot(n,y%v) + else + allocate(res(1)) + res(1) = psb_err_invalid_vect_state_ + end if + + end function s_mvect_dot_vect + + function s_mvect_dot_a(n,x,y) result(res) + implicit none + class(psb_s_multivect_type), intent(inout) :: x + real(psb_spk_), intent(in) :: y(:,:) + integer(psb_ipk_), intent(in) :: n + real(psb_spk_), dimension(:), allocatable :: res + + if (allocated(x%v)) then + res = x%v%dot(n,y) + else + allocate(res(1)) + res(1) = psb_err_invalid_vect_state_ + end if + + end function s_mvect_dot_a + !!$ subroutine s_mvect_axpby_v(m,alpha, x, beta, y, info) !!$ use psi_serial_mod !!$ implicit none diff --git a/base/modules/serial/psb_z_base_vect_mod.F90 b/base/modules/serial/psb_z_base_vect_mod.F90 index 5a55cdc66..b7981a373 100644 --- a/base/modules/serial/psb_z_base_vect_mod.F90 +++ b/base/modules/serial/psb_z_base_vect_mod.F90 @@ -149,7 +149,8 @@ module psb_z_base_vect_mod ! procedure, pass(x) :: dot_v => z_base_dot_v procedure, pass(x) :: dot_a => z_base_dot_a - generic, public :: dot => dot_v, dot_a + procedure, pass(x) :: dot_a2 => z_base_dot_a2 + generic, public :: dot => dot_v, dot_a, dot_a2 procedure, pass(y) :: axpby_v => z_base_axpby_v procedure, pass(y) :: axpby_a => z_base_axpby_a procedure, pass(z) :: axpby_v2 => z_base_axpby_v2 @@ -1010,6 +1011,34 @@ contains end function z_base_dot_a + ! + ! Base workhorse is good old BLAS2 + ! + ! + !> Function base_dot_a + !! \memberof psb_d_base_vect_type + !! \brief Dot product by a normal array + !! \param n Number of entries to be considered + !! \param y(:,:) The matrix to be multiplied by + !! + function z_base_dot_a2(n,x,y) result(res) + implicit none + class(psb_z_base_vect_type), intent(inout) :: x + complex(psb_dpk_), intent(in) :: y(:,:) + integer(psb_ipk_), intent(in) :: n + complex(psb_dpk_), allocatable, dimension(:) :: res + + ! local + integer(psb_ipk_) :: ncol + + ncol = size(y,2) + allocate(res(ncol)) + ! On the real cases the 'C' acts as a transpose, + ! on the complex cases it is a conjugate transpose + call zgemv('C',n,ncol,zone,y,n,x%v,1,zzero,res,1) + + end function z_base_dot_a2 + ! ! AXPBY is invoked via Y, hence the structure below. ! @@ -2133,7 +2162,8 @@ module psb_z_base_multivect_mod ! procedure, pass(x) :: dot_v => z_base_mlv_dot_v procedure, pass(x) :: dot_a => z_base_mlv_dot_a - generic, public :: dot => dot_v, dot_a + procedure, pass(x) :: dot_vect => z_base_mlv_dot_vect + generic, public :: dot => dot_v, dot_a, dot_vect procedure, pass(y) :: axpby_v => z_base_mlv_axpby_v procedure, pass(y) :: axpby_a => z_base_mlv_axpby_a generic, public :: axpby => axpby_v, axpby_a @@ -2754,6 +2784,36 @@ contains end function z_base_mlv_dot_a + !> Function z_base_mlv_dot_vect + !! \memberof psb_z_base_multivect_type + !! \brief Dot product by a base_mlv_vector + !! \param n Number of entries to be considered + !! \param y The other (base_vect) to be multiplied by + !! + function z_base_mlv_dot_vect(n,x,y) result(res) + implicit none + class(psb_z_base_multivect_type), intent(inout) :: x + class(psb_z_base_vect_type), intent(inout) :: y + integer(psb_ipk_), intent(in) :: n + complex(psb_dpk_), allocatable :: res(:) + complex(psb_dpk_), external :: zdot + integer(psb_ipk_) :: j,nc + + if (x%is_dev()) call x%sync() + + select type(yy => y) + type is (psb_z_base_vect_type) + nc = psb_size(x%v,2_psb_ipk_) + allocate(res(nc)) + do j=1,nc + res(j) = zdot(n,x%v(:,j),1,y%v,1) + end do + class default + res = y%dot(n,x%v) + end select + + end function z_base_mlv_dot_vect + ! ! AXPBY is invoked via Y, hence the structure below. ! diff --git a/base/modules/serial/psb_z_vect_mod.F90 b/base/modules/serial/psb_z_vect_mod.F90 index 5342cc747..c7c0148d5 100644 --- a/base/modules/serial/psb_z_vect_mod.F90 +++ b/base/modules/serial/psb_z_vect_mod.F90 @@ -1287,6 +1287,7 @@ end module psb_z_vect_mod module psb_z_multivect_mod use psb_z_base_multivect_mod + use psb_z_vect_mod use psb_const_mod use psb_i_vect_mod @@ -1309,11 +1310,18 @@ module psb_z_multivect_mod procedure, pass(x) :: get_dupl => z_mvect_get_dupl procedure, pass(x) :: set_dupl => z_mvect_set_dupl + procedure, pass(x) :: sync => z_mvect_sync + procedure, pass(x) :: is_host => z_mvect_is_host + procedure, pass(x) :: is_dev => z_mvect_is_dev + procedure, pass(x) :: is_sync => z_mvect_is_sync + procedure, pass(x) :: set_host => z_mvect_set_host + procedure, pass(x) :: set_dev => z_mvect_set_dev + procedure, pass(x) :: set_sync => z_mvect_set_sync + procedure, pass(x) :: all => z_mvect_all procedure, pass(x) :: reall => z_mvect_reall procedure, pass(x) :: zero => z_mvect_zero procedure, pass(x) :: asb => z_mvect_asb - procedure, pass(x) :: sync => z_mvect_sync procedure, pass(x) :: free => z_mvect_free procedure, pass(x) :: ins => z_mvect_ins procedure, pass(x) :: bld_x => z_mvect_bld_x @@ -1332,9 +1340,10 @@ module psb_z_multivect_mod procedure, pass(y) :: sctb => z_mvect_sctb procedure, pass(y) :: sctb_x => z_mvect_sctb_x generic, public :: sct => sctb, sctb_x -!!$ procedure, pass(x) :: dot_v => z_mvect_dot_v -!!$ procedure, pass(x) :: dot_a => z_mvect_dot_a -!!$ generic, public :: dot => dot_v, dot_a + procedure, pass(x) :: dot_v => z_mvect_dot_v + procedure, pass(x) :: dot_a => z_mvect_dot_a + procedure, pass(x) :: dot_a_vect => z_mvect_dot_vect + generic, public :: dot => dot_v, dot_a, dot_a_vect !!$ procedure, pass(y) :: axpby_v => z_mvect_axpby_v !!$ procedure, pass(y) :: axpby_a => z_mvect_axpby_a !!$ generic, public :: axpby => axpby_v, axpby_a @@ -1394,7 +1403,66 @@ contains x%dupl = psb_dupl_def_ end if end subroutine z_mvect_set_dupl - + + subroutine z_mvect_set_sync(x) + implicit none + class(psb_z_multivect_type), intent(inout) :: x + + if (allocated(x%v)) & + & call x%v%set_sync() + + end subroutine z_mvect_set_sync + + subroutine z_mvect_set_host(x) + implicit none + class(psb_z_multivect_type), intent(inout) :: x + + if (allocated(x%v)) & + & call x%v%set_host() + + end subroutine z_mvect_set_host + + subroutine z_mvect_set_dev(x) + implicit none + class(psb_z_multivect_type), intent(inout) :: x + + if (allocated(x%v)) & + & call x%v%set_dev() + + end subroutine z_mvect_set_dev + + function z_mvect_is_sync(x) result(res) + implicit none + logical :: res + class(psb_z_multivect_type), intent(inout) :: x + + res = .true. + if (allocated(x%v)) & + & res = x%v%is_sync() + + end function z_mvect_is_sync + + function z_mvect_is_host(x) result(res) + implicit none + logical :: res + class(psb_z_multivect_type), intent(inout) :: x + + res = .true. + if (allocated(x%v)) & + & res = x%v%is_host() + + end function z_mvect_is_host + + function z_mvect_is_dev(x) result(res) + implicit none + logical :: res + class(psb_z_multivect_type), intent(inout) :: x + + res = .false. + if (allocated(x%v)) & + & res = x%v%is_dev() + + end function z_mvect_is_dev function z_mvect_is_remote_build(x) result(res) implicit none @@ -1761,31 +1829,53 @@ contains end subroutine z_mvect_cnv -!!$ function z_mvect_dot_v(n,x,y) result(res) -!!$ implicit none -!!$ class(psb_z_multivect_type), intent(inout) :: x, y -!!$ integer(psb_ipk_), intent(in) :: n -!!$ complex(psb_dpk_) :: res -!!$ -!!$ res = zzero -!!$ if (allocated(x%v).and.allocated(y%v)) & -!!$ & res = x%v%dot(n,y%v) -!!$ -!!$ end function z_mvect_dot_v -!!$ -!!$ function z_mvect_dot_a(n,x,y) result(res) -!!$ implicit none -!!$ class(psb_z_multivect_type), intent(inout) :: x -!!$ complex(psb_dpk_), intent(in) :: y(:) -!!$ integer(psb_ipk_), intent(in) :: n -!!$ complex(psb_dpk_) :: res -!!$ -!!$ res = zzero -!!$ if (allocated(x%v)) & -!!$ & res = x%v%dot(n,y) -!!$ -!!$ end function z_mvect_dot_a -!!$ + function z_mvect_dot_v(n,x,y) result(res) + implicit none + class(psb_z_multivect_type), intent(inout) :: x, y + integer(psb_ipk_), intent(in) :: n + complex(psb_dpk_), dimension(:), allocatable :: res + + if (allocated(x%v).and.allocated(y%v)) then + res = x%v%dot(n,y%v) + else + allocate(res(1)) + res(1) = psb_err_invalid_vect_state_ + end if + + end function z_mvect_dot_v + + function z_mvect_dot_vect(n,x,y) result(res) + implicit none + class(psb_z_multivect_type), intent(inout) :: x + class(psb_z_vect_type), intent(inout) :: y + integer(psb_ipk_), intent(in) :: n + complex(psb_dpk_), dimension(:), allocatable :: res + + if (allocated(x%v).and.allocated(y%v)) then + res = x%v%dot(n,y%v) + else + allocate(res(1)) + res(1) = psb_err_invalid_vect_state_ + end if + + end function z_mvect_dot_vect + + function z_mvect_dot_a(n,x,y) result(res) + implicit none + class(psb_z_multivect_type), intent(inout) :: x + complex(psb_dpk_), intent(in) :: y(:,:) + integer(psb_ipk_), intent(in) :: n + complex(psb_dpk_), dimension(:), allocatable :: res + + if (allocated(x%v)) then + res = x%v%dot(n,y) + else + allocate(res(1)) + res(1) = psb_err_invalid_vect_state_ + end if + + end function z_mvect_dot_a + !!$ subroutine z_mvect_axpby_v(m,alpha, x, beta, y, info) !!$ use psi_serial_mod !!$ implicit none diff --git a/base/psblas/psb_cdot.f90 b/base/psblas/psb_cdot.f90 index ed300b7ce..d53812070 100644 --- a/base/psblas/psb_cdot.f90 +++ b/base/psblas/psb_cdot.f90 @@ -157,6 +157,350 @@ function psb_cdot_vect(x, y, desc_a,info,global) result(res) return end function psb_cdot_vect +! +! Parallel Sparse BLAS version 3.5 +! (C) Copyright 2006-2018 +! Salvatore Filippone +! Alfredo Buttari +! +! Redistribution and use in source and binary forms, with or without +! modification, are permitted provided that the following conditions +! are met: +! 1. Redistributions of source code must retain the above copyright +! notice, this list of conditions and the following disclaimer. +! 2. Redistributions in binary form must reproduce the above copyright +! notice, this list of conditions, and the following disclaimer in the +! documentation and/or other materials provided with the distribution. +! 3. The name of the PSBLAS group or the names of its contributors may +! not be used to endorse or promote products derived from this +! software without specific written permission. +! +! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +! ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +! TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +! PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE PSBLAS GROUP OR ITS CONTRIBUTORS +! BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +! CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +! SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +! INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +! CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +! POSSIBILITY OF SUCH DAMAGE. +! +! +! File: psb_cdot.f90 +! +! Function: psb_cdot_multivect +! psb_cdot computes the dot product of two distributed vectors, +! +! dot(:) := ( X(:) )**C * ( Y(:) ) +! +! +! Arguments: +! x - type(psb_d_vect_type) The input vector containing the entries of sub( X ). +! y - type(psb_d_vect_type) The input vector containing the entries of sub( Y ). +! desc_a - type(psb_desc_type). The communication descriptor. +! info - integer. Return code +! global - logical(optional) Whether to perform the global sum, default: .true. +! +! Note: from a functional point of view, X and Y are input, but here +! they are declared INOUT because of the sync() methods. +! +! +function psb_cdot_multivect(x, y, desc_a,info,global) result(res) + use psb_desc_mod + use psb_d_base_mat_mod + use psb_check_mod + use psb_error_mod + use psb_penv_mod + use psb_c_multivect_mod + use psb_c_psblas_mod, psb_protect_name => psb_cdot_multivect + implicit none + complex(psb_spk_), dimension(:), allocatable :: res + type(psb_c_multivect_type), intent(inout) :: x, y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global + + ! locals + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me, idx, ndm,& + & err_act, iix, jjx, iiy, jjy, i, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n, nx, ny + logical :: global_ + character(len=20) :: name, ch_err + + name='psb_cdot_multivect' + info=psb_success_ + call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if + + ctxt=desc_a%get_context() + call psb_info(ctxt, me, np) + if (np == -ione) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(y%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + + if (present(global)) then + global_ = global + else + global_ = .true. + end if + + ix = ione + ijx = ione + + iy = ione + ijy = ione + + m = desc_a%get_global_rows() + nx = x%get_ncols() + ny = y%get_ncols() + + ! check vector correctness + call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) + if (info == psb_success_) & + & call psb_chkvect(m,ny,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + + if ((iix /= ione).or.(iiy /= ione)) then + info=psb_err_ix_n1_iy_n1_unsupported_ + call psb_errpush(info,name) + goto 9999 + end if + + if (x%get_ncols() /= y%get_ncols()) then + info=psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + else + allocate(res(x%get_ncols()),stat=info) + if (info /= 0) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + end if + + nr = desc_a%get_local_rows() + if(nr > 0) then + res = x%dot(nr,y) + ! FIXME + ! adjust dot_local because overlapped elements are computed more than once + if (size(desc_a%ovrlap_elem,1)>0) then + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + do i=1,size(desc_a%ovrlap_elem,1) + idx = desc_a%ovrlap_elem(i,1) + ndm = desc_a%ovrlap_elem(i,2) + res(:) = res(:) - (real(ndm-1)/real(ndm))*(x%v%v(idx,:)*y%v%v(idx,:)) + end do + end if + else + res = czero + end if + + ! compute global sum + if (global_) call psb_sum(ctxt, res) + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + +end function psb_cdot_multivect +! +! Parallel Sparse BLAS version 3.5 +! (C) Copyright 2006-2018 +! Salvatore Filippone +! Alfredo Buttari +! +! Redistribution and use in source and binary forms, with or without +! modification, are permitted provided that the following conditions +! are met: +! 1. Redistributions of source code must retain the above copyright +! notice, this list of conditions and the following disclaimer. +! 2. Redistributions in binary form must reproduce the above copyright +! notice, this list of conditions, and the following disclaimer in the +! documentation and/or other materials provided with the distribution. +! 3. The name of the PSBLAS group or the names of its contributors may +! not be used to endorse or promote products derived from this +! software without specific written permission. +! +! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +! ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +! TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +! PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE PSBLAS GROUP OR ITS CONTRIBUTORS +! BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +! CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +! SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +! INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +! CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +! POSSIBILITY OF SUCH DAMAGE. +! +! +! File: psb_ddot.f90 +! +! Function: psb_cdot_mvect_vect +! psb_ddot computes the dot product of two distributed vectors, +! +! dot(:) := ( X )**C * ( Y ) +! +! +! Arguments: +! x - type(psb_c_multivect_type) The input vector containing the entries of sub( X ). +! y - type(psb_c_vect_type) The input vector containing the entries of sub( Y ). +! desc_a - type(psb_desc_type). The communication descriptor. +! info - integer. Return code +! global - logical(optional) Whether to perform the global sum, default: .true. +! +! Note: from a functional point of view, X and Y are input, but here +! they are declared INOUT because of the sync() methods. +! +! +function psb_cdot_mvect_vect(x, y, desc_a,info,global) result(res) + use psb_desc_mod + use psb_c_base_mat_mod + use psb_check_mod + use psb_error_mod + use psb_penv_mod + use psb_c_multivect_mod + use psb_c_vect_mod + use psb_c_psblas_mod, psb_protect_name => psb_cdot_mvect_vect + implicit none + complex(psb_spk_), dimension(:), allocatable :: res + type(psb_c_multivect_type), intent(inout) :: x + type(psb_c_vect_type), intent(inout) :: y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global + + ! locals + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me, idx, ndm,& + & err_act, iix, jjx, iiy, jjy, i, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n + logical :: global_ + character(len=20) :: name, ch_err + + name='psb_cdot_mvect_vect' + info=psb_success_ + call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if + + ctxt=desc_a%get_context() + call psb_info(ctxt, me, np) + if (np == -ione) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(y%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + + if (present(global)) then + global_ = global + else + global_ = .true. + end if + + ix = ione + ijx = ione + + iy = ione + ijy = ione + + m = desc_a%get_global_rows() + n = x%get_ncols() + + ! check vector correctness + call psb_chkvect(m,n,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) + if (info == psb_success_) & + & call psb_chkvect(m,lone,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + + if ((iix /= ione).or.(iiy /= ione)) then + info=psb_err_ix_n1_iy_n1_unsupported_ + call psb_errpush(info,name) + goto 9999 + end if + + allocate(res(x%get_ncols()),stat=info) + if (info /= 0) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + + nr = desc_a%get_local_rows() + if(nr > 0) then + res = x%dot(nr,y) + ! FIXME + ! adjust dot_local because overlapped elements are computed more than once + if (size(desc_a%ovrlap_elem,1)>0) then + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + do i=1,size(desc_a%ovrlap_elem,1) + idx = desc_a%ovrlap_elem(i,1) + ndm = desc_a%ovrlap_elem(i,2) + ! Remove the overlapped elements via cgemv calls + ! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res + call cgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), & + & size(x%v%v,1),y%v%v(idx),ione,done,res,ione) + end do + end if + else + res = czero + end if + + ! compute global sum + if (global_) call psb_sum(ctxt, res) + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + +end function psb_cdot_mvect_vect ! ! Function: psb_cdot ! psb_cdot computes the dot product of two distributed vectors, diff --git a/base/psblas/psb_ddot.f90 b/base/psblas/psb_ddot.f90 index ff4e91540..98adb42a2 100644 --- a/base/psblas/psb_ddot.f90 +++ b/base/psblas/psb_ddot.f90 @@ -190,10 +190,10 @@ end function psb_ddot_vect ! ! File: psb_ddot.f90 ! -! Function: psb_ddot_vect +! Function: psb_ddot_multivect ! psb_ddot computes the dot product of two distributed vectors, ! -! dot := ( X )**C * ( Y ) +! dot(:) := ( X(:) )**C * ( Y(:) ) ! ! ! Arguments: @@ -226,7 +226,7 @@ function psb_ddot_multivect(x, y, desc_a,info,global) result(res) type(psb_ctxt_type) :: ctxt integer(psb_ipk_) :: np, me, idx, ndm,& & err_act, iix, jjx, iiy, jjy, i, nr - integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n, nx, ny logical :: global_ character(len=20) :: name, ch_err @@ -268,11 +268,13 @@ function psb_ddot_multivect(x, y, desc_a,info,global) result(res) ijy = ione m = desc_a%get_global_rows() + nx = x%get_ncols() + ny = y%get_ncols() ! check vector correctness - call psb_chkvect(m,x%get_ncols(),x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) + call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) if (info == psb_success_) & - & call psb_chkvect(m,y%get_ncols(),y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy) + & call psb_chkvect(m,ny,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy) if(info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='psb_chkvect' @@ -361,14 +363,14 @@ end function psb_ddot_multivect ! ! File: psb_ddot.f90 ! -! Function: psb_ddot_vect +! Function: psb_ddot_mvect_vect ! psb_ddot computes the dot product of two distributed vectors, ! -! dot := ( X )**C * ( Y ) +! dot(:) := ( X )**C * ( Y ) ! ! ! Arguments: -! x - type(psb_d_vect_type) The input vector containing the entries of sub( X ). +! x - type(psb_d_multivect_type) The input vector containing the entries of sub( X ). ! y - type(psb_d_vect_type) The input vector containing the entries of sub( Y ). ! desc_a - type(psb_desc_type). The communication descriptor. ! info - integer. Return code @@ -441,9 +443,10 @@ function psb_ddot_mvect_vect(x, y, desc_a,info,global) result(res) ijy = ione m = desc_a%get_global_rows() + n = x%get_ncols() ! check vector correctness - call psb_chkvect(m,x%get_ncols(),x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) + call psb_chkvect(m,n,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) if (info == psb_success_) & & call psb_chkvect(m,lone,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy) if(info /= psb_success_) then @@ -479,7 +482,7 @@ function psb_ddot_mvect_vect(x, y, desc_a,info,global) result(res) ndm = desc_a%ovrlap_elem(i,2) ! Remove the overlapped elements via dgemv calls ! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res - call dgemv('T',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), & + call dgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), & & size(x%v%v,1),y%v%v(idx),ione,done,res,ione) end do end if diff --git a/base/psblas/psb_sdot.f90 b/base/psblas/psb_sdot.f90 index cf0678a71..1ee5ae7ea 100644 --- a/base/psblas/psb_sdot.f90 +++ b/base/psblas/psb_sdot.f90 @@ -157,6 +157,350 @@ function psb_sdot_vect(x, y, desc_a,info,global) result(res) return end function psb_sdot_vect +! +! Parallel Sparse BLAS version 3.5 +! (C) Copyright 2006-2018 +! Salvatore Filippone +! Alfredo Buttari +! +! Redistribution and use in source and binary forms, with or without +! modification, are permitted provided that the following conditions +! are met: +! 1. Redistributions of source code must retain the above copyright +! notice, this list of conditions and the following disclaimer. +! 2. Redistributions in binary form must reproduce the above copyright +! notice, this list of conditions, and the following disclaimer in the +! documentation and/or other materials provided with the distribution. +! 3. The name of the PSBLAS group or the names of its contributors may +! not be used to endorse or promote products derived from this +! software without specific written permission. +! +! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +! ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +! TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +! PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE PSBLAS GROUP OR ITS CONTRIBUTORS +! BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +! CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +! SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +! INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +! CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +! POSSIBILITY OF SUCH DAMAGE. +! +! +! File: psb_sdot.f90 +! +! Function: psb_sdot_multivect +! psb_sdot computes the dot product of two distributed vectors, +! +! dot(:) := ( X(:) )**C * ( Y(:) ) +! +! +! Arguments: +! x - type(psb_d_vect_type) The input vector containing the entries of sub( X ). +! y - type(psb_d_vect_type) The input vector containing the entries of sub( Y ). +! desc_a - type(psb_desc_type). The communication descriptor. +! info - integer. Return code +! global - logical(optional) Whether to perform the global sum, default: .true. +! +! Note: from a functional point of view, X and Y are input, but here +! they are declared INOUT because of the sync() methods. +! +! +function psb_sdot_multivect(x, y, desc_a,info,global) result(res) + use psb_desc_mod + use psb_d_base_mat_mod + use psb_check_mod + use psb_error_mod + use psb_penv_mod + use psb_s_multivect_mod + use psb_s_psblas_mod, psb_protect_name => psb_sdot_multivect + implicit none + real(psb_spk_), dimension(:), allocatable :: res + type(psb_s_multivect_type), intent(inout) :: x, y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global + + ! locals + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me, idx, ndm,& + & err_act, iix, jjx, iiy, jjy, i, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n, nx, ny + logical :: global_ + character(len=20) :: name, ch_err + + name='psb_sdot_multivect' + info=psb_success_ + call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if + + ctxt=desc_a%get_context() + call psb_info(ctxt, me, np) + if (np == -ione) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(y%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + + if (present(global)) then + global_ = global + else + global_ = .true. + end if + + ix = ione + ijx = ione + + iy = ione + ijy = ione + + m = desc_a%get_global_rows() + nx = x%get_ncols() + ny = y%get_ncols() + + ! check vector correctness + call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) + if (info == psb_success_) & + & call psb_chkvect(m,ny,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + + if ((iix /= ione).or.(iiy /= ione)) then + info=psb_err_ix_n1_iy_n1_unsupported_ + call psb_errpush(info,name) + goto 9999 + end if + + if (x%get_ncols() /= y%get_ncols()) then + info=psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + else + allocate(res(x%get_ncols()),stat=info) + if (info /= 0) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + end if + + nr = desc_a%get_local_rows() + if(nr > 0) then + res = x%dot(nr,y) + ! FIXME + ! adjust dot_local because overlapped elements are computed more than once + if (size(desc_a%ovrlap_elem,1)>0) then + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + do i=1,size(desc_a%ovrlap_elem,1) + idx = desc_a%ovrlap_elem(i,1) + ndm = desc_a%ovrlap_elem(i,2) + res(:) = res(:) - (real(ndm-1)/real(ndm))*(x%v%v(idx,:)*y%v%v(idx,:)) + end do + end if + else + res = szero + end if + + ! compute global sum + if (global_) call psb_sum(ctxt, res) + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + +end function psb_sdot_multivect +! +! Parallel Sparse BLAS version 3.5 +! (C) Copyright 2006-2018 +! Salvatore Filippone +! Alfredo Buttari +! +! Redistribution and use in source and binary forms, with or without +! modification, are permitted provided that the following conditions +! are met: +! 1. Redistributions of source code must retain the above copyright +! notice, this list of conditions and the following disclaimer. +! 2. Redistributions in binary form must reproduce the above copyright +! notice, this list of conditions, and the following disclaimer in the +! documentation and/or other materials provided with the distribution. +! 3. The name of the PSBLAS group or the names of its contributors may +! not be used to endorse or promote products derived from this +! software without specific written permission. +! +! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +! ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +! TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +! PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE PSBLAS GROUP OR ITS CONTRIBUTORS +! BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +! CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +! SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +! INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +! CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +! POSSIBILITY OF SUCH DAMAGE. +! +! +! File: psb_ddot.f90 +! +! Function: psb_sdot_mvect_vect +! psb_ddot computes the dot product of two distributed vectors, +! +! dot(:) := ( X )**C * ( Y ) +! +! +! Arguments: +! x - type(psb_s_multivect_type) The input vector containing the entries of sub( X ). +! y - type(psb_s_vect_type) The input vector containing the entries of sub( Y ). +! desc_a - type(psb_desc_type). The communication descriptor. +! info - integer. Return code +! global - logical(optional) Whether to perform the global sum, default: .true. +! +! Note: from a functional point of view, X and Y are input, but here +! they are declared INOUT because of the sync() methods. +! +! +function psb_sdot_mvect_vect(x, y, desc_a,info,global) result(res) + use psb_desc_mod + use psb_s_base_mat_mod + use psb_check_mod + use psb_error_mod + use psb_penv_mod + use psb_s_multivect_mod + use psb_s_vect_mod + use psb_s_psblas_mod, psb_protect_name => psb_sdot_mvect_vect + implicit none + real(psb_spk_), dimension(:), allocatable :: res + type(psb_s_multivect_type), intent(inout) :: x + type(psb_s_vect_type), intent(inout) :: y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global + + ! locals + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me, idx, ndm,& + & err_act, iix, jjx, iiy, jjy, i, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n + logical :: global_ + character(len=20) :: name, ch_err + + name='psb_sdot_mvect_vect' + info=psb_success_ + call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if + + ctxt=desc_a%get_context() + call psb_info(ctxt, me, np) + if (np == -ione) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(y%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + + if (present(global)) then + global_ = global + else + global_ = .true. + end if + + ix = ione + ijx = ione + + iy = ione + ijy = ione + + m = desc_a%get_global_rows() + n = x%get_ncols() + + ! check vector correctness + call psb_chkvect(m,n,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) + if (info == psb_success_) & + & call psb_chkvect(m,lone,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + + if ((iix /= ione).or.(iiy /= ione)) then + info=psb_err_ix_n1_iy_n1_unsupported_ + call psb_errpush(info,name) + goto 9999 + end if + + allocate(res(x%get_ncols()),stat=info) + if (info /= 0) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + + nr = desc_a%get_local_rows() + if(nr > 0) then + res = x%dot(nr,y) + ! FIXME + ! adjust dot_local because overlapped elements are computed more than once + if (size(desc_a%ovrlap_elem,1)>0) then + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + do i=1,size(desc_a%ovrlap_elem,1) + idx = desc_a%ovrlap_elem(i,1) + ndm = desc_a%ovrlap_elem(i,2) + ! Remove the overlapped elements via sgemv calls + ! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res + call sgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), & + & size(x%v%v,1),y%v%v(idx),ione,done,res,ione) + end do + end if + else + res = szero + end if + + ! compute global sum + if (global_) call psb_sum(ctxt, res) + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + +end function psb_sdot_mvect_vect ! ! Function: psb_sdot ! psb_sdot computes the dot product of two distributed vectors, diff --git a/base/psblas/psb_zdot.f90 b/base/psblas/psb_zdot.f90 index 97ecbedf3..ff2d74013 100644 --- a/base/psblas/psb_zdot.f90 +++ b/base/psblas/psb_zdot.f90 @@ -157,6 +157,350 @@ function psb_zdot_vect(x, y, desc_a,info,global) result(res) return end function psb_zdot_vect +! +! Parallel Sparse BLAS version 3.5 +! (C) Copyright 2006-2018 +! Salvatore Filippone +! Alfredo Buttari +! +! Redistribution and use in source and binary forms, with or without +! modification, are permitted provided that the following conditions +! are met: +! 1. Redistributions of source code must retain the above copyright +! notice, this list of conditions and the following disclaimer. +! 2. Redistributions in binary form must reproduce the above copyright +! notice, this list of conditions, and the following disclaimer in the +! documentation and/or other materials provided with the distribution. +! 3. The name of the PSBLAS group or the names of its contributors may +! not be used to endorse or promote products derived from this +! software without specific written permission. +! +! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +! ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +! TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +! PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE PSBLAS GROUP OR ITS CONTRIBUTORS +! BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +! CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +! SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +! INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +! CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +! POSSIBILITY OF SUCH DAMAGE. +! +! +! File: psb_zdot.f90 +! +! Function: psb_zdot_multivect +! psb_zdot computes the dot product of two distributed vectors, +! +! dot(:) := ( X(:) )**C * ( Y(:) ) +! +! +! Arguments: +! x - type(psb_d_vect_type) The input vector containing the entries of sub( X ). +! y - type(psb_d_vect_type) The input vector containing the entries of sub( Y ). +! desc_a - type(psb_desc_type). The communication descriptor. +! info - integer. Return code +! global - logical(optional) Whether to perform the global sum, default: .true. +! +! Note: from a functional point of view, X and Y are input, but here +! they are declared INOUT because of the sync() methods. +! +! +function psb_zdot_multivect(x, y, desc_a,info,global) result(res) + use psb_desc_mod + use psb_d_base_mat_mod + use psb_check_mod + use psb_error_mod + use psb_penv_mod + use psb_z_multivect_mod + use psb_z_psblas_mod, psb_protect_name => psb_zdot_multivect + implicit none + complex(psb_dpk_), dimension(:), allocatable :: res + type(psb_z_multivect_type), intent(inout) :: x, y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global + + ! locals + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me, idx, ndm,& + & err_act, iix, jjx, iiy, jjy, i, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n, nx, ny + logical :: global_ + character(len=20) :: name, ch_err + + name='psb_zdot_multivect' + info=psb_success_ + call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if + + ctxt=desc_a%get_context() + call psb_info(ctxt, me, np) + if (np == -ione) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(y%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + + if (present(global)) then + global_ = global + else + global_ = .true. + end if + + ix = ione + ijx = ione + + iy = ione + ijy = ione + + m = desc_a%get_global_rows() + nx = x%get_ncols() + ny = y%get_ncols() + + ! check vector correctness + call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) + if (info == psb_success_) & + & call psb_chkvect(m,ny,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + + if ((iix /= ione).or.(iiy /= ione)) then + info=psb_err_ix_n1_iy_n1_unsupported_ + call psb_errpush(info,name) + goto 9999 + end if + + if (x%get_ncols() /= y%get_ncols()) then + info=psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + else + allocate(res(x%get_ncols()),stat=info) + if (info /= 0) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + end if + + nr = desc_a%get_local_rows() + if(nr > 0) then + res = x%dot(nr,y) + ! FIXME + ! adjust dot_local because overlapped elements are computed more than once + if (size(desc_a%ovrlap_elem,1)>0) then + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + do i=1,size(desc_a%ovrlap_elem,1) + idx = desc_a%ovrlap_elem(i,1) + ndm = desc_a%ovrlap_elem(i,2) + res(:) = res(:) - (real(ndm-1)/real(ndm))*(x%v%v(idx,:)*y%v%v(idx,:)) + end do + end if + else + res = zzero + end if + + ! compute global sum + if (global_) call psb_sum(ctxt, res) + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + +end function psb_zdot_multivect +! +! Parallel Sparse BLAS version 3.5 +! (C) Copyright 2006-2018 +! Salvatore Filippone +! Alfredo Buttari +! +! Redistribution and use in source and binary forms, with or without +! modification, are permitted provided that the following conditions +! are met: +! 1. Redistributions of source code must retain the above copyright +! notice, this list of conditions and the following disclaimer. +! 2. Redistributions in binary form must reproduce the above copyright +! notice, this list of conditions, and the following disclaimer in the +! documentation and/or other materials provided with the distribution. +! 3. The name of the PSBLAS group or the names of its contributors may +! not be used to endorse or promote products derived from this +! software without specific written permission. +! +! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +! ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +! TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +! PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE PSBLAS GROUP OR ITS CONTRIBUTORS +! BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +! CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +! SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +! INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +! CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +! POSSIBILITY OF SUCH DAMAGE. +! +! +! File: psb_ddot.f90 +! +! Function: psb_zdot_mvect_vect +! psb_ddot computes the dot product of two distributed vectors, +! +! dot(:) := ( X )**C * ( Y ) +! +! +! Arguments: +! x - type(psb_z_multivect_type) The input vector containing the entries of sub( X ). +! y - type(psb_z_vect_type) The input vector containing the entries of sub( Y ). +! desc_a - type(psb_desc_type). The communication descriptor. +! info - integer. Return code +! global - logical(optional) Whether to perform the global sum, default: .true. +! +! Note: from a functional point of view, X and Y are input, but here +! they are declared INOUT because of the sync() methods. +! +! +function psb_zdot_mvect_vect(x, y, desc_a,info,global) result(res) + use psb_desc_mod + use psb_z_base_mat_mod + use psb_check_mod + use psb_error_mod + use psb_penv_mod + use psb_z_multivect_mod + use psb_z_vect_mod + use psb_z_psblas_mod, psb_protect_name => psb_zdot_mvect_vect + implicit none + complex(psb_dpk_), dimension(:), allocatable :: res + type(psb_z_multivect_type), intent(inout) :: x + type(psb_z_vect_type), intent(inout) :: y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global + + ! locals + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me, idx, ndm,& + & err_act, iix, jjx, iiy, jjy, i, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n + logical :: global_ + character(len=20) :: name, ch_err + + name='psb_zdot_mvect_vect' + info=psb_success_ + call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if + + ctxt=desc_a%get_context() + call psb_info(ctxt, me, np) + if (np == -ione) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(y%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + + if (present(global)) then + global_ = global + else + global_ = .true. + end if + + ix = ione + ijx = ione + + iy = ione + ijy = ione + + m = desc_a%get_global_rows() + n = x%get_ncols() + + ! check vector correctness + call psb_chkvect(m,n,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) + if (info == psb_success_) & + & call psb_chkvect(m,lone,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + + if ((iix /= ione).or.(iiy /= ione)) then + info=psb_err_ix_n1_iy_n1_unsupported_ + call psb_errpush(info,name) + goto 9999 + end if + + allocate(res(x%get_ncols()),stat=info) + if (info /= 0) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + + nr = desc_a%get_local_rows() + if(nr > 0) then + res = x%dot(nr,y) + ! FIXME + ! adjust dot_local because overlapped elements are computed more than once + if (size(desc_a%ovrlap_elem,1)>0) then + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + do i=1,size(desc_a%ovrlap_elem,1) + idx = desc_a%ovrlap_elem(i,1) + ndm = desc_a%ovrlap_elem(i,2) + ! Remove the overlapped elements via zgemv calls + ! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res + call zgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), & + & size(x%v%v,1),y%v%v(idx),ione,done,res,ione) + end do + end if + else + res = zzero + end if + + ! compute global sum + if (global_) call psb_sum(ctxt, res) + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + +end function psb_zdot_mvect_vect ! ! Function: psb_zdot ! psb_zdot computes the dot product of two distributed vectors,