Added implementation of multivector operations for all type/kinds

randomized
Fabio Durastante 1 year ago
parent e423e149fa
commit 611301a606

@ -529,6 +529,15 @@ module psb_c_psblas_mod
integer(psb_ipk_), intent(out) :: info
character(len=1), intent(in), optional :: conjgx, conjgy
end subroutine psb_cmlt_vect2
subroutine psb_cmlt_multivect(x, y, res, desc_a,info,global)
import :: psb_desc_type, psb_spk_, psb_ipk_, &
& psb_c_multivect_type, psb_cspmat_type
complex(psb_spk_), dimension(:,:), allocatable, intent(inout) :: res
type(psb_c_multivect_type), intent(inout) :: x, y
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
end subroutine psb_cmlt_multivect
end interface
interface psb_gediv
@ -568,6 +577,18 @@ module psb_c_psblas_mod
integer(psb_ipk_), intent(out) :: info
logical, intent(in) :: flag
end subroutine psb_cdiv_vect2_check
subroutine psb_cdiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag)
import :: psb_desc_type, psb_ipk_, &
& psb_spk_, psb_c_multivect_type
type(psb_c_multivect_type), intent (inout) :: x
complex(psb_spk_), intent (in), dimension(:,:) :: a
type(psb_desc_type), intent (in) :: desc_a
character(len=1), intent(in) :: uplo
integer(psb_ipk_), intent(out) :: info
complex(psb_spk_), intent (in), optional :: alpha
character(len=1), intent(in), optional :: trans
character(len=1), intent(in), optional :: diag
end subroutine psb_cdiv_trslv
end interface
interface psb_geinv

@ -540,6 +540,15 @@ module psb_s_psblas_mod
integer(psb_ipk_), intent(out) :: info
character(len=1), intent(in), optional :: conjgx, conjgy
end subroutine psb_smlt_vect2
subroutine psb_smlt_multivect(x, y, res, desc_a,info,global)
import :: psb_desc_type, psb_spk_, psb_ipk_, &
& psb_s_multivect_type, psb_sspmat_type
real(psb_spk_), dimension(:,:), allocatable, intent(inout) :: res
type(psb_s_multivect_type), intent(inout) :: x, y
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
end subroutine psb_smlt_multivect
end interface
interface psb_gediv
@ -579,6 +588,18 @@ module psb_s_psblas_mod
integer(psb_ipk_), intent(out) :: info
logical, intent(in) :: flag
end subroutine psb_sdiv_vect2_check
subroutine psb_sdiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag)
import :: psb_desc_type, psb_ipk_, &
& psb_spk_, psb_s_multivect_type
type(psb_s_multivect_type), intent (inout) :: x
real(psb_spk_), intent (in), dimension(:,:) :: a
type(psb_desc_type), intent (in) :: desc_a
character(len=1), intent(in) :: uplo
integer(psb_ipk_), intent(out) :: info
real(psb_spk_), intent (in), optional :: alpha
character(len=1), intent(in), optional :: trans
character(len=1), intent(in), optional :: diag
end subroutine psb_sdiv_trslv
end interface
interface psb_geinv

@ -529,6 +529,15 @@ module psb_z_psblas_mod
integer(psb_ipk_), intent(out) :: info
character(len=1), intent(in), optional :: conjgx, conjgy
end subroutine psb_zmlt_vect2
subroutine psb_zmlt_multivect(x, y, res, desc_a,info,global)
import :: psb_desc_type, psb_dpk_, psb_ipk_, &
& psb_z_multivect_type, psb_zspmat_type
complex(psb_dpk_), dimension(:,:), allocatable, intent(inout) :: res
type(psb_z_multivect_type), intent(inout) :: x, y
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
end subroutine psb_zmlt_multivect
end interface
interface psb_gediv
@ -568,6 +577,18 @@ module psb_z_psblas_mod
integer(psb_ipk_), intent(out) :: info
logical, intent(in) :: flag
end subroutine psb_zdiv_vect2_check
subroutine psb_zdiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag)
import :: psb_desc_type, psb_ipk_, &
& psb_dpk_, psb_z_multivect_type
type(psb_z_multivect_type), intent (inout) :: x
complex(psb_dpk_), intent (in), dimension(:,:) :: a
type(psb_desc_type), intent (in) :: desc_a
character(len=1), intent(in) :: uplo
integer(psb_ipk_), intent(out) :: info
complex(psb_dpk_), intent (in), optional :: alpha
character(len=1), intent(in), optional :: trans
character(len=1), intent(in), optional :: diag
end subroutine psb_zdiv_trslv
end interface
interface psb_geinv

