Added implementation of multivector operations for all type/kinds

randomized
Fabio Durastante 1 year ago
parent e423e149fa
commit 611301a606

@ -529,6 +529,15 @@ module psb_c_psblas_mod
integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), intent(out) :: info
character(len=1), intent(in), optional :: conjgx, conjgy character(len=1), intent(in), optional :: conjgx, conjgy
end subroutine psb_cmlt_vect2 end subroutine psb_cmlt_vect2
subroutine psb_cmlt_multivect(x, y, res, desc_a,info,global)
import :: psb_desc_type, psb_spk_, psb_ipk_, &
& psb_c_multivect_type, psb_cspmat_type
complex(psb_spk_), dimension(:,:), allocatable, intent(inout) :: res
type(psb_c_multivect_type), intent(inout) :: x, y
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
end subroutine psb_cmlt_multivect
end interface end interface
interface psb_gediv interface psb_gediv
@ -568,6 +577,18 @@ module psb_c_psblas_mod
integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), intent(out) :: info
logical, intent(in) :: flag logical, intent(in) :: flag
end subroutine psb_cdiv_vect2_check end subroutine psb_cdiv_vect2_check
subroutine psb_cdiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag)
import :: psb_desc_type, psb_ipk_, &
& psb_spk_, psb_c_multivect_type
type(psb_c_multivect_type), intent (inout) :: x
complex(psb_spk_), intent (in), dimension(:,:) :: a
type(psb_desc_type), intent (in) :: desc_a
character(len=1), intent(in) :: uplo
integer(psb_ipk_), intent(out) :: info
complex(psb_spk_), intent (in), optional :: alpha
character(len=1), intent(in), optional :: trans
character(len=1), intent(in), optional :: diag
end subroutine psb_cdiv_trslv
end interface end interface
interface psb_geinv interface psb_geinv

@ -540,6 +540,15 @@ module psb_s_psblas_mod
integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), intent(out) :: info
character(len=1), intent(in), optional :: conjgx, conjgy character(len=1), intent(in), optional :: conjgx, conjgy
end subroutine psb_smlt_vect2 end subroutine psb_smlt_vect2
subroutine psb_smlt_multivect(x, y, res, desc_a,info,global)
import :: psb_desc_type, psb_spk_, psb_ipk_, &
& psb_s_multivect_type, psb_sspmat_type
real(psb_spk_), dimension(:,:), allocatable, intent(inout) :: res
type(psb_s_multivect_type), intent(inout) :: x, y
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
end subroutine psb_smlt_multivect
end interface end interface
interface psb_gediv interface psb_gediv
@ -579,6 +588,18 @@ module psb_s_psblas_mod
integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), intent(out) :: info
logical, intent(in) :: flag logical, intent(in) :: flag
end subroutine psb_sdiv_vect2_check end subroutine psb_sdiv_vect2_check
subroutine psb_sdiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag)
import :: psb_desc_type, psb_ipk_, &
& psb_spk_, psb_s_multivect_type
type(psb_s_multivect_type), intent (inout) :: x
real(psb_spk_), intent (in), dimension(:,:) :: a
type(psb_desc_type), intent (in) :: desc_a
character(len=1), intent(in) :: uplo
integer(psb_ipk_), intent(out) :: info
real(psb_spk_), intent (in), optional :: alpha
character(len=1), intent(in), optional :: trans
character(len=1), intent(in), optional :: diag
end subroutine psb_sdiv_trslv
end interface end interface
interface psb_geinv interface psb_geinv

@ -529,6 +529,15 @@ module psb_z_psblas_mod
integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), intent(out) :: info
character(len=1), intent(in), optional :: conjgx, conjgy character(len=1), intent(in), optional :: conjgx, conjgy
end subroutine psb_zmlt_vect2 end subroutine psb_zmlt_vect2
subroutine psb_zmlt_multivect(x, y, res, desc_a,info,global)
import :: psb_desc_type, psb_dpk_, psb_ipk_, &
& psb_z_multivect_type, psb_zspmat_type
complex(psb_dpk_), dimension(:,:), allocatable, intent(inout) :: res
type(psb_z_multivect_type), intent(inout) :: x, y
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
end subroutine psb_zmlt_multivect
end interface end interface
interface psb_gediv interface psb_gediv
@ -568,6 +577,18 @@ module psb_z_psblas_mod
integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), intent(out) :: info
logical, intent(in) :: flag logical, intent(in) :: flag
end subroutine psb_zdiv_vect2_check end subroutine psb_zdiv_vect2_check
subroutine psb_zdiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag)
import :: psb_desc_type, psb_ipk_, &
& psb_dpk_, psb_z_multivect_type
type(psb_z_multivect_type), intent (inout) :: x
complex(psb_dpk_), intent (in), dimension(:,:) :: a
type(psb_desc_type), intent (in) :: desc_a
character(len=1), intent(in) :: uplo
integer(psb_ipk_), intent(out) :: info
complex(psb_dpk_), intent (in), optional :: alpha
character(len=1), intent(in), optional :: trans
character(len=1), intent(in), optional :: diag
end subroutine psb_zdiv_trslv
end interface end interface
interface psb_geinv interface psb_geinv

