Added convergence check

psblas-bgmres
gabrielequatrana 7 months ago
parent 1b79939255
commit 0839165bdc

@ -119,6 +119,16 @@ module psb_d_psblas_mod
logical, intent(in), optional :: trans logical, intent(in), optional :: trans
logical, intent(in), optional :: global logical, intent(in), optional :: global
end function psb_dprod_multivect_a end function psb_dprod_multivect_a
function psb_dprod_m(x,y,desc_a,info,trans,global) result(res)
import :: psb_desc_type, psb_dpk_, psb_ipk_, &
& psb_d_multivect_type, psb_dspmat_type
real(psb_dpk_), allocatable :: res(:,:)
real(psb_dpk_), intent(in) :: x(:,:), y(:,:)
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: trans
logical, intent(in), optional :: global
end function psb_dprod_m
end interface end interface
interface psb_geaxpby interface psb_geaxpby
@ -402,10 +412,10 @@ module psb_d_psblas_mod
logical, intent(in), optional :: global logical, intent(in), optional :: global
type(psb_d_vect_type), intent (inout), optional :: aux type(psb_d_vect_type), intent (inout), optional :: aux
end function psb_dnrm2_weightmask_vect end function psb_dnrm2_weightmask_vect
function psb_dnrm2_multivect(x, desc_a, info,global) result(res) function psb_dnrm2_multivect(x, desc_a, info, global) result(res)
import :: psb_desc_type, psb_dpk_, psb_ipk_, & import :: psb_desc_type, psb_dpk_, psb_ipk_, &
& psb_d_multivect_type, psb_dspmat_type & psb_d_multivect_type, psb_dspmat_type
real(psb_dpk_) :: res real(psb_dpk_), allocatable :: res(:)
type(psb_d_multivect_type), intent (inout) :: x type(psb_d_multivect_type), intent (inout) :: x
type(psb_desc_type), intent (in) :: desc_a type(psb_desc_type), intent (in) :: desc_a
integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), intent(out) :: info

@ -3296,13 +3296,16 @@ contains
implicit none implicit none
class(psb_d_base_multivect_type), intent(inout) :: x class(psb_d_base_multivect_type), intent(inout) :: x
integer(psb_ipk_), intent(in) :: nr integer(psb_ipk_), intent(in) :: nr
real(psb_dpk_), allocatable :: res real(psb_dpk_), allocatable :: res(:)
integer(psb_ipk_) :: nc, col
real(psb_dpk_), external :: dnrm2 real(psb_dpk_), external :: dnrm2
integer(psb_ipk_) :: j, nc
if (x%is_dev()) call x%sync() if (x%is_dev()) call x%sync()
nc = x%get_ncols() nc = x%get_ncols()
res = dnrm2(nc*nr,x%v,1) allocate(res(nc))
do col=1,nc
res(col) = dnrm2(nr,x%v(:,col),1)
end do
end function d_base_mlv_nrm2 end function d_base_mlv_nrm2

@ -2051,7 +2051,7 @@ contains
implicit none implicit none
class(psb_d_multivect_type), intent(inout) :: x class(psb_d_multivect_type), intent(inout) :: x
integer(psb_ipk_), intent(in) :: nr integer(psb_ipk_), intent(in) :: nr
real(psb_dpk_), allocatable :: res real(psb_dpk_), allocatable :: res(:)
if (allocated(x%v)) then if (allocated(x%v)) then
res = x%v%nrm2(nr) res = x%v%nrm2(nr)

