Implemented right triangular solve

randomized
Fabio Durastante 1 year ago
parent 96b509417a
commit e423e149fa

@ -543,7 +543,7 @@ module psb_d_psblas_mod
subroutine psb_dmlt_multivect(x, y, res, desc_a,info,global)
import :: psb_desc_type, psb_dpk_, psb_ipk_, &
& psb_d_multivect_type, psb_dspmat_type
real(psb_dpk_), dimension(:,:), allocatable :: res
real(psb_dpk_), dimension(:,:), allocatable, intent(inout) :: res
type(psb_d_multivect_type), intent(inout) :: x, y
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
@ -588,6 +588,18 @@ module psb_d_psblas_mod
integer(psb_ipk_), intent(out) :: info
logical, intent(in) :: flag
end subroutine psb_ddiv_vect2_check
subroutine psb_ddiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag)
import :: psb_desc_type, psb_ipk_, &
& psb_dpk_, psb_d_multivect_type
type(psb_d_multivect_type), intent (inout) :: x
real(psb_dpk_), intent (in), dimension(:,:) :: a
type(psb_desc_type), intent (in) :: desc_a
character(len=1), intent(in) :: uplo
integer(psb_ipk_), intent(out) :: info
real(psb_dpk_), intent (in), optional :: alpha
character(len=1), intent(in), optional :: trans
character(len=1), intent(in), optional :: diag
end subroutine psb_ddiv_trslv
end interface
interface psb_geinv

@ -2911,8 +2911,8 @@ contains
if (x%is_dev()) call x%sync()
if (x%is_sync()) then
! Call BLAS function to solve the system
lda = n
ldb = n
lda = size(a,1)
ldb = x%get_nrows()
side = 'R' ! X*op( A ) = alpha*B.
call dtrsm(side, uplo, trans_, diag_, n, x%get_ncols(), alpha_, a, lda, x%v, ldb)
end if
@ -3141,9 +3141,10 @@ contains
!! \param y The class(base_mlv_vect) to be multiplied by
!! \param a The resulting matrix
!! \param info return code
subroutine d_base_mlv_mlt_mv2(x,y,a,info)
subroutine d_base_mlv_mlt_mv2(n,x,y,a,info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: n
class(psb_d_base_multivect_type), intent(inout) :: x
class(psb_d_base_multivect_type), intent(inout) :: y
real(psb_dpk_), intent(inout), allocatable :: a(:,:)
@ -3171,7 +3172,7 @@ contains
! C = alpha*op( A )*op( B ) + beta*C
! In our case, we want to compute
! C = X'*Y
call dgemm('T', 'N', x%get_ncols(), y%get_ncols(), x%get_nrows(), done, &
call dgemm('T', 'N', x%get_ncols(), y%get_ncols(), n, done, &
& x%v, x%get_nrows(), y%v, y%get_nrows(), dzero, a, x%get_ncols())
end subroutine d_base_mlv_mlt_mv2

@ -1976,16 +1976,17 @@ contains
end if
end subroutine d_mlv_trslv
subroutine d_mvect_mlt_mv2(x,y,a,info)
subroutine d_mvect_mlt_mv2(n,x,y,a,info)
use psi_serial_mod
implicit none
integer(psb_ipk_), intent(in) :: n
class(psb_d_multivect_type), intent(inout) :: x
class(psb_d_multivect_type), intent(inout) :: y
real(psb_dpk_), intent(inout), allocatable :: a(:,:)
integer(psb_ipk_), intent(out) :: info
if (allocated(x%v).and.allocated(y%v)) then
call y%v%mlt(x%v,a,info)
call y%v%mlt(n,x%v,a,info)
else
info = psb_err_invalid_vect_state_
return

@ -444,3 +444,70 @@ function psb_dminquotient_vect(x,y,desc_a,info,global) result(res)
return
end function psb_dminquotient_vect
subroutine psb_ddiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag)
use psb_base_mod, psb_protect_name => psb_ddiv_trslv
implicit none
type(psb_d_multivect_type), intent (inout) :: x
real(psb_dpk_), intent (in), dimension(:,:) :: a
type(psb_desc_type), intent (in) :: desc_a
character(len=1), intent(in) :: uplo
integer(psb_ipk_), intent(out) :: info
real(psb_dpk_), intent (in), optional :: alpha
character(len=1), intent(in), optional :: trans
character(len=1), intent(in), optional :: diag
! locals
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me,&
& err_act, iix, jjx, iiy, jjy, nr
integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nx
character(len=20) :: name, ch_err
name='psb_ddiv_trslv'
if (psb_errstatus_fatal()) return
info=psb_success_
call psb_erractionsave(err_act)
ctxt=desc_a%get_context()
call psb_info(ctxt, me, np)
if (np == -ione) then
info = psb_err_context_error_
call psb_errpush(info,name)
goto 9999
endif
if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)
goto 9999
endif
ix = ione
ijx = ione
m = desc_a%get_global_rows()
nx = x%get_ncols()
! check vector correctness
call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect 1'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
nr = desc_a%get_local_rows()
if(nr > 0) then
call x%trslv(nr,a,uplo,alpha,trans,diag,info)
end if
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine psb_ddiv_trslv