@ -2177,13 +2177,15 @@ module psb_c_base_multivect_mod
procedure, pass(y) :: mlt_ar2 => c_base_mlv_mlt_ar2 procedure, pass(y) :: mlt_ar2 => c_base_mlv_mlt_ar2
procedure, pass(z) :: mlt_a_2 => c_base_mlv_mlt_a_2 procedure, pass(z) :: mlt_a_2 => c_base_mlv_mlt_a_2
procedure, pass(z) :: mlt_v_2 => c_base_mlv_mlt_v_2 procedure, pass(z) :: mlt_v_2 => c_base_mlv_mlt_v_2
procedure, pass(x) :: mlt_mv2 => c_base_mlv_mlt_mv2
!!$ procedure, pass(z) :: mlt_va => c_base_mlv_mlt_va !!$ procedure, pass(z) :: mlt_va => c_base_mlv_mlt_va
!!$ procedure, pass(z) :: mlt_av => c_base_mlv_mlt_av !!$ procedure, pass(z) :: mlt_av => c_base_mlv_mlt_av
generic, public :: mlt => mlt_mv, mlt_mv_v, mlt_ar1, mlt_ar2, & generic, public :: mlt => mlt_mv, mlt_mv_v, mlt_ar1, mlt_ar2, &
& mlt_a_2, mlt_v_2 !, mlt_av, mlt_va & mlt_a_2, mlt_v_2, mlt_mv2 !, mlt_av, mlt_va
! !
! Scaling and norms ! Scaling and norms
! !
procedure, pass(x) :: trslv => c_base_mlv_trslv
procedure, pass(x) :: scal => c_base_mlv_scal procedure, pass(x) :: scal => c_base_mlv_scal
procedure, pass(x) :: nrm2 => c_base_mlv_nrm2 procedure, pass(x) :: nrm2 => c_base_mlv_nrm2
procedure, pass(x) :: amax => c_base_mlv_amax procedure, pass(x) :: amax => c_base_mlv_amax
@ -2672,6 +2674,70 @@ contains
res(1:m,1:n) = x%v(1:m,1:n) res(1:m,1:n) = x%v(1:m,1:n)
end function c_base_mlv_get_vect end function c_base_mlv_get_vect
!
!> subroutine d_base_mlv_trslv
!! \memberof psb_d_base_multivect_type
!! \brief Computes X = X / A with A an upper triangular matrix
!! \param n Number of entries to be considered
!! \param x The multivector to be used for the division
!! \param uplo 'U' for upper triangular, 'L' for lower triangular
!! \param a The matrix to be used for the division
!! \param alpha (optional) The scaling factor
!! \param trans (optional) 'N' for no transpose, 'T' for transpose
!! \param diag (optional) 'N' for non-unit diagonal, 'U' for unit diagonal
!! \param info return code
!!
subroutine c_base_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info)
implicit none
class(psb_c_base_multivect_type), intent(inout) :: x
complex(psb_spk_), intent(in) :: a(:,:)
integer(psb_ipk_), intent(in) :: n
character(len=1), intent(in) :: uplo
complex(psb_spk_), intent(in), optional :: alpha
character(len=1), intent(in), optional :: trans, diag
integer(psb_ipk_), intent(out) :: info
! Local variables
integer(psb_ipk_) :: lda, ldb
character(len=1) :: trans_, diag_, side
complex(psb_spk_) :: alpha_
! Default values
if (.not.present(alpha)) then
alpha_ = cone
else
alpha_ = alpha
end if
if (.not.present(trans)) then
trans_ = 'N'
else
trans_ = trans
end if
if (.not.present(diag)) then
diag_ = 'N'
else
diag_ = diag
end if
info = psb_success_
! Check that a is square
if (size(a,1) /= size(a,2)) then
info = psb_err_invalid_input_
return
end if
! Check that a has the same number of columns as x
if (size(a,2) /= x%get_ncols()) then
info = psb_err_invalid_input_
return
end if
if (x%is_dev()) call x%sync()
if (x%is_sync()) then
! Call BLAS function to solve the system
lda = size(a,1)
ldb = x%get_nrows()
side = 'R' ! X*op( A ) = alpha*B.
call dtrsm(side, uplo, trans_, diag_, n, x%get_ncols(), alpha_, a, lda, x%v, ldb)
end if
end subroutine c_base_mlv_trslv
! !
! Reset all values ! Reset all values
! !
@ -2886,6 +2952,48 @@ contains
end subroutine c_base_mlv_axpby_a end subroutine c_base_mlv_axpby_a
!> Function base_mlv_mlt_mv2
!! \memberof psb_d_base_multivect_type
!! \brief computes A = transpose(X)*Y / conjugatetranspose(X)*Y
!! \param x The class(base_mlv_vect) to be multiplied by
!! \param y The class(base_mlv_vect) to be multiplied by
!! \param a The resulting matrix
!! \param info return code
subroutine c_base_mlv_mlt_mv2(n,x,y,a,info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: n
class(psb_c_base_multivect_type), intent(inout) :: x
class(psb_c_base_multivect_type), intent(inout) :: y
complex(psb_spk_), intent(inout), allocatable :: a(:,:)
integer(psb_ipk_), intent(out) :: info
info = psb_success_
if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync()
if (allocated(a)) then
if (size(a,1) /= x%get_ncols()) then
info = psb_err_invalid_input_
return
end if
if (size(a,2) /= y%get_ncols()) then
info = psb_err_invalid_input_
return
end if
else
allocate(a(x%get_ncols(),y%get_ncols()),stat=info)
if (info /= 0) call psb_errpush(psb_err_alloc_dealloc_,'base_mlv_mlt_mv2')
end if
! We do the multiplication by using the BLAS function
! dgemm, which computes the matrix-matrix product
! C = alpha*op( A )*op( B ) + beta*C
! In our case, we want to compute
! C = X'*Y
call dgemm('C', 'N', x%get_ncols(), y%get_ncols(), n, cone, &
& x%v, x%get_nrows(), y%v, y%get_nrows(), czero, a, x%get_ncols())
end subroutine c_base_mlv_mlt_mv2
! !
! Multiple variants of two operations: ! Multiple variants of two operations:

@ -1344,6 +1344,9 @@ module psb_c_multivect_mod
procedure, pass(x) :: dot_a => c_mvect_dot_a procedure, pass(x) :: dot_a => c_mvect_dot_a
procedure, pass(x) :: dot_a_vect => c_mvect_dot_vect procedure, pass(x) :: dot_a_vect => c_mvect_dot_vect
generic, public :: dot => dot_v, dot_a, dot_a_vect generic, public :: dot => dot_v, dot_a, dot_a_vect
procedure, pass(x) :: trslv => c_mlv_trslv
procedure, pass(x) :: mlt_mv2 => c_mvect_mlt_mv2
generic, public :: mlt => mlt_mv2
!!$ procedure, pass(y) :: axpby_v => c_mvect_axpby_v !!$ procedure, pass(y) :: axpby_v => c_mvect_axpby_v
!!$ procedure, pass(y) :: axpby_a => c_mvect_axpby_a !!$ procedure, pass(y) :: axpby_a => c_mvect_axpby_a
!!$ generic, public :: axpby => axpby_v, axpby_a !!$ generic, public :: axpby => axpby_v, axpby_a
@ -1876,6 +1879,43 @@ contains
end function c_mvect_dot_a end function c_mvect_dot_a
subroutine c_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info)
implicit none
class(psb_c_multivect_type), intent(inout) :: x
complex(psb_spk_), intent(in) :: a(:,:)
integer(psb_ipk_), intent(in) :: n
character(len=1), intent(in) :: uplo
complex(psb_spk_), intent(in), optional :: alpha
character(len=1), intent(in), optional :: trans, diag
integer(psb_ipk_), intent(out) :: info
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
return
else
call x%v%trslv(n,a,uplo,alpha=alpha,trans=trans,diag=diag,info=info)
end if
end subroutine c_mlv_trslv
subroutine c_mvect_mlt_mv2(n,x,y,a,info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: n
class(psb_c_multivect_type), intent(inout) :: x
class(psb_c_multivect_type), intent(inout) :: y
complex(psb_spk_), intent(inout), allocatable :: a(:,:)
integer(psb_ipk_), intent(out) :: info
if (allocated(x%v).and.allocated(y%v)) then
call y%v%mlt(n,x%v,a,info)
else
info = psb_err_invalid_vect_state_
return
end if
end subroutine c_mvect_mlt_mv2
!!$ subroutine c_mvect_axpby_v(m,alpha, x, beta, y, info) !!$ subroutine c_mvect_axpby_v(m,alpha, x, beta, y, info)
!!$ use psi_serial_mod !!$ use psi_serial_mod
!!$ implicit none !!$ implicit none

@ -2917,8 +2917,6 @@ contains
call dtrsm(side, uplo, trans_, diag_, n, x%get_ncols(), alpha_, a, lda, x%v, ldb) call dtrsm(side, uplo, trans_, diag_, n, x%get_ncols(), alpha_, a, lda, x%v, ldb)
end if end if
end subroutine d_base_mlv_trslv end subroutine d_base_mlv_trslv
! !
! Reset all values ! Reset all values
! !
@ -3133,10 +3131,9 @@ contains
end subroutine d_base_mlv_axpby_a end subroutine d_base_mlv_axpby_a
!> Function base_mlv_mlt_mv2
!> Function base_mlv_mlt_mv2
!! \memberof psb_d_base_multivect_type !! \memberof psb_d_base_multivect_type
!! \brief computes A = transpose(X)*Y !! \brief computes A = transpose(X)*Y / conjugatetranspose(X)*Y
!! \param x The class(base_mlv_vect) to be multiplied by !! \param x The class(base_mlv_vect) to be multiplied by
!! \param y The class(base_mlv_vect) to be multiplied by !! \param y The class(base_mlv_vect) to be multiplied by
!! \param a The resulting matrix !! \param a The resulting matrix
@ -3172,13 +3169,11 @@ contains
! C = alpha*op( A )*op( B ) + beta*C ! C = alpha*op( A )*op( B ) + beta*C
! In our case, we want to compute ! In our case, we want to compute
! C = X'*Y ! C = X'*Y
call dgemm('T', 'N', x%get_ncols(), y%get_ncols(), n, done, & call dgemm('C', 'N', x%get_ncols(), y%get_ncols(), n, done, &
& x%v, x%get_nrows(), y%v, y%get_nrows(), dzero, a, x%get_ncols()) & x%v, x%get_nrows(), y%v, y%get_nrows(), dzero, a, x%get_ncols())
end subroutine d_base_mlv_mlt_mv2 end subroutine d_base_mlv_mlt_mv2
! !
! Multiple variants of two operations: ! Multiple variants of two operations:
! Simple multiplication Y(:.:) = X(:,:)*Y(:,:) ! Simple multiplication Y(:.:) = X(:,:)*Y(:,:)

@ -1994,6 +1994,7 @@ contains
end subroutine d_mvect_mlt_mv2 end subroutine d_mvect_mlt_mv2
!!$ subroutine d_mvect_axpby_v(m,alpha, x, beta, y, info) !!$ subroutine d_mvect_axpby_v(m,alpha, x, beta, y, info)
!!$ use psi_serial_mod !!$ use psi_serial_mod
!!$ implicit none !!$ implicit none

@ -2356,13 +2356,15 @@ module psb_s_base_multivect_mod
procedure, pass(y) :: mlt_ar2 => s_base_mlv_mlt_ar2 procedure, pass(y) :: mlt_ar2 => s_base_mlv_mlt_ar2
procedure, pass(z) :: mlt_a_2 => s_base_mlv_mlt_a_2 procedure, pass(z) :: mlt_a_2 => s_base_mlv_mlt_a_2
procedure, pass(z) :: mlt_v_2 => s_base_mlv_mlt_v_2 procedure, pass(z) :: mlt_v_2 => s_base_mlv_mlt_v_2
procedure, pass(x) :: mlt_mv2 => s_base_mlv_mlt_mv2
!!$ procedure, pass(z) :: mlt_va => s_base_mlv_mlt_va !!$ procedure, pass(z) :: mlt_va => s_base_mlv_mlt_va
!!$ procedure, pass(z) :: mlt_av => s_base_mlv_mlt_av !!$ procedure, pass(z) :: mlt_av => s_base_mlv_mlt_av
generic, public :: mlt => mlt_mv, mlt_mv_v, mlt_ar1, mlt_ar2, & generic, public :: mlt => mlt_mv, mlt_mv_v, mlt_ar1, mlt_ar2, &
& mlt_a_2, mlt_v_2 !, mlt_av, mlt_va & mlt_a_2, mlt_v_2, mlt_mv2 !, mlt_av, mlt_va
! !
! Scaling and norms ! Scaling and norms
! !
procedure, pass(x) :: trslv => s_base_mlv_trslv
procedure, pass(x) :: scal => s_base_mlv_scal procedure, pass(x) :: scal => s_base_mlv_scal
procedure, pass(x) :: nrm2 => s_base_mlv_nrm2 procedure, pass(x) :: nrm2 => s_base_mlv_nrm2
procedure, pass(x) :: amax => s_base_mlv_amax procedure, pass(x) :: amax => s_base_mlv_amax
@ -2851,6 +2853,70 @@ contains
res(1:m,1:n) = x%v(1:m,1:n) res(1:m,1:n) = x%v(1:m,1:n)
end function s_base_mlv_get_vect end function s_base_mlv_get_vect
!
!> subroutine d_base_mlv_trslv
!! \memberof psb_d_base_multivect_type
!! \brief Computes X = X / A with A an upper triangular matrix
!! \param n Number of entries to be considered
!! \param x The multivector to be used for the division
!! \param uplo 'U' for upper triangular, 'L' for lower triangular
!! \param a The matrix to be used for the division
!! \param alpha (optional) The scaling factor
!! \param trans (optional) 'N' for no transpose, 'T' for transpose
!! \param diag (optional) 'N' for non-unit diagonal, 'U' for unit diagonal
!! \param info return code
!!
subroutine s_base_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info)
implicit none
class(psb_s_base_multivect_type), intent(inout) :: x
real(psb_spk_), intent(in) :: a(:,:)
integer(psb_ipk_), intent(in) :: n
character(len=1), intent(in) :: uplo
real(psb_spk_), intent(in), optional :: alpha
character(len=1), intent(in), optional :: trans, diag
integer(psb_ipk_), intent(out) :: info
! Local variables
integer(psb_ipk_) :: lda, ldb
character(len=1) :: trans_, diag_, side
real(psb_spk_) :: alpha_
! Default values
if (.not.present(alpha)) then
alpha_ = sone
else
alpha_ = alpha
end if
if (.not.present(trans)) then
trans_ = 'N'
else
trans_ = trans
end if
if (.not.present(diag)) then
diag_ = 'N'
else
diag_ = diag
end if
info = psb_success_
! Check that a is square
if (size(a,1) /= size(a,2)) then
info = psb_err_invalid_input_
return
end if
! Check that a has the same number of columns as x
if (size(a,2) /= x%get_ncols()) then
info = psb_err_invalid_input_
return
end if
if (x%is_dev()) call x%sync()
if (x%is_sync()) then
! Call BLAS function to solve the system
lda = size(a,1)
ldb = x%get_nrows()
side = 'R' ! X*op( A ) = alpha*B.
call dtrsm(side, uplo, trans_, diag_, n, x%get_ncols(), alpha_, a, lda, x%v, ldb)
end if
end subroutine s_base_mlv_trslv
! !
! Reset all values ! Reset all values
! !
@ -3065,6 +3131,48 @@ contains
end subroutine s_base_mlv_axpby_a end subroutine s_base_mlv_axpby_a
!> Function base_mlv_mlt_mv2
!! \memberof psb_d_base_multivect_type
!! \brief computes A = transpose(X)*Y / conjugatetranspose(X)*Y
!! \param x The class(base_mlv_vect) to be multiplied by
!! \param y The class(base_mlv_vect) to be multiplied by
!! \param a The resulting matrix
!! \param info return code
subroutine s_base_mlv_mlt_mv2(n,x,y,a,info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: n
class(psb_s_base_multivect_type), intent(inout) :: x
class(psb_s_base_multivect_type), intent(inout) :: y
real(psb_spk_), intent(inout), allocatable :: a(:,:)
integer(psb_ipk_), intent(out) :: info
info = psb_success_
if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync()
if (allocated(a)) then
if (size(a,1) /= x%get_ncols()) then
info = psb_err_invalid_input_
return
end if
if (size(a,2) /= y%get_ncols()) then
info = psb_err_invalid_input_
return
end if
else
allocate(a(x%get_ncols(),y%get_ncols()),stat=info)
if (info /= 0) call psb_errpush(psb_err_alloc_dealloc_,'base_mlv_mlt_mv2')
end if
! We do the multiplication by using the BLAS function
! dgemm, which computes the matrix-matrix product
! C = alpha*op( A )*op( B ) + beta*C
! In our case, we want to compute
! C = X'*Y
call dgemm('C', 'N', x%get_ncols(), y%get_ncols(), n, sone, &
& x%v, x%get_nrows(), y%v, y%get_nrows(), szero, a, x%get_ncols())
end subroutine s_base_mlv_mlt_mv2
! !
! Multiple variants of two operations: ! Multiple variants of two operations:

@ -1423,6 +1423,9 @@ module psb_s_multivect_mod
procedure, pass(x) :: dot_a => s_mvect_dot_a procedure, pass(x) :: dot_a => s_mvect_dot_a
procedure, pass(x) :: dot_a_vect => s_mvect_dot_vect procedure, pass(x) :: dot_a_vect => s_mvect_dot_vect
generic, public :: dot => dot_v, dot_a, dot_a_vect generic, public :: dot => dot_v, dot_a, dot_a_vect
procedure, pass(x) :: trslv => s_mlv_trslv
procedure, pass(x) :: mlt_mv2 => s_mvect_mlt_mv2
generic, public :: mlt => mlt_mv2
!!$ procedure, pass(y) :: axpby_v => s_mvect_axpby_v !!$ procedure, pass(y) :: axpby_v => s_mvect_axpby_v
!!$ procedure, pass(y) :: axpby_a => s_mvect_axpby_a !!$ procedure, pass(y) :: axpby_a => s_mvect_axpby_a
!!$ generic, public :: axpby => axpby_v, axpby_a !!$ generic, public :: axpby => axpby_v, axpby_a
@ -1955,6 +1958,43 @@ contains
end function s_mvect_dot_a end function s_mvect_dot_a
subroutine s_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info)
implicit none
class(psb_s_multivect_type), intent(inout) :: x
real(psb_spk_), intent(in) :: a(:,:)
integer(psb_ipk_), intent(in) :: n
character(len=1), intent(in) :: uplo
real(psb_spk_), intent(in), optional :: alpha
character(len=1), intent(in), optional :: trans, diag
integer(psb_ipk_), intent(out) :: info
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
return
else
call x%v%trslv(n,a,uplo,alpha=alpha,trans=trans,diag=diag,info=info)
end if
end subroutine s_mlv_trslv
subroutine s_mvect_mlt_mv2(n,x,y,a,info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: n
class(psb_s_multivect_type), intent(inout) :: x
class(psb_s_multivect_type), intent(inout) :: y
real(psb_spk_), intent(inout), allocatable :: a(:,:)
integer(psb_ipk_), intent(out) :: info
if (allocated(x%v).and.allocated(y%v)) then
call y%v%mlt(n,x%v,a,info)
else
info = psb_err_invalid_vect_state_
return
end if
end subroutine s_mvect_mlt_mv2
!!$ subroutine s_mvect_axpby_v(m,alpha, x, beta, y, info) !!$ subroutine s_mvect_axpby_v(m,alpha, x, beta, y, info)
!!$ use psi_serial_mod !!$ use psi_serial_mod
!!$ implicit none !!$ implicit none

@ -2177,13 +2177,15 @@ module psb_z_base_multivect_mod
procedure, pass(y) :: mlt_ar2 => z_base_mlv_mlt_ar2 procedure, pass(y) :: mlt_ar2 => z_base_mlv_mlt_ar2
procedure, pass(z) :: mlt_a_2 => z_base_mlv_mlt_a_2 procedure, pass(z) :: mlt_a_2 => z_base_mlv_mlt_a_2
procedure, pass(z) :: mlt_v_2 => z_base_mlv_mlt_v_2 procedure, pass(z) :: mlt_v_2 => z_base_mlv_mlt_v_2
procedure, pass(x) :: mlt_mv2 => z_base_mlv_mlt_mv2
!!$ procedure, pass(z) :: mlt_va => z_base_mlv_mlt_va !!$ procedure, pass(z) :: mlt_va => z_base_mlv_mlt_va
!!$ procedure, pass(z) :: mlt_av => z_base_mlv_mlt_av !!$ procedure, pass(z) :: mlt_av => z_base_mlv_mlt_av
generic, public :: mlt => mlt_mv, mlt_mv_v, mlt_ar1, mlt_ar2, & generic, public :: mlt => mlt_mv, mlt_mv_v, mlt_ar1, mlt_ar2, &
& mlt_a_2, mlt_v_2 !, mlt_av, mlt_va & mlt_a_2, mlt_v_2, mlt_mv2 !, mlt_av, mlt_va
! !
! Scaling and norms ! Scaling and norms
! !
procedure, pass(x) :: trslv => z_base_mlv_trslv
procedure, pass(x) :: scal => z_base_mlv_scal procedure, pass(x) :: scal => z_base_mlv_scal
procedure, pass(x) :: nrm2 => z_base_mlv_nrm2 procedure, pass(x) :: nrm2 => z_base_mlv_nrm2
procedure, pass(x) :: amax => z_base_mlv_amax procedure, pass(x) :: amax => z_base_mlv_amax
@ -2672,6 +2674,70 @@ contains
res(1:m,1:n) = x%v(1:m,1:n) res(1:m,1:n) = x%v(1:m,1:n)
end function z_base_mlv_get_vect end function z_base_mlv_get_vect
!
!> subroutine d_base_mlv_trslv
!! \memberof psb_d_base_multivect_type
!! \brief Computes X = X / A with A an upper triangular matrix
!! \param n Number of entries to be considered
!! \param x The multivector to be used for the division
!! \param uplo 'U' for upper triangular, 'L' for lower triangular
!! \param a The matrix to be used for the division
!! \param alpha (optional) The scaling factor
!! \param trans (optional) 'N' for no transpose, 'T' for transpose
!! \param diag (optional) 'N' for non-unit diagonal, 'U' for unit diagonal
!! \param info return code
!!
subroutine z_base_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info)
implicit none
class(psb_z_base_multivect_type), intent(inout) :: x
complex(psb_dpk_), intent(in) :: a(:,:)
integer(psb_ipk_), intent(in) :: n
character(len=1), intent(in) :: uplo
complex(psb_dpk_), intent(in), optional :: alpha
character(len=1), intent(in), optional :: trans, diag
integer(psb_ipk_), intent(out) :: info
! Local variables
integer(psb_ipk_) :: lda, ldb
character(len=1) :: trans_, diag_, side
complex(psb_dpk_) :: alpha_
! Default values
if (.not.present(alpha)) then
alpha_ = zone
else
alpha_ = alpha
end if
if (.not.present(trans)) then
trans_ = 'N'
else
trans_ = trans
end if
if (.not.present(diag)) then
diag_ = 'N'
else
diag_ = diag
end if
info = psb_success_
! Check that a is square
if (size(a,1) /= size(a,2)) then
info = psb_err_invalid_input_
return
end if
! Check that a has the same number of columns as x
if (size(a,2) /= x%get_ncols()) then
info = psb_err_invalid_input_
return
end if
if (x%is_dev()) call x%sync()
if (x%is_sync()) then
! Call BLAS function to solve the system
lda = size(a,1)
ldb = x%get_nrows()
side = 'R' ! X*op( A ) = alpha*B.
call dtrsm(side, uplo, trans_, diag_, n, x%get_ncols(), alpha_, a, lda, x%v, ldb)
end if
end subroutine z_base_mlv_trslv
! !
! Reset all values ! Reset all values
! !
@ -2886,6 +2952,48 @@ contains
end subroutine z_base_mlv_axpby_a end subroutine z_base_mlv_axpby_a
!> Function base_mlv_mlt_mv2
!! \memberof psb_d_base_multivect_type
!! \brief computes A = transpose(X)*Y / conjugatetranspose(X)*Y
!! \param x The class(base_mlv_vect) to be multiplied by
!! \param y The class(base_mlv_vect) to be multiplied by
!! \param a The resulting matrix
!! \param info return code
subroutine z_base_mlv_mlt_mv2(n,x,y,a,info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: n
class(psb_z_base_multivect_type), intent(inout) :: x
class(psb_z_base_multivect_type), intent(inout) :: y
complex(psb_dpk_), intent(inout), allocatable :: a(:,:)
integer(psb_ipk_), intent(out) :: info
info = psb_success_
if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync()
if (allocated(a)) then
if (size(a,1) /= x%get_ncols()) then
info = psb_err_invalid_input_
return
end if
if (size(a,2) /= y%get_ncols()) then
info = psb_err_invalid_input_
return
end if
else
allocate(a(x%get_ncols(),y%get_ncols()),stat=info)
if (info /= 0) call psb_errpush(psb_err_alloc_dealloc_,'base_mlv_mlt_mv2')
end if
! We do the multiplication by using the BLAS function
! dgemm, which computes the matrix-matrix product
! C = alpha*op( A )*op( B ) + beta*C
! In our case, we want to compute
! C = X'*Y
call dgemm('C', 'N', x%get_ncols(), y%get_ncols(), n, zone, &
& x%v, x%get_nrows(), y%v, y%get_nrows(), zzero, a, x%get_ncols())
end subroutine z_base_mlv_mlt_mv2
! !
! Multiple variants of two operations: ! Multiple variants of two operations:

@ -1344,6 +1344,9 @@ module psb_z_multivect_mod
procedure, pass(x) :: dot_a => z_mvect_dot_a procedure, pass(x) :: dot_a => z_mvect_dot_a
procedure, pass(x) :: dot_a_vect => z_mvect_dot_vect procedure, pass(x) :: dot_a_vect => z_mvect_dot_vect
generic, public :: dot => dot_v, dot_a, dot_a_vect generic, public :: dot => dot_v, dot_a, dot_a_vect
procedure, pass(x) :: trslv => z_mlv_trslv
procedure, pass(x) :: mlt_mv2 => z_mvect_mlt_mv2
generic, public :: mlt => mlt_mv2
!!$ procedure, pass(y) :: axpby_v => z_mvect_axpby_v !!$ procedure, pass(y) :: axpby_v => z_mvect_axpby_v
!!$ procedure, pass(y) :: axpby_a => z_mvect_axpby_a !!$ procedure, pass(y) :: axpby_a => z_mvect_axpby_a
!!$ generic, public :: axpby => axpby_v, axpby_a !!$ generic, public :: axpby => axpby_v, axpby_a
@ -1876,6 +1879,43 @@ contains
end function z_mvect_dot_a end function z_mvect_dot_a
subroutine z_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info)
implicit none
class(psb_z_multivect_type), intent(inout) :: x
complex(psb_dpk_), intent(in) :: a(:,:)
integer(psb_ipk_), intent(in) :: n
character(len=1), intent(in) :: uplo
complex(psb_dpk_), intent(in), optional :: alpha
character(len=1), intent(in), optional :: trans, diag
integer(psb_ipk_), intent(out) :: info
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
return
else
call x%v%trslv(n,a,uplo,alpha=alpha,trans=trans,diag=diag,info=info)
end if
end subroutine z_mlv_trslv
subroutine z_mvect_mlt_mv2(n,x,y,a,info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: n
class(psb_z_multivect_type), intent(inout) :: x
class(psb_z_multivect_type), intent(inout) :: y
complex(psb_dpk_), intent(inout), allocatable :: a(:,:)
integer(psb_ipk_), intent(out) :: info
if (allocated(x%v).and.allocated(y%v)) then
call y%v%mlt(n,x%v,a,info)
else
info = psb_err_invalid_vect_state_
return
end if
end subroutine z_mvect_mlt_mv2
!!$ subroutine z_mvect_axpby_v(m,alpha, x, beta, y, info) !!$ subroutine z_mvect_axpby_v(m,alpha, x, beta, y, info)
!!$ use psi_serial_mod !!$ use psi_serial_mod
!!$ implicit none !!$ implicit none

@ -356,3 +356,70 @@ subroutine psb_cdiv_vect2_check(x,y,z,desc_a,info,flag)
end subroutine psb_cdiv_vect2_check end subroutine psb_cdiv_vect2_check
subroutine psb_cdiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag)
use psb_base_mod, psb_protect_name => psb_cdiv_trslv
implicit none
type(psb_c_multivect_type), intent (inout) :: x
complex(psb_spk_), intent (in), dimension(:,:) :: a
type(psb_desc_type), intent (in) :: desc_a
character(len=1), intent(in) :: uplo
integer(psb_ipk_), intent(out) :: info
complex(psb_spk_), intent (in), optional :: alpha
character(len=1), intent(in), optional :: trans
character(len=1), intent(in), optional :: diag
! locals
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me,&
& err_act, iix, jjx, iiy, jjy, nr
integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nx
character(len=20) :: name, ch_err
name='psb_cdiv_trslv'
if (psb_errstatus_fatal()) return
info=psb_success_
call psb_erractionsave(err_act)
ctxt=desc_a%get_context()
call psb_info(ctxt, me, np)
if (np == -ione) then
info = psb_err_context_error_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
ix = ione
ijx = ione
m = desc_a%get_global_rows()
nx = x%get_ncols()
! check vector correctness
call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect 1'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
nr = desc_a%get_local_rows()
if(nr > 0) then
call x%trslv(nr,a,uplo,alpha,trans,diag,info)
end if
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine psb_cdiv_trslv

