Added distributed MGS QR working on GPU

psblas-bgmres
gabrielequatrana 3 months ago
parent 3f22276aff
commit b22cba4413

@ -54,6 +54,16 @@ module psb_d_psblas_mod
integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global logical, intent(in), optional :: global
end function psb_ddot_multivect end function psb_ddot_multivect
function psb_ddot_multivect_col(col_x, col_y, x, y, desc_a,info,global) result(res)
import :: psb_desc_type, psb_dpk_, psb_ipk_, &
& psb_d_multivect_type, psb_dspmat_type
real(psb_dpk_) :: res
integer(psb_ipk_), intent(in) :: col_x, col_y
type(psb_d_multivect_type), intent(inout) :: x, y
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
end function psb_ddot_multivect_col
function psb_ddotv(x, y, desc_a,info,global) function psb_ddotv(x, y, desc_a,info,global)
import :: psb_desc_type, psb_dpk_, psb_ipk_, & import :: psb_desc_type, psb_dpk_, psb_ipk_, &
& psb_d_vect_type, psb_dspmat_type & psb_d_vect_type, psb_dspmat_type
@ -137,15 +147,16 @@ module psb_d_psblas_mod
type(psb_desc_type), intent (in) :: desc_a type(psb_desc_type), intent (in) :: desc_a
integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), intent(out) :: info
end subroutine psb_daxpby_multivect end subroutine psb_daxpby_multivect
subroutine psb_daxpby_multivect_a(alpha, x, beta, y, desc_a, info) subroutine psb_daxpby_multivect_col(col_x, col_y, alpha, x, beta, y, desc_a, info)
import :: psb_desc_type, psb_dpk_, psb_ipk_, & import :: psb_desc_type, psb_dpk_, psb_ipk_, &
& psb_d_multivect_type, psb_dspmat_type & psb_d_multivect_type, psb_dspmat_type
real(psb_dpk_), intent(in) :: x(:,:) type(psb_d_multivect_type), intent (inout) :: x
type(psb_d_multivect_type), intent (inout) :: y type(psb_d_multivect_type), intent (inout) :: y
real(psb_dpk_), intent (in) :: alpha, beta real(psb_dpk_), intent (in) :: alpha, beta
integer(psb_ipk_), intent(in) :: col_x, col_y
type(psb_desc_type), intent (in) :: desc_a type(psb_desc_type), intent (in) :: desc_a
integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), intent(out) :: info
end subroutine psb_daxpby_multivect_a end subroutine psb_daxpby_multivect_col
subroutine psb_daxpby_vect_out(alpha, x, beta, y,& subroutine psb_daxpby_vect_out(alpha, x, beta, y,&
& z, desc_a, info) & z, desc_a, info)
import :: psb_desc_type, psb_dpk_, psb_ipk_, & import :: psb_desc_type, psb_dpk_, psb_ipk_, &
@ -421,6 +432,16 @@ module psb_d_psblas_mod
integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global logical, intent(in), optional :: global
end function psb_dnrm2_multivect end function psb_dnrm2_multivect
function psb_dnrm2_multivect_col(x, col, desc_a, info, global) result(res)
import :: psb_desc_type, psb_dpk_, psb_ipk_, &
& psb_d_multivect_type, psb_dspmat_type
real(psb_dpk_) :: res
type(psb_d_multivect_type), intent (inout) :: x
integer(psb_ipk_), intent(in) :: col
type(psb_desc_type), intent (in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
end function psb_dnrm2_multivect_col
end interface end interface
#if ! defined(HAVE_BUGGY_GENERICS) #if ! defined(HAVE_BUGGY_GENERICS)

@ -2309,15 +2309,19 @@ module psb_d_base_multivect_mod
! !
! Product, dot-product (col-by-col) and AXPBY ! Product, dot-product (col-by-col) and AXPBY
! !
procedure, pass(x) :: prod_v => d_base_mlv_prod_v procedure, pass(x) :: prod_v => d_base_mlv_prod_v
procedure, pass(x) :: prod_a => d_base_mlv_prod_a procedure, pass(x) :: prod_a => d_base_mlv_prod_a
generic, public :: prod => prod_v, prod_a generic, public :: prod => prod_v, prod_a
procedure, pass(x) :: dot_v => d_base_mlv_dot_v procedure, pass(x) :: dot_v => d_base_mlv_dot_v
procedure, pass(x) :: dot_a => d_base_mlv_dot_a procedure, pass(x) :: dot_v_col => d_base_mlv_dot_v_col
generic, public :: dot => dot_v, dot_a procedure, pass(x) :: dot_a => d_base_mlv_dot_a
procedure, pass(y) :: axpby_v => d_base_mlv_axpby_v procedure, pass(x) :: dot_a_col => d_base_mlv_dot_a_col
procedure, pass(y) :: axpby_a => d_base_mlv_axpby_a generic, public :: dot => dot_v, dot_v_col, dot_a, dot_a_col
generic, public :: axpby => axpby_v, axpby_a procedure, pass(y) :: axpby_v => d_base_mlv_axpby_v
procedure, pass(y) :: axpby_v_col => d_base_mlv_axpby_v_col
procedure, pass(y) :: axpby_a => d_base_mlv_axpby_a
procedure, pass(y) :: axpby_a_col => d_base_mlv_axpby_a_col
generic, public :: axpby => axpby_v, axpby_v_col, axpby_a, axpby_a_col
! !
! MultiVector by vector/multivector multiplication. Need all variants ! MultiVector by vector/multivector multiplication. Need all variants
! to handle multiple requirements from preconditioners ! to handle multiple requirements from preconditioners
@ -2336,7 +2340,9 @@ module psb_d_base_multivect_mod
! Scaling and norms ! Scaling and norms
! !
procedure, pass(x) :: scal => d_base_mlv_scal procedure, pass(x) :: scal => d_base_mlv_scal
procedure, pass(x) :: nrm2 => d_base_mlv_nrm2 procedure, pass(x) :: nrm2_mv => d_base_mlv_nrm2
procedure, pass(x) :: nrm2_col => d_base_mlv_nrm2_col
generic, public :: nrm2 => nrm2_mv, nrm2_col
procedure, pass(x) :: amax => d_base_mlv_amax procedure, pass(x) :: amax => d_base_mlv_amax
procedure, pass(x) :: asum => d_base_mlv_asum procedure, pass(x) :: asum => d_base_mlv_asum
procedure, pass(x) :: absval1 => d_base_mlv_absval1 procedure, pass(x) :: absval1 => d_base_mlv_absval1
@ -3031,6 +3037,43 @@ contains
end function d_base_mlv_dot_v end function d_base_mlv_dot_v
!
! Dot products
!
!
!> Function base_mlv_dot_v_col
!! \memberof psb_d_base_multivect_type
!! \brief Dot product by another base_mlv_vector
!! \param nr Number of rows to be considered
!! \param col Column index
!! \param y The other (base_mlv_vect) to be multiplied by
!! \param res Result matrix
!!
function d_base_mlv_dot_v_col(nr,col_x,col_y,x,y) result(res)
implicit none
class(psb_d_base_multivect_type), intent(inout) :: x, y
integer(psb_ipk_), intent(in) :: nr, col_x, col_y
real(psb_dpk_) :: res
real(psb_dpk_), external :: ddot
if (x%is_dev()) call x%sync()
!
! Note: this is the base implementation.
! When we get here, we are sure that X is of
! TYPE psb_d_base_mlv_vect (or its class does not care).
! If Y is not, throw the burden on it, implicitly
! calling dot_a
!
select type(yy => y)
type is (psb_d_base_multivect_type)
if (y%is_dev()) call y%sync()
res = ddot(nr,x%v(1:nr,col_x),1,y%v(1:nr,col_y),1)
class default
res = x%dot(nr,col_x,col_y,y%v)
end select
end function d_base_mlv_dot_v_col
! !
! Base workhorse is good old BLAS1 ! Base workhorse is good old BLAS1
! !
@ -3061,6 +3104,30 @@ contains
end do end do
end function d_base_mlv_dot_a end function d_base_mlv_dot_a
!
! Base workhorse is good old BLAS1
!
!
!> Function base_mlv_dot_a
!! \memberof psb_d_base_multivect_type
!! \brief Dot product by a normal array
!! \param nr Number of rows to be considered
!! \param y(:,:) The array to be multiplied by
!! \param res Result matrix
!!
function d_base_mlv_dot_a_col(nr,col_x,col_y,x,y) result(res)
class(psb_d_base_multivect_type), intent(inout) :: x
real(psb_dpk_), intent(in) :: y(:,:)
integer(psb_ipk_), intent(in) :: nr, col_x, col_y
real(psb_dpk_) :: res
real(psb_dpk_), external :: ddot
if (x%is_dev()) call x%sync()
res = ddot(nr,x%v(1:nr,col_x),1,y(1:nr,col_y),1)
end function d_base_mlv_dot_a_col
! !
! AXPBY is invoked via Y, hence the structure below. ! AXPBY is invoked via Y, hence the structure below.
! !
@ -3100,6 +3167,39 @@ contains
end subroutine d_base_mlv_axpby_v end subroutine d_base_mlv_axpby_v
!
! AXPBY is invoked via Y, hence the structure below.
!
!
!
!> Function base_mlv_axpby_v_col
!! \memberof psb_d_base_multivect_type
!! \brief AXPBY by a (base_mlv_vect) y=alpha*x+beta*y
!! \param m Number of entries to be considered
!! \param alpha scalar alpha
!! \param x The class(base_mlv_vect) to be added
!! \param beta scalar alpha
!! \param col column index
!! \param info return code
!!
subroutine d_base_mlv_axpby_v_col(m, col_x, col_y, alpha, x, beta, y, info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: m
class(psb_d_base_multivect_type), intent(inout) :: x
class(psb_d_base_multivect_type), intent(inout) :: y
real(psb_dpk_), intent (in) :: alpha, beta
integer(psb_ipk_), intent(in) :: col_x, col_y
integer(psb_ipk_), intent(out) :: info
select type(xx => x)
type is (psb_d_base_multivect_type)
call psb_geaxpby(m,alpha,x%v(:,col_x),beta,y%v(:,col_y),info)
class default
call y%axpby(m,col_x,col_y,alpha,x%v,beta,info)
end select
end subroutine d_base_mlv_axpby_v_col
! !
! AXPBY is invoked via Y, hence the structure below. ! AXPBY is invoked via Y, hence the structure below.
! !
@ -3133,6 +3233,33 @@ contains
end subroutine d_base_mlv_axpby_a end subroutine d_base_mlv_axpby_a
!
! AXPBY is invoked via Y, hence the structure below.
!
!
!> Function base_mlv_axpby_a_col
!! \memberof psb_d_base_multivect_type
!! \brief AXPBY by a normal array y=alpha*x+beta*y
!! \param m Number of entries to be considered
!! \param alpha scalar alpha
!! \param x(:,:) The array to be added
!! \param beta scalar alpha
!! \param col column index
!! \param info return code
!!
subroutine d_base_mlv_axpby_a_col(m, col_x, col_y, alpha, x, beta, y, info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: m
real(psb_dpk_), intent(in) :: x(:,:)
class(psb_d_base_multivect_type), intent(inout) :: y
real(psb_dpk_), intent (in) :: alpha, beta
integer(psb_ipk_), intent(in) :: col_x, col_y
integer(psb_ipk_), intent(out) :: info
call psb_geaxpby(m,alpha,x(:,col_x),beta,y%v(:,col_y),info)
end subroutine d_base_mlv_axpby_a_col
! !
! Multiple variants of two operations: ! Multiple variants of two operations:
@ -3399,6 +3526,26 @@ contains
end function d_base_mlv_nrm2 end function d_base_mlv_nrm2
!
! Norms 1, 2 and infinity
!
!> Function base_mlv_nrm2_col
!! \memberof psb_d_base_multivect_type
!! \brief 2-norm |x(1:nr,col)|_2
!! \param col column index to consider
!! \param nr how many rows to consider
function d_base_mlv_nrm2_col(nr,col,x) result(res)
implicit none
class(psb_d_base_multivect_type), intent(inout) :: x
integer(psb_ipk_), intent(in) :: nr, col
real(psb_dpk_) :: res
real(psb_dpk_), external :: dnrm2
if (x%is_dev()) call x%sync()
res = dnrm2(nr,x%v(:,col),1)
end function d_base_mlv_nrm2_col
! !
!> Function base_mlv_amax !> Function base_mlv_amax
!! \memberof psb_d_base_multivect_type !! \memberof psb_d_base_multivect_type

@ -1424,15 +1424,19 @@ module psb_d_multivect_mod
! !
! Produc, dot-product and AXPBY ! Produc, dot-product and AXPBY
! !
procedure, pass(x) :: prod_v => d_vect_prod_v procedure, pass(x) :: prod_v => d_vect_prod_v
procedure, pass(x) :: prod_a => d_vect_prod_a procedure, pass(x) :: prod_a => d_vect_prod_a
generic, public :: prod => prod_v, prod_a generic, public :: prod => prod_v, prod_a
procedure, pass(x) :: dot_v => d_vect_dot_v procedure, pass(x) :: dot_v => d_vect_dot_v
procedure, pass(x) :: dot_a => d_vect_dot_a procedure, pass(x) :: dot_v_col => d_vect_dot_v_col
generic, public :: dot => dot_v, dot_a procedure, pass(x) :: dot_a => d_vect_dot_a
procedure, pass(y) :: axpby_v => d_vect_axpby_v procedure, pass(x) :: dot_a_col => d_vect_dot_a_col
procedure, pass(y) :: axpby_a => d_vect_axpby_a generic, public :: dot => dot_v, dot_v_col, dot_a, dot_a_col
generic, public :: axpby => axpby_v, axpby_a procedure, pass(y) :: axpby_v => d_vect_axpby_v
procedure, pass(y) :: axpby_v_col => d_vect_axpby_v_col
procedure, pass(y) :: axpby_a => d_vect_axpby_a
procedure, pass(y) :: axpby_a_col => d_vect_axpby_a_col
generic, public :: axpby => axpby_v, axpby_v_col, axpby_a, axpby_a_col
! !
! MultiVector by vector/multivector multiplication. Need all variants ! MultiVector by vector/multivector multiplication. Need all variants
! to handle multiple requirements from preconditioners ! to handle multiple requirements from preconditioners
@ -1449,7 +1453,9 @@ module psb_d_multivect_mod
! Scaling and norms ! Scaling and norms
! !
!!$ procedure, pass(x) :: scal => d_vect_scal !!$ procedure, pass(x) :: scal => d_vect_scal
procedure, pass(x) :: nrm2 => d_vect_nrm2 procedure, pass(x) :: nrm2_mv => d_vect_nrm2
procedure, pass(x) :: nrm2_col => d_vect_nrm2_col
generic, public :: nrm2 => nrm2_mv, nrm2_col
procedure, pass(x) :: amax => d_vect_amax procedure, pass(x) :: amax => d_vect_amax
procedure, pass(x) :: asum => d_vect_asum procedure, pass(x) :: asum => d_vect_asum
procedure, pass(x) :: qr_fact => d_vect_qr_fact procedure, pass(x) :: qr_fact => d_vect_qr_fact
@ -1927,6 +1933,17 @@ contains
end function d_vect_dot_v end function d_vect_dot_v
function d_vect_dot_v_col(nr,col_x,col_y,x,y) result(res)
implicit none
class(psb_d_multivect_type), intent(inout) :: x, y
integer(psb_ipk_), intent(in) :: nr, col_x, col_y
real(psb_dpk_) :: res
if (allocated(x%v).and.allocated(y%v)) &
& res = x%v%dot(nr,col_x,col_y,y%v)
end function d_vect_dot_v_col
function d_vect_dot_a(nr,x,y) result(res) function d_vect_dot_a(nr,x,y) result(res)
implicit none implicit none
class(psb_d_multivect_type), intent(inout) :: x class(psb_d_multivect_type), intent(inout) :: x
@ -1939,6 +1956,18 @@ contains
end function d_vect_dot_a end function d_vect_dot_a
function d_vect_dot_a_col(nr,col_x,col_y,x,y) result(res)
implicit none
class(psb_d_multivect_type), intent(inout) :: x
real(psb_dpk_), intent(in) :: y(:,:)
integer(psb_ipk_), intent(in) :: nr, col_x, col_y
real(psb_dpk_) :: res
if (allocated(x%v)) &
& res = x%v%dot(nr,col_x,col_y,y)
end function d_vect_dot_a_col
subroutine d_vect_axpby_v(m,alpha, x, beta, y, info) subroutine d_vect_axpby_v(m,alpha, x, beta, y, info)
use psi_serial_mod use psi_serial_mod
implicit none implicit none
@ -1956,6 +1985,24 @@ contains
end subroutine d_vect_axpby_v end subroutine d_vect_axpby_v
subroutine d_vect_axpby_v_col(m, col_x, col_y, alpha, x, beta, y, info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: m
class(psb_d_multivect_type), intent(inout) :: x
class(psb_d_multivect_type), intent(inout) :: y
real(psb_dpk_), intent (in) :: alpha, beta
integer(psb_ipk_), intent(in) :: col_x, col_y
integer(psb_ipk_), intent(out) :: info
if (allocated(x%v).and.allocated(y%v)) then
call y%v%axpby(m,col_x,col_y,alpha,x%v,beta,info)
else
info = psb_err_invalid_vect_state_
end if
end subroutine d_vect_axpby_v_col
subroutine d_vect_axpby_a(m,alpha, x, beta, y, info) subroutine d_vect_axpby_a(m,alpha, x, beta, y, info)
use psi_serial_mod use psi_serial_mod
implicit none implicit none
@ -1970,6 +2017,21 @@ contains
end subroutine d_vect_axpby_a end subroutine d_vect_axpby_a
subroutine d_vect_axpby_a_col(m, col_x, col_y, alpha, x, beta, y, info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: m
real(psb_dpk_), intent(in) :: x(:,:)
class(psb_d_multivect_type), intent(inout) :: y
real(psb_dpk_), intent (in) :: alpha, beta
integer(psb_ipk_), intent(in) :: col_x, col_y
integer(psb_ipk_), intent(out) :: info
if (allocated(y%v)) &
& call y%v%axpby(m,col_x,col_y,alpha,x,beta,info)
end subroutine d_vect_axpby_a_col
!!$ subroutine d_vect_mlt_v(x, y, info) !!$ subroutine d_vect_mlt_v(x, y, info)
!!$ use psi_serial_mod !!$ use psi_serial_mod
!!$ implicit none !!$ implicit none
@ -2094,6 +2156,20 @@ contains
end function d_vect_nrm2 end function d_vect_nrm2
function d_vect_nrm2_col(nr,col,x) result(res)
implicit none
class(psb_d_multivect_type), intent(inout) :: x
integer(psb_ipk_), intent(in) :: col, nr
real(psb_dpk_) :: res
if (allocated(x%v)) then
res = x%v%nrm2(nr,col)
else
res = dzero
end if
end function d_vect_nrm2_col
function d_vect_amax(nr,x) result(res) function d_vect_amax(nr,x) result(res)
implicit none implicit none
class(psb_d_multivect_type), intent(inout) :: x class(psb_d_multivect_type), intent(inout) :: x

@ -230,38 +230,40 @@ subroutine psb_daxpby_multivect(alpha, x, beta, y, desc_a, info)
end subroutine psb_daxpby_multivect end subroutine psb_daxpby_multivect
! !
! Subroutine: psb_daxpby_multivect ! Subroutine: psb_daxpby_multivect_col
! Adds one distributed multivector to another, ! Adds one distributed multivector column to another,
! !
! Y := beta * Y + alpha * X ! Y := beta * Y + alpha * X
! !
! Arguments: ! Arguments:
! alpha - real,input The scalar used to multiply each component of X ! alpha - real,input The scalar used to multiply each component of X
! x - real(psb_dpk_)(:,:) The input multivector containing the entries of X ! x - type(psb_d_multivect_type) The input multivector containing the entries of X
! beta - real,input The scalar used to multiply each component of Y ! beta - real,input The scalar used to multiply each component of Y
! y - type(psb_d_multivect_type) The input/output multivector Y ! y - type(psb_d_multivect_type) The input/output multivector Y
! desc_a - type(psb_desc_type) The communication descriptor. ! col - integer Column index
! info - integer Return code ! desc_a - type(psb_desc_type) The communication descriptor.
! info - integer Return code
! !
! Note: from a functional point of view, X is input, but here ! Note: from a functional point of view, X is input, but here
! it's declared INOUT because of the sync() methods. ! it's declared INOUT because of the sync() methods.
! !
subroutine psb_daxpby_multivect_a(alpha, x, beta, y, desc_a, info) subroutine psb_daxpby_multivect_col(col_x, col_y, alpha, x, beta, y, desc_a, info)
use psb_base_mod, psb_protect_name => psb_daxpby_multivect_a use psb_base_mod, psb_protect_name => psb_daxpby_multivect_col
implicit none implicit none
real(psb_dpk_), intent(in) :: x(:,:) type(psb_d_multivect_type), intent (inout) :: x
type(psb_d_multivect_type), intent (inout) :: y type(psb_d_multivect_type), intent (inout) :: y
real(psb_dpk_), intent (in) :: alpha, beta real(psb_dpk_), intent (in) :: alpha, beta
integer(psb_ipk_), intent(in) :: col_x, col_y
type(psb_desc_type), intent (in) :: desc_a type(psb_desc_type), intent (in) :: desc_a
integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), intent(out) :: info
! locals ! locals
type(psb_ctxt_type) :: ctxt type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me, err_act, iiy, jjy integer(psb_ipk_) :: np, me, err_act, iix, jjx, iiy, jjy
integer(psb_lpk_) :: iy, ijy, m, n integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n
character(len=20) :: name, ch_err character(len=20) :: name, ch_err
name='psb_dgeaxpby' name='psb_dgeaxpby_mv_col'
if (psb_errstatus_fatal()) return if (psb_errstatus_fatal()) return
info=psb_success_ info=psb_success_
call psb_erractionsave(err_act) call psb_erractionsave(err_act)
@ -274,16 +276,33 @@ subroutine psb_daxpby_multivect_a(alpha, x, beta, y, desc_a, info)
call psb_errpush(info,name) call psb_errpush(info,name)
goto 9999 goto 9999
endif endif
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(y%v)) then if (.not.allocated(y%v)) then
info = psb_err_invalid_vect_state_ info = psb_err_invalid_vect_state_
call psb_errpush(info,name) call psb_errpush(info,name)
goto 9999 goto 9999
endif endif
ix = ione
ijx = ione
iy = ione iy = ione
ijy = ione ijy = ione
m = desc_a%get_global_rows() m = desc_a%get_global_rows()
n = x%get_ncols()
! check vector correctness
call psb_chkvect(m,n,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect 1'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
n = y%get_ncols() n = y%get_ncols()
call psb_chkvect(m,n,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy) call psb_chkvect(m,n,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy)
if(info /= psb_success_) then if(info /= psb_success_) then
@ -293,13 +312,13 @@ subroutine psb_daxpby_multivect_a(alpha, x, beta, y, desc_a, info)
goto 9999 goto 9999
end if end if
if (iiy /= ione) then if ((iix /= ione).or.(iiy /= ione)) then
info=psb_err_ix_n1_iy_n1_unsupported_ info=psb_err_ix_n1_iy_n1_unsupported_
call psb_errpush(info,name) call psb_errpush(info,name)
end if end if
if(desc_a%get_local_rows() > 0) then if(desc_a%get_local_rows() > 0) then
call y%axpby(desc_a%get_local_rows(),alpha,x,beta,info) call y%axpby(desc_a%get_local_rows(),col_x,col_y,alpha,x,beta,info)
end if end if
call psb_erractionrestore(err_act) call psb_erractionrestore(err_act)
@ -309,7 +328,7 @@ subroutine psb_daxpby_multivect_a(alpha, x, beta, y, desc_a, info)
return return
end subroutine psb_daxpby_multivect_a end subroutine psb_daxpby_multivect_col
! !
! Parallel Sparse BLAS version 3.5 ! Parallel Sparse BLAS version 3.5

@ -287,6 +287,133 @@ function psb_ddot_multivect(x, y, desc_a,info,global) result(res)
end function psb_ddot_multivect end function psb_ddot_multivect
! !
! Function: psb_ddot_multivect
! psb_ddot computes the col-by-col dot product of two distributed vectors,
!
! dot := ( X )**C * ( Y )
!
!
! Arguments:
! x - type(psb_d_multivect_type) The input vector containing the entries of sub( X ).
! y - type(psb_d_multivect_type) The input vector containing the entries of sub( Y ).
! desc_a - type(psb_desc_type). The communication descriptor.
! info - integer. Return code
! global - logical(optional) Whether to perform the global sum, default: .true.
!
! Note: from a functional point of view, X and Y are input, but here
! they are declared INOUT because of the sync() methods.
!
!
function psb_ddot_multivect_col(col_x, col_y, x, y, desc_a,info,global) result(res)
use psb_desc_mod
use psb_d_base_mat_mod
use psb_check_mod
use psb_error_mod
use psb_penv_mod
use psb_d_vect_mod
use psb_d_psblas_mod, psb_protect_name => psb_ddot_multivect_col
implicit none
real(psb_dpk_) :: res
integer(psb_ipk_), intent(in) :: col_x, col_y
type(psb_d_multivect_type), intent(inout) :: x, y
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
! locals
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me, idx, ndm,&
& err_act, iix, jjx, iiy, jjy, i, j, nr
integer(psb_lpk_) :: ix, ijx, iy, ijy, m, n
logical :: global_
character(len=20) :: name, ch_err
name='psb_ddot_multivect_col'
info=psb_success_
call psb_erractionsave(err_act)
if (psb_errstatus_fatal()) then
info = psb_err_internal_error_ ; goto 9999
end if
ctxt=desc_a%get_context()
call psb_info(ctxt, me, np)
if (np == -ione) then
info = psb_err_context_error_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(y%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
if (present(global)) then
global_ = global
else
global_ = .true.
end if
ix = ione
ijx = ione
iy = ione
ijy = ione
m = desc_a%get_global_rows()
n = x%get_ncols()
! check vector correctness
call psb_chkvect(m,n,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
n = y%get_ncols()
if (info == psb_success_) &
& call psb_chkvect(m,n,y%get_nrows(),iy,ijy,desc_a,info,iiy,jjy)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
if ((iix /= ione).or.(iiy /= ione)) then
info=psb_err_ix_n1_iy_n1_unsupported_
call psb_errpush(info,name)
goto 9999
end if
nr = desc_a%get_local_rows()
if(nr > 0) then
res = x%dot(nr,col_x,col_y,y)
! adjust dot_local because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%v%is_dev()) call x%sync()
if (y%v%is_dev()) call y%sync()
do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2)
res = res - (real(ndm-1)/real(ndm))*(x%v%v(idx,col_x)*y%v%v(idx,col_y))
end do
end if
else
res = dzero
end if
! compute global sum
if (global_) call psb_sum(ctxt, res)
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end function psb_ddot_multivect_col
!
! Function: psb_ddot ! Function: psb_ddot
! psb_ddot computes the dot product of two distributed vectors, ! psb_ddot computes the dot product of two distributed vectors,
! !

@ -483,6 +483,115 @@ function psb_dnrm2_multivect(x, desc_a, info, global) result(res)
end function psb_dnrm2_multivect end function psb_dnrm2_multivect
! Function: psb_dnrm2_multivect_col
! Computes the norm2 of a distributed multivector column,
!
! norm2 := sqrt ( X**C * X)
!
! Arguments:
! x - type(psb_d_multivect_type) The input vector containing the entries of X.
! col - integer. Multivector column index
! desc_a - type(psb_desc_type). The communication descriptor.
! info - integer. Return code
! global - logical(optional) Whether to perform the global reduction, default: .true.
!
function psb_dnrm2_multivect_col(x, col, desc_a, info, global) result(res)
use psb_desc_mod
use psb_check_mod
use psb_error_mod
use psb_penv_mod
use psb_d_multivect_mod
implicit none
real(psb_dpk_) :: res
type(psb_d_multivect_type), intent (inout) :: x
integer(psb_ipk_), intent(in) :: col
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: global
! locals
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me, err_act, idx, i, j, iix, jjx, ldx, ndm
real(psb_dpk_) :: dd
integer(psb_lpk_) :: ix, jx, m, n
logical :: global_
character(len=20) :: name, ch_err
name='psb_dnrm2mv_col'
call psb_erractionsave(err_act)
if (psb_errstatus_fatal()) then
info = psb_err_internal_error_ ; goto 9999
end if
info=psb_success_
ctxt=desc_a%get_context()
call psb_info(ctxt, me, np)
if (np == -1) then
info=psb_err_context_error_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
if (present(global)) then
global_ = global
else
global_ = .true.
end if
ix = 1
jx = 1
m = desc_a%get_global_rows()
n = x%get_ncols()
ldx = x%get_nrows()
call psb_chkvect(m,n,ldx,ix,jx,desc_a,info,iix,jjx)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect'
call psb_errpush(info,name,a_err=ch_err)
end if
if (iix /= 1) then
info=psb_err_ix_n1_iy_n1_unsupported_
call psb_errpush(info,name)
goto 9999
end if
if (desc_a%get_local_rows() > 0) then
res = x%nrm2(desc_a%get_local_rows(),col)
! adjust because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%v%is_dev()) call x%sync()
do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2)
dd = dble(ndm-1)/dble(ndm)
res = res * sqrt(done - dd*(abs(x%v%v(idx,col))/res)**2)
end do
end if
else
res = dzero
end if
if (global_) call psb_nrm2(ctxt,res)
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end function psb_dnrm2_multivect_col
! Function: psb_dnrm2_weight_vect ! Function: psb_dnrm2_weight_vect
! Computes the weighted norm2 of a distributed vector, ! Computes the weighted norm2 of a distributed vector,
! !

@ -243,6 +243,15 @@ int nrm2MultiVecDeviceDouble(double* y_res, int n, void* devMultiVecA)
return(i); return(i);
} }
int nrm2MultiVecDeviceDoubleCol(double* y_res, int n, int col, void* devMultiVecA)
{ int i=0;
spgpuHandle_t handle=psb_cudaGetHandle();
struct MultiVectDevice *devVecA = (struct MultiVectDevice *) devMultiVecA;
spgpuDmnrm2(handle, y_res, n, ((double *)devVecA->v_)+devVecA->pitch_*(col-1), 1, devVecA->pitch_);
return(i);
}
int amaxMultiVecDeviceDouble(double* y_res, int n, void* devMultiVecA) int amaxMultiVecDeviceDouble(double* y_res, int n, void* devMultiVecA)
{ int i=0; { int i=0;
spgpuHandle_t handle=psb_cudaGetHandle(); spgpuHandle_t handle=psb_cudaGetHandle();
@ -273,6 +282,17 @@ int dotMultiVecDeviceDouble(double* y_res, int n, void* devMultiVecA, void* devM
return(i); return(i);
} }
int dotMultiVecDeviceDoubleCol(double* y_res, int n, int col_x, int col_y, void* devMultiVecA, void* devMultiVecB)
{int i=0;
struct MultiVectDevice *devVecA = (struct MultiVectDevice *) devMultiVecA;
struct MultiVectDevice *devVecB = (struct MultiVectDevice *) devMultiVecB;
spgpuHandle_t handle=psb_cudaGetHandle();
spgpuDmdot(handle, y_res, n, ((double*)devVecA->v_)+devVecA->pitch_*(col_x-1),
((double*)devVecB->v_)+devVecB->pitch_*(col_y-1), 1, devVecA->pitch_);
return(i);
}
int axpbyMultiVecDeviceDouble(int n,double alpha, void* devMultiVecX, int axpbyMultiVecDeviceDouble(int n,double alpha, void* devMultiVecX,
double beta, void* devMultiVecY) double beta, void* devMultiVecY)
{ int j=0, i=0; { int j=0, i=0;
@ -290,6 +310,22 @@ int axpbyMultiVecDeviceDouble(int n,double alpha, void* devMultiVecX,
return(i); return(i);
} }
int axpbyMultiVecDeviceDoubleCol(int n, int col_x, int col_y, double alpha, void* devMultiVecX,
double beta, void* devMultiVecY)
{ int i=0;
int pitch = 0;
struct MultiVectDevice *devVecX = (struct MultiVectDevice *) devMultiVecX;
struct MultiVectDevice *devVecY = (struct MultiVectDevice *) devMultiVecY;
spgpuHandle_t handle=psb_cudaGetHandle();
pitch = devVecY->pitch_;
if ((n > devVecY->size_) || (n>devVecX->size_ ))
return SPGPU_UNSUPPORTED;
spgpuDaxpby(handle,((double *)devVecY->v_)+pitch*(col_y-1), n, beta,
((double *)devVecY->v_)+pitch*(col_y-1), alpha,((double *)devVecX->v_)+pitch*(col_x-1));
return(i);
}
int abgdxyzMultiVecDeviceDouble(int n,double alpha,double beta, double gamma, double delta, int abgdxyzMultiVecDeviceDouble(int n,double alpha,double beta, double gamma, double delta,
void* devMultiVecX, void* devMultiVecY, void* devMultiVecZ) void* devMultiVecX, void* devMultiVecY, void* devMultiVecZ)
{ int j=0, i=0; { int j=0, i=0;

@ -68,11 +68,14 @@ int iscatMultiVecDeviceDouble(void* deviceVec, int vectorId, int n, int first, v
int scalMultiVecDeviceDouble(double alpha, void* devMultiVecA); int scalMultiVecDeviceDouble(double alpha, void* devMultiVecA);
int nrm2MultiVecDeviceDouble(double* y_res, int n, void* devVecA); int nrm2MultiVecDeviceDouble(double* y_res, int n, void* devVecA);
int nrm2MultiVecDeviceDoubleCol(double* y_res, int n, int col, void* devVecA);
int amaxMultiVecDeviceDouble(double* y_res, int n, void* devVecA); int amaxMultiVecDeviceDouble(double* y_res, int n, void* devVecA);
int asumMultiVecDeviceDouble(double* y_res, int n, void* devVecA); int asumMultiVecDeviceDouble(double* y_res, int n, void* devVecA);
int dotMultiVecDeviceDouble(double* y_res, int n, void* devVecA, void* devVecB); int dotMultiVecDeviceDouble(double* y_res, int n, void* devVecA, void* devVecB);
int dotMultiVecDeviceDoubleCol(double* y_res, int n, int col_x, int col_y, void* devVecA, void* devVecB);
int axpbyMultiVecDeviceDouble(int n, double alpha, void* devVecX, double beta, void* devVecY); int axpbyMultiVecDeviceDouble(int n, double alpha, void* devVecX, double beta, void* devVecY);
int axpbyMultiVecDeviceDoubleCol(int n, int col_x, int col_y, double alpha, void* devVecX, double beta, void* devVecY);
int abgdxyzMultiVecDeviceDouble(int n,double alpha,double beta, double gamma, double delta, int abgdxyzMultiVecDeviceDouble(int n,double alpha,double beta, double gamma, double delta,
void* devMultiVecX, void* devMultiVecY, void* devMultiVecZ); void* devMultiVecX, void* devMultiVecY, void* devMultiVecZ);
int xyzwMultiVecDeviceDouble(int n,double a, double b, double c, double d, double e, double f, int xyzwMultiVecDeviceDouble(int n,double a, double b, double c, double d, double e, double f,

@ -1399,12 +1399,16 @@ module psb_d_cuda_multivect_mod
! !
! Product, dot-product (col-by-col) and AXPBY ! Product, dot-product (col-by-col) and AXPBY
! !
procedure, pass(x) :: prod_v => d_cuda_multi_prod_v procedure, pass(x) :: prod_v => d_cuda_multi_prod_v
procedure, pass(x) :: prod_a => d_cuda_multi_prod_a procedure, pass(x) :: prod_a => d_cuda_multi_prod_a
procedure, pass(x) :: dot_v => d_cuda_multi_dot_v procedure, pass(x) :: dot_v => d_cuda_multi_dot_v
procedure, pass(x) :: dot_a => d_cuda_multi_dot_a procedure, pass(x) :: dot_v_col => d_cuda_multi_dot_v_col
procedure, pass(y) :: axpby_v => d_cuda_multi_axpby_v procedure, pass(x) :: dot_a => d_cuda_multi_dot_a
procedure, pass(y) :: axpby_a => d_cuda_multi_axpby_a procedure, pass(x) :: dot_a_col => d_cuda_multi_dot_a_col
procedure, pass(y) :: axpby_v => d_cuda_multi_axpby_v
procedure, pass(y) :: axpby_v_col => d_cuda_multi_axpby_v_col
procedure, pass(y) :: axpby_a => d_cuda_multi_axpby_a
procedure, pass(y) :: axpby_a_col => d_cuda_multi_axpby_a_col
! !
! MultiVector by vector/multivector multiplication. Need all variants ! MultiVector by vector/multivector multiplication. Need all variants
! to handle multiple requirements from preconditioners ! to handle multiple requirements from preconditioners
@ -1416,7 +1420,8 @@ module psb_d_cuda_multivect_mod
! !
! Scaling and norms ! Scaling and norms
! !
procedure, pass(x) :: nrm2 => d_cuda_multi_nrm2 procedure, pass(x) :: nrm2_mv => d_cuda_multi_nrm2
procedure, pass(x) :: nrm2_col => d_cuda_multi_nrm2_col
procedure, pass(x) :: amax => d_cuda_multi_amax procedure, pass(x) :: amax => d_cuda_multi_amax
procedure, pass(x) :: asum => d_cuda_multi_asum procedure, pass(x) :: asum => d_cuda_multi_asum
!!$ procedure, pass(x) :: scal => d_cuda_multi_scal !!$ procedure, pass(x) :: scal => d_cuda_multi_scal
@ -1725,6 +1730,31 @@ contains
end function d_cuda_multi_dot_v end function d_cuda_multi_dot_v
function d_cuda_multi_dot_v_col(nr,col_x,col_y,x,y) result(res)
implicit none
class(psb_d_multivect_cuda), intent(inout) :: x
class(psb_d_base_multivect_type), intent(inout) :: y
integer(psb_ipk_), intent(in) :: nr, col_x, col_y
real(psb_dpk_) :: res
integer(psb_ipk_) :: info
!
! Note: this is the gpu implementation.
! When we get here, we are sure that X is of
! TYPE psb_d_vect
!
select type(yy => y)
type is (psb_d_multivect_cuda)
if (x%is_host()) call x%sync()
if (yy%is_host()) call yy%sync()
info = dotMultiVecDevice(res,nr,col_x,col_y,x%deviceVect,yy%deviceVect)
class default
if (y%is_host()) call y%sync()
res = x%dot(nr,col_x,col_y,y%v)
end select
end function d_cuda_multi_dot_v_col
function d_cuda_multi_dot_a(nr,x,y) result(res) function d_cuda_multi_dot_a(nr,x,y) result(res)
implicit none implicit none
class(psb_d_multivect_cuda), intent(inout) :: x class(psb_d_multivect_cuda), intent(inout) :: x
@ -1746,6 +1776,19 @@ contains
end function d_cuda_multi_dot_a end function d_cuda_multi_dot_a
function d_cuda_multi_dot_a_col(nr,col_x,col_y,x,y) result(res)
implicit none
class(psb_d_multivect_cuda), intent(inout) :: x
real(psb_dpk_), intent(in) :: y(:,:)
integer(psb_ipk_), intent(in) :: nr, col_x, col_y
real(psb_dpk_) :: res
real(psb_dpk_), external :: ddot
if (x%is_dev()) call x%sync()
res = ddot(nr,x%v(1:nr,col_x),1,y(1:nr,col_y),1)
end function d_cuda_multi_dot_a_col
subroutine d_cuda_multi_axpby_v(m,alpha, x, beta, y, info, n) subroutine d_cuda_multi_axpby_v(m,alpha, x, beta, y, info, n)
use psi_serial_mod use psi_serial_mod
implicit none implicit none
@ -1778,6 +1821,38 @@ contains
end subroutine d_cuda_multi_axpby_v end subroutine d_cuda_multi_axpby_v
subroutine d_cuda_multi_axpby_v_col(m, col_x, col_y, alpha, x, beta, y, info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: m
class(psb_d_base_multivect_type), intent(inout) :: x
class(psb_d_multivect_cuda), intent(inout) :: y
real(psb_dpk_), intent (in) :: alpha, beta
integer(psb_ipk_), intent(in) :: col_x, col_y
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: nx, ny
info = psb_success_
select type(xx => x)
type is (psb_d_multivect_cuda)
if ((beta /= dzero).and.(y%is_host())) call y%sync()
if (xx%is_host()) call xx%sync()
nx = getMultiVecDeviceSize(xx%deviceVect)
ny = getMultiVecDeviceSize(y%deviceVect)
if ((nx<m).or.(ny<m)) then
info = psb_err_internal_error_
else
info = axpbyMultiVecDevice(m,col_x,col_y,alpha,xx%deviceVect,beta,y%deviceVect)
end if
call y%set_dev()
class default
! Do it on the host side
if ((alpha /= dzero).and.(x%is_dev())) call x%sync()
call y%axpby(m,col_x,col_y,alpha,x%v,beta,info)
end select
end subroutine d_cuda_multi_axpby_v_col
subroutine d_cuda_multi_axpby_a(m,alpha, x, beta, y, info, n) subroutine d_cuda_multi_axpby_a(m,alpha, x, beta, y, info, n)
use psi_serial_mod use psi_serial_mod
implicit none implicit none
@ -1799,6 +1874,21 @@ contains
call y%set_host() call y%set_host()
end subroutine d_cuda_multi_axpby_a end subroutine d_cuda_multi_axpby_a
subroutine d_cuda_multi_axpby_a_col(m, col_x, col_y, alpha, x, beta, y, info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: m
real(psb_dpk_), intent(in) :: x(:,:)
class(psb_d_multivect_cuda), intent(inout) :: y
real(psb_dpk_), intent (in) :: alpha, beta
integer(psb_ipk_), intent(in) :: col_x, col_y
integer(psb_ipk_), intent(out) :: info
if ((beta /= dzero).and.(y%is_dev())) call y%sync()
call psb_geaxpby(m,alpha,x(:,col_x),beta,y%v(:,col_y),info)
call y%set_host()
end subroutine d_cuda_multi_axpby_a_col
!!$ subroutine d_cuda_multi_mlt_v(x, y, info) !!$ subroutine d_cuda_multi_mlt_v(x, y, info)
!!$ use psi_serial_mod !!$ use psi_serial_mod
!!$ implicit none !!$ implicit none
@ -1972,6 +2062,17 @@ contains
info = nrm2MultiVecDevice(res,nr,x%deviceVect) info = nrm2MultiVecDevice(res,nr,x%deviceVect)
end function d_cuda_multi_nrm2 end function d_cuda_multi_nrm2
function d_cuda_multi_nrm2_col(nr,col,x) result(res)
implicit none
class(psb_d_multivect_cuda), intent(inout) :: x
integer(psb_ipk_), intent(in) :: nr, col
real(psb_dpk_) :: res
integer(psb_ipk_) :: info
! WARNING: this should be changed.
if (x%is_host()) call x%sync()
info = nrm2MultiVecDevice(res,nr,col,x%deviceVect)
end function d_cuda_multi_nrm2_col
! TODO CUDA ! TODO CUDA
function d_cuda_multi_amax(nr,x) result(res) function d_cuda_multi_amax(nr,x) result(res)
implicit none implicit none

@ -279,7 +279,7 @@ module psb_d_vectordev_mod
end interface end interface
interface dotMultiVecDevice interface dotMultiVecDevice
function dotMultiVecDeviceDouble(res, n,deviceVecA,deviceVecB) & function dotMultiVecDeviceDouble(res,n,deviceVecA,deviceVecB) &
& result(val) bind(c,name='dotMultiVecDeviceDouble') & result(val) bind(c,name='dotMultiVecDeviceDouble')
use iso_c_binding use iso_c_binding
integer(c_int) :: val integer(c_int) :: val
@ -287,7 +287,7 @@ module psb_d_vectordev_mod
real(c_double) :: res real(c_double) :: res
type(c_ptr), value :: deviceVecA, deviceVecB type(c_ptr), value :: deviceVecA, deviceVecB
end function dotMultiVecDeviceDouble end function dotMultiVecDeviceDouble
function dotMultiVecDeviceDoubleR2(res, n,deviceVecA,deviceVecB,ld) & function dotMultiVecDeviceDoubleR2(res,n,deviceVecA,deviceVecB,ld) &
& result(val) bind(c,name='dotMultiVecDeviceDouble') & result(val) bind(c,name='dotMultiVecDeviceDouble')
use iso_c_binding use iso_c_binding
integer(c_int) :: val integer(c_int) :: val
@ -296,6 +296,14 @@ module psb_d_vectordev_mod
integer(c_int), value :: ld integer(c_int), value :: ld
type(c_ptr), value :: deviceVecA, deviceVecB type(c_ptr), value :: deviceVecA, deviceVecB
end function dotMultiVecDeviceDoubleR2 end function dotMultiVecDeviceDoubleR2
function dotMultiVecDeviceDoubleCol(res,n,col_x,col_y,deviceVecA,deviceVecB) &
& result(val) bind(c,name='dotMultiVecDeviceDoubleCol')
use iso_c_binding
integer(c_int) :: val
integer(c_int), value :: n, col_x, col_y
real(c_double) :: res
type(c_ptr), value :: deviceVecA, deviceVecB
end function dotMultiVecDeviceDoubleCol
end interface end interface
interface nrm2MultiVecDevice interface nrm2MultiVecDevice
@ -304,17 +312,25 @@ module psb_d_vectordev_mod
use iso_c_binding use iso_c_binding
integer(c_int) :: val integer(c_int) :: val
integer(c_int), value :: n integer(c_int), value :: n
real(c_double) :: res real(c_double) :: res
type(c_ptr), value :: deviceVecA type(c_ptr), value :: deviceVecA
end function nrm2MultiVecDeviceDouble end function nrm2MultiVecDeviceDouble
function nrm2MultiVecDeviceDoubleR2(res,n,deviceVecA) & function nrm2MultiVecDeviceDoubleR2(res,n,deviceVecA) &
& result(val) bind(c,name='nrm2MultiVecDeviceDouble') & result(val) bind(c,name='nrm2MultiVecDeviceDouble')
use iso_c_binding use iso_c_binding
integer(c_int) :: val integer(c_int) :: val
integer(c_int), value :: n integer(c_int), value :: n
real(c_double) :: res(*) real(c_double) :: res(*)
type(c_ptr), value :: deviceVecA type(c_ptr), value :: deviceVecA
end function nrm2MultiVecDeviceDoubleR2 end function nrm2MultiVecDeviceDoubleR2
function nrm2MultiVecDeviceDoubleCol(res,n,col,deviceVecA) &
& result(val) bind(c,name='nrm2MultiVecDeviceDoubleCol')
use iso_c_binding
integer(c_int) :: val
integer(c_int), value :: n, col
real(c_double) :: res
type(c_ptr), value :: deviceVecA
end function nrm2MultiVecDeviceDoubleCol
end interface end interface
interface amaxMultiVecDevice interface amaxMultiVecDevice
@ -348,6 +364,14 @@ module psb_d_vectordev_mod
real(c_double), value :: alpha, beta real(c_double), value :: alpha, beta
type(c_ptr), value :: deviceVecA, deviceVecB type(c_ptr), value :: deviceVecA, deviceVecB
end function axpbyMultiVecDeviceDouble end function axpbyMultiVecDeviceDouble
function axpbyMultiVecDeviceDoubleCol(n,col_x,col_y,alpha,deviceVecA,beta,deviceVecB) &
& result(res) bind(c,name='axpbyMultiVecDeviceDoubleCol')
use iso_c_binding
integer(c_int) :: res
integer(c_int), value :: n, col_x, col_y
real(c_double), value :: alpha, beta
type(c_ptr), value :: deviceVecA, deviceVecB
end function axpbyMultiVecDeviceDoubleCol
end interface end interface
interface abgdxyzMultiVecDevice interface abgdxyzMultiVecDevice

@ -155,7 +155,7 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
if (info == psb_success_) call psb_geall(r0,desc_a,info,n=nrhs) if (info == psb_success_) call psb_geall(r0,desc_a,info,n=nrhs)
if (info == psb_success_) call psb_geall(rm,desc_a,info,n=nrhs) if (info == psb_success_) call psb_geall(rm,desc_a,info,n=nrhs)
if (info == psb_success_) call psb_geall(pd,desc_a,info,n=nrhs) if (info == psb_success_) call psb_geall(pd,desc_a,info,n=nrhs)
if (info == psb_success_) call psb_geasb(v,desc_a,info,mold=x%v,n=nrhs) if (info == psb_success_) call psb_geasb(v(1),desc_a,info,mold=x%v,n=nrhs)
if (info == psb_success_) call psb_geasb(w,desc_a,info,mold=x%v,n=nrhs) if (info == psb_success_) call psb_geasb(w,desc_a,info,mold=x%v,n=nrhs)
if (info == psb_success_) call psb_geasb(r0,desc_a,info,mold=x%v,n=nrhs) if (info == psb_success_) call psb_geasb(r0,desc_a,info,mold=x%v,n=nrhs)
if (info == psb_success_) call psb_geasb(rm,desc_a,info,mold=x%v,n=nrhs) if (info == psb_success_) call psb_geasb(rm,desc_a,info,mold=x%v,n=nrhs)
@ -224,7 +224,7 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
end if end if
! STEP 2: Compute QR_fact(R(0)) ! STEP 2: Compute QR_fact(R(0))
beta(1:nrhs,1:nrhs) = qr_fact(v(1)) beta(1:nrhs,1:nrhs) = mgs_qr_fact(v(1))
if (info /= psb_success_) then if (info /= psb_success_) then
info=psb_err_from_subroutine_non_ info=psb_err_from_subroutine_non_
call psb_errpush(info,name) call psb_errpush(info,name)
@ -234,6 +234,14 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
! STEP 3: Outer loop ! STEP 3: Outer loop
outer: do j=1,nrep outer: do j=1,nrep
! Assembly next iteration
call psb_geasb(v(j+1),desc_a,info,mold=x%v,n=nrhs)
if (info /= psb_success_) then
info=psb_err_from_subroutine_non_
call psb_errpush(info,name)
goto 9999
end if
! Update itx counter ! Update itx counter
itx = itx + 1 itx = itx + 1
@ -285,7 +293,7 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
! STEP 8: Compute QR_fact(W) ! STEP 8: Compute QR_fact(W)
! Store R in H(j+1,j) ! Store R in H(j+1,j)
h(idx_j+nrhs:idx_j+nrhs+n_add,idx_j:idx_j+n_add) = qr_fact(w) h(idx_j+nrhs:idx_j+nrhs+n_add,idx_j:idx_j+n_add) = mgs_qr_fact(w)
if (info /= psb_success_) then if (info /= psb_success_) then
info=psb_err_from_subroutine_non_ info=psb_err_from_subroutine_non_
call psb_errpush(info,name) call psb_errpush(info,name)
@ -380,8 +388,8 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
contains contains
! QR factorization ! Modified Gram-Schmidt QR factorization
function qr_fact(mv) result(res) function mgs_qr_fact(mv) result(res)
implicit none implicit none
! I/O parameters ! I/O parameters
@ -389,51 +397,37 @@ contains
real(psb_dpk_), allocatable :: res(:,:) real(psb_dpk_), allocatable :: res(:,:)
! Utils ! Utils
real(psb_dpk_), allocatable :: vec(:,:) real(psb_dpk_) :: scal
real(psb_dpk_), allocatable :: tau(:), work(:) integer(psb_ipk_) :: i, j
integer(psb_ipk_) :: ii, jj, m, lda, lwork
! Allocate output ! Allocate output
allocate(res(nrhs,nrhs)) allocate(res(nrhs,nrhs))
! Gather multivector to factorize ! Start factorization
call psb_gather(vec,mv,desc_a,info,root=psb_root_) do i=1,nrhs
! If root
if (me == psb_root_) then
! Initialize params
m = size(vec,1)
lda = m
lwork = nrhs
allocate(tau(nrhs),work(lwork))
! Perform QR factorization ! Compute R(i,i) = nrm2(W(:,i))
call dgeqrf(m,nrhs,vec,lda,tau,work,lwork,info) res(i,i) = psb_genrm2(mv,i,desc_a,info)
! Set R ! Compute 1/R(i,i)
res = vec(1:nrhs,1:nrhs) scal = done/res(i,i)
do ii = 2, nrhs
do jj = 1, ii-1
res(ii,jj) = dzero
enddo
enddo
! Generate Q matrix ! Compute W(:,i) = W(:,i)/R(i,i)
call dorgqr(m,nrhs,nrhs,vec,lda,tau,work,lwork,info) call psb_geaxpby(i,i,scal,mv,dzero,mv,desc_a,info)
! Deallocate ! Iterate over columns
deallocate(tau,work) do j=i+1,nrhs
end if ! Compute R(i,j) = W(:,i)'*W(:,j)
res(i,j) = psb_gedot(i,j,mv,mv,desc_a,info)
! Share R ! Compute W(:,j) = W(:,j) - R(i,j)*W(:,i)
call psb_bcast(ctxt,res) call psb_geaxpby(i,j,-res(i,j),mv,done,mv,desc_a,info)
! Scatter Q end do
call psb_scatter(vec,mv,desc_a,info,root=psb_root_,mold=mv%v) end do
end function qr_fact end function mgs_qr_fact
function givens_rotation(rep) result(res) function givens_rotation(rep) result(res)

@ -641,8 +641,9 @@ program dpdegen
end if end if
! set random RHS ! set random RHS
call random_number(b_mv%v%v) call b_mv%zero()
b_mv%v%v = -10 + (20)*b_mv%v%v call random_number(b_mv%v%v(1:desc_a%get_local_rows(),:))
b_mv%v%v(1:desc_a%get_local_rows(),:) = -10 + (20)*b_mv%v%v
call b_mv%v%set_host() call b_mv%v%set_host()
call b_mv%sync() call b_mv%sync()

Loading…
Cancel
Save