From 0839165bdc3f98391fcf2073abcab363a809fc1b Mon Sep 17 00:00:00 2001 From: gabrielequatrana Date: Mon, 25 Mar 2024 19:56:50 +0100 Subject: [PATCH] Added convergence check --- base/modules/psblas/psb_d_psblas_mod.F90 | 14 +- base/modules/serial/psb_d_base_vect_mod.F90 | 9 +- base/modules/serial/psb_d_vect_mod.F90 | 2 +- base/psblas/psb_dnrm2.f90 | 14 +- base/psblas/psb_dprod.f90 | 349 ++++++++++++++------ base/psblas/psb_dqrfact.f90 | 10 +- krylov/psb_dbgmres.f90 | 276 +++++++++++----- test/block_krylov/psb_dbf_sample.f90 | 28 +- 8 files changed, 492 insertions(+), 210 deletions(-) diff --git a/base/modules/psblas/psb_d_psblas_mod.F90 b/base/modules/psblas/psb_d_psblas_mod.F90 index 3d9625f3..5377fc0b 100644 --- a/base/modules/psblas/psb_d_psblas_mod.F90 +++ b/base/modules/psblas/psb_d_psblas_mod.F90 @@ -119,6 +119,16 @@ module psb_d_psblas_mod logical, intent(in), optional :: trans logical, intent(in), optional :: global end function psb_dprod_multivect_a + function psb_dprod_m(x,y,desc_a,info,trans,global) result(res) + import :: psb_desc_type, psb_dpk_, psb_ipk_, & + & psb_d_multivect_type, psb_dspmat_type + real(psb_dpk_), allocatable :: res(:,:) + real(psb_dpk_), intent(in) :: x(:,:), y(:,:) + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: trans + logical, intent(in), optional :: global + end function psb_dprod_m end interface interface psb_geaxpby @@ -402,10 +412,10 @@ module psb_d_psblas_mod logical, intent(in), optional :: global type(psb_d_vect_type), intent (inout), optional :: aux end function psb_dnrm2_weightmask_vect - function psb_dnrm2_multivect(x, desc_a, info,global) result(res) + function psb_dnrm2_multivect(x, desc_a, info, global) result(res) import :: psb_desc_type, psb_dpk_, psb_ipk_, & & psb_d_multivect_type, psb_dspmat_type - real(psb_dpk_) :: res + real(psb_dpk_), allocatable :: res(:) type(psb_d_multivect_type), intent (inout) :: x type(psb_desc_type), intent (in) :: desc_a integer(psb_ipk_), intent(out) :: info diff --git a/base/modules/serial/psb_d_base_vect_mod.F90 b/base/modules/serial/psb_d_base_vect_mod.F90 index e2e56a60..d70536fb 100644 --- a/base/modules/serial/psb_d_base_vect_mod.F90 +++ b/base/modules/serial/psb_d_base_vect_mod.F90 @@ -3296,13 +3296,16 @@ contains implicit none class(psb_d_base_multivect_type), intent(inout) :: x integer(psb_ipk_), intent(in) :: nr - real(psb_dpk_), allocatable :: res + real(psb_dpk_), allocatable :: res(:) + integer(psb_ipk_) :: nc, col real(psb_dpk_), external :: dnrm2 - integer(psb_ipk_) :: j, nc if (x%is_dev()) call x%sync() nc = x%get_ncols() - res = dnrm2(nc*nr,x%v,1) + allocate(res(nc)) + do col=1,nc + res(col) = dnrm2(nr,x%v(:,col),1) + end do end function d_base_mlv_nrm2 diff --git a/base/modules/serial/psb_d_vect_mod.F90 b/base/modules/serial/psb_d_vect_mod.F90 index ec6ef0ad..5eb7de92 100644 --- a/base/modules/serial/psb_d_vect_mod.F90 +++ b/base/modules/serial/psb_d_vect_mod.F90 @@ -2051,7 +2051,7 @@ contains implicit none class(psb_d_multivect_type), intent(inout) :: x integer(psb_ipk_), intent(in) :: nr - real(psb_dpk_), allocatable :: res + real(psb_dpk_), allocatable :: res(:) if (allocated(x%v)) then res = x%v%nrm2(nr) diff --git a/base/psblas/psb_dnrm2.f90 b/base/psblas/psb_dnrm2.f90 index be0fb6df..81ec41b6 100644 --- a/base/psblas/psb_dnrm2.f90 +++ b/base/psblas/psb_dnrm2.f90 @@ -384,7 +384,7 @@ end function psb_dnrm2_vect ! info - integer. Return code ! global - logical(optional) Whether to perform the global reduction, default: .true. ! -function psb_dnrm2_multivect(x, desc_a, info,global) result(res) +function psb_dnrm2_multivect(x, desc_a, info, global) result(res) use psb_desc_mod use psb_check_mod use psb_error_mod @@ -392,7 +392,7 @@ function psb_dnrm2_multivect(x, desc_a, info,global) result(res) use psb_d_multivect_mod implicit none - real(psb_dpk_) :: res + real(psb_dpk_), allocatable :: res(:) type(psb_d_multivect_type), intent (inout) :: x type(psb_desc_type), intent(in) :: desc_a integer(psb_ipk_), intent(out) :: info @@ -402,7 +402,7 @@ function psb_dnrm2_multivect(x, desc_a, info,global) result(res) type(psb_ctxt_type) :: ctxt integer(psb_ipk_) :: np, me, err_act, idx, i, j, iix, jjx, ldx, ndm real(psb_dpk_) :: dd - integer(psb_lpk_) :: ix, jx, m, n + integer(psb_lpk_) :: ix, jx, m logical :: global_ character(len=20) :: name, ch_err @@ -438,10 +438,9 @@ function psb_dnrm2_multivect(x, desc_a, info,global) result(res) jx = 1 m = desc_a%get_global_rows() - n = x%get_ncols() ldx = x%get_nrows() - call psb_chkvect(m,n,ldx,ix,jx,desc_a,info,iix,jjx) + call psb_chkvect(m,x%get_ncols(),ldx,ix,jx,desc_a,info,iix,jjx) if(info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='psb_chkvect' @@ -455,7 +454,7 @@ function psb_dnrm2_multivect(x, desc_a, info,global) result(res) end if if (desc_a%get_local_rows() > 0) then - res = x%nrm2(desc_a%get_local_rows()) + res = x%nrm2(desc_a%get_local_rows()) ! adjust because overlapped elements are computed more than once if (size(desc_a%ovrlap_elem,1)>0) then if (x%v%is_dev()) call x%sync() @@ -464,11 +463,12 @@ function psb_dnrm2_multivect(x, desc_a, info,global) result(res) idx = desc_a%ovrlap_elem(i,1) ndm = desc_a%ovrlap_elem(i,2) dd = dble(ndm-1)/dble(ndm) - res = res * sqrt(done - dd*(abs(x%v%v(idx,j))/res)**2) + res(j) = res(j) * sqrt(done - dd*(abs(x%v%v(idx,j))/res(j))**2) end do end do end if else + allocate(res(x%get_ncols())) res = dzero end if diff --git a/base/psblas/psb_dprod.f90 b/base/psblas/psb_dprod.f90 index 0aa13a86..661325dc 100644 --- a/base/psblas/psb_dprod.f90 +++ b/base/psblas/psb_dprod.f90 @@ -187,108 +187,247 @@ end function psb_dprod_multivect ! ! function psb_dprod_multivect_a(x,y,desc_a,info,trans,global) result(res) - use psb_desc_mod - use psb_d_base_mat_mod - use psb_check_mod - use psb_error_mod - use psb_penv_mod - use psb_d_vect_mod - use psb_d_psblas_mod, psb_protect_name => psb_dprod_multivect_a - implicit none - real(psb_dpk_), allocatable :: res(:,:) - type(psb_d_multivect_type), intent(inout) :: x - real(psb_dpk_), intent(in) :: y(:,:) - type(psb_desc_type), intent(in) :: desc_a - integer(psb_ipk_), intent(out) :: info - logical, intent(in), optional :: trans - logical, intent(in), optional :: global - - ! locals - type(psb_ctxt_type) :: ctxt - integer(psb_ipk_) :: np, me, idx, ndm,& - & err_act, iix, jjx, iiy, jjy, i, j, nr - integer(psb_lpk_) :: ix, ijx, iy, ijy, m - logical :: global_, trans_ - character(len=20) :: name, ch_err - - name='psb_dprod_multivect' - info=psb_success_ - call psb_erractionsave(err_act) - if (psb_errstatus_fatal()) then - info = psb_err_internal_error_ ; goto 9999 - end if - - ctxt=desc_a%get_context() - call psb_info(ctxt, me, np) - if (np == -ione) then - info = psb_err_context_error_ - call psb_errpush(info,name) - goto 9999 - endif - if (.not.allocated(x%v)) then - info = psb_err_invalid_vect_state_ - call psb_errpush(info,name) - goto 9999 - endif - - if (present(trans)) then - trans_ = trans - else - trans_ = .false. - end if - - if (present(global)) then - global_ = global - else - global_ = .true. - end if - - ix = ione - ijx = ione - - m = desc_a%get_global_rows() - - ! check vector correctness - call psb_chkvect(m,x%get_ncols(),x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) - if(info /= psb_success_) then - info=psb_err_from_subroutine_ - ch_err='psb_chkvect' - call psb_errpush(info,name,a_err=ch_err) - goto 9999 - end if - - if ((iix /= ione)) then - info=psb_err_ix_n1_iy_n1_unsupported_ - call psb_errpush(info,name) - goto 9999 - end if - - nr = desc_a%get_local_rows() - if (nr > 0) then - res = x%prod(nr,y,trans_) - ! adjust dot_local because overlapped elements are computed more than once - if (size(desc_a%ovrlap_elem,1)>0) then - if (x%v%is_dev()) call x%sync() - do j=1,x%get_ncols() - do i=1,size(desc_a%ovrlap_elem,1) - idx = desc_a%ovrlap_elem(i,1) - ndm = desc_a%ovrlap_elem(i,2) - res(j,:) = res(j,:) - (real(ndm-1)/real(ndm))*(x%v%v(idx,:)*y(idx,:)) - end do - end do - end if - else - res = dzero - end if - - ! compute global sum - if (global_) call psb_sum(ctxt, res) - - call psb_erractionrestore(err_act) - return - - 9999 call psb_error_handler(ctxt,err_act) - - return - - end function psb_dprod_multivect_a \ No newline at end of file + use psb_desc_mod + use psb_d_base_mat_mod + use psb_check_mod + use psb_error_mod + use psb_penv_mod + use psb_d_vect_mod + use psb_d_psblas_mod, psb_protect_name => psb_dprod_multivect_a + implicit none + real(psb_dpk_), allocatable :: res(:,:) + type(psb_d_multivect_type), intent(inout) :: x + real(psb_dpk_), intent(in) :: y(:,:) + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: trans + logical, intent(in), optional :: global + + ! locals + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me, idx, ndm,& + & err_act, iix, jjx, iiy, jjy, i, j, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m + logical :: global_, trans_ + character(len=20) :: name, ch_err + + name='psb_dprod_multivect' + info=psb_success_ + call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if + + ctxt=desc_a%get_context() + call psb_info(ctxt, me, np) + if (np == -ione) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + + if (present(trans)) then + trans_ = trans + else + trans_ = .false. + end if + + if (present(global)) then + global_ = global + else + global_ = .true. + end if + + ix = ione + ijx = ione + + m = desc_a%get_global_rows() + + ! check vector correctness + call psb_chkvect(m,x%get_ncols(),x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + + if ((iix /= ione)) then + info=psb_err_ix_n1_iy_n1_unsupported_ + call psb_errpush(info,name) + goto 9999 + end if + + nr = desc_a%get_local_rows() + if (nr > 0) then + res = x%prod(nr,y,trans_) + ! adjust dot_local because overlapped elements are computed more than once + if (size(desc_a%ovrlap_elem,1)>0) then + if (x%v%is_dev()) call x%sync() + do j=1,x%get_ncols() + do i=1,size(desc_a%ovrlap_elem,1) + idx = desc_a%ovrlap_elem(i,1) + ndm = desc_a%ovrlap_elem(i,2) + res(j,:) = res(j,:) - (real(ndm-1)/real(ndm))*(x%v%v(idx,:)*y(idx,:)) + end do + end do + end if + else + res = dzero + end if + + ! compute global sum + if (global_) call psb_sum(ctxt, res) + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + +end function psb_dprod_multivect_a +! +! Function: psb_dprod_m +! psb_dprod computes the product of two distributed multivectors, +! +! prod := ( X ) * ( Y ) or +! prod := ( X )**C * ( Y ) +! +! +! Arguments: +! x - real(:,:) The input vector containing the entries of sub( X ). +! y - real(:,:) The input vector containing the entries of sub( Y ). +! desc_a - type(psb_desc_type). The communication descriptor. +! info - integer. Return code +! trans - logical(optional) Whether multivector X is transposed, default: .false. +! global - logical(optional) Whether to perform the global reduce, default: .true. +! +! Note: from a functional point of view, X and Y are input, but here +! they are declared INOUT because of the sync() methods. +! +! +function psb_dprod_m(x,y,desc_a,info,trans,global) result(res) + use psb_desc_mod + use psb_d_base_mat_mod + use psb_check_mod + use psb_error_mod + use psb_penv_mod + use psb_d_vect_mod + use psb_d_psblas_mod, psb_protect_name => psb_dprod_m + implicit none + real(psb_dpk_), allocatable :: res(:,:) + real(psb_dpk_), intent(in) :: x(:,:), y(:,:) + type(psb_desc_type), intent(in) :: desc_a + integer(psb_ipk_), intent(out) :: info + logical, intent(in), optional :: trans + logical, intent(in), optional :: global + + ! locals + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me, idx, ndm,& + & err_act, iix, jjx, iiy, jjy, i, j, nr, x_n, y_n, lda, ldb + integer(psb_lpk_) :: ix, ijx, iy, ijy, m + logical :: global_, trans_ + character(len=20) :: name, ch_err + + name='psb_dprod_multivect' + info=psb_success_ + call psb_erractionsave(err_act) + if (psb_errstatus_fatal()) then + info = psb_err_internal_error_ ; goto 9999 + end if + + ctxt=desc_a%get_context() + call psb_info(ctxt, me, np) + if (np == -ione) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + + if (present(trans)) then + trans_ = trans + else + trans_ = .false. + end if + + if (present(global)) then + global_ = global + else + global_ = .true. + end if + + ix = ione + ijx = ione + + m = desc_a%get_global_rows() + + ! check vector correctness + call psb_chkvect(m,size(x,2),size(x,1),ix,ijx,desc_a,info,iix,jjx) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + + if ((iix /= ione)) then + info=psb_err_ix_n1_iy_n1_unsupported_ + call psb_errpush(info,name) + goto 9999 + end if + + nr = desc_a%get_local_rows() + x_n = size(x,2) + y_n = size(y,2) + lda = size(x,1) + if (nr > 0) then + + if (trans_) then + allocate(res(x_n,y_n)) + res = dzero + ldb = size(y,1) + call dgemm('T','N',x_n,y_n,nr,done,x,lda,y,ldb,dzero,res,x_n) + else + allocate(res(lda,y_n)) + res = dzero + ldb = x_n + call dgemm('N','N',nr,y_n,x_n,done,x,lda,y,ldb,dzero,res,lda) + end if + + ! adjust dot_local because overlapped elements are computed more than once + if (size(desc_a%ovrlap_elem,1)>0) then + do j=1,x_n + do i=1,size(desc_a%ovrlap_elem,1) + idx = desc_a%ovrlap_elem(i,1) + ndm = desc_a%ovrlap_elem(i,2) + res(j,:) = res(j,:) - (real(ndm-1)/real(ndm))*(x(idx,:)*y(idx,:)) + end do + end do + end if + else + if (trans_) then + allocate(res(x_n,y_n)) + else + allocate(res(lda,y_n)) + end if + res = dzero + end if + + ! compute global sum + if (global_) call psb_sum(ctxt, res) + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + +end function psb_dprod_m diff --git a/base/psblas/psb_dqrfact.f90 b/base/psblas/psb_dqrfact.f90 index abfdeae0..2b7f2039 100644 --- a/base/psblas/psb_dqrfact.f90 +++ b/base/psblas/psb_dqrfact.f90 @@ -24,7 +24,7 @@ function psb_dqrfact(x, desc_a, info) result(res) integer(psb_lpk_) :: ix, ijx, m, n character(len=20) :: name, ch_err real(psb_dpk_), allocatable :: temp(:,:) - type(psb_d_base_multivect_type) :: qr_temp + type(psb_d_multivect_type) :: qr_temp name='psb_dgqrfact' if (psb_errstatus_fatal()) return @@ -69,6 +69,14 @@ function psb_dqrfact(x, desc_a, info) result(res) if (me == psb_root_) then call qr_temp%bld(temp) res = qr_temp%qr_fact(info) + + ! TODO Check sulla diagonale di R + do i=1,n + if (res(i,i) == dzero) then + write(*,*) 'DIAGONAL 0' + end if + end do + temp = qr_temp%get_vect() call psb_bcast(ctxt,res) else diff --git a/krylov/psb_dbgmres.f90 b/krylov/psb_dbgmres.f90 index 31d16480..cef9d12b 100644 --- a/krylov/psb_dbgmres.f90 +++ b/krylov/psb_dbgmres.f90 @@ -61,10 +61,10 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter, integer(psb_ipk_), Optional, Intent(out) :: iter real(psb_dpk_), Optional, Intent(out) :: err - real(psb_dpk_), allocatable :: aux(:), h(:,:), beta(:,:), beta_e1(:,:) + real(psb_dpk_), allocatable :: aux(:), h(:,:), vt(:,:), beta(:,:), y(:,:) type(psb_d_multivect_type), allocatable :: v(:) - type(psb_d_multivect_type) :: v_tot, w + type(psb_d_multivect_type) :: w, xt, r real(psb_dpk_) :: t1, t2 @@ -72,13 +72,14 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter, integer(psb_ipk_) :: litmax, naux, itrace_, n_row, n_col, nrhs, nrep integer(psb_lpk_) :: mglob, n_add - integer(psb_ipk_) :: i, j, k, istop_, err_act, idx_i, idx_j, idx + integer(psb_ipk_) :: i, j, k, col, istop_, err_act, idx_i, idx_j, idx integer(psb_ipk_) :: debug_level, debug_unit type(psb_ctxt_type) :: ctxt integer(psb_ipk_) :: np, me, itx - real(psb_dpk_) :: rni, xni, bni, ani, bn2, r0n2 - real(psb_dpk_) :: errnum, errden, deps, derr + real(psb_dpk_), allocatable :: r0n2(:), rmn2(:) + real(psb_dpk_), allocatable :: errnum(:), errden(:) + real(psb_dpk_) :: deps, derr character(len=20) :: name character(len=*), parameter :: methdname='BGMRES' @@ -112,10 +113,10 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter, if (present(istop)) then istop_ = istop else - istop_ = 2 + istop_ = 1 endif - if ((istop_ < 1 ).or.(istop_ > 2 ) ) then + if (istop_ /= 1) then info=psb_err_invalid_istop_ err=info call psb_errpush(info,name,i_err=(/istop_/)) @@ -167,13 +168,17 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter, naux = 4*n_col nrhs = x%get_ncols() - allocate(aux(naux),h((nrep+1)*nrhs,nrep*nrhs),stat=info) + allocate(aux(naux),h((nrep+1)*nrhs,nrep*nrhs),y(nrep*nrhs,nrhs),& + & vt(n_row,(nrep+1)*nrhs),r0n2(nrhs),rmn2(nrhs),& + & errnum(nrhs),errden(nrhs),stat=info) if (info == psb_success_) call psb_geall(v,desc_a,info,m=nrep+1,n=nrhs) - if (info == psb_success_) call psb_geall(v_tot,desc_a,info,n=nrep*nrhs) if (info == psb_success_) call psb_geall(w,desc_a,info,n=nrhs) + if (info == psb_success_) call psb_geall(xt,desc_a,info,n=nrhs) + if (info == psb_success_) call psb_geall(r,desc_a,info,n=nrhs) if (info == psb_success_) call psb_geasb(v,desc_a,info,mold=x%v,n=nrhs) - if (info == psb_success_) call psb_geasb(v_tot,desc_a,info,mold=x%v,n=nrep*nrhs) if (info == psb_success_) call psb_geasb(w,desc_a,info,mold=x%v,n=nrhs) + if (info == psb_success_) call psb_geasb(xt,desc_a,info,mold=x%v,n=nrhs) + if (info == psb_success_) call psb_geasb(r,desc_a,info,mold=x%v,n=nrhs) if (info /= psb_success_) then info=psb_err_from_subroutine_non_ call psb_errpush(info,name) @@ -185,11 +190,21 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter, & ' Size of V,W ',v(1)%get_nrows(),size(v),& & w%get_nrows() + ! Compute norm2 of R(0) if (istop_ == 1) then - ani = psb_spnrmi(a,desc_a,info) - bni = psb_geamax(b,desc_a,info) - else if (istop_ == 2) then - bn2 = psb_genrm2(b,desc_a,info) + call psb_geaxpby(done,b,dzero,r,desc_a,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_non_ + call psb_errpush(info,name) + goto 9999 + end if + call psb_spmm(-done,a,x,done,r,desc_a,info,work=aux) + if (info /= psb_success_) then + info=psb_err_from_subroutine_non_ + call psb_errpush(info,name) + goto 9999 + end if + r0n2 = psb_genrm2(r,desc_a,info) endif if (info /= psb_success_) then info=psb_err_from_subroutine_non_ @@ -198,6 +213,7 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter, end if h = dzero + y = dzero errnum = dzero errden = done deps = eps @@ -208,8 +224,10 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter, ! BGMRES algorithm - ! TODO QR fact seriale per ora - ! TODO Con tanti ITRS NaN (forse genera righe dipendenti (vedere pargen)) + ! TODO Con tanti ITRS e tanti NRHS si ottengono NaN + ! TODO Deflazione e restart dopo aver trovato una colonna, difficile... + + ! TODO L'algo converge abbastanza bene. Capire come fare check residui ! STEP 1: Compute R(0) = B - A*X(0) @@ -237,31 +255,15 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter, goto 9999 end if + ! Add V(1) to VT + vt(:,1:nrhs) = v(1)%get_vect() + ! STEP 3: Outer loop outer: do j=1,nrep - ! TODO Check convergence - ! if (istop_ == 1) then - ! rni = psb_geamax(v(1),desc_a,info) - ! xni = psb_geamax(x,desc_a,info) - ! errnum = rni - ! errden = (ani*xni+bni) - ! else if (istop_ == 2) then - ! rni = psb_genrm2(v(1),desc_a,info) - ! errnum = rni - ! errden = bn2 - ! endif - ! if (info /= psb_success_) then - ! info=psb_err_from_subroutine_non_ - ! call psb_errpush(info,name) - ! goto 9999 - ! end if - - ! if (errnum <= eps*errden) exit outer - - ! if (itrace_ > 0) call log_conv(methdname,me,itx,itrace_,errnum,errden,deps) - + ! Update itx counter itx = itx + 1 + if (itx >= litmax) exit outer ! Compute j index for H operations idx_j = (j-1)*nrhs+1 @@ -274,15 +276,13 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter, goto 9999 end if - if (itx >= litmax) exit outer - ! STEP 5: Inner loop inner: do i=1,j ! Compute i index for H operations idx_i = (i-1)*nrhs+1 - ! STEP 6: Compute H(i,j) = V(i)_T*W + ! STEP 6: Compute H(i,j) = (V(i)**T)*W h(idx_i:idx_i+n_add,idx_j:idx_j+n_add) = psb_geprod(v(i),w,desc_a,info,trans=.true.) if (info /= psb_success_) then info=psb_err_from_subroutine_non_ @@ -292,8 +292,8 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter, ! STEP 7: Compute W = W - V(i)*H(i,j) call psb_geaxpby(-done,& - & psb_geprod(v(i),h(idx_i:idx_i+n_add,idx_j:idx_j+n_add),desc_a,info,global=.false.),& - & done,w,desc_a,info) + & psb_geprod(v(i),h(idx_i:idx_i+n_add,idx_j:idx_j+n_add),desc_a,info,global=.false.),& + & done,w,desc_a,info) if (info /= psb_success_) then info=psb_err_from_subroutine_non_ call psb_errpush(info,name) @@ -320,41 +320,154 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter, goto 9999 end if - end do outer + ! Add V(j+1) to VT + idx = j*nrhs+1 + vt(:,idx:idx+n_add) = v(j+1)%get_vect() - ! STEP 9: Compute Y(m) - call frobenius_norm_min() - if (info /= psb_success_) then - info=psb_err_from_subroutine_non_ - call psb_errpush(info,name) - goto 9999 - end if + ! STEP 9: Compute Y(j) + call frobenius_norm_min(j) + if (info /= psb_success_) then + info=psb_err_from_subroutine_non_ + call psb_errpush(info,name) + goto 9999 + end if - ! STEP 10: Compute V = {V(1),...,V(m)} - do i=1,nrep - idx = (i-1)*nrhs+1 - v_tot%v%v(1:n_row,idx:idx+n_add) = v(i)%v%v(1:n_row,1:nrhs) - enddo + ! Compute residues + if (istop_ == 1) then - ! STEP 11: X(m) = X(0) + V*Y(m) - call psb_geaxpby(done,psb_geprod(v_tot,beta_e1,desc_a,info,global=.false.),done,x,desc_a,info) - if (info /= psb_success_) then - info=psb_err_from_subroutine_non_ - call psb_errpush(info,name) - goto 9999 - end if + ! TODO Compute R(j) = R(0) - VT(j+1)*H(j)*Y(j) + call psb_geaxpby(-done,psb_geprod(psb_geprod(vt(:,1:(j+1)*nrhs),h(1:(j+1)*nrhs,1:j*nrhs),& + & desc_a,info,global=.false.),& + & y(1:j*nrhs,1:nrhs),desc_a,info,global=.false.),done,r,desc_a,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_non_ + call psb_errpush(info,name) + goto 9999 + end if + + write(*,*) + do col=1,r%get_nrows() + write(*,*) r%v%v(col,:) + end do + write(*,*) + + ! TODO Calcolo soluzione al passo J e vedo i residui (se minore esco dal ciclo) + ! TODO Compute R(j) = B - A*X(j) + + ! Copy X in XT + call psb_geaxpby(done,x,dzero,xt,desc_a,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_non_ + call psb_errpush(info,name) + goto 9999 + end if + + ! Compute current solution X(j) = X(0) + VT(j)*Y(j) + call psb_geaxpby(done,psb_geprod(vt(:,1:j*nrhs),y(1:j*nrhs,:),desc_a,info,global=.false.),done,xt,desc_a,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_non_ + call psb_errpush(info,name) + goto 9999 + end if + + ! Copy B in R + call psb_geaxpby(done,b,dzero,r,desc_a,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_non_ + call psb_errpush(info,name) + goto 9999 + end if + + ! Compute R(j) = B - A*X(j) + call psb_spmm(-done,a,xt,done,r,desc_a,info,work=aux) + if (info /= psb_success_) then + info=psb_err_from_subroutine_non_ + call psb_errpush(info,name) + goto 9999 + end if + + write(*,*) + do col=1,r%get_nrows() + write(*,*) r%v%v(col,:) + end do + write(*,*) + + ! Compute nrm2 of each column of R(j) + rmn2 = psb_genrm2(r,desc_a,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_non_ + call psb_errpush(info,name) + goto 9999 + end if + + ! Set error num and den + errnum = rmn2 + errden = r0n2 + + ! TODO Ogni entrata della norma2 di R(m) deve essere più piccola di tolleranza*nrm2(r0) + do col=1,nrhs + write(*,*) rmn2(col), r0n2(col) + end do + end if + + ! TODO Norma dei residui con Xm devono essere minori di tolleranza * nrm2(R0)? + ! Check convergence + if (maxval(errnum) <= eps*maxval(errden)) then + + ! Compute result and exit + if (istop_ == 1) then + + ! Compute X(j) = X(0) + VT(j)*Y(j) + ! call psb_geaxpby(done,psb_geprod(vt(:,1:j*nrhs),y(1:j*nrhs,:),desc_a,info,global=.false.),done,x,desc_a,info) + ! if (info /= psb_success_) then + ! info=psb_err_from_subroutine_non_ + ! call psb_errpush(info,name) + ! goto 9999 + ! end if + + ! Copy current solution XT in X + call psb_geaxpby(done,xt,dzero,x,desc_a,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_non_ + call psb_errpush(info,name) + goto 9999 + end if + + end if + + ! Exit algorithm + exit outer + + end if + + ! Log update + if (itrace_ > 0) call log_conv(methdname,me,itx,ione,maxval(errnum),maxval(errden),deps) + + end do outer + + ! STEP 10: X(m) = X(0) + VT(m)*Y(m) + ! call psb_geaxpby(done,psb_geprod(vt(:,1:nrep*nrhs),y,desc_a,info,global=.false.),done,x,desc_a,info) + ! if (info /= psb_success_) then + ! info=psb_err_from_subroutine_non_ + ! call psb_errpush(info,name) + ! goto 9999 + ! end if ! END algorithm - if (itrace_ > 0) call log_conv(methdname,me,itx,ione,errnum,errden,deps) + ! TODO log_conv passa scalari errnum,errden,deps (servono vettori) + ! TODO Inizialmente versione verbosa che stampa errore per tutte le colonne + ! TODO Versione finale che stampa errore massimo (si può usare log_conv con questo) + if (itrace_ > 0) call log_conv(methdname,me,itx,ione,maxval(errnum),maxval(errden),deps) - call log_end(methdname,me,itx,itrace_,errnum,errden,deps,err=derr,iter=iter) + call log_end(methdname,me,itx,itrace_,maxval(errnum),maxval(errden),deps,err=derr,iter=iter) if (present(err)) err = derr if (info == psb_success_) call psb_gefree(v,desc_a,info) - if (info == psb_success_) call psb_gefree(v_tot,desc_a,info) if (info == psb_success_) call psb_gefree(w,desc_a,info) - if (info == psb_success_) deallocate(aux,h,stat=info) + if (info == psb_success_) call psb_gefree(xt,desc_a,info) + if (info == psb_success_) call psb_gefree(r,desc_a,info) + if (info == psb_success_) deallocate(aux,h,y,vt,stat=info) if (info /= psb_success_) then info=psb_err_from_subroutine_non_ call psb_errpush(info,name) @@ -370,36 +483,39 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter, contains ! Minimize Frobenius norm - subroutine frobenius_norm_min() + subroutine frobenius_norm_min(rep) implicit none + integer(psb_ipk_), intent(in) :: rep + integer(psb_ipk_) :: lwork - real(psb_dpk_), allocatable :: work(:), beta_temp(:,:) + real(psb_dpk_), allocatable :: work(:), beta_e1(:,:) - integer(psb_ipk_) :: m_h, n_h, mn + real(psb_dpk_), allocatable :: h_temp(:,:) + integer(psb_ipk_) :: m_h, n_h, mn ! Initialize params - m_h = (nrep+1)*nrhs - n_h = nrep*nrhs - mn = min(m_h,n_h) - lwork = max(1,mn+max(mn,nrhs)) + h_temp = h + m_h = (rep+1)*nrhs + n_h = rep*nrhs + mn = min(m_h,n_h) + lwork = max(1,mn+max(mn,nrhs)) allocate(work(lwork)) ! Compute E1*beta - allocate(beta_temp(m_h,nrhs)) - beta_temp = dzero - beta_temp(1:nrhs,1:nrhs) = beta + allocate(beta_e1(m_h,nrhs)) + beta_e1 = dzero + beta_e1(1:nrhs,1:nrhs) = beta ! Compute min Frobenius norm - call dgels('N',m_h,n_h,nrhs,h,m_h,beta_temp,m_h,work,lwork,info) + call dgels('N',m_h,n_h,nrhs,h_temp(1:m_h,1:n_h),m_h,beta_e1,m_h,work,lwork,info) ! Set solution - allocate(beta_e1(n_h,nrhs)) - beta_e1 = beta_temp(1:n_h,1:nrhs) + y = beta_e1(1:n_h,1:nrhs) ! Deallocate - deallocate(work,beta,beta_temp) + deallocate(work,h_temp,beta_e1) return diff --git a/test/block_krylov/psb_dbf_sample.f90 b/test/block_krylov/psb_dbf_sample.f90 index 5ec6d5d0..4e0c2a22 100644 --- a/test/block_krylov/psb_dbf_sample.f90 +++ b/test/block_krylov/psb_dbf_sample.f90 @@ -26,7 +26,8 @@ program psb_dbf_sample integer(psb_ipk_) :: m, nrhs real(psb_dpk_) :: random_value - real(psb_dpk_), allocatable :: test(:,:) + real(psb_dpk_), allocatable :: test(:) + ! communications data structure type(psb_desc_type) :: desc_a @@ -48,8 +49,10 @@ program psb_dbf_sample ! other variables integer(psb_ipk_) :: i, j, info real(psb_dpk_) :: t1, t2, tprec - real(psb_dpk_) :: resmx, resmxp + real(psb_dpk_), allocatable :: resmx(:) + real(psb_dpk_) :: resmxp integer(psb_ipk_), allocatable :: ivg(:) + logical :: print_matrix = .true. call psb_init(ctxt) call psb_info(ctxt,iam,np) @@ -130,9 +133,9 @@ program psb_dbf_sample b_mv_glob => aux_b(:,:) do i=1, m do j=1, nrhs - !b_mv_glob(i,j) = done - call random_number(random_value) - b_mv_glob(i,j) = random_value + b_mv_glob(i,j) = done + !call random_number(random_value) + !b_mv_glob(i,j) = random_value enddo enddo endif @@ -222,7 +225,8 @@ program psb_dbf_sample call psb_geaxpby(done,b_mv,dzero,r_mv,desc_a,info) call psb_spmm(-done,a,x_mv,done,r_mv,desc_a,info) - resmx = psb_genrm2(r_mv,desc_a,info) + ! TODO resmx vettore ogni entrata è più piccola della tolleranza (per* norma di r0) + resmx = psb_genrm2(r_mv,desc_a,info) resmxp = psb_geamax(r_mv,desc_a,info) amatsize = a%sizeof() @@ -253,12 +257,14 @@ program psb_dbf_sample write(psb_out_unit,'("Time to solve system: ",es12.5)')t2 write(psb_out_unit,'("Time per iteration: ",es12.5)')t2/(iter) write(psb_out_unit,'("Total time: ",es12.5)')t2+tprec - write(psb_out_unit,'("Residual norm 2: ",es12.5)')resmx + write(psb_out_unit,'("Residual norm 2: ",es12.5)')maxval(resmx) write(psb_out_unit,'("Residual norm inf: ",es12.5)')resmxp - write(psb_out_unit,'(" ")') - ! do i=1,m - ! write(psb_out_unit,993) i, x_mv_glob(i,:), r_mv_glob(i,:), b_mv_glob(i,:) - ! enddo + write(psb_out_unit,'(a8,4(2x,a20))') 'I','X(I)','R(I)','B(I)' + if (print_matrix) then + do i=1,m + write(psb_out_unit,993) i, x_mv_glob(i,:)!, ' ', r_mv_glob(i,:), ' ', b_mv_glob(i,:) + end do + end if end if 998 format(i8,4(2x,g20.14))