@ -478,12 +478,17 @@ function psb_cdot_mvect_vect(x, y, desc_a,info,global) result(res)
if (x%is_dev()) call x%sync() if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync() if (y%is_dev()) call y%sync()
do i=1,size(desc_a%ovrlap_elem,1) do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1) !idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2) !ndm = desc_a%ovrlap_elem(i,2)
! Remove the overlapped elements via cgemv calls ! Remove the overlapped elements via cgemv calls
! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res ! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res
call cgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), & !call cgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), &
& size(x%v%v,1),y%v%v(idx),ione,done,res,ione) ! & size(x%v%v,1),y%v%v(idx),ione,done,res,ione)
! FIXME: To be fixed for overlapped communicators, e.g., AS
info = psb_err_internal_error_
ch_err='over_elem_unsup'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end do end do
end if end if
else else

@ -29,29 +29,168 @@
! POSSIBILITY OF SUCH DAMAGE. ! POSSIBILITY OF SUCH DAMAGE.
! !
! !
! File: psb_cvmlt.f90 ! File: psb_dvmlt.f90
!!$
!!$subroutine psb_dvmlt(x,y,desc_a,info)
!!$ use psb_base_mod, psb_protect_name => psb_dvmlt
!!$ implicit none
!!$ type(psb_d_vect_type), intent (inout) :: x
!!$ type(psb_d_vect_type), intent (inout) :: y
!!$ type(psb_desc_type), intent (in) :: desc_a
!!$ integer(psb_ipk_), intent(out) :: info
!!$
!!$ ! locals
!!$ integer(psb_ipk_) :: ctxt, np, me,&
!!$ & err_act, iix, jjx, iiy, jjy
!!$ integer(psb_lpk_) :: ix, ijx, iy, ijy, m
!!$ character(len=20) :: name, ch_err
!!$
!!$ name='psb_dgevmlt'
!!$ if (psb_errstatus_fatal()) return
!!$ info=psb_success_
!!$ call psb_erractionsave(err_act)
!!$
!!$ ctxt=desc_a%get_context()
!!$
!!$ call psb_info(ctxt, me, np)
!!$ if (np == -ione) then
!!$ info = psb_err_context_error_
!!$ call psb_errpush(info,name)
!!$ goto 9999
!!$ endif
!!$ if (.not.allocated(x%v)) then
!!$ info = psb_err_invalid_vect_state_
!!$ call psb_errpush(info,name)
!!$ goto 9999
!!$ endif
!!$ if (.not.allocated(y%v)) then
!!$ info = psb_err_invalid_vect_state_
!!$ call psb_errpush(info,name)
!!$ goto 9999
!!$ endif
!!$
!!$
!!$ ix = ione
!!$ iy = ione
!!$
!!$ m = desc_a%get_global_rows()
!!$
!!$ ! check vector correctness
!!$ call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx)
!!$ if(info /= psb_success_) then
!!$ info=psb_err_from_subroutine_
!!$ ch_err='psb_chkvect 1'
!!$ call psb_errpush(info,name,a_err=ch_err)
!!$ goto 9999
!!$ end if
!!$ call psb_chkvect(m,lone,y%get_nrows(),iy,lone,desc_a,info,iiy,jjy)
!!$ if(info /= psb_success_) then
!!$ info=psb_err_from_subroutine_
!!$ ch_err='psb_chkvect 2'
!!$ call psb_errpush(info,name,a_err=ch_err)
!!$ goto 9999
!!$ end if
!!$
!!$ if ((iix /= ione).or.(iiy /= ione)) then
!!$ info=psb_err_ix_n1_iy_n1_unsupported_
!!$ call psb_errpush(info,name)
!!$ end if
!!$
!!$ if(desc_a%get_local_rows() > 0) then
!!$ call y%base_mlt_v(desc_a%get_local_rows(),&
!!$ & alpha,x,beta,info)
!!$ end if
!!$
!!$ call psb_erractionrestore(err_act)
!!$ return
!!$
!!$9999 call psb_error_handler(ctxt,err_act)
!!$
!!$ return
!!$
!!$end subroutine psb_dvmlt
subroutine psb_cvmlt(x,y,desc_a,info) !
use psb_base_mod, psb_protect_name => psb_cvmlt ! Parallel Sparse BLAS version 3.5
! (C) Copyright 2006-2018
! Salvatore Filippone
! Alfredo Buttari
!
! Redistribution and use in source and binary forms, with or without
! modification, are permitted provided that the following conditions
! are met:
! 1. Redistributions of source code must retain the above copyright
! notice, this list of conditions and the following disclaimer.
! 2. Redistributions in binary form must reproduce the above copyright
! notice, this list of conditions, and the following disclaimer in the
! documentation and/or other materials provided with the distribution.
! 3. The name of the PSBLAS group or the names of its contributors may
! not be used to endorse or promote products derived from this
! software without specific written permission.
!
! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
! ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
! TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
! PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE PSBLAS GROUP OR ITS CONTRIBUTORS
! BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
! CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
! SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
! INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
! CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
! POSSIBILITY OF SUCH DAMAGE.
!
!
! File: psb_dvmlt.f90
!
! Function: psb_ddot_multivect
! psb_dvmlta computes the inner product of two distributed multivectors,
!
! dot(:,:) := ( X(:,:) )**C * ( Y(:,:) )
!
!
! Arguments:
! x - type(psb_d_multivect_type) The input vector containing the entries of sub( X ).
! y - type(psb_d_multivect_type) The input vector containing the entries of sub( Y ).
! desc_a - type(psb_desc_type). The communication descriptor.
! info - integer. Return code
! global - logical(optional) Whether to perform the global sum, default: .true.
!
! Note: from a functional point of view, X and Y are input, but here
! they are declared INOUT because of the sync() methods.
!
!
subroutine psb_cmlt_multivect(x, y, res,desc_a,info,global)
use psb_desc_mod
use psb_c_base_mat_mod
use psb_check_mod
use psb_error_mod
use psb_penv_mod
use psb_c_multivect_mod
use psb_c_psblas_mod, psb_protect_name => psb_dmlt_multivect
implicit none implicit none
type(psb_c_vect_type), intent (inout) :: x complex(psb_spk_), dimension(:,:), allocatable, intent(inout) :: res
type(psb_c_vect_type), intent (inout) :: y type(psb_c_multivect_type), intent(inout) :: x, y
type(psb_desc_type), intent (in) :: desc_a type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
! locals ! locals
integer(psb_ipk_) :: ctxt, np, me,& type(psb_ctxt_type) :: ctxt
& err_act, iix, jjx, iiy, jjy integer(psb_ipk_) :: np, me, idx, ndm,&
integer(psb_lpk_) :: ix, ijx, iy, ijy, m & err_act, iix, jjx, iiy, jjy, i, nr
character(len=20) :: name, ch_err integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n, nx, ny
logical :: global_
character(len=20) :: name, ch_err
name='psb_cgevmlt' name='psb_dmlt_multivect'
if (psb_errstatus_fatal()) return
info=psb_success_ info=psb_success_
call psb_erractionsave(err_act) call psb_erractionsave(err_act)
if (psb_errstatus_fatal()) then
info = psb_err_internal_error_ ; goto 9999
end if
ctxt=desc_a%get_context() ctxt=desc_a%get_context()
call psb_info(ctxt, me, np) call psb_info(ctxt, me, np)
if (np == -ione) then if (np == -ione) then
info = psb_err_context_error_ info = psb_err_context_error_
@ -69,24 +208,29 @@ subroutine psb_cvmlt(x,y,desc_a,info)
goto 9999 goto 9999
endif endif
if (present(global)) then
global_ = global
else
global_ = .true.
end if
ix = ione ix = ione
ijx = ione
iy = ione iy = ione
ijy = ione
m = desc_a%get_global_rows() m = desc_a%get_global_rows()
nx = x%get_ncols()
ny = y%get_ncols()
! check vector correctness ! check vector correctness
call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx) call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if(info /= psb_success_) then if (info == psb_success_) &
info=psb_err_from_subroutine_ & call psb_chkvect(m,ny,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy)
ch_err='psb_chkvect 1'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
call psb_chkvect(m,lone,y%get_nrows(),iy,lone,desc_a,info,iiy,jjy)
if(info /= psb_success_) then if(info /= psb_success_) then
info=psb_err_from_subroutine_ info=psb_err_from_subroutine_
ch_err='psb_chkvect 2' ch_err='psb_chkvect'
call psb_errpush(info,name,a_err=ch_err) call psb_errpush(info,name,a_err=ch_err)
goto 9999 goto 9999
end if end if
@ -94,13 +238,55 @@ subroutine psb_cvmlt(x,y,desc_a,info)
if ((iix /= ione).or.(iiy /= ione)) then if ((iix /= ione).or.(iiy /= ione)) then
info=psb_err_ix_n1_iy_n1_unsupported_ info=psb_err_ix_n1_iy_n1_unsupported_
call psb_errpush(info,name) call psb_errpush(info,name)
goto 9999
end if
if (x%get_ncols() /= y%get_ncols()) then
info=psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
else
if (allocated(res)) then
if ((size(res,1) /= x%get_ncols()).or.(size(res,2) /= y%get_ncols())) then
deallocate(res,stat=info)
end if
end if
end if end if
if(desc_a%get_local_rows() > 0) then nr = desc_a%get_local_rows()
call y%base_mlt_v(desc_a%get_local_rows(),& if(nr > 0) then
& alpha,x,beta,info) call x%mlt(nr,y,res,info)
! FIXME
! adjust dot_local because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync()
do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2)
! FIXME: case of AS-type descriptors
! Since I'm coputing res via a dgemm on the whole vector, I need to adjust
! the result by removing the contribution of the overlapped elements
! specifically: res(:,:) = res(:,:) - x%v%v(idx,:)^T*y%v%v(idx,:)
! using dgemm to compute the matrix-matrix product of the form R = R - X'*Y
! where R is the result, X' is the transpose of the matrix x%v%v(idx,:)
! and Y is the matrix y%v%v(idx,:)
! call dgemm('T','N',size(x%v%v(idx,:),1),size(y%v%v(idx,:),1),&
! & size(x%v%v(idx,:),1),-done,x%v%v(idx,:),size(x%v%v(idx,:),1),&
! & y%v%v(idx,:),size(y%v%v(idx,:),1),done,res,y%get_ncols())
info = psb_err_internal_error_
ch_err='over_elem_unsup'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end do
end if
else
res = dzero
end if end if
! compute global sum
if (global_) call psb_sum(ctxt, res)
call psb_erractionrestore(err_act) call psb_erractionrestore(err_act)
return return
@ -108,4 +294,4 @@ subroutine psb_cvmlt(x,y,desc_a,info)
return return
end subroutine psb_cvmlt end subroutine psb_cmlt_multivect