@ -2177,13 +2177,15 @@ module psb_c_base_multivect_mod
procedure, pass(y) :: mlt_ar2 => c_base_mlv_mlt_ar2
procedure, pass(z) :: mlt_a_2 => c_base_mlv_mlt_a_2
procedure, pass(z) :: mlt_v_2 => c_base_mlv_mlt_v_2
procedure, pass(x) :: mlt_mv2 => c_base_mlv_mlt_mv2
!!$ procedure, pass(z) :: mlt_va => c_base_mlv_mlt_va
!!$ procedure, pass(z) :: mlt_av => c_base_mlv_mlt_av
generic, public :: mlt => mlt_mv, mlt_mv_v, mlt_ar1, mlt_ar2, &
& mlt_a_2, mlt_v_2 !, mlt_av, mlt_va
& mlt_a_2, mlt_v_2, mlt_mv2 !, mlt_av, mlt_va
!
! Scaling and norms
!
procedure, pass(x) :: trslv => c_base_mlv_trslv
procedure, pass(x) :: scal => c_base_mlv_scal
procedure, pass(x) :: nrm2 => c_base_mlv_nrm2
procedure, pass(x) :: amax => c_base_mlv_amax
@ -2672,6 +2674,70 @@ contains
res(1:m,1:n) = x%v(1:m,1:n)
end function c_base_mlv_get_vect
!
!> subroutine d_base_mlv_trslv
!! \memberof psb_d_base_multivect_type
!! \brief Computes X = X / A with A an upper triangular matrix
!! \param n Number of entries to be considered
!! \param x The multivector to be used for the division
!! \param uplo 'U' for upper triangular, 'L' for lower triangular
!! \param a The matrix to be used for the division
!! \param alpha (optional) The scaling factor
!! \param trans (optional) 'N' for no transpose, 'T' for transpose
!! \param diag (optional) 'N' for non-unit diagonal, 'U' for unit diagonal
!! \param info return code
!!
subroutine c_base_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info)
implicit none
class(psb_c_base_multivect_type), intent(inout) :: x
complex(psb_spk_), intent(in) :: a(:,:)
integer(psb_ipk_), intent(in) :: n
character(len=1), intent(in) :: uplo
complex(psb_spk_), intent(in), optional :: alpha
character(len=1), intent(in), optional :: trans, diag
integer(psb_ipk_), intent(out) :: info
! Local variables
integer(psb_ipk_) :: lda, ldb
character(len=1) :: trans_, diag_, side
complex(psb_spk_) :: alpha_
! Default values
if (.not.present(alpha)) then
alpha_ = cone
else
alpha_ = alpha
end if
if (.not.present(trans)) then
trans_ = 'N'
else
trans_ = trans
end if
if (.not.present(diag)) then
diag_ = 'N'
else
diag_ = diag
end if
info = psb_success_
! Check that a is square
if (size(a,1) /= size(a,2)) then
info = psb_err_invalid_input_
return
end if
! Check that a has the same number of columns as x
if (size(a,2) /= x%get_ncols()) then
info = psb_err_invalid_input_
return
end if
if (x%is_dev()) call x%sync()
if (x%is_sync()) then
! Call BLAS function to solve the system
lda = size(a,1)
ldb = x%get_nrows()
side = 'R' ! X*op( A ) = alpha*B.
call dtrsm(side, uplo, trans_, diag_, n, x%get_ncols(), alpha_, a, lda, x%v, ldb)
end if
end subroutine c_base_mlv_trslv
!
! Reset all values
!
@ -2886,6 +2952,48 @@ contains
end subroutine c_base_mlv_axpby_a
!> Function base_mlv_mlt_mv2
!! \memberof psb_d_base_multivect_type
!! \brief computes A = transpose(X)*Y / conjugatetranspose(X)*Y
!! \param x The class(base_mlv_vect) to be multiplied by
!! \param y The class(base_mlv_vect) to be multiplied by
!! \param a The resulting matrix
!! \param info return code
subroutine c_base_mlv_mlt_mv2(n,x,y,a,info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: n
class(psb_c_base_multivect_type), intent(inout) :: x
class(psb_c_base_multivect_type), intent(inout) :: y
complex(psb_spk_), intent(inout), allocatable :: a(:,:)
integer(psb_ipk_), intent(out) :: info
info = psb_success_
if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync()
if (allocated(a)) then
if (size(a,1) /= x%get_ncols()) then
info = psb_err_invalid_input_
return
end if
if (size(a,2) /= y%get_ncols()) then
info = psb_err_invalid_input_
return
end if
else
allocate(a(x%get_ncols(),y%get_ncols()),stat=info)
if (info /= 0) call psb_errpush(psb_err_alloc_dealloc_,'base_mlv_mlt_mv2')
end if
! We do the multiplication by using the BLAS function
! dgemm, which computes the matrix-matrix product
! C = alpha*op( A )*op( B ) + beta*C
! In our case, we want to compute
! C = X'*Y
call dgemm('C', 'N', x%get_ncols(), y%get_ncols(), n, cone, &
& x%v, x%get_nrows(), y%v, y%get_nrows(), czero, a, x%get_ncols())
end subroutine c_base_mlv_mlt_mv2
!
! Multiple variants of two operations:

@ -1344,6 +1344,9 @@ module psb_c_multivect_mod
procedure, pass(x) :: dot_a => c_mvect_dot_a
procedure, pass(x) :: dot_a_vect => c_mvect_dot_vect
generic, public :: dot => dot_v, dot_a, dot_a_vect
procedure, pass(x) :: trslv => c_mlv_trslv
procedure, pass(x) :: mlt_mv2 => c_mvect_mlt_mv2
generic, public :: mlt => mlt_mv2
!!$ procedure, pass(y) :: axpby_v => c_mvect_axpby_v
!!$ procedure, pass(y) :: axpby_a => c_mvect_axpby_a
!!$ generic, public :: axpby => axpby_v, axpby_a
@ -1876,6 +1879,43 @@ contains
end function c_mvect_dot_a
subroutine c_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info)
implicit none
class(psb_c_multivect_type), intent(inout) :: x
complex(psb_spk_), intent(in) :: a(:,:)
integer(psb_ipk_), intent(in) :: n
character(len=1), intent(in) :: uplo
complex(psb_spk_), intent(in), optional :: alpha
character(len=1), intent(in), optional :: trans, diag
integer(psb_ipk_), intent(out) :: info
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
return
else
call x%v%trslv(n,a,uplo,alpha=alpha,trans=trans,diag=diag,info=info)
end if
end subroutine c_mlv_trslv
subroutine c_mvect_mlt_mv2(n,x,y,a,info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: n
class(psb_c_multivect_type), intent(inout) :: x
class(psb_c_multivect_type), intent(inout) :: y
complex(psb_spk_), intent(inout), allocatable :: a(:,:)
integer(psb_ipk_), intent(out) :: info
if (allocated(x%v).and.allocated(y%v)) then
call y%v%mlt(n,x%v,a,info)
else
info = psb_err_invalid_vect_state_
return
end if
end subroutine c_mvect_mlt_mv2
!!$ subroutine c_mvect_axpby_v(m,alpha, x, beta, y, info)
!!$ use psi_serial_mod
!!$ implicit none

@ -2917,8 +2917,6 @@ contains
call dtrsm(side, uplo, trans_, diag_, n, x%get_ncols(), alpha_, a, lda, x%v, ldb)
end if
end subroutine d_base_mlv_trslv
!
! Reset all values
!
@ -3133,10 +3131,9 @@ contains
end subroutine d_base_mlv_axpby_a
!> Function base_mlv_mlt_mv2
!> Function base_mlv_mlt_mv2
!! \memberof psb_d_base_multivect_type
!! \brief computes A = transpose(X)*Y
!! \brief computes A = transpose(X)*Y / conjugatetranspose(X)*Y
!! \param x The class(base_mlv_vect) to be multiplied by
!! \param y The class(base_mlv_vect) to be multiplied by
!! \param a The resulting matrix
@ -3172,13 +3169,11 @@ contains
! C = alpha*op( A )*op( B ) + beta*C
! In our case, we want to compute
! C = X'*Y
call dgemm('T', 'N', x%get_ncols(), y%get_ncols(), n, done, &
call dgemm('C', 'N', x%get_ncols(), y%get_ncols(), n, done, &
& x%v, x%get_nrows(), y%v, y%get_nrows(), dzero, a, x%get_ncols())
end subroutine d_base_mlv_mlt_mv2
!
! Multiple variants of two operations:
! Simple multiplication Y(:.:) = X(:,:)*Y(:,:)

@ -1994,6 +1994,7 @@ contains
end subroutine d_mvect_mlt_mv2
!!$ subroutine d_mvect_axpby_v(m,alpha, x, beta, y, info)
!!$ use psi_serial_mod
!!$ implicit none

@ -2356,13 +2356,15 @@ module psb_s_base_multivect_mod
procedure, pass(y) :: mlt_ar2 => s_base_mlv_mlt_ar2
procedure, pass(z) :: mlt_a_2 => s_base_mlv_mlt_a_2
procedure, pass(z) :: mlt_v_2 => s_base_mlv_mlt_v_2
procedure, pass(x) :: mlt_mv2 => s_base_mlv_mlt_mv2
!!$ procedure, pass(z) :: mlt_va => s_base_mlv_mlt_va
!!$ procedure, pass(z) :: mlt_av => s_base_mlv_mlt_av
generic, public :: mlt => mlt_mv, mlt_mv_v, mlt_ar1, mlt_ar2, &
& mlt_a_2, mlt_v_2 !, mlt_av, mlt_va
& mlt_a_2, mlt_v_2, mlt_mv2 !, mlt_av, mlt_va
!
! Scaling and norms
!
procedure, pass(x) :: trslv => s_base_mlv_trslv
procedure, pass(x) :: scal => s_base_mlv_scal
procedure, pass(x) :: nrm2 => s_base_mlv_nrm2
procedure, pass(x) :: amax => s_base_mlv_amax
@ -2851,6 +2853,70 @@ contains
res(1:m,1:n) = x%v(1:m,1:n)
end function s_base_mlv_get_vect
!
!> subroutine d_base_mlv_trslv
!! \memberof psb_d_base_multivect_type
!! \brief Computes X = X / A with A an upper triangular matrix
!! \param n Number of entries to be considered
!! \param x The multivector to be used for the division
!! \param uplo 'U' for upper triangular, 'L' for lower triangular
!! \param a The matrix to be used for the division
!! \param alpha (optional) The scaling factor
!! \param trans (optional) 'N' for no transpose, 'T' for transpose
!! \param diag (optional) 'N' for non-unit diagonal, 'U' for unit diagonal
!! \param info return code
!!
subroutine s_base_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info)
implicit none
class(psb_s_base_multivect_type), intent(inout) :: x
real(psb_spk_), intent(in) :: a(:,:)
integer(psb_ipk_), intent(in) :: n
character(len=1), intent(in) :: uplo
real(psb_spk_), intent(in), optional :: alpha
character(len=1), intent(in), optional :: trans, diag
integer(psb_ipk_), intent(out) :: info
! Local variables
integer(psb_ipk_) :: lda, ldb
character(len=1) :: trans_, diag_, side
real(psb_spk_) :: alpha_
! Default values
if (.not.present(alpha)) then
alpha_ = sone
else
alpha_ = alpha
end if
if (.not.present(trans)) then
trans_ = 'N'
else
trans_ = trans
end if
if (.not.present(diag)) then
diag_ = 'N'
else
diag_ = diag
end if
info = psb_success_
! Check that a is square
if (size(a,1) /= size(a,2)) then
info = psb_err_invalid_input_
return
end if
! Check that a has the same number of columns as x
if (size(a,2) /= x%get_ncols()) then
info = psb_err_invalid_input_
return
end if
if (x%is_dev()) call x%sync()
if (x%is_sync()) then
! Call BLAS function to solve the system
lda = size(a,1)
ldb = x%get_nrows()
side = 'R' ! X*op( A ) = alpha*B.
call dtrsm(side, uplo, trans_, diag_, n, x%get_ncols(), alpha_, a, lda, x%v, ldb)
end if
end subroutine s_base_mlv_trslv
!
! Reset all values
!
@ -3065,6 +3131,48 @@ contains
end subroutine s_base_mlv_axpby_a
!> Function base_mlv_mlt_mv2
!! \memberof psb_d_base_multivect_type
!! \brief computes A = transpose(X)*Y / conjugatetranspose(X)*Y
!! \param x The class(base_mlv_vect) to be multiplied by
!! \param y The class(base_mlv_vect) to be multiplied by
!! \param a The resulting matrix
!! \param info return code
subroutine s_base_mlv_mlt_mv2(n,x,y,a,info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: n
class(psb_s_base_multivect_type), intent(inout) :: x
class(psb_s_base_multivect_type), intent(inout) :: y
real(psb_spk_), intent(inout), allocatable :: a(:,:)
integer(psb_ipk_), intent(out) :: info
info = psb_success_
if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync()
if (allocated(a)) then
if (size(a,1) /= x%get_ncols()) then
info = psb_err_invalid_input_
return
end if
if (size(a,2) /= y%get_ncols()) then
info = psb_err_invalid_input_
return
end if
else
allocate(a(x%get_ncols(),y%get_ncols()),stat=info)
if (info /= 0) call psb_errpush(psb_err_alloc_dealloc_,'base_mlv_mlt_mv2')
end if
! We do the multiplication by using the BLAS function
! dgemm, which computes the matrix-matrix product
! C = alpha*op( A )*op( B ) + beta*C
! In our case, we want to compute
! C = X'*Y
call dgemm('C', 'N', x%get_ncols(), y%get_ncols(), n, sone, &
& x%v, x%get_nrows(), y%v, y%get_nrows(), szero, a, x%get_ncols())
end subroutine s_base_mlv_mlt_mv2
!
! Multiple variants of two operations:

@ -1423,6 +1423,9 @@ module psb_s_multivect_mod
procedure, pass(x) :: dot_a => s_mvect_dot_a
procedure, pass(x) :: dot_a_vect => s_mvect_dot_vect
generic, public :: dot => dot_v, dot_a, dot_a_vect
procedure, pass(x) :: trslv => s_mlv_trslv
procedure, pass(x) :: mlt_mv2 => s_mvect_mlt_mv2
generic, public :: mlt => mlt_mv2
!!$ procedure, pass(y) :: axpby_v => s_mvect_axpby_v
!!$ procedure, pass(y) :: axpby_a => s_mvect_axpby_a
!!$ generic, public :: axpby => axpby_v, axpby_a
@ -1955,6 +1958,43 @@ contains
end function s_mvect_dot_a
subroutine s_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info)
implicit none
class(psb_s_multivect_type), intent(inout) :: x
real(psb_spk_), intent(in) :: a(:,:)
integer(psb_ipk_), intent(in) :: n
character(len=1), intent(in) :: uplo
real(psb_spk_), intent(in), optional :: alpha
character(len=1), intent(in), optional :: trans, diag
integer(psb_ipk_), intent(out) :: info
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
return
else
call x%v%trslv(n,a,uplo,alpha=alpha,trans=trans,diag=diag,info=info)
end if
end subroutine s_mlv_trslv
subroutine s_mvect_mlt_mv2(n,x,y,a,info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: n
class(psb_s_multivect_type), intent(inout) :: x
class(psb_s_multivect_type), intent(inout) :: y
real(psb_spk_), intent(inout), allocatable :: a(:,:)
integer(psb_ipk_), intent(out) :: info
if (allocated(x%v).and.allocated(y%v)) then
call y%v%mlt(n,x%v,a,info)
else
info = psb_err_invalid_vect_state_
return
end if
end subroutine s_mvect_mlt_mv2
!!$ subroutine s_mvect_axpby_v(m,alpha, x, beta, y, info)
!!$ use psi_serial_mod
!!$ implicit none

@ -2177,13 +2177,15 @@ module psb_z_base_multivect_mod
procedure, pass(y) :: mlt_ar2 => z_base_mlv_mlt_ar2
procedure, pass(z) :: mlt_a_2 => z_base_mlv_mlt_a_2
procedure, pass(z) :: mlt_v_2 => z_base_mlv_mlt_v_2
procedure, pass(x) :: mlt_mv2 => z_base_mlv_mlt_mv2
!!$ procedure, pass(z) :: mlt_va => z_base_mlv_mlt_va
!!$ procedure, pass(z) :: mlt_av => z_base_mlv_mlt_av
generic, public :: mlt => mlt_mv, mlt_mv_v, mlt_ar1, mlt_ar2, &
& mlt_a_2, mlt_v_2 !, mlt_av, mlt_va
& mlt_a_2, mlt_v_2, mlt_mv2 !, mlt_av, mlt_va
!
! Scaling and norms
!
procedure, pass(x) :: trslv => z_base_mlv_trslv
procedure, pass(x) :: scal => z_base_mlv_scal
procedure, pass(x) :: nrm2 => z_base_mlv_nrm2
procedure, pass(x) :: amax => z_base_mlv_amax
@ -2672,6 +2674,70 @@ contains
res(1:m,1:n) = x%v(1:m,1:n)
end function z_base_mlv_get_vect
!
!> subroutine d_base_mlv_trslv
!! \memberof psb_d_base_multivect_type
!! \brief Computes X = X / A with A an upper triangular matrix
!! \param n Number of entries to be considered
!! \param x The multivector to be used for the division
!! \param uplo 'U' for upper triangular, 'L' for lower triangular
!! \param a The matrix to be used for the division
!! \param alpha (optional) The scaling factor
!! \param trans (optional) 'N' for no transpose, 'T' for transpose
!! \param diag (optional) 'N' for non-unit diagonal, 'U' for unit diagonal
!! \param info return code
!!
subroutine z_base_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info)
implicit none
class(psb_z_base_multivect_type), intent(inout) :: x
complex(psb_dpk_), intent(in) :: a(:,:)
integer(psb_ipk_), intent(in) :: n
character(len=1), intent(in) :: uplo
complex(psb_dpk_), intent(in), optional :: alpha
character(len=1), intent(in), optional :: trans, diag
integer(psb_ipk_), intent(out) :: info
! Local variables
integer(psb_ipk_) :: lda, ldb
character(len=1) :: trans_, diag_, side
complex(psb_dpk_) :: alpha_
! Default values
if (.not.present(alpha)) then
alpha_ = zone
else
alpha_ = alpha
end if
if (.not.present(trans)) then
trans_ = 'N'
else
trans_ = trans
end if
if (.not.present(diag)) then
diag_ = 'N'
else
diag_ = diag
end if
info = psb_success_
! Check that a is square
if (size(a,1) /= size(a,2)) then
info = psb_err_invalid_input_
return
end if
! Check that a has the same number of columns as x
if (size(a,2) /= x%get_ncols()) then
info = psb_err_invalid_input_
return
end if
if (x%is_dev()) call x%sync()
if (x%is_sync()) then
! Call BLAS function to solve the system
lda = size(a,1)
ldb = x%get_nrows()
side = 'R' ! X*op( A ) = alpha*B.
call dtrsm(side, uplo, trans_, diag_, n, x%get_ncols(), alpha_, a, lda, x%v, ldb)
end if
end subroutine z_base_mlv_trslv
!
! Reset all values
!
@ -2886,6 +2952,48 @@ contains
end subroutine z_base_mlv_axpby_a
!> Function base_mlv_mlt_mv2
!! \memberof psb_d_base_multivect_type
!! \brief computes A = transpose(X)*Y / conjugatetranspose(X)*Y
!! \param x The class(base_mlv_vect) to be multiplied by
!! \param y The class(base_mlv_vect) to be multiplied by
!! \param a The resulting matrix
!! \param info return code
subroutine z_base_mlv_mlt_mv2(n,x,y,a,info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: n
class(psb_z_base_multivect_type), intent(inout) :: x
class(psb_z_base_multivect_type), intent(inout) :: y
complex(psb_dpk_), intent(inout), allocatable :: a(:,:)
integer(psb_ipk_), intent(out) :: info
info = psb_success_
if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync()
if (allocated(a)) then
if (size(a,1) /= x%get_ncols()) then
info = psb_err_invalid_input_
return
end if
if (size(a,2) /= y%get_ncols()) then
info = psb_err_invalid_input_
return
end if
else
allocate(a(x%get_ncols(),y%get_ncols()),stat=info)
if (info /= 0) call psb_errpush(psb_err_alloc_dealloc_,'base_mlv_mlt_mv2')
end if
! We do the multiplication by using the BLAS function
! dgemm, which computes the matrix-matrix product
! C = alpha*op( A )*op( B ) + beta*C
! In our case, we want to compute
! C = X'*Y
call dgemm('C', 'N', x%get_ncols(), y%get_ncols(), n, zone, &
& x%v, x%get_nrows(), y%v, y%get_nrows(), zzero, a, x%get_ncols())
end subroutine z_base_mlv_mlt_mv2
!
! Multiple variants of two operations:

@ -1344,6 +1344,9 @@ module psb_z_multivect_mod
procedure, pass(x) :: dot_a => z_mvect_dot_a
procedure, pass(x) :: dot_a_vect => z_mvect_dot_vect
generic, public :: dot => dot_v, dot_a, dot_a_vect
procedure, pass(x) :: trslv => z_mlv_trslv
procedure, pass(x) :: mlt_mv2 => z_mvect_mlt_mv2
generic, public :: mlt => mlt_mv2
!!$ procedure, pass(y) :: axpby_v => z_mvect_axpby_v
!!$ procedure, pass(y) :: axpby_a => z_mvect_axpby_a
!!$ generic, public :: axpby => axpby_v, axpby_a
@ -1876,6 +1879,43 @@ contains
end function z_mvect_dot_a
subroutine z_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info)
implicit none
class(psb_z_multivect_type), intent(inout) :: x
complex(psb_dpk_), intent(in) :: a(:,:)
integer(psb_ipk_), intent(in) :: n
character(len=1), intent(in) :: uplo
complex(psb_dpk_), intent(in), optional :: alpha
character(len=1), intent(in), optional :: trans, diag
integer(psb_ipk_), intent(out) :: info
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
return
else
call x%v%trslv(n,a,uplo,alpha=alpha,trans=trans,diag=diag,info=info)
end if
end subroutine z_mlv_trslv
subroutine z_mvect_mlt_mv2(n,x,y,a,info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: n
class(psb_z_multivect_type), intent(inout) :: x
class(psb_z_multivect_type), intent(inout) :: y
complex(psb_dpk_), intent(inout), allocatable :: a(:,:)
integer(psb_ipk_), intent(out) :: info
if (allocated(x%v).and.allocated(y%v)) then
call y%v%mlt(n,x%v,a,info)
else
info = psb_err_invalid_vect_state_
return
end if
end subroutine z_mvect_mlt_mv2
!!$ subroutine z_mvect_axpby_v(m,alpha, x, beta, y, info)
!!$ use psi_serial_mod
!!$ implicit none

@ -356,3 +356,70 @@ subroutine psb_cdiv_vect2_check(x,y,z,desc_a,info,flag)
end subroutine psb_cdiv_vect2_check
subroutine psb_cdiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag)
use psb_base_mod, psb_protect_name => psb_cdiv_trslv
implicit none
type(psb_c_multivect_type), intent (inout) :: x
complex(psb_spk_), intent (in), dimension(:,:) :: a
type(psb_desc_type), intent (in) :: desc_a
character(len=1), intent(in) :: uplo
integer(psb_ipk_), intent(out) :: info
complex(psb_spk_), intent (in), optional :: alpha
character(len=1), intent(in), optional :: trans
character(len=1), intent(in), optional :: diag
! locals
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me,&
& err_act, iix, jjx, iiy, jjy, nr
integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nx
character(len=20) :: name, ch_err
name='psb_cdiv_trslv'
if (psb_errstatus_fatal()) return
info=psb_success_
call psb_erractionsave(err_act)
ctxt=desc_a%get_context()
call psb_info(ctxt, me, np)
if (np == -ione) then
info = psb_err_context_error_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
ix = ione
ijx = ione
m = desc_a%get_global_rows()
nx = x%get_ncols()
! check vector correctness
call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect 1'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
nr = desc_a%get_local_rows()
if(nr > 0) then
call x%trslv(nr,a,uplo,alpha,trans,diag,info)
end if
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine psb_cdiv_trslv

@ -478,12 +478,17 @@ function psb_cdot_mvect_vect(x, y, desc_a,info,global) result(res)
if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync()
do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2)
!idx = desc_a%ovrlap_elem(i,1)
!ndm = desc_a%ovrlap_elem(i,2)
! Remove the overlapped elements via cgemv calls
! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res
call cgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), &
& size(x%v%v,1),y%v%v(idx),ione,done,res,ione)
!call cgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), &
! & size(x%v%v,1),y%v%v(idx),ione,done,res,ione)
! FIXME: To be fixed for overlapped communicators, e.g., AS
info = psb_err_internal_error_
ch_err='over_elem_unsup'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end do
end if
else

