[UPDATE] Transient version usefull for debug on local server

communication_v2
Stack-1 4 weeks ago
parent 5ed9643fe6
commit fcae4a1633

@ -82,8 +82,61 @@
submodule (psi_d_comm_v_mod) psi_d_swapdata_impl
use psb_desc_const_mod, only: psb_swap_start_, psb_swap_wait_
use psb_base_mod
use psb_error_mod, only: psb_get_debug_level, psb_get_debug_unit, psb_debug_ext_
use psb_comm_factory_mod
logical, save :: psb_swap_timing_inited = .false.
logical, save :: psb_swap_timing_enabled = .false.
integer(psb_ipk_), save :: psb_swap_timing_max_report = 32
integer(psb_ipk_), save :: psb_swap_timing_report_count = 0
integer(psb_ipk_), save :: psb_swap_timing_wrapper_calls = 0
integer(psb_ipk_), save :: psb_swap_timing_baseline_calls = 0
integer(psb_ipk_), save :: psb_swap_timing_neighbor_calls = 0
contains
subroutine psb_swap_timing_setup()
implicit none
character(len=64) :: env_buf
integer(psb_ipk_) :: env_len, env_status, ios
if (psb_swap_timing_inited) return
psb_swap_timing_inited = .true.
psb_swap_timing_enabled = .false.
psb_swap_timing_max_report = 32
call get_environment_variable('PSB_SWAP_TIMING', env_buf, length=env_len, status=env_status)
if ((env_status == 0) .and. (env_len > 0)) then
select case(env_buf(1:1))
case('1','t','T','y','Y')
psb_swap_timing_enabled = .true.
case default
psb_swap_timing_enabled = .false.
end select
end if
call get_environment_variable('PSB_SWAP_TIMING_MAX_REPORT', env_buf, length=env_len, status=env_status)
if ((env_status == 0) .and. (env_len > 0)) then
read(env_buf(1:env_len), *, iostat=ios) psb_swap_timing_max_report
if ((ios /= 0) .or. (psb_swap_timing_max_report < 1)) psb_swap_timing_max_report = 32
end if
end subroutine psb_swap_timing_setup
logical function psb_swap_timing_should_report()
implicit none
call psb_swap_timing_setup()
psb_swap_timing_should_report = .false.
if (.not. psb_swap_timing_enabled) return
if (psb_swap_timing_report_count >= psb_swap_timing_max_report) return
psb_swap_timing_report_count = psb_swap_timing_report_count + 1
psb_swap_timing_should_report = .true.
end function psb_swap_timing_should_report
module subroutine psi_dswapdata_vect(swap_status,beta,y,desc_a,info,data)
#ifdef PSB_MPI_MOD
@ -103,8 +156,14 @@ contains
! locals
type(psb_ctxt_type) :: ctxt
integer(psb_ipk_) :: np, me, total_send, total_recv, num_neighbors, data_, err_act
integer(psb_ipk_) :: np, me, total_send, total_recv, num_neighbors, data_
class(psb_i_base_vect_type), pointer :: comm_indexes
logical :: debug_on
integer(psb_ipk_) :: dbg_unit
logical :: timing_on, timing_report
real(psb_dpk_) :: t0, t1, t_get_list, t_kernel, t_total
integer(psb_ipk_) :: call_idx
character(len=24) :: phase_name, scheme_name, exchange_name
! communication scheme/status selectors
logical :: baseline, neighbor_a2av
@ -126,6 +185,26 @@ contains
goto 9999
endif
debug_on = (psb_get_debug_level() >= psb_debug_ext_)
call psb_swap_timing_setup()
timing_on = psb_swap_timing_enabled
timing_report = .false.
if (timing_on) then
t_get_list = dzero
t_kernel = dzero
t_total = dzero
t0 = psb_wtime()
call_idx = psb_swap_timing_wrapper_calls + 1
end if
if (debug_on) then
dbg_unit = psb_get_debug_unit()
if (dbg_unit <= 0) dbg_unit = psb_err_unit
write(dbg_unit,*) me, trim(name), ': enter swap_status=', swap_status, &
& ' data=', data_, ' local_rows=', desc_a%get_local_rows(), &
& ' local_cols=', desc_a%get_local_cols(), ' y_nrows=', y%get_nrows()
end if
if (.not.psb_is_asb_desc(desc_a)) then
info=psb_err_invalid_cd_state_
call psb_errpush(info,name)
@ -138,17 +217,19 @@ contains
data_ = psb_comm_halo_
end if
if (timing_on) t1 = psb_wtime()
call desc_a%get_list_p(data_,comm_indexes,num_neighbors,total_recv,total_send,info)
if (timing_on) t_get_list = psb_wtime() - t1
if (info /= psb_success_) then
call psb_errpush(psb_err_internal_error_,name,a_err='desc_a%get_list_p')
goto 9999
end if
! Debug: report list sizes
! if(me == 0) then
! write(psb_err_unit,*) me, 'DBG: get_list_p -> num_neighbors=', &
! & num_neighbors, ' total_send=', total_send, ' total_recv=', total_recv
! end if
if (debug_on) then
write(dbg_unit,*) me, trim(name), ': list_p num_neighbors=', num_neighbors, &
& ' total_send=', total_send, ' total_recv=', total_recv, &
& ' comm_indexes_size=', size(comm_indexes%v)
end if
if( (swap_status /= psb_comm_status_start_).and.(swap_status /= psb_comm_status_wait_)&
@ -181,6 +262,11 @@ contains
goto 9999
end if
if (debug_on) then
write(dbg_unit,*) me, trim(name), ': comm_type=', y%comm_handle%comm_type, &
& ' swap_status=', swap_status
end if
! if(me == 0) then
! write(psb_err_unit,*) me, 'DBG: after set_swap_status, info=', info
! end if
@ -199,14 +285,18 @@ contains
! end if
if (baseline) then
if (timing_on) t1 = psb_wtime()
call psi_dswap_baseline_vect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,total_send,total_recv,y%comm_handle,info)
if (timing_on) t_kernel = psb_wtime() - t1
if (info /= psb_success_) then
call psb_errpush(info,name,a_err='baseline swap')
goto 9999
end if
else if (neighbor_a2av) then
if (timing_on) t1 = psb_wtime()
call psi_dswap_neighbor_topology_vect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,&
& total_send,total_recv,y%comm_handle,info)
if (timing_on) t_kernel = psb_wtime() - t1
if (info /= psb_success_) then
call psb_errpush(info,name,a_err='neighbor a2av swap')
goto 9999
@ -217,6 +307,45 @@ contains
goto 9999
end if
if (timing_on) then
t_total = psb_wtime() - t0
call psb_amx(ctxt, t_get_list)
call psb_amx(ctxt, t_kernel)
call psb_amx(ctxt, t_total)
if (me == psb_root_) timing_report = psb_swap_timing_should_report()
if ((me == psb_root_) .and. timing_report) then
psb_swap_timing_wrapper_calls = call_idx
select case(swap_status)
case(psb_comm_status_start_)
phase_name = 'start'
case(psb_comm_status_wait_)
phase_name = 'wait'
case(psb_comm_status_sync_)
phase_name = 'sync'
case default
phase_name = 'unknown'
end select
if (baseline) then
scheme_name = 'baseline'
else
if (y%comm_handle%comm_type == psb_comm_persistent_ineighbor_alltoallv_) then
scheme_name = 'persistent_neighbor'
else
scheme_name = 'neighbor'
end if
end if
if (call_idx == 1) then
exchange_name = 'first'
else
exchange_name = 'steady'
end if
write(psb_out_unit,'("SWAP_TIMING wrapper scheme=",a,", phase=",a,", exchange=",a,", call=",i0)') &
& trim(scheme_name), trim(phase_name), trim(exchange_name), call_idx
write(psb_out_unit,'(" get_list=",es12.5,", kernel=",es12.5,", total=",es12.5)') &
& t_get_list, t_kernel, t_total
end if
end if
call psb_erractionrestore(err_act)
return
@ -253,8 +382,15 @@ contains
integer(psb_ipk_) :: err_act, i, idx_pt, total_send_, total_recv_,&
& snd_pt, rcv_pt, pnti, n
logical :: do_send,do_recv
logical, parameter :: usersend=.false., debug=.false.
logical, parameter :: usersend=.false.
logical :: debug
logical :: timing_on, timing_report
integer(psb_ipk_) :: dbg_unit
character(len=20) :: name
real(psb_dpk_) :: t0, t1
real(psb_dpk_) :: t_buf, t_gth, t_post, t_wait, t_sct, t_dev, t_total
integer(psb_ipk_) :: call_idx
character(len=12) :: exchange_name
info = psb_success_
name = 'psi_dswap_baseline_vect'
@ -268,6 +404,24 @@ contains
icomm = ctxt%get_mpic()
debug = (psb_get_debug_level() >= psb_debug_ext_)
call psb_swap_timing_setup()
timing_on = psb_swap_timing_enabled
timing_report = .false.
if (timing_on) then
t_buf = dzero
t_gth = dzero
t_post = dzero
t_wait = dzero
t_sct = dzero
t_dev = dzero
t_total = dzero
t0 = psb_wtime()
call_idx = psb_swap_timing_baseline_calls + 1
end if
dbg_unit = psb_get_debug_unit()
if (dbg_unit <= 0) dbg_unit = psb_err_unit
baseline_comm_handle => null()
select type(ch => comm_handle)
type is(psb_comm_baseline_handle)
@ -292,7 +446,7 @@ contains
total_send_ = total_send * n
call comm_indexes%sync()
if (debug) write(*,*) me,'Internal buffer'
if (debug) write(dbg_unit,*) me,'Internal buffer'
if (do_send) then
if (allocated(baseline_comm_handle%comid)) then
if (any(baseline_comm_handle%comid /= mpi_request_null)) then
@ -304,13 +458,16 @@ contains
goto 9999
end if
end if
if (debug) write(*,*) me,'do_send start'
if (debug) write(dbg_unit,*) me,'do_send start'
if (timing_on) t1 = psb_wtime()
call y%new_buffer(ione*size(comm_indexes%v),info)
call psb_realloc(num_neighbors,2_psb_ipk_,baseline_comm_handle%comid,info)
baseline_comm_handle%comid = mpi_request_null
call psb_realloc(num_neighbors,prcid,info)
if (timing_on) t_buf = t_buf + (psb_wtime() - t1)
! First I post all the non blocking receives
pnti = 1
if (timing_on) t1 = psb_wtime()
do i=1, num_neighbors
proc_to_comm = comm_indexes%v(pnti+psb_proc_id_)
nerv = comm_indexes%v(pnti+psb_n_elem_recv_)
@ -319,7 +476,7 @@ contains
rcv_pt = 1+pnti+psb_n_elem_recv_
prcid(i) = psb_get_mpi_rank(ctxt,proc_to_comm)
if ((nerv>0).and.(proc_to_comm /= me)) then
if (debug) write(*,*) me,'Posting receive from',prcid(i),rcv_pt
if (debug) write(dbg_unit,*) me,'Posting receive from',prcid(i),rcv_pt
p2ptag = psb_double_swap_tag
call mpi_irecv(y%combuf(rcv_pt),nerv,&
& psb_mpi_r_dpk_,prcid(i),&
@ -327,11 +484,13 @@ contains
end if
pnti = pnti + nerv + nesd + 3
end do
if (debug) write(*,*) me,' Gather '
if (timing_on) t_post = t_post + (psb_wtime() - t1)
if (debug) write(dbg_unit,*) me,' Gather '
!
! Then gather for sending.
!
pnti = 1
if (timing_on) t1 = psb_wtime()
do i=1, num_neighbors
nerv = comm_indexes%v(pnti+psb_n_elem_recv_)
nesd = comm_indexes%v(pnti+nerv+psb_n_elem_send_)
@ -351,13 +510,16 @@ contains
call y%gth(idx_pt,nesd,comm_indexes)
pnti = pnti + nerv + nesd + 3
end do
if (timing_on) t_gth = t_gth + (psb_wtime() - t1)
!
! Then wait
!
if (timing_on) t1 = psb_wtime()
call y%device_wait()
if (timing_on) t_dev = t_dev + (psb_wtime() - t1)
if (debug) write(*,*) me,' isend'
if (debug) write(dbg_unit,*) me,' isend'
!
! Then send
!
@ -366,6 +528,7 @@ contains
snd_pt = 1
rcv_pt = 1
p2ptag = psb_double_swap_tag
if (timing_on) t1 = psb_wtime()
do i=1, num_neighbors
proc_to_comm = comm_indexes%v(pnti+psb_proc_id_)
nerv = comm_indexes%v(pnti+psb_n_elem_recv_)
@ -387,10 +550,11 @@ contains
pnti = pnti + nerv + nesd + 3
end do
if (timing_on) t_post = t_post + (psb_wtime() - t1)
end if
if (do_recv) then
if (debug) write(*,*) me,' do_Recv'
if (debug) write(dbg_unit,*) me,' do_Recv'
if (.not.allocated(baseline_comm_handle%comid)) then
!
! No matching send? Something is wrong....
@ -401,9 +565,10 @@ contains
end if
call psb_realloc(num_neighbors,prcid,info)
if (debug) write(*,*) me,' wait'
if (debug) write(dbg_unit,*) me,' wait'
pnti = 1
p2ptag = psb_double_swap_tag
if (timing_on) t1 = psb_wtime()
do i=1, num_neighbors
proc_to_comm = comm_indexes%v(pnti+psb_proc_id_)
nerv = comm_indexes%v(pnti+psb_n_elem_recv_)
@ -438,11 +603,13 @@ contains
end if
pnti = pnti + nerv + nesd + 3
end do
if (timing_on) t_wait = t_wait + (psb_wtime() - t1)
if (debug) write(*,*) me,' scatter'
if (debug) write(dbg_unit,*) me,' scatter'
pnti = 1
snd_pt = 1
rcv_pt = 1
if (timing_on) t1 = psb_wtime()
do i=1, num_neighbors
proc_to_comm = comm_indexes%v(pnti+psb_proc_id_)
nerv = comm_indexes%v(pnti+psb_n_elem_recv_)
@ -462,11 +629,12 @@ contains
goto 9999
end if
if (debug) write(0,*)me,' Received from: ',prcid(i),&
if (debug) write(dbg_unit,*)me,' Received from: ',prcid(i),&
& y%combuf(rcv_pt:rcv_pt+nerv-1)
call y%sct(rcv_pt,nerv,comm_indexes,beta)
pnti = pnti + nerv + nesd + 3
end do
if (timing_on) t_sct = t_sct + (psb_wtime() - t1)
!
! Waited for everybody, clean up
!
@ -475,9 +643,10 @@ contains
!
! Then wait for device
!
if (debug) write(*,*) me,' wait'
if (debug) write(dbg_unit,*) me,' wait'
if (timing_on) t1 = psb_wtime()
call y%device_wait()
if (debug) write(*,*) me,' free buffer'
if (debug) write(dbg_unit,*) me,' free buffer'
call y%maybe_free_buffer(info)
if (info == 0) then
if (allocated(y%comm_handle)) call y%comm_handle%free(info)
@ -486,7 +655,33 @@ contains
call psb_errpush(psb_err_alloc_dealloc_,name)
goto 9999
end if
if (debug) write(*,*) me,' done'
if (timing_on) t_dev = t_dev + (psb_wtime() - t1)
if (debug) write(dbg_unit,*) me,' done'
end if
if (timing_on) then
t_total = psb_wtime() - t0
call psb_amx(ctxt, t_buf)
call psb_amx(ctxt, t_gth)
call psb_amx(ctxt, t_post)
call psb_amx(ctxt, t_wait)
call psb_amx(ctxt, t_sct)
call psb_amx(ctxt, t_dev)
call psb_amx(ctxt, t_total)
if (me == psb_root_) timing_report = psb_swap_timing_should_report()
if ((me == psb_root_) .and. timing_report) then
psb_swap_timing_baseline_calls = call_idx
if (call_idx == 1) then
exchange_name = 'first'
else
exchange_name = 'steady'
end if
write(psb_out_unit,'("SWAP_TIMING baseline phase start=",l1,", wait=",l1)') do_send, do_recv
write(psb_out_unit,'(" exchange=",a,", call=",i0)') trim(exchange_name), call_idx
write(psb_out_unit,'(" buf=",es12.5,", gth=",es12.5,", post=",es12.5,", wait=",es12.5)') &
& t_buf, t_gth, t_post, t_wait
write(psb_out_unit,'(" sct=",es12.5,", dev=",es12.5,", total=",es12.5)') t_sct, t_dev, t_total
end if
end if
@ -526,8 +721,14 @@ contains
type(psb_comm_neighbor_handle), pointer :: neighbor_comm_handle
integer(psb_ipk_) :: err_act, topology_total_send, topology_total_recv, buffer_size
logical :: do_start, do_wait
logical, parameter :: debug = .false.
logical :: debug
logical :: timing_on, timing_report
integer(psb_ipk_) :: dbg_unit
character(len=30) :: name
real(psb_dpk_) :: t0, t1
real(psb_dpk_) :: t_topo, t_buf, t_gth, t_init, t_post, t_wait, t_sct, t_dev, t_total
integer(psb_ipk_) :: call_idx
character(len=12) :: exchange_name
info = psb_success_
@ -541,6 +742,25 @@ contains
endif
icomm = ctxt%get_mpic()
dbg_unit = psb_get_debug_unit()
if (dbg_unit <= 0) dbg_unit = psb_err_unit
debug = (psb_get_debug_level() >= psb_debug_ext_)
call psb_swap_timing_setup()
timing_on = psb_swap_timing_enabled
timing_report = .false.
if (timing_on) then
t_topo = dzero
t_buf = dzero
t_gth = dzero
t_init = dzero
t_post = dzero
t_wait = dzero
t_sct = dzero
t_dev = dzero
t_total = dzero
t0 = psb_wtime()
call_idx = psb_swap_timing_neighbor_calls + 1
end if
neighbor_comm_handle => null()
select type(ch => comm_handle)
@ -567,7 +787,7 @@ contains
! START phase: build topology (if needed), gather, post MPI
! ---------------------------------------------------------
if (do_start) then
if(debug) write(*,*) me,' nbr_vect: starting data exchange'
if(debug) write(dbg_unit,*) me,' nbr_vect: starting data exchange'
if (neighbor_comm_handle%use_persistent_buffers) then
if (neighbor_comm_handle%persistent_in_flight) then
info = psb_err_mpi_error_
@ -576,8 +796,10 @@ contains
end if
end if
if (.not. neighbor_comm_handle%is_initialized) then
if (debug) write(*,*) me,' nbr_vect: building topology via handle'
if (debug) write(dbg_unit,*) me,' nbr_vect: building topology via handle'
if (timing_on) t1 = psb_wtime()
call neighbor_comm_handle%topology_init(comm_indexes%v, num_neighbors, total_send, total_recv, ctxt, icomm, info)
if (timing_on) t_topo = t_topo + (psb_wtime() - t1)
if (info /= psb_success_) then
call psb_errpush(psb_err_internal_error_, name, a_err='neighbor_topology_init')
goto 9999
@ -592,6 +814,7 @@ contains
buffer_size = topology_total_send + topology_total_recv
if (buffer_size > 0) then
if (timing_on) t1 = psb_wtime()
if (neighbor_comm_handle%use_persistent_buffers) then
if (.not. allocated(y%combuf)) then
neighbor_comm_handle%diag_buffer_reallocs = neighbor_comm_handle%diag_buffer_reallocs + 1
@ -633,13 +856,16 @@ contains
goto 9999
end if
end if
if (timing_on) t_buf = t_buf + (psb_wtime() - t1)
neighbor_comm_handle%comm_request = mpi_request_null
! Gather send data into contiguous send buffer (polymorphic for GPU)
if (debug) write(*,*) me,' nbr_vect: gathering send data,', topology_total_send,' elems'
if (timing_on) t1 = psb_wtime()
call y%gth(int(topology_total_send,psb_mpk_), &
& neighbor_comm_handle%send_indexes, &
& y%combuf(1:topology_total_send))
if (timing_on) t_gth = t_gth + (psb_wtime() - t1)
else
! No data to send/recv: ensure requests/buffers indicate idle state
neighbor_comm_handle%comm_request = mpi_request_null
@ -648,7 +874,9 @@ contains
end if
! Wait for device (important for GPU subclasses)
if (timing_on) t1 = psb_wtime()
call y%device_wait()
if (timing_on) t_dev = t_dev + (psb_wtime() - t1)
if (neighbor_comm_handle%use_persistent_buffers) then
! Lazy persistent-init: build the request once, then reuse with START/WAIT.
@ -656,6 +884,7 @@ contains
#ifdef PSB_HAVE_MPI_NEIGHBOR_PERSISTENT
if (buffer_size > 0) then
if (debug) write(*,*) me,' nbr_vect: posting MPI_Neighbor_alltoallv_init'
if (timing_on) t1 = psb_wtime()
call mpi_neighbor_alltoallv_init( &
& y%combuf(1), & ! send buffer
& neighbor_comm_handle%send_counts, &
@ -668,6 +897,7 @@ contains
& neighbor_comm_handle%graph_comm, &
& mpi_info_null, &
& neighbor_comm_handle%persistent_request, iret)
if (timing_on) t_init = t_init + (psb_wtime() - t1)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info, name, m_err=(/iret/))
@ -689,7 +919,9 @@ contains
#ifdef PSB_HAVE_MPI_NEIGHBOR_PERSISTENT
if (buffer_size > 0) then
if (timing_on) t1 = psb_wtime()
call mpi_start(neighbor_comm_handle%persistent_request, iret)
if (timing_on) t_post = t_post + (psb_wtime() - t1)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info, name, m_err=(/iret/))
@ -702,6 +934,7 @@ contains
end if
#else
if (buffer_size > 0) then
if (timing_on) t1 = psb_wtime()
call mpi_ineighbor_alltoallv( &
& y%combuf(1), & ! send buffer
& neighbor_comm_handle%send_counts, &
@ -713,6 +946,7 @@ contains
& psb_mpi_r_dpk_, &
& neighbor_comm_handle%graph_comm, &
& neighbor_comm_handle%comm_request, iret)
if (timing_on) t_post = t_post + (psb_wtime() - t1)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info, name, m_err=(/iret/))
@ -728,6 +962,7 @@ contains
! Post non-blocking neighborhood alltoallv
if (debug) write(*,*) me,' nbr_vect: posting MPI_Ineighbor_alltoallv'
if (buffer_size > 0) then
if (timing_on) t1 = psb_wtime()
call mpi_ineighbor_alltoallv( &
& y%combuf(1), & ! send buffer
& neighbor_comm_handle%send_counts, &
@ -739,6 +974,7 @@ contains
& psb_mpi_r_dpk_, &
& neighbor_comm_handle%graph_comm, &
& neighbor_comm_handle%comm_request, iret)
if (timing_on) t_post = t_post + (psb_wtime() - t1)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info, name, m_err=(/iret/))
@ -788,6 +1024,7 @@ contains
if ((topology_total_send + topology_total_recv) > 0) then
! Wait for the non-blocking collective to complete
if (debug) write(*,*) me,' nbr_vect: waiting on MPI request'
if (timing_on) t1 = psb_wtime()
if (neighbor_comm_handle%use_persistent_buffers) then
#ifdef PSB_HAVE_MPI_NEIGHBOR_PERSISTENT
call mpi_wait(neighbor_comm_handle%persistent_request, p2pstat, iret)
@ -797,6 +1034,7 @@ contains
else
call mpi_wait(neighbor_comm_handle%comm_request, p2pstat, iret)
end if
if (timing_on) t_wait = t_wait + (psb_wtime() - t1)
if (iret /= mpi_success) then
info = psb_err_mpi_error_
call psb_errpush(info, name, m_err=(/iret/))
@ -811,10 +1049,12 @@ contains
! Scatter received data to local vector positions (polymorphic for GPU)
if (debug) write(*,*) me,' nbr_vect: scattering recv data,', topology_total_recv,' elems'
if (timing_on) t1 = psb_wtime()
call y%sct(int(topology_total_recv,psb_mpk_), &
& neighbor_comm_handle%recv_indexes, &
& y%combuf(topology_total_send+1:topology_total_send+topology_total_recv), &
& beta)
if (timing_on) t_sct = t_sct + (psb_wtime() - t1)
else
! nothing to wait/scatter
end if
@ -825,6 +1065,7 @@ contains
& (neighbor_comm_handle%use_persistent_buffers .and. .not. neighbor_comm_handle%persistent_request_ready)) then
neighbor_comm_handle%comm_request = mpi_request_null
end if
if (timing_on) t1 = psb_wtime()
call y%device_wait()
if (.not. neighbor_comm_handle%use_persistent_buffers) then
call y%maybe_free_buffer(info)
@ -833,10 +1074,43 @@ contains
goto 9999
end if
end if
if (timing_on) t_dev = t_dev + (psb_wtime() - t1)
if (debug) write(*,*) me,' nbr_vect: done'
end if ! do_wait
if (timing_on) then
t_total = psb_wtime() - t0
call psb_amx(ctxt, t_topo)
call psb_amx(ctxt, t_buf)
call psb_amx(ctxt, t_gth)
call psb_amx(ctxt, t_init)
call psb_amx(ctxt, t_post)
call psb_amx(ctxt, t_wait)
call psb_amx(ctxt, t_sct)
call psb_amx(ctxt, t_dev)
call psb_amx(ctxt, t_total)
if (me == psb_root_) timing_report = psb_swap_timing_should_report()
if ((me == psb_root_) .and. timing_report) then
psb_swap_timing_neighbor_calls = call_idx
if (call_idx == 1) then
exchange_name = 'first'
else
exchange_name = 'steady'
end if
if (neighbor_comm_handle%use_persistent_buffers) then
write(psb_out_unit,'("SWAP_TIMING persistent_neighbor phase start=",l1,", wait=",l1)') do_start, do_wait
else
write(psb_out_unit,'("SWAP_TIMING neighbor phase start=",l1,", wait=",l1)') do_start, do_wait
end if
write(psb_out_unit,'(" exchange=",a,", call=",i0)') trim(exchange_name), call_idx
write(psb_out_unit,'(" topo=",es12.5,", buf=",es12.5,", gth=",es12.5,", init=",es12.5)') &
& t_topo, t_buf, t_gth, t_init
write(psb_out_unit,'(" post=",es12.5,", wait=",es12.5,", sct=",es12.5,", dev=",es12.5,", total=",es12.5)') &
& t_post, t_wait, t_sct, t_dev, t_total
end if
end if
call psb_erractionrestore(err_act)
return