@ -384,7 +384,7 @@ end function psb_dnrm2_vect
! info - integer. Return code ! info - integer. Return code
! global - logical(optional) Whether to perform the global reduction, default: .true. ! global - logical(optional) Whether to perform the global reduction, default: .true.
! !
function psb_dnrm2_multivect(x, desc_a, info,global) result(res) function psb_dnrm2_multivect(x, desc_a, info, global) result(res)
use psb_desc_mod use psb_desc_mod
use psb_check_mod use psb_check_mod
use psb_error_mod use psb_error_mod
@ -392,7 +392,7 @@ function psb_dnrm2_multivect(x, desc_a, info,global) result(res)
use psb_d_multivect_mod use psb_d_multivect_mod
implicit none implicit none
real(psb_dpk_) :: res real(psb_dpk_), allocatable :: res(:)
type(psb_d_multivect_type), intent (inout) :: x type(psb_d_multivect_type), intent (inout) :: x
type(psb_desc_type), intent(in) :: desc_a type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), intent(out) :: info
@ -402,7 +402,7 @@ function psb_dnrm2_multivect(x, desc_a, info,global) result(res)
type(psb_ctxt_type) :: ctxt type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me, err_act, idx, i, j, iix, jjx, ldx, ndm integer(psb_ipk_) :: np, me, err_act, idx, i, j, iix, jjx, ldx, ndm
real(psb_dpk_) :: dd real(psb_dpk_) :: dd
integer(psb_lpk_) :: ix, jx, m, n integer(psb_lpk_) :: ix, jx, m
logical :: global_ logical :: global_
character(len=20) :: name, ch_err character(len=20) :: name, ch_err
@ -438,10 +438,9 @@ function psb_dnrm2_multivect(x, desc_a, info,global) result(res)
jx = 1 jx = 1
m = desc_a%get_global_rows() m = desc_a%get_global_rows()
n = x%get_ncols()
ldx = x%get_nrows() ldx = x%get_nrows()
call psb_chkvect(m,n,ldx,ix,jx,desc_a,info,iix,jjx) call psb_chkvect(m,x%get_ncols(),ldx,ix,jx,desc_a,info,iix,jjx)
if(info /= psb_success_) then if(info /= psb_success_) then
info=psb_err_from_subroutine_ info=psb_err_from_subroutine_
ch_err='psb_chkvect' ch_err='psb_chkvect'
@ -455,7 +454,7 @@ function psb_dnrm2_multivect(x, desc_a, info,global) result(res)
end if end if
if (desc_a%get_local_rows() > 0) then if (desc_a%get_local_rows() > 0) then
res = x%nrm2(desc_a%get_local_rows()) res = x%nrm2(desc_a%get_local_rows())
! adjust because overlapped elements are computed more than once ! adjust because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then if (size(desc_a%ovrlap_elem,1)>0) then
if (x%v%is_dev()) call x%sync() if (x%v%is_dev()) call x%sync()
@ -464,11 +463,12 @@ function psb_dnrm2_multivect(x, desc_a, info,global) result(res)
idx = desc_a%ovrlap_elem(i,1) idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2) ndm = desc_a%ovrlap_elem(i,2)
dd = dble(ndm-1)/dble(ndm) dd = dble(ndm-1)/dble(ndm)
res = res * sqrt(done - dd*(abs(x%v%v(idx,j))/res)**2) res(j) = res(j) * sqrt(done - dd*(abs(x%v%v(idx,j))/res(j))**2)
end do end do
end do end do
end if end if
else else
allocate(res(x%get_ncols()))
res = dzero res = dzero
end if end if