@ -29,64 +29,208 @@
! POSSIBILITY OF SUCH DAMAGE.
!
!
! File: psb_cvmlt.f90
! File: psb_dvmlt.f90
!!$
!!$subroutine psb_dvmlt(x,y,desc_a,info)
!!$ use psb_base_mod, psb_protect_name => psb_dvmlt
!!$ implicit none
!!$ type(psb_d_vect_type), intent (inout) :: x
!!$ type(psb_d_vect_type), intent (inout) :: y
!!$ type(psb_desc_type), intent (in) :: desc_a
!!$ integer(psb_ipk_), intent(out) :: info
!!$
!!$ ! locals
!!$ integer(psb_ipk_) :: ctxt, np, me,&
!!$ & err_act, iix, jjx, iiy, jjy
!!$ integer(psb_lpk_) :: ix, ijx, iy, ijy, m
!!$ character(len=20) :: name, ch_err
!!$
!!$ name='psb_dgevmlt'
!!$ if (psb_errstatus_fatal()) return
!!$ info=psb_success_
!!$ call psb_erractionsave(err_act)
!!$
!!$ ctxt=desc_a%get_context()
!!$
!!$ call psb_info(ctxt, me, np)
!!$ if (np == -ione) then
!!$ info = psb_err_context_error_
!!$ call psb_errpush(info,name)
!!$ goto 9999
!!$ endif
!!$ if (.not.allocated(x%v)) then
!!$ info = psb_err_invalid_vect_state_
!!$ call psb_errpush(info,name)
!!$ goto 9999
!!$ endif
!!$ if (.not.allocated(y%v)) then
!!$ info = psb_err_invalid_vect_state_
!!$ call psb_errpush(info,name)
!!$ goto 9999
!!$ endif
!!$
!!$
!!$ ix = ione
!!$ iy = ione
!!$
!!$ m = desc_a%get_global_rows()
!!$
!!$ ! check vector correctness
!!$ call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx)
!!$ if(info /= psb_success_) then
!!$ info=psb_err_from_subroutine_
!!$ ch_err='psb_chkvect 1'
!!$ call psb_errpush(info,name,a_err=ch_err)
!!$ goto 9999
!!$ end if
!!$ call psb_chkvect(m,lone,y%get_nrows(),iy,lone,desc_a,info,iiy,jjy)
!!$ if(info /= psb_success_) then
!!$ info=psb_err_from_subroutine_
!!$ ch_err='psb_chkvect 2'
!!$ call psb_errpush(info,name,a_err=ch_err)
!!$ goto 9999
!!$ end if
!!$
!!$ if ((iix /= ione).or.(iiy /= ione)) then
!!$ info=psb_err_ix_n1_iy_n1_unsupported_
!!$ call psb_errpush(info,name)
!!$ end if
!!$
!!$ if(desc_a%get_local_rows() > 0) then
!!$ call y%base_mlt_v(desc_a%get_local_rows(),&
!!$ & alpha,x,beta,info)
!!$ end if
!!$
!!$ call psb_erractionrestore(err_act)
!!$ return
!!$
!!$9999 call psb_error_handler(ctxt,err_act)
!!$
!!$ return
!!$
!!$end subroutine psb_dvmlt
subroutine psb_cvmlt(x,y,desc_a,info)
use psb_base_mod, psb_protect_name => psb_cvmlt
implicit none
type(psb_c_vect_type), intent (inout) :: x
type(psb_c_vect_type), intent (inout) :: y
type(psb_desc_type), intent (in) :: desc_a
integer(psb_ipk_), intent(out) :: info
!
! Parallel Sparse BLAS version 3.5
! (C) Copyright 2006-2018
! Salvatore Filippone
! Alfredo Buttari
!
! Redistribution and use in source and binary forms, with or without
! modification, are permitted provided that the following conditions
! are met:
! 1. Redistributions of source code must retain the above copyright
! notice, this list of conditions and the following disclaimer.
! 2. Redistributions in binary form must reproduce the above copyright
! notice, this list of conditions, and the following disclaimer in the
! documentation and/or other materials provided with the distribution.
! 3. The name of the PSBLAS group or the names of its contributors may
! not be used to endorse or promote products derived from this
! software without specific written permission.
!
! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
! ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
! TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
! PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE PSBLAS GROUP OR ITS CONTRIBUTORS
! BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
! CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
! SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
! INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
! CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
! POSSIBILITY OF SUCH DAMAGE.
!
!
! File: psb_dvmlt.f90
!
! Function: psb_ddot_multivect
! psb_dvmlta computes the inner product of two distributed multivectors,
!
! dot(:,:) := ( X(:,:) )**C * ( Y(:,:) )
!
!
! Arguments:
! x - type(psb_d_multivect_type) The input vector containing the entries of sub( X ).
! y - type(psb_d_multivect_type) The input vector containing the entries of sub( Y ).
! desc_a - type(psb_desc_type). The communication descriptor.
! info - integer. Return code
! global - logical(optional) Whether to perform the global sum, default: .true.
!
! Note: from a functional point of view, X and Y are input, but here
! they are declared INOUT because of the sync() methods.
!
!
subroutine psb_cmlt_multivect(x, y, res,desc_a,info,global)
use psb_desc_mod
use psb_c_base_mat_mod
use psb_check_mod
use psb_error_mod
use psb_penv_mod
use psb_c_multivect_mod
use psb_c_psblas_mod, psb_protect_name => psb_dmlt_multivect
implicit none
complex(psb_spk_), dimension(:,:), allocatable, intent(inout) :: res
type(psb_c_multivect_type), intent(inout) :: x, y
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
! locals
integer(psb_ipk_) :: ctxt, np, me,&
& err_act, iix, jjx, iiy, jjy
integer(psb_lpk_) :: ix, ijx, iy, ijy, m
character(len=20) :: name, ch_err
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me, idx, ndm,&
& err_act, iix, jjx, iiy, jjy, i, nr
integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n, nx, ny
logical :: global_
character(len=20) :: name, ch_err
name='psb_cgevmlt'
if (psb_errstatus_fatal()) return
name='psb_dmlt_multivect'
info=psb_success_
call psb_erractionsave(err_act)
if (psb_errstatus_fatal()) then
info = psb_err_internal_error_ ; goto 9999
end if
ctxt=desc_a%get_context()
call psb_info(ctxt, me, np)
if (np == -ione) then
info = psb_err_context_error_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(x%v)) then
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(y%v)) then
if (.not.allocated(y%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
if (present(global)) then
global_ = global
else
global_ = .true.
end if
ix = ione
ijx = ione
iy = ione
ijy = ione
m = desc_a%get_global_rows()
nx = x%get_ncols()
ny = y%get_ncols()
! check vector correctness
call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx)
call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if (info == psb_success_) &
& call psb_chkvect(m,ny,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect 1'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
call psb_chkvect(m,lone,y%get_nrows(),iy,lone,desc_a,info,iiy,jjy)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect 2'
ch_err='psb_chkvect'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
@ -94,18 +238,60 @@ subroutine psb_cvmlt(x,y,desc_a,info)
if ((iix /= ione).or.(iiy /= ione)) then
info=psb_err_ix_n1_iy_n1_unsupported_
call psb_errpush(info,name)
goto 9999
end if
if(desc_a%get_local_rows() > 0) then
call y%base_mlt_v(desc_a%get_local_rows(),&
& alpha,x,beta,info)
if (x%get_ncols() /= y%get_ncols()) then
info=psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
else
if (allocated(res)) then
if ((size(res,1) /= x%get_ncols()).or.(size(res,2) /= y%get_ncols())) then
deallocate(res,stat=info)
end if
end if
end if
nr = desc_a%get_local_rows()
if(nr > 0) then
call x%mlt(nr,y,res,info)
! FIXME
! adjust dot_local because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync()
do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2)
! FIXME: case of AS-type descriptors
! Since I'm coputing res via a dgemm on the whole vector, I need to adjust
! the result by removing the contribution of the overlapped elements
! specifically: res(:,:) = res(:,:) - x%v%v(idx,:)^T*y%v%v(idx,:)
! using dgemm to compute the matrix-matrix product of the form R = R - X'*Y
! where R is the result, X' is the transpose of the matrix x%v%v(idx,:)
! and Y is the matrix y%v%v(idx,:)
! call dgemm('T','N',size(x%v%v(idx,:),1),size(y%v%v(idx,:),1),&
! & size(x%v%v(idx,:),1),-done,x%v%v(idx,:),size(x%v%v(idx,:),1),&
! & y%v%v(idx,:),size(y%v%v(idx,:),1),done,res,y%get_ncols())
info = psb_err_internal_error_
ch_err='over_elem_unsup'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end do
end if
else
res = dzero
end if
! compute global sum
if (global_) call psb_sum(ctxt, res)
call psb_erractionrestore(err_act)
return
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine psb_cvmlt
end subroutine psb_cmlt_multivect

@ -356,6 +356,73 @@ subroutine psb_ddiv_vect2_check(x,y,z,desc_a,info,flag)
end subroutine psb_ddiv_vect2_check
subroutine psb_ddiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag)
use psb_base_mod, psb_protect_name => psb_ddiv_trslv
implicit none
type(psb_d_multivect_type), intent (inout) :: x
real(psb_dpk_), intent (in), dimension(:,:) :: a
type(psb_desc_type), intent (in) :: desc_a
character(len=1), intent(in) :: uplo
integer(psb_ipk_), intent(out) :: info
real(psb_dpk_), intent (in), optional :: alpha
character(len=1), intent(in), optional :: trans
character(len=1), intent(in), optional :: diag
! locals
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me,&
& err_act, iix, jjx, iiy, jjy, nr
integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nx
character(len=20) :: name, ch_err
name='psb_ddiv_trslv'
if (psb_errstatus_fatal()) return
info=psb_success_
call psb_erractionsave(err_act)
ctxt=desc_a%get_context()
call psb_info(ctxt, me, np)
if (np == -ione) then
info = psb_err_context_error_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
ix = ione
ijx = ione
m = desc_a%get_global_rows()
nx = x%get_ncols()
! check vector correctness
call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect 1'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
nr = desc_a%get_local_rows()
if(nr > 0) then
call x%trslv(nr,a,uplo,alpha,trans,diag,info)
end if
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine psb_ddiv_trslv
function psb_dminquotient_vect(x,y,desc_a,info,global) result(res)
use psb_penv_mod
use psb_serial_mod
@ -444,70 +511,3 @@ function psb_dminquotient_vect(x,y,desc_a,info,global) result(res)
return
end function psb_dminquotient_vect
subroutine psb_ddiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag)
use psb_base_mod, psb_protect_name => psb_ddiv_trslv
implicit none
type(psb_d_multivect_type), intent (inout) :: x
real(psb_dpk_), intent (in), dimension(:,:) :: a
type(psb_desc_type), intent (in) :: desc_a
character(len=1), intent(in) :: uplo
integer(psb_ipk_), intent(out) :: info
real(psb_dpk_), intent (in), optional :: alpha
character(len=1), intent(in), optional :: trans
character(len=1), intent(in), optional :: diag
! locals
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me,&
& err_act, iix, jjx, iiy, jjy, nr
integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nx
character(len=20) :: name, ch_err
name='psb_ddiv_trslv'
if (psb_errstatus_fatal()) return
info=psb_success_
call psb_erractionsave(err_act)
ctxt=desc_a%get_context()
call psb_info(ctxt, me, np)
if (np == -ione) then
info = psb_err_context_error_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
ix = ione
ijx = ione
m = desc_a%get_global_rows()
nx = x%get_ncols()
! check vector correctness
call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect 1'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
nr = desc_a%get_local_rows()
if(nr > 0) then
call x%trslv(nr,a,uplo,alpha,trans,diag,info)
end if
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine psb_ddiv_trslv

