diff --git a/base/modules/psblas/psb_c_psblas_mod.F90 b/base/modules/psblas/psb_c_psblas_mod.F90 index 4a606f7e..c5a5a40f 100644 --- a/base/modules/psblas/psb_c_psblas_mod.F90 +++ b/base/modules/psblas/psb_c_psblas_mod.F90 @@ -529,6 +529,15 @@ module psb_c_psblas_mod integer(psb_ipk_), intent(out) :: info character(len=1), intent(in), optional :: conjgx, conjgy end subroutine psb_cmlt_vect2 + subroutine psb_cmlt_multivect(x, y, res, desc_a,info,global) + import :: psb_desc_type, psb_spk_, psb_ipk_, & + & psb_c_multivect_type, psb_cspmat_type + complex(psb_spk_), dimension(:,:), allocatable, intent(inout) :: res + type(psb_c_multivect_type), intent(inout) :: x, y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global + end subroutine psb_cmlt_multivect end interface interface psb_gediv @@ -568,6 +577,18 @@ module psb_c_psblas_mod integer(psb_ipk_), intent(out) :: info logical, intent(in) :: flag end subroutine psb_cdiv_vect2_check + subroutine psb_cdiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag) + import :: psb_desc_type, psb_ipk_, & + & psb_spk_, psb_c_multivect_type + type(psb_c_multivect_type), intent (inout) :: x + complex(psb_spk_), intent (in), dimension(:,:) :: a + type(psb_desc_type), intent (in) :: desc_a + character(len=1), intent(in) :: uplo + integer(psb_ipk_), intent(out) :: info + complex(psb_spk_), intent (in), optional :: alpha + character(len=1), intent(in), optional :: trans + character(len=1), intent(in), optional :: diag + end subroutine psb_cdiv_trslv end interface interface psb_geinv diff --git a/base/modules/psblas/psb_s_psblas_mod.F90 b/base/modules/psblas/psb_s_psblas_mod.F90 index f8da1c7c..4b14cae0 100644 --- a/base/modules/psblas/psb_s_psblas_mod.F90 +++ b/base/modules/psblas/psb_s_psblas_mod.F90 @@ -540,6 +540,15 @@ module psb_s_psblas_mod integer(psb_ipk_), intent(out) :: info character(len=1), intent(in), optional :: conjgx, conjgy end subroutine psb_smlt_vect2 + subroutine psb_smlt_multivect(x, y, res, desc_a,info,global) + import :: psb_desc_type, psb_spk_, psb_ipk_, & + & psb_s_multivect_type, psb_sspmat_type + real(psb_spk_), dimension(:,:), allocatable, intent(inout) :: res + type(psb_s_multivect_type), intent(inout) :: x, y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global + end subroutine psb_smlt_multivect end interface interface psb_gediv @@ -579,6 +588,18 @@ module psb_s_psblas_mod integer(psb_ipk_), intent(out) :: info logical, intent(in) :: flag end subroutine psb_sdiv_vect2_check + subroutine psb_sdiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag) + import :: psb_desc_type, psb_ipk_, & + & psb_spk_, psb_s_multivect_type + type(psb_s_multivect_type), intent (inout) :: x + real(psb_spk_), intent (in), dimension(:,:) :: a + type(psb_desc_type), intent (in) :: desc_a + character(len=1), intent(in) :: uplo + integer(psb_ipk_), intent(out) :: info + real(psb_spk_), intent (in), optional :: alpha + character(len=1), intent(in), optional :: trans + character(len=1), intent(in), optional :: diag + end subroutine psb_sdiv_trslv end interface interface psb_geinv diff --git a/base/modules/psblas/psb_z_psblas_mod.F90 b/base/modules/psblas/psb_z_psblas_mod.F90 index c45f02af..506fcd07 100644 --- a/base/modules/psblas/psb_z_psblas_mod.F90 +++ b/base/modules/psblas/psb_z_psblas_mod.F90 @@ -529,6 +529,15 @@ module psb_z_psblas_mod integer(psb_ipk_), intent(out) :: info character(len=1), intent(in), optional :: conjgx, conjgy end subroutine psb_zmlt_vect2 + subroutine psb_zmlt_multivect(x, y, res, desc_a,info,global) + import :: psb_desc_type, psb_dpk_, psb_ipk_, & + & psb_z_multivect_type, psb_zspmat_type + complex(psb_dpk_), dimension(:,:), allocatable, intent(inout) :: res + type(psb_z_multivect_type), intent(inout) :: x, y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global + end subroutine psb_zmlt_multivect end interface interface psb_gediv @@ -568,6 +577,18 @@ module psb_z_psblas_mod integer(psb_ipk_), intent(out) :: info logical, intent(in) :: flag end subroutine psb_zdiv_vect2_check + subroutine psb_zdiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag) + import :: psb_desc_type, psb_ipk_, & + & psb_dpk_, psb_z_multivect_type + type(psb_z_multivect_type), intent (inout) :: x + complex(psb_dpk_), intent (in), dimension(:,:) :: a + type(psb_desc_type), intent (in) :: desc_a + character(len=1), intent(in) :: uplo + integer(psb_ipk_), intent(out) :: info + complex(psb_dpk_), intent (in), optional :: alpha + character(len=1), intent(in), optional :: trans + character(len=1), intent(in), optional :: diag + end subroutine psb_zdiv_trslv end interface interface psb_geinv diff --git a/base/modules/serial/psb_c_base_vect_mod.F90 b/base/modules/serial/psb_c_base_vect_mod.F90 index 142add83..cde1fd36 100644 --- a/base/modules/serial/psb_c_base_vect_mod.F90 +++ b/base/modules/serial/psb_c_base_vect_mod.F90 @@ -2177,13 +2177,15 @@ module psb_c_base_multivect_mod procedure, pass(y) :: mlt_ar2 => c_base_mlv_mlt_ar2 procedure, pass(z) :: mlt_a_2 => c_base_mlv_mlt_a_2 procedure, pass(z) :: mlt_v_2 => c_base_mlv_mlt_v_2 + procedure, pass(x) :: mlt_mv2 => c_base_mlv_mlt_mv2 !!$ procedure, pass(z) :: mlt_va => c_base_mlv_mlt_va !!$ procedure, pass(z) :: mlt_av => c_base_mlv_mlt_av generic, public :: mlt => mlt_mv, mlt_mv_v, mlt_ar1, mlt_ar2, & - & mlt_a_2, mlt_v_2 !, mlt_av, mlt_va + & mlt_a_2, mlt_v_2, mlt_mv2 !, mlt_av, mlt_va ! ! Scaling and norms ! + procedure, pass(x) :: trslv => c_base_mlv_trslv procedure, pass(x) :: scal => c_base_mlv_scal procedure, pass(x) :: nrm2 => c_base_mlv_nrm2 procedure, pass(x) :: amax => c_base_mlv_amax @@ -2672,6 +2674,70 @@ contains res(1:m,1:n) = x%v(1:m,1:n) end function c_base_mlv_get_vect + ! + !> subroutine d_base_mlv_trslv + !! \memberof psb_d_base_multivect_type + !! \brief Computes X = X / A with A an upper triangular matrix + !! \param n Number of entries to be considered + !! \param x The multivector to be used for the division + !! \param uplo 'U' for upper triangular, 'L' for lower triangular + !! \param a The matrix to be used for the division + !! \param alpha (optional) The scaling factor + !! \param trans (optional) 'N' for no transpose, 'T' for transpose + !! \param diag (optional) 'N' for non-unit diagonal, 'U' for unit diagonal + !! \param info return code + !! + subroutine c_base_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info) + implicit none + class(psb_c_base_multivect_type), intent(inout) :: x + complex(psb_spk_), intent(in) :: a(:,:) + integer(psb_ipk_), intent(in) :: n + character(len=1), intent(in) :: uplo + complex(psb_spk_), intent(in), optional :: alpha + character(len=1), intent(in), optional :: trans, diag + integer(psb_ipk_), intent(out) :: info + ! Local variables + integer(psb_ipk_) :: lda, ldb + character(len=1) :: trans_, diag_, side + complex(psb_spk_) :: alpha_ + + ! Default values + if (.not.present(alpha)) then + alpha_ = cone + else + alpha_ = alpha + end if + if (.not.present(trans)) then + trans_ = 'N' + else + trans_ = trans + end if + if (.not.present(diag)) then + diag_ = 'N' + else + diag_ = diag + end if + + info = psb_success_ + ! Check that a is square + if (size(a,1) /= size(a,2)) then + info = psb_err_invalid_input_ + return + end if + ! Check that a has the same number of columns as x + if (size(a,2) /= x%get_ncols()) then + info = psb_err_invalid_input_ + return + end if + if (x%is_dev()) call x%sync() + if (x%is_sync()) then + ! Call BLAS function to solve the system + lda = size(a,1) + ldb = x%get_nrows() + side = 'R' ! X*op( A ) = alpha*B. + call dtrsm(side, uplo, trans_, diag_, n, x%get_ncols(), alpha_, a, lda, x%v, ldb) + end if + end subroutine c_base_mlv_trslv ! ! Reset all values ! @@ -2886,6 +2952,48 @@ contains end subroutine c_base_mlv_axpby_a + !> Function base_mlv_mlt_mv2 + !! \memberof psb_d_base_multivect_type + !! \brief computes A = transpose(X)*Y / conjugatetranspose(X)*Y + !! \param x The class(base_mlv_vect) to be multiplied by + !! \param y The class(base_mlv_vect) to be multiplied by + !! \param a The resulting matrix + !! \param info return code + subroutine c_base_mlv_mlt_mv2(n,x,y,a,info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: n + class(psb_c_base_multivect_type), intent(inout) :: x + class(psb_c_base_multivect_type), intent(inout) :: y + complex(psb_spk_), intent(inout), allocatable :: a(:,:) + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + + if (allocated(a)) then + if (size(a,1) /= x%get_ncols()) then + info = psb_err_invalid_input_ + return + end if + if (size(a,2) /= y%get_ncols()) then + info = psb_err_invalid_input_ + return + end if + else + allocate(a(x%get_ncols(),y%get_ncols()),stat=info) + if (info /= 0) call psb_errpush(psb_err_alloc_dealloc_,'base_mlv_mlt_mv2') + end if + ! We do the multiplication by using the BLAS function + ! dgemm, which computes the matrix-matrix product + ! C = alpha*op( A )*op( B ) + beta*C + ! In our case, we want to compute + ! C = X'*Y + call dgemm('C', 'N', x%get_ncols(), y%get_ncols(), n, cone, & + & x%v, x%get_nrows(), y%v, y%get_nrows(), czero, a, x%get_ncols()) + + end subroutine c_base_mlv_mlt_mv2 ! ! Multiple variants of two operations: diff --git a/base/modules/serial/psb_c_vect_mod.F90 b/base/modules/serial/psb_c_vect_mod.F90 index 0371b4a7..9dab78d8 100644 --- a/base/modules/serial/psb_c_vect_mod.F90 +++ b/base/modules/serial/psb_c_vect_mod.F90 @@ -1344,6 +1344,9 @@ module psb_c_multivect_mod procedure, pass(x) :: dot_a => c_mvect_dot_a procedure, pass(x) :: dot_a_vect => c_mvect_dot_vect generic, public :: dot => dot_v, dot_a, dot_a_vect + procedure, pass(x) :: trslv => c_mlv_trslv + procedure, pass(x) :: mlt_mv2 => c_mvect_mlt_mv2 + generic, public :: mlt => mlt_mv2 !!$ procedure, pass(y) :: axpby_v => c_mvect_axpby_v !!$ procedure, pass(y) :: axpby_a => c_mvect_axpby_a !!$ generic, public :: axpby => axpby_v, axpby_a @@ -1876,6 +1879,43 @@ contains end function c_mvect_dot_a + subroutine c_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info) + implicit none + class(psb_c_multivect_type), intent(inout) :: x + complex(psb_spk_), intent(in) :: a(:,:) + integer(psb_ipk_), intent(in) :: n + character(len=1), intent(in) :: uplo + complex(psb_spk_), intent(in), optional :: alpha + character(len=1), intent(in), optional :: trans, diag + integer(psb_ipk_), intent(out) :: info + + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + return + else + call x%v%trslv(n,a,uplo,alpha=alpha,trans=trans,diag=diag,info=info) + end if + end subroutine c_mlv_trslv + + subroutine c_mvect_mlt_mv2(n,x,y,a,info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: n + class(psb_c_multivect_type), intent(inout) :: x + class(psb_c_multivect_type), intent(inout) :: y + complex(psb_spk_), intent(inout), allocatable :: a(:,:) + integer(psb_ipk_), intent(out) :: info + + if (allocated(x%v).and.allocated(y%v)) then + call y%v%mlt(n,x%v,a,info) + else + info = psb_err_invalid_vect_state_ + return + end if + + end subroutine c_mvect_mlt_mv2 + + !!$ subroutine c_mvect_axpby_v(m,alpha, x, beta, y, info) !!$ use psi_serial_mod !!$ implicit none diff --git a/base/modules/serial/psb_d_base_vect_mod.F90 b/base/modules/serial/psb_d_base_vect_mod.F90 index 4c415387..014dd7f8 100644 --- a/base/modules/serial/psb_d_base_vect_mod.F90 +++ b/base/modules/serial/psb_d_base_vect_mod.F90 @@ -2917,8 +2917,6 @@ contains call dtrsm(side, uplo, trans_, diag_, n, x%get_ncols(), alpha_, a, lda, x%v, ldb) end if end subroutine d_base_mlv_trslv - - ! ! Reset all values ! @@ -3133,10 +3131,9 @@ contains end subroutine d_base_mlv_axpby_a - - !> Function base_mlv_mlt_mv2 + !> Function base_mlv_mlt_mv2 !! \memberof psb_d_base_multivect_type - !! \brief computes A = transpose(X)*Y + !! \brief computes A = transpose(X)*Y / conjugatetranspose(X)*Y !! \param x The class(base_mlv_vect) to be multiplied by !! \param y The class(base_mlv_vect) to be multiplied by !! \param a The resulting matrix @@ -3172,13 +3169,11 @@ contains ! C = alpha*op( A )*op( B ) + beta*C ! In our case, we want to compute ! C = X'*Y - call dgemm('T', 'N', x%get_ncols(), y%get_ncols(), n, done, & + call dgemm('C', 'N', x%get_ncols(), y%get_ncols(), n, done, & & x%v, x%get_nrows(), y%v, y%get_nrows(), dzero, a, x%get_ncols()) end subroutine d_base_mlv_mlt_mv2 - - ! ! Multiple variants of two operations: ! Simple multiplication Y(:.:) = X(:,:)*Y(:,:) diff --git a/base/modules/serial/psb_d_vect_mod.F90 b/base/modules/serial/psb_d_vect_mod.F90 index 91862891..2145b05f 100644 --- a/base/modules/serial/psb_d_vect_mod.F90 +++ b/base/modules/serial/psb_d_vect_mod.F90 @@ -1994,6 +1994,7 @@ contains end subroutine d_mvect_mlt_mv2 + !!$ subroutine d_mvect_axpby_v(m,alpha, x, beta, y, info) !!$ use psi_serial_mod !!$ implicit none diff --git a/base/modules/serial/psb_s_base_vect_mod.F90 b/base/modules/serial/psb_s_base_vect_mod.F90 index d4a8cdf3..3460c601 100644 --- a/base/modules/serial/psb_s_base_vect_mod.F90 +++ b/base/modules/serial/psb_s_base_vect_mod.F90 @@ -2356,13 +2356,15 @@ module psb_s_base_multivect_mod procedure, pass(y) :: mlt_ar2 => s_base_mlv_mlt_ar2 procedure, pass(z) :: mlt_a_2 => s_base_mlv_mlt_a_2 procedure, pass(z) :: mlt_v_2 => s_base_mlv_mlt_v_2 + procedure, pass(x) :: mlt_mv2 => s_base_mlv_mlt_mv2 !!$ procedure, pass(z) :: mlt_va => s_base_mlv_mlt_va !!$ procedure, pass(z) :: mlt_av => s_base_mlv_mlt_av generic, public :: mlt => mlt_mv, mlt_mv_v, mlt_ar1, mlt_ar2, & - & mlt_a_2, mlt_v_2 !, mlt_av, mlt_va + & mlt_a_2, mlt_v_2, mlt_mv2 !, mlt_av, mlt_va ! ! Scaling and norms ! + procedure, pass(x) :: trslv => s_base_mlv_trslv procedure, pass(x) :: scal => s_base_mlv_scal procedure, pass(x) :: nrm2 => s_base_mlv_nrm2 procedure, pass(x) :: amax => s_base_mlv_amax @@ -2851,6 +2853,70 @@ contains res(1:m,1:n) = x%v(1:m,1:n) end function s_base_mlv_get_vect + ! + !> subroutine d_base_mlv_trslv + !! \memberof psb_d_base_multivect_type + !! \brief Computes X = X / A with A an upper triangular matrix + !! \param n Number of entries to be considered + !! \param x The multivector to be used for the division + !! \param uplo 'U' for upper triangular, 'L' for lower triangular + !! \param a The matrix to be used for the division + !! \param alpha (optional) The scaling factor + !! \param trans (optional) 'N' for no transpose, 'T' for transpose + !! \param diag (optional) 'N' for non-unit diagonal, 'U' for unit diagonal + !! \param info return code + !! + subroutine s_base_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info) + implicit none + class(psb_s_base_multivect_type), intent(inout) :: x + real(psb_spk_), intent(in) :: a(:,:) + integer(psb_ipk_), intent(in) :: n + character(len=1), intent(in) :: uplo + real(psb_spk_), intent(in), optional :: alpha + character(len=1), intent(in), optional :: trans, diag + integer(psb_ipk_), intent(out) :: info + ! Local variables + integer(psb_ipk_) :: lda, ldb + character(len=1) :: trans_, diag_, side + real(psb_spk_) :: alpha_ + + ! Default values + if (.not.present(alpha)) then + alpha_ = sone + else + alpha_ = alpha + end if + if (.not.present(trans)) then + trans_ = 'N' + else + trans_ = trans + end if + if (.not.present(diag)) then + diag_ = 'N' + else + diag_ = diag + end if + + info = psb_success_ + ! Check that a is square + if (size(a,1) /= size(a,2)) then + info = psb_err_invalid_input_ + return + end if + ! Check that a has the same number of columns as x + if (size(a,2) /= x%get_ncols()) then + info = psb_err_invalid_input_ + return + end if + if (x%is_dev()) call x%sync() + if (x%is_sync()) then + ! Call BLAS function to solve the system + lda = size(a,1) + ldb = x%get_nrows() + side = 'R' ! X*op( A ) = alpha*B. + call dtrsm(side, uplo, trans_, diag_, n, x%get_ncols(), alpha_, a, lda, x%v, ldb) + end if + end subroutine s_base_mlv_trslv ! ! Reset all values ! @@ -3065,6 +3131,48 @@ contains end subroutine s_base_mlv_axpby_a + !> Function base_mlv_mlt_mv2 + !! \memberof psb_d_base_multivect_type + !! \brief computes A = transpose(X)*Y / conjugatetranspose(X)*Y + !! \param x The class(base_mlv_vect) to be multiplied by + !! \param y The class(base_mlv_vect) to be multiplied by + !! \param a The resulting matrix + !! \param info return code + subroutine s_base_mlv_mlt_mv2(n,x,y,a,info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: n + class(psb_s_base_multivect_type), intent(inout) :: x + class(psb_s_base_multivect_type), intent(inout) :: y + real(psb_spk_), intent(inout), allocatable :: a(:,:) + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + + if (allocated(a)) then + if (size(a,1) /= x%get_ncols()) then + info = psb_err_invalid_input_ + return + end if + if (size(a,2) /= y%get_ncols()) then + info = psb_err_invalid_input_ + return + end if + else + allocate(a(x%get_ncols(),y%get_ncols()),stat=info) + if (info /= 0) call psb_errpush(psb_err_alloc_dealloc_,'base_mlv_mlt_mv2') + end if + ! We do the multiplication by using the BLAS function + ! dgemm, which computes the matrix-matrix product + ! C = alpha*op( A )*op( B ) + beta*C + ! In our case, we want to compute + ! C = X'*Y + call dgemm('C', 'N', x%get_ncols(), y%get_ncols(), n, sone, & + & x%v, x%get_nrows(), y%v, y%get_nrows(), szero, a, x%get_ncols()) + + end subroutine s_base_mlv_mlt_mv2 ! ! Multiple variants of two operations: diff --git a/base/modules/serial/psb_s_vect_mod.F90 b/base/modules/serial/psb_s_vect_mod.F90 index 95a7ab02..df1df43a 100644 --- a/base/modules/serial/psb_s_vect_mod.F90 +++ b/base/modules/serial/psb_s_vect_mod.F90 @@ -1423,6 +1423,9 @@ module psb_s_multivect_mod procedure, pass(x) :: dot_a => s_mvect_dot_a procedure, pass(x) :: dot_a_vect => s_mvect_dot_vect generic, public :: dot => dot_v, dot_a, dot_a_vect + procedure, pass(x) :: trslv => s_mlv_trslv + procedure, pass(x) :: mlt_mv2 => s_mvect_mlt_mv2 + generic, public :: mlt => mlt_mv2 !!$ procedure, pass(y) :: axpby_v => s_mvect_axpby_v !!$ procedure, pass(y) :: axpby_a => s_mvect_axpby_a !!$ generic, public :: axpby => axpby_v, axpby_a @@ -1955,6 +1958,43 @@ contains end function s_mvect_dot_a + subroutine s_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info) + implicit none + class(psb_s_multivect_type), intent(inout) :: x + real(psb_spk_), intent(in) :: a(:,:) + integer(psb_ipk_), intent(in) :: n + character(len=1), intent(in) :: uplo + real(psb_spk_), intent(in), optional :: alpha + character(len=1), intent(in), optional :: trans, diag + integer(psb_ipk_), intent(out) :: info + + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + return + else + call x%v%trslv(n,a,uplo,alpha=alpha,trans=trans,diag=diag,info=info) + end if + end subroutine s_mlv_trslv + + subroutine s_mvect_mlt_mv2(n,x,y,a,info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: n + class(psb_s_multivect_type), intent(inout) :: x + class(psb_s_multivect_type), intent(inout) :: y + real(psb_spk_), intent(inout), allocatable :: a(:,:) + integer(psb_ipk_), intent(out) :: info + + if (allocated(x%v).and.allocated(y%v)) then + call y%v%mlt(n,x%v,a,info) + else + info = psb_err_invalid_vect_state_ + return + end if + + end subroutine s_mvect_mlt_mv2 + + !!$ subroutine s_mvect_axpby_v(m,alpha, x, beta, y, info) !!$ use psi_serial_mod !!$ implicit none diff --git a/base/modules/serial/psb_z_base_vect_mod.F90 b/base/modules/serial/psb_z_base_vect_mod.F90 index 0321ff77..0c97f05f 100644 --- a/base/modules/serial/psb_z_base_vect_mod.F90 +++ b/base/modules/serial/psb_z_base_vect_mod.F90 @@ -2177,13 +2177,15 @@ module psb_z_base_multivect_mod procedure, pass(y) :: mlt_ar2 => z_base_mlv_mlt_ar2 procedure, pass(z) :: mlt_a_2 => z_base_mlv_mlt_a_2 procedure, pass(z) :: mlt_v_2 => z_base_mlv_mlt_v_2 + procedure, pass(x) :: mlt_mv2 => z_base_mlv_mlt_mv2 !!$ procedure, pass(z) :: mlt_va => z_base_mlv_mlt_va !!$ procedure, pass(z) :: mlt_av => z_base_mlv_mlt_av generic, public :: mlt => mlt_mv, mlt_mv_v, mlt_ar1, mlt_ar2, & - & mlt_a_2, mlt_v_2 !, mlt_av, mlt_va + & mlt_a_2, mlt_v_2, mlt_mv2 !, mlt_av, mlt_va ! ! Scaling and norms ! + procedure, pass(x) :: trslv => z_base_mlv_trslv procedure, pass(x) :: scal => z_base_mlv_scal procedure, pass(x) :: nrm2 => z_base_mlv_nrm2 procedure, pass(x) :: amax => z_base_mlv_amax @@ -2672,6 +2674,70 @@ contains res(1:m,1:n) = x%v(1:m,1:n) end function z_base_mlv_get_vect + ! + !> subroutine d_base_mlv_trslv + !! \memberof psb_d_base_multivect_type + !! \brief Computes X = X / A with A an upper triangular matrix + !! \param n Number of entries to be considered + !! \param x The multivector to be used for the division + !! \param uplo 'U' for upper triangular, 'L' for lower triangular + !! \param a The matrix to be used for the division + !! \param alpha (optional) The scaling factor + !! \param trans (optional) 'N' for no transpose, 'T' for transpose + !! \param diag (optional) 'N' for non-unit diagonal, 'U' for unit diagonal + !! \param info return code + !! + subroutine z_base_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info) + implicit none + class(psb_z_base_multivect_type), intent(inout) :: x + complex(psb_dpk_), intent(in) :: a(:,:) + integer(psb_ipk_), intent(in) :: n + character(len=1), intent(in) :: uplo + complex(psb_dpk_), intent(in), optional :: alpha + character(len=1), intent(in), optional :: trans, diag + integer(psb_ipk_), intent(out) :: info + ! Local variables + integer(psb_ipk_) :: lda, ldb + character(len=1) :: trans_, diag_, side + complex(psb_dpk_) :: alpha_ + + ! Default values + if (.not.present(alpha)) then + alpha_ = zone + else + alpha_ = alpha + end if + if (.not.present(trans)) then + trans_ = 'N' + else + trans_ = trans + end if + if (.not.present(diag)) then + diag_ = 'N' + else + diag_ = diag + end if + + info = psb_success_ + ! Check that a is square + if (size(a,1) /= size(a,2)) then + info = psb_err_invalid_input_ + return + end if + ! Check that a has the same number of columns as x + if (size(a,2) /= x%get_ncols()) then + info = psb_err_invalid_input_ + return + end if + if (x%is_dev()) call x%sync() + if (x%is_sync()) then + ! Call BLAS function to solve the system + lda = size(a,1) + ldb = x%get_nrows() + side = 'R' ! X*op( A ) = alpha*B. + call dtrsm(side, uplo, trans_, diag_, n, x%get_ncols(), alpha_, a, lda, x%v, ldb) + end if + end subroutine z_base_mlv_trslv ! ! Reset all values ! @@ -2886,6 +2952,48 @@ contains end subroutine z_base_mlv_axpby_a + !> Function base_mlv_mlt_mv2 + !! \memberof psb_d_base_multivect_type + !! \brief computes A = transpose(X)*Y / conjugatetranspose(X)*Y + !! \param x The class(base_mlv_vect) to be multiplied by + !! \param y The class(base_mlv_vect) to be multiplied by + !! \param a The resulting matrix + !! \param info return code + subroutine z_base_mlv_mlt_mv2(n,x,y,a,info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: n + class(psb_z_base_multivect_type), intent(inout) :: x + class(psb_z_base_multivect_type), intent(inout) :: y + complex(psb_dpk_), intent(inout), allocatable :: a(:,:) + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + + if (allocated(a)) then + if (size(a,1) /= x%get_ncols()) then + info = psb_err_invalid_input_ + return + end if + if (size(a,2) /= y%get_ncols()) then + info = psb_err_invalid_input_ + return + end if + else + allocate(a(x%get_ncols(),y%get_ncols()),stat=info) + if (info /= 0) call psb_errpush(psb_err_alloc_dealloc_,'base_mlv_mlt_mv2') + end if + ! We do the multiplication by using the BLAS function + ! dgemm, which computes the matrix-matrix product + ! C = alpha*op( A )*op( B ) + beta*C + ! In our case, we want to compute + ! C = X'*Y + call dgemm('C', 'N', x%get_ncols(), y%get_ncols(), n, zone, & + & x%v, x%get_nrows(), y%v, y%get_nrows(), zzero, a, x%get_ncols()) + + end subroutine z_base_mlv_mlt_mv2 ! ! Multiple variants of two operations: diff --git a/base/modules/serial/psb_z_vect_mod.F90 b/base/modules/serial/psb_z_vect_mod.F90 index c7c0148d..bd014212 100644 --- a/base/modules/serial/psb_z_vect_mod.F90 +++ b/base/modules/serial/psb_z_vect_mod.F90 @@ -1344,6 +1344,9 @@ module psb_z_multivect_mod procedure, pass(x) :: dot_a => z_mvect_dot_a procedure, pass(x) :: dot_a_vect => z_mvect_dot_vect generic, public :: dot => dot_v, dot_a, dot_a_vect + procedure, pass(x) :: trslv => z_mlv_trslv + procedure, pass(x) :: mlt_mv2 => z_mvect_mlt_mv2 + generic, public :: mlt => mlt_mv2 !!$ procedure, pass(y) :: axpby_v => z_mvect_axpby_v !!$ procedure, pass(y) :: axpby_a => z_mvect_axpby_a !!$ generic, public :: axpby => axpby_v, axpby_a @@ -1876,6 +1879,43 @@ contains end function z_mvect_dot_a + subroutine z_mlv_trslv(n,x,a,uplo,alpha,trans,diag,info) + implicit none + class(psb_z_multivect_type), intent(inout) :: x + complex(psb_dpk_), intent(in) :: a(:,:) + integer(psb_ipk_), intent(in) :: n + character(len=1), intent(in) :: uplo + complex(psb_dpk_), intent(in), optional :: alpha + character(len=1), intent(in), optional :: trans, diag + integer(psb_ipk_), intent(out) :: info + + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + return + else + call x%v%trslv(n,a,uplo,alpha=alpha,trans=trans,diag=diag,info=info) + end if + end subroutine z_mlv_trslv + + subroutine z_mvect_mlt_mv2(n,x,y,a,info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: n + class(psb_z_multivect_type), intent(inout) :: x + class(psb_z_multivect_type), intent(inout) :: y + complex(psb_dpk_), intent(inout), allocatable :: a(:,:) + integer(psb_ipk_), intent(out) :: info + + if (allocated(x%v).and.allocated(y%v)) then + call y%v%mlt(n,x%v,a,info) + else + info = psb_err_invalid_vect_state_ + return + end if + + end subroutine z_mvect_mlt_mv2 + + !!$ subroutine z_mvect_axpby_v(m,alpha, x, beta, y, info) !!$ use psi_serial_mod !!$ implicit none diff --git a/base/psblas/psb_cdiv_vect.f90 b/base/psblas/psb_cdiv_vect.f90 index 0fe4594a..47d4be84 100644 --- a/base/psblas/psb_cdiv_vect.f90 +++ b/base/psblas/psb_cdiv_vect.f90 @@ -356,3 +356,70 @@ subroutine psb_cdiv_vect2_check(x,y,z,desc_a,info,flag) end subroutine psb_cdiv_vect2_check +subroutine psb_cdiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag) + use psb_base_mod, psb_protect_name => psb_cdiv_trslv + implicit none + type(psb_c_multivect_type), intent (inout) :: x + complex(psb_spk_), intent (in), dimension(:,:) :: a + type(psb_desc_type), intent (in) :: desc_a + character(len=1), intent(in) :: uplo + integer(psb_ipk_), intent(out) :: info + complex(psb_spk_), intent (in), optional :: alpha + character(len=1), intent(in), optional :: trans + character(len=1), intent(in), optional :: diag + + ! locals + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me,& + & err_act, iix, jjx, iiy, jjy, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nx + character(len=20) :: name, ch_err + + name='psb_cdiv_trslv' + if (psb_errstatus_fatal()) return + info=psb_success_ + call psb_erractionsave(err_act) + + ctxt=desc_a%get_context() + + call psb_info(ctxt, me, np) + if (np == -ione) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + + ix = ione + ijx = ione + + m = desc_a%get_global_rows() + nx = x%get_ncols() + + ! check vector correctness + call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect 1' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + + nr = desc_a%get_local_rows() + if(nr > 0) then + call x%trslv(nr,a,uplo,alpha,trans,diag,info) + end if + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + +end subroutine psb_cdiv_trslv + diff --git a/base/psblas/psb_cdot.f90 b/base/psblas/psb_cdot.f90 index d5381207..2f64ec40 100644 --- a/base/psblas/psb_cdot.f90 +++ b/base/psblas/psb_cdot.f90 @@ -478,12 +478,17 @@ function psb_cdot_mvect_vect(x, y, desc_a,info,global) result(res) if (x%is_dev()) call x%sync() if (y%is_dev()) call y%sync() do i=1,size(desc_a%ovrlap_elem,1) - idx = desc_a%ovrlap_elem(i,1) - ndm = desc_a%ovrlap_elem(i,2) + !idx = desc_a%ovrlap_elem(i,1) + !ndm = desc_a%ovrlap_elem(i,2) ! Remove the overlapped elements via cgemv calls ! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res - call cgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), & - & size(x%v%v,1),y%v%v(idx),ione,done,res,ione) + !call cgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), & + ! & size(x%v%v,1),y%v%v(idx),ione,done,res,ione) + ! FIXME: To be fixed for overlapped communicators, e.g., AS + info = psb_err_internal_error_ + ch_err='over_elem_unsup' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 end do end if else diff --git a/base/psblas/psb_cvmlt.f90 b/base/psblas/psb_cvmlt.f90 index a5ee7bbc..613c8000 100644 --- a/base/psblas/psb_cvmlt.f90 +++ b/base/psblas/psb_cvmlt.f90 @@ -29,64 +29,208 @@ ! POSSIBILITY OF SUCH DAMAGE. ! ! -! File: psb_cvmlt.f90 +! File: psb_dvmlt.f90 +!!$ +!!$subroutine psb_dvmlt(x,y,desc_a,info) +!!$ use psb_base_mod, psb_protect_name => psb_dvmlt +!!$ implicit none +!!$ type(psb_d_vect_type), intent (inout) :: x +!!$ type(psb_d_vect_type), intent (inout) :: y +!!$ type(psb_desc_type), intent (in) :: desc_a +!!$ integer(psb_ipk_), intent(out) :: info +!!$ +!!$ ! locals +!!$ integer(psb_ipk_) :: ctxt, np, me,& +!!$ & err_act, iix, jjx, iiy, jjy +!!$ integer(psb_lpk_) :: ix, ijx, iy, ijy, m +!!$ character(len=20) :: name, ch_err +!!$ +!!$ name='psb_dgevmlt' +!!$ if (psb_errstatus_fatal()) return +!!$ info=psb_success_ +!!$ call psb_erractionsave(err_act) +!!$ +!!$ ctxt=desc_a%get_context() +!!$ +!!$ call psb_info(ctxt, me, np) +!!$ if (np == -ione) then +!!$ info = psb_err_context_error_ +!!$ call psb_errpush(info,name) +!!$ goto 9999 +!!$ endif +!!$ if (.not.allocated(x%v)) then +!!$ info = psb_err_invalid_vect_state_ +!!$ call psb_errpush(info,name) +!!$ goto 9999 +!!$ endif +!!$ if (.not.allocated(y%v)) then +!!$ info = psb_err_invalid_vect_state_ +!!$ call psb_errpush(info,name) +!!$ goto 9999 +!!$ endif +!!$ +!!$ +!!$ ix = ione +!!$ iy = ione +!!$ +!!$ m = desc_a%get_global_rows() +!!$ +!!$ ! check vector correctness +!!$ call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx) +!!$ if(info /= psb_success_) then +!!$ info=psb_err_from_subroutine_ +!!$ ch_err='psb_chkvect 1' +!!$ call psb_errpush(info,name,a_err=ch_err) +!!$ goto 9999 +!!$ end if +!!$ call psb_chkvect(m,lone,y%get_nrows(),iy,lone,desc_a,info,iiy,jjy) +!!$ if(info /= psb_success_) then +!!$ info=psb_err_from_subroutine_ +!!$ ch_err='psb_chkvect 2' +!!$ call psb_errpush(info,name,a_err=ch_err) +!!$ goto 9999 +!!$ end if +!!$ +!!$ if ((iix /= ione).or.(iiy /= ione)) then +!!$ info=psb_err_ix_n1_iy_n1_unsupported_ +!!$ call psb_errpush(info,name) +!!$ end if +!!$ +!!$ if(desc_a%get_local_rows() > 0) then +!!$ call y%base_mlt_v(desc_a%get_local_rows(),& +!!$ & alpha,x,beta,info) +!!$ end if +!!$ +!!$ call psb_erractionrestore(err_act) +!!$ return +!!$ +!!$9999 call psb_error_handler(ctxt,err_act) +!!$ +!!$ return +!!$ +!!$end subroutine psb_dvmlt -subroutine psb_cvmlt(x,y,desc_a,info) - use psb_base_mod, psb_protect_name => psb_cvmlt - implicit none - type(psb_c_vect_type), intent (inout) :: x - type(psb_c_vect_type), intent (inout) :: y - type(psb_desc_type), intent (in) :: desc_a - integer(psb_ipk_), intent(out) :: info +! +! Parallel Sparse BLAS version 3.5 +! (C) Copyright 2006-2018 +! Salvatore Filippone +! Alfredo Buttari +! +! Redistribution and use in source and binary forms, with or without +! modification, are permitted provided that the following conditions +! are met: +! 1. Redistributions of source code must retain the above copyright +! notice, this list of conditions and the following disclaimer. +! 2. Redistributions in binary form must reproduce the above copyright +! notice, this list of conditions, and the following disclaimer in the +! documentation and/or other materials provided with the distribution. +! 3. The name of the PSBLAS group or the names of its contributors may +! not be used to endorse or promote products derived from this +! software without specific written permission. +! +! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +! ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +! TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +! PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE PSBLAS GROUP OR ITS CONTRIBUTORS +! BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +! CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +! SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +! INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +! CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +! POSSIBILITY OF SUCH DAMAGE. +! +! +! File: psb_dvmlt.f90 +! +! Function: psb_ddot_multivect +! psb_dvmlta computes the inner product of two distributed multivectors, +! +! dot(:,:) := ( X(:,:) )**C * ( Y(:,:) ) +! +! +! Arguments: +! x - type(psb_d_multivect_type) The input vector containing the entries of sub( X ). +! y - type(psb_d_multivect_type) The input vector containing the entries of sub( Y ). +! desc_a - type(psb_desc_type). The communication descriptor. +! info - integer. Return code +! global - logical(optional) Whether to perform the global sum, default: .true. +! +! Note: from a functional point of view, X and Y are input, but here +! they are declared INOUT because of the sync() methods. +! +! +subroutine psb_cmlt_multivect(x, y, res,desc_a,info,global) + use psb_desc_mod + use psb_c_base_mat_mod + use psb_check_mod + use psb_error_mod + use psb_penv_mod + use psb_c_multivect_mod + use psb_c_psblas_mod, psb_protect_name => psb_dmlt_multivect + implicit none + complex(psb_spk_), dimension(:,:), allocatable, intent(inout) :: res + type(psb_c_multivect_type), intent(inout) :: x, y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global ! locals - integer(psb_ipk_) :: ctxt, np, me,& - & err_act, iix, jjx, iiy, jjy - integer(psb_lpk_) :: ix, ijx, iy, ijy, m - character(len=20) :: name, ch_err + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me, idx, ndm,& + & err_act, iix, jjx, iiy, jjy, i, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n, nx, ny + logical :: global_ + character(len=20) :: name, ch_err - name='psb_cgevmlt' - if (psb_errstatus_fatal()) return + name='psb_dmlt_multivect' info=psb_success_ call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if ctxt=desc_a%get_context() - call psb_info(ctxt, me, np) if (np == -ione) then info = psb_err_context_error_ call psb_errpush(info,name) goto 9999 endif - if (.not.allocated(x%v)) then + if (.not.allocated(x%v)) then info = psb_err_invalid_vect_state_ call psb_errpush(info,name) goto 9999 endif - if (.not.allocated(y%v)) then + if (.not.allocated(y%v)) then info = psb_err_invalid_vect_state_ call psb_errpush(info,name) goto 9999 endif + if (present(global)) then + global_ = global + else + global_ = .true. + end if ix = ione + ijx = ione + iy = ione + ijy = ione m = desc_a%get_global_rows() + nx = x%get_ncols() + ny = y%get_ncols() ! check vector correctness - call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx) + call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) + if (info == psb_success_) & + & call psb_chkvect(m,ny,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy) if(info /= psb_success_) then info=psb_err_from_subroutine_ - ch_err='psb_chkvect 1' - call psb_errpush(info,name,a_err=ch_err) - goto 9999 - end if - call psb_chkvect(m,lone,y%get_nrows(),iy,lone,desc_a,info,iiy,jjy) - if(info /= psb_success_) then - info=psb_err_from_subroutine_ - ch_err='psb_chkvect 2' + ch_err='psb_chkvect' call psb_errpush(info,name,a_err=ch_err) goto 9999 end if @@ -94,18 +238,60 @@ subroutine psb_cvmlt(x,y,desc_a,info) if ((iix /= ione).or.(iiy /= ione)) then info=psb_err_ix_n1_iy_n1_unsupported_ call psb_errpush(info,name) + goto 9999 end if - if(desc_a%get_local_rows() > 0) then - call y%base_mlt_v(desc_a%get_local_rows(),& - & alpha,x,beta,info) + if (x%get_ncols() /= y%get_ncols()) then + info=psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + else + if (allocated(res)) then + if ((size(res,1) /= x%get_ncols()).or.(size(res,2) /= y%get_ncols())) then + deallocate(res,stat=info) + end if + end if + end if + + nr = desc_a%get_local_rows() + if(nr > 0) then + call x%mlt(nr,y,res,info) + ! FIXME + ! adjust dot_local because overlapped elements are computed more than once + if (size(desc_a%ovrlap_elem,1)>0) then + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + do i=1,size(desc_a%ovrlap_elem,1) + idx = desc_a%ovrlap_elem(i,1) + ndm = desc_a%ovrlap_elem(i,2) + ! FIXME: case of AS-type descriptors + ! Since I'm coputing res via a dgemm on the whole vector, I need to adjust + ! the result by removing the contribution of the overlapped elements + ! specifically: res(:,:) = res(:,:) - x%v%v(idx,:)^T*y%v%v(idx,:) + ! using dgemm to compute the matrix-matrix product of the form R = R - X'*Y + ! where R is the result, X' is the transpose of the matrix x%v%v(idx,:) + ! and Y is the matrix y%v%v(idx,:) + ! call dgemm('T','N',size(x%v%v(idx,:),1),size(y%v%v(idx,:),1),& + ! & size(x%v%v(idx,:),1),-done,x%v%v(idx,:),size(x%v%v(idx,:),1),& + ! & y%v%v(idx,:),size(y%v%v(idx,:),1),done,res,y%get_ncols()) + info = psb_err_internal_error_ + ch_err='over_elem_unsup' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end do + end if + else + res = dzero end if + ! compute global sum + if (global_) call psb_sum(ctxt, res) + call psb_erractionrestore(err_act) - return + return 9999 call psb_error_handler(ctxt,err_act) return -end subroutine psb_cvmlt +end subroutine psb_cmlt_multivect diff --git a/base/psblas/psb_ddiv_vect.f90 b/base/psblas/psb_ddiv_vect.f90 index 8eba3ac1..8b3d8b96 100644 --- a/base/psblas/psb_ddiv_vect.f90 +++ b/base/psblas/psb_ddiv_vect.f90 @@ -356,6 +356,73 @@ subroutine psb_ddiv_vect2_check(x,y,z,desc_a,info,flag) end subroutine psb_ddiv_vect2_check +subroutine psb_ddiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag) + use psb_base_mod, psb_protect_name => psb_ddiv_trslv + implicit none + type(psb_d_multivect_type), intent (inout) :: x + real(psb_dpk_), intent (in), dimension(:,:) :: a + type(psb_desc_type), intent (in) :: desc_a + character(len=1), intent(in) :: uplo + integer(psb_ipk_), intent(out) :: info + real(psb_dpk_), intent (in), optional :: alpha + character(len=1), intent(in), optional :: trans + character(len=1), intent(in), optional :: diag + + ! locals + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me,& + & err_act, iix, jjx, iiy, jjy, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nx + character(len=20) :: name, ch_err + + name='psb_ddiv_trslv' + if (psb_errstatus_fatal()) return + info=psb_success_ + call psb_erractionsave(err_act) + + ctxt=desc_a%get_context() + + call psb_info(ctxt, me, np) + if (np == -ione) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + + ix = ione + ijx = ione + + m = desc_a%get_global_rows() + nx = x%get_ncols() + + ! check vector correctness + call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect 1' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + + nr = desc_a%get_local_rows() + if(nr > 0) then + call x%trslv(nr,a,uplo,alpha,trans,diag,info) + end if + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + +end subroutine psb_ddiv_trslv + function psb_dminquotient_vect(x,y,desc_a,info,global) result(res) use psb_penv_mod use psb_serial_mod @@ -444,70 +511,3 @@ function psb_dminquotient_vect(x,y,desc_a,info,global) result(res) return end function psb_dminquotient_vect - -subroutine psb_ddiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag) - use psb_base_mod, psb_protect_name => psb_ddiv_trslv - implicit none - type(psb_d_multivect_type), intent (inout) :: x - real(psb_dpk_), intent (in), dimension(:,:) :: a - type(psb_desc_type), intent (in) :: desc_a - character(len=1), intent(in) :: uplo - integer(psb_ipk_), intent(out) :: info - real(psb_dpk_), intent (in), optional :: alpha - character(len=1), intent(in), optional :: trans - character(len=1), intent(in), optional :: diag - - ! locals - type(psb_ctxt_type) :: ctxt - integer(psb_ipk_) :: np, me,& - & err_act, iix, jjx, iiy, jjy, nr - integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nx - character(len=20) :: name, ch_err - - name='psb_ddiv_trslv' - if (psb_errstatus_fatal()) return - info=psb_success_ - call psb_erractionsave(err_act) - - ctxt=desc_a%get_context() - - call psb_info(ctxt, me, np) - if (np == -ione) then - info = psb_err_context_error_ - call psb_errpush(info,name) - goto 9999 - endif - if (.not.allocated(x%v)) then - info = psb_err_invalid_vect_state_ - call psb_errpush(info,name) - goto 9999 - endif - - ix = ione - ijx = ione - - m = desc_a%get_global_rows() - nx = x%get_ncols() - - ! check vector correctness - call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) - if(info /= psb_success_) then - info=psb_err_from_subroutine_ - ch_err='psb_chkvect 1' - call psb_errpush(info,name,a_err=ch_err) - goto 9999 - end if - - nr = desc_a%get_local_rows() - if(nr > 0) then - call x%trslv(nr,a,uplo,alpha,trans,diag,info) - end if - - call psb_erractionrestore(err_act) - return - -9999 call psb_error_handler(ctxt,err_act) - - return - -end subroutine psb_ddiv_trslv diff --git a/base/psblas/psb_ddot.f90 b/base/psblas/psb_ddot.f90 index 949efb2c..e2d77192 100644 --- a/base/psblas/psb_ddot.f90 +++ b/base/psblas/psb_ddot.f90 @@ -478,13 +478,17 @@ function psb_ddot_mvect_vect(x, y, desc_a,info,global) result(res) if (x%is_dev()) call x%sync() if (y%is_dev()) call y%sync() do i=1,size(desc_a%ovrlap_elem,1) - idx = desc_a%ovrlap_elem(i,1) - ndm = desc_a%ovrlap_elem(i,2) - ! FIXME: MAKES NO SENSE! - ! Remove the overlapped elements via dgemv calls which are axpy + !idx = desc_a%ovrlap_elem(i,1) + !ndm = desc_a%ovrlap_elem(i,2) + ! Remove the overlapped elements via dgemv calls ! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res - call dgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), & - & size(x%v%v,1),y%v%v(idx),ione,done,res,ione) + !call dgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), & + ! & size(x%v%v,1),y%v%v(idx),ione,done,res,ione) + ! FIXME: To be fixed for overlapped communicators, e.g., AS + info = psb_err_internal_error_ + ch_err='over_elem_unsup' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 end do end if else diff --git a/base/psblas/psb_sdiv_vect.f90 b/base/psblas/psb_sdiv_vect.f90 index 70bb96d0..fdbda3ed 100644 --- a/base/psblas/psb_sdiv_vect.f90 +++ b/base/psblas/psb_sdiv_vect.f90 @@ -356,6 +356,73 @@ subroutine psb_sdiv_vect2_check(x,y,z,desc_a,info,flag) end subroutine psb_sdiv_vect2_check +subroutine psb_sdiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag) + use psb_base_mod, psb_protect_name => psb_sdiv_trslv + implicit none + type(psb_s_multivect_type), intent (inout) :: x + real(psb_spk_), intent (in), dimension(:,:) :: a + type(psb_desc_type), intent (in) :: desc_a + character(len=1), intent(in) :: uplo + integer(psb_ipk_), intent(out) :: info + real(psb_spk_), intent (in), optional :: alpha + character(len=1), intent(in), optional :: trans + character(len=1), intent(in), optional :: diag + + ! locals + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me,& + & err_act, iix, jjx, iiy, jjy, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nx + character(len=20) :: name, ch_err + + name='psb_sdiv_trslv' + if (psb_errstatus_fatal()) return + info=psb_success_ + call psb_erractionsave(err_act) + + ctxt=desc_a%get_context() + + call psb_info(ctxt, me, np) + if (np == -ione) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + + ix = ione + ijx = ione + + m = desc_a%get_global_rows() + nx = x%get_ncols() + + ! check vector correctness + call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect 1' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + + nr = desc_a%get_local_rows() + if(nr > 0) then + call x%trslv(nr,a,uplo,alpha,trans,diag,info) + end if + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + +end subroutine psb_sdiv_trslv + function psb_sminquotient_vect(x,y,desc_a,info,global) result(res) use psb_penv_mod use psb_serial_mod diff --git a/base/psblas/psb_sdot.f90 b/base/psblas/psb_sdot.f90 index 1ee5ae7e..f639f1cb 100644 --- a/base/psblas/psb_sdot.f90 +++ b/base/psblas/psb_sdot.f90 @@ -478,12 +478,17 @@ function psb_sdot_mvect_vect(x, y, desc_a,info,global) result(res) if (x%is_dev()) call x%sync() if (y%is_dev()) call y%sync() do i=1,size(desc_a%ovrlap_elem,1) - idx = desc_a%ovrlap_elem(i,1) - ndm = desc_a%ovrlap_elem(i,2) + !idx = desc_a%ovrlap_elem(i,1) + !ndm = desc_a%ovrlap_elem(i,2) ! Remove the overlapped elements via sgemv calls ! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res - call sgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), & - & size(x%v%v,1),y%v%v(idx),ione,done,res,ione) + !call sgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), & + ! & size(x%v%v,1),y%v%v(idx),ione,done,res,ione) + ! FIXME: To be fixed for overlapped communicators, e.g., AS + info = psb_err_internal_error_ + ch_err='over_elem_unsup' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 end do end if else diff --git a/base/psblas/psb_svmlt.f90 b/base/psblas/psb_svmlt.f90 index a9b506c6..e173182d 100644 --- a/base/psblas/psb_svmlt.f90 +++ b/base/psblas/psb_svmlt.f90 @@ -29,64 +29,208 @@ ! POSSIBILITY OF SUCH DAMAGE. ! ! -! File: psb_svmlt.f90 +! File: psb_dvmlt.f90 +!!$ +!!$subroutine psb_dvmlt(x,y,desc_a,info) +!!$ use psb_base_mod, psb_protect_name => psb_dvmlt +!!$ implicit none +!!$ type(psb_d_vect_type), intent (inout) :: x +!!$ type(psb_d_vect_type), intent (inout) :: y +!!$ type(psb_desc_type), intent (in) :: desc_a +!!$ integer(psb_ipk_), intent(out) :: info +!!$ +!!$ ! locals +!!$ integer(psb_ipk_) :: ctxt, np, me,& +!!$ & err_act, iix, jjx, iiy, jjy +!!$ integer(psb_lpk_) :: ix, ijx, iy, ijy, m +!!$ character(len=20) :: name, ch_err +!!$ +!!$ name='psb_dgevmlt' +!!$ if (psb_errstatus_fatal()) return +!!$ info=psb_success_ +!!$ call psb_erractionsave(err_act) +!!$ +!!$ ctxt=desc_a%get_context() +!!$ +!!$ call psb_info(ctxt, me, np) +!!$ if (np == -ione) then +!!$ info = psb_err_context_error_ +!!$ call psb_errpush(info,name) +!!$ goto 9999 +!!$ endif +!!$ if (.not.allocated(x%v)) then +!!$ info = psb_err_invalid_vect_state_ +!!$ call psb_errpush(info,name) +!!$ goto 9999 +!!$ endif +!!$ if (.not.allocated(y%v)) then +!!$ info = psb_err_invalid_vect_state_ +!!$ call psb_errpush(info,name) +!!$ goto 9999 +!!$ endif +!!$ +!!$ +!!$ ix = ione +!!$ iy = ione +!!$ +!!$ m = desc_a%get_global_rows() +!!$ +!!$ ! check vector correctness +!!$ call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx) +!!$ if(info /= psb_success_) then +!!$ info=psb_err_from_subroutine_ +!!$ ch_err='psb_chkvect 1' +!!$ call psb_errpush(info,name,a_err=ch_err) +!!$ goto 9999 +!!$ end if +!!$ call psb_chkvect(m,lone,y%get_nrows(),iy,lone,desc_a,info,iiy,jjy) +!!$ if(info /= psb_success_) then +!!$ info=psb_err_from_subroutine_ +!!$ ch_err='psb_chkvect 2' +!!$ call psb_errpush(info,name,a_err=ch_err) +!!$ goto 9999 +!!$ end if +!!$ +!!$ if ((iix /= ione).or.(iiy /= ione)) then +!!$ info=psb_err_ix_n1_iy_n1_unsupported_ +!!$ call psb_errpush(info,name) +!!$ end if +!!$ +!!$ if(desc_a%get_local_rows() > 0) then +!!$ call y%base_mlt_v(desc_a%get_local_rows(),& +!!$ & alpha,x,beta,info) +!!$ end if +!!$ +!!$ call psb_erractionrestore(err_act) +!!$ return +!!$ +!!$9999 call psb_error_handler(ctxt,err_act) +!!$ +!!$ return +!!$ +!!$end subroutine psb_dvmlt -subroutine psb_svmlt(x,y,desc_a,info) - use psb_base_mod, psb_protect_name => psb_svmlt - implicit none - type(psb_s_vect_type), intent (inout) :: x - type(psb_s_vect_type), intent (inout) :: y - type(psb_desc_type), intent (in) :: desc_a - integer(psb_ipk_), intent(out) :: info +! +! Parallel Sparse BLAS version 3.5 +! (C) Copyright 2006-2018 +! Salvatore Filippone +! Alfredo Buttari +! +! Redistribution and use in source and binary forms, with or without +! modification, are permitted provided that the following conditions +! are met: +! 1. Redistributions of source code must retain the above copyright +! notice, this list of conditions and the following disclaimer. +! 2. Redistributions in binary form must reproduce the above copyright +! notice, this list of conditions, and the following disclaimer in the +! documentation and/or other materials provided with the distribution. +! 3. The name of the PSBLAS group or the names of its contributors may +! not be used to endorse or promote products derived from this +! software without specific written permission. +! +! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +! ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +! TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +! PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE PSBLAS GROUP OR ITS CONTRIBUTORS +! BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +! CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +! SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +! INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +! CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +! POSSIBILITY OF SUCH DAMAGE. +! +! +! File: psb_dvmlt.f90 +! +! Function: psb_ddot_multivect +! psb_dvmlta computes the inner product of two distributed multivectors, +! +! dot(:,:) := ( X(:,:) )**C * ( Y(:,:) ) +! +! +! Arguments: +! x - type(psb_d_multivect_type) The input vector containing the entries of sub( X ). +! y - type(psb_d_multivect_type) The input vector containing the entries of sub( Y ). +! desc_a - type(psb_desc_type). The communication descriptor. +! info - integer. Return code +! global - logical(optional) Whether to perform the global sum, default: .true. +! +! Note: from a functional point of view, X and Y are input, but here +! they are declared INOUT because of the sync() methods. +! +! +subroutine psb_smlt_multivect(x, y, res,desc_a,info,global) + use psb_desc_mod + use psb_s_base_mat_mod + use psb_check_mod + use psb_error_mod + use psb_penv_mod + use psb_s_multivect_mod + use psb_s_psblas_mod, psb_protect_name => psb_dmlt_multivect + implicit none + real(psb_spk_), dimension(:,:), allocatable, intent(inout) :: res + type(psb_s_multivect_type), intent(inout) :: x, y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global ! locals - integer(psb_ipk_) :: ctxt, np, me,& - & err_act, iix, jjx, iiy, jjy - integer(psb_lpk_) :: ix, ijx, iy, ijy, m - character(len=20) :: name, ch_err + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me, idx, ndm,& + & err_act, iix, jjx, iiy, jjy, i, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n, nx, ny + logical :: global_ + character(len=20) :: name, ch_err - name='psb_sgevmlt' - if (psb_errstatus_fatal()) return + name='psb_dmlt_multivect' info=psb_success_ call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if ctxt=desc_a%get_context() - call psb_info(ctxt, me, np) if (np == -ione) then info = psb_err_context_error_ call psb_errpush(info,name) goto 9999 endif - if (.not.allocated(x%v)) then + if (.not.allocated(x%v)) then info = psb_err_invalid_vect_state_ call psb_errpush(info,name) goto 9999 endif - if (.not.allocated(y%v)) then + if (.not.allocated(y%v)) then info = psb_err_invalid_vect_state_ call psb_errpush(info,name) goto 9999 endif + if (present(global)) then + global_ = global + else + global_ = .true. + end if ix = ione + ijx = ione + iy = ione + ijy = ione m = desc_a%get_global_rows() + nx = x%get_ncols() + ny = y%get_ncols() ! check vector correctness - call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx) + call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) + if (info == psb_success_) & + & call psb_chkvect(m,ny,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy) if(info /= psb_success_) then info=psb_err_from_subroutine_ - ch_err='psb_chkvect 1' - call psb_errpush(info,name,a_err=ch_err) - goto 9999 - end if - call psb_chkvect(m,lone,y%get_nrows(),iy,lone,desc_a,info,iiy,jjy) - if(info /= psb_success_) then - info=psb_err_from_subroutine_ - ch_err='psb_chkvect 2' + ch_err='psb_chkvect' call psb_errpush(info,name,a_err=ch_err) goto 9999 end if @@ -94,18 +238,60 @@ subroutine psb_svmlt(x,y,desc_a,info) if ((iix /= ione).or.(iiy /= ione)) then info=psb_err_ix_n1_iy_n1_unsupported_ call psb_errpush(info,name) + goto 9999 end if - if(desc_a%get_local_rows() > 0) then - call y%base_mlt_v(desc_a%get_local_rows(),& - & alpha,x,beta,info) + if (x%get_ncols() /= y%get_ncols()) then + info=psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + else + if (allocated(res)) then + if ((size(res,1) /= x%get_ncols()).or.(size(res,2) /= y%get_ncols())) then + deallocate(res,stat=info) + end if + end if + end if + + nr = desc_a%get_local_rows() + if(nr > 0) then + call x%mlt(nr,y,res,info) + ! FIXME + ! adjust dot_local because overlapped elements are computed more than once + if (size(desc_a%ovrlap_elem,1)>0) then + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + do i=1,size(desc_a%ovrlap_elem,1) + idx = desc_a%ovrlap_elem(i,1) + ndm = desc_a%ovrlap_elem(i,2) + ! FIXME: case of AS-type descriptors + ! Since I'm coputing res via a dgemm on the whole vector, I need to adjust + ! the result by removing the contribution of the overlapped elements + ! specifically: res(:,:) = res(:,:) - x%v%v(idx,:)^T*y%v%v(idx,:) + ! using dgemm to compute the matrix-matrix product of the form R = R - X'*Y + ! where R is the result, X' is the transpose of the matrix x%v%v(idx,:) + ! and Y is the matrix y%v%v(idx,:) + ! call dgemm('T','N',size(x%v%v(idx,:),1),size(y%v%v(idx,:),1),& + ! & size(x%v%v(idx,:),1),-done,x%v%v(idx,:),size(x%v%v(idx,:),1),& + ! & y%v%v(idx,:),size(y%v%v(idx,:),1),done,res,y%get_ncols()) + info = psb_err_internal_error_ + ch_err='over_elem_unsup' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end do + end if + else + res = dzero end if + ! compute global sum + if (global_) call psb_sum(ctxt, res) + call psb_erractionrestore(err_act) - return + return 9999 call psb_error_handler(ctxt,err_act) return -end subroutine psb_svmlt +end subroutine psb_smlt_multivect diff --git a/base/psblas/psb_zdiv_vect.f90 b/base/psblas/psb_zdiv_vect.f90 index 22d8b21c..eab30cd4 100644 --- a/base/psblas/psb_zdiv_vect.f90 +++ b/base/psblas/psb_zdiv_vect.f90 @@ -356,3 +356,70 @@ subroutine psb_zdiv_vect2_check(x,y,z,desc_a,info,flag) end subroutine psb_zdiv_vect2_check +subroutine psb_zdiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag) + use psb_base_mod, psb_protect_name => psb_zdiv_trslv + implicit none + type(psb_z_multivect_type), intent (inout) :: x + complex(psb_dpk_), intent (in), dimension(:,:) :: a + type(psb_desc_type), intent (in) :: desc_a + character(len=1), intent(in) :: uplo + integer(psb_ipk_), intent(out) :: info + complex(psb_dpk_), intent (in), optional :: alpha + character(len=1), intent(in), optional :: trans + character(len=1), intent(in), optional :: diag + + ! locals + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me,& + & err_act, iix, jjx, iiy, jjy, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nx + character(len=20) :: name, ch_err + + name='psb_zdiv_trslv' + if (psb_errstatus_fatal()) return + info=psb_success_ + call psb_erractionsave(err_act) + + ctxt=desc_a%get_context() + + call psb_info(ctxt, me, np) + if (np == -ione) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + + ix = ione + ijx = ione + + m = desc_a%get_global_rows() + nx = x%get_ncols() + + ! check vector correctness + call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect 1' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + + nr = desc_a%get_local_rows() + if(nr > 0) then + call x%trslv(nr,a,uplo,alpha,trans,diag,info) + end if + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + +end subroutine psb_zdiv_trslv + diff --git a/base/psblas/psb_zdot.f90 b/base/psblas/psb_zdot.f90 index ff2d7401..49c59819 100644 --- a/base/psblas/psb_zdot.f90 +++ b/base/psblas/psb_zdot.f90 @@ -478,12 +478,17 @@ function psb_zdot_mvect_vect(x, y, desc_a,info,global) result(res) if (x%is_dev()) call x%sync() if (y%is_dev()) call y%sync() do i=1,size(desc_a%ovrlap_elem,1) - idx = desc_a%ovrlap_elem(i,1) - ndm = desc_a%ovrlap_elem(i,2) + !idx = desc_a%ovrlap_elem(i,1) + !ndm = desc_a%ovrlap_elem(i,2) ! Remove the overlapped elements via zgemv calls ! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res - call zgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), & - & size(x%v%v,1),y%v%v(idx),ione,done,res,ione) + !call zgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), & + ! & size(x%v%v,1),y%v%v(idx),ione,done,res,ione) + ! FIXME: To be fixed for overlapped communicators, e.g., AS + info = psb_err_internal_error_ + ch_err='over_elem_unsup' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 end do end if else diff --git a/base/psblas/psb_zvmlt.f90 b/base/psblas/psb_zvmlt.f90 index a4c06fc4..077d0a5f 100644 --- a/base/psblas/psb_zvmlt.f90 +++ b/base/psblas/psb_zvmlt.f90 @@ -29,64 +29,208 @@ ! POSSIBILITY OF SUCH DAMAGE. ! ! -! File: psb_zvmlt.f90 +! File: psb_dvmlt.f90 +!!$ +!!$subroutine psb_dvmlt(x,y,desc_a,info) +!!$ use psb_base_mod, psb_protect_name => psb_dvmlt +!!$ implicit none +!!$ type(psb_d_vect_type), intent (inout) :: x +!!$ type(psb_d_vect_type), intent (inout) :: y +!!$ type(psb_desc_type), intent (in) :: desc_a +!!$ integer(psb_ipk_), intent(out) :: info +!!$ +!!$ ! locals +!!$ integer(psb_ipk_) :: ctxt, np, me,& +!!$ & err_act, iix, jjx, iiy, jjy +!!$ integer(psb_lpk_) :: ix, ijx, iy, ijy, m +!!$ character(len=20) :: name, ch_err +!!$ +!!$ name='psb_dgevmlt' +!!$ if (psb_errstatus_fatal()) return +!!$ info=psb_success_ +!!$ call psb_erractionsave(err_act) +!!$ +!!$ ctxt=desc_a%get_context() +!!$ +!!$ call psb_info(ctxt, me, np) +!!$ if (np == -ione) then +!!$ info = psb_err_context_error_ +!!$ call psb_errpush(info,name) +!!$ goto 9999 +!!$ endif +!!$ if (.not.allocated(x%v)) then +!!$ info = psb_err_invalid_vect_state_ +!!$ call psb_errpush(info,name) +!!$ goto 9999 +!!$ endif +!!$ if (.not.allocated(y%v)) then +!!$ info = psb_err_invalid_vect_state_ +!!$ call psb_errpush(info,name) +!!$ goto 9999 +!!$ endif +!!$ +!!$ +!!$ ix = ione +!!$ iy = ione +!!$ +!!$ m = desc_a%get_global_rows() +!!$ +!!$ ! check vector correctness +!!$ call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx) +!!$ if(info /= psb_success_) then +!!$ info=psb_err_from_subroutine_ +!!$ ch_err='psb_chkvect 1' +!!$ call psb_errpush(info,name,a_err=ch_err) +!!$ goto 9999 +!!$ end if +!!$ call psb_chkvect(m,lone,y%get_nrows(),iy,lone,desc_a,info,iiy,jjy) +!!$ if(info /= psb_success_) then +!!$ info=psb_err_from_subroutine_ +!!$ ch_err='psb_chkvect 2' +!!$ call psb_errpush(info,name,a_err=ch_err) +!!$ goto 9999 +!!$ end if +!!$ +!!$ if ((iix /= ione).or.(iiy /= ione)) then +!!$ info=psb_err_ix_n1_iy_n1_unsupported_ +!!$ call psb_errpush(info,name) +!!$ end if +!!$ +!!$ if(desc_a%get_local_rows() > 0) then +!!$ call y%base_mlt_v(desc_a%get_local_rows(),& +!!$ & alpha,x,beta,info) +!!$ end if +!!$ +!!$ call psb_erractionrestore(err_act) +!!$ return +!!$ +!!$9999 call psb_error_handler(ctxt,err_act) +!!$ +!!$ return +!!$ +!!$end subroutine psb_dvmlt -subroutine psb_zvmlt(x,y,desc_a,info) - use psb_base_mod, psb_protect_name => psb_zvmlt - implicit none - type(psb_z_vect_type), intent (inout) :: x - type(psb_z_vect_type), intent (inout) :: y - type(psb_desc_type), intent (in) :: desc_a - integer(psb_ipk_), intent(out) :: info +! +! Parallel Sparse BLAS version 3.5 +! (C) Copyright 2006-2018 +! Salvatore Filippone +! Alfredo Buttari +! +! Redistribution and use in source and binary forms, with or without +! modification, are permitted provided that the following conditions +! are met: +! 1. Redistributions of source code must retain the above copyright +! notice, this list of conditions and the following disclaimer. +! 2. Redistributions in binary form must reproduce the above copyright +! notice, this list of conditions, and the following disclaimer in the +! documentation and/or other materials provided with the distribution. +! 3. The name of the PSBLAS group or the names of its contributors may +! not be used to endorse or promote products derived from this +! software without specific written permission. +! +! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +! ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +! TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +! PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE PSBLAS GROUP OR ITS CONTRIBUTORS +! BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +! CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +! SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +! INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +! CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +! ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +! POSSIBILITY OF SUCH DAMAGE. +! +! +! File: psb_dvmlt.f90 +! +! Function: psb_ddot_multivect +! psb_dvmlta computes the inner product of two distributed multivectors, +! +! dot(:,:) := ( X(:,:) )**C * ( Y(:,:) ) +! +! +! Arguments: +! x - type(psb_d_multivect_type) The input vector containing the entries of sub( X ). +! y - type(psb_d_multivect_type) The input vector containing the entries of sub( Y ). +! desc_a - type(psb_desc_type). The communication descriptor. +! info - integer. Return code +! global - logical(optional) Whether to perform the global sum, default: .true. +! +! Note: from a functional point of view, X and Y are input, but here +! they are declared INOUT because of the sync() methods. +! +! +subroutine psb_zmlt_multivect(x, y, res,desc_a,info,global) + use psb_desc_mod + use psb_z_base_mat_mod + use psb_check_mod + use psb_error_mod + use psb_penv_mod + use psb_z_multivect_mod + use psb_z_psblas_mod, psb_protect_name => psb_dmlt_multivect + implicit none + complex(psb_dpk_), dimension(:,:), allocatable, intent(inout) :: res + type(psb_z_multivect_type), intent(inout) :: x, y + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: global ! locals - integer(psb_ipk_) :: ctxt, np, me,& - & err_act, iix, jjx, iiy, jjy - integer(psb_lpk_) :: ix, ijx, iy, ijy, m - character(len=20) :: name, ch_err + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me, idx, ndm,& + & err_act, iix, jjx, iiy, jjy, i, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n, nx, ny + logical :: global_ + character(len=20) :: name, ch_err - name='psb_zgevmlt' - if (psb_errstatus_fatal()) return + name='psb_dmlt_multivect' info=psb_success_ call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if ctxt=desc_a%get_context() - call psb_info(ctxt, me, np) if (np == -ione) then info = psb_err_context_error_ call psb_errpush(info,name) goto 9999 endif - if (.not.allocated(x%v)) then + if (.not.allocated(x%v)) then info = psb_err_invalid_vect_state_ call psb_errpush(info,name) goto 9999 endif - if (.not.allocated(y%v)) then + if (.not.allocated(y%v)) then info = psb_err_invalid_vect_state_ call psb_errpush(info,name) goto 9999 endif + if (present(global)) then + global_ = global + else + global_ = .true. + end if ix = ione + ijx = ione + iy = ione + ijy = ione m = desc_a%get_global_rows() + nx = x%get_ncols() + ny = y%get_ncols() ! check vector correctness - call psb_chkvect(m,lone,x%get_nrows(),ix,lone,desc_a,info,iix,jjx) + call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) + if (info == psb_success_) & + & call psb_chkvect(m,ny,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy) if(info /= psb_success_) then info=psb_err_from_subroutine_ - ch_err='psb_chkvect 1' - call psb_errpush(info,name,a_err=ch_err) - goto 9999 - end if - call psb_chkvect(m,lone,y%get_nrows(),iy,lone,desc_a,info,iiy,jjy) - if(info /= psb_success_) then - info=psb_err_from_subroutine_ - ch_err='psb_chkvect 2' + ch_err='psb_chkvect' call psb_errpush(info,name,a_err=ch_err) goto 9999 end if @@ -94,18 +238,60 @@ subroutine psb_zvmlt(x,y,desc_a,info) if ((iix /= ione).or.(iiy /= ione)) then info=psb_err_ix_n1_iy_n1_unsupported_ call psb_errpush(info,name) + goto 9999 end if - if(desc_a%get_local_rows() > 0) then - call y%base_mlt_v(desc_a%get_local_rows(),& - & alpha,x,beta,info) + if (x%get_ncols() /= y%get_ncols()) then + info=psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + else + if (allocated(res)) then + if ((size(res,1) /= x%get_ncols()).or.(size(res,2) /= y%get_ncols())) then + deallocate(res,stat=info) + end if + end if + end if + + nr = desc_a%get_local_rows() + if(nr > 0) then + call x%mlt(nr,y,res,info) + ! FIXME + ! adjust dot_local because overlapped elements are computed more than once + if (size(desc_a%ovrlap_elem,1)>0) then + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + do i=1,size(desc_a%ovrlap_elem,1) + idx = desc_a%ovrlap_elem(i,1) + ndm = desc_a%ovrlap_elem(i,2) + ! FIXME: case of AS-type descriptors + ! Since I'm coputing res via a dgemm on the whole vector, I need to adjust + ! the result by removing the contribution of the overlapped elements + ! specifically: res(:,:) = res(:,:) - x%v%v(idx,:)^T*y%v%v(idx,:) + ! using dgemm to compute the matrix-matrix product of the form R = R - X'*Y + ! where R is the result, X' is the transpose of the matrix x%v%v(idx,:) + ! and Y is the matrix y%v%v(idx,:) + ! call dgemm('T','N',size(x%v%v(idx,:),1),size(y%v%v(idx,:),1),& + ! & size(x%v%v(idx,:),1),-done,x%v%v(idx,:),size(x%v%v(idx,:),1),& + ! & y%v%v(idx,:),size(y%v%v(idx,:),1),done,res,y%get_ncols()) + info = psb_err_internal_error_ + ch_err='over_elem_unsup' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end do + end if + else + res = dzero end if + ! compute global sum + if (global_) call psb_sum(ctxt, res) + call psb_erractionrestore(err_act) - return + return 9999 call psb_error_handler(ctxt,err_act) return -end subroutine psb_zvmlt +end subroutine psb_zmlt_multivect