@ -356,6 +356,73 @@ subroutine psb_ddiv_vect2_check(x,y,z,desc_a,info,flag)
end subroutine psb_ddiv_vect2_check end subroutine psb_ddiv_vect2_check
subroutine psb_ddiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag)
use psb_base_mod, psb_protect_name => psb_ddiv_trslv
implicit none
type(psb_d_multivect_type), intent (inout) :: x
real(psb_dpk_), intent (in), dimension(:,:) :: a
type(psb_desc_type), intent (in) :: desc_a
character(len=1), intent(in) :: uplo
integer(psb_ipk_), intent(out) :: info
real(psb_dpk_), intent (in), optional :: alpha
character(len=1), intent(in), optional :: trans
character(len=1), intent(in), optional :: diag
! locals
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me,&
& err_act, iix, jjx, iiy, jjy, nr
integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nx
character(len=20) :: name, ch_err
name='psb_ddiv_trslv'
if (psb_errstatus_fatal()) return
info=psb_success_
call psb_erractionsave(err_act)
ctxt=desc_a%get_context()
call psb_info(ctxt, me, np)
if (np == -ione) then
info = psb_err_context_error_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
ix = ione
ijx = ione
m = desc_a%get_global_rows()
nx = x%get_ncols()
! check vector correctness
call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect 1'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
nr = desc_a%get_local_rows()
if(nr > 0) then
call x%trslv(nr,a,uplo,alpha,trans,diag,info)
end if
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine psb_ddiv_trslv
function psb_dminquotient_vect(x,y,desc_a,info,global) result(res) function psb_dminquotient_vect(x,y,desc_a,info,global) result(res)
use psb_penv_mod use psb_penv_mod
use psb_serial_mod use psb_serial_mod
@ -444,70 +511,3 @@ function psb_dminquotient_vect(x,y,desc_a,info,global) result(res)
return return
end function psb_dminquotient_vect end function psb_dminquotient_vect
subroutine psb_ddiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag)
use psb_base_mod, psb_protect_name => psb_ddiv_trslv
implicit none
type(psb_d_multivect_type), intent (inout) :: x
real(psb_dpk_), intent (in), dimension(:,:) :: a
type(psb_desc_type), intent (in) :: desc_a
character(len=1), intent(in) :: uplo
integer(psb_ipk_), intent(out) :: info
real(psb_dpk_), intent (in), optional :: alpha
character(len=1), intent(in), optional :: trans
character(len=1), intent(in), optional :: diag
! locals
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me,&
& err_act, iix, jjx, iiy, jjy, nr
integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nx
character(len=20) :: name, ch_err
name='psb_ddiv_trslv'
if (psb_errstatus_fatal()) return
info=psb_success_
call psb_erractionsave(err_act)
ctxt=desc_a%get_context()
call psb_info(ctxt, me, np)
if (np == -ione) then
info = psb_err_context_error_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
ix = ione
ijx = ione
m = desc_a%get_global_rows()
nx = x%get_ncols()
! check vector correctness
call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect 1'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
nr = desc_a%get_local_rows()
if(nr > 0) then
call x%trslv(nr,a,uplo,alpha,trans,diag,info)
end if
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine psb_ddiv_trslv

