Added inner product implementation

randomized
Fabio Durastante 1 year ago
parent da9ef2f8cc
commit 4823c5662a

@ -540,6 +540,15 @@ module psb_d_psblas_mod
integer(psb_ipk_), intent(out) :: info
character(len=1), intent(in), optional :: conjgx, conjgy
end subroutine psb_dmlt_vect2
subroutine psb_dmlt_multivect(x, y, res, desc_a,info,global)
import :: psb_desc_type, psb_dpk_, psb_ipk_, &
& psb_d_multivect_type, psb_dspmat_type
real(psb_dpk_), dimension(:,:), allocatable :: res
type(psb_d_multivect_type), intent(inout) :: x, y
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
end subroutine psb_dmlt_multivect
end interface
interface psb_gediv

@ -2356,13 +2356,15 @@ module psb_d_base_multivect_mod
procedure, pass(y) :: mlt_ar2 => d_base_mlv_mlt_ar2
procedure, pass(z) :: mlt_a_2 => d_base_mlv_mlt_a_2
procedure, pass(z) :: mlt_v_2 => d_base_mlv_mlt_v_2
procedure, pass(x) :: mlt_mv2 => d_base_mlv_mlt_mv2
!!$ procedure, pass(z) :: mlt_va => d_base_mlv_mlt_va
!!$ procedure, pass(z) :: mlt_av => d_base_mlv_mlt_av
generic, public :: mlt => mlt_mv, mlt_mv_v, mlt_ar1, mlt_ar2, &
& mlt_a_2, mlt_v_2 !, mlt_av, mlt_va
& mlt_a_2, mlt_v_2, mlt_mv2 !, mlt_av, mlt_va
!
! Scaling and norms
!
procedure, pass(x) :: trslv => d_base_mlv_trslv
procedure, pass(x) :: scal => d_base_mlv_scal
procedure, pass(x) :: nrm2 => d_base_mlv_nrm2
procedure, pass(x) :: amax => d_base_mlv_amax
@ -2851,6 +2853,72 @@ contains
res(1:m,1:n) = x%v(1:m,1:n)
end function d_base_mlv_get_vect
!
!> subroutine d_base_mlv_trslv
!! \memberof psb_d_base_multivect_type
!! \brief Computes X = X / A with A an upper triangular matrix
!! \param n Number of entries to be considered
!! \param x The multivector to be used for the division
!! \param uplo 'U' for upper triangular, 'L' for lower triangular
!! \param a The matrix to be used for the division
!! \param alpha (optional) The scaling factor
!! \param trans (optional) 'N' for no transpose, 'T' for transpose
!! \param diag (optional) 'N' for non-unit diagonal, 'U' for unit diagonal
!! \param info return code
!!
subroutine d_base_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info)
implicit none
class(psb_d_base_multivect_type), intent(inout) :: x
real(psb_dpk_), intent(in) :: a(:,:)
integer(psb_ipk_), intent(in) :: n
character(len=1), intent(in) :: uplo
real(psb_dpk_), intent(in), optional :: alpha
character(len=1), intent(in), optional :: trans, diag
integer(psb_ipk_), intent(out) :: info
! Local variables
integer(psb_ipk_) :: lda, ldb
character(len=1) :: trans_, diag_, side
real(psb_dpk_) :: alpha_
! Default values
if (.not.present(alpha)) then
alpha_ = done
else
alpha_ = alpha
end if
if (.not.present(trans)) then
trans_ = 'N'
else
trans_ = trans
end if
if (.not.present(diag)) then
diag_ = 'N'
else
diag_ = diag
end if
info = psb_success_
! Check that a is square
if (size(a,1) /= size(a,2)) then
info = psb_err_invalid_input_
return
end if
! Check that a has the same number of columns as x
if (size(a,2) /= x%get_ncols()) then
info = psb_err_invalid_input_
return
end if
if (x%is_dev()) call x%sync()
if (x%is_sync()) then
! Call BLAS function to solve the system
lda = n
ldb = n
side = 'R' ! X*op( A ) = alpha*B.
call dtrsm(side, uplo, trans_, diag_, n, x%get_ncols(), alpha_, a, lda, x%v, ldb)
end if
end subroutine d_base_mlv_trslv
!
! Reset all values
!
@ -3066,6 +3134,50 @@ contains
end subroutine d_base_mlv_axpby_a
!> Function base_mlv_mlt_mv2
!! \memberof psb_d_base_multivect_type
!! \brief computes A = transpose(X)*Y
!! \param x The class(base_mlv_vect) to be multiplied by
!! \param y The class(base_mlv_vect) to be multiplied by
!! \param a The resulting matrix
!! \param info return code
subroutine d_base_mlv_mlt_mv2(x,y,a,info)
use psi_serial_mod
implicit none
class(psb_d_base_multivect_type), intent(inout) :: x
class(psb_d_base_multivect_type), intent(inout) :: y
real(psb_dpk_), intent(inout), allocatable :: a(:,:)
integer(psb_ipk_), intent(out) :: info
info = psb_success_
if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync()
if (allocated(a)) then
if (size(a,1) /= x%get_ncols()) then
info = psb_err_invalid_input_
return
end if
if (size(a,2) /= y%get_ncols()) then
info = psb_err_invalid_input_
return
end if
else
allocate(a(x%get_ncols(),y%get_ncols()),stat=info)
if (info /= 0) call psb_errpush(psb_err_alloc_dealloc_,'base_mlv_mlt_mv2')
end if
! We do the multiplication by using the BLAS function
! dgemm, which computes the matrix-matrix product
! C = alpha*op( A )*op( B ) + beta*C
! In our case, we want to compute
! C = X'*Y
call dgemm('T', 'N', x%get_ncols(), y%get_ncols(), x%get_nrows(), done, &
& x%v, x%get_nrows(), y%v, y%get_nrows(), dzero, a, x%get_ncols())
end subroutine d_base_mlv_mlt_mv2
!
! Multiple variants of two operations:
! Simple multiplication Y(:.:) = X(:,:)*Y(:,:)