@ -478,13 +478,17 @@ function psb_ddot_mvect_vect(x, y, desc_a,info,global) result(res)
if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync()
do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2)
! FIXME: MAKES NO SENSE!
! Remove the overlapped elements via dgemv calls which are axpy
!idx = desc_a%ovrlap_elem(i,1)
!ndm = desc_a%ovrlap_elem(i,2)
! Remove the overlapped elements via dgemv calls
! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res
call dgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), &
& size(x%v%v,1),y%v%v(idx),ione,done,res,ione)
!call dgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), &
! & size(x%v%v,1),y%v%v(idx),ione,done,res,ione)
! FIXME: To be fixed for overlapped communicators, e.g., AS
info = psb_err_internal_error_
ch_err='over_elem_unsup'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end do
end if
else

@ -356,6 +356,73 @@ subroutine psb_sdiv_vect2_check(x,y,z,desc_a,info,flag)
end subroutine psb_sdiv_vect2_check
subroutine psb_sdiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag)
use psb_base_mod, psb_protect_name => psb_sdiv_trslv
implicit none
type(psb_s_multivect_type), intent (inout) :: x
real(psb_spk_), intent (in), dimension(:,:) :: a
type(psb_desc_type), intent (in) :: desc_a
character(len=1), intent(in) :: uplo
integer(psb_ipk_), intent(out) :: info
real(psb_spk_), intent (in), optional :: alpha
character(len=1), intent(in), optional :: trans
character(len=1), intent(in), optional :: diag
! locals
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me,&
& err_act, iix, jjx, iiy, jjy, nr
integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nx
character(len=20) :: name, ch_err
name='psb_sdiv_trslv'
if (psb_errstatus_fatal()) return
info=psb_success_
call psb_erractionsave(err_act)
ctxt=desc_a%get_context()
call psb_info(ctxt, me, np)
if (np == -ione) then
info = psb_err_context_error_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
ix = ione
ijx = ione
m = desc_a%get_global_rows()
nx = x%get_ncols()
! check vector correctness
call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect 1'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
nr = desc_a%get_local_rows()
if(nr > 0) then
call x%trslv(nr,a,uplo,alpha,trans,diag,info)
end if
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine psb_sdiv_trslv
function psb_sminquotient_vect(x,y,desc_a,info,global) result(res)
use psb_penv_mod
use psb_serial_mod