@ -478,13 +478,17 @@ function psb_ddot_mvect_vect(x, y, desc_a,info,global) result(res)
if (x%is_dev()) call x%sync() if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync() if (y%is_dev()) call y%sync()
do i=1,size(desc_a%ovrlap_elem,1) do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1) !idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2) !ndm = desc_a%ovrlap_elem(i,2)
! FIXME: MAKES NO SENSE! ! Remove the overlapped elements via dgemv calls
! Remove the overlapped elements via dgemv calls which are axpy
! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res ! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res
call dgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), & !call dgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), &
& size(x%v%v,1),y%v%v(idx),ione,done,res,ione) ! & size(x%v%v,1),y%v%v(idx),ione,done,res,ione)
! FIXME: To be fixed for overlapped communicators, e.g., AS
info = psb_err_internal_error_
ch_err='over_elem_unsup'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end do end do
end if end if
else else

@ -356,6 +356,73 @@ subroutine psb_sdiv_vect2_check(x,y,z,desc_a,info,flag)
end subroutine psb_sdiv_vect2_check end subroutine psb_sdiv_vect2_check
subroutine psb_sdiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag)
use psb_base_mod, psb_protect_name => psb_sdiv_trslv
implicit none
type(psb_s_multivect_type), intent (inout) :: x
real(psb_spk_), intent (in), dimension(:,:) :: a
type(psb_desc_type), intent (in) :: desc_a
character(len=1), intent(in) :: uplo
integer(psb_ipk_), intent(out) :: info
real(psb_spk_), intent (in), optional :: alpha
character(len=1), intent(in), optional :: trans
character(len=1), intent(in), optional :: diag
! locals
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me,&
& err_act, iix, jjx, iiy, jjy, nr
integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nx
character(len=20) :: name, ch_err
name='psb_sdiv_trslv'
if (psb_errstatus_fatal()) return
info=psb_success_
call psb_erractionsave(err_act)
ctxt=desc_a%get_context()
call psb_info(ctxt, me, np)
if (np == -ione) then
info = psb_err_context_error_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
ix = ione
ijx = ione
m = desc_a%get_global_rows()
nx = x%get_ncols()
! check vector correctness
call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect 1'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
nr = desc_a%get_local_rows()
if(nr > 0) then
call x%trslv(nr,a,uplo,alpha,trans,diag,info)
end if
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine psb_sdiv_trslv
function psb_sminquotient_vect(x,y,desc_a,info,global) result(res) function psb_sminquotient_vect(x,y,desc_a,info,global) result(res)
use psb_penv_mod use psb_penv_mod
use psb_serial_mod use psb_serial_mod

