[UPDATE] Changed all the interfaces that calls psi_swapdata inside PSBLAS internals for double precision vectors. Added also tests under test/comm/ in order to check psi_swapdata, psb_spmv and psb_cg calls

communication_v2
Stack-1 1 month ago
parent 09a5a74d75
commit 33477e4f03

@ -82,12 +82,7 @@
submodule (psi_d_comm_v_mod) psi_d_swapdata_impl
use psb_desc_const_mod, only: psb_swap_start_, psb_swap_wait_
use psb_base_mod
use psb_comm_schemes_mod, only: psb_comm_handle_type, psb_comm_isend_irecv_, psb_comm_ineighbor_alltoallv_, &
& psb_comm_persistent_ineighbor_alltoallv_, &
& psb_comm_status_start_, psb_comm_status_wait_, psb_comm_status_unknown_
use psb_comm_factory_mod, only: psb_comm_init, psb_comm_free
use psb_comm_baseline_mod, only: psb_comm_baseline_handle, psb_comm_baseline_alloc_comid
use psb_comm_neighbor_impl_mod, only: psb_comm_neighbor_handle
use psb_comm_factory_mod
contains
module subroutine psi_dswapdata_vect(swap_status,beta,y,desc_a,info,data)
@ -109,7 +104,6 @@ contains
! locals
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me, total_send, total_recv, num_neighbors, data_, err_act
integer(psb_ipk_) :: setflag
class(psb_i_base_vect_type), pointer :: comm_indexes
! communication scheme/status selectors
@ -155,21 +149,13 @@ contains
! write(psb_err_unit,*) me, 'DBG: get_list_p -> num_neighbors=', &
! & num_neighbors, ' total_send=', total_send, ' total_recv=', total_recv
! end if
! Accept both new comm-status enums and legacy descriptor bitfields.
setflag = swap_status
if (swap_status == psb_swap_start_) then
setflag = psb_comm_status_start_
else if (swap_status == psb_swap_wait_) then
setflag = psb_comm_status_wait_
else if (iand(swap_status, psb_swap_start_) /= 0 .and. iand(swap_status, psb_swap_wait_) /= 0) then
setflag = psb_comm_status_unknown_
end if
if ((setflag /= psb_comm_status_start_) .and. (setflag /= psb_comm_status_wait_) .and. &
& (setflag /= psb_comm_status_unknown_)) then
info = psb_err_mpi_error_
call psb_errpush(info,name,a_err='Invalid swap_status swap_status')
goto 9999
if( (swap_status /= psb_comm_status_start_).and.(swap_status /= psb_comm_status_wait_)&
& .and.(swap_status /= psb_comm_status_sync_) ) then
info = psb_err_mpi_error_
call psb_errpush(info,name,a_err='Invalid swap_status swap_status')
goto 9999
end if
if (.not. allocated(y%comm_handle)) then
@ -189,7 +175,7 @@ contains
! end if
! end if
! Set the normalized swap status on the comm handle
call y%comm_handle%set_swap_status(setflag, info)
call y%comm_handle%set_swap_status(swap_status, info)
if (info /= psb_success_) then
call psb_errpush(info,name,a_err='set_swap_status')
goto 9999
@ -292,9 +278,15 @@ contains
goto 9999
end select
if(swap_status == psb_comm_status_unknown_) then
info = psb_err_mpi_error_
call psb_errpush(info,name,a_err='Invalid swap_status: psb_comm_status_unknown_ is not allowed in neighbor swap')
goto 9999
end if
n=1
do_send = (swap_status == psb_comm_status_start_) .or. (swap_status == psb_comm_status_unknown_)
do_recv = (swap_status == psb_comm_status_wait_) .or. (swap_status == psb_comm_status_unknown_)
do_send = (swap_status == psb_comm_status_start_).or.(swap_status == psb_comm_status_sync_)
do_recv = (swap_status == psb_comm_status_wait_).or.(swap_status == psb_comm_status_sync_)
total_recv_ = total_recv * n
total_send_ = total_send * n
@ -539,8 +531,14 @@ contains
goto 9999
end select
do_start = (swap_status == psb_comm_status_start_) .or. (swap_status == psb_comm_status_unknown_)
do_wait = (swap_status == psb_comm_status_wait_) .or. (swap_status == psb_comm_status_unknown_)
if(swap_status == psb_comm_status_unknown_) then
info = psb_err_mpi_error_
call psb_errpush(info,name,a_err='Invalid swap_status: psb_comm_status_unknown_ is not allowed in neighbor swap')
goto 9999
end if
do_start = (swap_status == psb_comm_status_start_) .or. (swap_status == psb_comm_status_sync_)
do_wait = (swap_status == psb_comm_status_wait_) .or. (swap_status == psb_comm_status_sync_)
call comm_indexes%sync()

