Fixed some bugs (QR_fact serial)

psblas-bgmres
gabrielequatrana 2 years ago
parent 676652fcff
commit 1b79939255

@ -2936,7 +2936,7 @@ contains
end do
end do
class default
res = x%dot_col(nr,y%v)
res = x%dot(nr,y%v)
end select
end function d_base_mlv_dot_v

@ -1395,9 +1395,6 @@ module psb_d_multivect_mod
procedure, pass(x) :: dot_v => d_vect_dot_v
procedure, pass(x) :: dot_a => d_vect_dot_a
generic, public :: dot => dot_v, dot_a
procedure, pass(x) :: dot_row_v => d_vect_dot_row_v
procedure, pass(x) :: dot_row_a => d_vect_dot_row_a
generic, public :: dot_row => dot_row_v, dot_row_a
procedure, pass(y) :: axpby_v => d_vect_axpby_v
procedure, pass(y) :: axpby_a => d_vect_axpby_a
generic, public :: axpby => axpby_v, axpby_a

@ -256,10 +256,8 @@ function psb_ddot_multivect(x, y, desc_a,info,global) result(res)
nr = desc_a%get_local_rows()
if(nr > 0) then
res = x%dot_col(nr,y)
! TODO adjust dot_local because overlapped elements are computed more than once
res = x%dot(nr,y)
! adjust dot_local because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%v%is_dev()) call x%sync()
if (y%v%is_dev()) call y%sync()

@ -456,8 +456,7 @@ function psb_dnrm2_multivect(x, desc_a, info,global) result(res)
if (desc_a%get_local_rows() > 0) then
res = x%nrm2(desc_a%get_local_rows())
! TODO adjust because overlapped elements are computed more than once
! adjust because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%v%is_dev()) call x%sync()
do j=1,x%get_ncols()

@ -139,8 +139,7 @@ function psb_dprod_multivect(x,y,desc_a,info,trans,global) result(res)
nr = desc_a%get_local_rows()
if (nr > 0) then
res = x%prod(nr,y,trans_)
! TODO adjust dot_local because overlapped elements are computed more than once
! adjust dot_local because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%v%is_dev()) call x%sync()
if (y%v%is_dev()) call y%sync()
@ -156,7 +155,6 @@ function psb_dprod_multivect(x,y,desc_a,info,trans,global) result(res)
res = dzero
end if
! TODO forse è meglio global false di default
! compute global sum
if (global_) call psb_sum(ctxt, res)
@ -268,8 +266,7 @@ function psb_dprod_multivect_a(x,y,desc_a,info,trans,global) result(res)
nr = desc_a%get_local_rows()
if (nr > 0) then
res = x%prod(nr,y,trans_)
! TODO adjust dot_local because overlapped elements are computed more than once
! adjust dot_local because overlapped elements are computed more than once
if (size(desc_a%ovrlap_elem,1)>0) then
if (x%v%is_dev()) call x%sync()
do j=1,x%get_ncols()
@ -284,7 +281,6 @@ function psb_dprod_multivect_a(x,y,desc_a,info,trans,global) result(res)
res = dzero
end if
! TODO
! compute global sum
if (global_) call psb_sum(ctxt, res)

@ -24,6 +24,7 @@ function psb_dqrfact(x, desc_a, info) result(res)
integer(psb_lpk_) :: ix, ijx, m, n
character(len=20) :: name, ch_err
real(psb_dpk_), allocatable :: temp(:,:)
type(psb_d_base_multivect_type) :: qr_temp
name='psb_dgqrfact'
if (psb_errstatus_fatal()) return
@ -66,15 +67,15 @@ function psb_dqrfact(x, desc_a, info) result(res)
call psb_gather(temp,x,desc_a,info,root=psb_root_)
if (me == psb_root_) then
call x%set(temp)
res = x%qr_fact(info)
call qr_temp%bld(temp)
res = qr_temp%qr_fact(info)
temp = qr_temp%get_vect()
call psb_bcast(ctxt,res)
else
allocate(res(n,n))
call psb_bcast(ctxt,res)
end if
temp = x%get_vect()
call psb_scatter(temp,x,desc_a,info,root=psb_root_)
call psb_erractionrestore(err_act)

@ -209,6 +209,7 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
! BGMRES algorithm
! TODO QR fact seriale per ora
! TODO Con tanti ITRS NaN (forse genera righe dipendenti (vedere pargen))
! STEP 1: Compute R(0) = B - A*X(0)
@ -290,10 +291,9 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
end if
! STEP 7: Compute W = W - V(i)*H(i,j)
! TODO si blocca con NRHS grandi?
!temp = psb_geprod(v(i),h(idx_i:idx_i+n_add,idx_j:idx_j+n_add),desc_a,info,global=.false.)
call psb_geaxpby(-done,psb_geprod(v(i),h(idx_i:idx_i+n_add,idx_j:idx_j+n_add),desc_a,info,global=.false.),done,w,desc_a,info)
call psb_geaxpby(-done,&
& psb_geprod(v(i),h(idx_i:idx_i+n_add,idx_j:idx_j+n_add),desc_a,info,global=.false.),&
& done,w,desc_a,info)
if (info /= psb_success_) then
info=psb_err_from_subroutine_non_
call psb_errpush(info,name)
@ -330,7 +330,6 @@ subroutine psb_dbgmres_multivect(a, prec, b, x, eps, desc_a, info, itmax, iter,
goto 9999
end if
! TODO V_tot comprende V(nrep+1)?
! STEP 10: Compute V = {V(1),...,V(m)}
do i=1,nrep
idx = (i-1)*nrhs+1

@ -130,9 +130,9 @@ program psb_dbf_sample
b_mv_glob => aux_b(:,:)
do i=1, m
do j=1, nrhs
b_mv_glob(i,j) = done
!call random_number(random_value)
!b_mv_glob(i,j) = random_value
!b_mv_glob(i,j) = done
call random_number(random_value)
b_mv_glob(i,j) = random_value
enddo
enddo
endif
@ -219,7 +219,6 @@ program psb_dbf_sample
write(psb_out_unit,'(" ")')
end if
! TODO spmm cambia X (che senso ha?)
call psb_geaxpby(done,b_mv,dzero,r_mv,desc_a,info)
call psb_spmm(-done,a,x_mv,done,r_mv,desc_a,info)

Loading…
Cancel
Save