@ -1423,6 +1423,9 @@ module psb_d_multivect_mod
procedure, pass(x) :: dot_a => d_mvect_dot_a
procedure, pass(x) :: dot_a_vect => d_mvect_dot_vect
generic, public :: dot => dot_v, dot_a, dot_a_vect
procedure, pass(x) :: trslv => d_mlv_trslv
procedure, pass(x) :: mlt_mv2 => d_mvect_mlt_mv2
generic, public :: mlt => mlt_mv2
!!$ procedure, pass(y) :: axpby_v => d_mvect_axpby_v
!!$ procedure, pass(y) :: axpby_a => d_mvect_axpby_a
!!$ generic, public :: axpby => axpby_v, axpby_a
@ -1955,6 +1958,41 @@ contains
end function d_mvect_dot_a
subroutine d_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info)
implicit none
class(psb_d_multivect_type), intent(inout) :: x
real(psb_dpk_), intent(in) :: a(:,:)
integer(psb_ipk_), intent(in) :: n
character(len=1), intent(in) :: uplo
real(psb_dpk_), intent(in), optional :: alpha
character(len=1), intent(in), optional :: trans, diag
integer(psb_ipk_), intent(out) :: info
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
return
else
call x%v%trslv(n,a,uplo,alpha=alpha,trans=trans,diag=diag,info=info)
end if
end subroutine d_mlv_trslv
subroutine d_mvect_mlt_mv2(x,y,a,info)
use psi_serial_mod
implicit none
class(psb_d_multivect_type), intent(inout) :: x
class(psb_d_multivect_type), intent(inout) :: y
real(psb_dpk_), intent(inout), allocatable :: a(:,:)
integer(psb_ipk_), intent(out) :: info
if (allocated(x%v).and.allocated(y%v)) then
call y%v%mlt(x%v,a,info)
else
info = psb_err_invalid_vect_state_
return
end if
end subroutine d_mvect_mlt_mv2
!!$ subroutine d_mvect_axpby_v(m,alpha, x, beta, y, info)
!!$ use psi_serial_mod
!!$ implicit none