@ -92,10 +92,7 @@
!
submodule (psi_d_comm_v_mod) psi_d_swaptran_impl
use psb_base_mod
use psb_comm_schemes_mod, only: psb_comm_handle_type, psb_comm_isend_irecv_
use psb_comm_factory_mod, only: psb_comm_init, psb_comm_free
use psb_comm_baseline_mod, only: psb_comm_baseline_handle
use psb_comm_neighbor_impl_mod, only: psb_comm_neighbor_handle
use psb_comm_factory_mod
contains
module subroutine psi_dswaptran_vect(swap_status,beta,y,desc_a,info,data)
@ -122,10 +119,8 @@ contains
character(len=20) :: name
! local variables used to detect the communication scheme
logical :: swap_mpi, swap_sync, swap_send, swap_recv, swap_start, swap_wait
logical :: baseline, neighbor_a2av
info = psb_success_
name = 'psi_dswaptran_vect'
call psb_erractionsave(err_act)
@ -157,20 +152,11 @@ contains
goto 9999
end if
swap_mpi = iand(swap_status,psb_swap_mpi_) /= 0
swap_sync = iand(swap_status,psb_swap_sync_) /= 0
swap_send = iand(swap_status,psb_swap_send_) /= 0
swap_recv = iand(swap_status,psb_swap_recv_) /= 0
swap_start = iand(swap_status,psb_swap_start_) /= 0
swap_wait = iand(swap_status,psb_swap_wait_) /= 0
baseline = swap_mpi .or. swap_send .or. swap_recv .or. swap_sync
neighbor_a2av = swap_start .or. swap_wait
if( (baseline.eqv..true.).and.(neighbor_a2av.eqv..true.) ) then
info = psb_err_mpi_error_
call psb_errpush(info,name,a_err='Incompatible swap_status settings: both baseline and neighbor_a2av are true')
goto 9999
if( (swap_status /= psb_comm_status_start_).and.(swap_status /= psb_comm_status_wait_)&
& .and.(swap_status /= psb_comm_status_sync_) ) then
info = psb_err_mpi_error_
call psb_errpush(info,name,a_err='Invalid swap_status swap_status')
goto 9999
end if
if (.not. allocated(y%comm_handle)) then
@ -181,6 +167,22 @@ contains
end if
end if
! Set the normalized swap status on the comm handle
call y%comm_handle%set_swap_status(swap_status, info)
if (info /= psb_success_) then
call psb_errpush(info,name,a_err='set_swap_status')
goto 9999
end if
baseline = .false.
neighbor_a2av = .false.
select case(y%comm_handle%comm_type)
case(psb_comm_ineighbor_alltoallv_, psb_comm_persistent_ineighbor_alltoallv_)
neighbor_a2av = .true.
case default
baseline = .true.
end select
if (baseline) then
call psi_dtran_baseline_vect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,total_send,total_recv,y%comm_handle,info)
if (info /= psb_success_) then
@ -276,12 +278,8 @@ contains
end select
n=1
swap_mpi = iand(swap_status,psb_swap_mpi_) /= 0
swap_sync = iand(swap_status,psb_swap_sync_) /= 0
swap_send = iand(swap_status,psb_swap_send_) /= 0
swap_recv = iand(swap_status,psb_swap_recv_) /= 0
do_send = swap_mpi .or. swap_sync .or. swap_send
do_recv = swap_mpi .or. swap_sync .or. swap_recv
do_send = (swap_status == psb_comm_status_start_) .or. (swap_status == psb_comm_status_sync_)
do_recv = (swap_status == psb_comm_status_wait_) .or. (swap_status == psb_comm_status_sync_)
total_recv_ = total_recv * n
total_send_ = total_send * n
@ -534,8 +532,8 @@ contains
goto 9999
end select
do_start = iand(swap_status,psb_swap_start_) /= 0
do_wait = iand(swap_status,psb_swap_wait_) /= 0
do_start = (swap_status == psb_comm_status_start_) .or. (swap_status == psb_comm_status_sync_)
do_wait = (swap_status == psb_comm_status_wait_) .or. (swap_status == psb_comm_status_sync_)
call comm_indexes%sync()
@ -690,8 +688,8 @@ contains
character(len=20) :: name
! local variables used to detect the communication scheme
logical :: swap_mpi, swap_sync, swap_send, swap_recv, swap_start, swap_wait
logical :: baseline, neighbor_a2av
integer(psb_ipk_) :: setflag
info = psb_success_
@ -725,19 +723,20 @@ contains
goto 9999
end if
swap_mpi = iand(swap_status,psb_swap_mpi_) /= 0
swap_sync = iand(swap_status,psb_swap_sync_) /= 0
swap_send = iand(swap_status,psb_swap_send_) /= 0
swap_recv = iand(swap_status,psb_swap_recv_) /= 0
swap_start = iand(swap_status,psb_swap_start_) /= 0
swap_wait = iand(swap_status,psb_swap_wait_) /= 0
baseline = swap_mpi .or. swap_send .or. swap_recv .or. swap_sync
neighbor_a2av = swap_start .or. swap_wait
setflag = swap_status
if (swap_status == psb_swap_start_) then
setflag = psb_comm_status_start_
else if (swap_status == psb_swap_wait_) then
setflag = psb_comm_status_wait_
else if ((iand(swap_status, psb_swap_send_) /= 0) .or. (iand(swap_status, psb_swap_recv_) /= 0) .or. &
& (iand(swap_status, psb_swap_mpi_) /= 0) .or. (iand(swap_status, psb_swap_sync_) /= 0)) then
setflag = psb_comm_status_sync_
end if
if( (baseline.eqv..true.).and.(neighbor_a2av.eqv..true.) ) then
if ((setflag /= psb_comm_status_start_) .and. (setflag /= psb_comm_status_wait_) .and. &
& (setflag /= psb_comm_status_sync_)) then
info = psb_err_mpi_error_
call psb_errpush(info,name,a_err='Incompatible swap_status settings: both baseline and neighbor_a2av are true')
call psb_errpush(info,name,a_err='Invalid swap_status')
goto 9999
end if
@ -749,14 +748,29 @@ contains
end if
end if
call y%comm_handle%set_swap_status(setflag, info)
if (info /= psb_success_) then
call psb_errpush(info,name,a_err='set_swap_status')
goto 9999
end if
baseline = .false.
neighbor_a2av = .false.
select case(y%comm_handle%comm_type)
case(psb_comm_ineighbor_alltoallv_, psb_comm_persistent_ineighbor_alltoallv_)
neighbor_a2av = .true.
case default
baseline = .true.
end select
if (baseline) then
call psi_dtran_baseline_multivect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,total_send,total_recv,y%comm_handle,info)
call psi_dtran_baseline_multivect(ctxt,setflag,beta,y,comm_indexes,num_neighbors,total_send,total_recv,y%comm_handle,info)
if (info /= psb_success_) then
call psb_errpush(info,name,a_err='baseline swap')
goto 9999
end if
else if (neighbor_a2av) then
call psi_dtran_neighbor_topology_multivect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,&
call psi_dtran_neighbor_topology_multivect(ctxt,setflag,beta,y,comm_indexes,num_neighbors,&
& total_send,total_recv,y%comm_handle,info)
if (info /= psb_success_) then
call psb_errpush(info,name,a_err='neighbor a2av swap')
@ -832,12 +846,8 @@ contains
n = y%get_ncols()
swap_mpi = iand(swap_status,psb_swap_mpi_) /= 0
swap_sync = iand(swap_status,psb_swap_sync_) /= 0
swap_send = iand(swap_status,psb_swap_send_) /= 0
swap_recv = iand(swap_status,psb_swap_recv_) /= 0
do_send = swap_mpi .or. swap_sync .or. swap_send
do_recv = swap_mpi .or. swap_sync .or. swap_recv
do_send = (swap_status == psb_comm_status_start_) .or. (swap_status == psb_comm_status_sync_)
do_recv = (swap_status == psb_comm_status_wait_) .or. (swap_status == psb_comm_status_sync_)
total_recv_ = total_recv * n
total_send_ = total_send * n
@ -1091,8 +1101,8 @@ contains
goto 9999
end select
do_start = iand(swap_status,psb_swap_start_) /= 0
do_wait = iand(swap_status,psb_swap_wait_) /= 0
do_start = (swap_status == psb_comm_status_start_) .or. (swap_status == psb_comm_status_sync_)
do_wait = (swap_status == psb_comm_status_wait_) .or. (swap_status == psb_comm_status_sync_)
call comm_indexes%sync()

@ -54,6 +54,8 @@
subroutine psb_dhalo_vect(x,desc_a,info,tran,mode,data)
use psb_base_mod, psb_protect_name => psb_dhalo_vect
use psi_mod
use psb_comm_factory_mod
implicit none
type(psb_d_vect_type), intent(inout) :: x
@ -115,7 +117,7 @@ subroutine psb_dhalo_vect(x,desc_a,info,tran,mode,data)
if (present(mode)) then
imode = mode
else
imode = IOR(psb_swap_send_,psb_swap_recv_) ! default base communication scheme Isend/Irecv
imode = psb_comm_status_sync_
endif
if ((info == 0).and.(lldx<ncol)) call x%reall(ncol,info)
@ -238,7 +240,7 @@ subroutine psb_dhalo_multivect(x,desc_a,info,tran,mode,data)
if (present(mode)) then
imode = mode
else
imode = IOR(psb_swap_send_,psb_swap_recv_)
imode = psb_comm_mov_
endif
if (lldx < ncol) call x%reall(ncol,x%get_ncols(),info)

@ -65,6 +65,7 @@
subroutine psb_dovrl_vect(x,desc_a,info,update,mode)
use psb_base_mod, psb_protect_name => psb_dovrl_vect
use psi_mod
use psb_comm_factory_mod
implicit none
type(psb_d_vect_type), intent(inout) :: x
@ -121,7 +122,7 @@ subroutine psb_dovrl_vect(x,desc_a,info,update,mode)
if (present(mode)) then
mode_ = mode
else
mode_ = IOR(psb_swap_send_,psb_swap_recv_)
mode_ = psb_comm_status_sync_
endif
do_swap = (mode_ /= 0)

@ -1,7 +1,9 @@
module psb_comm_factory_mod
use psb_const_mod
use psb_comm_schemes_mod, only: psb_comm_handle_type, psb_comm_ineighbor_alltoallv_, &
& psb_comm_persistent_ineighbor_alltoallv_, psb_comm_unknown_
use psb_comm_schemes_mod, only: psb_comm_handle_type, psb_comm_isend_irecv_, &
& psb_comm_ineighbor_alltoallv_, psb_comm_persistent_ineighbor_alltoallv_, &
& psb_comm_unknown_, psb_comm_status_start_, psb_comm_status_wait_, &
& psb_comm_status_sync_, psb_comm_status_unknown_
use psb_comm_baseline_mod, only: psb_comm_baseline_handle
use psb_comm_neighbor_impl_mod, only: psb_comm_neighbor_handle
implicit none
@ -14,18 +16,45 @@ contains
integer(psb_ipk_), intent(in) :: comm_type
class(psb_comm_handle_type), allocatable, intent(inout) :: handle
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: old_id, old_swap_status
info = 0
old_id = 0
old_swap_status = psb_comm_status_unknown_
if (allocated(handle)) then
info = -1
return
old_id = handle%id
old_swap_status = handle%swap_status
if (handle%comm_type == comm_type) then
call handle%free(info)
if (info /= 0) return
call handle%init(info)
if (info /= 0) return
handle%id = old_id
handle%swap_status = old_swap_status
select type(h => handle)
type is(psb_comm_neighbor_handle)
h%comm_type = comm_type
h%use_persistent_buffers = (comm_type == psb_comm_persistent_ineighbor_alltoallv_)
class default
! nothing else to configure
end select
return
else
call psb_comm_free(handle, info)
if (info /= 0) return
end if
end if
select case(comm_type)
case(psb_comm_ineighbor_alltoallv_, psb_comm_persistent_ineighbor_alltoallv_)
allocate(psb_comm_neighbor_handle :: handle, stat=info)
if (info /= 0) return
call handle%init(info)
if (info /= 0) return
handle%id = old_id
handle%swap_status = old_swap_status
select type(h => handle)
type is(psb_comm_neighbor_handle)
h%comm_type = comm_type
@ -35,6 +64,9 @@ contains
allocate(psb_comm_baseline_handle :: handle, stat=info)
if (info /= 0) return
call handle%init(info)
if (info /= 0) return
handle%id = old_id
handle%swap_status = old_swap_status
end select
end subroutine psb_comm_init

@ -17,6 +17,7 @@ module psb_comm_schemes_mod
enumerator psb_comm_status_unknown_
enumerator psb_comm_status_start_
enumerator psb_comm_status_wait_
enumerator psb_comm_status_sync_ ! Used in order to exchange data in a synchronous way (Start and recv in the same call)
end enum