@ -140,6 +140,11 @@ subroutine psb_dcg_vect(a,prec,b,x,eps,desc_a,info,&
ctxt = desc_a%get_context()
call psb_info(ctxt, me, np)
if (np == -ione) then
info = psb_err_context_error_
call psb_errpush(info,name,a_err='invalid desc_a context in psb_dcg_vect')
goto 9999
end if
if (.not.allocated(b%v)) then
info = psb_err_invalid_vect_state_
call psb_errpush(info,name)

@ -9,12 +9,14 @@ program psb_comm_cg_test
implicit none
type(psb_ctxt_type) :: ctxt
type(psb_ctxt_type) :: desc_ctxt
type(psb_dspmat_type) :: a
type(psb_desc_type) :: desc_a
type(psb_d_vect_type) :: b, x
type(psb_dprec_type) :: prec
integer(psb_ipk_) :: info, iam, np
integer(psb_ipk_) :: desc_me, desc_np
integer(psb_ipk_) :: idim, itmax, itrace, istop, iter
integer(psb_ipk_) :: scheme_idx, prec_idx, rep, nrep, nwarm
integer(psb_ipk_), parameter :: n_schemes=3, n_precs=2
@ -42,7 +44,7 @@ program psb_comm_cg_test
nrep = 5
nwarm = 1
! Keep itrace positive to avoid modulo-by-zero paths in convergence logging.
itrace = 1
itrace = 0
istop = 2
eps = 1.d-6
scheme_type = (/ psb_comm_isend_irecv_, psb_comm_ineighbor_alltoallv_, &
@ -127,6 +129,14 @@ program psb_comm_cg_test
! call probe_ieee('after psb_d_gen_pde3d')
if (info /= psb_success_) goto 9999
! desc_ctxt = desc_a%get_context()
! call psb_info(desc_ctxt, desc_me, desc_np)
! if (desc_np == -1) then
! info = psb_err_context_error_
! write(psb_err_unit,*) 'Invalid descriptor context after psb_d_gen_pde3d'
! goto 9999
! end if
do prec_idx = 1, n_precs
do scheme_idx = 1, n_schemes
do rep = 1, nrep
@ -169,14 +179,6 @@ program psb_comm_cg_test
setup_time(prec_idx,scheme_idx,rep) = prec_init_time(prec_idx,scheme_idx,rep) + &
& prec_bld_time(prec_idx,scheme_idx,rep) + comm_set_time(prec_idx,scheme_idx,rep)
do iter = 1, nwarm
call psb_geaxpby(dzero,b,dzero,x,desc_a,info)
if (info /= psb_success_) goto 9999
call psb_krylov('CG',a,prec,b,x,eps,desc_a,info,&
& itmax=itmax,itrace=itrace,istop=istop)
if (info /= psb_success_) goto 9999
end do
call psb_geaxpby(dzero,b,dzero,x,desc_a,info)
if (info /= psb_success_) goto 9999
@ -212,20 +214,20 @@ program psb_comm_cg_test
final_error(prec_idx,scheme_idx,rep) = err
solve_info(prec_idx,scheme_idx,rep) = info
if (iam == psb_root_) then
select type(ch => x%v%comm_handle)
type is(psb_comm_neighbor_handle)
write(psb_out_unit,'("DIAG_COMM scheme=",a,", prec=",a,", rep=",i0)') &
& trim(scheme_name(scheme_idx)), trim(prec_name(prec_idx)), rep
write(psb_out_unit,'("DIAG_COMM counters: init=",i0,", start=",i0,", wait=",i0,", realloc=",i0)') &
& ch%diag_init_calls, ch%diag_start_calls, ch%diag_wait_calls, &
& ch%diag_buffer_reallocs
write(psb_out_unit,'("DIAG_COMM state: ready=",l1,", bsz=",i0)') &
& ch%persistent_request_ready, ch%persistent_buffer_size
class default
continue
end select
end if
! if (iam == psb_root_) then
! select type(ch => x%v%comm_handle)
! type is(psb_comm_neighbor_handle)
! write(psb_out_unit,'("DIAG_COMM scheme=",a,", prec=",a,", rep=",i0)') &
! & trim(scheme_name(scheme_idx)), trim(prec_name(prec_idx)), rep
! write(psb_out_unit,'("DIAG_COMM counters: init=",i0,", start=",i0,", wait=",i0,", realloc=",i0)') &
! & ch%diag_init_calls, ch%diag_start_calls, ch%diag_wait_calls, &
! & ch%diag_buffer_reallocs
! write(psb_out_unit,'("DIAG_COMM state: ready=",l1,", bsz=",i0)') &
! & ch%persistent_request_ready, ch%persistent_buffer_size
! class default
! continue
! end select
! end if
if (info /= psb_success_) goto 9999
end do

@ -0,0 +1,50 @@
#!/usr/bin/env bash
#SBATCH --job-name=psb_spmv_overlap
#SBATCH --partition=boost_usr_prod
#SBATCH --time=01:00:00
#SBATCH --nodes=2
#SBATCH --ntasks=64
#SBATCH --ntasks-per-node=32
#SBATCH --cpus-per-task=1
#SBATCH --threads-per-core=1
#SBATCH --gpus-per-node=4
#SBATCH --export=ALL
#SBATCH -A CNHPC_1736213
#SBATCH --output=psb_spmv_overlap_%j.out
#SBATCH --error=psb_spmv_overlap_%j.err
set -euo pipefail
# Environment tuned like the existing comm test script.
export UCX_VFS_ENABLE=n
export UCX_VFS_USE_FUSE=n
export UCX_STATS_DEST=none
export UCX_LOG_LEVEL=error
export UCX_TLS=dc,sm,self
export UCX_NET_DEVICES=mlx5_0:1
export OMPI_MCA_coll=^hcoll,han
export OMPI_MCA_coll_hcoll_enable=0
export UCX_MEMTYPE_CACHE=n
export UCX_CLOSE_TIMEOUT=10s
EXEC=./test/comm/spmv/runs/spmv_overlap
IDIM_LIST=${IDIM_LIST:-"20 40 60 80 100 140 180 220 260 300"}
TIMES=${TIMES:-100}
BUILD_IF_MISSING=${BUILD_IF_MISSING:-1}
if [[ ! -x "$EXEC" ]]; then
echo "Executable not found: $EXEC" >&2
exit 1
fi
echo "Running SpMV overlap comm test"
echo " EXEC=$EXEC"
echo " IDIM_LIST=$IDIM_LIST"
echo " TIMES=$TIMES"
echo " BUILD_IF_MISSING=$BUILD_IF_MISSING"
for IDIM in $IDIM_LIST; do
echo ""
echo "=== Running IDIM=$IDIM TIMES=$TIMES ==="
IDIM="$IDIM" TIMES="$TIMES" srun --exclusive -N2 -n64 "$EXEC"
done

@ -524,29 +524,47 @@ contains
real(psb_dpk_) :: alpha, beta
type(psb_dspmat_type) :: a
type(psb_d_vect_type) :: x_baseline, x_neighbor, x_persistent
type(psb_d_vect_type) :: y_baseline, y_neighbor, y_persistent
type(psb_d_vect_type) :: x_isend, x_neighbor, x_persistent
type(psb_d_vect_type) :: y_ov_isend, y_ov_neighbor, y_ov_persistent
type(psb_d_vect_type) :: y_no_isend, y_no_neighbor, y_no_persistent
type(psb_desc_type) :: desc_a
character(len=:), allocatable :: output_file_name
character(len=32) :: idim_str
character(len=64) :: env_buf
real(psb_dpk_), allocatable :: x_global(:), y_global(:)
integer(psb_ipk_) :: my_rank, np, info, err_act
integer :: env_len, env_status, ios
integer(psb_ipk_) :: n_global, idim
integer(psb_ipk_) :: i, times
real(psb_dpk_) :: t0, t1, dt
real(psb_dpk_) :: tsum_baseline, tsum_neighbor, tsum_persistent
real(psb_dpk_) :: err_bn, err_bp, tol
logical :: tnd
real(psb_dpk_) :: t_ov_isend, t_ov_neighbor, t_ov_persistent
real(psb_dpk_) :: t_no_isend, t_no_neighbor, t_no_persistent
real(psb_dpk_) :: err_isend, err_neighbor, err_persistent, tol
real(psb_dpk_) :: avg_ov, avg_no, speedup, gain_pct
info = psb_success_
tol = 1.0d-10
times = 100
tsum_baseline = 0.0_psb_dpk_
tsum_neighbor = 0.0_psb_dpk_
tsum_persistent = 0.0_psb_dpk_
tnd = .false.
t_ov_isend = 0.0_psb_dpk_
t_ov_neighbor = 0.0_psb_dpk_
t_ov_persistent = 0.0_psb_dpk_
t_no_isend = 0.0_psb_dpk_
t_no_neighbor = 0.0_psb_dpk_
t_no_persistent = 0.0_psb_dpk_
idim = 10
call psb_erractionsave(err_act)
call get_environment_variable('IDIM', env_buf, length=env_len, status=env_status)
if ((env_status == 0) .and. (env_len > 0)) then
read(env_buf(1:env_len), *, iostat=ios) idim
if ((ios /= 0) .or. (idim < 2)) idim = 10
end if
call get_environment_variable('TIMES', env_buf, length=env_len, status=env_status)
if ((env_status == 0) .and. (env_len > 0)) then
read(env_buf(1:env_len), *, iostat=ios) times
if ((ios /= 0) .or. (times < 1)) times = 100
end if
n_global = idim * idim * idim
alpha = done
beta = dzero
@ -555,7 +573,7 @@ contains
call psb_barrier(ctxt)
call psb_d_gen_pde3d(ctxt,idim,a,y_baseline,x_baseline,desc_a,"CSR",info,partition=1)
call psb_d_gen_pde3d(ctxt,idim,a,y_ov_isend,x_isend,desc_a,"CSR",info,partition=1)
if (info /= psb_success_) goto 9999
call psb_barrier(ctxt)
@ -571,96 +589,200 @@ contains
call psb_geall(x_neighbor, desc_a, info)
call psb_geall(x_persistent, desc_a, info)
call psb_geall(y_neighbor, desc_a, info)
call psb_geall(y_persistent, desc_a, info)
call psb_geall(y_ov_neighbor, desc_a, info)
call psb_geall(y_ov_persistent, desc_a, info)
call psb_geall(y_no_isend, desc_a, info)
call psb_geall(y_no_neighbor, desc_a, info)
call psb_geall(y_no_persistent, desc_a, info)
if (info /= psb_success_) goto 9999
call psb_scatter(x_global, x_baseline, desc_a, info, root=psb_root_)
call psb_scatter(x_global, x_isend, desc_a, info, root=psb_root_)
call psb_scatter(x_global, x_neighbor, desc_a, info, root=psb_root_)
call psb_scatter(x_global, x_persistent, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_baseline, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_neighbor, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_persistent, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_ov_isend, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_ov_neighbor, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_ov_persistent, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_no_isend, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_no_neighbor, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_no_persistent, desc_a, info, root=psb_root_)
if (info /= psb_success_) goto 9999
! Set communication schemes on the x vectors used by psb_spmm.
call psb_comm_set(psb_comm_isend_irecv_, x_baseline%v%comm_handle, info)
call psb_comm_set(psb_comm_isend_irecv_, x_isend%v%comm_handle, info)
if (info /= psb_success_) goto 9999
call psb_comm_set(psb_comm_ineighbor_alltoallv_, x_neighbor%v%comm_handle, info)
if (info /= psb_success_) goto 9999
call psb_comm_set(psb_comm_persistent_ineighbor_alltoallv_, x_persistent%v%comm_handle, info)
if (info /= psb_success_) goto 9999
! Warm-up all schemes once.
call psb_spmm(alpha, a, x_baseline, beta, y_baseline, desc_a, info, doswap=.true.)
call psb_spmm(alpha, a, x_neighbor, beta, y_neighbor, desc_a, info, doswap=.true.)
call psb_spmm(alpha, a, x_persistent, beta, y_persistent, desc_a, info, doswap=.true.)
! Warm-up all schemes once: overlap and non-overlap paths.
call psb_spmm(alpha, a, x_isend, beta, y_ov_isend, desc_a, info, doswap=.true.)
call psb_halo(x_isend, desc_a, info)
call psb_spmm(alpha, a, x_isend, beta, y_no_isend, desc_a, info, doswap=.false.)
call psb_spmm(alpha, a, x_neighbor, beta, y_ov_neighbor, desc_a, info, doswap=.true.)
call psb_halo(x_neighbor, desc_a, info)
call psb_spmm(alpha, a, x_neighbor, beta, y_no_neighbor, desc_a, info, doswap=.false.)
call psb_spmm(alpha, a, x_persistent, beta, y_ov_persistent, desc_a, info, doswap=.true.)
call psb_halo(x_persistent, desc_a, info)
call psb_spmm(alpha, a, x_persistent, beta, y_no_persistent, desc_a, info, doswap=.false.)
if (info /= psb_success_) goto 9999
! Restore vectors so timed loops start from same initial state.
call psb_scatter(x_global, x_baseline, desc_a, info, root=psb_root_)
! -----------------------------
! isend/irecv scheme
! -----------------------------
call psb_scatter(x_global, x_isend, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_ov_isend, desc_a, info, root=psb_root_)
call psb_barrier(ctxt)
t0 = psb_wtime()
do i = 1, times
call psb_spmm(alpha, a, x_isend, beta, y_ov_isend, desc_a, info, doswap=.true.)
end do
t1 = psb_wtime()
dt = t1 - t0
call psb_amx(ctxt, dt)
t_ov_isend = dt
call psb_scatter(x_global, x_isend, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_no_isend, desc_a, info, root=psb_root_)
call psb_barrier(ctxt)
t0 = psb_wtime()
do i = 1, times
call psb_halo(x_isend, desc_a, info)
call psb_spmm(alpha, a, x_isend, beta, y_no_isend, desc_a, info, doswap=.false.)
end do
t1 = psb_wtime()
dt = t1 - t0
call psb_amx(ctxt, dt)
t_no_isend = dt
! -----------------------------
! ineighbor_alltoallv scheme
! -----------------------------
call psb_scatter(x_global, x_neighbor, desc_a, info, root=psb_root_)
call psb_scatter(x_global, x_persistent, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_baseline, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_neighbor, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_persistent, desc_a, info, root=psb_root_)
if (info /= psb_success_) goto 9999
call psb_scatter(y_global, y_ov_neighbor, desc_a, info, root=psb_root_)
call psb_barrier(ctxt)
t0 = psb_wtime()
do i = 1, times
call psb_spmm(alpha, a, x_neighbor, beta, y_ov_neighbor, desc_a, info, doswap=.true.)
end do
t1 = psb_wtime()
dt = t1 - t0
call psb_amx(ctxt, dt)
t_ov_neighbor = dt
! Baseline (isend/irecv) overlapped SpMV.
call psb_scatter(x_global, x_neighbor, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_no_neighbor, desc_a, info, root=psb_root_)
call psb_barrier(ctxt)
t0 = psb_wtime()
do i = 1, times
call psb_spmm(alpha, a, x_baseline, beta, y_baseline, desc_a, info, doswap=.true.)
call psb_halo(x_neighbor, desc_a, info)
call psb_spmm(alpha, a, x_neighbor, beta, y_no_neighbor, desc_a, info, doswap=.false.)
end do
t1 = psb_wtime()
dt = t1 - t0
call psb_amx(ctxt, dt)
tsum_baseline = tsum_baseline + dt
t_no_neighbor = dt
! Neighbor alltoallv overlapped SpMV.
! ----------------------------------------
! persistent_ineighbor_alltoallv scheme
! ----------------------------------------
call psb_scatter(x_global, x_persistent, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_ov_persistent, desc_a, info, root=psb_root_)
call psb_barrier(ctxt)
t0 = psb_wtime()
do i = 1, times
call psb_spmm(alpha, a, x_neighbor, beta, y_neighbor, desc_a, info, doswap=.true.)
call psb_spmm(alpha, a, x_persistent, beta, y_ov_persistent, desc_a, info, doswap=.true.)
end do
t1 = psb_wtime()
dt = t1 - t0
call psb_amx(ctxt, dt)
tsum_neighbor = tsum_neighbor + dt
t_ov_persistent = dt
! Persistent-neighbor overlapped SpMV.
call psb_scatter(x_global, x_persistent, desc_a, info, root=psb_root_)
call psb_scatter(y_global, y_no_persistent, desc_a, info, root=psb_root_)
call psb_barrier(ctxt)
t0 = psb_wtime()
do i = 1, times
call psb_spmm(alpha, a, x_persistent, beta, y_persistent, desc_a, info, doswap=.true.)
call psb_halo(x_persistent, desc_a, info)
call psb_spmm(alpha, a, x_persistent, beta, y_no_persistent, desc_a, info, doswap=.false.)
end do
t1 = psb_wtime()
dt = t1 - t0
call psb_amx(ctxt, dt)
tsum_persistent = tsum_persistent + dt
t_no_persistent = dt
if (info /= psb_success_) goto 9999
err_bn = maxval(abs(y_baseline%get_vect() - y_neighbor%get_vect()))
err_bp = maxval(abs(y_baseline%get_vect() - y_persistent%get_vect()))
call psb_amx(ctxt, err_bn)
call psb_amx(ctxt, err_bp)
err_isend = maxval(abs(y_ov_isend%get_vect() - y_no_isend%get_vect()))
err_neighbor = maxval(abs(y_ov_neighbor%get_vect() - y_no_neighbor%get_vect()))
err_persistent = maxval(abs(y_ov_persistent%get_vect() - y_no_persistent%get_vect()))
call psb_amx(ctxt, err_isend)
call psb_amx(ctxt, err_neighbor)
call psb_amx(ctxt, err_persistent)
if (my_rank == 0) then
write(psb_out_unit,'(" Avg baseline time : ",es12.5)') tsum_baseline / real(times, psb_dpk_)
write(psb_out_unit,'(" Tot baseline time : ",es12.5)') tsum_baseline
write(psb_out_unit,'(" Avg neighbor time : ",es12.5)') tsum_neighbor / real(times, psb_dpk_)
write(psb_out_unit,'(" Tot neighbor time : ",es12.5)') tsum_neighbor
write(psb_out_unit,'(" Avg pers-neigh time: ",es12.5)') tsum_persistent / real(times, psb_dpk_)
write(psb_out_unit,'(" Tot pers-neigh time: ",es12.5)') tsum_persistent
write(psb_out_unit,'(" Check baseline vs neighbor err = ",es12.5)') err_bn
write(psb_out_unit,'(" Check baseline vs persistent err = ",es12.5)') err_bp
if ((err_bn > tol) .or. (err_bp > tol)) then
write(psb_out_unit,'(/,"SpMV overlap benchmark")')
write(psb_out_unit,'(" idim : ",i0)') idim
write(psb_out_unit,'(" global unknowns : ",i0)') n_global
write(psb_out_unit,'(" repetitions : ",i0)') times
write(psb_out_unit,'(" timing metric : max over MPI ranks")')
write(psb_out_unit,'(" gain(%) = 100*(1 - overlap/no_overlap)")')
write(psb_out_unit,'(/,"Scheme: isend_irecv")')
avg_ov = t_ov_isend / real(times, psb_dpk_)
avg_no = t_no_isend / real(times, psb_dpk_)
speedup = t_no_isend / max(t_ov_isend, tiny(done))
gain_pct = 100.0_psb_dpk_ * (done - (t_ov_isend / max(t_no_isend, tiny(done))))
write(psb_out_unit,'(" total overlap : ",es12.5)') t_ov_isend
write(psb_out_unit,'(" total no_overlap : ",es12.5)') t_no_isend
write(psb_out_unit,'(" avg overlap : ",es12.5)') avg_ov
write(psb_out_unit,'(" avg no_overlap : ",es12.5)') avg_no
write(psb_out_unit,'(" speedup (no/ov) : ",f10.4)') speedup
write(psb_out_unit,'(" gain (%) : ",f10.4)') gain_pct
write(psb_out_unit,'(" overlap vs no_overlap err = ",es12.5)') err_isend
write(psb_out_unit,'(/,"Scheme: ineighbor_alltoallv")')
avg_ov = t_ov_neighbor / real(times, psb_dpk_)
avg_no = t_no_neighbor / real(times, psb_dpk_)
speedup = t_no_neighbor / max(t_ov_neighbor, tiny(done))
gain_pct = 100.0_psb_dpk_ * (done - (t_ov_neighbor / max(t_no_neighbor, tiny(done))))
write(psb_out_unit,'(" total overlap : ",es12.5)') t_ov_neighbor
write(psb_out_unit,'(" total no_overlap : ",es12.5)') t_no_neighbor
write(psb_out_unit,'(" avg overlap : ",es12.5)') avg_ov
write(psb_out_unit,'(" avg no_overlap : ",es12.5)') avg_no
write(psb_out_unit,'(" speedup (no/ov) : ",f10.4)') speedup
write(psb_out_unit,'(" gain (%) : ",f10.4)') gain_pct
write(psb_out_unit,'(" overlap vs no_overlap err = ",es12.5)') err_neighbor
write(psb_out_unit,'(/,"Scheme: persistent_ineighbor_alltoallv")')
avg_ov = t_ov_persistent / real(times, psb_dpk_)
avg_no = t_no_persistent / real(times, psb_dpk_)
speedup = t_no_persistent / max(t_ov_persistent, tiny(done))
gain_pct = 100.0_psb_dpk_ * (done - (t_ov_persistent / max(t_no_persistent, tiny(done))))
write(psb_out_unit,'(" total overlap : ",es12.5)') t_ov_persistent
write(psb_out_unit,'(" total no_overlap : ",es12.5)') t_no_persistent
write(psb_out_unit,'(" avg overlap : ",es12.5)') avg_ov
write(psb_out_unit,'(" avg no_overlap : ",es12.5)') avg_no
write(psb_out_unit,'(" speedup (no/ov) : ",f10.4)') speedup
write(psb_out_unit,'(" gain (%) : ",f10.4)') gain_pct
write(psb_out_unit,'(" overlap vs no_overlap err = ",es12.5)') err_persistent
if ((err_isend > tol) .or. (err_neighbor > tol) .or. (err_persistent > tol)) then
write(psb_out_unit,'(" WARNING: mismatch exceeds tolerance ",es12.5)') tol
end if
end if
call psb_gefree(x_baseline, desc_a, info)
call psb_gefree(x_isend, desc_a, info)
call psb_gefree(x_neighbor, desc_a, info)
call psb_gefree(x_persistent, desc_a, info)
call psb_gefree(y_baseline, desc_a, info)
call psb_gefree(y_neighbor, desc_a, info)
call psb_gefree(y_persistent, desc_a, info)
call psb_gefree(y_ov_isend, desc_a, info)
call psb_gefree(y_ov_neighbor, desc_a, info)
call psb_gefree(y_ov_persistent, desc_a, info)
call psb_gefree(y_no_isend, desc_a, info)
call psb_gefree(y_no_neighbor, desc_a, info)
call psb_gefree(y_no_persistent, desc_a, info)
call psb_spfree(a, desc_a, info)
call psb_cdfree(desc_a, info)

@ -17,6 +17,7 @@
!
program psb_comm_test
use psb_base_mod
use psb_error_mod, only: psb_set_debug_level, psb_debug_ext_
use psi_mod
use psb_comm_factory_mod, only: psb_comm_set, psb_comm_free
use psb_comm_schemes_mod, only: psb_comm_ineighbor_alltoallv_, psb_comm_persistent_ineighbor_alltoallv_, &
@ -30,6 +31,7 @@ program psb_comm_test
integer(psb_ipk_) :: iters
character(len=32) :: arg
character(len=16) :: mode
logical :: debug_swapdata
! ---- descriptor / context ----
type(psb_ctxt_type) :: ctxt
@ -58,13 +60,16 @@ program psb_comm_test
real(psb_dpk_) :: t0, t1, dt, tsum_baseline, tsum_neighbor, tsum_neighbor_persistent
integer(psb_lpk_), allocatable :: glob_col(:)
character(len=40) :: name
real(psb_dpk_) :: huge_d
name = 'test_halo_new'
tol = 1.0d-12
huge_d = huge(1.0_psb_dpk_)
n_pass = 0
n_total = 0
iters = 5
mode = 'both'
debug_swapdata = .false.
! ---- parse command-line argument for idim ----
idim = 10
@ -92,9 +97,15 @@ program psb_comm_test
call get_command_argument(i+1, arg)
read(arg, *) mode
end if
else if (trim(arg) == '--debug') then
debug_swapdata = .true.
end if
end do
if (debug_swapdata) then
call psb_set_debug_level(psb_debug_ext_)
end if
run_baseline = .false.
run_neighbor = .false.
run_persistent = .false.
@ -180,11 +191,35 @@ program psb_comm_test
! 3. Allocate two D vectors (scratch) and fill owned entries
! ==================================================================
call psb_geall(v_baseline, desc_a, info)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'geall baseline error:', info
call psb_abort(ctxt)
end if
call psb_geall(v_neighbor, desc_a, info)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'geall neighbor error:', info
call psb_abort(ctxt)
end if
call psb_geall(v_neighbor_persistent, desc_a, info)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'geall persistent-neighbor error:', info
call psb_abort(ctxt)
end if
call psb_geasb(v_baseline, desc_a, info, scratch=.true.)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'geasb baseline error:', info
call psb_abort(ctxt)
end if
call psb_geasb(v_neighbor, desc_a, info, scratch=.true.)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'geasb neighbor error:', info
call psb_abort(ctxt)
end if
call psb_geasb(v_neighbor_persistent, desc_a, info, scratch=.true.)
if (info /= psb_success_) then
write(psb_err_unit,*) my_rank, 'geasb persistent-neighbor error:', info
call psb_abort(ctxt)
end if
! Fill owned entries with the global index value
allocate(vals(ncol))
@ -207,6 +242,10 @@ program psb_comm_test
do i = 1, ncol
expected(i) = real(glob_col(i), psb_dpk_)
end do
allocate(result_baseline(ncol), result_neighbor(ncol), result_persistent(ncol))
result_baseline = huge_d
result_neighbor = huge_d
result_persistent = huge_d
! ==================================================================
! 6. Baseline halo exchange (Isend/Irecv in one call)
@ -342,16 +381,33 @@ program psb_comm_test
! ==================================================================
! 8. Extract results and compare
! ==================================================================
result_baseline = v_baseline%get_vect()
result_neighbor = v_neighbor%get_vect()
result_persistent = v_neighbor_persistent%get_vect()
result_baseline = v_baseline%v%v
result_neighbor = v_neighbor%v%v
result_persistent = v_neighbor_persistent%v%v
! Debug: Check if results are properly populated
if (my_rank == 0 .and. debug_swapdata) then
write(psb_out_unit,'("DEBUG: ncol=",i0," nrow=",i0)') ncol, nrow
write(psb_out_unit,'("DEBUG: size(result_baseline)=",i0)') size(result_baseline)
if (ncol > 0) then
write(psb_out_unit,'("DEBUG: result_baseline(1:min(5,ncol))=",5(es12.5,1x))') &
& result_baseline(1:min(5,ncol))
write(psb_out_unit,'("DEBUG: expected(1:min(5,ncol))=",5(es12.5,1x))') &
& expected(1:min(5,ncol))
end if
end if
if (run_baseline .and. run_neighbor) then
n_total = n_total + 1
err = maxval(abs(result_baseline(1:ncol) - result_neighbor(1:ncol)))
err = huge_d
if (ncol > 0) then
err = maxval(abs(result_baseline(1:ncol) - result_neighbor(1:ncol)))
else
err = 0.0_psb_dpk_
end if
call psb_amx(ctxt, err)
if (my_rank == 0) then
if (err < tol) then
if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then
write(psb_out_unit,'(" [PASS] cross-check baseline vs neighbor : err = ",es12.5)') err
n_pass = n_pass + 1
else
@ -362,10 +418,15 @@ program psb_comm_test
if (run_baseline) then
n_total = n_total + 1
err = maxval(abs(result_baseline(1:ncol) - expected(1:ncol)))
err = huge_d
if (ncol > 0) then
err = maxval(abs(result_baseline(1:ncol) - expected(1:ncol)))
else
err = 0.0_psb_dpk_
end if
call psb_amx(ctxt, err)
if (my_rank == 0) then
if (err < tol) then
if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then
write(psb_out_unit,'(" [PASS] baseline absolute correctness : err = ",es12.5)') err
n_pass = n_pass + 1
else
@ -376,10 +437,15 @@ program psb_comm_test
if (run_neighbor) then
n_total = n_total + 1
err = maxval(abs(result_neighbor(1:ncol) - expected(1:ncol)))
err = huge_d
if (ncol > 0) then
err = maxval(abs(result_neighbor(1:ncol) - expected(1:ncol)))
else
err = 0.0_psb_dpk_
end if
call psb_amx(ctxt, err)
if (my_rank == 0) then
if (err < tol) then
if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then
write(psb_out_unit,'(" [PASS] neighbor absolute correctness : err = ",es12.5)') err
n_pass = n_pass + 1
else
@ -390,10 +456,15 @@ program psb_comm_test
if (run_baseline .and. run_persistent) then
n_total = n_total + 1
err = maxval(abs(result_baseline(1:ncol) - result_persistent(1:ncol)))
err = huge_d
if (ncol > 0) then
err = maxval(abs(result_baseline(1:ncol) - result_persistent(1:ncol)))
else
err = 0.0_psb_dpk_
end if
call psb_amx(ctxt, err)
if (my_rank == 0) then
if (err < tol) then
if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then
write(psb_out_unit,'(" [PASS] cross-check baseline vs pers-nei : err = ",es12.5)') err
n_pass = n_pass + 1
else
@ -404,10 +475,15 @@ program psb_comm_test
if (run_persistent) then
n_total = n_total + 1
err = maxval(abs(result_persistent(1:ncol) - expected(1:ncol)))
err = huge_d
if (ncol > 0) then
err = maxval(abs(result_persistent(1:ncol) - expected(1:ncol)))
else
err = 0.0_psb_dpk_
end if
call psb_amx(ctxt, err)
if (my_rank == 0) then
if (err < tol) then
if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then
write(psb_out_unit,'(" [PASS] pers-neigh absolute correctness : err = ",es12.5)') err
n_pass = n_pass + 1
else
@ -426,12 +502,17 @@ program psb_comm_test
call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_)
call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_)
result_neighbor = v_neighbor%get_vect()
result_neighbor = v_neighbor%v%v
n_total = n_total + 1
err = maxval(abs(result_neighbor(1:ncol) - expected(1:ncol)))
err = huge_d
if (ncol > 0) then
err = maxval(abs(result_neighbor(1:ncol) - expected(1:ncol)))
else
err = 0.0_psb_dpk_
end if
call psb_amx(ctxt, err)
if (my_rank == 0) then
if (err < tol) then
if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then
write(psb_out_unit,'(" [PASS] neighbor topology reuse : err = ",es12.5)') err
n_pass = n_pass + 1
else
@ -450,12 +531,17 @@ program psb_comm_test
call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_)
call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_)
result_persistent = v_neighbor_persistent%get_vect()
result_persistent = v_neighbor_persistent%v%v
n_total = n_total + 1
err = maxval(abs(result_persistent(1:ncol) - expected(1:ncol)))
err = huge_d
if (ncol > 0) then
err = maxval(abs(result_persistent(1:ncol) - expected(1:ncol)))
else
err = 0.0_psb_dpk_
end if
call psb_amx(ctxt, err)
if (my_rank == 0) then
if (err < tol) then
if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then
write(psb_out_unit,'(" [PASS] pers-neigh buffer reuse : err = ",es12.5)') err
n_pass = n_pass + 1
else
@ -467,6 +553,7 @@ program psb_comm_test
! ==================================================================
! 9. Summary
! ==================================================================
call psb_barrier(ctxt)
if (my_rank == 0) then
write(psb_out_unit,'("================================================")')
write(psb_out_unit,'(" Results: ",i0," / ",i0," tests passed")') n_pass, n_total
@ -477,6 +564,7 @@ program psb_comm_test
end if
write(psb_out_unit,'("================================================")')
end if
call psb_barrier(ctxt)
! ==================================================================
! 10. Cleanup

