From e423e149fa8c4f6402eb362f77782a8af3c7a0ca Mon Sep 17 00:00:00 2001 From: Fabio Durastante Date: Tue, 8 Apr 2025 23:26:46 +0200 Subject: [PATCH] Implemented right triangular solve --- base/modules/psblas/psb_d_psblas_mod.F90 | 14 +++- base/modules/serial/psb_d_base_vect_mod.F90 | 9 ++- base/modules/serial/psb_d_vect_mod.F90 | 5 +- base/psblas/psb_ddiv_vect.f90 | 67 +++++++++++++++++ base/psblas/psb_ddot.f90 | 3 +- base/psblas/psb_dvmlt.f90 | 24 +++--- test/kernel/vecoperation.f90 | 81 ++++++++++++++++++++- 7 files changed, 183 insertions(+), 20 deletions(-) diff --git a/base/modules/psblas/psb_d_psblas_mod.F90 b/base/modules/psblas/psb_d_psblas_mod.F90 index 6ed3eed7..48f2d30d 100644 --- a/base/modules/psblas/psb_d_psblas_mod.F90 +++ b/base/modules/psblas/psb_d_psblas_mod.F90 @@ -543,7 +543,7 @@ module psb_d_psblas_mod subroutine psb_dmlt_multivect(x, y, res, desc_a,info,global) import :: psb_desc_type, psb_dpk_, psb_ipk_, & & psb_d_multivect_type, psb_dspmat_type - real(psb_dpk_), dimension(:,:), allocatable :: res + real(psb_dpk_), dimension(:,:), allocatable, intent(inout) :: res type(psb_d_multivect_type), intent(inout) :: x, y type(psb_desc_type), intent(in) :: desc_a integer(psb_ipk_), intent(out) :: info @@ -588,6 +588,18 @@ module psb_d_psblas_mod integer(psb_ipk_), intent(out) :: info logical, intent(in) :: flag end subroutine psb_ddiv_vect2_check + subroutine psb_ddiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag) + import :: psb_desc_type, psb_ipk_, & + & psb_dpk_, psb_d_multivect_type + type(psb_d_multivect_type), intent (inout) :: x + real(psb_dpk_), intent (in), dimension(:,:) :: a + type(psb_desc_type), intent (in) :: desc_a + character(len=1), intent(in) :: uplo + integer(psb_ipk_), intent(out) :: info + real(psb_dpk_), intent (in), optional :: alpha + character(len=1), intent(in), optional :: trans + character(len=1), intent(in), optional :: diag + end subroutine psb_ddiv_trslv end interface interface psb_geinv diff --git a/base/modules/serial/psb_d_base_vect_mod.F90 b/base/modules/serial/psb_d_base_vect_mod.F90 index 2a2a6902..4c415387 100644 --- a/base/modules/serial/psb_d_base_vect_mod.F90 +++ b/base/modules/serial/psb_d_base_vect_mod.F90 @@ -2911,8 +2911,8 @@ contains if (x%is_dev()) call x%sync() if (x%is_sync()) then ! Call BLAS function to solve the system - lda = n - ldb = n + lda = size(a,1) + ldb = x%get_nrows() side = 'R' ! X*op( A ) = alpha*B. call dtrsm(side, uplo, trans_, diag_, n, x%get_ncols(), alpha_, a, lda, x%v, ldb) end if @@ -3141,9 +3141,10 @@ contains !! \param y The class(base_mlv_vect) to be multiplied by !! \param a The resulting matrix !! \param info return code - subroutine d_base_mlv_mlt_mv2(x,y,a,info) + subroutine d_base_mlv_mlt_mv2(n,x,y,a,info) use psi_serial_mod implicit none + integer(psb_ipk_), intent(in) :: n class(psb_d_base_multivect_type), intent(inout) :: x class(psb_d_base_multivect_type), intent(inout) :: y real(psb_dpk_), intent(inout), allocatable :: a(:,:) @@ -3171,7 +3172,7 @@ contains ! C = alpha*op( A )*op( B ) + beta*C ! In our case, we want to compute ! C = X'*Y - call dgemm('T', 'N', x%get_ncols(), y%get_ncols(), x%get_nrows(), done, & + call dgemm('T', 'N', x%get_ncols(), y%get_ncols(), n, done, & & x%v, x%get_nrows(), y%v, y%get_nrows(), dzero, a, x%get_ncols()) end subroutine d_base_mlv_mlt_mv2 diff --git a/base/modules/serial/psb_d_vect_mod.F90 b/base/modules/serial/psb_d_vect_mod.F90 index 29e59ce5..91862891 100644 --- a/base/modules/serial/psb_d_vect_mod.F90 +++ b/base/modules/serial/psb_d_vect_mod.F90 @@ -1976,16 +1976,17 @@ contains end if end subroutine d_mlv_trslv - subroutine d_mvect_mlt_mv2(x,y,a,info) + subroutine d_mvect_mlt_mv2(n,x,y,a,info) use psi_serial_mod implicit none + integer(psb_ipk_), intent(in) :: n class(psb_d_multivect_type), intent(inout) :: x class(psb_d_multivect_type), intent(inout) :: y real(psb_dpk_), intent(inout), allocatable :: a(:,:) integer(psb_ipk_), intent(out) :: info if (allocated(x%v).and.allocated(y%v)) then - call y%v%mlt(x%v,a,info) + call y%v%mlt(n,x%v,a,info) else info = psb_err_invalid_vect_state_ return diff --git a/base/psblas/psb_ddiv_vect.f90 b/base/psblas/psb_ddiv_vect.f90 index 7f958e19..8eba3ac1 100644 --- a/base/psblas/psb_ddiv_vect.f90 +++ b/base/psblas/psb_ddiv_vect.f90 @@ -444,3 +444,70 @@ function psb_dminquotient_vect(x,y,desc_a,info,global) result(res) return end function psb_dminquotient_vect + +subroutine psb_ddiv_trslv(x,a,desc_a,uplo,info,alpha,trans,diag) + use psb_base_mod, psb_protect_name => psb_ddiv_trslv + implicit none + type(psb_d_multivect_type), intent (inout) :: x + real(psb_dpk_), intent (in), dimension(:,:) :: a + type(psb_desc_type), intent (in) :: desc_a + character(len=1), intent(in) :: uplo + integer(psb_ipk_), intent(out) :: info + real(psb_dpk_), intent (in), optional :: alpha + character(len=1), intent(in), optional :: trans + character(len=1), intent(in), optional :: diag + + ! locals + type(psb_ctxt_type) :: ctxt + integer(psb_ipk_) :: np, me,& + & err_act, iix, jjx, iiy, jjy, nr + integer(psb_lpk_) :: ix, ijx, iy, ijy, m, nx + character(len=20) :: name, ch_err + + name='psb_ddiv_trslv' + if (psb_errstatus_fatal()) return + info=psb_success_ + call psb_erractionsave(err_act) + + ctxt=desc_a%get_context() + + call psb_info(ctxt, me, np) + if (np == -ione) then + info = psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + if (.not.allocated(x%v)) then + info = psb_err_invalid_vect_state_ + call psb_errpush(info,name) + goto 9999 + endif + + ix = ione + ijx = ione + + m = desc_a%get_global_rows() + nx = x%get_ncols() + + ! check vector correctness + call psb_chkvect(m,nx,x%get_nrows(),ix,ijx,desc_a,info,iix,jjx) + if(info /= psb_success_) then + info=psb_err_from_subroutine_ + ch_err='psb_chkvect 1' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 + end if + + nr = desc_a%get_local_rows() + if(nr > 0) then + call x%trslv(nr,a,uplo,alpha,trans,diag,info) + end if + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + +end subroutine psb_ddiv_trslv diff --git a/base/psblas/psb_ddot.f90 b/base/psblas/psb_ddot.f90 index 98adb42a..949efb2c 100644 --- a/base/psblas/psb_ddot.f90 +++ b/base/psblas/psb_ddot.f90 @@ -480,7 +480,8 @@ function psb_ddot_mvect_vect(x, y, desc_a,info,global) result(res) do i=1,size(desc_a%ovrlap_elem,1) idx = desc_a%ovrlap_elem(i,1) ndm = desc_a%ovrlap_elem(i,2) - ! Remove the overlapped elements via dgemv calls + ! FIXME: MAKES NO SENSE! + ! Remove the overlapped elements via dgemv calls which are axpy ! res = - (real(ndm-1)/real(ndm))* x(idx,:)^T y(idx) + 1.0 res call dgemv('C',size(x%v%v,1),size(x%v%v,2),-(real(ndm-1)/real(ndm)), & & size(x%v%v,1),y%v%v(idx),ione,done,res,ione) diff --git a/base/psblas/psb_dvmlt.f90 b/base/psblas/psb_dvmlt.f90 index 42c33bfa..c0c9d72f 100644 --- a/base/psblas/psb_dvmlt.f90 +++ b/base/psblas/psb_dvmlt.f90 @@ -169,7 +169,7 @@ subroutine psb_dmlt_multivect(x, y, res,desc_a,info,global) use psb_d_multivect_mod use psb_d_psblas_mod, psb_protect_name => psb_dmlt_multivect implicit none - real(psb_dpk_), dimension(:,:), allocatable :: res + real(psb_dpk_), dimension(:,:), allocatable, intent(inout) :: res type(psb_d_multivect_type), intent(inout) :: x, y type(psb_desc_type), intent(in) :: desc_a integer(psb_ipk_), intent(out) :: info @@ -246,17 +246,16 @@ subroutine psb_dmlt_multivect(x, y, res,desc_a,info,global) call psb_errpush(info,name) goto 9999 else - allocate(res(x%get_nrows(),x%get_ncols()),stat=info) - if (info /= 0) then - info=psb_err_alloc_dealloc_ - call psb_errpush(info,name) - goto 9999 + if (allocated(res)) then + if ((size(res,1) /= x%get_ncols()).or.(size(res,2) /= y%get_ncols())) then + deallocate(res,stat=info) + end if end if end if nr = desc_a%get_local_rows() if(nr > 0) then - call x%mlt(y,res,info) + call x%mlt(nr,y,res,info) ! FIXME ! adjust dot_local because overlapped elements are computed more than once if (size(desc_a%ovrlap_elem,1)>0) then @@ -265,15 +264,20 @@ subroutine psb_dmlt_multivect(x, y, res,desc_a,info,global) do i=1,size(desc_a%ovrlap_elem,1) idx = desc_a%ovrlap_elem(i,1) ndm = desc_a%ovrlap_elem(i,2) + ! FIXME: case of AS-type descriptors ! Since I'm coputing res via a dgemm on the whole vector, I need to adjust ! the result by removing the contribution of the overlapped elements ! specifically: res(:,:) = res(:,:) - x%v%v(idx,:)^T*y%v%v(idx,:) ! using dgemm to compute the matrix-matrix product of the form R = R - X'*Y ! where R is the result, X' is the transpose of the matrix x%v%v(idx,:) ! and Y is the matrix y%v%v(idx,:) - call dgemm('T','N',size(x%v%v(idx,:),1),size(y%v%v(idx,:),1),& - & size(x%v%v(idx,:),1),-done,x%v%v(idx,:),size(x%v%v(idx,:),1),& - & y%v%v(idx,:),size(y%v%v(idx,:),1),done,res,y%get_ncols()) + ! call dgemm('T','N',size(x%v%v(idx,:),1),size(y%v%v(idx,:),1),& + ! & size(x%v%v(idx,:),1),-done,x%v%v(idx,:),size(x%v%v(idx,:),1),& + ! & y%v%v(idx,:),size(y%v%v(idx,:),1),done,res,y%get_ncols()) + info = psb_err_internal_error_ + ch_err='over_elem_unsup' + call psb_errpush(info,name,a_err=ch_err) + goto 9999 end do end if else diff --git a/test/kernel/vecoperation.f90 b/test/kernel/vecoperation.f90 index 44abefdd..fc5f7e2e 100644 --- a/test/kernel/vecoperation.f90 +++ b/test/kernel/vecoperation.f90 @@ -45,7 +45,8 @@ module unittestvector_mod module procedure psb_d_check_ans_v, psb_c_check_ans_v, & & psb_z_check_ans_v, psb_s_check_ans_v, & & psb_d_check_ans_mv, psb_s_check_ans_mv, & - & psb_c_check_ans_mv, psb_z_check_ans_mv + & psb_c_check_ans_mv, psb_z_check_ans_mv, & + & psb_d_check_ans_mv_a end interface psb_check_ans contains @@ -210,6 +211,43 @@ contains end function psb_d_check_ans_mv + function psb_d_check_ans_mv_a(v,val,ctxt) result(ans) + use psb_base_mod + + implicit none + + type(psb_d_multivect_type) :: v + real(psb_dpk_) :: val(:) + type(psb_ctxt_type) :: ctxt + logical :: ans + + ! Local variables + integer(psb_ipk_) :: np, iam, info,i + real(psb_dpk_) :: check + real(psb_dpk_), allocatable :: va(:,:) + + call psb_info(ctxt,iam,np) + + va = v%get_vect() + ! subtract the row vector val from every row of va + + do i=1,size(va,1) + va(i,:) = va(i,:) - val; + end do + + check = maxval(va); + + call psb_sum(ctxt,check) + + if(check == 0.d0) then + ans = .true. + else + ans = .false. + end if + + end function psb_d_check_ans_mv_a + + function psb_s_check_ans_mv(v,val,ctxt) result(ans) use psb_base_mod @@ -1050,11 +1088,13 @@ program vecoperation type(psb_z_multivect_type) :: zmv1, zmv2 ! scalars real(psb_dpk_), allocatable, dimension(:,:) :: res + real(psb_dpk_), allocatable, dimension(:,:) :: a + real(psb_dpk_), allocatable, dimension(:) :: check_row ! blacs parameters type(psb_ctxt_type) :: ctxt integer(psb_ipk_) :: iam, np ! auxiliary parameters - integer(psb_ipk_) :: ii + integer(psb_ipk_) :: ii,jj integer(psb_ipk_) :: info character(len=20) :: name,ch_err,readinput real(psb_dpk_) :: ans @@ -1491,6 +1531,41 @@ program vecoperation if(hasitnotfailed) write(psb_out_unit,'("TEST PASSED >>> Constant multivector (complex double precision)")') if(.not.hasitnotfailed) write(psb_out_unit,'("TEST FAILED --- Constant multivector (complex double precision)")') end if + ! X = 1, T = upper triangular of all ones + call psb_d_gen_const_multi(mv1,done,idim,nmv,ctxt,desc_a,info) + allocate(a(nmv,nmv),check_row(nmv)) + do ii=1,nmv + do jj=ii,nmv + a(ii,jj) = done + end do + end do + check_row = 0 + check_row(1) = done + call psb_gediv(mv1,a,desc_a,'U',info) + hasitnotfailed = psb_check_ans(mv1,check_row,ctxt) + if (iam == psb_root_) then + if(hasitnotfailed) write(psb_out_unit,'("TEST PASSED >>> Triangular solve (UP) mv1 = mv1 / T")') + if(.not.hasitnotfailed) write(psb_out_unit,'("TEST FAILED --- Triangular solve (UP) mv1 = mv1 / T")') + end if + ! X = 1, T = lower triangular of all ones + call psb_d_gen_const_multi(mv1,done,idim,nmv,ctxt,desc_a,info) + if (allocated(a)) deallocate(a) + if (allocated(check_row)) deallocate(check_row) + allocate(a(nmv,nmv),check_row(nmv)) + do ii=1,nmv + do jj=1,ii + a(ii,jj) = done + end do + end do + check_row = 0 + check_row(nmv) = done + call psb_gediv(mv1,a,desc_a,'L',info) + hasitnotfailed = psb_check_ans(mv1,check_row,ctxt) + if (iam == psb_root_) then + if(hasitnotfailed) write(psb_out_unit,'("TEST PASSED >>> Triangular solve (LOW) mv1 = mv1 / T")') + if(.not.hasitnotfailed) write(psb_out_unit,'("TEST FAILED --- Triangular solve (LOW) mv1 = mv1 / T")') + end if + ! ! Multivector to field operation @@ -1604,6 +1679,8 @@ program vecoperation call psb_gefree(zmv2,desc_a,info) call psb_cdfree(desc_a,info) if(allocated(res)) deallocate(res) + if(allocated(a)) deallocate(a) + if(allocated(check_row)) deallocate(check_row) if(info /= psb_success_) then info=psb_err_from_subroutine_ ch_err='free routine'