@ -480,7 +480,8 @@ function psb_ddot_mvect_vect(x, y, desc_a,info,global) result(res)
do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2)
! Remove the overlapped elements via dgemv calls
! FIXME: MAKES NO SENSE!
! Remove the overlapped elements via dgemv calls which are axpy
! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res
call dgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), &
& size(x%v%v,1),y%v%v(idx),ione,done,res,ione)

@ -169,7 +169,7 @@ subroutine psb_dmlt_multivect(x, y, res,desc_a,info,global)
use psb_d_multivect_mod
use psb_d_psblas_mod, psb_protect_name => psb_dmlt_multivect
implicit none
real(psb_dpk_), dimension(:,:), allocatable :: res
real(psb_dpk_), dimension(:,:), allocatable, intent(inout) :: res
type(psb_d_multivect_type), intent(inout) :: x, y
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
@ -246,17 +246,16 @@ subroutine psb_dmlt_multivect(x, y, res,desc_a,info,global)
call psb_errpush(info,name)
goto 9999
else
allocate(res(x%get_nrows(),x%get_ncols()),stat=info)
if (info /= 0) then
info=psb_err_alloc_dealloc_
call psb_errpush(info,name)
goto 9999
if (allocated(res)) then
if ((size(res,1) /= x%get_ncols()).or.(size(res,2) /= y%get_ncols())) then
deallocate(res,stat=info)
end if
end if
end if
nr = desc_a%get_local_rows()
if(nr > 0) then
call x%mlt(y,res,info)
call x%mlt(nr,y,res,info)
! FIXME
! adjust dot_local because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
@ -265,15 +264,20 @@ subroutine psb_dmlt_multivect(x, y, res,desc_a,info,global)
do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2)
! FIXME: case of AS-type descriptors
! Since I'm coputing res via a dgemm on the whole vector, I need to adjust
! the result by removing the contribution of the overlapped elements
! specifically: res(:,:) = res(:,:) - x%v%v(idx,:)^T*y%v%v(idx,:)
! using dgemm to compute the matrix-matrix product of the form R = R - X'*Y
! where R is the result, X' is the transpose of the matrix x%v%v(idx,:)
! and Y is the matrix y%v%v(idx,:)
call dgemm('T','N',size(x%v%v(idx,:),1),size(y%v%v(idx,:),1),&
& size(x%v%v(idx,:),1),-done,x%v%v(idx,:),size(x%v%v(idx,:),1),&
& y%v%v(idx,:),size(y%v%v(idx,:),1),done,res,y%get_ncols())
! call dgemm('T','N',size(x%v%v(idx,:),1),size(y%v%v(idx,:),1),&
! & size(x%v%v(idx,:),1),-done,x%v%v(idx,:),size(x%v%v(idx,:),1),&
! & y%v%v(idx,:),size(y%v%v(idx,:),1),done,res,y%get_ncols())
info = psb_err_internal_error_
ch_err='over_elem_unsup'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end do
end if
else