@ -187,108 +187,247 @@ end function psb_dprod_multivect
! !
! !
function psb_dprod_multivect_a(x,y,desc_a,info,trans,global) result(res) function psb_dprod_multivect_a(x,y,desc_a,info,trans,global) result(res)
use psb_desc_mod use psb_desc_mod
use psb_d_base_mat_mod use psb_d_base_mat_mod
use psb_check_mod use psb_check_mod
use psb_error_mod use psb_error_mod
use psb_penv_mod use psb_penv_mod
use psb_d_vect_mod use psb_d_vect_mod
use psb_d_psblas_mod, psb_protect_name => psb_dprod_multivect_a use psb_d_psblas_mod, psb_protect_name => psb_dprod_multivect_a
implicit none implicit none
real(psb_dpk_), allocatable :: res(:,:) real(psb_dpk_), allocatable :: res(:,:)
type(psb_d_multivect_type), intent(inout) :: x type(psb_d_multivect_type), intent(inout) :: x
real(psb_dpk_), intent(in) :: y(:,:) real(psb_dpk_), intent(in) :: y(:,:)
type(psb_desc_type), intent(in) :: desc_a type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: trans logical, intent(in), optional :: trans
logical, intent(in), optional :: global logical, intent(in), optional :: global
! locals ! locals
type(psb_ctxt_type) :: ctxt type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me, idx, ndm,& integer(psb_ipk_) :: np, me, idx, ndm,&
& err_act, iix, jjx, iiy, jjy, i, j, nr & err_act, iix, jjx, iiy, jjy, i, j, nr
integer(psb_lpk_) :: ix, ijx, iy, ijy, m integer(psb_lpk_) :: ix, ijx, iy, ijy, m
logical :: global_, trans_ logical :: global_, trans_
character(len=20) :: name, ch_err character(len=20) :: name, ch_err
name='psb_dprod_multivect' name='psb_dprod_multivect'
info=psb_success_ info=psb_success_
call psb_erractionsave(err_act) call psb_erractionsave(err_act)
if (psb_errstatus_fatal()) then if (psb_errstatus_fatal()) then
info = psb_err_internal_error_ ; goto 9999 info = psb_err_internal_error_ ; goto 9999
end if end if
ctxt=desc_a%get_context() ctxt=desc_a%get_context()
call psb_info(ctxt, me, np) call psb_info(ctxt, me, np)
if (np == -ione) then if (np == -ione) then
info = psb_err_context_error_ info = psb_err_context_error_
call psb_errpush(info,name) call psb_errpush(info,name)
goto 9999 goto 9999
endif endif
if (.not.allocated(x%v)) then if (.not.allocated(x%v)) then
info = psb_err_invalid_vect_state_ info = psb_err_invalid_vect_state_
call psb_errpush(info,name) call psb_errpush(info,name)
goto 9999 goto 9999
endif endif
if (present(trans)) then if (present(trans)) then
trans_ = trans trans_ = trans
else else
trans_ = .false. trans_ = .false.
end if end if
if (present(global)) then if (present(global)) then
global_ = global global_ = global
else else
global_ = .true. global_ = .true.
end if end if
ix = ione ix = ione
ijx = ione ijx = ione
m = desc_a%get_global_rows() m = desc_a%get_global_rows()
! check vector correctness ! check vector correctness
call psb_chkvect(m,x%get_ncols(),x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) call psb_chkvect(m,x%get_ncols(),x%get_nrows(),ix,ijx,desc_a,info,iix,jjx)
if(info /= psb_success_) then if(info /= psb_success_) then
info=psb_err_from_subroutine_ info=psb_err_from_subroutine_
ch_err='psb_chkvect' ch_err='psb_chkvect'
call psb_errpush(info,name,a_err=ch_err) call psb_errpush(info,name,a_err=ch_err)
goto 9999 goto 9999
end if end if
if ((iix /= ione)) then if ((iix /= ione)) then
info=psb_err_ix_n1_iy_n1_unsupported_ info=psb_err_ix_n1_iy_n1_unsupported_
call psb_errpush(info,name) call psb_errpush(info,name)
goto 9999 goto 9999
end if end if
nr = desc_a%get_local_rows() nr = desc_a%get_local_rows()
if (nr > 0) then if (nr > 0) then
res = x%prod(nr,y,trans_) res = x%prod(nr,y,trans_)
! adjust dot_local because overlapped elements are computed more than once ! adjust dot_local because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then if (size(desc_a%ovrlap_elem,1)>0) then
if (x%v%is_dev()) call x%sync() if (x%v%is_dev()) call x%sync()
do j=1,x%get_ncols() do j=1,x%get_ncols()
do i=1,size(desc_a%ovrlap_elem,1) do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1) idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2) ndm = desc_a%ovrlap_elem(i,2)
res(j,:) = res(j,:) - (real(ndm-1)/real(ndm))*(x%v%v(idx,:)*y(idx,:)) res(j,:) = res(j,:) - (real(ndm-1)/real(ndm))*(x%v%v(idx,:)*y(idx,:))
end do end do
end do end do
end if end if
else else
res = dzero res = dzero
end if end if
! compute global sum ! compute global sum
if (global_) call psb_sum(ctxt, res) if (global_) call psb_sum(ctxt, res)
call psb_erractionrestore(err_act) call psb_erractionrestore(err_act)
return return
9999 call psb_error_handler(ctxt,err_act) 9999 call psb_error_handler(ctxt,err_act)
return return
end function psb_dprod_multivect_a end function psb_dprod_multivect_a
!
! Function: psb_dprod_m
! psb_dprod computes the product of two distributed multivectors,
!
! prod := ( X ) * ( Y ) or
! prod := ( X )**C * ( Y )
!
!
! Arguments:
! x - real(:,:) The input vector containing the entries of sub( X ).
! y - real(:,:) The input vector containing the entries of sub( Y ).
! desc_a - type(psb_desc_type). The communication descriptor.
! info - integer. Return code
! trans - logical(optional) Whether multivector X is transposed, default: .false.
! global - logical(optional) Whether to perform the global reduce, default: .true.
!
! Note: from a functional point of view, X and Y are input, but here
! they are declared INOUT because of the sync() methods.
!
!
function psb_dprod_m(x,y,desc_a,info,trans,global) result(res)
use psb_desc_mod
use psb_d_base_mat_mod
use psb_check_mod
use psb_error_mod
use psb_penv_mod
use psb_d_vect_mod
use psb_d_psblas_mod, psb_protect_name => psb_dprod_m
implicit none
real(psb_dpk_), allocatable :: res(:,:)
real(psb_dpk_), intent(in) :: x(:,:), y(:,:)
type(psb_desc_type), intent(in) :: desc_a
integer(psb_ipk_), intent(out) :: info
logical, intent(in), optional :: trans
logical, intent(in), optional :: global
! locals
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me, idx, ndm,&
& err_act, iix, jjx, iiy, jjy, i, j, nr, x_n, y_n, lda, ldb
integer(psb_lpk_) :: ix, ijx, iy, ijy, m
logical :: global_, trans_
character(len=20) :: name, ch_err
name='psb_dprod_multivect'
info=psb_success_
call psb_erractionsave(err_act)
if (psb_errstatus_fatal()) then
info = psb_err_internal_error_ ; goto 9999
end if
ctxt=desc_a%get_context()
call psb_info(ctxt, me, np)
if (np == -ione) then
info = psb_err_context_error_
call psb_errpush(info,name)
goto 9999
endif
if (present(trans)) then
trans_ = trans
else
trans_ = .false.
end if
if (present(global)) then
global_ = global
else
global_ = .true.
end if
ix = ione
ijx = ione
m = desc_a%get_global_rows()
! check vector correctness
call psb_chkvect(m,size(x,2),size(x,1),ix,ijx,desc_a,info,iix,jjx)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='psb_chkvect'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
if ((iix /= ione)) then
info=psb_err_ix_n1_iy_n1_unsupported_
call psb_errpush(info,name)
goto 9999
end if
nr = desc_a%get_local_rows()
x_n = size(x,2)
y_n = size(y,2)
lda = size(x,1)
if (nr > 0) then
if (trans_) then
allocate(res(x_n,y_n))
res = dzero
ldb = size(y,1)
call dgemm('T','N',x_n,y_n,nr,done,x,lda,y,ldb,dzero,res,x_n)
else
allocate(res(lda,y_n))
res = dzero
ldb = x_n
call dgemm('N','N',nr,y_n,x_n,done,x,lda,y,ldb,dzero,res,lda)
end if
! adjust dot_local because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
do j=1,x_n
do i=1,size(desc_a%ovrlap_elem,1)
idx = desc_a%ovrlap_elem(i,1)
ndm = desc_a%ovrlap_elem(i,2)
res(j,:) = res(j,:) - (real(ndm-1)/real(ndm))*(x(idx,:)*y(idx,:))
end do
end do
end if
else
if (trans_) then
allocate(res(x_n,y_n))
else
allocate(res(lda,y_n))
end if
res = dzero
end if
! compute global sum
if (global_) call psb_sum(ctxt, res)
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end function psb_dprod_m