@ -0,0 +1,62 @@
#!/usr/bin/env bash
#SBATCH --job-name=psb_swapdata_test
#SBATCH --partition=boost_usr_prod
#SBATCH --time=02:00:00
#SBATCH --nodes=4
#SBATCH --ntasks=128
#SBATCH --ntasks-per-node=32
#SBATCH --cpus-per-task=1
#SBATCH --threads-per-core=1
#SBATCH --gpus-per-node=4
#SBATCH --export=ALL
#SBATCH -A CNHPC_1736213
#SBATCH --output=psb_comm_swapdata_%j.out
#SBATCH --error=psb_comm_swapdata_%j.err
set -euo pipefail
# Environment tuned like the main comm test script.
export UCX_VFS_ENABLE=n
export UCX_VFS_USE_FUSE=n
export UCX_STATS_DEST=none
export UCX_LOG_LEVEL=error
export UCX_TLS=dc,sm,self
export UCX_NET_DEVICES=mlx5_0:1
export OMPI_MCA_coll=^hcoll,han
export OMPI_MCA_coll_hcoll_enable=0
export UCX_MEMTYPE_CACHE=n
export UCX_CLOSE_TIMEOUT=10s
EXEC=./test/comm/swapdata/runs/psb_comm_test
DIM_START=${DIM_START:-20}
DIM_END=${DIM_END:-300}
DIM_STEP=${DIM_STEP:-20}
ITERS=${ITERS:-10}
MODE=${MODE:-both}
DEBUG=${DEBUG:-0}
EXTRA_ARGS=()
if [[ "$DEBUG" != "0" ]]; then
EXTRA_ARGS+=(--debug)
fi
if [[ ! -x "$EXEC" ]]; then
echo "Executable not found: $EXEC" >&2
echo "Build the test first with 'make' in test/comm/swapdata/" >&2
exit 1
fi
echo "Running swapdata comm test"
echo " EXEC=$EXEC"
echo " DIM_START=$DIM_START"
echo " DIM_END=$DIM_END"
echo " DIM_STEP=$DIM_STEP"
echo " ITERS=$ITERS"
echo " MODE=$MODE"
echo " DEBUG=$DEBUG"
for DIM in $(seq "$DIM_START" "$DIM_STEP" "$DIM_END"); do
echo ""
echo "=== Running DIM=$DIM ==="
srun --exclusive -N2 -n64 "$EXEC" --dim "$DIM" --iters "$ITERS" --mode "$MODE" "${EXTRA_ARGS[@]}"
done
Loading…
Cancel
Save