From 4823c5662aaf3e0ecd929a56450ea7851dbedabc Mon Sep 17 00:00:00 2001 From: Fabio Durastante Date: Tue, 8 Apr 2025 14:16:32 +0200 Subject: [PATCH] Added inner product implementation --- base/modules/psblas/psb_d_psblas_mod.F90 | 9 + base/modules/serial/psb_d_base_vect_mod.F90 | 114 +++++++++++- base/modules/serial/psb_d_vect_mod.F90 | 38 ++++ base/psblas/psb_dvmlt.f90 | 182 ++++++++++++++++++++ test/kernel/vecoperation.f90 | 13 +- 5 files changed, 354 insertions(+), 2 deletions(-) diff --git a/base/modules/psblas/psb_d_psblas_mod.F90 b/base/modules/psblas/psb_d_psblas_mod.F90 index ca61729c..6ed3eed7 100644 --- a/base/modules/psblas/psb_d_psblas_mod.F90 +++ b/base/modules/psblas/psb_d_psblas_mod.F90 @@ -540,6 +540,15 @@ module psb_d_psblas_mod integer(psb_ipk_), intent(out) :: info character(len=1), intent(in), optional :: conjgx, conjgy end subroutine psb_dmlt_vect2 + subroutine psb_dmlt_multivect(x, y, res, desc_a,info,global) + import :: psb_desc_type, psb_dpk_, psb_ipk_, & + & psb_d_multivect_type, psb_dspmat_type + real(psb_dpk_), dimension(:,:), allocatable :: res + type(psb_d_multivect_type), intent(inout) :: x, y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global + end subroutine psb_dmlt_multivect end interface interface psb_gediv diff --git a/base/modules/serial/psb_d_base_vect_mod.F90 b/base/modules/serial/psb_d_base_vect_mod.F90 index c626c70e..2a2a6902 100644 --- a/base/modules/serial/psb_d_base_vect_mod.F90 +++ b/base/modules/serial/psb_d_base_vect_mod.F90 @@ -2356,13 +2356,15 @@ module psb_d_base_multivect_mod procedure, pass(y) :: mlt_ar2 => d_base_mlv_mlt_ar2 procedure, pass(z) :: mlt_a_2 => d_base_mlv_mlt_a_2 procedure, pass(z) :: mlt_v_2 => d_base_mlv_mlt_v_2 + procedure, pass(x) :: mlt_mv2 => d_base_mlv_mlt_mv2 !!$ procedure, pass(z) :: mlt_va => d_base_mlv_mlt_va !!$ procedure, pass(z) :: mlt_av => d_base_mlv_mlt_av generic, public :: mlt => mlt_mv, mlt_mv_v, mlt_ar1, mlt_ar2, & - & mlt_a_2, mlt_v_2 !, mlt_av, mlt_va + & mlt_a_2, mlt_v_2, mlt_mv2 !, mlt_av, mlt_va ! ! Scaling and norms ! + procedure, pass(x) :: trslv => d_base_mlv_trslv procedure, pass(x) :: scal => d_base_mlv_scal procedure, pass(x) :: nrm2 => d_base_mlv_nrm2 procedure, pass(x) :: amax => d_base_mlv_amax @@ -2851,6 +2853,72 @@ contains res(1:m,1:n) = x%v(1:m,1:n) end function d_base_mlv_get_vect + ! + !> subroutine d_base_mlv_trslv + !! \memberof psb_d_base_multivect_type + !! \brief Computes X = X / A with A an upper triangular matrix + !! \param n Number of entries to be considered + !! \param x The multivector to be used for the division + !! \param uplo 'U' for upper triangular, 'L' for lower triangular + !! \param a The matrix to be used for the division + !! \param alpha (optional) The scaling factor + !! \param trans (optional) 'N' for no transpose, 'T' for transpose + !! \param diag (optional) 'N' for non-unit diagonal, 'U' for unit diagonal + !! \param info return code + !! + subroutine d_base_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info) + implicit none + class(psb_d_base_multivect_type), intent(inout) :: x + real(psb_dpk_), intent(in) :: a(:,:) + integer(psb_ipk_), intent(in) :: n + character(len=1), intent(in) :: uplo + real(psb_dpk_), intent(in), optional :: alpha + character(len=1), intent(in), optional :: trans, diag + integer(psb_ipk_), intent(out) :: info + ! Local variables + integer(psb_ipk_) :: lda, ldb + character(len=1) :: trans_, diag_, side + real(psb_dpk_) :: alpha_ + + ! Default values + if (.not.present(alpha)) then + alpha_ = done + else + alpha_ = alpha + end if + if (.not.present(trans)) then + trans_ = 'N' + else + trans_ = trans + end if + if (.not.present(diag)) then + diag_ = 'N' + else + diag_ = diag + end if + + info = psb_success_ + ! Check that a is square + if (size(a,1) /= size(a,2)) then + info = psb_err_invalid_input_ + return + end if + ! Check that a has the same number of columns as x + if (size(a,2) /= x%get_ncols()) then + info = psb_err_invalid_input_ + return + end if + if (x%is_dev()) call x%sync() + if (x%is_sync()) then + ! Call BLAS function to solve the system + lda = n + ldb = n + side = 'R' ! X*op( A ) = alpha*B. + call dtrsm(side, uplo, trans_, diag_, n, x%get_ncols(), alpha_, a, lda, x%v, ldb) + end if + end subroutine d_base_mlv_trslv + + ! ! Reset all values ! @@ -3066,6 +3134,50 @@ contains end subroutine d_base_mlv_axpby_a + !> Function base_mlv_mlt_mv2 + !! \memberof psb_d_base_multivect_type + !! \brief computes A = transpose(X)*Y + !! \param x The class(base_mlv_vect) to be multiplied by + !! \param y The class(base_mlv_vect) to be multiplied by + !! \param a The resulting matrix + !! \param info return code + subroutine d_base_mlv_mlt_mv2(x,y,a,info) + use psi_serial_mod + implicit none + class(psb_d_base_multivect_type), intent(inout) :: x + class(psb_d_base_multivect_type), intent(inout) :: y + real(psb_dpk_), intent(inout), allocatable :: a(:,:) + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + + if (allocated(a)) then + if (size(a,1) /= x%get_ncols()) then + info = psb_err_invalid_input_ + return + end if + if (size(a,2) /= y%get_ncols()) then + info = psb_err_invalid_input_ + return + end if + else + allocate(a(x%get_ncols(),y%get_ncols()),stat=info) + if (info /= 0) call psb_errpush(psb_err_alloc_dealloc_,'base_mlv_mlt_mv2') + end if + ! We do the multiplication by using the BLAS function + ! dgemm, which computes the matrix-matrix product + ! C = alpha*op( A )*op( B ) + beta*C + ! In our case, we want to compute + ! C = X'*Y + call dgemm('T', 'N', x%get_ncols(), y%get_ncols(), x%get_nrows(), done, & + & x%v, x%get_nrows(), y%v, y%get_nrows(), dzero, a, x%get_ncols()) + + end subroutine d_base_mlv_mlt_mv2 + + + ! ! Multiple variants of two operations: ! Simple multiplication Y(:.:) = X(:,:)*Y(:,:) diff --git a/base/modules/serial/psb_d_vect_mod.F90 b/base/modules/serial/psb_d_vect_mod.F90 index abdac604..29e59ce5 100644 --- a/base/modules/serial/psb_d_vect_mod.F90 +++ b/base/modules/serial/psb_d_vect_mod.F90 @@ -1423,6 +1423,9 @@ module psb_d_multivect_mod procedure, pass(x) :: dot_a => d_mvect_dot_a procedure, pass(x) :: dot_a_vect => d_mvect_dot_vect generic, public :: dot => dot_v, dot_a, dot_a_vect + procedure, pass(x) :: trslv => d_mlv_trslv + procedure, pass(x) :: mlt_mv2 => d_mvect_mlt_mv2 + generic, public :: mlt => mlt_mv2 !!$ procedure, pass(y) :: axpby_v => d_mvect_axpby_v !!$ procedure, pass(y) :: axpby_a => d_mvect_axpby_a !!$ generic, public :: axpby => axpby_v, axpby_a @@ -1955,6 +1958,41 @@ contains end function d_mvect_dot_a + subroutine d_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info) + implicit none + class(psb_d_multivect_type), intent(inout) :: x + real(psb_dpk_), intent(in) :: a(:,:) + integer(psb_ipk_), intent(in) :: n + character(len=1), intent(in) :: uplo + real(psb_dpk_), intent(in), optional :: alpha + character(len=1), intent(in), optional :: trans, diag + integer(psb_ipk_), intent(out) :: info + + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + return + else + call x%v%trslv(n,a,uplo,alpha=alpha,trans=trans,diag=diag,info=info) + end if + end subroutine d_mlv_trslv + + subroutine d_mvect_mlt_mv2(x,y,a,info) + use psi_serial_mod + implicit none + class(psb_d_multivect_type), intent(inout) :: x + class(psb_d_multivect_type), intent(inout) :: y + real(psb_dpk_), intent(inout), allocatable :: a(:,:) + integer(psb_ipk_), intent(out) :: info + + if (allocated(x%v).and.allocated(y%v)) then + call y%v%mlt(x%v,a,info) + else + info = psb_err_invalid_vect_state_ + return + end if + + end subroutine d_mvect_mlt_mv2 + !!$ subroutine d_mvect_axpby_v(m,alpha, x, beta, y, info) !!$ use psi_serial_mod !!$ implicit none diff --git a/base/psblas/psb_dvmlt.f90 b/base/psblas/psb_dvmlt.f90 index ea76e57a..c6995db7 100644 --- a/base/psblas/psb_dvmlt.f90 +++ b/base/psblas/psb_dvmlt.f90 @@ -109,3 +109,185 @@ subroutine psb_dvmlt(x,y,desc_a,info) return end subroutine psb_dvmlt + +! +! Parallel Sparse BLAS version 3.5 +! (C) Copyright 2006-2018 +! Salvatore Filippone +! Alfredo Buttari +! +! Redistribution and use in source and binary forms, with or without +! modification, are permitted provided that the following conditions +! are met: +! 1. Redistributions of source code must retain the above copyright +! notice, this list of conditions and the following disclaimer. +! 2. Redistributions in binary form must reproduce the above copyright +! notice, this list of conditions, and the following disclaimer in the +! documentation and/or other materials provided with the distribution. +! 3. The name of the PSBLAS group or the names of its contributors may +! not be used to endorse or promote products derived from this +! software without specific written permission. +! +! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +! ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +! TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +! PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE PSBLAS GROUP OR ITS CONTRIBUTORS +! BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +! CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +! SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +! INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +! CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +! POSSIBILITY OF SUCH DAMAGE. +! +! +! File: psb_dvmlt.f90 +! +! Function: psb_ddot_multivect +! psb_dvmlta computes the inner product of two distributed multivectors, +! +! dot(:,:) := ( X(:,:) )**C * ( Y(:,:) ) +! +! +! Arguments: +! x - type(psb_d_multivect_type) The input vector containing the entries of sub( X ). +! y - type(psb_d_multivect_type) The input vector containing the entries of sub( Y ). +! desc_a - type(psb_desc_type). The communication descriptor. +! info - integer. Return code +! global - logical(optional) Whether to perform the global sum, default: .true. +! +! Note: from a functional point of view, X and Y are input, but here +! they are declared INOUT because of the sync() methods. +! +! +subroutine psb_dmlt_multivect(x, y, desc_a,res,info,global) + use psb_desc_mod + use psb_d_base_mat_mod + use psb_check_mod + use psb_error_mod + use psb_penv_mod + use psb_d_multivect_mod + use psb_d_psblas_mod, psb_protect_name => psb_dmlt_multivect + implicit none + real(psb_dpk_), dimension(:,:), allocatable :: res + type(psb_d_multivect_type), intent(inout) :: x, y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global + + ! locals + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me, idx, ndm,& + & err_act, iix, jjx, iiy, jjy, i, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n, nx, ny + logical :: global_ + character(len=20) :: name, ch_err + + name='psb_dmlt_multivect' + info=psb_success_ + call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if + + ctxt=desc_a%get_context() + call psb_info(ctxt, me, np) + if (np == -ione) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(y%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + + if (present(global)) then + global_ = global + else + global_ = .true. + end if + + ix = ione + ijx = ione + + iy = ione + ijy = ione + + m = desc_a%get_global_rows() + nx = x%get_ncols() + ny = y%get_ncols() + + ! check vector correctness + call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) + if (info == psb_success_) & + & call psb_chkvect(m,ny,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + + if ((iix /= ione).or.(iiy /= ione)) then + info=psb_err_ix_n1_iy_n1_unsupported_ + call psb_errpush(info,name) + goto 9999 + end if + + if (x%get_ncols() /= y%get_ncols()) then + info=psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + else + allocate(res(x%get_ncols()),stat=info) + if (info /= 0) then + info=psb_err_alloc_dealloc_ + call psb_errpush(info,name) + goto 9999 + end if + end if + + nr = desc_a%get_local_rows() + if(nr > 0) then + call x%mlt(y,res,info) + ! FIXME + ! adjust dot_local because overlapped elements are computed more than once + if (size(desc_a%ovrlap_elem,1)>0) then + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + do i=1,size(desc_a%ovrlap_elem,1) + idx = desc_a%ovrlap_elem(i,1) + ndm = desc_a%ovrlap_elem(i,2) + ! Since I'm coputing res via a dgemm on the whole vector, I need to adjust + ! the result by removing the contribution of the overlapped elements + ! specifically: res(:,:) = res(:,:) - x%v%v(idx,:)^T*y%v%v(idx,:) + ! using dgemm to compute the matrix-matrix product of the form R = R - X'*Y + ! where R is the result, X' is the transpose of the matrix x%v%v(idx,:) + ! and Y is the matrix y%v%v(idx,:) + call dgemm('T','N',size(x%v%v(idx,:),2),size(y%v%v(idx,:),2),& + & size(x%v%v(idx,:),1),-done,x%v%v(idx,:),size(x%v%v(idx,:),1),& + & y%v%v(idx,:),size(y%v%v(idx,:),1),done,res,y%get_ncols()) + end do + end if + else + res = dzero + end if + + ! compute global sum + if (global_) call psb_sum(ctxt, res) + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + +end subroutine psb_dmlt_multivect \ No newline at end of file diff --git a/test/kernel/vecoperation.f90 b/test/kernel/vecoperation.f90 index 5c87768b..44abefdd 100644 --- a/test/kernel/vecoperation.f90 +++ b/test/kernel/vecoperation.f90 @@ -1048,6 +1048,8 @@ program vecoperation type(psb_s_multivect_type) :: smv1, smv2 type(psb_c_multivect_type) :: cmv1, cmv2 type(psb_z_multivect_type) :: zmv1, zmv2 + ! scalars + real(psb_dpk_), allocatable, dimension(:,:) :: res ! blacs parameters type(psb_ctxt_type) :: ctxt integer(psb_ipk_) :: iam, np @@ -1569,7 +1571,15 @@ program vecoperation if(all(ansmv(:) == np*idim)) write(psb_out_unit,'("TEST PASSED >>> Dot product (mv vs vector) (double complex)")') if(any(ansmv(:) /= np*idim)) write(psb_out_unit,'("TEST FAILED --- Dot product (mv vs vector) (double complex)")') end if - + ! Inner product: multivector vs multivector (double real) + call psb_d_gen_const_multi(mv1,done,idim,nmv,ctxt,desc_a,info) + call psb_d_gen_const_multi(mv2,done,idim,nmv,ctxt,desc_a,info) + allocate(res(nmv,nmv)) + call psb_gemlt(mv1,mv2,res,desc_a,info) + if (iam == psb_root_) then + if(all(res(:,:) == np*idim)) write(psb_out_unit,'("TEST PASSED >>> Inner product (mv vs mv) (double real)")') + if(any(res(:,:) /= np*idim)) write(psb_out_unit,'("TEST FAILED --- Inner product (mv vs mv) (double real)")') + end if call psb_gefree(x,desc_a,info) @@ -1593,6 +1603,7 @@ program vecoperation call psb_gefree(zmv1,desc_a,info) call psb_gefree(zmv2,desc_a,info) call psb_cdfree(desc_a,info) + if(allocated(res)) deallocate(res) if(info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='free routine'