@ -24,7 +24,7 @@ function psb_dqrfact(x, desc_a, info) result(res)
integer(psb_lpk_) :: ix, ijx, m, n integer(psb_lpk_) :: ix, ijx, m, n
character(len=20) :: name, ch_err character(len=20) :: name, ch_err
real(psb_dpk_), allocatable :: temp(:,:) real(psb_dpk_), allocatable :: temp(:,:)
type(psb_d_base_multivect_type) :: qr_temp type(psb_d_multivect_type) :: qr_temp
name='psb_dgqrfact' name='psb_dgqrfact'
if (psb_errstatus_fatal()) return if (psb_errstatus_fatal()) return
@ -69,6 +69,14 @@ function psb_dqrfact(x, desc_a, info) result(res)
if (me == psb_root_) then if (me == psb_root_) then
call qr_temp%bld(temp) call qr_temp%bld(temp)
res = qr_temp%qr_fact(info) res = qr_temp%qr_fact(info)
! TODO Check sulla diagonale di R
do i=1,n
if (res(i,i) == dzero) then
write(*,*) 'DIAGONAL 0'
end if
end do
temp = qr_temp%get_vect() temp = qr_temp%get_vect()
call psb_bcast(ctxt,res) call psb_bcast(ctxt,res)
else else

@ -61,10 +61,10 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
integer(psb_ipk_), Optional, Intent(out) :: iter integer(psb_ipk_), Optional, Intent(out) :: iter
real(psb_dpk_), Optional, Intent(out) :: err real(psb_dpk_), Optional, Intent(out) :: err
real(psb_dpk_), allocatable :: aux(:), h(:,:), beta(:,:), beta_e1(:,:) real(psb_dpk_), allocatable :: aux(:), h(:,:), vt(:,:), beta(:,:), y(:,:)
type(psb_d_multivect_type), allocatable :: v(:) type(psb_d_multivect_type), allocatable :: v(:)
type(psb_d_multivect_type) :: v_tot, w type(psb_d_multivect_type) :: w, xt, r
real(psb_dpk_) :: t1, t2 real(psb_dpk_) :: t1, t2
@ -72,13 +72,14 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
integer(psb_ipk_) :: litmax, naux, itrace_, n_row, n_col, nrhs, nrep integer(psb_ipk_) :: litmax, naux, itrace_, n_row, n_col, nrhs, nrep
integer(psb_lpk_) :: mglob, n_add integer(psb_lpk_) :: mglob, n_add
integer(psb_ipk_) :: i, j, k, istop_, err_act, idx_i, idx_j, idx integer(psb_ipk_) :: i, j, k, col, istop_, err_act, idx_i, idx_j, idx
integer(psb_ipk_) :: debug_level, debug_unit integer(psb_ipk_) :: debug_level, debug_unit
type(psb_ctxt_type) :: ctxt type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me, itx integer(psb_ipk_) :: np, me, itx
real(psb_dpk_) :: rni, xni, bni, ani, bn2, r0n2 real(psb_dpk_), allocatable :: r0n2(:), rmn2(:)
real(psb_dpk_) :: errnum, errden, deps, derr real(psb_dpk_), allocatable :: errnum(:), errden(:)
real(psb_dpk_) :: deps, derr
character(len=20) :: name character(len=20) :: name
character(len=*), parameter :: methdname='BGMRES' character(len=*), parameter :: methdname='BGMRES'
@ -112,10 +113,10 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
if (present(istop)) then if (present(istop)) then
istop_ = istop istop_ = istop
else else
istop_ = 2 istop_ = 1
endif endif
if ((istop_ < 1 ).or.(istop_ > 2 ) ) then if (istop_ /= 1) then
info=psb_err_invalid_istop_ info=psb_err_invalid_istop_
err=info err=info
call psb_errpush(info,name,i_err=(/istop_/)) call psb_errpush(info,name,i_err=(/istop_/))
@ -167,13 +168,17 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
naux = 4*n_col naux = 4*n_col
nrhs = x%get_ncols() nrhs = x%get_ncols()
allocate(aux(naux),h((nrep+1)*nrhs,nrep*nrhs),stat=info) allocate(aux(naux),h((nrep+1)*nrhs,nrep*nrhs),y(nrep*nrhs,nrhs),&
& vt(n_row,(nrep+1)*nrhs),r0n2(nrhs),rmn2(nrhs),&
& errnum(nrhs),errden(nrhs),stat=info)
if (info == psb_success_) call psb_geall(v,desc_a,info,m=nrep+1,n=nrhs) if (info == psb_success_) call psb_geall(v,desc_a,info,m=nrep+1,n=nrhs)
if (info == psb_success_) call psb_geall(v_tot,desc_a,info,n=nrep*nrhs)
if (info == psb_success_) call psb_geall(w,desc_a,info,n=nrhs) if (info == psb_success_) call psb_geall(w,desc_a,info,n=nrhs)
if (info == psb_success_) call psb_geall(xt,desc_a,info,n=nrhs)
if (info == psb_success_) call psb_geall(r,desc_a,info,n=nrhs)
if (info == psb_success_) call psb_geasb(v,desc_a,info,mold=x%v,n=nrhs) if (info == psb_success_) call psb_geasb(v,desc_a,info,mold=x%v,n=nrhs)
if (info == psb_success_) call psb_geasb(v_tot,desc_a,info,mold=x%v,n=nrep*nrhs)
if (info == psb_success_) call psb_geasb(w,desc_a,info,mold=x%v,n=nrhs) if (info == psb_success_) call psb_geasb(w,desc_a,info,mold=x%v,n=nrhs)
if (info == psb_success_) call psb_geasb(xt,desc_a,info,mold=x%v,n=nrhs)
if (info == psb_success_) call psb_geasb(r,desc_a,info,mold=x%v,n=nrhs)
if (info /= psb_success_) then if (info /= psb_success_) then
info=psb_err_from_subroutine_non_ info=psb_err_from_subroutine_non_
call psb_errpush(info,name) call psb_errpush(info,name)
@ -185,11 +190,21 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
& ' Size of V,W ',v(1)%get_nrows(),size(v),& & ' Size of V,W ',v(1)%get_nrows(),size(v),&
& w%get_nrows() & w%get_nrows()
! Compute norm2 of R(0)
if (istop_ == 1) then if (istop_ == 1) then
ani = psb_spnrmi(a,desc_a,info) call psb_geaxpby(done,b,dzero,r,desc_a,info)
bni = psb_geamax(b,desc_a,info) if (info /= psb_success_) then
else if (istop_ == 2) then info=psb_err_from_subroutine_non_
bn2 = psb_genrm2(b,desc_a,info) call psb_errpush(info,name)
goto 9999
end if
call psb_spmm(-done,a,x,done,r,desc_a,info,work=aux)
if (info /= psb_success_) then
info=psb_err_from_subroutine_non_
call psb_errpush(info,name)
goto 9999
end if
r0n2 = psb_genrm2(r,desc_a,info)
endif endif
if (info /= psb_success_) then if (info /= psb_success_) then
info=psb_err_from_subroutine_non_ info=psb_err_from_subroutine_non_
@ -198,6 +213,7 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
end if end if
h = dzero h = dzero
y = dzero
errnum = dzero errnum = dzero
errden = done errden = done
deps = eps deps = eps
@ -208,8 +224,10 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
! BGMRES algorithm ! BGMRES algorithm
! TODO QR fact seriale per ora ! TODO Con tanti ITRS e tanti NRHS si ottengono NaN
! TODO Con tanti ITRS NaN (forse genera righe dipendenti (vedere pargen)) ! TODO Deflazione e restart dopo aver trovato una colonna, difficile...
! TODO L'algo converge abbastanza bene. Capire come fare check residui
! STEP 1: Compute R(0) = B - A*X(0) ! STEP 1: Compute R(0) = B - A*X(0)
@ -237,31 +255,15 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
goto 9999 goto 9999
end if end if
! Add V(1) to VT
vt(:,1:nrhs) = v(1)%get_vect()
! STEP 3: Outer loop ! STEP 3: Outer loop
outer: do j=1,nrep outer: do j=1,nrep
! TODO Check convergence ! Update itx counter
! if (istop_ == 1) then
! rni = psb_geamax(v(1),desc_a,info)
! xni = psb_geamax(x,desc_a,info)
! errnum = rni
! errden = (ani*xni+bni)
! else if (istop_ == 2) then
! rni = psb_genrm2(v(1),desc_a,info)
! errnum = rni
! errden = bn2
! endif
! if (info /= psb_success_) then
! info=psb_err_from_subroutine_non_
! call psb_errpush(info,name)
! goto 9999
! end if
! if (errnum <= eps*errden) exit outer
! if (itrace_ > 0) call log_conv(methdname,me,itx,itrace_,errnum,errden,deps)
itx = itx + 1 itx = itx + 1
if (itx >= litmax) exit outer
! Compute j index for H operations ! Compute j index for H operations
idx_j = (j-1)*nrhs+1 idx_j = (j-1)*nrhs+1
@ -274,15 +276,13 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
goto 9999 goto 9999
end if end if
if (itx >= litmax) exit outer
! STEP 5: Inner loop ! STEP 5: Inner loop
inner: do i=1,j inner: do i=1,j
! Compute i index for H operations ! Compute i index for H operations
idx_i = (i-1)*nrhs+1 idx_i = (i-1)*nrhs+1
! STEP 6: Compute H(i,j) = V(i)_T*W ! STEP 6: Compute H(i,j) = (V(i)**T)*W
h(idx_i:idx_i+n_add,idx_j:idx_j+n_add) = psb_geprod(v(i),w,desc_a,info,trans=.true.) h(idx_i:idx_i+n_add,idx_j:idx_j+n_add) = psb_geprod(v(i),w,desc_a,info,trans=.true.)
if (info /= psb_success_) then if (info /= psb_success_) then
info=psb_err_from_subroutine_non_ info=psb_err_from_subroutine_non_
@ -292,8 +292,8 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
! STEP 7: Compute W = W - V(i)*H(i,j) ! STEP 7: Compute W = W - V(i)*H(i,j)
call psb_geaxpby(-done,& call psb_geaxpby(-done,&
& psb_geprod(v(i),h(idx_i:idx_i+n_add,idx_j:idx_j+n_add),desc_a,info,global=.false.),& & psb_geprod(v(i),h(idx_i:idx_i+n_add,idx_j:idx_j+n_add),desc_a,info,global=.false.),&
& done,w,desc_a,info) & done,w,desc_a,info)
if (info /= psb_success_) then if (info /= psb_success_) then
info=psb_err_from_subroutine_non_ info=psb_err_from_subroutine_non_
call psb_errpush(info,name) call psb_errpush(info,name)
@ -320,41 +320,154 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
goto 9999 goto 9999
end if end if
end do outer ! Add V(j+1) to VT
idx = j*nrhs+1
vt(:,idx:idx+n_add) = v(j+1)%get_vect()
! STEP 9: Compute Y(m) ! STEP 9: Compute Y(j)
call frobenius_norm_min() call frobenius_norm_min(j)
if (info /= psb_success_) then if (info /= psb_success_) then
info=psb_err_from_subroutine_non_ info=psb_err_from_subroutine_non_
call psb_errpush(info,name) call psb_errpush(info,name)
goto 9999 goto 9999
end if end if
! STEP 10: Compute V = {V(1),...,V(m)} ! Compute residues
do i=1,nrep if (istop_ == 1) then
idx = (i-1)*nrhs+1
v_tot%v%v(1:n_row,idx:idx+n_add) = v(i)%v%v(1:n_row,1:nrhs)
enddo
! STEP 11: X(m) = X(0) + V*Y(m) ! TODO Compute R(j) = R(0) - VT(j+1)*H(j)*Y(j)
call psb_geaxpby(done,psb_geprod(v_tot,beta_e1,desc_a,info,global=.false.),done,x,desc_a,info) call psb_geaxpby(-done,psb_geprod(psb_geprod(vt(:,1:(j+1)*nrhs),h(1:(j+1)*nrhs,1:j*nrhs),&
if (info /= psb_success_) then & desc_a,info,global=.false.),&
info=psb_err_from_subroutine_non_ & y(1:j*nrhs,1:nrhs),desc_a,info,global=.false.),done,r,desc_a,info)
call psb_errpush(info,name) if (info /= psb_success_) then
goto 9999 info=psb_err_from_subroutine_non_
end if call psb_errpush(info,name)
goto 9999
end if
write(*,*)
do col=1,r%get_nrows()
write(*,*) r%v%v(col,:)
end do
write(*,*)
! TODO Calcolo soluzione al passo J e vedo i residui (se minore esco dal ciclo)
! TODO Compute R(j) = B - A*X(j)
! Copy X in XT
call psb_geaxpby(done,x,dzero,xt,desc_a,info)
if (info /= psb_success_) then
info=psb_err_from_subroutine_non_
call psb_errpush(info,name)
goto 9999
end if
! Compute current solution X(j) = X(0) + VT(j)*Y(j)
call psb_geaxpby(done,psb_geprod(vt(:,1:j*nrhs),y(1:j*nrhs,:),desc_a,info,global=.false.),done,xt,desc_a,info)
if (info /= psb_success_) then
info=psb_err_from_subroutine_non_
call psb_errpush(info,name)
goto 9999
end if
! Copy B in R
call psb_geaxpby(done,b,dzero,r,desc_a,info)
if (info /= psb_success_) then
info=psb_err_from_subroutine_non_
call psb_errpush(info,name)
goto 9999
end if
! Compute R(j) = B - A*X(j)
call psb_spmm(-done,a,xt,done,r,desc_a,info,work=aux)
if (info /= psb_success_) then
info=psb_err_from_subroutine_non_
call psb_errpush(info,name)
goto 9999
end if
write(*,*)
do col=1,r%get_nrows()
write(*,*) r%v%v(col,:)
end do
write(*,*)
! Compute nrm2 of each column of R(j)
rmn2 = psb_genrm2(r,desc_a,info)
if (info /= psb_success_) then
info=psb_err_from_subroutine_non_
call psb_errpush(info,name)
goto 9999
end if
! Set error num and den
errnum = rmn2
errden = r0n2
! TODO Ogni entrata della norma2 di R(m) deve essere più piccola di tolleranza*nrm2(r0)
do col=1,nrhs
write(*,*) rmn2(col), r0n2(col)
end do
end if
! TODO Norma dei residui con Xm devono essere minori di tolleranza * nrm2(R0)?
! Check convergence
if (maxval(errnum) <= eps*maxval(errden)) then
! Compute result and exit
if (istop_ == 1) then
! Compute X(j) = X(0) + VT(j)*Y(j)
! call psb_geaxpby(done,psb_geprod(vt(:,1:j*nrhs),y(1:j*nrhs,:),desc_a,info,global=.false.),done,x,desc_a,info)
! if (info /= psb_success_) then
! info=psb_err_from_subroutine_non_
! call psb_errpush(info,name)
! goto 9999
! end if
! Copy current solution XT in X
call psb_geaxpby(done,xt,dzero,x,desc_a,info)
if (info /= psb_success_) then
info=psb_err_from_subroutine_non_
call psb_errpush(info,name)
goto 9999
end if
end if
! Exit algorithm
exit outer
end if
! Log update
if (itrace_ > 0) call log_conv(methdname,me,itx,ione,maxval(errnum),maxval(errden),deps)
end do outer
! STEP 10: X(m) = X(0) + VT(m)*Y(m)
! call psb_geaxpby(done,psb_geprod(vt(:,1:nrep*nrhs),y,desc_a,info,global=.false.),done,x,desc_a,info)
! if (info /= psb_success_) then
! info=psb_err_from_subroutine_non_
! call psb_errpush(info,name)
! goto 9999
! end if
! END algorithm ! END algorithm
if (itrace_ > 0) call log_conv(methdname,me,itx,ione,errnum,errden,deps) ! TODO log_conv passa scalari errnum,errden,deps (servono vettori)
! TODO Inizialmente versione verbosa che stampa errore per tutte le colonne
! TODO Versione finale che stampa errore massimo (si può usare log_conv con questo)
if (itrace_ > 0) call log_conv(methdname,me,itx,ione,maxval(errnum),maxval(errden),deps)
call log_end(methdname,me,itx,itrace_,errnum,errden,deps,err=derr,iter=iter) call log_end(methdname,me,itx,itrace_,maxval(errnum),maxval(errden),deps,err=derr,iter=iter)
if (present(err)) err = derr if (present(err)) err = derr
if (info == psb_success_) call psb_gefree(v,desc_a,info) if (info == psb_success_) call psb_gefree(v,desc_a,info)
if (info == psb_success_) call psb_gefree(v_tot,desc_a,info)
if (info == psb_success_) call psb_gefree(w,desc_a,info) if (info == psb_success_) call psb_gefree(w,desc_a,info)
if (info == psb_success_) deallocate(aux,h,stat=info) if (info == psb_success_) call psb_gefree(xt,desc_a,info)
if (info == psb_success_) call psb_gefree(r,desc_a,info)
if (info == psb_success_) deallocate(aux,h,y,vt,stat=info)
if (info /= psb_success_) then if (info /= psb_success_) then
info=psb_err_from_subroutine_non_ info=psb_err_from_subroutine_non_
call psb_errpush(info,name) call psb_errpush(info,name)
@ -370,36 +483,39 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
contains contains
! Minimize Frobenius norm ! Minimize Frobenius norm
subroutine frobenius_norm_min() subroutine frobenius_norm_min(rep)
implicit none implicit none
integer(psb_ipk_), intent(in) :: rep
integer(psb_ipk_) :: lwork integer(psb_ipk_) :: lwork
real(psb_dpk_), allocatable :: work(:), beta_temp(:,:) real(psb_dpk_), allocatable :: work(:), beta_e1(:,:)
integer(psb_ipk_) :: m_h, n_h, mn real(psb_dpk_), allocatable :: h_temp(:,:)
integer(psb_ipk_) :: m_h, n_h, mn
! Initialize params ! Initialize params
m_h = (nrep+1)*nrhs h_temp = h
n_h = nrep*nrhs m_h = (rep+1)*nrhs
mn = min(m_h,n_h) n_h = rep*nrhs
lwork = max(1,mn+max(mn,nrhs)) mn = min(m_h,n_h)
lwork = max(1,mn+max(mn,nrhs))
allocate(work(lwork)) allocate(work(lwork))
! Compute E1*beta ! Compute E1*beta
allocate(beta_temp(m_h,nrhs)) allocate(beta_e1(m_h,nrhs))
beta_temp = dzero beta_e1 = dzero
beta_temp(1:nrhs,1:nrhs) = beta beta_e1(1:nrhs,1:nrhs) = beta
! Compute min Frobenius norm ! Compute min Frobenius norm
call dgels('N',m_h,n_h,nrhs,h,m_h,beta_temp,m_h,work,lwork,info) call dgels('N',m_h,n_h,nrhs,h_temp(1:m_h,1:n_h),m_h,beta_e1,m_h,work,lwork,info)
! Set solution ! Set solution
allocate(beta_e1(n_h,nrhs)) y = beta_e1(1:n_h,1:nrhs)
beta_e1 = beta_temp(1:n_h,1:nrhs)
! Deallocate ! Deallocate
deallocate(work,beta,beta_temp) deallocate(work,h_temp,beta_e1)
return return