@ -478,12 +478,17 @@ function psb_sdot_mvect_vect(x, y, desc_a,info,global) result(res)
if (x%is_dev()) call x%sync() if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync() if (y%is_dev()) call y%sync()
do i=1,size(desc_a%ovrlap_elem,1) do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1) !idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2) !ndm = desc_a%ovrlap_elem(i,2)
! Remove the overlapped elements via sgemv calls ! Remove the overlapped elements via sgemv calls
! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res ! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res
call sgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), & !call sgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), &
& size(x%v%v,1),y%v%v(idx),ione,done,res,ione) ! & size(x%v%v,1),y%v%v(idx),ione,done,res,ione)
! FIXME: To be fixed for overlapped communicators, e.g., AS
info = psb_err_internal_error_
ch_err='over_elem_unsup'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end do end do
end if end if
else else

@ -29,29 +29,168 @@
! POSSIBILITY OF SUCH DAMAGE. ! POSSIBILITY OF SUCH DAMAGE.
! !
! !
! File: psb_svmlt.f90 ! File: psb_dvmlt.f90
!!$
!!$subroutine psb_dvmlt(x,y,desc_a,info)
!!$ use psb_base_mod, psb_protect_name => psb_dvmlt
!!$ implicit none
!!$ type(psb_d_vect_type), intent (inout) :: x
!!$ type(psb_d_vect_type), intent (inout) :: y
!!$ type(psb_desc_type), intent (in) :: desc_a
!!$ integer(psb_ipk_), intent(out) :: info
!!$
!!$ ! locals
!!$ integer(psb_ipk_) :: ctxt, np, me,&
!!$ & err_act, iix, jjx, iiy, jjy
!!$ integer(psb_lpk_) :: ix, ijx, iy, ijy, m
!!$ character(len=20) :: name, ch_err
!!$
!!$ name='psb_dgevmlt'
!!$ if (psb_errstatus_fatal()) return
!!$ info=psb_success_
!!$ call psb_erractionsave(err_act)
!!$
!!$ ctxt=desc_a%get_context()
!!$
!!$ call psb_info(ctxt, me, np)
!!$ if (np == -ione) then
!!$ info = psb_err_context_error_
!!$ call psb_errpush(info,name)
!!$ goto 9999
!!$ endif
!!$ if (.not.allocated(x%v)) then
!!$ info = psb_err_invalid_vect_state_
!!$ call psb_errpush(info,name)
!!$ goto 9999
!!$ endif
!!$ if (.not.allocated(y%v)) then
!!$ info = psb_err_invalid_vect_state_
!!$ call psb_errpush(info,name)
!!$ goto 9999
!!$ endif
!!$
!!$
!!$ ix = ione
!!$ iy = ione
!!$
!!$ m = desc_a%get_global_rows()
!!$
!!$ ! check vector correctness
!!$ call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx)
!!$ if(info /= psb_success_) then
!!$ info=psb_err_from_subroutine_
!!$ ch_err='psb_chkvect 1'
!!$ call psb_errpush(info,name,a_err=ch_err)
!!$ goto 9999
!!$ end if
!!$ call psb_chkvect(m,lone,y%get_nrows(),iy,lone,desc_a,info,iiy,jjy)
!!$ if(info /= psb_success_) then
!!$ info=psb_err_from_subroutine_
!!$ ch_err='psb_chkvect 2'
!!$ call psb_errpush(info,name,a_err=ch_err)
!!$ goto 9999
!!$ end if
!!$
!!$ if ((iix /= ione).or.(iiy /= ione)) then
!!$ info=psb_err_ix_n1_iy_n1_unsupported_
!!$ call psb_errpush(info,name)
!!$ end if
!!$
!!$ if(desc_a%get_local_rows() > 0) then
!!$ call y%base_mlt_v(desc_a%get_local_rows(),&
!!$ & alpha,x,beta,info)
!!$ end if
!!$
!!$ call psb_erractionrestore(err_act)
!!$ return
!!$
!!$9999 call psb_error_handler(ctxt,err_act)
!!$
!!$ return
!!$
!!$end subroutine psb_dvmlt
subroutine psb_svmlt(x,y,desc_a,info) !
use psb_base_mod, psb_protect_name => psb_svmlt ! Parallel Sparse BLAS version 3.5
! (C) Copyright 2006-2018
! Salvatore Filippone
! Alfredo Buttari
!
! Redistribution and use in source and binary forms, with or without
! modification, are permitted provided that the following conditions
! are met:
! 1. Redistributions of source code must retain the above copyright
! notice, this list of conditions and the following disclaimer.
! 2. Redistributions in binary form must reproduce the above copyright
! notice, this list of conditions, and the following disclaimer in the
! documentation and/or other materials provided with the distribution.
! 3. The name of the PSBLAS group or the names of its contributors may
! not be used to endorse or promote products derived from this
! software without specific written permission.
!
! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
! ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
! TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
! PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE PSBLAS GROUP OR ITS CONTRIBUTORS
! BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
! CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
! SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
! INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
! CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
! POSSIBILITY OF SUCH DAMAGE.
!
!
! File: psb_dvmlt.f90
!
! Function: psb_ddot_multivect
! psb_dvmlta computes the inner product of two distributed multivectors,
!
! dot(:,:) := ( X(:,:) )**C * ( Y(:,:) )
!
!
! Arguments:
! x - type(psb_d_multivect_type) The input vector containing the entries of sub( X ).
! y - type(psb_d_multivect_type) The input vector containing the entries of sub( Y ).
! desc_a - type(psb_desc_type). The communication descriptor.
! info - integer. Return code
! global - logical(optional) Whether to perform the global sum, default: .true.
!
! Note: from a functional point of view, X and Y are input, but here
! they are declared INOUT because of the sync() methods.
!
!
subroutine psb_smlt_multivect(x, y, res,desc_a,info,global)
use psb_desc_mod
use psb_s_base_mat_mod
use psb_check_mod
use psb_error_mod
use psb_penv_mod
use psb_s_multivect_mod
use psb_s_psblas_mod, psb_protect_name => psb_dmlt_multivect
implicit none implicit none
type(psb_s_vect_type), intent (inout) :: x real(psb_spk_), dimension(:,:), allocatable, intent(inout) :: res
type(psb_s_vect_type), intent (inout) :: y type(psb_s_multivect_type), intent(inout) :: x, y
type(psb_desc_type), intent (in) :: desc_a type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
! locals ! locals
integer(psb_ipk_) :: ctxt, np, me,& type(psb_ctxt_type) :: ctxt
& err_act, iix, jjx, iiy, jjy integer(psb_ipk_) :: np, me, idx, ndm,&
integer(psb_lpk_) :: ix, ijx, iy, ijy, m & err_act, iix, jjx, iiy, jjy, i, nr
character(len=20) :: name, ch_err integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n, nx, ny
logical :: global_
character(len=20) :: name, ch_err
name='psb_sgevmlt' name='psb_dmlt_multivect'
if (psb_errstatus_fatal()) return
info=psb_success_ info=psb_success_
call psb_erractionsave(err_act) call psb_erractionsave(err_act)
if (psb_errstatus_fatal()) then
info = psb_err_internal_error_ ; goto 9999
end if
ctxt=desc_a%get_context() ctxt=desc_a%get_context()
call psb_info(ctxt, me, np) call psb_info(ctxt, me, np)
if (np == -ione) then if (np == -ione) then
info = psb_err_context_error_ info = psb_err_context_error_
@ -69,24 +208,29 @@ subroutine psb_svmlt(x,y,desc_a,info)
goto 9999 goto 9999
endif endif
if (present(global)) then
global_ = global
else
global_ = .true.
end if
ix = ione ix = ione
ijx = ione
iy = ione iy = ione
ijy = ione
m = desc_a%get_global_rows() m = desc_a%get_global_rows()
nx = x%get_ncols()
ny = y%get_ncols()
! check vector correctness ! check vector correctness
call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx) call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if(info /= psb_success_) then if (info == psb_success_) &
info=psb_err_from_subroutine_ & call psb_chkvect(m,ny,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy)
ch_err='psb_chkvect 1'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
call psb_chkvect(m,lone,y%get_nrows(),iy,lone,desc_a,info,iiy,jjy)
if(info /= psb_success_) then if(info /= psb_success_) then
info=psb_err_from_subroutine_ info=psb_err_from_subroutine_
ch_err='psb_chkvect 2' ch_err='psb_chkvect'
call psb_errpush(info,name,a_err=ch_err) call psb_errpush(info,name,a_err=ch_err)
goto 9999 goto 9999
end if end if
@ -94,13 +238,55 @@ subroutine psb_svmlt(x,y,desc_a,info)
if ((iix /= ione).or.(iiy /= ione)) then if ((iix /= ione).or.(iiy /= ione)) then
info=psb_err_ix_n1_iy_n1_unsupported_ info=psb_err_ix_n1_iy_n1_unsupported_
call psb_errpush(info,name) call psb_errpush(info,name)
goto 9999
end if
if (x%get_ncols() /= y%get_ncols()) then
info=psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
else
if (allocated(res)) then
if ((size(res,1) /= x%get_ncols()).or.(size(res,2) /= y%get_ncols())) then
deallocate(res,stat=info)
end if
end if
end if end if
if(desc_a%get_local_rows() > 0) then nr = desc_a%get_local_rows()
call y%base_mlt_v(desc_a%get_local_rows(),& if(nr > 0) then
& alpha,x,beta,info) call x%mlt(nr,y,res,info)
! FIXME
! adjust dot_local because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync()
do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2)
! FIXME: case of AS-type descriptors
! Since I'm coputing res via a dgemm on the whole vector, I need to adjust
! the result by removing the contribution of the overlapped elements
! specifically: res(:,:) = res(:,:) - x%v%v(idx,:)^T*y%v%v(idx,:)
! using dgemm to compute the matrix-matrix product of the form R = R - X'*Y
! where R is the result, X' is the transpose of the matrix x%v%v(idx,:)
! and Y is the matrix y%v%v(idx,:)
! call dgemm('T','N',size(x%v%v(idx,:),1),size(y%v%v(idx,:),1),&
! & size(x%v%v(idx,:),1),-done,x%v%v(idx,:),size(x%v%v(idx,:),1),&
! & y%v%v(idx,:),size(y%v%v(idx,:),1),done,res,y%get_ncols())
info = psb_err_internal_error_
ch_err='over_elem_unsup'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end do
end if
else
res = dzero
end if end if
! compute global sum
if (global_) call psb_sum(ctxt, res)
call psb_erractionrestore(err_act) call psb_erractionrestore(err_act)
return return
@ -108,4 +294,4 @@ subroutine psb_svmlt(x,y,desc_a,info)
return return
end subroutine psb_svmlt end subroutine psb_smlt_multivect