@ -478,12 +478,17 @@ function psb_sdot_mvect_vect(x, y, desc_a,info,global) result(res)
if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync()
do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2)
!idx = desc_a%ovrlap_elem(i,1)
!ndm = desc_a%ovrlap_elem(i,2)
! Remove the overlapped elements via sgemv calls
! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res
call sgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), &
& size(x%v%v,1),y%v%v(idx),ione,done,res,ione)
!call sgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), &
! & size(x%v%v,1),y%v%v(idx),ione,done,res,ione)
! FIXME: To be fixed for overlapped communicators, e.g., AS
info = psb_err_internal_error_
ch_err='over_elem_unsup'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end do
end if
else

@ -29,64 +29,208 @@
! POSSIBILITY OF SUCH DAMAGE.
!
!
! File: psb_svmlt.f90
! File: psb_dvmlt.f90
!!$
!!$subroutine psb_dvmlt(x,y,desc_a,info)
!!$ use psb_base_mod, psb_protect_name => psb_dvmlt
!!$ implicit none
!!$ type(psb_d_vect_type), intent (inout) :: x
!!$ type(psb_d_vect_type), intent (inout) :: y
!!$ type(psb_desc_type), intent (in) :: desc_a
!!$ integer(psb_ipk_), intent(out) :: info
!!$
!!$ ! locals
!!$ integer(psb_ipk_) :: ctxt, np, me,&
!!$ & err_act, iix, jjx, iiy, jjy
!!$ integer(psb_lpk_) :: ix, ijx, iy, ijy, m
!!$ character(len=20) :: name, ch_err
!!$
!!$ name='psb_dgevmlt'
!!$ if (psb_errstatus_fatal()) return
!!$ info=psb_success_
!!$ call psb_erractionsave(err_act)
!!$
!!$ ctxt=desc_a%get_context()
!!$
!!$ call psb_info(ctxt, me, np)
!!$ if (np == -ione) then
!!$ info = psb_err_context_error_
!!$ call psb_errpush(info,name)
!!$ goto 9999
!!$ endif
!!$ if (.not.allocated(x%v)) then
!!$ info = psb_err_invalid_vect_state_
!!$ call psb_errpush(info,name)
!!$ goto 9999
!!$ endif
!!$ if (.not.allocated(y%v)) then
!!$ info = psb_err_invalid_vect_state_
!!$ call psb_errpush(info,name)
!!$ goto 9999
!!$ endif
!!$
!!$
!!$ ix = ione
!!$ iy = ione
!!$
!!$ m = desc_a%get_global_rows()
!!$
!!$ ! check vector correctness
!!$ call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx)
!!$ if(info /= psb_success_) then
!!$ info=psb_err_from_subroutine_
!!$ ch_err='psb_chkvect 1'
!!$ call psb_errpush(info,name,a_err=ch_err)
!!$ goto 9999
!!$ end if
!!$ call psb_chkvect(m,lone,y%get_nrows(),iy,lone,desc_a,info,iiy,jjy)
!!$ if(info /= psb_success_) then
!!$ info=psb_err_from_subroutine_
!!$ ch_err='psb_chkvect 2'
!!$ call psb_errpush(info,name,a_err=ch_err)
!!$ goto 9999
!!$ end if
!!$
!!$ if ((iix /= ione).or.(iiy /= ione)) then
!!$ info=psb_err_ix_n1_iy_n1_unsupported_
!!$ call psb_errpush(info,name)
!!$ end if
!!$
!!$ if(desc_a%get_local_rows() > 0) then
!!$ call y%base_mlt_v(desc_a%get_local_rows(),&
!!$ & alpha,x,beta,info)
!!$ end if
!!$
!!$ call psb_erractionrestore(err_act)
!!$ return
!!$
!!$9999 call psb_error_handler(ctxt,err_act)
!!$
!!$ return
!!$
!!$end subroutine psb_dvmlt
subroutine psb_svmlt(x,y,desc_a,info)
use psb_base_mod, psb_protect_name => psb_svmlt
implicit none
type(psb_s_vect_type), intent (inout) :: x
type(psb_s_vect_type), intent (inout) :: y
type(psb_desc_type), intent (in) :: desc_a
integer(psb_ipk_), intent(out) :: info
!
! Parallel Sparse BLAS version 3.5
! (C) Copyright 2006-2018
! Salvatore Filippone
! Alfredo Buttari
!
! Redistribution and use in source and binary forms, with or without
! modification, are permitted provided that the following conditions
! are met:
! 1. Redistributions of source code must retain the above copyright
! notice, this list of conditions and the following disclaimer.
! 2. Redistributions in binary form must reproduce the above copyright
! notice, this list of conditions, and the following disclaimer in the
! documentation and/or other materials provided with the distribution.
! 3. The name of the PSBLAS group or the names of its contributors may
! not be used to endorse or promote products derived from this
! software without specific written permission.
!
! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
! ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
! TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
! PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE PSBLAS GROUP OR ITS CONTRIBUTORS
! BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
! CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
! SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
! INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
! CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
! POSSIBILITY OF SUCH DAMAGE.
!
!
! File: psb_dvmlt.f90
!
! Function: psb_ddot_multivect
! psb_dvmlta computes the inner product of two distributed multivectors,
!
! dot(:,:) := ( X(:,:) )**C * ( Y(:,:) )
!
!
! Arguments:
! x - type(psb_d_multivect_type) The input vector containing the entries of sub( X ).
! y - type(psb_d_multivect_type) The input vector containing the entries of sub( Y ).
! desc_a - type(psb_desc_type). The communication descriptor.
! info - integer. Return code
! global - logical(optional) Whether to perform the global sum, default: .true.
!
! Note: from a functional point of view, X and Y are input, but here
! they are declared INOUT because of the sync() methods.
!
!
subroutine psb_smlt_multivect(x, y, res,desc_a,info,global)
use psb_desc_mod
use psb_s_base_mat_mod
use psb_check_mod
use psb_error_mod
use psb_penv_mod
use psb_s_multivect_mod
use psb_s_psblas_mod, psb_protect_name => psb_dmlt_multivect
implicit none
real(psb_spk_), dimension(:,:), allocatable, intent(inout) :: res
type(psb_s_multivect_type), intent(inout) :: x, y
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
! locals
integer(psb_ipk_) :: ctxt, np, me,&
& err_act, iix, jjx, iiy, jjy
integer(psb_lpk_) :: ix, ijx, iy, ijy, m
character(len=20) :: name, ch_err
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me, idx, ndm,&
& err_act, iix, jjx, iiy, jjy, i, nr
integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n, nx, ny
logical :: global_
character(len=20) :: name, ch_err
name='psb_sgevmlt'
if (psb_errstatus_fatal()) return
name='psb_dmlt_multivect'
info=psb_success_
call psb_erractionsave(err_act)
if (psb_errstatus_fatal()) then
info = psb_err_internal_error_ ; goto 9999
end if
ctxt=desc_a%get_context()
call psb_info(ctxt, me, np)
if (np == -ione) then
info = psb_err_context_error_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(x%v)) then
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(y%v)) then
if (.not.allocated(y%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
if (present(global)) then
global_ = global
else
global_ = .true.
end if
ix = ione
ijx = ione
iy = ione
ijy = ione
m = desc_a%get_global_rows()
nx = x%get_ncols()
ny = y%get_ncols()
! check vector correctness
call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx)
call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if (info == psb_success_) &
& call psb_chkvect(m,ny,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect 1'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
call psb_chkvect(m,lone,y%get_nrows(),iy,lone,desc_a,info,iiy,jjy)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect 2'
ch_err='psb_chkvect'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
@ -94,18 +238,60 @@ subroutine psb_svmlt(x,y,desc_a,info)
if ((iix /= ione).or.(iiy /= ione)) then
info=psb_err_ix_n1_iy_n1_unsupported_
call psb_errpush(info,name)
goto 9999
end if
if(desc_a%get_local_rows() > 0) then
call y%base_mlt_v(desc_a%get_local_rows(),&
& alpha,x,beta,info)
if (x%get_ncols() /= y%get_ncols()) then
info=psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
else
if (allocated(res)) then
if ((size(res,1) /= x%get_ncols()).or.(size(res,2) /= y%get_ncols())) then
deallocate(res,stat=info)
end if
end if
end if
nr = desc_a%get_local_rows()
if(nr > 0) then
call x%mlt(nr,y,res,info)
! FIXME
! adjust dot_local because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync()
do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2)
! FIXME: case of AS-type descriptors
! Since I'm coputing res via a dgemm on the whole vector, I need to adjust
! the result by removing the contribution of the overlapped elements
! specifically: res(:,:) = res(:,:) - x%v%v(idx,:)^T*y%v%v(idx,:)
! using dgemm to compute the matrix-matrix product of the form R = R - X'*Y
! where R is the result, X' is the transpose of the matrix x%v%v(idx,:)
! and Y is the matrix y%v%v(idx,:)
! call dgemm('T','N',size(x%v%v(idx,:),1),size(y%v%v(idx,:),1),&
! & size(x%v%v(idx,:),1),-done,x%v%v(idx,:),size(x%v%v(idx,:),1),&
! & y%v%v(idx,:),size(y%v%v(idx,:),1),done,res,y%get_ncols())
info = psb_err_internal_error_
ch_err='over_elem_unsup'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end do
end if
else
res = dzero
end if
! compute global sum
if (global_) call psb_sum(ctxt, res)
call psb_erractionrestore(err_act)
return
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine psb_svmlt
end subroutine psb_smlt_multivect

@ -356,3 +356,70 @@ subroutine psb_zdiv_vect2_check(x,y,z,desc_a,info,flag)
end subroutine psb_zdiv_vect2_check
subroutine psb_zdiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag)
use psb_base_mod, psb_protect_name => psb_zdiv_trslv
implicit none
type(psb_z_multivect_type), intent (inout) :: x
complex(psb_dpk_), intent (in), dimension(:,:) :: a
type(psb_desc_type), intent (in) :: desc_a
character(len=1), intent(in) :: uplo
integer(psb_ipk_), intent(out) :: info
complex(psb_dpk_), intent (in), optional :: alpha
character(len=1), intent(in), optional :: trans
character(len=1), intent(in), optional :: diag
! locals
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me,&
& err_act, iix, jjx, iiy, jjy, nr
integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nx
character(len=20) :: name, ch_err
name='psb_zdiv_trslv'
if (psb_errstatus_fatal()) return
info=psb_success_
call psb_erractionsave(err_act)
ctxt=desc_a%get_context()
call psb_info(ctxt, me, np)
if (np == -ione) then
info = psb_err_context_error_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
ix = ione
ijx = ione
m = desc_a%get_global_rows()
nx = x%get_ncols()
! check vector correctness
call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect 1'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
nr = desc_a%get_local_rows()
if(nr > 0) then
call x%trslv(nr,a,uplo,alpha,trans,diag,info)
end if
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine psb_zdiv_trslv