@ -57,6 +57,8 @@ subroutine psb_dspmv_vect(alpha,a,x,beta,y,desc_a,info,&
& trans, doswap)
use psb_base_mod, psb_protect_name => psb_dspmv_vect
use psi_mod
use psb_comm_factory_mod
implicit none
real(psb_dpk_), intent(in) :: alpha, beta
@ -174,12 +176,12 @@ subroutine psb_dspmv_vect(alpha,a,x,beta,y,desc_a,info,&
!if (me==0) write(0,*) 'going for overlap ',a%ad%get_fmt(),' ',a%and%get_fmt()
if (do_timings) call psb_barrier(ctxt)
if (do_timings) call psb_tic(mv_phase1)
if (doswap_) call psi_swapdata(psb_swap_send_, dzero, x%v, desc_a, info, data=psb_comm_halo_)
if (doswap_) call psi_swapdata(psb_comm_status_start_, dzero, x%v, desc_a, info, data=psb_comm_halo_)
if (do_timings) call psb_toc(mv_phase1)
if (do_timings) call psb_tic(mv_phase2)
call a%ad%spmm(alpha,x%v,beta,y%v,info)
if (do_timings) call psb_tic(mv_phase3)
if (doswap_) call psi_swapdata(psb_swap_recv_, dzero, x%v, desc_a, info, data=psb_comm_halo_)
if (doswap_) call psi_swapdata(psb_comm_status_wait_, dzero, x%v, desc_a, info, data=psb_comm_halo_)
if (do_timings) call psb_toc(mv_phase3)
if (do_timings) call psb_tic(mv_phase4)
call a%and%spmm(alpha,x%v,done,y%v,info)
@ -194,7 +196,7 @@ subroutine psb_dspmv_vect(alpha,a,x,beta,y,desc_a,info,&
if (do_timings) call psb_tic(mv_phase11)
if (doswap_) then
call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_), dzero, x%v, desc_a, info, data=psb_comm_halo_)
call psi_swapdata(psb_comm_status_sync_, dzero, x%v, desc_a, info, data=psb_comm_halo_)
end if
if (do_timings) call psb_toc(mv_phase11)
if (do_timings) call psb_tic(mv_phase12)
@ -237,9 +239,9 @@ subroutine psb_dspmv_vect(alpha,a,x,beta,y,desc_a,info,&
end if
if (doswap_) then
call psi_swaptran(ior(psb_swap_send_,psb_swap_recv_), done, y%v, desc_a, info)
call psi_swaptran(psb_comm_status_sync_, done, y%v, desc_a, info)
if (info == psb_success_) then
call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_), done, y%v, desc_a, info, data=psb_comm_ovr_)
call psi_swapdata(psb_comm_status_sync_, done, y%v, desc_a, info, data=psb_comm_ovr_)
end if
if (debug_level >= psb_debug_comp_) &
@ -301,6 +303,8 @@ subroutine psb_dspmm(alpha,a,x,beta,y,desc_a,info,&
& trans, k, jx, jy, work, doswap)
use psb_base_mod, psb_protect_name => psb_dspmm
use psi_mod
use psb_comm_factory_mod
implicit none
real(psb_dpk_), intent(in) :: alpha, beta
@ -560,9 +564,9 @@ subroutine psb_dspmm(alpha,a,x,beta,y,desc_a,info,&
if (doswap_)then
ik = lik ! This should not be an issue, we are expecting the values
! to be small, within PSB_IPK
call psi_swaptran(ior(psb_swap_send_,psb_swap_recv_),&
& ik,done,y(:,1:ik),desc_a,iwork,info)
if (info == psb_success_) call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),&
call psi_swaptran(psb_comm_status_sync_,&
& ik,done,y(:,1:ik),desc_a,iwork,info,data=psb_comm_ovr_)
if (info == psb_success_) call psi_swapdata(psb_comm_status_sync_,&
& ik,done,y(:,1:ik),desc_a,iwork,info,data=psb_comm_ovr_)
if (debug_level >= psb_debug_comp_) &
@ -653,6 +657,7 @@ subroutine psb_dspmv(alpha,a,x,beta,y,desc_a,info,&
& trans, work, doswap)
use psb_base_mod, psb_protect_name => psb_dspmv
use psi_mod
use psb_comm_factory_mod
implicit none
real(psb_dpk_), intent(in) :: alpha, beta
@ -799,8 +804,7 @@ subroutine psb_dspmv(alpha,a,x,beta,y,desc_a,info,&
end if
if (doswap_) then
call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),&
& dzero,x,desc_a,iwork,info,data=psb_comm_halo_)
call psi_swapdata(psb_comm_status_sync_,dzero,x,desc_a,iwork,info,data=psb_comm_halo_)
end if
call psb_csmm(alpha,a,x,beta,y,info)
@ -866,11 +870,10 @@ subroutine psb_dspmv(alpha,a,x,beta,y,desc_a,info,&
end if
if (doswap_) then
call psi_swaptran(ior(psb_swap_send_,psb_swap_recv_),&
& done,yp,desc_a,iwork,info)
if (info == psb_success_) call psi_swapdata(ior(psb_swap_send_,psb_swap_recv_),&
& done,yp,desc_a,iwork,info,data=psb_comm_ovr_)
call psi_swaptran(psb_comm_status_sync_,done,yp,desc_a,iwork,info)
if (info == psb_success_) then
call psi_swapdata(psb_comm_status_sync_,done,yp,desc_a,iwork,info,data=psb_comm_ovr_)
end if
if (debug_level >= psb_debug_comp_) &
& write(debug_unit,*) me,' ',trim(name),' swaptran ', info
if(info /= psb_success_) then

1824
log.txt

File diff suppressed because one or more lines are too long

@ -0,0 +1,23 @@
Communication scheme tests
==========================
This directory contains tests created after adding multiple communication schemes
to PSBLAS.
The goal is to exercise and compare communication patterns at different layers:
- direct halo exchange (`psi_swapdata`)
- overlap exchange for transpose/SpMV workflows (`psi_swaptran` + `psi_swapdata`)
- full Krylov solver runs (`CG`) using different comm schemes.
Communication schemes covered in this area:
- `psb_comm_isend_irecv_` (baseline point-to-point)
- `psb_comm_ineighbor_alltoallv_` (neighbor collective)
- `psb_comm_persistent_ineighbor_alltoallv_` (persistent neighbor collective)
See:
- `swapdata/` for a direct halo-exchange test.
- `spmv/` for an overlap SpMV test that uses different communication schemes.
- `cg/` for a conjugate-gradient solve-time comparison across the three schemes.

@ -0,0 +1,37 @@
INSTALLDIR=../../..
INCDIR=$(INSTALLDIR)/include/
MODDIR=$(INSTALLDIR)/modules/
include $(INCDIR)/Make.inc.psblas
LIBDIR=$(INSTALLDIR)/lib/
PSBLAS_LIB= -L$(LIBDIR) -lpsb_util -lpsb_linsolve -lpsb_prec -lpsb_base
LDLIBS=$(PSBLDLIBS)
FINCLUDES=$(FMFLAG)$(MODDIR) $(FMFLAG).
NP ?= 4
IDIM ?= 40
PROGSRC=psb_comm_cg_test.F90
TOBJS=psb_comm_cg_test.o
EXEDIR=./runs
EXE=psb_comm_cg_test
all: runsd $(EXE)
runsd:
(if test ! -d runs ; then mkdir runs; fi)
psb_comm_cg_test.o: $(PROGSRC)
$(FC) $(FCOPT) $(FINCLUDES) $(FDEFINES) -c $(PROGSRC) -o $@
$(EXE): $(TOBJS)
$(FLINK) $(LOPT) $(TOBJS) -o $(EXE) $(PSBLAS_LIB) $(LDLIBS)
/bin/mv $(EXE) $(EXEDIR)
run: all
mpirun -np $(NP) $(EXEDIR)/$(EXE) $(IDIM)
clean:
/bin/rm -f $(TOBJS) *$(.mod) $(EXEDIR)/$(EXE)

@ -0,0 +1,37 @@
CG no-preconditioner communication test
=======================================
This test lives under `test/comm/cg` and builds a local executable:
- source: `psb_comm_cg_test.F90`
- executable: `runs/psb_comm_cg_test`
Behavior:
- generates a 3D PDE matrix using local `psb_d_gen_pde3d` (double precision)
- solves with `CG`
- uses preconditioner `NONE`
- runs CG three times, changing communication scheme of `x` each run
Communication pattern used in this test:
1. reset solution vector
2. set communication scheme on `x%v%comm_handle`
3. run full `psb_krylov('CG', ...)`
4. collect and compare solve time
Schemes compared:
- `psb_comm_isend_irecv_`
- `psb_comm_ineighbor_alltoallv_`
- `psb_comm_persistent_ineighbor_alltoallv_`
How to run
----------
From this directory:
- `make run` (defaults: `NP=4`, `IDIM=40`)
- `make run NP=8 IDIM=80`
The program accepts one optional CLI argument: `IDIM`.

@ -0,0 +1,401 @@
program psb_comm_cg_test
use psb_base_mod
use psb_prec_mod
use psb_linsolve_mod
use psb_comm_factory_mod
implicit none
type(psb_ctxt_type) :: ctxt
type(psb_dspmat_type) :: a
type(psb_desc_type) :: desc_a
type(psb_d_vect_type) :: b, x
type(psb_dprec_type) :: prec
integer(psb_ipk_) :: info, iam, np
integer(psb_ipk_) :: idim, itmax, itrace, istop, iter, is
integer(psb_ipk_) :: iter_arr(3), info_arr(3)
integer(psb_ipk_) :: scheme_types(3)
real(psb_dpk_) :: eps, err, t1, t2
real(psb_dpk_) :: tsolve(3), err_arr(3)
character(len=25) :: scheme_names(3)
character(len=5) :: afmt
character(len=256) :: arg
info = psb_success_
afmt = 'CSR'
idim = 40
itmax = 500
itrace = 0
istop = 2
eps = 1.d-6
scheme_types = (/ psb_comm_isend_irecv_, psb_comm_ineighbor_alltoallv_, &
& psb_comm_persistent_ineighbor_alltoallv_ /)
scheme_names(1) = 'isend_irecv'
scheme_names(2) = 'ineighbor_alltoallv'
scheme_names(3) = 'persistent_ineighbor_a2av'
call get_command_argument(1,arg)
if (len_trim(arg) > 0) then
read(arg,*,iostat=info) idim
if (info /= 0) then
idim = 40
info = psb_success_
end if
end if
call psb_init(ctxt)
call psb_info(ctxt, iam, np)
if (iam == psb_root_) then
write(psb_out_unit,*) 'Welcome to PSBLAS version: ', psb_version_string_
write(psb_out_unit,*) 'This is the comm/cg test program'
write(psb_out_unit,'("Grid dimensions : ",i4," x ",i4," x ",i4)') idim,idim,idim
write(psb_out_unit,'("Number of processors : ",i0)') np
write(psb_out_unit,'("Iterative method : CG")')
write(psb_out_unit,'("Preconditioner : NONE")')
write(psb_out_unit,'(" ")')
end if
call psb_barrier(ctxt)
t1 = psb_wtime()
call psb_d_gen_pde3d(ctxt,idim,a,b,x,desc_a,afmt,info)
if (info /= psb_success_) goto 9999
call prec%init(ctxt,'NONE',info)
if (info /= psb_success_) goto 9999
call prec%build(a,desc_a,info)
if (info /= psb_success_) goto 9999
do is = 1, 3
call psb_geaxpby(dzero,b,dzero,x,desc_a,info)
if (info /= psb_success_) goto 9999
call psb_comm_init(scheme_types(is),x%v%comm_handle,info)
if (info /= psb_success_) goto 9999
call psb_barrier(ctxt)
t1 = psb_wtime()
call psb_krylov('CG',a,prec,b,x,eps,desc_a,info,&
& itmax=itmax,iter=iter,err=err,itrace=itrace,istop=istop)
t2 = psb_wtime() - t1
call psb_amx(ctxt,t2)
tsolve(is) = t2
iter_arr(is) = iter
err_arr(is) = err
info_arr(is) = info
if (info /= psb_success_) goto 9999
end do
if (iam == psb_root_) then
write(psb_out_unit,'(" ")')
write(psb_out_unit,'("CG solve time by communication scheme")')
write(psb_out_unit,'("--------------------------------------")')
do is = 1, 3
write(psb_out_unit,'(a25,2x,"time=",es12.5,2x,"iter=",i8,2x,"err=",es12.5,2x,"info=",i6)') &
& trim(scheme_names(is)), tsolve(is), iter_arr(is), err_arr(is), info_arr(is)
end do
end if
call psb_gefree(b,desc_a,info)
call psb_gefree(x,desc_a,info)
call psb_spfree(a,desc_a,info)
call prec%free(info)
call psb_cdfree(desc_a,info)
call psb_exit(ctxt)
stop
9999 call psb_error(ctxt)
stop 1
contains
function b1(x,y,z) result(val)
real(psb_dpk_), intent(in) :: x,y,z
real(psb_dpk_) :: val
val = dzero
end function b1
function b2(x,y,z) result(val)
real(psb_dpk_), intent(in) :: x,y,z
real(psb_dpk_) :: val
val = dzero
end function b2
function b3(x,y,z) result(val)
real(psb_dpk_), intent(in) :: x,y,z
real(psb_dpk_) :: val
val = dzero
end function b3
function cfun(x,y,z) result(val)
real(psb_dpk_), intent(in) :: x,y,z
real(psb_dpk_) :: val
val = dzero
end function cfun
function a1(x,y,z) result(val)
real(psb_dpk_), intent(in) :: x,y,z
real(psb_dpk_) :: val
val = done/80
end function a1
function a2(x,y,z) result(val)
real(psb_dpk_), intent(in) :: x,y,z
real(psb_dpk_) :: val
val = done/80
end function a2
function a3(x,y,z) result(val)
real(psb_dpk_), intent(in) :: x,y,z
real(psb_dpk_) :: val
val = done/80
end function a3
function gfun(x,y,z) result(val)
real(psb_dpk_), intent(in) :: x,y,z
real(psb_dpk_) :: val
val = dzero
if (x == done) then
val = done
else if (x == dzero) then
val = exp(y**2-z**2)
end if
end function gfun
subroutine psb_d_gen_pde3d(ctxt,idim,a,bv,xv,desc_a,afmt,info)
implicit none
integer(psb_ipk_), intent(in) :: idim
type(psb_dspmat_type), intent(out) :: a
type(psb_d_vect_type), intent(out) :: xv,bv
type(psb_desc_type), intent(out) :: desc_a
type(psb_ctxt_type), intent(in) :: ctxt
integer(psb_ipk_), intent(out) :: info
character(len=*), intent(in) :: afmt
integer(psb_ipk_), parameter :: nb=20
real(psb_dpk_) :: zt(nb),x,y,z
integer(psb_lpk_) :: m,n,glob_row
integer(psb_ipk_) :: nnz,nlr,i,ii,ib,k
integer(psb_ipk_) :: ix,iy,iz
integer(psb_ipk_) :: np, iam, nr, nt
integer(psb_ipk_) :: icoeff
integer(psb_lpk_), allocatable :: irow(:),icol(:),myidx(:)
real(psb_dpk_), allocatable :: val(:)
real(psb_dpk_) :: deltah, sqdeltah, deltah2
real(psb_dpk_) :: t0, t1, t2, t3, tasb, talc, ttot, tgen, tcdasb
integer(psb_ipk_) :: err_act
character(len=20) :: name, ch_err, tmpfmt
info = psb_success_
name = 'create_matrix'
call psb_erractionsave(err_act)
call psb_info(ctxt, iam, np)
deltah = done/(idim+2)
sqdeltah = deltah*deltah
deltah2 = 2.d0*deltah
m = idim*idim*idim
n = m
nnz = ((n*9)/(np))
if(iam == psb_root_) write(psb_out_unit,'("Generating Matrix (size=",i0,")...")')n
nt = (m+np-1)/np
nr = max(0,min(nt,m-(iam*nt)))
nt = nr
call psb_sum(ctxt,nt)
if (nt /= m) then
write(psb_err_unit,*) iam, 'Initialization error ',nr,nt,m
info = -1
call psb_barrier(ctxt)
call psb_abort(ctxt)
return
end if
call psb_barrier(ctxt)
t0 = psb_wtime()
call psb_cdall(ctxt,desc_a,info,nl=nr)
if (info == psb_success_) call psb_spall(a,desc_a,info,nnz=nnz)
if (info == psb_success_) call psb_geall(xv,desc_a,info)
if (info == psb_success_) call psb_geall(bv,desc_a,info)
call psb_barrier(ctxt)
talc = psb_wtime()-t0
if (info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='allocation rout.'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
allocate(val(20*nb),irow(20*nb),icol(20*nb),stat=info)
if (info /= psb_success_) then
info=psb_err_alloc_dealloc_
call psb_errpush(info,name)
goto 9999
endif
myidx = desc_a%get_global_indices()
nlr = size(myidx)
call psb_barrier(ctxt)
t1 = psb_wtime()
do ii=1, nlr,nb
ib = min(nb,nlr-ii+1)
icoeff = 1
do k=1,ib
i=ii+k-1
glob_row=myidx(i)
if (mod(glob_row,(idim*idim)) == 0) then
ix = glob_row/(idim*idim)
else
ix = glob_row/(idim*idim)+1
endif
if (mod((glob_row-(ix-1)*idim*idim),idim) == 0) then
iy = (glob_row-(ix-1)*idim*idim)/idim
else
iy = (glob_row-(ix-1)*idim*idim)/idim+1
endif
iz = glob_row-(ix-1)*idim*idim-(iy-1)*idim
x = (ix-1)*deltah
y = (iy-1)*deltah
z = (iz-1)*deltah
zt(k) = dzero
val(icoeff) = -a1(x,y,z)/sqdeltah-b1(x,y,z)/deltah2
if (ix == 1) then
zt(k) = gfun(dzero,y,z)*(-val(icoeff)) + zt(k)
else
icol(icoeff) = (ix-2)*idim*idim+(iy-1)*idim+iz
irow(icoeff) = glob_row
icoeff = icoeff+1
endif
val(icoeff) = -a2(x,y,z)/sqdeltah-b2(x,y,z)/deltah2
if (iy == 1) then
zt(k) = gfun(x,dzero,z)*(-val(icoeff)) + zt(k)
else
icol(icoeff) = (ix-1)*idim*idim+(iy-2)*idim+iz
irow(icoeff) = glob_row
icoeff = icoeff+1
endif
val(icoeff) = -a3(x,y,z)/sqdeltah-b3(x,y,z)/deltah2
if (iz == 1) then
zt(k) = gfun(x,y,dzero)*(-val(icoeff)) + zt(k)
else
icol(icoeff) = (ix-1)*idim*idim+(iy-1)*idim+(iz-1)
irow(icoeff) = glob_row
icoeff = icoeff+1
endif
val(icoeff)=2.d0*(a1(x,y,z)+a2(x,y,z)+a3(x,y,z))/sqdeltah + cfun(x,y,z)
icol(icoeff) = (ix-1)*idim*idim+(iy-1)*idim+iz
irow(icoeff) = glob_row
icoeff = icoeff+1
val(icoeff) = -a3(x,y,z)/sqdeltah+b3(x,y,z)/deltah2
if (iz == idim) then
zt(k) = gfun(x,y,done)*(-val(icoeff)) + zt(k)
else
icol(icoeff) = (ix-1)*idim*idim+(iy-1)*idim+(iz+1)
irow(icoeff) = glob_row
icoeff = icoeff+1
endif
val(icoeff) = -a2(x,y,z)/sqdeltah+b2(x,y,z)/deltah2
if (iy == idim) then
zt(k) = gfun(x,done,z)*(-val(icoeff)) + zt(k)
else
icol(icoeff) = (ix-1)*idim*idim+iy*idim+iz
irow(icoeff) = glob_row
icoeff = icoeff+1
endif
val(icoeff) = -a1(x,y,z)/sqdeltah+b1(x,y,z)/deltah2
if (ix == idim) then
zt(k) = gfun(done,y,z)*(-val(icoeff)) + zt(k)
else
icol(icoeff) = ix*idim*idim+(iy-1)*idim+iz
irow(icoeff) = glob_row
icoeff = icoeff+1
endif
end do
call psb_spins(icoeff-1,irow,icol,val,a,desc_a,info)
if(info /= psb_success_) exit
call psb_geins(ib,myidx(ii:ii+ib-1),zt(1:ib),bv,desc_a,info)
if(info /= psb_success_) exit
zt(:)=dzero
call psb_geins(ib,myidx(ii:ii+ib-1),zt(1:ib),xv,desc_a,info)
if(info /= psb_success_) exit
end do
tgen = psb_wtime()-t1
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='insert rout.'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
deallocate(val,irow,icol)
call psb_barrier(ctxt)
t1 = psb_wtime()
call psb_cdasb(desc_a,info)
tcdasb = psb_wtime()-t1
call psb_barrier(ctxt)
t1 = psb_wtime()
if (info == psb_success_) call psb_spasb(a,desc_a,info,afmt=afmt)
call psb_barrier(ctxt)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='asb rout.'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
if (info == psb_success_) call psb_geasb(xv,desc_a,info)
if (info == psb_success_) call psb_geasb(bv,desc_a,info)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='asb rout.'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
tasb = psb_wtime()-t1
call psb_barrier(ctxt)
ttot = psb_wtime() - t0
call psb_amx(ctxt,talc)
call psb_amx(ctxt,tgen)
call psb_amx(ctxt,tasb)
call psb_amx(ctxt,ttot)
if(iam == psb_root_) then
tmpfmt = a%get_fmt()
write(psb_out_unit,'("The matrix has been generated and assembled in ",a3," format.")') tmpfmt
write(psb_out_unit,'("-allocation time : ",es12.5)') talc
write(psb_out_unit,'("-coeff. gen. time : ",es12.5)') tgen
write(psb_out_unit,'("-desc asbly time : ",es12.5)') tcdasb
write(psb_out_unit,'("- mat asbly time : ",es12.5)') tasb
write(psb_out_unit,'("-total time : ",es12.5)') ttot
end if
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine psb_d_gen_pde3d
end program psb_comm_cg_test

@ -0,0 +1,18 @@
17 Number of entries below this
CG Iterative method BICGSTAB CGS BICG BICGSTABL RGMRES FCG CGR RICHARDSON
NONE Preconditioner NONE DIAG BJAC
CSR Storage format for matrix A: CSR COO
100 Domain size (actual system is this**3 in pde3d)
3 Partition: 1 BLOCK 3 3D
2 Stopping criterion 1 2
0200 MAXIT
10 ITRACE
002 IRST restart for RGMRES and BiCGSTABL
INVK Block Solver ILU,ILUT,INVK,INVT,AINV
NONE If ILU : MILU or NONE otherwise ignored
NONE Scaling if ILUT: NONE, MAXVAL otherwise ignored
0 Level of fill for forward factorization
1 Level of fill for inverse factorization (only INVK,INVT)
1E-1 Threshold for forward factorization
1E-1 Threshold for inverse factorization (Only INVK, INVT)
LLK Orthogonalization algorithm (only AINV)

@ -0,0 +1,33 @@
INSTALLDIR=../..
INCDIR=$(INSTALLDIR)/include/
MODDIR=$(INSTALLDIR)/modules/
include $(INCDIR)/Make.inc.psblas
#
# Libraries used
#
LIBDIR=$(INSTALLDIR)/lib/
PSBLAS_LIB= -L$(LIBDIR) -lpsb_util -lpsb_linsolve -lpsb_prec -lpsb_base
LDLIBS=$(PSBLDLIBS)
FINCLUDES=$(FMFLAG)$(MODDIR) $(FMFLAG).
TOBJS=psb_spmv_overlap_test.o spmv_overlap.o
EXEDIR=./runs
all: runsd spmv_overlap
runsd:
(if test ! -d runs ; then mkdir runs; fi)
spmv_overlap: $(TOBJS)
$(FLINK) $(LOPT) $(TOBJS) -o spmv_overlap $(PSBLAS_LIB) $(LDLIBS)
/bin/mv spmv_overlap $(EXEDIR)
clean:
/bin/rm -f $(TOBJS) $(TOBJS_API) *$(.mod) $(EXEDIR)/spmv_overlap
lib:
(cd ../../; make library)
verycleanlib:
(cd ../../; make veryclean)

@ -0,0 +1,21 @@
spmv overlap communication test
===============================
This test was added after introducing different communication schemes in PSBLAS.
It exercises the overlapped SpMV communication path inside `psb_spmm`.
Communication pattern:
- split exchange/computation flow (`start` + local compute + `wait`)
- halo/overlap update through internal swap routines used by SpMV kernels
- same matrix/vector workload repeated across schemes for timing comparison
Communication schemes compared:
- `psb_comm_isend_irecv_`
- `psb_comm_ineighbor_alltoallv_`
- `psb_comm_persistent_ineighbor_alltoallv_`
Unlike `swapdata/`, which checks direct halo exchange, this test covers the
overlapped SpMV workflow.

@ -0,0 +1,680 @@
!> Test program for overlapping communication and computation with psb_spmm.
!!
!! This benchmark compares two equivalent SpMV paths:
!! 1. Serialized halo exchange + compute
!! 2. Overlapped psb_spmm(..., doswap=.true.)
!!
module psb_spmv_overlap_test
use psb_base_mod
use psb_util_mod
use psb_comm_factory_mod, only: psb_comm_init
use psb_comm_schemes_mod, only: psb_comm_isend_irecv_, psb_comm_ineighbor_alltoallv_, &
& psb_comm_persistent_ineighbor_alltoallv_
implicit none
interface
function d_func_3d(x,y,z) result(val)
import :: psb_dpk_
real(psb_dpk_), intent(in) :: x,y,z
real(psb_dpk_) :: val
end function d_func_3d
end interface
contains
function d_null_func_3d(x,y,z) result(val)
real(psb_dpk_), intent(in) :: x,y,z
real(psb_dpk_) :: val
val = dzero
end function d_null_func_3d
!
! functions parametrizing the differential equation
!
!
! Note: b1, b2 and b3 are the coefficients of the first
! derivative of the unknown function. The default
! we apply here is to have them zero, so that the resulting
! matrix is symmetric/hermitian and suitable for
! testing with CG and FCG.
! When testing methods for non-hermitian matrices you can
! change the B1/B2/B3 functions to e.g. done/sqrt((3*done))
!
function b1(x,y,z)
use psb_base_mod, only : psb_dpk_, done, dzero
implicit none
real(psb_dpk_) :: b1
real(psb_dpk_), intent(in) :: x,y,z
b1=dzero
end function b1
function b2(x,y,z)
use psb_base_mod, only : psb_dpk_, done, dzero
implicit none
real(psb_dpk_) :: b2
real(psb_dpk_), intent(in) :: x,y,z
b2=dzero
end function b2
function b3(x,y,z)
use psb_base_mod, only : psb_dpk_, done, dzero
implicit none
real(psb_dpk_) :: b3
real(psb_dpk_), intent(in) :: x,y,z
b3=dzero
end function b3
function c(x,y,z)
use psb_base_mod, only : psb_dpk_, done, dzero
implicit none
real(psb_dpk_) :: c
real(psb_dpk_), intent(in) :: x,y,z
c=dzero
end function c
function a1(x,y,z)
use psb_base_mod, only : psb_dpk_, done, dzero
implicit none
real(psb_dpk_) :: a1
real(psb_dpk_), intent(in) :: x,y,z
a1=done/80
end function a1
function a2(x,y,z)
use psb_base_mod, only : psb_dpk_, done, dzero
implicit none
real(psb_dpk_) :: a2
real(psb_dpk_), intent(in) :: x,y,z
a2=done/80
end function a2
function a3(x,y,z)
use psb_base_mod, only : psb_dpk_, done, dzero
implicit none
real(psb_dpk_) :: a3
real(psb_dpk_), intent(in) :: x,y,z
a3=done/80
end function a3
function g(x,y,z)
use psb_base_mod, only : psb_dpk_, done, dzero
implicit none
real(psb_dpk_) :: g
real(psb_dpk_), intent(in) :: x,y,z
g = dzero
if (x == done) then
g = done
else if (x == dzero) then
g = exp(y**2-z**2)
end if
end function g
!
! subroutine to allocate and fill in the coefficient matrix and
! the rhs.
!
subroutine psb_d_gen_pde3d(ctxt,idim,a,bv,xv,desc_a,afmt,info,&
& f,amold,vmold,imold,partition,nrl,iv,tnd)
use psb_base_mod
use psb_util_mod
!
! Discretizes the partial differential equation
!
! a1 dd(u) a2 dd(u) a3 dd(u) b1 d(u) b2 d(u) b3 d(u)
! - ------ - ------ - ------ + ----- + ------ + ------ + c u = f
! dxdx dydy dzdz dx dy dz
!
! with Dirichlet boundary conditions
! u = g
!
! on the unit cube 0<=x,y,z<=1.
!
!
! Note that if b1=b2=b3=c=0., the PDE is the Laplace equation.
!
implicit none
integer(psb_ipk_) :: idim
type(psb_dspmat_type) :: a
type(psb_d_vect_type) :: xv,bv
type(psb_desc_type) :: desc_a
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: info
character(len=*) :: afmt
procedure(d_func_3d), optional :: f
class(psb_d_base_sparse_mat), optional :: amold
class(psb_d_base_vect_type), optional :: vmold
class(psb_i_base_vect_type), optional :: imold
integer(psb_ipk_), optional :: partition, nrl,iv(:)
logical, optional :: tnd
! Local variables.
integer(psb_ipk_), parameter :: nb=20
type(psb_d_csc_sparse_mat) :: acsc
type(psb_d_coo_sparse_mat) :: acoo
type(psb_d_csr_sparse_mat) :: acsr
real(psb_dpk_) :: zt(nb),x,y,z
integer(psb_ipk_) :: nnz,nr,nlr,i,j,ii,ib,k, partition_
integer(psb_lpk_) :: m,n,glob_row,nt
integer(psb_ipk_) :: ix,iy,iz,ia,indx_owner
! For 3D partition
! Note: integer control variables going directly into an MPI call
! must be 4 bytes, i.e. psb_mpk_
integer(psb_mpk_) :: npdims(3), npp, minfo
integer(psb_ipk_) :: npx,npy,npz, iamx,iamy,iamz,mynx,myny,mynz
integer(psb_ipk_), allocatable :: bndx(:),bndy(:),bndz(:)
! Process grid
integer(psb_ipk_) :: np, iam, nth
integer(psb_ipk_) :: icoeff
integer(psb_lpk_), allocatable :: irow(:),icol(:),myidx(:)
real(psb_dpk_), allocatable :: val(:)
! deltah dimension of each grid cell
! deltat discretization time
real(psb_dpk_) :: deltah, sqdeltah, deltah2
real(psb_dpk_), parameter :: rhs=dzero,one=done,zero=dzero
real(psb_dpk_) :: t0, t1, t2, t3, tasb, talc, ttot, tgen, tcdasb
integer(psb_ipk_) :: err_act
procedure(d_func_3d), pointer :: f_
logical :: tnd_
character(len=20) :: name, ch_err,tmpfmt
info = psb_success_
name = 'create_matrix'
call psb_erractionsave(err_act)
call psb_info(ctxt, iam, np)
if (present(f)) then
f_ => f
else
f_ => d_null_func_3d
end if
deltah = done/(idim+1)
sqdeltah = deltah*deltah
deltah2 = (2*done)* deltah
if (present(partition)) then
if ((1<= partition).and.(partition <= 3)) then
partition_ = partition
else
write(*,*) 'Invalid partition choice ',partition,' defaulting to 3'
partition_ = 3
end if
else
partition_ = 3
end if
! initialize array descriptor and sparse matrix storage. provide an
! estimate of the number of non zeroes
m = (1_psb_lpk_*idim)*idim*idim
n = m
nnz = ((n*7)/(np))
if(iam == psb_root_) write(psb_out_unit,'("Generating Matrix (size=",i0,")...")')n
t0 = psb_wtime()
select case(partition_)
case(1)
! A BLOCK partition
if (present(nrl)) then
nr = nrl
else
!
! Using a simple BLOCK distribution.
!
nt = (m+np-1)/np
nr = max(0,min(nt,m-(iam*nt)))
end if
nt = nr
call psb_sum(ctxt,nt)
if (nt /= m) then
write(psb_err_unit,*) iam, 'Initialization error ',nr,nt,m
info = -1
call psb_barrier(ctxt)
call psb_abort(ctxt)
return
end if
!
! First example of use of CDALL: specify for each process a number of
! contiguous rows
!
call psb_cdall(ctxt,desc_a,info,nl=nr)
myidx = desc_a%get_global_indices()
nlr = size(myidx)
case(2)
! A partition defined by the user through IV
if (present(iv)) then
if (size(iv) /= m) then
write(psb_err_unit,*) iam, 'Initialization error: wrong IV size',size(iv),m
info = -1
call psb_barrier(ctxt)
call psb_abort(ctxt)
return
end if
else
write(psb_err_unit,*) iam, 'Initialization error: IV not present'
info = -1
call psb_barrier(ctxt)
call psb_abort(ctxt)
return
end if
!
! Second example of use of CDALL: specify for each row the
! process that owns it
!
call psb_cdall(ctxt,desc_a,info,vg=iv)
myidx = desc_a%get_global_indices()
nlr = size(myidx)
case(3)
! A 3-dimensional partition
! A nifty MPI function will split the process list
npdims = 0
call mpi_dims_create(np,3,npdims,info)
npx = npdims(1)
npy = npdims(2)
npz = npdims(3)
allocate(bndx(0:npx),bndy(0:npy),bndz(0:npz))
! We can reuse idx2ijk for process indices as well.
call idx2ijk(iamx,iamy,iamz,iam,npx,npy,npz,base=0)
! Now let's split the 3D cube in hexahedra
call dist1Didx(bndx,idim,npx)
mynx = bndx(iamx+1)-bndx(iamx)
call dist1Didx(bndy,idim,npy)
myny = bndy(iamy+1)-bndy(iamy)
call dist1Didx(bndz,idim,npz)
mynz = bndz(iamz+1)-bndz(iamz)
! How many indices do I own?
nlr = mynx*myny*mynz
allocate(myidx(nlr))
! Now, let's generate the list of indices I own
nr = 0
do i=bndx(iamx),bndx(iamx+1)-1
do j=bndy(iamy),bndy(iamy+1)-1
do k=bndz(iamz),bndz(iamz+1)-1
nr = nr + 1
call ijk2idx(myidx(nr),i,j,k,idim,idim,idim)
end do
end do
end do
if (nr /= nlr) then
write(psb_err_unit,*) iam,iamx,iamy,iamz, 'Initialization error: NR vs NLR ',&
& nr,nlr,mynx,myny,mynz
info = -1
call psb_barrier(ctxt)
call psb_abort(ctxt)
end if
!
! Third example of use of CDALL: specify for each process
! the set of global indices it owns.
!
call psb_cdall(ctxt,desc_a,info,vl=myidx)
case default
write(psb_err_unit,*) iam, 'Initialization error: should not get here'
info = -1
call psb_barrier(ctxt)
call psb_abort(ctxt)
return
end select
if (info == psb_success_) call psb_spall(a,desc_a,info,nnz=nnz)
! define rhs from boundary conditions; also build initial guess
if (info == psb_success_) call psb_geall(xv,desc_a,info)
if (info == psb_success_) call psb_geall(bv,desc_a,info)
call psb_barrier(ctxt)
talc = psb_wtime()-t0
if (info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='allocation rout.'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
! we build an auxiliary matrix consisting of one row at a
! time; just a small matrix. might be extended to generate
! a bunch of rows per call.
!
allocate(val(20*nb),irow(20*nb),&
&icol(20*nb),stat=info)
if (info /= psb_success_ ) then
info=psb_err_alloc_dealloc_
call psb_errpush(info,name)
goto 9999
endif
! loop over rows belonging to current process in a block
! distribution.
call psb_barrier(ctxt)
t1 = psb_wtime()
do ii=1, nlr,nb
ib = min(nb,nlr-ii+1)
icoeff = 1
do k=1,ib
i=ii+k-1
! local matrix pointer
glob_row=myidx(i)
! compute gridpoint coordinates
call idx2ijk(ix,iy,iz,glob_row,idim,idim,idim)
! x, y, z coordinates
x = (ix-1)*deltah
y = (iy-1)*deltah
z = (iz-1)*deltah
zt(k) = f_(x,y,z)
! internal point: build discretization
!
! term depending on (x-1,y,z)
!
val(icoeff) = -a1(x,y,z)/sqdeltah-b1(x,y,z)/deltah2
if (ix == 1) then
zt(k) = g(dzero,y,z)*(-val(icoeff)) + zt(k)
else
call ijk2idx(icol(icoeff),ix-1,iy,iz,idim,idim,idim)
irow(icoeff) = glob_row
icoeff = icoeff+1
endif
! term depending on (x,y-1,z)
val(icoeff) = -a2(x,y,z)/sqdeltah-b2(x,y,z)/deltah2
if (iy == 1) then
zt(k) = g(x,dzero,z)*(-val(icoeff)) + zt(k)
else
call ijk2idx(icol(icoeff),ix,iy-1,iz,idim,idim,idim)
irow(icoeff) = glob_row
icoeff = icoeff+1
endif
! term depending on (x,y,z-1)
val(icoeff)=-a3(x,y,z)/sqdeltah-b3(x,y,z)/deltah2
if (iz == 1) then
zt(k) = g(x,y,dzero)*(-val(icoeff)) + zt(k)
else
call ijk2idx(icol(icoeff),ix,iy,iz-1,idim,idim,idim)
irow(icoeff) = glob_row
icoeff = icoeff+1
endif
! term depending on (x,y,z)
val(icoeff)=(2*done)*(a1(x,y,z)+a2(x,y,z)+a3(x,y,z))/sqdeltah &
& + c(x,y,z)
call ijk2idx(icol(icoeff),ix,iy,iz,idim,idim,idim)
irow(icoeff) = glob_row
icoeff = icoeff+1
! term depending on (x,y,z+1)
val(icoeff)=-a3(x,y,z)/sqdeltah+b3(x,y,z)/deltah2
if (iz == idim) then
zt(k) = g(x,y,done)*(-val(icoeff)) + zt(k)
else
call ijk2idx(icol(icoeff),ix,iy,iz+1,idim,idim,idim)
irow(icoeff) = glob_row
icoeff = icoeff+1
endif
! term depending on (x,y+1,z)
val(icoeff)=-a2(x,y,z)/sqdeltah+b2(x,y,z)/deltah2
if (iy == idim) then
zt(k) = g(x,done,z)*(-val(icoeff)) + zt(k)
else
call ijk2idx(icol(icoeff),ix,iy+1,iz,idim,idim,idim)
irow(icoeff) = glob_row
icoeff = icoeff+1
endif
! term depending on (x+1,y,z)
val(icoeff)=-a1(x,y,z)/sqdeltah+b1(x,y,z)/deltah2
if (ix==idim) then
zt(k) = g(done,y,z)*(-val(icoeff)) + zt(k)
else
call ijk2idx(icol(icoeff),ix+1,iy,iz,idim,idim,idim)
irow(icoeff) = glob_row
icoeff = icoeff+1
endif
end do
call psb_spins(icoeff-1,irow,icol,val,a,desc_a,info)
if(info /= psb_success_) exit
call psb_geins(ib,myidx(ii:ii+ib-1),zt(1:ib),bv,desc_a,info)
if(info /= psb_success_) exit
zt(:)=dzero
call psb_geins(ib,myidx(ii:ii+ib-1),zt(1:ib),xv,desc_a,info)
if(info /= psb_success_) exit
end do
tgen = psb_wtime()-t1
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='insert rout.'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
deallocate(val,irow,icol)
call psb_barrier(ctxt)
t1 = psb_wtime()
call psb_cdasb(desc_a,info,mold=imold)
tcdasb = psb_wtime()-t1
call psb_barrier(ctxt)
t1 = psb_wtime()
if (info == psb_success_) then
if (present(amold)) then
call psb_spasb(a,desc_a,info,mold=amold,bld_and=tnd)
else
call psb_spasb(a,desc_a,info,afmt=afmt,bld_and=tnd)
end if
end if
call psb_barrier(ctxt)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='asb rout.'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
if (info == psb_success_) call psb_geasb(xv,desc_a,info,mold=vmold)
if (info == psb_success_) call psb_geasb(bv,desc_a,info,mold=vmold)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='asb rout.'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
tasb = psb_wtime()-t1
call psb_barrier(ctxt)
ttot = psb_wtime() - t0
call psb_amx(ctxt,talc)
call psb_amx(ctxt,tgen)
call psb_amx(ctxt,tasb)
call psb_amx(ctxt,ttot)
if(iam == psb_root_) then
tmpfmt = a%get_fmt()
write(psb_out_unit,'("The matrix has been generated and assembled in ",a3," format.")')&
& tmpfmt
write(psb_out_unit,'("-allocation time : ",es12.5)') talc
write(psb_out_unit,'("-coeff. gen. time : ",es12.5)') tgen
write(psb_out_unit,'("-desc asbly time : ",es12.5)') tcdasb
write(psb_out_unit,'("- mat asbly time : ",es12.5)') tasb
write(psb_out_unit,'("-total time : ",es12.5)') ttot
end if
call psb_erractionrestore(err_act)
return
9999 call psb_error_handler(ctxt,err_act)
return
end subroutine psb_d_gen_pde3d
subroutine psb_spmv_overlap_kernel(ctxt)
use psb_base_mod
use psb_util_mod
implicit none
type(psb_ctxt_type), intent(in) :: ctxt
real(psb_dpk_) :: alpha, beta
type(psb_dspmat_type) :: a
type(psb_d_vect_type) :: x_baseline, x_neighbor, x_persistent
type(psb_d_vect_type) :: y_baseline, y_neighbor, y_persistent
type(psb_desc_type) :: desc_a
character(len=:), allocatable :: output_file_name
character(len=32) :: idim_str
real(psb_dpk_), allocatable :: x_global(:), y_global(:)
integer(psb_ipk_) :: my_rank, np, info, err_act
integer(psb_ipk_) :: n_global, idim
integer(psb_ipk_) :: i, times
real(psb_dpk_) :: t0, t1, dt
real(psb_dpk_) :: tsum_baseline, tsum_neighbor, tsum_persistent
real(psb_dpk_) :: err_bn, err_bp, tol
logical :: tnd
info = psb_success_
tol = 1.0d-10
times = 100
tsum_baseline = 0.0_psb_dpk_
tsum_neighbor = 0.0_psb_dpk_
tsum_persistent = 0.0_psb_dpk_
tnd = .false.
idim = 10
n_global = idim * idim * idim
alpha = done
beta = dzero
call psb_info(ctxt, my_rank, np)
call psb_barrier(ctxt)
call psb_d_gen_pde3d(ctxt,idim,a,y_baseline,x_baseline,desc_a,"CSR",info,partition=1)
if (info /= psb_success_) goto 9999
call psb_barrier(ctxt)
if (my_rank == psb_root_) then
allocate(x_global(n_global))
allocate(y_global(n_global))
do i = 1, n_global
x_global(i) = real(mod(i,17)+1, psb_dpk_) / real(17, psb_dpk_)
y_global(i) = real(mod(i,13), psb_dpk_) / real(29, psb_dpk_)
end do
end if
call psb_geall(x_neighbor, desc_a, info)
call psb_geall(x_persistent, desc_a, info)
call psb_geall(y_neighbor, desc_a, info)
call psb_geall(y_persistent, desc_a, info)
if (info /= psb_success_) goto 9999
call psb_scatter(x_global, x_baseline, desc_a, info, root=psb_root_)
call psb_scatter(x_global, x_neighbor, desc_a, info, root=psb_root_)
call psb_scatter(x_global, x_persistent, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_baseline, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_neighbor, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_persistent, desc_a, info, root=psb_root_)
if (info /= psb_success_) goto 9999
! Set communication schemes on the x vectors used by psb_spmm.
call psb_comm_init(psb_comm_isend_irecv_, x_baseline%v%comm_handle, info)
if (info /= psb_success_) goto 9999
call psb_comm_init(psb_comm_ineighbor_alltoallv_, x_neighbor%v%comm_handle, info)
if (info /= psb_success_) goto 9999
call psb_comm_init(psb_comm_persistent_ineighbor_alltoallv_, x_persistent%v%comm_handle, info)
if (info /= psb_success_) goto 9999
! Warm-up all schemes once.
call psb_spmm(alpha, a, x_baseline, beta, y_baseline, desc_a, info, doswap=.true.)
call psb_spmm(alpha, a, x_neighbor, beta, y_neighbor, desc_a, info, doswap=.true.)
call psb_spmm(alpha, a, x_persistent, beta, y_persistent, desc_a, info, doswap=.true.)
if (info /= psb_success_) goto 9999
! Restore vectors so timed loops start from same initial state.
call psb_scatter(x_global, x_baseline, desc_a, info, root=psb_root_)
call psb_scatter(x_global, x_neighbor, desc_a, info, root=psb_root_)
call psb_scatter(x_global, x_persistent, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_baseline, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_neighbor, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_persistent, desc_a, info, root=psb_root_)
if (info /= psb_success_) goto 9999
! Baseline (isend/irecv) overlapped SpMV.
t0 = psb_wtime()
do i = 1, times
call psb_spmm(alpha, a, x_baseline, beta, y_baseline, desc_a, info, doswap=.true.)
end do
t1 = psb_wtime()
dt = t1 - t0
call psb_amx(ctxt, dt)
tsum_baseline = tsum_baseline + dt
! Neighbor alltoallv overlapped SpMV.
t0 = psb_wtime()
do i = 1, times
call psb_spmm(alpha, a, x_neighbor, beta, y_neighbor, desc_a, info, doswap=.true.)
end do
t1 = psb_wtime()
dt = t1 - t0
call psb_amx(ctxt, dt)
tsum_neighbor = tsum_neighbor + dt
! Persistent-neighbor overlapped SpMV.
t0 = psb_wtime()
do i = 1, times
call psb_spmm(alpha, a, x_persistent, beta, y_persistent, desc_a, info, doswap=.true.)
end do
t1 = psb_wtime()
dt = t1 - t0
call psb_amx(ctxt, dt)
tsum_persistent = tsum_persistent + dt
err_bn = maxval(abs(y_baseline%get_vect() - y_neighbor%get_vect()))
err_bp = maxval(abs(y_baseline%get_vect() - y_persistent%get_vect()))
call psb_amx(ctxt, err_bn)
call psb_amx(ctxt, err_bp)
if (my_rank == 0) then
write(psb_out_unit,'(" Avg baseline time : ",es12.5)') tsum_baseline / real(times, psb_dpk_)
write(psb_out_unit,'(" Tot baseline time : ",es12.5)') tsum_baseline
write(psb_out_unit,'(" Avg neighbor time : ",es12.5)') tsum_neighbor / real(times, psb_dpk_)
write(psb_out_unit,'(" Tot neighbor time : ",es12.5)') tsum_neighbor
write(psb_out_unit,'(" Avg pers-neigh time: ",es12.5)') tsum_persistent / real(times, psb_dpk_)
write(psb_out_unit,'(" Tot pers-neigh time: ",es12.5)') tsum_persistent
write(psb_out_unit,'(" Check baseline vs neighbor err = ",es12.5)') err_bn
write(psb_out_unit,'(" Check baseline vs persistent err = ",es12.5)') err_bp
if ((err_bn > tol) .or. (err_bp > tol)) then
write(psb_out_unit,'(" WARNING: mismatch exceeds tolerance ",es12.5)') tol
end if
end if
call psb_gefree(x_baseline, desc_a, info)
call psb_gefree(x_neighbor, desc_a, info)
call psb_gefree(x_persistent, desc_a, info)
call psb_gefree(y_baseline, desc_a, info)
call psb_gefree(y_neighbor, desc_a, info)
call psb_gefree(y_persistent, desc_a, info)
call psb_spfree(a, desc_a, info)
call psb_cdfree(desc_a, info)
if (my_rank == 0) then
deallocate(x_global)
deallocate(y_global)
end if
return
9999 call psb_error(ctxt)
call psb_error_handler(ctxt, err_act)
end subroutine psb_spmv_overlap_kernel
end module psb_spmv_overlap_test

@ -0,0 +1,28 @@
program main
use psb_spmv_overlap_test
use psb_base_mod
implicit none
integer(psb_ipk_) :: my_rank, np
integer(psb_ipk_) :: k, h
type(psb_ctxt_type) :: ctxt
call psb_init(ctxt)
call psb_info(ctxt, my_rank, np)
if (my_rank == psb_root_) then
write(psb_out_unit,*) 'Welcome to PSBLAS version: ', psb_version_string_
write(psb_out_unit,*) 'This is the psb_spmv_overlap_test sample program'
end if
call psb_barrier(ctxt)
call psb_spmv_overlap_kernel(ctxt)
call psb_exit(ctxt)
end program main

@ -0,0 +1,19 @@
swapdata communication test
============================
This test was added after introducing different communication schemes in PSBLAS.
It focuses on direct halo exchange through the `swapdata` path:
- index list type: halo (`psb_comm_halo_`)
- exchange API: `psi_swapdata`
- phases: `start`, `wait`, and `sync` (depending on test section)
Communication patterns exercised:
- baseline point-to-point (`isend/irecv`)
- neighbor collective (`ineighbor_alltoallv`)
- persistent neighbor collective (`persistent_ineighbor_alltoallv`)
This test validates the low-level communication behavior in isolation, without
the full SpMV overlap pipeline.

@ -30,7 +30,7 @@ set(PSBLAS_LIBS psblas::util psblas::prec psblas::base)
include(${CMAKE_CURRENT_LIST_DIR}/geaxpby/CMakeLists.txt)
include(${CMAKE_CURRENT_LIST_DIR}/gedot/CMakeLists.txt)
include(${CMAKE_CURRENT_LIST_DIR}/spmm/CMakeLists.txt)
# Create executables
add_executable(psb_geaxpby_test ${geaxpby_source_files})
@ -42,6 +42,7 @@ target_link_libraries(psb_geaxpby_test ${PSBLAS_LIBS})
target_link_libraries(psb_gedot_test ${PSBLAS_LIBS})
target_link_libraries(psb_spmm_test ${PSBLAS_LIBS})
# Set output directory
set_target_properties(psb_geaxpby_test PROPERTIES
RUNTIME_OUTPUT_DIRECTORY ${EXEDIR}

@ -6,3 +6,11 @@ foreach(file IN LISTS PSB_spmm_source_files)
list(APPEND spmm_source_files ${CMAKE_CURRENT_LIST_DIR}/${file})
endforeach()
set(PSB_spmm_overlap_source_files
psb_spmm_overlap_test.f90
spmm_overlap.f90
)
foreach(file IN LISTS PSB_spmm_overlap_source_files)
list(APPEND spmm_overlap_source_files ${CMAKE_CURRENT_LIST_DIR}/${file})
endforeach()

@ -21,7 +21,7 @@ YELLOW=\033[33m
END_COLOUR=\033[0m
all: runsd psb_spmm_test
all: runsd psb_spmm_test psb_spmm_overlap_test
@printf "$(GREEN)[INFO]\tCompilation success!$(END_COLOUR)\n"
runsd:
@ -36,7 +36,7 @@ psb_spmm_test:
clean:
@rm -f $(OBJS)\
*$(.mod) $(EXEDIR)/psb_spmm_test
*$(.mod) $(EXEDIR)/psb_spmm_test $(EXEDIR)/psb_spmm_overlap_test
.PHONY: all runsd clean

@ -670,6 +670,8 @@ program psb_d_pde3d
use psb_linsolve_mod
use psb_util_mod
use psb_d_pde3d_mod
use psb_comm_factory_mod
#if defined(PSB_OPENMP)
use omp_lib
#endif
@ -835,6 +837,14 @@ program psb_d_pde3d
& err=err,itrace=itrace,&
& istop=istopc)
case('BICGSTAB','BICGSTABL','BICG','CG','CGS','FCG','GCR','RGMRES')
call psb_comm_init(psb_comm_persistent_ineighbor_alltoallv_,xxv%v%comm_handle,info)
if(info /= psb_success_) then
info=psb_err_from_subroutine_
ch_err='comm init'
call psb_errpush(info,name,a_err=ch_err)
goto 9999
end if
call psb_krylov(kmethd,a,prec,bv,xxv,eps,&
& desc_a,info,itmax=itmax,iter=iter,err=err,itrace=itrace,&
& istop=istopc,irst=irst)

Loading…
Cancel
Save