diff --git a/mlprec/impl/mld_dmlprec_aply.f90 b/mlprec/impl/mld_dmlprec_aply.f90 index 94f54e86..11fe8643 100644 --- a/mlprec/impl/mld_dmlprec_aply.f90 +++ b/mlprec/impl/mld_dmlprec_aply.f90 @@ -1112,7 +1112,7 @@ subroutine mld_dmlprec_aply_vect(alpha,p,x,beta,y,desc_data,trans,work,info) contains - recursive subroutine inner_ml_aply(level,p,mlprec_wrk,trans,work,info) + recursive subroutine inner_ml_aply(level,p,mlprec_wrk,trans,work,info, u) implicit none @@ -1123,6 +1123,9 @@ contains character, intent(in) :: trans real(psb_dpk_),target :: work(:) integer(psb_ipk_), intent(out) :: info + type(psb_d_vect_type),intent(inout), optional :: u + + type(psb_d_vect_type) :: res ! Local variables integer(psb_ipk_) :: ictxt,np,me @@ -1668,6 +1671,313 @@ contains end select + + case(mld_vcycle_ml_, mld_wcycle_ml_) + + !V/W cycle + if (level > 1) then + ! Apply the restriction + call psb_map_X2Y(done,mlprec_wrk(level-1)%vty,& + & dzero,mlprec_wrk(level)%vx2l,& + & p%precv(level)%map,info,work=work) + + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,& + & a_err='Error during restriction') + goto 9999 + end if + end if + + call psb_geaxpby(done,mlprec_wrk(level)%vx2l,& + & dzero,mlprec_wrk(level)%vtx,& + & p%precv(level)%base_desc,info) + ! + ! Apply the base preconditioner + ! + if (level < nlev) then + + if (present(u)) then + call mlprec_wrk(level)%vy2l%set(u%get_vect()) + else + call mlprec_wrk(level)%vy2l%set(dzero) + endif + res = mlprec_wrk(level)%vx2l + + call psb_spmm(-done,p%precv(level)%base_a,mlprec_wrk(level)%vy2l,& + done, res, p%precv(level)%base_desc, info, work=work, trans=trans) + + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,& + & a_err='Error during residue') + goto 9999 + end if + + if (trans == 'N') then + sweeps = p%precv(level)%parms%sweeps_pre + if (info == psb_success_) call p%precv(level)%sm%apply(done,& + & mlprec_wrk(level)%vx2l,dzero,mlprec_wrk(level)%vy2l,& + & p%precv(level)%base_desc, trans,& + & sweeps,work,info) + else + sweeps = p%precv(level)%parms%sweeps_post + if (info == psb_success_) call p%precv(level)%sm2%apply(done,& + & mlprec_wrk(level)%vx2l,dzero,mlprec_wrk(level)%vy2l,& + & p%precv(level)%base_desc, trans,& + & sweeps,work,info) + end if + else + sweeps = p%precv(level)%parms%sweeps + if (info == psb_success_) call p%precv(level)%sm%apply(done,& + & mlprec_wrk(level)%vx2l,dzero,mlprec_wrk(level)%vy2l,& + & p%precv(level)%base_desc, trans,& + & sweeps,work,info) + end if + + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,& + & a_err='Error during 2-PRE smoother_apply') + goto 9999 + end if + + + ! + ! Compute the residual (at all levels but the coarsest one) + ! and call recursively + ! + if(level < nlev) then + + call psb_geaxpby(done,mlprec_wrk(level)%vx2l,& + & dzero,mlprec_wrk(level)%vty,& + & p%precv(level)%base_desc,info) + + if (info == psb_success_) call psb_spmm(-done,p%precv(level)%base_a,& + & mlprec_wrk(level)%vy2l,done,mlprec_wrk(level)%vty,& + & p%precv(level)%base_desc,info,work=work,trans=trans) + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,& + & a_err='Error during residue') + goto 9999 + end if + + call inner_ml_aply(level+1,p,mlprec_wrk,trans,work,info) + + if (p%precv(level)%parms%ml_type == mld_wcycle_ml_) then + call inner_ml_aply(level+1,p,mlprec_wrk,trans,work,info, mlprec_wrk(level+1)%vy2l) + endif + + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,& + & a_err='Error in recursive call') + goto 9999 + end if + + + ! + ! Apply the prolongator + ! + call psb_map_Y2X(done,mlprec_wrk(level+1)%vy2l,& + & done,mlprec_wrk(level)%vy2l,& + & p%precv(level+1)%map,info,work=work) + + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,& + & a_err='Error during prolongation') + goto 9999 + end if + + ! + ! Compute the residual + ! + call psb_spmm(-done,p%precv(level)%base_a,mlprec_wrk(level)%vy2l,& + & done,mlprec_wrk(level)%vtx,p%precv(level)%base_desc,info,& + & work=work,trans=trans) + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,& + & a_err='Error during residue') + goto 9999 + end if + ! + ! Apply the base preconditioner + ! + if (trans == 'N') then + sweeps = p%precv(level)%parms%sweeps_post + if (info == psb_success_) call p%precv(level)%sm2%apply(done,& + & mlprec_wrk(level)%vtx,done,mlprec_wrk(level)%vy2l,& + & p%precv(level)%base_desc, trans,& + & sweeps,work,info) + else + sweeps = p%precv(level)%parms%sweeps_pre + if (info == psb_success_) call p%precv(level)%sm%apply(done,& + & mlprec_wrk(level)%vtx,done,mlprec_wrk(level)%vy2l,& + & p%precv(level)%base_desc, trans,& + & sweeps,work,info) + end if + + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,& + & a_err='Error during 2-POST smoother_apply') + goto 9999 + end if + + endif + + + case(mld_kcycle_ml_, mld_kcyclesym_ml_) + + + !K cycle + + call psb_geaxpby(done,mlprec_wrk(level)%vx2l,& + & dzero,mlprec_wrk(level)%vtx,& + & p%precv(level)%base_desc,info) + ! + ! Apply the base preconditioner + ! + if (level < nlev) then + + if (present(u)) then + call mlprec_wrk(level)%vy2l%set(u%get_vect()) + else + call mlprec_wrk(level)%vy2l%set(dzero) + endif + res = mlprec_wrk(level)%vx2l + + call psb_spmm(-done,p%precv(level)%base_a,mlprec_wrk(level)%vy2l,& + done, res, p%precv(level)%base_desc, info, work=work, trans=trans) + + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,& + & a_err='Error during residue') + goto 9999 + end if + + if (trans == 'N') then + sweeps = p%precv(level)%parms%sweeps_pre + if (info == psb_success_) call p%precv(level)%sm%apply(done,& + & mlprec_wrk(level)%vx2l,dzero,mlprec_wrk(level)%vy2l,& + & p%precv(level)%base_desc, trans,& + & sweeps,work,info) + else + sweeps = p%precv(level)%parms%sweeps_post + if (info == psb_success_) call p%precv(level)%sm2%apply(done,& + & mlprec_wrk(level)%vx2l,dzero,mlprec_wrk(level)%vy2l,& + & p%precv(level)%base_desc, trans,& + & sweeps,work,info) + end if + else + sweeps = p%precv(level)%parms%sweeps + if (info == psb_success_) call p%precv(level)%sm%apply(done,& + & mlprec_wrk(level)%vx2l,dzero,mlprec_wrk(level)%vy2l,& + & p%precv(level)%base_desc, trans,& + & sweeps,work,info) + end if + + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,& + & a_err='Error during 2-PRE smoother_apply') + goto 9999 + end if + + + ! + ! Compute the residual (at all levels but the coarsest one) + ! and call recursively + ! + if(level < nlev) then + + call psb_geaxpby(done,mlprec_wrk(level)%vx2l,& + & dzero,mlprec_wrk(level)%vty,& + & p%precv(level)%base_desc,info) + + if (info == psb_success_) call psb_spmm(-done,p%precv(level)%base_a,& + & mlprec_wrk(level)%vy2l,done,mlprec_wrk(level)%vty,& + & p%precv(level)%base_desc,info,work=work,trans=trans) + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,& + & a_err='Error during residue') + goto 9999 + end if + + ! Apply the restriction + call psb_map_X2Y(done,mlprec_wrk(level)%vty,& + & dzero,mlprec_wrk(level + 1)%vx2l,& + & p%precv(level + 1)%map,info,work=work) + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,& + & a_err='Error during restriction') + goto 9999 + end if + + !Set the preconditioner + + + if ((level < nlev - 2)) then + if (p%precv(level)%parms%ml_type == mld_kcyclesym_ml_) then + call mld_dinneritkcycle(p, mlprec_wrk, level + 1, trans, work, 'FCG') + elseif (p%precv(level)%parms%ml_type == mld_kcycle_ml_) then + call mld_dinneritkcycle(p, mlprec_wrk, level + 1, trans, work, 'CGR') + endif + else + call inner_ml_aply(level + 1 ,p,mlprec_wrk,trans,work,info) + endif + + + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,& + & a_err='Error in recursive call') + goto 9999 + end if + + + ! + ! Apply the prolongator + ! + call psb_map_Y2X(done,mlprec_wrk(level+1)%vy2l,& + & done,mlprec_wrk(level)%vy2l,& + & p%precv(level+1)%map,info,work=work) + + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,& + & a_err='Error during prolongation') + goto 9999 + end if + + ! + ! Compute the residual + ! + call psb_spmm(-done,p%precv(level)%base_a,mlprec_wrk(level)%vy2l,& + & done,mlprec_wrk(level)%vtx,p%precv(level)%base_desc,info,& + & work=work,trans=trans) + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,& + & a_err='Error during residue') + goto 9999 + end if + ! + ! Apply the base preconditioner + ! + if (trans == 'N') then + sweeps = p%precv(level)%parms%sweeps_post + if (info == psb_success_) call p%precv(level)%sm2%apply(done,& + & mlprec_wrk(level)%vtx,done,mlprec_wrk(level)%vy2l,& + & p%precv(level)%base_desc, trans,& + & sweeps,work,info) + else + sweeps = p%precv(level)%parms%sweeps_pre + if (info == psb_success_) call p%precv(level)%sm%apply(done,& + & mlprec_wrk(level)%vtx,done,mlprec_wrk(level)%vy2l,& + & p%precv(level)%base_desc, trans,& + & sweeps,work,info) + end if + + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,& + & a_err='Error during 2-POST smoother_apply') + goto 9999 + end if + + endif + case default info = psb_err_from_subroutine_ai_ call psb_errpush(info,name,a_err='invalid mltype',& @@ -1684,5 +1994,170 @@ contains end subroutine inner_ml_aply + +recursive subroutine mld_dinneritkcycle(p, mlprec_wrk, level, trans, work, innersolv) + use psb_base_mod + use mld_prec_mod + use mld_d_inner_mod, mld_protect_name => mld_dmlprec_aply + + implicit none + + !Input/Oputpu variables + type(mld_dprec_type), intent(inout) :: p + + type(mld_mlprec_wrk_type), intent(inout) :: mlprec_wrk(:) + integer(psb_ipk_), intent(in) :: level + character, intent(in) :: trans, innersolv + real(psb_dpk_),target :: work(:) + + !Other variables + type(psb_d_vect_type) :: v, w, rhs, v1, x + type(psb_d_vect_type), dimension(0:1) :: d + real(psb_dpk_) :: delta_old, delta, rhs_norm, alpha, tau, tau1, tau2, tau3, tau4, beta + + real(psb_dpk_) :: l2_norm, rtol=0.25 + real(psb_dpk_), allocatable :: temp_v(:) + integer(psb_ipk_) :: info, nlev, i, iter, max_iter=2, idx + + !Assemble rhs, w, v, v1, x + + call psb_geasb(rhs,& + & p%precv(level)%base_desc,info,& + & scratch=.true.,mold=mlprec_wrk(level)%vx2l%v) + call psb_geasb(w,& + & p%precv(level)%base_desc,info,& + & scratch=.true.,mold=mlprec_wrk(level)%vx2l%v) + call psb_geasb(v,& + & p%precv(level)%base_desc,info,& + & scratch=.true.,mold=mlprec_wrk(level)%vx2l%v) + call psb_geasb(v1,& + & p%precv(level)%base_desc,info,& + & scratch=.true.,mold=mlprec_wrk(level)%vx2l%v) + call psb_geasb(x,& + & p%precv(level)%base_desc,info,& + & scratch=.true.,mold=mlprec_wrk(level)%vx2l%v) + + call x%set(dzero) + + ! rhs=vx2l and w=rhs + call psb_geaxpby(done,mlprec_wrk(level)%vx2l,dzero,rhs,& + & p%precv(level)%base_desc,info) + call psb_geaxpby(done,mlprec_wrk(level)%vx2l,dzero,w,& + & p%precv(level)%base_desc,info) + + if (psb_errstatus_fatal()) then + nc2l = p%precv(level)%base_desc%get_local_cols() + info=psb_err_alloc_request_ + call psb_errpush(info,name,i_err=(/2*nc2l,izero,izero,izero,izero/),& + & a_err='real(psb_dpk_)') + goto 9999 + end if + + delta = psb_gedot(w, w, p%precv(level)%base_desc, info) + + !Apply the preconditioner + + call mlprec_wrk(level)%vy2l%set(dzero) + + idx=0 + call inner_ml_aply(level,p,mlprec_wrk,trans,work,info) + + !Assemble d(0) and d(1) + call psb_geasb(d(0),& + & p%precv(level)%base_desc,info,& + & scratch=.true.,mold=mlprec_wrk(level)%vy2l%v) + call psb_geasb(d(1),& + & p%precv(level)%base_desc,info,& + & scratch=.true.,mold=mlprec_wrk(level)%vy2l%v) + + call psb_geaxpby(done,mlprec_wrk(level)%vy2l,dzero,d(idx),p%precv(level)%base_desc,info) + + + call psb_spmm(done,p%precv(level)%base_a,d(idx),dzero,v,p%precv(level)%base_desc,info) + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,& + & a_err='Error during residue') + goto 9999 + end if + + !FCG + if (innersolv == 'FCG') then + delta_old = psb_gedot(d(idx), w, p%precv(level)%base_desc, info) + tau = psb_gedot(d(idx), v, p%precv(level)%base_desc, info) + !CGR + else + delta_old = psb_gedot(v, w, p%precv(level)%base_desc, info) + tau = psb_gedot(v, v, p%precv(level)%base_desc, info) + endif + + alpha = delta_old/tau + !Update residual w + call psb_geaxpby(-alpha, v, done, w, p%precv(level)%base_desc, info) + + l2_norm = psb_gedot(w, w, p%precv(level)%base_desc, info) + iter = 0 + + + if (l2_norm <= rtol*delta) then + !Update solution x + call psb_geaxpby(alpha, d(idx), done, x, p%precv(level)%base_desc, info) + else + iter = iter + 1 + idx=mod(iter,2) + + !Apply preconditioner + call psb_geaxpby(done,w,dzero,mlprec_wrk(level)%vx2l,p%precv(level)%base_desc,info) + call inner_ml_aply(level,p,mlprec_wrk,trans,work,info) + call psb_geaxpby(done,mlprec_wrk(level)%vy2l,dzero,d(idx),p%precv(level)%base_desc,info) + + !Sparse matrix vector product + + call psb_spmm(done,p%precv(level)%base_a,d(idx),dzero,v1,p%precv(level)%base_desc,info) + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_,name,& + & a_err='Error during residue') + goto 9999 + end if + + !tau1, tau2, tau3, tau4 + !FCG + if (innersolv == 'FCG') then + tau1= psb_gedot(d(idx), v, p%precv(level)%base_desc, info) + tau2= psb_gedot(d(idx), v1, p%precv(level)%base_desc, info) + tau3= psb_gedot(d(idx), w, p%precv(level)%base_desc, info) + tau4= tau2 - (tau1*tau1)/tau + !CGR + else + tau1= psb_gedot(v1, v, p%precv(level)%base_desc, info) + tau2= psb_gedot(v1, v1, p%precv(level)%base_desc, info) + tau3= psb_gedot(v1, w, p%precv(level)%base_desc, info) + tau4= tau2 - (tau1*tau1)/tau + endif + + !Update solution + alpha=alpha-(tau1*tau3)/(tau*tau4) + call psb_geaxpby(alpha,d(idx - 1),done,x,p%precv(level)%base_desc,info) + alpha=tau3/tau4 + call psb_geaxpby(alpha,d(idx),done,x,p%precv(level)%base_desc,info) + endif + + !Free vectors + call psb_geaxpby(done,x,dzero,mlprec_wrk(level)%vy2l,p%precv(level)%base_desc,info) + call psb_gefree(v, p%precv(level)%base_desc, info) + call psb_gefree(v1, p%precv(level)%base_desc, info) + call psb_gefree(w, p%precv(level)%base_desc, info) + call psb_gefree(x, p%precv(level)%base_desc, info) + call psb_gefree(d(0), p%precv(level)%base_desc, info) + call psb_gefree(d(1), p%precv(level)%base_desc, info) + +9999 continue + call psb_erractionrestore(err_act) + if (err_act.eq.psb_act_abort_) then + call psb_error() + return + end if + return +end subroutine mld_dinneritkcycle + end subroutine mld_dmlprec_aply_vect diff --git a/mlprec/mld_base_prec_type.F90 b/mlprec/mld_base_prec_type.F90 index 8dac161d..8c8f8f8e 100644 --- a/mlprec/mld_base_prec_type.F90 +++ b/mlprec/mld_base_prec_type.F90 @@ -219,11 +219,15 @@ module mld_base_prec_type ! ! Legal values for entry: mld_ml_type_ ! - integer(psb_ipk_), parameter :: mld_no_ml_ = 0 - integer(psb_ipk_), parameter :: mld_add_ml_ = 1 - integer(psb_ipk_), parameter :: mld_mult_ml_ = 2 - integer(psb_ipk_), parameter :: mld_new_ml_prec_ = 3 - integer(psb_ipk_), parameter :: mld_max_ml_type_ = mld_mult_ml_ + integer(psb_ipk_), parameter :: mld_no_ml_ = 0 + integer(psb_ipk_), parameter :: mld_add_ml_ = 1 + integer(psb_ipk_), parameter :: mld_mult_ml_ = 2 + integer(psb_ipk_), parameter :: mld_vcycle_ml_ = 3 + integer(psb_ipk_), parameter :: mld_wcycle_ml_ = 4 + integer(psb_ipk_), parameter :: mld_kcycle_ml_ = 5 + integer(psb_ipk_), parameter :: mld_kcyclesym_ml_ = 6 + integer(psb_ipk_), parameter :: mld_new_ml_prec_ = 7 + integer(psb_ipk_), parameter :: mld_max_ml_type_ = 8 ! ! Legal values for entry: mld_smoother_pos_ ! @@ -337,8 +341,8 @@ module mld_base_prec_type character(len=12), parameter, private :: & & prolong_names(0:3)=(/'none ','sum ','average ','square root'/) character(len=15), parameter, private :: & - & ml_names(0:3)=(/'none ','additive ','multiplicative',& - & 'new ML '/) + & ml_names(0:7)=(/'none ','additive ','multiplicative',& + & 'VCycle ','WCycle ','KCycle ','KCycleSym ','new ML '/) character(len=15), parameter :: & & mld_fact_names(0:mld_max_sub_solve_)=(/& & 'none ','none ',& @@ -422,6 +426,14 @@ contains val = mld_add_ml_ case('MULT') val = mld_mult_ml_ + case('VCYCLE') + val = mld_vcycle_ml_ + case('WCYCLE') + val = mld_wcycle_ml_ + case('KCYCLE') + val = mld_kcycle_ml_ + case('KCYCLESYM') + val = mld_kcyclesym_ml_ case('DEC') val = mld_dec_aggr_ case('SYMDEC')