@ -109,3 +109,185 @@ subroutine psb_dvmlt(x,y,desc_a,info)
return
end subroutine psb_dvmlt
!
! Parallel Sparse BLAS version 3.5
! (C) Copyright 2006-2018
! Salvatore Filippone
! Alfredo Buttari
!
! Redistribution and use in source and binary forms, with or without
! modification, are permitted provided that the following conditions
! are met:
! 1. Redistributions of source code must retain the above copyright
! notice, this list of conditions and the following disclaimer.
! 2. Redistributions in binary form must reproduce the above copyright
! notice, this list of conditions, and the following disclaimer in the
! documentation and/or other materials provided with the distribution.
! 3. The name of the PSBLAS group or the names of its contributors may
! not be used to endorse or promote products derived from this
! software without specific written permission.
!
! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
! ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
! TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
! PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE PSBLAS GROUP OR ITS CONTRIBUTORS
! BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
! CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
! SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
! INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
! CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
! POSSIBILITY OF SUCH DAMAGE.
!
!
! File: psb_dvmlt.f90
!
! Function: psb_ddot_multivect
! psb_dvmlta computes the inner product of two distributed multivectors,
!
! dot(:,:) := ( X(:,:) )**C * ( Y(:,:) )
!
!
! Arguments:
! x - type(psb_d_multivect_type) The input vector containing the entries of sub( X ).
! y - type(psb_d_multivect_type) The input vector containing the entries of sub( Y ).
! desc_a - type(psb_desc_type). The communication descriptor.
! info - integer. Return code
! global - logical(optional) Whether to perform the global sum, default: .true.
!
! Note: from a functional point of view, X and Y are input, but here
! they are declared INOUT because of the sync() methods.
!
!
subroutine psb_dmlt_multivect(x, y, desc_a,res,info,global)
use psb_desc_mod
use psb_d_base_mat_mod
use psb_check_mod
use psb_error_mod
use psb_penv_mod
use psb_d_multivect_mod
use psb_d_psblas_mod, psb_protect_name => psb_dmlt_multivect
implicit none
real(psb_dpk_), dimension(:,:), allocatable :: res
type(psb_d_multivect_type), intent(inout) :: x, y
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
! locals
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me, idx, ndm,&
& err_act, iix, jjx, iiy, jjy, i, nr
integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n, nx, ny
logical :: global_
character(len=20) :: name, ch_err
name='psb_dmlt_multivect'
info=psb_success_
call psb_erractionsave(err_act)
if (psb_errstatus_fatal()) then
info = psb_err_internal_error_ ; goto 9999
end if
ctxt=desc_a%get_context()
call psb_info(ctxt, me, np)
if (np == -ione) then
info = psb_err_context_error_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(y%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
if (present(global)) then
global_ = global
else
global_ = .true.
end if
ix = ione
ijx = ione
iy = ione
ijy = ione
m = desc_a%get_global_rows()
nx = x%get_ncols()
ny = y%get_ncols()
! check vector correctness
call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if (info == psb_success_) &
& call psb_chkvect(m,ny,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
if ((iix /= ione).or.(iiy /= ione)) then
info=psb_err_ix_n1_iy_n1_unsupported_
call psb_errpush(info,name)
goto 9999
end if
if (x%get_ncols() /= y%get_ncols()) then
info=psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
else
allocate(res(x%get_ncols()),stat=info)
if (info /= 0) then
info=psb_err_alloc_dealloc_
call psb_errpush(info,name)
goto 9999
end if
end if
nr = desc_a%get_local_rows()
if(nr > 0) then
call x%mlt(y,res,info)
! FIXME
! adjust dot_local because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync()
do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2)
! Since I'm coputing res via a dgemm on the whole vector, I need to adjust
! the result by removing the contribution of the overlapped elements
! specifically: res(:,:) = res(:,:) - x%v%v(idx,:)^T*y%v%v(idx,:)
! using dgemm to compute the matrix-matrix product of the form R = R - X'*Y
! where R is the result, X' is the transpose of the matrix x%v%v(idx,:)
! and Y is the matrix y%v%v(idx,:)
call dgemm('T','N',size(x%v%v(idx,:),2),size(y%v%v(idx,:),2),&
& size(x%v%v(idx,:),1),-done,x%v%v(idx,:),size(x%v%v(idx,:),1),&
& y%v%v(idx,:),size(y%v%v(idx,:),1),done,res,y%get_ncols())
end do
end if
else
res = dzero
end if
! compute global sum
if (global_) call psb_sum(ctxt, res)
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine psb_dmlt_multivect

@ -1048,6 +1048,8 @@ program vecoperation
type(psb_s_multivect_type) :: smv1, smv2
type(psb_c_multivect_type) :: cmv1, cmv2
type(psb_z_multivect_type) :: zmv1, zmv2
! scalars
real(psb_dpk_), allocatable, dimension(:,:) :: res
! blacs parameters
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: iam, np
@ -1569,7 +1571,15 @@ program vecoperation
if(all(ansmv(:) == np*idim)) write(psb_out_unit,'("TEST PASSED >>> Dot product (mv vs vector) (double complex)")')
if(any(ansmv(:) /= np*idim)) write(psb_out_unit,'("TEST FAILED --- Dot product (mv vs vector) (double complex)")')
end if
! Inner product: multivector vs multivector (double real)
call psb_d_gen_const_multi(mv1,done,idim,nmv,ctxt,desc_a,info)
call psb_d_gen_const_multi(mv2,done,idim,nmv,ctxt,desc_a,info)
allocate(res(nmv,nmv))
call psb_gemlt(mv1,mv2,res,desc_a,info)
if (iam == psb_root_) then
if(all(res(:,:) == np*idim)) write(psb_out_unit,'("TEST PASSED >>> Inner product (mv vs mv) (double real)")')
if(any(res(:,:) /= np*idim)) write(psb_out_unit,'("TEST FAILED --- Inner product (mv vs mv) (double real)")')
end if
call psb_gefree(x,desc_a,info)
@ -1593,6 +1603,7 @@ program vecoperation
call psb_gefree(zmv1,desc_a,info)
call psb_gefree(zmv2,desc_a,info)
call psb_cdfree(desc_a,info)
if(allocated(res)) deallocate(res)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='free routine'

Loading…
Cancel
Save