@ -26,7 +26,8 @@ program psb_dbf_sample
integer(psb_ipk_) :: m, nrhs integer(psb_ipk_) :: m, nrhs
real(psb_dpk_) :: random_value real(psb_dpk_) :: random_value
real(psb_dpk_), allocatable :: test(:,:) real(psb_dpk_), allocatable :: test(:)
! communications data structure ! communications data structure
type(psb_desc_type) :: desc_a type(psb_desc_type) :: desc_a
@ -48,8 +49,10 @@ program psb_dbf_sample
! other variables ! other variables
integer(psb_ipk_) :: i, j, info integer(psb_ipk_) :: i, j, info
real(psb_dpk_) :: t1, t2, tprec real(psb_dpk_) :: t1, t2, tprec
real(psb_dpk_) :: resmx, resmxp real(psb_dpk_), allocatable :: resmx(:)
real(psb_dpk_) :: resmxp
integer(psb_ipk_), allocatable :: ivg(:) integer(psb_ipk_), allocatable :: ivg(:)
logical :: print_matrix = .true.
call psb_init(ctxt) call psb_init(ctxt)
call psb_info(ctxt,iam,np) call psb_info(ctxt,iam,np)
@ -130,9 +133,9 @@ program psb_dbf_sample
b_mv_glob => aux_b(:,:) b_mv_glob => aux_b(:,:)
do i=1, m do i=1, m
do j=1, nrhs do j=1, nrhs
!b_mv_glob(i,j) = done b_mv_glob(i,j) = done
call random_number(random_value) !call random_number(random_value)
b_mv_glob(i,j) = random_value !b_mv_glob(i,j) = random_value
enddo enddo
enddo enddo
endif endif
@ -222,7 +225,8 @@ program psb_dbf_sample
call psb_geaxpby(done,b_mv,dzero,r_mv,desc_a,info) call psb_geaxpby(done,b_mv,dzero,r_mv,desc_a,info)
call psb_spmm(-done,a,x_mv,done,r_mv,desc_a,info) call psb_spmm(-done,a,x_mv,done,r_mv,desc_a,info)
resmx = psb_genrm2(r_mv,desc_a,info) ! TODO resmx vettore ogni entrata è più piccola della tolleranza (per* norma di r0)
resmx = psb_genrm2(r_mv,desc_a,info)
resmxp = psb_geamax(r_mv,desc_a,info) resmxp = psb_geamax(r_mv,desc_a,info)
amatsize = a%sizeof() amatsize = a%sizeof()
@ -253,12 +257,14 @@ program psb_dbf_sample
write(psb_out_unit,'("Time to solve system: ",es12.5)')t2 write(psb_out_unit,'("Time to solve system: ",es12.5)')t2
write(psb_out_unit,'("Time per iteration: ",es12.5)')t2/(iter) write(psb_out_unit,'("Time per iteration: ",es12.5)')t2/(iter)
write(psb_out_unit,'("Total time: ",es12.5)')t2+tprec write(psb_out_unit,'("Total time: ",es12.5)')t2+tprec
write(psb_out_unit,'("Residual norm 2: ",es12.5)')resmx write(psb_out_unit,'("Residual norm 2: ",es12.5)')maxval(resmx)
write(psb_out_unit,'("Residual norm inf: ",es12.5)')resmxp write(psb_out_unit,'("Residual norm inf: ",es12.5)')resmxp
write(psb_out_unit,'(" ")') write(psb_out_unit,'(a8,4(2x,a20))') 'I','X(I)','R(I)','B(I)'
! do i=1,m if (print_matrix) then
! write(psb_out_unit,993) i, x_mv_glob(i,:), r_mv_glob(i,:), b_mv_glob(i,:) do i=1,m
! enddo write(psb_out_unit,993) i, x_mv_glob(i,:)!, ' ', r_mv_glob(i,:), ' ', b_mv_glob(i,:)
end do
end if
end if end if
998 format(i8,4(2x,g20.14)) 998 format(i8,4(2x,g20.14))

Loading…
Cancel
Save