@ -478,12 +478,17 @@ function psb_zdot_mvect_vect(x, y, desc_a,info,global) result(res)
if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync()
do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2)
!idx = desc_a%ovrlap_elem(i,1)
!ndm = desc_a%ovrlap_elem(i,2)
! Remove the overlapped elements via zgemv calls
! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res
call zgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), &
& size(x%v%v,1),y%v%v(idx),ione,done,res,ione)
!call zgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), &
! & size(x%v%v,1),y%v%v(idx),ione,done,res,ione)
! FIXME: To be fixed for overlapped communicators, e.g., AS
info = psb_err_internal_error_
ch_err='over_elem_unsup'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end do
end if
else

@ -29,64 +29,208 @@
! POSSIBILITY OF SUCH DAMAGE.
!
!
! File: psb_zvmlt.f90
! File: psb_dvmlt.f90
!!$
!!$subroutine psb_dvmlt(x,y,desc_a,info)
!!$ use psb_base_mod, psb_protect_name => psb_dvmlt
!!$ implicit none
!!$ type(psb_d_vect_type), intent (inout) :: x
!!$ type(psb_d_vect_type), intent (inout) :: y
!!$ type(psb_desc_type), intent (in) :: desc_a
!!$ integer(psb_ipk_), intent(out) :: info
!!$
!!$ ! locals
!!$ integer(psb_ipk_) :: ctxt, np, me,&
!!$ & err_act, iix, jjx, iiy, jjy
!!$ integer(psb_lpk_) :: ix, ijx, iy, ijy, m
!!$ character(len=20) :: name, ch_err
!!$
!!$ name='psb_dgevmlt'
!!$ if (psb_errstatus_fatal()) return
!!$ info=psb_success_
!!$ call psb_erractionsave(err_act)
!!$
!!$ ctxt=desc_a%get_context()
!!$
!!$ call psb_info(ctxt, me, np)
!!$ if (np == -ione) then
!!$ info = psb_err_context_error_
!!$ call psb_errpush(info,name)
!!$ goto 9999
!!$ endif
!!$ if (.not.allocated(x%v)) then
!!$ info = psb_err_invalid_vect_state_
!!$ call psb_errpush(info,name)
!!$ goto 9999
!!$ endif
!!$ if (.not.allocated(y%v)) then
!!$ info = psb_err_invalid_vect_state_
!!$ call psb_errpush(info,name)
!!$ goto 9999
!!$ endif
!!$
!!$
!!$ ix = ione
!!$ iy = ione
!!$
!!$ m = desc_a%get_global_rows()
!!$
!!$ ! check vector correctness
!!$ call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx)
!!$ if(info /= psb_success_) then
!!$ info=psb_err_from_subroutine_
!!$ ch_err='psb_chkvect 1'
!!$ call psb_errpush(info,name,a_err=ch_err)
!!$ goto 9999
!!$ end if
!!$ call psb_chkvect(m,lone,y%get_nrows(),iy,lone,desc_a,info,iiy,jjy)
!!$ if(info /= psb_success_) then
!!$ info=psb_err_from_subroutine_
!!$ ch_err='psb_chkvect 2'
!!$ call psb_errpush(info,name,a_err=ch_err)
!!$ goto 9999
!!$ end if
!!$
!!$ if ((iix /= ione).or.(iiy /= ione)) then
!!$ info=psb_err_ix_n1_iy_n1_unsupported_
!!$ call psb_errpush(info,name)
!!$ end if
!!$
!!$ if(desc_a%get_local_rows() > 0) then
!!$ call y%base_mlt_v(desc_a%get_local_rows(),&
!!$ & alpha,x,beta,info)
!!$ end if
!!$
!!$ call psb_erractionrestore(err_act)
!!$ return
!!$
!!$9999 call psb_error_handler(ctxt,err_act)
!!$
!!$ return
!!$
!!$end subroutine psb_dvmlt
subroutine psb_zvmlt(x,y,desc_a,info)
use psb_base_mod, psb_protect_name => psb_zvmlt
implicit none
type(psb_z_vect_type), intent (inout) :: x
type(psb_z_vect_type), intent (inout) :: y
type(psb_desc_type), intent (in) :: desc_a
integer(psb_ipk_), intent(out) :: info
!
! Parallel Sparse BLAS version 3.5
! (C) Copyright 2006-2018
! Salvatore Filippone
! Alfredo Buttari
!
! Redistribution and use in source and binary forms, with or without
! modification, are permitted provided that the following conditions
! are met:
! 1. Redistributions of source code must retain the above copyright
! notice, this list of conditions and the following disclaimer.
! 2. Redistributions in binary form must reproduce the above copyright
! notice, this list of conditions, and the following disclaimer in the
! documentation and/or other materials provided with the distribution.
! 3. The name of the PSBLAS group or the names of its contributors may
! not be used to endorse or promote products derived from this
! software without specific written permission.
!
! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
! ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
! TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
! PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE PSBLAS GROUP OR ITS CONTRIBUTORS
! BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
! CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
! SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
! INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
! CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
! POSSIBILITY OF SUCH DAMAGE.
!
!
! File: psb_dvmlt.f90
!
! Function: psb_ddot_multivect
! psb_dvmlta computes the inner product of two distributed multivectors,
!
! dot(:,:) := ( X(:,:) )**C * ( Y(:,:) )
!
!
! Arguments:
! x - type(psb_d_multivect_type) The input vector containing the entries of sub( X ).
! y - type(psb_d_multivect_type) The input vector containing the entries of sub( Y ).
! desc_a - type(psb_desc_type). The communication descriptor.
! info - integer. Return code
! global - logical(optional) Whether to perform the global sum, default: .true.
!
! Note: from a functional point of view, X and Y are input, but here
! they are declared INOUT because of the sync() methods.
!
!
subroutine psb_zmlt_multivect(x, y, res,desc_a,info,global)
use psb_desc_mod
use psb_z_base_mat_mod
use psb_check_mod
use psb_error_mod
use psb_penv_mod
use psb_z_multivect_mod
use psb_z_psblas_mod, psb_protect_name => psb_dmlt_multivect
implicit none
complex(psb_dpk_), dimension(:,:), allocatable, intent(inout) :: res
type(psb_z_multivect_type), intent(inout) :: x, y
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
! locals
integer(psb_ipk_) :: ctxt, np, me,&
& err_act, iix, jjx, iiy, jjy
integer(psb_lpk_) :: ix, ijx, iy, ijy, m
character(len=20) :: name, ch_err
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me, idx, ndm,&
& err_act, iix, jjx, iiy, jjy, i, nr
integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n, nx, ny
logical :: global_
character(len=20) :: name, ch_err
name='psb_zgevmlt'
if (psb_errstatus_fatal()) return
name='psb_dmlt_multivect'
info=psb_success_
call psb_erractionsave(err_act)
if (psb_errstatus_fatal()) then
info = psb_err_internal_error_ ; goto 9999
end if
ctxt=desc_a%get_context()
call psb_info(ctxt, me, np)
if (np == -ione) then
info = psb_err_context_error_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(x%v)) then
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(y%v)) then
if (.not.allocated(y%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
if (present(global)) then
global_ = global
else
global_ = .true.
end if
ix = ione
ijx = ione
iy = ione
ijy = ione
m = desc_a%get_global_rows()
nx = x%get_ncols()
ny = y%get_ncols()
! check vector correctness
call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx)
call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if (info == psb_success_) &
& call psb_chkvect(m,ny,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect 1'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
call psb_chkvect(m,lone,y%get_nrows(),iy,lone,desc_a,info,iiy,jjy)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect 2'
ch_err='psb_chkvect'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
@ -94,18 +238,60 @@ subroutine psb_zvmlt(x,y,desc_a,info)
if ((iix /= ione).or.(iiy /= ione)) then
info=psb_err_ix_n1_iy_n1_unsupported_
call psb_errpush(info,name)
goto 9999
end if
if(desc_a%get_local_rows() > 0) then
call y%base_mlt_v(desc_a%get_local_rows(),&
& alpha,x,beta,info)
if (x%get_ncols() /= y%get_ncols()) then
info=psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
else
if (allocated(res)) then
if ((size(res,1) /= x%get_ncols()).or.(size(res,2) /= y%get_ncols())) then
deallocate(res,stat=info)
end if
end if
end if
nr = desc_a%get_local_rows()
if(nr > 0) then
call x%mlt(nr,y,res,info)
! FIXME
! adjust dot_local because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%is_dev()) call x%sync()
if (y%is_dev()) call y%sync()
do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2)
! FIXME: case of AS-type descriptors
! Since I'm coputing res via a dgemm on the whole vector, I need to adjust
! the result by removing the contribution of the overlapped elements
! specifically: res(:,:) = res(:,:) - x%v%v(idx,:)^T*y%v%v(idx,:)
! using dgemm to compute the matrix-matrix product of the form R = R - X'*Y
! where R is the result, X' is the transpose of the matrix x%v%v(idx,:)
! and Y is the matrix y%v%v(idx,:)
! call dgemm('T','N',size(x%v%v(idx,:),1),size(y%v%v(idx,:),1),&
! & size(x%v%v(idx,:),1),-done,x%v%v(idx,:),size(x%v%v(idx,:),1),&
! & y%v%v(idx,:),size(y%v%v(idx,:),1),done,res,y%get_ncols())
info = psb_err_internal_error_
ch_err='over_elem_unsup'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end do
end if
else
res = dzero
end if
! compute global sum
if (global_) call psb_sum(ctxt, res)
call psb_erractionrestore(err_act)
return
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine psb_zvmlt
end subroutine psb_zmlt_multivect

Loading…
Cancel
Save