@ -356,3 +356,70 @@ subroutine psb_zdiv_vect2_check(x,y,z,desc_a,info,flag)
end subroutine psb_zdiv_vect2_check end subroutine psb_zdiv_vect2_check
subroutine psb_zdiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag)
use psb_base_mod, psb_protect_name => psb_zdiv_trslv
implicit none
type(psb_z_multivect_type), intent (inout) :: x
complex(psb_dpk_), intent (in), dimension(:,:) :: a
type(psb_desc_type), intent (in) :: desc_a
character(len=1), intent(in) :: uplo
integer(psb_ipk_), intent(out) :: info
complex(psb_dpk_), intent (in), optional :: alpha
character(len=1), intent(in), optional :: trans
character(len=1), intent(in), optional :: diag
! locals
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me,&
& err_act, iix, jjx, iiy, jjy, nr
integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nx
character(len=20) :: name, ch_err
name='psb_zdiv_trslv'
if (psb_errstatus_fatal()) return
info=psb_success_
call psb_erractionsave(err_act)
ctxt=desc_a%get_context()
call psb_info(ctxt, me, np)
if (np == -ione) then
info = psb_err_context_error_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
ix = ione
ijx = ione
m = desc_a%get_global_rows()
nx = x%get_ncols()
! check vector correctness
call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect 1'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
nr = desc_a%get_local_rows()
if(nr > 0) then
call x%trslv(nr,a,uplo,alpha,trans,diag,info)
end if
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine psb_zdiv_trslv

@ -478,12 +478,17 @@ function psb_zdot_mvect_vect(x, y, desc_a,info,global) result(res)
if (x%is_dev()) call x%sync() if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync() if (y%is_dev()) call y%sync()
do i=1,size(desc_a%ovrlap_elem,1) do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1) !idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2) !ndm = desc_a%ovrlap_elem(i,2)
! Remove the overlapped elements via zgemv calls ! Remove the overlapped elements via zgemv calls
! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res ! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res
call zgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), & !call zgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), &
& size(x%v%v,1),y%v%v(idx),ione,done,res,ione) ! & size(x%v%v,1),y%v%v(idx),ione,done,res,ione)
! FIXME: To be fixed for overlapped communicators, e.g., AS
info = psb_err_internal_error_
ch_err='over_elem_unsup'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end do end do
end if end if
else else