@ -45,7 +45,8 @@ module unittestvector_mod
module procedure psb_d_check_ans_v, psb_c_check_ans_v, &
& psb_z_check_ans_v, psb_s_check_ans_v, &
& psb_d_check_ans_mv, psb_s_check_ans_mv, &
& psb_c_check_ans_mv, psb_z_check_ans_mv
& psb_c_check_ans_mv, psb_z_check_ans_mv, &
& psb_d_check_ans_mv_a
end interface psb_check_ans
contains
@ -210,6 +211,43 @@ contains
end function psb_d_check_ans_mv
function psb_d_check_ans_mv_a(v,val,ctxt) result(ans)
use psb_base_mod
implicit none
type(psb_d_multivect_type) :: v
real(psb_dpk_) :: val(:)
type(psb_ctxt_type) :: ctxt
logical :: ans
! Local variables
integer(psb_ipk_) :: np, iam, info,i
real(psb_dpk_) :: check
real(psb_dpk_), allocatable :: va(:,:)
call psb_info(ctxt,iam,np)
va = v%get_vect()
! subtract the row vector val from every row of va
do i=1,size(va,1)
va(i,:) = va(i,:) - val;
end do
check = maxval(va);
call psb_sum(ctxt,check)
if(check == 0.d0) then
ans = .true.
else
ans = .false.
end if
end function psb_d_check_ans_mv_a
function psb_s_check_ans_mv(v,val,ctxt) result(ans)
use psb_base_mod
@ -1050,11 +1088,13 @@ program vecoperation
type(psb_z_multivect_type) :: zmv1, zmv2
! scalars
real(psb_dpk_), allocatable, dimension(:,:) :: res
real(psb_dpk_), allocatable, dimension(:,:) :: a
real(psb_dpk_), allocatable, dimension(:) :: check_row
! blacs parameters
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: iam, np
! auxiliary parameters
integer(psb_ipk_) :: ii
integer(psb_ipk_) :: ii,jj
integer(psb_ipk_) :: info
character(len=20) :: name,ch_err,readinput
real(psb_dpk_) :: ans
@ -1491,6 +1531,41 @@ program vecoperation
if(hasitnotfailed) write(psb_out_unit,'("TEST PASSED >>> Constant multivector (complex double precision)")')
if(.not.hasitnotfailed) write(psb_out_unit,'("TEST FAILED --- Constant multivector (complex double precision)")')
end if
! X = 1, T = upper triangular of all ones
call psb_d_gen_const_multi(mv1,done,idim,nmv,ctxt,desc_a,info)
allocate(a(nmv,nmv),check_row(nmv))
do ii=1,nmv
do jj=ii,nmv
a(ii,jj) = done
end do
end do
check_row = 0
check_row(1) = done
call psb_gediv(mv1,a,desc_a,'U',info)
hasitnotfailed = psb_check_ans(mv1,check_row,ctxt)
if (iam == psb_root_) then
if(hasitnotfailed) write(psb_out_unit,'("TEST PASSED >>> Triangular solve (UP) mv1 = mv1 / T")')
if(.not.hasitnotfailed) write(psb_out_unit,'("TEST FAILED --- Triangular solve (UP) mv1 = mv1 / T")')
end if
! X = 1, T = lower triangular of all ones
call psb_d_gen_const_multi(mv1,done,idim,nmv,ctxt,desc_a,info)
if (allocated(a)) deallocate(a)
if (allocated(check_row)) deallocate(check_row)
allocate(a(nmv,nmv),check_row(nmv))
do ii=1,nmv
do jj=1,ii
a(ii,jj) = done
end do
end do
check_row = 0
check_row(nmv) = done
call psb_gediv(mv1,a,desc_a,'L',info)
hasitnotfailed = psb_check_ans(mv1,check_row,ctxt)
if (iam == psb_root_) then
if(hasitnotfailed) write(psb_out_unit,'("TEST PASSED >>> Triangular solve (LOW) mv1 = mv1 / T")')
if(.not.hasitnotfailed) write(psb_out_unit,'("TEST FAILED --- Triangular solve (LOW) mv1 = mv1 / T")')
end if
!
! Multivector to field operation
@ -1604,6 +1679,8 @@ program vecoperation
call psb_gefree(zmv2,desc_a,info)
call psb_cdfree(desc_a,info)
if(allocated(res)) deallocate(res)
if(allocated(a)) deallocate(a)
if(allocated(check_row)) deallocate(check_row)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='free routine'

Loading…
Cancel
Save