@ -29,29 +29,168 @@
! POSSIBILITY OF SUCH DAMAGE. ! POSSIBILITY OF SUCH DAMAGE.
! !
! !
! File: psb_zvmlt.f90 ! File: psb_dvmlt.f90
!!$
!!$subroutine psb_dvmlt(x,y,desc_a,info)
!!$ use psb_base_mod, psb_protect_name => psb_dvmlt
!!$ implicit none
!!$ type(psb_d_vect_type), intent (inout) :: x
!!$ type(psb_d_vect_type), intent (inout) :: y
!!$ type(psb_desc_type), intent (in) :: desc_a
!!$ integer(psb_ipk_), intent(out) :: info
!!$
!!$ ! locals
!!$ integer(psb_ipk_) :: ctxt, np, me,&
!!$ & err_act, iix, jjx, iiy, jjy
!!$ integer(psb_lpk_) :: ix, ijx, iy, ijy, m
!!$ character(len=20) :: name, ch_err
!!$
!!$ name='psb_dgevmlt'
!!$ if (psb_errstatus_fatal()) return
!!$ info=psb_success_
!!$ call psb_erractionsave(err_act)
!!$
!!$ ctxt=desc_a%get_context()
!!$
!!$ call psb_info(ctxt, me, np)
!!$ if (np == -ione) then
!!$ info = psb_err_context_error_
!!$ call psb_errpush(info,name)
!!$ goto 9999
!!$ endif
!!$ if (.not.allocated(x%v)) then
!!$ info = psb_err_invalid_vect_state_
!!$ call psb_errpush(info,name)
!!$ goto 9999
!!$ endif
!!$ if (.not.allocated(y%v)) then
!!$ info = psb_err_invalid_vect_state_
!!$ call psb_errpush(info,name)
!!$ goto 9999
!!$ endif
!!$
!!$
!!$ ix = ione
!!$ iy = ione
!!$
!!$ m = desc_a%get_global_rows()
!!$
!!$ ! check vector correctness
!!$ call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx)
!!$ if(info /= psb_success_) then
!!$ info=psb_err_from_subroutine_
!!$ ch_err='psb_chkvect 1'
!!$ call psb_errpush(info,name,a_err=ch_err)
!!$ goto 9999
!!$ end if
!!$ call psb_chkvect(m,lone,y%get_nrows(),iy,lone,desc_a,info,iiy,jjy)
!!$ if(info /= psb_success_) then
!!$ info=psb_err_from_subroutine_
!!$ ch_err='psb_chkvect 2'
!!$ call psb_errpush(info,name,a_err=ch_err)
!!$ goto 9999
!!$ end if
!!$
!!$ if ((iix /= ione).or.(iiy /= ione)) then
!!$ info=psb_err_ix_n1_iy_n1_unsupported_
!!$ call psb_errpush(info,name)
!!$ end if
!!$
!!$ if(desc_a%get_local_rows() > 0) then
!!$ call y%base_mlt_v(desc_a%get_local_rows(),&
!!$ & alpha,x,beta,info)
!!$ end if
!!$
!!$ call psb_erractionrestore(err_act)
!!$ return
!!$
!!$9999 call psb_error_handler(ctxt,err_act)
!!$
!!$ return
!!$
!!$end subroutine psb_dvmlt
subroutine psb_zvmlt(x,y,desc_a,info) !
use psb_base_mod, psb_protect_name => psb_zvmlt ! Parallel Sparse BLAS version 3.5
! (C) Copyright 2006-2018
! Salvatore Filippone
! Alfredo Buttari
!
! Redistribution and use in source and binary forms, with or without
! modification, are permitted provided that the following conditions
! are met:
! 1. Redistributions of source code must retain the above copyright
! notice, this list of conditions and the following disclaimer.
! 2. Redistributions in binary form must reproduce the above copyright
! notice, this list of conditions, and the following disclaimer in the
! documentation and/or other materials provided with the distribution.
! 3. The name of the PSBLAS group or the names of its contributors may
! not be used to endorse or promote products derived from this
! software without specific written permission.
!
! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
! ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
! TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
! PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE PSBLAS GROUP OR ITS CONTRIBUTORS
! BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
! CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
! SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
! INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
! CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
! POSSIBILITY OF SUCH DAMAGE.
!
!
! File: psb_dvmlt.f90
!
! Function: psb_ddot_multivect
! psb_dvmlta computes the inner product of two distributed multivectors,
!
! dot(:,:) := ( X(:,:) )**C * ( Y(:,:) )
!
!
! Arguments:
! x - type(psb_d_multivect_type) The input vector containing the entries of sub( X ).
! y - type(psb_d_multivect_type) The input vector containing the entries of sub( Y ).
! desc_a - type(psb_desc_type). The communication descriptor.
! info - integer. Return code
! global - logical(optional) Whether to perform the global sum, default: .true.
!
! Note: from a functional point of view, X and Y are input, but here
! they are declared INOUT because of the sync() methods.
!
!
subroutine psb_zmlt_multivect(x, y, res,desc_a,info,global)
use psb_desc_mod
use psb_z_base_mat_mod
use psb_check_mod
use psb_error_mod
use psb_penv_mod
use psb_z_multivect_mod
use psb_z_psblas_mod, psb_protect_name => psb_dmlt_multivect
implicit none implicit none
type(psb_z_vect_type), intent (inout) :: x complex(psb_dpk_), dimension(:,:), allocatable, intent(inout) :: res
type(psb_z_vect_type), intent (inout) :: y type(psb_z_multivect_type), intent(inout) :: x, y
type(psb_desc_type), intent (in) :: desc_a type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
! locals ! locals
integer(psb_ipk_) :: ctxt, np, me,& type(psb_ctxt_type) :: ctxt
& err_act, iix, jjx, iiy, jjy integer(psb_ipk_) :: np, me, idx, ndm,&
integer(psb_lpk_) :: ix, ijx, iy, ijy, m & err_act, iix, jjx, iiy, jjy, i, nr
character(len=20) :: name, ch_err integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n, nx, ny
logical :: global_
character(len=20) :: name, ch_err
name='psb_zgevmlt' name='psb_dmlt_multivect'
if (psb_errstatus_fatal()) return
info=psb_success_ info=psb_success_
call psb_erractionsave(err_act) call psb_erractionsave(err_act)
if (psb_errstatus_fatal()) then
info = psb_err_internal_error_ ; goto 9999
end if
ctxt=desc_a%get_context() ctxt=desc_a%get_context()
call psb_info(ctxt, me, np) call psb_info(ctxt, me, np)
if (np == -ione) then if (np == -ione) then
info = psb_err_context_error_ info = psb_err_context_error_
@ -69,24 +208,29 @@ subroutine psb_zvmlt(x,y,desc_a,info)
goto 9999 goto 9999
endif endif
if (present(global)) then
global_ = global
else
global_ = .true.
end if
ix = ione ix = ione
ijx = ione
iy = ione iy = ione
ijy = ione
m = desc_a%get_global_rows() m = desc_a%get_global_rows()
nx = x%get_ncols()
ny = y%get_ncols()
! check vector correctness ! check vector correctness
call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx) call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if(info /= psb_success_) then if (info == psb_success_) &
info=psb_err_from_subroutine_ & call psb_chkvect(m,ny,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy)
ch_err='psb_chkvect 1'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
call psb_chkvect(m,lone,y%get_nrows(),iy,lone,desc_a,info,iiy,jjy)
if(info /= psb_success_) then if(info /= psb_success_) then
info=psb_err_from_subroutine_ info=psb_err_from_subroutine_
ch_err='psb_chkvect 2' ch_err='psb_chkvect'
call psb_errpush(info,name,a_err=ch_err) call psb_errpush(info,name,a_err=ch_err)
goto 9999 goto 9999
end if end if
@ -94,13 +238,55 @@ subroutine psb_zvmlt(x,y,desc_a,info)
if ((iix /= ione).or.(iiy /= ione)) then if ((iix /= ione).or.(iiy /= ione)) then
info=psb_err_ix_n1_iy_n1_unsupported_ info=psb_err_ix_n1_iy_n1_unsupported_
call psb_errpush(info,name) call psb_errpush(info,name)
goto 9999
end if
if (x%get_ncols() /= y%get_ncols()) then
info=psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
else
if (allocated(res)) then
if ((size(res,1) /= x%get_ncols()).or.(size(res,2) /= y%get_ncols())) then
deallocate(res,stat=info)
end if
end if
end if end if
if(desc_a%get_local_rows() > 0) then nr = desc_a%get_local_rows()
call y%base_mlt_v(desc_a%get_local_rows(),& if(nr > 0) then
& alpha,x,beta,info) call x%mlt(nr,y,res,info)
! FIXME
! adjust dot_local because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync()
do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2)
! FIXME: case of AS-type descriptors
! Since I'm coputing res via a dgemm on the whole vector, I need to adjust
! the result by removing the contribution of the overlapped elements
! specifically: res(:,:) = res(:,:) - x%v%v(idx,:)^T*y%v%v(idx,:)
! using dgemm to compute the matrix-matrix product of the form R = R - X'*Y
! where R is the result, X' is the transpose of the matrix x%v%v(idx,:)
! and Y is the matrix y%v%v(idx,:)
! call dgemm('T','N',size(x%v%v(idx,:),1),size(y%v%v(idx,:),1),&
! & size(x%v%v(idx,:),1),-done,x%v%v(idx,:),size(x%v%v(idx,:),1),&
! & y%v%v(idx,:),size(y%v%v(idx,:),1),done,res,y%get_ncols())
info = psb_err_internal_error_
ch_err='over_elem_unsup'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end do
end if
else
res = dzero
end if end if
! compute global sum
if (global_) call psb_sum(ctxt, res)
call psb_erractionrestore(err_act) call psb_erractionrestore(err_act)
return return
@ -108,4 +294,4 @@ subroutine psb_zvmlt(x,y,desc_a,info)
return return
end subroutine psb_zvmlt end subroutine psb_zmlt_multivect

Loading…
Cancel
Save